From 2923244ded1f5987724a913cd0d1e3ff1d95253c Mon Sep 17 00:00:00 2001 From: hinriksnaer Date: Mon, 24 Nov 2025 16:17:56 +0000 Subject: [PATCH 001/338] [dynamo] Remove redundant _nonvar_fields assignments in UserDefinedObjectVariable subclasses (#167801) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removed redundant `_nonvar_fields` assignments from 5 UserDefinedObjectVariable subclasses. These explicit re-assignments are unnecessary because Python's class attribute inheritance automatically provides access to parent class attributes. **Classes cleaned up:** - UserDefinedDictVariable - UserDefinedSetVariable - UserDefinedListVariable - UserDefinedTupleVariable - MutableMappingVariable All 5 classes inherit from `UserDefinedObjectVariable`, which defines `_nonvar_fields`. The pattern `_nonvar_fields = UserDefinedObjectVariable._nonvar_fields` is pure redundancy - the child classes will automatically inherit this attribute from the parent. ## Changes - **Lines removed:** 10 (5 redundant assignments + 5 blank lines) - **File modified:** `torch/_dynamo/variables/user_defined.py` ## Impact - **Code reduction:** -10 lines - **Maintainability:** ↑ (less redundancy) - **Risk:** Zero (identical behavior via inheritance) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167801 Approved by: https://github.com/guilhermeleobas --- torch/_dynamo/variables/user_defined.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index fb676295535df..e87af5b87a75a 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -2018,8 +2018,6 @@ class UserDefinedDictVariable(UserDefinedObjectVariable): UserDefinedObjectVariable. """ - _nonvar_fields = UserDefinedObjectVariable._nonvar_fields - def __init__(self, value, dict_vt=None, **kwargs): super().__init__(value, **kwargs) self._dict_vt = dict_vt @@ -2092,8 +2090,6 @@ class UserDefinedSetVariable(UserDefinedObjectVariable): UserDefinedObjectVariable. """ - _nonvar_fields = UserDefinedObjectVariable._nonvar_fields - def __init__(self, value, set_vt=None, **kwargs): super().__init__(value, **kwargs) self._set_vt = set_vt @@ -2167,8 +2163,6 @@ class UserDefinedListVariable(UserDefinedObjectVariable): UserDefinedObjectVariable. """ - _nonvar_fields = UserDefinedObjectVariable._nonvar_fields - def __init__(self, value, list_vt=None, **kwargs): super().__init__(value, **kwargs) self._list_vt = list_vt @@ -2210,8 +2204,6 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable): UserDefinedObjectVariable. """ - _nonvar_fields = UserDefinedObjectVariable._nonvar_fields - def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): super().__init__(value, init_args=init_args, **kwargs) self._tuple_vt = tuple_vt @@ -2252,8 +2244,6 @@ def unpack_var_sequence(self, tx): class MutableMappingVariable(UserDefinedObjectVariable): - _nonvar_fields = UserDefinedObjectVariable._nonvar_fields - def __init__(self, value, **kwargs): super().__init__(value, **kwargs) self.generic_dict_vt = variables.ConstDictVariable({}) From dfdf024e236d404f08fdeb5591e1a18a3c45b8fe Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 24 Nov 2025 17:35:47 +0000 Subject: [PATCH 002/338] Revert "[DTensor] compute shape and offset for arbitrary _StridedShard (#168146)" This reverts commit 08bfadf971742edf63d8e3eb16f151de0c9dc41b. Reverted https://github.com/pytorch/pytorch/pull/168146 on behalf of https://github.com/yangw-dev due to failed internal tests due to AttributeError: 'LocalIntNode' object has no attribute 'int_', please fix it and re-merge again ([comment](https://github.com/pytorch/pytorch/pull/168146#issuecomment-3571957707)) --- test/distributed/tensor/test_utils.py | 471 +++++++----------- torch/distributed/tensor/_api.py | 2 +- .../distributed/tensor/_ops/_common_rules.py | 5 +- torch/distributed/tensor/_ops/_matrix_ops.py | 2 +- torch/distributed/tensor/_sharding_prop.py | 2 +- torch/distributed/tensor/_utils.py | 227 +++++---- torch/distributed/tensor/placement_types.py | 9 +- 7 files changed, 318 insertions(+), 400 deletions(-) diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 5f3225d174cb2..11b70c8554e52 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -16,6 +16,7 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._utils import ( _compute_local_shape_and_global_offset, + _explicit_order_placements, compute_global_tensor_info, compute_global_tensor_shape, compute_local_shape_and_global_offset, @@ -45,6 +46,85 @@ class LocalTest(TestCase): + def test_explicit_order_placements(self): + # mesh_shape: ShapeType, placements: Sequence[Placement] + test_cases = [ + { + "mesh_shape": [2, 4], + "placements": [Replicate(), Replicate()], + "ordered": [(0, Replicate()), (1, Replicate())], + }, + { + "mesh_shape": [3, 2], + "placements": [Shard(0), Replicate()], + "ordered": [(0, Shard(0)), (1, Replicate())], + }, + { + "mesh_shape": [2, 4], + "placements": [_StridedShard(0, split_factor=4), Shard(0)], + "ordered": [(1, Shard(0)), (0, Shard(0))], + }, + { + "mesh_shape": [2, 3, 4], + "placements": [Shard(0), _StridedShard(0, split_factor=4), Shard(0)], + "ordered": [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))], + }, + { + "mesh_shape": [2, 3, 4], + "placements": [ + _StridedShard(0, split_factor=12), + _StridedShard(0, split_factor=4), + Shard(0), + ], + "ordered": [(2, Shard(0)), (1, Shard(0)), (0, Shard(0))], + }, + ] + for test_case in test_cases: + actual = _explicit_order_placements( + test_case["mesh_shape"], test_case["placements"] + ) + expected = test_case["ordered"] + + self.assertEqual( + actual, + expected, + f"mesh_shape={test_case['mesh_shape']} placements={test_case['placements']}, output: {actual=}, {expected=}", + ) + + error_cases = [ + { + "mesh_shape": [2, 3, 4], + "placements": [Shard(0), _StridedShard(0, split_factor=3), Shard(0)], + "exception_type": RuntimeError, + "exception_text": "Can only convert _StridedShard to ordered Shard if split_factor", + }, + { + "mesh_shape": [2, 3, 4], + "placements": [ + _StridedShard(0, split_factor=3), + Shard(0), + Shard(0), + ], + "exception_type": NotImplementedError, + "exception_text": r"Strided sharding does not allow Shard\(\) to appear after the strided part has ended", + }, + { + "mesh_shape": [2, 3], + "placements": [ + Shard(0), + ], + "exception_type": RuntimeError, + "exception_text": "Expected one placement per mesh dim", + }, + ] + for test_case in error_cases: + with self.assertRaisesRegex( + test_case["exception_type"], test_case["exception_text"] + ): + _explicit_order_placements( + test_case["mesh_shape"], test_case["placements"] + ) + def test_compute_local_shape_and_global_offset_uneven(self): # This case is not only 'uneven' bug also has an empty shard # (e.g. most DP ranks have local shape 18,4096, one has 8,4096, one has 0,4096 @@ -71,225 +151,6 @@ def test_compute_local_shape_and_global_offset_uneven(self): self.assertEqual(local_shape, (expected_shard_size, 4096)) self.assertEqual(global_offset, (expected_shard_offset, 0)) - # S, S uneven without empty - global_shape = (18, 2) - DP = 4 - TP = 2 - mesh_shape = (DP, TP) - placements = [Shard(0), Shard(0)] - for my_coordinate in itertools.product(range(DP), range(TP)): - dp_rank, tp_rank = my_coordinate - local_shape, global_offset = _compute_local_shape_and_global_offset( - global_shape, mesh_shape, list(my_coordinate), placements - ) - - dp012_shard_size = 5 - if dp_rank in (0, 1, 2): - tp0_shard_size = 3 - if tp_rank == 0: - expected_shard_offset = dp012_shard_size * dp_rank - expected_shard_size = 3 - else: - assert tp_rank == 1 - expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size - expected_shard_size = 2 - else: - assert dp_rank == 3 - tp0_shard_size = 2 - if tp_rank == 0: - expected_shard_offset = dp012_shard_size * dp_rank - expected_shard_size = 2 - else: - assert tp_rank == 1 - expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size - expected_shard_size = 1 - self.assertEqual(local_shape, (expected_shard_size, 2)) - self.assertEqual(global_offset, (expected_shard_offset, 0)) - - # S, S uneven with empty - global_shape = (13, 2) - DP = 4 - TP = 2 - mesh_shape = (DP, TP) - placements = [Shard(0), Shard(0)] - for my_coordinate in itertools.product(range(DP), range(TP)): - dp_rank, tp_rank = my_coordinate - local_shape, global_offset = _compute_local_shape_and_global_offset( - global_shape, mesh_shape, list(my_coordinate), placements - ) - - dp012_shard_size = 4 - if dp_rank in (0, 1, 2): - tp0_shard_size = 2 - if tp_rank == 0: - expected_shard_offset = dp012_shard_size * dp_rank - expected_shard_size = 2 - else: - assert tp_rank == 1 - expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size - expected_shard_size = 2 - else: - assert dp_rank == 3 - tp0_shard_size = 1 - if tp_rank == 0: - expected_shard_offset = dp012_shard_size * dp_rank - expected_shard_size = 1 - else: - assert tp_rank == 1 - expected_shard_offset = global_shape[0] - expected_shard_size = 0 - self.assertEqual(local_shape, (expected_shard_size, 2)) - self.assertEqual(global_offset, (expected_shard_offset, 0)) - - # SS, Shard - global_shape = (18, 2) - DP = 4 - TP = 2 - mesh_shape = (DP, TP) - placements = [_StridedShard(0, split_factor=TP), Shard(0)] - TP_shard_size = int(global_shape[0] / TP) - for my_coordinate in itertools.product(range(DP), range(TP)): - dp_rank, tp_rank = my_coordinate - local_shape, global_offset = _compute_local_shape_and_global_offset( - global_shape, mesh_shape, list(my_coordinate), placements - ) - expected_shard_size = 3 - expected_shard_offset = ( - tp_rank * TP_shard_size + expected_shard_size * dp_rank - ) - if dp_rank == 3: - expected_shard_size = 0 - expected_shard_offset = 18 - self.assertEqual(local_shape, (expected_shard_size, 2)) - self.assertEqual(global_offset, (expected_shard_offset, 0)) - - # SS, SS - global_shape = (39, 2) - DP = 4 - TP = 2 - mesh_shape = (DP, TP) - placements = [ - _StridedShard(0, split_factor=3), - _StridedShard(0, split_factor=4), - ] - for my_coordinate in itertools.product(range(DP), range(TP)): - dp_rank, tp_rank = my_coordinate - local_shape, global_offset = _compute_local_shape_and_global_offset( - global_shape, mesh_shape, list(my_coordinate), placements - ) - if dp_rank in (0, 1, 2): - tp0_shard_size = 8 - if tp_rank == 0: - expected_shard_offset = 4 * dp_rank - expected_shard_size = tp0_shard_size - else: - assert tp_rank == 1 - expected_shard_offset = 4 * dp_rank + 2 - expected_shard_size = 4 - else: - assert dp_rank == 3 - tp0_shard_size = 3 - if tp_rank == 0: - expected_shard_offset = 4 * dp_rank - expected_shard_size = 3 - else: - assert tp_rank == 1 - expected_shard_offset = global_shape[0] - expected_shard_size = 0 - self.assertEqual(local_shape, (expected_shard_size, 2)) - self.assertEqual(global_offset, (expected_shard_offset, 0)) - - # (Shard, SS) - global_shape = (18, 2) - DP = 4 - TP = 2 - mesh_shape = (DP, TP) - placements = [Shard(0), _StridedShard(0, split_factor=2)] - for my_coordinate in itertools.product(range(DP), range(TP)): - dp_rank, tp_rank = my_coordinate - local_shape, global_offset = _compute_local_shape_and_global_offset( - global_shape, mesh_shape, list(my_coordinate), placements - ) - if dp_rank in (0, 1, 2): - tp0_shard_size = 3 - if tp_rank == 0: - expected_shard_offset = 5 * dp_rank - expected_shard_size = tp0_shard_size - else: - assert tp_rank == 1 - expected_shard_offset = 5 * dp_rank + 2 - expected_shard_size = 2 - else: - assert dp_rank == 3 - if tp_rank == 0: - expected_shard_offset = 5 * dp_rank - expected_shard_size = 2 - else: - assert tp_rank == 1 - expected_shard_offset = 5 * dp_rank + 1 - expected_shard_size = 1 - self.assertEqual(local_shape, (expected_shard_size, 2)) - self.assertEqual(global_offset, (expected_shard_offset, 0)) - - # (Shard, SS, Shard) - global_shape = (39, 2) - mesh0, mesh1, mesh2 = 4, 2, 3 - mesh_shape = (mesh0, mesh1, mesh2) - placements = [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] - for my_coordinate in itertools.product( - range(mesh0), range(mesh1), range(mesh2) - ): - mesh0_rank, mesh1_rank, mesh2_rank = my_coordinate - local_shape, global_offset = _compute_local_shape_and_global_offset( - global_shape, mesh_shape, list(my_coordinate), placements - ) - if mesh0_rank in (0, 1, 2): - if mesh1_rank == 0: - if mesh2_rank == 0: - expected_shard_offset = 10 * mesh0_rank - expected_shard_size = 2 - elif mesh2_rank == 1: - expected_shard_offset = 10 * mesh0_rank + 2 - expected_shard_size = 2 - else: - expected_shard_offset = 10 * mesh0_rank + 6 - expected_shard_size = 2 - else: - assert mesh1_rank == 1 - if mesh2_rank == 0: - expected_shard_offset = 10 * mesh0_rank + 3 - expected_shard_size = 2 - elif mesh2_rank == 1: - expected_shard_offset = 10 * mesh0_rank + 8 - expected_shard_size = 2 - else: - assert mesh2_rank == 2 - expected_shard_size = 0 - expected_shard_offset = global_shape[0] - else: - assert mesh0_rank == 3 - if mesh1_rank == 0: - if mesh2_rank in (0, 1): - expected_shard_offset = 10 * mesh0_rank + 2 * mesh2_rank - expected_shard_size = 2 - else: - assert mesh2_rank == 2 - expected_shard_offset = 10 * mesh0_rank + 6 - expected_shard_size = 1 - else: - assert mesh1_rank == 1 - if mesh2_rank == 0: - expected_shard_offset = 10 * mesh0_rank + 3 - expected_shard_size = 2 - elif mesh2_rank == 1: - expected_shard_offset = 10 * mesh0_rank + 7 - expected_shard_size = 2 - else: - expected_shard_offset = global_shape[0] - expected_shard_size = 0 - self.assertEqual(local_shape, (expected_shard_size, 2)) - self.assertEqual(global_offset, (expected_shard_offset, 0)) - class UtilTest(DTensorTestBase): @property @@ -431,78 +292,6 @@ def test_compute_local_shape_and_global_offset_2D(self): global_tensor[dim0_start:dim0_end, dim1_start:dim1_end], ) - @with_comms - def test_compute_local_shape_and_global_offset_3D(self): - global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) - mesh_size_0 = 2 - mesh_size_1 = 2 - mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1) - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1, mesh_size_2), - mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"), - ) - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - Shard(0), - Shard(0), - ] - local_shape, global_offset = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - mesh0_rank, mesh1_rank, mesh2_rank = global_mesh.get_coordinate() - self.assertEqual(local_shape, [2, 2 * self.world_size]) - self.assertEqual( - global_offset, (4 * mesh0_rank + 8 * mesh1_rank + 2 * mesh2_rank, 0) - ) - - @with_comms - def test_compute_local_shape_and_global_offset_4D(self): - global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) - mesh_size_0 = 1 - mesh_size_1 = 2 - mesh_size_2 = 2 - mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2) - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3), - mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"), - ) - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - _StridedShard(1, split_factor=mesh_size_3), - Shard(0), - Shard(1), - ] - local_shape, global_offset = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - mesh0_rank, mesh1_rank, mesh2_rank, mesh3_rank = global_mesh.get_coordinate() - self.assertEqual( - local_shape, (2 * mesh_size_1 * mesh_size_3, 2 * mesh_size_0 * mesh_size_2) - ) - self.assertEqual( - global_offset, - (8 * mesh2_rank + 4 * mesh0_rank, 8 * mesh3_rank + 4 * mesh1_rank), - ) - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - _StridedShard(1, split_factor=mesh_size_3), - Shard(0), - Shard(0), - ] - local_shape, global_offset = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - mesh0_rank, mesh1_rank, mesh2_rank, mesh3_rank = global_mesh.get_coordinate() - self.assertEqual( - local_shape, (2 * mesh_size_1, 2 * mesh_size_2 * mesh_size_3 * mesh_size_0) - ) - self.assertEqual( - global_offset, - (8 * mesh2_rank + 0 * mesh0_rank + 4 * mesh3_rank, 4 * mesh1_rank), - ) - @with_comms def test_fsdp_tp_meta_compute(self): # FSDP + TP sharding @@ -573,6 +362,106 @@ def test_hsdp_tp_meta_compute(self): self.assertEqual(local_shape, expected_local_shape) self.assertEqual(global_offset, expected_global_offset) + # TODO: remove this test once we support general meta compute on strided sharding + @with_comms + def test_strided_sharding_assumption_in_meta_compute(self): + # current ``compute_local_shape_and_global_offset`` does not allow Shard(i) + # placement to appear after the strided sharding part has ended. This test + # check that ``compute_local_shape_and_global_offset`` does not allow placements + # that violate the assumption and does not forbid the allowed ones. + + # Test 0: 2-D mesh + mesh_size_0 = 2 + mesh_size_1 = self.world_size // mesh_size_0 + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1), + mesh_dim_names=("mesh-0", "mesh-1"), + ) + global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) + + for shard_dim in [0, 1]: + placements = [ + _StridedShard(shard_dim, split_factor=mesh_size_1), + Shard(shard_dim), + ] + _, _ = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + + # Test 1: 3-D mesh + mesh_size_0 = 2 + mesh_size_1 = 2 + mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1) + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1, mesh_size_2), + mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"), + ) + + # legal placements: Shard() appear after the strided part but it's on another + # tensor dimension. + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + Shard(0), + Shard(1), + ] + _, _ = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + + # illegal placements: Shard() appear after the strided part and it's on the + # same tensor dimension. + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + Shard(0), + Shard(0), + ] + with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"): + _, _ = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + + # Test 2: 4-D mesh + mesh_size_0 = 1 + mesh_size_1 = 2 + mesh_size_2 = 2 + mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2) + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3), + mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"), + ) + # legal placements: Shard() appear after the strided part but it's on another + # tensor dimension. + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + _StridedShard(1, split_factor=mesh_size_3), + Shard(0), + Shard(1), + ] + local_shape, _ = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + expected_local_shape = ( + 2 * mesh_size_1 * mesh_size_3, + 2 * mesh_size_0 * mesh_size_2, + ) + self.assertEqual(local_shape, expected_local_shape) + + # illegal placements: Shard() appear after the strided part and it's on the + # same tensor dimension. + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + _StridedShard(1, split_factor=mesh_size_3), + Shard(0), + Shard(0), + ] + with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"): + _, _ = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + class UtilSingleDeviceTest(TestCase): def test_compute_global_tensor_info_unsupported_placement(self): diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index dabf9f6f194ce..fb072d8dce629 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -1071,7 +1071,7 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] # get local tensor shape local_shape, _ = compute_local_shape_and_global_offset( - size, device_mesh, placements, skip_offset=True + size, device_mesh, placements ) # initialize the local tensor diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 2d4a311b4bedd..1e7ff648f7fbd 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -168,10 +168,7 @@ def merge_sharding(dim: str, a: int, b: int) -> int: assert input_spec.tensor_meta is not None global_shape = input_spec.tensor_meta.shape local_shape, _ = compute_local_shape_and_global_offset( - global_shape, - input_spec.mesh, - input_spec.placements, - skip_offset=True, + global_shape, input_spec.mesh, input_spec.placements ) cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) # pyrefly: ignore [bad-argument-type] diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index ecd7938d75e2e..81b9e328f0604 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -1090,7 +1090,7 @@ def local_meta(spec: OpSpec, placements: tuple[Placement, ...]) -> TensorMeta: meta: TensorMeta = spec.output_specs.tensor_meta local_stride = compute_local_stride(meta.stride, mesh, placements) local_shape, _ = compute_local_shape_and_global_offset( - meta.shape, mesh, placements, skip_offset=True + meta.shape, mesh, placements ) return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 2db44f387e4eb..f3dc04ef10f97 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -660,7 +660,7 @@ def _adjust_shape_and_stride_args( # adjust shape to be the same as that of the _local_tensor # of the DTensor input arg at index 0, which is inferred expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( - out_tensor_meta.shape, spec.mesh, spec.placements, skip_offset=True + out_tensor_meta.shape, spec.mesh, spec.placements ) # adjust the stride arg for aten.new_empty_strided.default diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index d7ee355500528..74ad2aaa80434 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,4 +1,5 @@ import threading +from collections import defaultdict from collections.abc import Sequence from typing import cast, Optional @@ -6,7 +7,6 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._api as dtensor from torch._prims_common import ShapeType -from torch.distributed._local_tensor import maybe_run_for_local_tensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._collective_utils import redistribute_cost from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -17,6 +17,7 @@ Replicate, Shard, ) +from torch.utils._typing_utils import not_none class ExplicitRedistributionContext: @@ -55,11 +56,61 @@ def __exit__(self, exc_type, exc_val, exc_tb): ExplicitRedistributionContext._local._active = self._prev +def _explicit_order_placements( + mesh_shape: ShapeType, placements: Sequence[Placement] +) -> Sequence[tuple[int, Placement]]: + """ + Replace Strided Shards with regular shards in an adjusted order. + + Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order. + + ex. + [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] -> + [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))] + + """ + if not len(placements) == len(mesh_shape): + raise RuntimeError( + "Expected one placement per mesh dim, " + f"but found {len(placements)} placements and {len(mesh_shape)} mesh dims." + ) + ordered = [] + deferred_strided_placements = defaultdict(list) + strided_part_ended_for_dim = set() + for mesh_dim, p in enumerate(placements): + if isinstance(p, _StridedShard): + # validate the stride is the correct multiple of the meshdim and the earlier shard + deferred_strided_placements[p.dim].append((mesh_dim, p)) + + else: + ordered.append((mesh_dim, p)) + if isinstance(p, Shard): + if p.dim in strided_part_ended_for_dim: + raise NotImplementedError( + f"Strided sharding does not allow Shard() to appear after " + f"the strided part has ended. {p} at mesh dim {mesh_dim} in " + f"{placements} violates this assumption." + ) + + if p.dim in deferred_strided_placements: + strided_part_ended_for_dim.add(p.dim) + strided_placements = deferred_strided_placements.pop(p.dim) + aggregate_size = mesh_shape[mesh_dim] + while len(strided_placements) > 0: + strided_mesh_dim, strided = strided_placements.pop() + if not strided.split_factor == aggregate_size: + raise RuntimeError( + f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})" + f" == aggregate mesh size ({aggregate_size})" + ) + aggregate_size *= mesh_shape[strided_mesh_dim] + ordered.append((strided_mesh_dim, Shard(p.dim))) + + return ordered + + def compute_local_shape_and_global_offset( - global_shape: ShapeType, - mesh: DeviceMesh, - placements: Sequence[Placement], - skip_offset: bool = False, + global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Compute the local tensor shape and the global offsets into the original tensor @@ -92,55 +143,24 @@ def compute_local_shape_and_global_offset( global_shape (ShapeType): The global shape of the DTensor. mesh (:class:`DeviceMesh`): The device mesh this DTensor is distributed on. placements (Sequence[:class:`Placement`]]): The placements of the DTensor. - skip_offset (bool): If True, skip computing the global offsets and return an empty - tuple for global_offset. This can improve performance when only the local shape - is needed. Defaults to False. Return: local_shape: the shape of the DTensor's _local_tensor on the current rank. global_offset: a tuple of offsets for each dimension of the global tensor shape, - identifying how this shard fits into the global tensor in each dimension. If - skip_offset is True, this will be an empty tuple. + identifying how this shard fits into the global tensor in each dimension. """ return _compute_local_shape_and_global_offset( - global_shape, mesh.shape, mesh.get_coordinate(), placements, skip_offset + global_shape, mesh.shape, mesh.get_coordinate(), placements ) -@maybe_run_for_local_tensor -def _compute_offsets( - placement, - shard_offsets: int, - shard_size: int, - zero_global_offset: int, - previous_offsets, -) -> torch.Tensor: - if shard_size == 0: - return torch.arange(zero_global_offset, zero_global_offset + 1) - if isinstance(placement, Shard) and not isinstance(placement, _StridedShard): - index = torch.arange(shard_offsets, shard_offsets + shard_size) - else: - assert isinstance(shard_offsets, list) - index = torch.tensor(shard_offsets) - if previous_offsets is None: - return index - else: - return previous_offsets[index] - - -@maybe_run_for_local_tensor -def _get_first_offset(offsets: torch.Tensor) -> int: - return int(offsets[0]) - - # accept 'plain data types' to enable simpler unit testing without creating device mesh def _compute_local_shape_and_global_offset( global_shape: ShapeType, mesh_shape: ShapeType, my_coordinate: Optional[list[int]], placements: Sequence[Placement], - skip_offset: bool = False, ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Suppose you have a full tensor with size global_shape, and you have sharded @@ -156,72 +176,85 @@ def _compute_local_shape_and_global_offset( This function is fairly simple if your tensor is evenly sharded; the complication is around uneven splits. There is also some complication for handling StridedShard, which changes the order you should apply sharding. - - Args: - global_shape (ShapeType): The global shape of the tensor. - mesh_shape (ShapeType): The shape of the device mesh. - my_coordinate (Optional[list[int]]): The coordinate of the current rank in the device mesh. - placements (Sequence[Placement]): The placements of the DTensor. - skip_offset (bool): If True, skip computing the global offsets and return an empty - tuple for global_offset. This can improve performance when only the local shape - is needed. Defaults to False. - - Returns: - tuple: A tuple containing: - - local_shape (tuple[int, ...]): The shape of the local shard on the current rank. - - global_offset (tuple[int, ...]): The offsets for each dimension identifying where - this shard begins in the global tensor. If skip_offset is True, this will be an - empty tuple. """ - empty_offset = () if my_coordinate is None: # if rank not in the mesh, return empty offset - return ((0,), empty_offset) + return ((0,), ()) + + # StridedShard implies a non-standard order to apply shards; get the + # correct order to start applying splits + ordered_placements = _explicit_order_placements(mesh_shape, placements) local_shape = list(global_shape) - # Perform shard from left to right. For example, - # global tensor: [0, 1, 2, 3, 4, 5, 6, 7] - # placements: S(0), SS(0, split_factor=2) - # mesh_shape: (2, 2) - # After S(0), shard_dim_to_global_offsets are - # {0: [0, 1, 2, 3]} on my_coordinate [0, 0] [0, 1] - # {0: [4, 5, 6, 7]} on my_coordinate [1, 0] [1, 1] - # After SS(0, split_factor=2), shard_dim_to_global_offsets are - # {0: [0, 2]} on my_coordinate [0, 0] - # {0: [1, 3]} on my_coordinate [0, 1] - # {0: [4, 6]} on my_coordinate [1, 0] - # {0: [5, 7]} on my_coordinate [1, 1] - shard_dim_to_global_offsets = {} - for mesh_dim, placement in enumerate(placements): - mesh_dim_size = mesh_shape[mesh_dim] - if not isinstance(placement, (Shard, _StridedShard)): - continue - shard_dim = placement.dim - zero_global_offset = global_shape[shard_dim] - assert shard_dim < len(local_shape), ( - f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" - ) - shard_size, shard_offsets = placement._local_shard_size_and_offset( - local_shape[shard_dim], - mesh_dim_size, - my_coordinate[mesh_dim], - ) - local_shape[shard_dim] = shard_size - if skip_offset: - continue - shard_dim_to_global_offsets[shard_dim] = _compute_offsets( - placement, - shard_offsets, - shard_size, - zero_global_offset, - shard_dim_to_global_offsets.get(shard_dim), - ) - if skip_offset: - return tuple(local_shape), empty_offset + # We'll compute the data for where the shard begins on a per-dim basis. + # However, a single dim can be sharded multiple times, so we will end up + # doing a Sum(size*stride) like computation to determine the location of our + # shard for each of the shardings on that dim. global_offset = [0] * len(global_shape) - for shard_dim, global_offsets in shard_dim_to_global_offsets.items(): - global_offset[shard_dim] = _get_first_offset(global_offsets) + + for mesh_dim, placement in ordered_placements: + mesh_dim_size = mesh_shape[mesh_dim] + if isinstance(placement, Shard): + shard_dim = placement.dim + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) + shard_size, shard_offset = placement._local_shard_size_and_offset( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[mesh_dim], + ) + + local_shape[shard_dim] = shard_size + + shard_global_offset = global_offset[shard_dim] + not_none(shard_offset) + + zero_global_offset = global_shape[shard_dim] + if isinstance(shard_global_offset, torch.SymInt) and not isinstance( + zero_global_offset, torch.SymInt + ): + zero_global_offset = torch.SymInt(zero_global_offset) + + global_offset[shard_dim] = torch.sym_ite( + shard_size == 0, + # Special case to fill in a standardized non-garbage value for + # the global_offset of zero-sized shards. This value is out + # of bounds of the tensor, so it won't conflict with any real + # offsets. DCP may rely on this value to de-duplicate shards. + # Note that you can end up with zero-size shards that are + # still otherwise in bounds for the tensor (TODO: give an + # example). + zero_global_offset, + # As we successively shard the same dimension, we keep + # advancing our pointer beyond our original offset until we + # get to the final chunk start. + shard_global_offset, + ) + + # NOTE: the offset compute relies on the local shard index and it has no + # problem when strided sharding is not present. To correctly compute, we assume + # that the ``_StridedShard.split_factor`` field encodes how many partitions + # each local tensor will be further split into when sharding on higher mesh + # dimensions. However, this number is only correct if the DTensor is not + # sharded after the strided sharding completes. For example, + # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements + # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on + # device mesh dim-2, and last on mesh dim-1. We define the + # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding + # part because strided sharding happens on mesh dim-1 and it was caused by + # the fact that sharding on dim-2 occurred ahead. In this case, there's no + # further sharding after this strided sharding part and ``split_factor`` + # correctly encodes the number. Another example is + # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's + # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh + # dim-2. This violates our assumption that no further sharding shall occur + # after the strided sharding part and ``split_factor`` won't correctly + # encode the number of further split. So far, the only case where _StridedShard + # placement would appear is FSDP2 + TP on 2D mesh and the above case could only + # happen on mesh of 3 or more dimensions. + # TODO: change this function to correctly address this. + # TODO: this logic can be applied to contiguous sharding as well return tuple(local_shape), tuple(global_offset) diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 65da0a7b1823b..726abc5971376 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -684,13 +684,12 @@ def _to_replicate_tensor( def _local_shard_size(sharded_indices: list[torch.Tensor], rank: int) -> int: return len(sharded_indices[rank]) - # delete pyre-ignore once separating _StridedShard from Shard - def _local_shard_size_and_offset( # pyre-ignore[bad-override] + def _local_shard_size_and_offset( self, curr_local_size: int, num_chunks: int, rank: int, - ) -> tuple[int, list[int]]: + ) -> tuple[int, Optional[int]]: # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [curr_local_size] @@ -708,9 +707,9 @@ def _local_shard_size_and_offset( # pyre-ignore[bad-override] sharded_indices = [shard.view(-1) for shard in sharded_indices] local_shard_size = _StridedShard._local_shard_size(sharded_indices, rank) - offsets = sharded_indices[rank].tolist() - return local_shard_size, offsets + # offsets from _StridedShard is never used + return local_shard_size, None class Replicate(torch._C._distributed.Replicate): From bec6e689b775385239c68690baff04c320f94f5a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 23 Nov 2025 12:03:54 -0800 Subject: [PATCH 003/338] [dynamo] Skip optree tests when optree isn't installed (#168931) This fixes an issue with the tests in fbcode Pull Request resolved: https://github.com/pytorch/pytorch/pull/168931 Approved by: https://github.com/anijain2305 --- test/dynamo/test_tree_map.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_tree_map.py b/test/dynamo/test_tree_map.py index 0e18d69129d56..dc43cca2bf65c 100644 --- a/test/dynamo/test_tree_map.py +++ b/test/dynamo/test_tree_map.py @@ -1,6 +1,9 @@ # Owner(s): ["module: dynamo"] -import optree +try: + import optree +except ImportError: # pragma: no cover + optree = None import torch import torch._dynamo @@ -46,10 +49,15 @@ def _tuple_is_leaf(node): return isinstance(node, tuple) -TREE_MAP_IMPLEMENTATIONS = [ - ("optree", optree.tree_map), - ("pytree_python", pytree.tree_map), -] +def _require_optree(test_case): + if optree is None: + test_case.skipTest("optree is unavailable") + + +TREE_MAP_IMPLEMENTATIONS = [] +if optree is not None: + TREE_MAP_IMPLEMENTATIONS.append(("optree", optree.tree_map)) +TREE_MAP_IMPLEMENTATIONS.append(("pytree_python", pytree.tree_map)) if cxx_pytree is not None: TREE_MAP_IMPLEMENTATIONS.append(("pytree_cxx", cxx_pytree.tree_map)) @@ -257,6 +265,8 @@ def fn(arg): _assert_trees_allclose(self, expected, result) def test_tree_map_none_nodes_reject_mismatched_siblings(self) -> None: + _require_optree(self) + def fn(a, b): return optree.tree_map(lambda u, v: (u, v), a, b) @@ -292,6 +302,8 @@ def fn(a, b): self.assertEqual(result, expected) def test_constantvariable_handles_none_is_leaf_kwarg(self) -> None: + _require_optree(self) + tree = {"none": None} def run_case(none_is_leaf_flag): @@ -317,6 +329,8 @@ def mapper(node): self.assertEqual(run_case(True), "visited") def test_constantvariable_handles_python_and_dtype_leaves(self) -> None: + _require_optree(self) + tree = { "int": 7, "nested": {"string": "foo", "dtype": torch.float32}, From 821047d4a4aa34d06aaca0eebfda25eeddc33862 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 23 Nov 2025 12:03:55 -0800 Subject: [PATCH 004/338] [dynamo] Fix local test failures for test/dynamo/test_after_aot.py (#168914) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168914 Approved by: https://github.com/anijain2305 ghstack dependencies: #168931 --- test/dynamo/test_after_aot.py | 34 ++++++++++++++++++++++++++++++++ test/dynamo/test_aot_compile.py | 21 ++++++++++++++++++++ torch/_dynamo/repro/after_aot.py | 14 +++++++++++++ 3 files changed, 69 insertions(+) diff --git a/test/dynamo/test_after_aot.py b/test/dynamo/test_after_aot.py index 1f8425a3ede7a..91fd1caea5de9 100644 --- a/test/dynamo/test_after_aot.py +++ b/test/dynamo/test_after_aot.py @@ -9,9 +9,11 @@ import torch._dynamo.test_case from torch._dynamo.repro.after_aot import InputReader, InputWriter, save_graph_repro +from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import IS_FBCODE from torch.utils._traceback import report_compile_source_on_error +from torch.utils._triton import has_triton def strip_trailing_whitespace(r): @@ -23,6 +25,31 @@ class TestAfterAot(torch._dynamo.test_case.TestCase): def test_save_graph_repro(self): # TODO: This triggers CUDA context initialization, even though # it is CPU only + saved_kernel_state = None + if has_triton(): + import triton + import triton.language as tl + + saved_kernel_state = ( + dict(kernel_side_table.id_to_kernel), + dict(kernel_side_table.kernel_to_id), + dict(kernel_side_table.constant_args), + ) + kernel_side_table.reset_table() + + @triton.jit + def _repro_kernel(x_ptr, y_ptr, size, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < size + tl.store( + y_ptr + offsets, + tl.load(x_ptr + offsets, mask=mask), + mask=mask, + ) + + kernel_side_table.add_kernel(_repro_kernel) + buf = io.StringIO() args = [torch.randn(4)] @@ -42,6 +69,13 @@ def f(x): with report_compile_source_on_error(): exec(r, {"__compile_source__": r}) + if saved_kernel_state is not None: + ( + kernel_side_table.id_to_kernel, + kernel_side_table.kernel_to_id, + kernel_side_table.constant_args, + ) = saved_kernel_state + @unittest.skipIf(sys.byteorder != "little", "checksum depends on endianness") def test_dump_tensor(self): def test(tensor, expected): diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index 8ab8155aa9704..f68b4443f796e 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -275,6 +275,27 @@ def fn(x, y): actual = compiled_fn(*inputs) self.assertEqual(expected, actual) + def test_aot_compile_grad_mode_after_prior_compile(self): + def warmup_fn(x, y): + return x + y + + def target_fn(x, y): + return x - y + + torch.compile(warmup_fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4), torch.randn(3, 4)), {}) + ) + torch._dynamo.reset() + + with torch.no_grad(): + compiled_fn = torch.compile(target_fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4), torch.randn(3, 4)), {}) + ) + + inputs = (torch.randn(3, 4), torch.randn(3, 4)) + with torch.no_grad(): + self.assertEqual(compiled_fn(*inputs), target_fn(*inputs)) + def test_aot_compile_source_info(self): from torch._dynamo.package import SourceInfo diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 94f3c2d689b6a..d8465541cdfa3 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -355,6 +355,20 @@ def generate_compiler_repro_string( {maybe_fbcode_instructions()} """ ) + model_str += textwrap.dedent( + """ +if "__compile_source__" in globals(): + import inspect as __after_aot_inspect + import linecache as __after_aot_linecache + __after_aot_filename = __after_aot_inspect.currentframe().f_code.co_filename + __after_aot_linecache.cache[__after_aot_filename] = ( + len(__compile_source__), + None, + __compile_source__.splitlines(True), + __after_aot_filename, + ) +""" + ) if not stable_output: model_str += f"# torch version: {torch.version.__version__}\n" if hasattr(torch.version, "cuda"): From 55b10d723763c6c61935f4c70e8c677ddb47ed13 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 23 Nov 2025 12:03:55 -0800 Subject: [PATCH 005/338] [dynamo] Fix local test failures for test_compiler_bisector.py (#168915) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168915 Approved by: https://github.com/anijain2305 ghstack dependencies: #168931, #168914 --- test/dynamo/test_compiler_bisector.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index 8ebf35f3f0d3f..ae52490c243cf 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -108,8 +108,6 @@ def pass_fn(graph: torch.fx.Graph): args[1] = 2 nodes[0].args = tuple(args) - config.pre_grad_custom_pass = pass_fn - def foo(x): return x + 1 @@ -123,7 +121,8 @@ def test_fn(): return torch.allclose(out, out_c) - out = CompilerBisector.do_bisect(test_fn) + with config.patch(pre_grad_custom_pass=pass_fn): + out = CompilerBisector.do_bisect(test_fn) self.assertEqual(out.backend, "inductor") self.assertEqual(out.subsystem, "pre_grad_passes") self.assertEqual(out.bisect_number, 0) @@ -141,8 +140,6 @@ def pass_fn(graph: torch.fx.Graph): args[1] = 2 nodes[0].args = tuple(args) - config.joint_custom_post_pass = pass_fn - def foo(x): return x + 1 @@ -156,7 +153,8 @@ def test_fn(): return torch.allclose(out, out_c) - out = CompilerBisector.do_bisect(test_fn) + with config.patch(joint_custom_post_pass=pass_fn): + out = CompilerBisector.do_bisect(test_fn) self.assertEqual(out.backend, "inductor") self.assertEqual(out.subsystem, "joint_graph_passes") self.assertEqual(out.bisect_number, 4) From 627f6c7b84d5f6eb694342e85a9df9f4e1663be1 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 23 Nov 2025 12:03:55 -0800 Subject: [PATCH 006/338] [dynamo] Fix more cases of tests leaking config changes (#168924) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168924 Approved by: https://github.com/anijain2305 ghstack dependencies: #168931, #168914, #168915 --- test/dynamo/test_decorators.py | 13 +++++++------ test/dynamo/test_dicts.py | 6 ++++-- test/dynamo/test_repros.py | 17 ++++++++--------- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 09936044bd450..0e26ff2d4140b 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -1313,12 +1313,13 @@ def fn(B): B = torch.tensor(B_list, dtype=torch.int32) torch._dynamo.decorators.mark_static(B, 0) - torch._dynamo.config.capture_scalar_outputs = True - torch._dynamo.config.capture_dynamic_output_shape_ops = True - - self.assertEqual( - fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) - ) + with torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ): + self.assertEqual( + fn(B), + torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B), + ) def test_assume_constant_result_on_computation_with_graph_input(self): @torch._dynamo.assume_constant_result diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 4a4d2ff87718f..79ead6b348a75 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -1361,11 +1361,12 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase): # ==, !=, | def setUp(self): + self._prev_trace_unittest = torch._dynamo.config.enable_trace_unittest torch._dynamo.config.enable_trace_unittest = True super().setUp() def tearDown(self): - torch._dynamo.config.enable_trace_unittest = False + torch._dynamo.config.enable_trace_unittest = self._prev_trace_unittest return super().tearDown() def assertEqual(self, x, y): @@ -1780,11 +1781,12 @@ def test_popitem_kwarg(self): class OrderedDictSubclassOverload(torch._dynamo.test_case.TestCase): def setUp(self): + self._prev_trace_unittest = torch._dynamo.config.enable_trace_unittest torch._dynamo.config.enable_trace_unittest = True super().setUp() def tearDown(self): - torch._dynamo.config.enable_trace_unittest = False + torch._dynamo.config.enable_trace_unittest = self._prev_trace_unittest return super().tearDown() def assertEqual(self, x, y): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 8eefbefe9237f..3fc5da288786e 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7094,15 +7094,14 @@ def f(image_latent): expected = f(torch.randn((2, 12, 16, 32, 32))).sum() # https://github.com/pytorch/pytorch/issues/147171 - torch._inductor.config.fallback_random = True - - for backend in ["eager", "aot_eager"]: - torch.manual_seed(54321) - torch.cuda.manual_seed_all(54321) - actual = torch.compile(backend=backend, fullgraph=True)(f)( - torch.randn((2, 12, 16, 32, 32)) - ).sum() - self.assertEqual(actual, expected) + with torch._inductor.config.patch(fallback_random=True): + for backend in ["eager", "aot_eager"]: + torch.manual_seed(54321) + torch.cuda.manual_seed_all(54321) + actual = torch.compile(backend=backend, fullgraph=True)(f)( + torch.randn((2, 12, 16, 32, 32)) + ).sum() + self.assertEqual(actual, expected) def test_incompatible_configs(self): with torch._dynamo.config.patch( From 7dcdb3c9ad666257df8fd9366e3b865107daeb6c Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 23 Nov 2025 12:03:56 -0800 Subject: [PATCH 007/338] [dynamo] Isolate test/dynamo/test_aot_compile.py tests in subproc (#168925) They are leaking state and breaking other tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/168925 Approved by: https://github.com/anijain2305 ghstack dependencies: #168931, #168914, #168915, #168924 --- test/dynamo/test_aot_compile.py | 271 +++++++++++++++++++------------- 1 file changed, 161 insertions(+), 110 deletions(-) diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index f68b4443f796e..7fcfbc68599fa 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -3,8 +3,10 @@ import copy import functools import inspect +import multiprocessing as mp import os import pickle +import tempfile import unittest from contextlib import contextmanager from unittest.mock import patch @@ -106,6 +108,162 @@ def forward(self, x): return super().forward(x) +def _subprocess_entry(fn, queue): + try: + fn() + except BaseException as exc: # noqa: BLE001 + import traceback + + queue.put((type(exc).__name__, str(exc), traceback.format_exc())) + raise + else: + queue.put(None) + + +def _run_in_subprocess(fn): + ctx = mp.get_context("spawn") + queue = ctx.Queue() + proc = ctx.Process(target=_subprocess_entry, args=(fn, queue)) + proc.start() + proc.join() + result = queue.get() + if result is not None: + name, msg, tb = result + raise AssertionError(f"Subprocess failure ({name}: {msg})\n{tb}") + + +def _subprocess_disable_guard_check(): + import torch + from torch._dynamo import config + + with config.patch(enable_aot_compile=True): + + def fn(x, y): + return x + y + + compiled_fn = torch.compile(fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4), torch.randn(3, 4)), {}) + ) + inputs = (torch.randn(3, 4), torch.randn(3, 4)) + expected = fn(*inputs) + prev_grad = torch.is_grad_enabled() + try: + torch.set_grad_enabled(not prev_grad) + try: + compiled_fn(*inputs) + except RuntimeError as exc: # pragma: no cover + if "GuardManager check failed" not in str(exc): + raise + else: # pragma: no cover + raise AssertionError("Guard check should have failed") + compiled_fn.disable_guard_check() + actual = compiled_fn(*inputs) + assert torch.allclose(actual, expected) + finally: + torch.set_grad_enabled(prev_grad) + + +def _subprocess_grad_mode_after_prior_compile(): + import torch + from torch._dynamo import config + + with config.patch(enable_aot_compile=True): + + def warmup_fn(x, y): + return x + y + + def target_fn(x, y): + return x - y + + torch.compile(warmup_fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4), torch.randn(3, 4)), {}) + ) + torch._dynamo.reset() + + with torch.no_grad(): + compiled_fn = torch.compile(target_fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4), torch.randn(3, 4)), {}) + ) + + inputs = (torch.randn(3, 4), torch.randn(3, 4)) + with torch.no_grad(): + actual = compiled_fn(*inputs) + expected = target_fn(*inputs) + assert torch.allclose(actual, expected) + + +def _subprocess_aot_compile_module(): + import torch + from torch._dynamo import config + + with config.patch(enable_aot_compile=True): + mod = SimpleLinearModule() + model = torch.compile( + mod, + fullgraph=True, + backend="inductor", + options={ + "guard_filter_fn": torch.compiler.skip_guard_on_globals_unsafe, + }, + ) + + @contextmanager + def train_mode(mdl): + mdl.train() + yield + + @contextmanager + def eval_mode(mdl): + mdl.eval() + yield + + inputs = [ + ModelInput( + args=(torch.randn(3, 3),), + kwargs={}, + contexts=[torch.no_grad(), eval_mode(model)], + ), + ModelInput( + args=(torch.randn(3, 3),), kwargs={}, contexts=[train_mode(model)] + ), + ] + assert isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + model._aot_compile(inputs) + + with torch.compiler.set_stance("fail_on_recompile"): + model.eval() + eager_inputs = (torch.randn(3, 3),) + expected = mod(*eager_inputs) + actual = model(*eager_inputs) + assert torch.allclose(expected, actual) + model.train() + expected.sum().backward() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "model.pt") + model._save_aot_compiled_module(path) + torch._dynamo.reset() + model = torch.compile( + mod, + fullgraph=True, + backend="inductor", + options={ + "guard_filter_fn": torch.compiler.skip_guard_on_globals_unsafe, + }, + ) + assert isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + with open(path, "rb") as f: + data = f.read() + model._load_aot_compiled_module(data) + + with torch.compiler.set_stance("fail_on_recompile"): + model.eval() + eager_inputs = (torch.randn(3, 3),) + expected = mod(*eager_inputs) + actual = model(*eager_inputs) + assert torch.allclose(expected, actual) + + @torch._dynamo.config.patch("enable_aot_compile", True) @instantiate_parametrized_tests class TestAOTCompile(torch._inductor.test_case.TestCase): @@ -260,41 +418,10 @@ def backend(gm, example_inputs): self.assertEqual(expected, actual) def test_aot_compile_disable_guard_check(self): - def fn(x, y): - return x + y - - with torch.no_grad(): - compiled_fn = torch.compile(fn, fullgraph=True).aot_compile( - ((torch.randn(3, 4), torch.randn(3, 4)), {}) - ) - inputs = (torch.randn(3, 4), torch.randn(3, 4)) - expected = fn(*inputs) - with self.assertRaisesRegex(RuntimeError, "GuardManager check failed"): - compiled_fn(*inputs) - compiled_fn.disable_guard_check() - actual = compiled_fn(*inputs) - self.assertEqual(expected, actual) + _run_in_subprocess(_subprocess_disable_guard_check) def test_aot_compile_grad_mode_after_prior_compile(self): - def warmup_fn(x, y): - return x + y - - def target_fn(x, y): - return x - y - - torch.compile(warmup_fn, fullgraph=True).aot_compile( - ((torch.randn(3, 4), torch.randn(3, 4)), {}) - ) - torch._dynamo.reset() - - with torch.no_grad(): - compiled_fn = torch.compile(target_fn, fullgraph=True).aot_compile( - ((torch.randn(3, 4), torch.randn(3, 4)), {}) - ) - - inputs = (torch.randn(3, 4), torch.randn(3, 4)) - with torch.no_grad(): - self.assertEqual(compiled_fn(*inputs), target_fn(*inputs)) + _run_in_subprocess(_subprocess_grad_mode_after_prior_compile) def test_aot_compile_source_info(self): from torch._dynamo.package import SourceInfo @@ -404,83 +531,7 @@ def fn(x, y): self.assertEqual(expected, actual) def test_aot_compile_module(self): - mod = SimpleLinearModule() - - model = torch.compile( - mod, - fullgraph=True, - backend="inductor", - options={ - "guard_filter_fn": torch.compiler.skip_guard_on_globals_unsafe, - }, - ) - - @contextmanager - def train_mode(model): - """ - Context manager that sets the model to training mode before entering the context. - """ - model.train() - yield - - @contextmanager - def eval_mode(model): - """ - Context manager that sets the model to evaluation mode before entering the context. - """ - model.eval() - yield - - inputs = [ - ModelInput( - args=(torch.randn(3, 3),), - kwargs={}, - contexts=[torch.no_grad(), eval_mode(model)], - ), - ModelInput( - args=(torch.randn(3, 3),), kwargs={}, contexts=[train_mode(model)] - ), - ] - assert isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - model._aot_compile( - inputs, - ) - with torch.compiler.set_stance("fail_on_recompile"): - model.eval() - inputs = (torch.randn(3, 3),) - expected = mod(*inputs) - actual = model(*inputs) - self.assertEqual(expected, actual) - - # Shouldn't recompile - model.train() - expected.sum().backward() - - model._save_aot_compiled_module(self.path()) - torch._dynamo.reset() - model = torch.compile( - mod, - fullgraph=True, - backend="inductor", - options={ - "guard_filter_fn": torch.compiler.skip_guard_on_globals_unsafe, - }, - ) - assert isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - with open(self.path(), "rb") as f: - data = f.read() - model._load_aot_compiled_module(data) - - with torch.compiler.set_stance("fail_on_recompile"): - model.eval() - inputs = (torch.randn(3, 3),) - expected = mod(*inputs) - actual = model(*inputs) - self.assertEqual(expected, actual) - - # Shouldn't recompile - model.train() - expected.sum().backward() + _run_in_subprocess(_subprocess_aot_compile_module) def test_aot_module_simplified_serializable_autograd(self): mod = SimpleLinearModule() From fd14c10030ad2ef1944cfb59bdd2b514fd5ebb1e Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 23 Nov 2025 12:03:56 -0800 Subject: [PATCH 008/338] [dynamo] Make test custom op naming unique (#168926) Test would fail because op names were already in use. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168926 Approved by: https://github.com/anijain2305 ghstack dependencies: #168931, #168914, #168915, #168924, #168925 --- test/dynamo/test_aot_autograd.py | 22 +++++++++++++++------- test/dynamo/test_error_messages.py | 10 ++++++---- test/dynamo/test_misc.py | 16 +++++++++------- test/dynamo/test_modes.py | 4 ++-- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 568bf23a4d196..cb9a646134a9d 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -1165,9 +1165,13 @@ def test_data_ptr_access_copy(self): def test_data_ptr_access_fails_in_forward(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib) + torch.library.define( + "mylib::foo_data_ptr_forward", "(Tensor x) -> Tensor", lib=lib + ) - @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib) + @torch.library.impl( + "mylib::foo_data_ptr_forward", "CompositeImplicitAutograd", lib=lib + ) def _(x): x.data_ptr() return x.clone() @@ -1175,12 +1179,12 @@ def _(x): x = torch.randn(3) def data_ptr_graph_input(x): - r0 = torch.ops.mylib.foo(x) + r0 = torch.ops.mylib.foo_data_ptr_forward(x) return r0 def data_ptr_graph_intermediate(x): y = x.clone() - r0 = torch.ops.mylib.foo(y) + r0 = torch.ops.mylib.foo_data_ptr_forward(y) return r0 tests = [data_ptr_graph_input, data_ptr_graph_intermediate] @@ -1200,7 +1204,9 @@ def ctx(): def test_data_ptr_access_fails_in_backward(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib) + torch.library.define( + "mylib::foo_data_ptr_backward", "(Tensor x) -> Tensor", lib=lib + ) backward_called = False @@ -1216,12 +1222,14 @@ def backward(ctx, grad): grad.data_ptr() return grad.clone() - @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib) + @torch.library.impl( + "mylib::foo_data_ptr_backward", "CompositeImplicitAutograd", lib=lib + ) def _(x): return Foo.apply(x) def f(x): - return torch.ops.mylib.foo(x) + return torch.ops.mylib.foo_data_ptr_backward(x) x = torch.randn(3, requires_grad=True) with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"): diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 49f787bd25cd6..cdc87813d9151 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -812,7 +812,9 @@ def post_munge(s): ) def test_faketensor_nyi(self): - @torch.library.custom_op("mylib::foo", mutates_args=()) + op_name = "mylib::error_messages_faketensor" + + @torch.library.custom_op(op_name, mutates_args=()) def foo(x: torch.Tensor) -> torch.Tensor: return x.sin() @@ -821,14 +823,14 @@ def _(x): raise NotImplementedError def fn(x): - return torch.ops.mylib.foo(x) + return torch.ops.mylib.error_messages_faketensor(x) self.assertExpectedInlineMunged( Unsupported, lambda: torch.compile(fn, backend="eager", fullgraph=True)(torch.randn(3)), """\ NotImplementedError/UnsupportedFakeTensorException when running FX node - Explanation: Dynamo failed to run FX node with fake tensors: call_function mylib.foo(*(FakeTensor(..., size=(3,)),), **{}): got NotImplementedError() + Explanation: Dynamo failed to run FX node with fake tensors: call_function mylib.error_messages_faketensor(*(FakeTensor(..., size=(3,)),), **{}): got NotImplementedError() Hint: If the op is a PyTorch op, please file an issue to PyTorch. Developer debug context: @@ -837,7 +839,7 @@ def fn(x): from user code: File "test_error_messages.py", line N, in fn - return torch.ops.mylib.foo(x)""", + return torch.ops.mylib.error_messages_faketensor(x)""", ) def test_data_dependent_branching_fullgraph(self): diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 781e95e0c7c95..a03537ad7d186 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -790,20 +790,22 @@ def fn(x, other_fn): def test_generate_trivial_abstract_impl(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( - "mylib::foo", + "mylib::foo_generate_trivial_abstract_impl", "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) - @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch.library.impl( + "mylib::foo_generate_trivial_abstract_impl", "cpu", lib=lib + ) @torch._dynamo.disable def foo_impl(x, y, z, w): x + y[0] + w return def f(x, y, z, w): - return torch.ops.mylib.foo(x, y, z, 2) + return torch.ops.mylib.foo_generate_trivial_abstract_impl(x, y, z, 2) x = torch.randn(3) y = (torch.randn(3), torch.randn(3)) @@ -10146,14 +10148,14 @@ def f(x, i): def test_validate_outputs_unbacked_by_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( - "mylib::foo", + "mylib::foo_validate_outputs_unbacked", "(Tensor a, Tensor b) -> (Tensor)", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch.library.register_fake("mylib::foo") + @torch.library.impl("mylib::foo_validate_outputs_unbacked", "cpu", lib=lib) + @torch.library.register_fake("mylib::foo_validate_outputs_unbacked") def foo_impl(x, y): return torch.cat([x, y]) @@ -10161,7 +10163,7 @@ def foo_impl(x, y): def f(x, i): i0, i1 = i.tolist() x0, x1 = x.split([i0, i1]) - return torch.ops.mylib.foo(x0, x1) + return torch.ops.mylib.foo_validate_outputs_unbacked(x0, x1) f(torch.randn(9, requires_grad=True), torch.tensor([3, 6])) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 82c87bde8c0ba..f163e7169bfa3 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -70,7 +70,7 @@ def test_torch_dispatch_ignore_compile_internals(self): counters.clear() from torch.utils._python_dispatch import TorchDispatchMode - @torch.library.custom_op("mylib::foo", mutates_args=()) + @torch.library.custom_op("mylib::modes_checksum", mutates_args=()) def foo(x: torch.Tensor) -> torch.Tensor: return x.clone() @@ -90,7 +90,7 @@ def __init__(self) -> None: def __torch_dispatch__(self, func, types, args, kwargs=None): kwargs = kwargs or {} - if func is torch.ops.mylib.foo.default: + if func is torch.ops.mylib.modes_checksum.default: # Do some compute, smoketest to see if there's a bad interaction _checksums.append(args[0].abs().sum()) From 35944cb42369321f07d8280e1ce6cde5e7ec1107 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 24 Nov 2025 18:16:40 +0000 Subject: [PATCH 009/338] Revert "Checking if the input is finite before calculation in lowering of pow func (#167723)" This reverts commit f1c49c9372b9af1063b98a70c8528969e68ba04d. Reverted https://github.com/pytorch/pytorch/pull/167723 on behalf of https://github.com/yangw-dev due to break trunk inductor tests test/inductor/test_triton_cpu_backend.py ([comment](https://github.com/pytorch/pytorch/pull/167723#issuecomment-3572098649)) --- test/inductor/test_torchinductor.py | 35 ----------------------------- torch/_inductor/lowering.py | 2 +- 2 files changed, 1 insertion(+), 36 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 3bc1dba12acd8..b1cea5eac77d7 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5528,32 +5528,6 @@ def fn(x): check_lowp=not is_halide_backend(self.device), # misaligned addr fp16 ) - def test_lp_pool1d_with_inf_norm(self): - # https://github.com/pytorch/pytorch/issues/167197 - # Test that LPPool1d works with infinity norm (should behave like max pooling) - def fn(x): - return torch.nn.functional.lp_pool1d( - x, norm_type=float("inf"), kernel_size=2, stride=2 - ) - - self.common( - fn, - (torch.randn(3, 4, 8),), - ) - - def test_lp_pool2d_with_inf_norm(self): - # https://github.com/pytorch/pytorch/issues/167197 - # Test that LPPool2d works with infinity norm (should behave like max pooling) - def fn(x): - return torch.nn.functional.lp_pool2d( - x, norm_type=float("inf"), kernel_size=2, stride=2 - ) - - self.common( - fn, - (torch.randn(3, 4, 8, 8),), - ) - @tf32_on_and_off(0.006) @skip_if_gpu_halide # slow def test_alexnet_prefix(self): @@ -6333,15 +6307,6 @@ def fn(x): x = torch.randn([16, 16], device=self.device) self.assertEqual(cfn(x), fn(x)) - def test_pow_infinite(self): - def fn(a, b): - return torch.pow(a, b) - - opt = torch.compile(fn, backend="inductor") - a = torch.randn((3, 4, 8), device=self.device) - b = float("inf") - self.assertTrue(same(opt(a, b), fn(a, b))) - def test_glu(self): def fn(x): return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d9890f1958edd..090265d208c92 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -6361,7 +6361,7 @@ def pow_native(a, b): @register_lowering(aten.pow, broadcast=True) def pow(a, b): - if isinstance(b, float) and math.isfinite(b) and b == int(b): + if isinstance(b, float) and b == int(b): return pow(a, int(b)) elif isinstance(b, float) and b == 0.5: return sqrt(a) From 33d4cf4fcb7f0cba6191b242dae53b48057e05b9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 24 Nov 2025 18:18:29 +0000 Subject: [PATCH 010/338] Revert "Move CUDAEvent to c10 (#158219)" This reverts commit 4909fd89dcd0016b367e2ab4ff7f594d5e7f1b9e. Reverted https://github.com/pytorch/pytorch/pull/158219 on behalf of https://github.com/jeffdaily due to broke ROCm dynamo inductor benchmarks on ciflow/inductor-periodic label which wasn't run by default for this PR ([comment](https://github.com/pytorch/pytorch/pull/158219#issuecomment-3572110617)) --- aten/src/ATen/cuda/CUDAEvent.h | 254 +++++++++++++++- .../hip/impl/HIPEventMasqueradingAsCUDA.h | 86 ------ c10/cuda/CMakeLists.txt | 1 - c10/cuda/CUDAEvent.h | 278 ------------------ torch/utils/hipify/cuda_to_hip_mappings.py | 11 - 5 files changed, 249 insertions(+), 381 deletions(-) delete mode 100644 aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h delete mode 100644 c10/cuda/CUDAEvent.h diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h index 73340604574ad..7a650b9cbcf35 100644 --- a/aten/src/ATen/cuda/CUDAEvent.h +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -3,15 +3,259 @@ #include #include #include -#include +#include #include +#include +#include + +#include + +#include +#include + +/* +* `cudaEventExternal` is a torch-specific flag that is used to +* indicate that the CUDAEvent will be used only for synchronization +* with work outside of the cuda graph, rather than creation of +* cross-stream dependencies within a cuda graph. Resources: +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e +*/ +#define cudaEventExternal 0x08 namespace at::cuda { -// EventPool - Thread-safe pool of CUDA events to avoid expensive -// cudaEventCreate calls. cudaEventCreate when concurrently invoked from -// multiple threads can be very expensive (especially on certain device/driver -// combinations). +/* +* CUDAEvents are movable not copyable wrappers around CUDA's events. +* +* CUDAEvents are constructed lazily when first recorded unless it is +* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this +* device is acquired from the first recording stream. However, if reconstructed +* from a handle, the device should be explicitly specified; or if ipc_handle() is +* called before the event is ever recorded, it will use the current device. +* Later streams that record the event must match this device. +*/ +struct TORCH_CUDA_CPP_API CUDAEvent { + // Constructors + // Default value for `flags` is specified below - it's cudaEventDisableTiming + CUDAEvent() noexcept = default; + CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} + + CUDAEvent( + DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) { + CUDAGuard guard(device_index_); + + AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); + is_created_ = true; + } + + // Note: event destruction done on creating device to avoid creating a + // CUDA context on other devices. + ~CUDAEvent() { + try { + if (is_created_) { + CUDAGuard guard(device_index_); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast(event_)); + } + AT_CUDA_CHECK(cudaEventDestroy(event_)); + } + } catch (...) { /* No throw */ } + } + + CUDAEvent(const CUDAEvent&) = delete; + CUDAEvent& operator=(const CUDAEvent&) = delete; + + CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); } + CUDAEvent& operator=(CUDAEvent&& other) noexcept { + if (this != &other) { + moveHelper(std::move(other)); + } + return *this; + } + + operator cudaEvent_t() const { return event(); } + + // Less than operator (to allow use in sets) + friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { + return left.event_ < right.event_; + } + + std::optional device() const { + if (is_created_) { + return at::Device(at::kCUDA, device_index_); + } else { + return {}; + } + } + + bool isCreated() const { return is_created_; } + DeviceIndex device_index() const {return device_index_;} + cudaEvent_t event() const { return event_; } + + // Note: cudaEventQuery can be safely called from any device + bool query() const { + if (!is_created_) { + return true; + } + + cudaError_t err = cudaEventQuery(event_); + if (err == cudaSuccess) { + return true; + } else if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + + return false; + } + + void record() { record(getCurrentCUDAStream()); } + + void recordOnce(const CUDAStream& stream) { + if (!was_recorded_) record(stream); + } + + // Note: cudaEventRecord must be called on the same device as the event. + void record(const CUDAStream& stream) { + if (!is_created_) { + createEvent(stream.device_index()); + } + + TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_, + " does not match recording stream's device ", stream.device_index(), "."); + CUDAGuard guard(device_index_); + +#ifndef USE_ROCM + // it is an error to use cudaEventRecordExternal when not doing stream capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventRecordExternal : cudaEventRecordDefault; + AT_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); +#else + AT_CUDA_CHECK(cudaEventRecord(event_, stream)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record(at::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream()) + ); + } + was_recorded_ = true; + } + + // Note: cudaStreamWaitEvent must be called on the same device as the stream. + // The event has no actual GPU resources associated with it. + void block(const CUDAStream& stream) { + if (is_created_) { + CUDAGuard guard(stream.device_index()); +#ifndef USE_ROCM + // it is an error to use cudaEventWaitExternal when not doing stream capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventWaitExternal : cudaEventWaitDefault; + AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); +#else + AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait(at::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream()) + ); + } + } + } + + // Note: cudaEventElapsedTime can be safely called from any device + float elapsed_time(const CUDAEvent& other) const { + TORCH_CHECK_VALUE( + !(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + + float time_ms = 0; + // We do not strictly have to set the device index to the same as our event, + // but if we don't and the current device is not initialized, it will + // create a new cuda context, which will consume a lot of memory. + CUDAGuard guard(device_index_); + // raise cudaErrorNotReady if either event is recorded but not yet completed + AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); + return time_ms; + } + + // Note: cudaEventSynchronize can be safely called from any device + void synchronize() const { + if (is_created_) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast(event_)); + } + AT_CUDA_CHECK(cudaEventSynchronize(event_)); + } + } + + // Note: cudaIpcGetEventHandle must be called on the same device as the event + void ipc_handle(cudaIpcEventHandle_t * handle) { + if (!is_created_) { + // this CUDAEvent object was initially constructed from flags but event_ + // is not created yet. + createEvent(getCurrentCUDAStream().device_index()); + } + CUDAGuard guard(device_index_); + AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); + } + +private: + unsigned int flags_ = cudaEventDisableTiming; + bool is_created_ = false; + bool was_recorded_ = false; + bool external_ = false; + DeviceIndex device_index_ = -1; + cudaEvent_t event_{}; + + void createEvent(DeviceIndex device_index) { + external_ = (flags_ & cudaEventExternal) != 0; +#ifdef USE_ROCM + TORCH_CHECK(!external_, "External events are disallowed in rocm"); +#endif + flags_ &= ~cudaEventExternal; + device_index_ = device_index; + CUDAGuard guard(device_index_); + AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast(event_)); + } + is_created_ = true; + } + + void moveHelper(CUDAEvent&& other) { + // Transfer ownership of all state from other to this + flags_ = other.flags_; + is_created_ = other.is_created_; + was_recorded_ = other.was_recorded_; + external_ = other.external_; + device_index_ = other.device_index_; + event_ = other.event_; + + // Reset other to a valid empty state to prevent double-free + // The moved-from object must not attempt to destroy the event + other.is_created_ = false; + other.event_ = cudaEvent_t{}; + } +}; + +// EventPool - Thread-safe pool of CUDA events to avoid expensive cudaEventCreate +// calls. cudaEventCreate when concurrently invoked from multiple threads can be +// very expensive (especially on certain device/driver combinations). using CUDAEventPtr = std::unique_ptr>; diff --git a/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h deleted file mode 100644 index f2741a32889fb..0000000000000 --- a/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h +++ /dev/null @@ -1,86 +0,0 @@ -#pragma once - -#include - -// Use of c10::hip namespace here makes hipification easier, because -// I don't have to also fix namespaces. Sorry! -namespace c10 { namespace hip { - -// See Note [Masquerading as CUDA] for motivation - -struct HIPEventMasqueradingAsCUDA { - HIPEventMasqueradingAsCUDA() noexcept = default; - HIPEventMasqueradingAsCUDA(unsigned int flags) noexcept - : event_(HIPEvent(flags)) {} - HIPEventMasqueradingAsCUDA( - DeviceIndex device_index, - const hipIpcEventHandle_t* handle) - : event_(HIPEvent(device_index, handle)) {} - - ~HIPEventMasqueradingAsCUDA() = default; - - HIPEventMasqueradingAsCUDA(const HIPEventMasqueradingAsCUDA&) = delete; - HIPEventMasqueradingAsCUDA& operator=(const HIPEventMasqueradingAsCUDA&) = delete; - HIPEventMasqueradingAsCUDA(HIPEventMasqueradingAsCUDA&& other) noexcept = default; - HIPEventMasqueradingAsCUDA& operator=(HIPEventMasqueradingAsCUDA&& other) noexcept = default; - - operator hipEvent_t() const { - return event_.event(); - } - - // Less than operator (to allow use in sets) - friend bool operator<( - const HIPEventMasqueradingAsCUDA& left, - const HIPEventMasqueradingAsCUDA& right) { - return left.event_ < right.event_; - } - - std::optional device() const { - // Unsafely coerce HIP device into CUDA device - return Device(c10::DeviceType::CUDA, event_.device_index()); - } - bool isCreated() const { - return event_.isCreated(); - } - DeviceIndex device_index() const { - return event_.device_index(); - } - hipEvent_t event() const { - return event_.event(); - } - bool query() const { - return event_.query(); - } - void record() { - return event_.record(); - } - - void recordOnce(const HIPStreamMasqueradingAsCUDA& stream) { - event_.recordOnce(stream.hip_stream()); - } - - void record(const HIPStreamMasqueradingAsCUDA& stream) { - event_.record(stream.hip_stream()); - } - - void block(const HIPStreamMasqueradingAsCUDA& stream) { - event_.block(stream.hip_stream()); - } - - float elapsed_time(const HIPEventMasqueradingAsCUDA& other) const { - return event_.elapsed_time(other.event_); - } - - void synchronize() const { - event_.synchronize(); - } - - void ipc_handle(hipIpcEventHandle_t* handle) { - event_.ipc_handle(handle); - } - - private: - HIPEvent event_; -}; - -}} // namespace c10::hip diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index fd80c45fcc79e..2604f677858d1 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -43,7 +43,6 @@ set(C10_CUDA_HEADERS CUDACachingAllocator.h CUDADeviceAssertionHost.h CUDAException.h - CUDAEvent.h CUDAFunctions.h CUDAGuard.h CUDAMacros.h diff --git a/c10/cuda/CUDAEvent.h b/c10/cuda/CUDAEvent.h deleted file mode 100644 index 6e5205044879f..0000000000000 --- a/c10/cuda/CUDAEvent.h +++ /dev/null @@ -1,278 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -/* - * `cudaEventExternal` is a torch-specific flag that is used to - * indicate that the CUDAEvent will be used only for synchronization - * with work outside of the cuda graph, rather than creation of - * cross-stream dependencies within a cuda graph. Resources: - * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events - * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 - * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e - */ -#define cudaEventExternal 0x08 - -namespace c10::cuda { - -/* - * CUDAEvents are movable not copyable wrappers around CUDA's events. - * - * CUDAEvents are constructed lazily when first recorded unless it is - * reconstructed from a cudaIpcEventHandle_t. The event has a device, and this - * device is acquired from the first recording stream. However, if reconstructed - * from a handle, the device should be explicitly specified; or if ipc_handle() - * is called before the event is ever recorded, it will use the current device. - * Later streams that record the event must match this device. - */ -struct CUDAEvent { - // Constructors - // Default value for `flags` is specified below - it's cudaEventDisableTiming - CUDAEvent() noexcept = default; - CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} - - CUDAEvent(DeviceIndex device_index, const cudaIpcEventHandle_t* handle) - : device_index_(device_index) { - CUDAGuard guard(device_index_); - - C10_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); - is_created_ = true; - } - - // Note: event destruction done on creating device to avoid creating a - // CUDA context on other devices. - ~CUDAEvent() { - if (is_created_) { - CUDAGuard guard(device_index_); - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_deletion( - c10::kCUDA, reinterpret_cast(event_)); - } - C10_CUDA_CHECK_WARN(cudaEventDestroy(event_)); - } - } - - CUDAEvent(const CUDAEvent&) = delete; - CUDAEvent& operator=(const CUDAEvent&) = delete; - - CUDAEvent(CUDAEvent&& other) noexcept { - moveHelper(std::move(other)); - } - CUDAEvent& operator=(CUDAEvent&& other) noexcept { - if (this != &other) { - moveHelper(std::move(other)); - } - return *this; - } - - operator cudaEvent_t() const { - return event(); - } - - // Less than operator (to allow use in sets) - friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { - return left.event_ < right.event_; - } - - std::optional device() const { - if (is_created_) { - return c10::Device(c10::kCUDA, device_index_); - } else { - return {}; - } - } - - bool isCreated() const { - return is_created_; - } - DeviceIndex device_index() const { - return device_index_; - } - cudaEvent_t event() const { - return event_; - } - - // Note: cudaEventQuery can be safely called from any device - bool query() const { - if (!is_created_) { - return true; - } - - cudaError_t err = cudaEventQuery(event_); - if (err == cudaSuccess) { - return true; - } else if (err != cudaErrorNotReady) { - C10_CUDA_CHECK(err); - } else { - // ignore and clear the error if not ready - (void)cudaGetLastError(); - } - - return false; - } - - void record() { - record(getCurrentCUDAStream()); - } - - void recordOnce(const CUDAStream& stream) { - if (!was_recorded_) - record(stream); - } - - // Note: cudaEventRecord must be called on the same device as the event. - void record(const CUDAStream& stream) { - if (!is_created_) { - createEvent(stream.device_index()); - } - - TORCH_CHECK( - device_index_ == stream.device_index(), - "Event device ", - device_index_, - " does not match recording stream's device ", - stream.device_index(), - "."); - CUDAGuard guard(device_index_); - -#ifndef USE_ROCM - // it is an error to use cudaEventRecordExternal when not doing stream - // capture - unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != - c10::cuda::CaptureStatus::None && - external_) - ? cudaEventRecordExternal - : cudaEventRecordDefault; - C10_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); -#else - C10_CUDA_CHECK(cudaEventRecord(event_, stream)); -#endif - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_record( - c10::kCUDA, - reinterpret_cast(event_), - reinterpret_cast(stream.stream())); - } - was_recorded_ = true; - } - - // Note: cudaStreamWaitEvent must be called on the same device as the stream. - // The event has no actual GPU resources associated with it. - void block(const CUDAStream& stream) { - if (is_created_) { - CUDAGuard guard(stream.device_index()); -#ifndef USE_ROCM - // it is an error to use cudaEventWaitExternal when not doing stream - // capture - unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != - c10::cuda::CaptureStatus::None && - external_) - ? cudaEventWaitExternal - : cudaEventWaitDefault; - C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); -#else - C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); -#endif - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_wait( - c10::kCUDA, - reinterpret_cast(event_), - reinterpret_cast(stream.stream())); - } - } - } - - // Note: cudaEventElapsedTime can be safely called from any device - float elapsed_time(const CUDAEvent& other) const { - TORCH_CHECK_VALUE( - !(flags_ & cudaEventDisableTiming) && - !(other.flags_ & cudaEventDisableTiming), - "Both events must be created with argument 'enable_timing=True'."); - TORCH_CHECK_VALUE( - is_created_ && other.isCreated(), - "Both events must be recorded before calculating elapsed time."); - TORCH_CHECK( - query() && other.query(), - "Both events must be completed before calculating elapsed time."); - - float time_ms = 0; - // We do not strictly have to set the device index to the same as our event, - // but if we don't and the current device is not initialized, it will - // create a new cuda context, which will consume a lot of memory. - CUDAGuard guard(device_index_); - // raise cudaErrorNotReady if either event is recorded but not yet completed - C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); - return time_ms; - } - - // Note: cudaEventSynchronize can be safely called from any device - void synchronize() const { - if (is_created_) { - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_synchronization( - c10::kCUDA, reinterpret_cast(event_)); - } - C10_CUDA_CHECK(cudaEventSynchronize(event_)); - } - } - - // Note: cudaIpcGetEventHandle must be called on the same device as the event - void ipc_handle(cudaIpcEventHandle_t* handle) { - if (!is_created_) { - // this CUDAEvent object was initially constructed from flags but event_ - // is not created yet. - createEvent(getCurrentCUDAStream().device_index()); - } - CUDAGuard guard(device_index_); - C10_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); - } - - private: - unsigned int flags_ = cudaEventDisableTiming; - bool is_created_ = false; - bool was_recorded_ = false; - bool external_ = false; - DeviceIndex device_index_ = -1; - cudaEvent_t event_{}; - - void createEvent(DeviceIndex device_index) { - external_ = (flags_ & cudaEventExternal) != 0; -#ifdef USE_ROCM - TORCH_CHECK(!external_, "External events are disallowed in rocm"); -#endif - flags_ &= ~cudaEventExternal; - device_index_ = device_index; - CUDAGuard guard(device_index_); - C10_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_creation( - c10::kCUDA, reinterpret_cast(event_)); - } - is_created_ = true; - } - - void moveHelper(CUDAEvent&& other) { - // Transfer ownership of all state from other to this - flags_ = other.flags_; - is_created_ = other.is_created_; - was_recorded_ = other.was_recorded_; - external_ = other.external_; - device_index_ = other.device_index_; - event_ = other.event_; - - // Reset other to a valid empty state to prevent double-free - // The moved-from object must not attempt to destroy the event - other.is_created_ = false; - other.event_ = cudaEvent_t{}; - } -}; - -} // namespace c10::cuda diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 18afecd18c9be..fb7dc1c7cb7f0 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -9231,8 +9231,6 @@ API_PYTORCH, ), ), - ("cuda::CUDAEvent", ("hip::HIPEventMasqueradingAsCUDA", API_PYTORCH)), - ("CUDAEvent", ("HIPEventMasqueradingAsCUDA", API_PYTORCH)), ("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)), ("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)), ( @@ -9287,14 +9285,6 @@ "c10/cuda/CUDACachingAllocator.h", ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH), ), - ( - "ATen/cuda/CUDAEvent.h", # To keep BC, we have to keep this mapping - ("ATen/hip/HIPEvent.h", API_PYTORCH), - ), - ( - "c10/cuda/CUDAEvent.h", - ("ATen/hip/impl/HIPEventMasqueradingAsCUDA.h", API_PYTORCH), - ), ( "c10/cuda/CUDAStream.h", ("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH), @@ -9435,7 +9425,6 @@ ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), ("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)), - ("c10/cuda/CUDAEvent.h", ("c10/hip/HIPEvent.h", API_C10)), ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), ("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)), ("c10/cuda/CUDAAllocatorConfig.h", ("c10/hip/HIPAllocatorConfig.h", API_C10)), From 89891302d444d9d84e9072f78767c77eceffbfa2 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 21 Nov 2025 16:32:42 -0800 Subject: [PATCH 011/338] Fix local_map default partitioner issue (#168396) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168396 Approved by: https://github.com/xmfan ghstack dependencies: #168289 --- torch/_higher_order_ops/local_map.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/_higher_order_ops/local_map.py b/torch/_higher_order_ops/local_map.py index 7970acbc5d6ad..1d4ad631ea102 100644 --- a/torch/_higher_order_ops/local_map.py +++ b/torch/_higher_order_ops/local_map.py @@ -334,6 +334,13 @@ def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]: static_lifetime_input_indices=[], ) + # Fix tags because min-cut does not respect fw/bw boundary, breaking + # default partitioner's assumptions. + for node in new_fw_gm.graph.nodes: + node.meta["partitioner_tag"] = "is_forward" + for node in new_bw_gm.graph.nodes: + node.meta["partitioner_tag"] = "is_backward" + # Propagate meta onto fw/bw graphs, later will be set on proxied nodes new_fw_gm.meta["local_map_kwargs"] = local_map_kwargs new_bw_gm.meta["local_map_kwargs"] = {**local_map_kwargs} From 0ecdf68d051b6e31e09b08b3d9d470b98730e68b Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Mon, 24 Nov 2025 19:25:46 +0000 Subject: [PATCH 012/338] [CI][CUDA][B200][Smoke Test] Add one more B200 smoke test. Change periodic frequency from every 6 hours to every 2 hours (#168990) Fix low utilization issue for linux.dgx.b200. linux.dgx.b200.8 is much busier. According to https://hud.pytorch.org/runners/pytorch?search=dgx.b200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168990 Approved by: https://github.com/drisspg --- .ci/pytorch/test.sh | 1 + .github/workflows/test-b200.yml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 7e25c8c6d199c..fa884ecf2b52a 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -354,6 +354,7 @@ test_python_smoke_b200() { nn/attention/test_fa4 \ nn/attention/test_open_registry \ inductor/test_flex_flash \ + inductor/test_torchinductor \ $PYTHON_TEST_EXTRA_OPTION \ --upload-artifacts-while-running assert_git_not_dirty diff --git a/.github/workflows/test-b200.yml b/.github/workflows/test-b200.yml index 7cc935f46d6c8..54acc686d1ae4 100644 --- a/.github/workflows/test-b200.yml +++ b/.github/workflows/test-b200.yml @@ -23,7 +23,7 @@ on: - .github/workflows/test-b200.yml workflow_dispatch: schedule: - - cron: 0 4,10,16,22 * * * # every 6 hours + - cron: 0 */2 * * * # every 2 hours push: tags: - ciflow/b200/* From 09de09e01efd95c4fd85daa67b8d1f5ef4030c1f Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 24 Nov 2025 08:34:15 -0800 Subject: [PATCH 013/338] [AOTI] Skip emit_multi_arch_kernel when CUDA version is lower than 12.8 (#168985) Summary: Fix https://github.com/pytorch/pytorch/issues/168353. aot_inductor.emit_multi_arch_kernel requires a newer CUDA version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168985 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 4 ++-- test/inductor/test_aot_inductor_package.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index fd962c8bea70a..8cac7b8f929d1 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -245,8 +245,8 @@ def forward(self, x): # failure on CI @common_utils.parametrize("embed_kernel_binary", [False]) @unittest.skipIf( - torch.version.hip is None and _get_torch_cuda_version() < (12, 6), - "Test is only supported on CUDA 12.6+", + torch.version.hip is None and _get_torch_cuda_version() < (12, 8), + "Test is only supported on CUDA 12.8+", ) def test_simple_multi_arch(self, embed_kernel_binary): if self.device != GPU_TYPE: diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 2f67758eaa24e..f1b190caaf0f7 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -315,8 +315,8 @@ def forward(self, x, y): self.assertTrue(torch.allclose(actual, expected)) @unittest.skipIf( - torch.version.hip is None and _get_torch_cuda_version() < (12, 6), - "Test is only supported on CUDA 12.6+", + torch.version.hip is None and _get_torch_cuda_version() < (12, 8), + "Test is only supported on CUDA 12.8+", ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfXpu # doesn't support multi-arch binary From 3e12e0f0a1f7ff11d66510ce18ecdce8ae63a7cd Mon Sep 17 00:00:00 2001 From: Ting Lu Date: Mon, 24 Nov 2025 19:47:44 +0000 Subject: [PATCH 014/338] [CD] [aarch64] unify the build.sh to build for aarch64 wheel (#166044) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit related to https://github.com/pytorch/pytorch/issues/163970 Changes: Below are addressed from review from @malfet and @atalman: 1. Simplified the x86 TORCH_CUDA_ARCH_LIST logic to reuse the base list in`.ci/manywheel/build_cuda.sh`. 2. Added function filter_aarch64_archs() that filters the TORCH_CUDA_ARCH_LIST for aarch64 based on the x86 code. 3. Added function in `.ci/pytorch/build.sh` to report error if ACL is not present. 4. Deprecated previous aarch64 scripts (`.ci/aarch64_linux/` folder). Improvements: 1. Significant improvement in build time for CUDA ARM wheel build - Reduced build time from 5.5–6 hours to 1 hour 40–50 minutes taking this 13.0 build for example, 6h 11m 46s to 1h 50m 1s ≈ 70 % faster build time old: https://github.com/pytorch/pytorch/actions/runs/19304934204/job/55209695430 new: https://github.com/pytorch/pytorch/actions/runs/19301014750/job/55195226316 Reason: MAX_JOBS=5 is now removed after we move away from original aarch64 build workflow, previously it was OOM in building flash-attn, new MAX_JOBS is 12. https://github.com/pytorch/pytorch/pull/166044/files#diff-ccef31095e4f2d203710232531c38bff3251e41cf73ec84ee59f224bb64034aeL280 2. Unified workflow for building x86 and sbsa wheels - more maintainable code Pull Request resolved: https://github.com/pytorch/pytorch/pull/166044 Approved by: https://github.com/atalman --- .ci/aarch64_linux/README.md | 19 - .ci/aarch64_linux/aarch64_ci_build.sh | 53 - .ci/aarch64_linux/aarch64_ci_setup.sh | 21 - .ci/aarch64_linux/aarch64_wheel_ci_build.py | 333 ------ .ci/aarch64_linux/build_aarch64_wheel.py | 999 ------------------ .ci/aarch64_linux/embed_library.py | 87 -- .ci/manywheel/build.sh | 4 +- .ci/manywheel/build_common.sh | 27 +- .ci/manywheel/build_cpu.sh | 61 +- .ci/manywheel/build_cuda.sh | 142 ++- .ci/pytorch/check_binary.sh | 7 +- .../linux_binary_build_workflow.yml.j2 | 2 - .github/workflows/_binary-build-linux.yml | 7 +- ...linux-aarch64-binary-manywheel-nightly.yml | 707 +++++++++++++ ...enerated-linux-binary-libtorch-nightly.yml | 7 + ...nerated-linux-binary-manywheel-nightly.yml | 56 + ...d-linux-s390x-binary-manywheel-nightly.yml | 7 + .github/workflows/test-check-binary.yml | 4 + 18 files changed, 996 insertions(+), 1547 deletions(-) delete mode 100644 .ci/aarch64_linux/README.md delete mode 100644 .ci/aarch64_linux/aarch64_ci_build.sh delete mode 100755 .ci/aarch64_linux/aarch64_ci_setup.sh delete mode 100755 .ci/aarch64_linux/aarch64_wheel_ci_build.py delete mode 100755 .ci/aarch64_linux/build_aarch64_wheel.py delete mode 100644 .ci/aarch64_linux/embed_library.py diff --git a/.ci/aarch64_linux/README.md b/.ci/aarch64_linux/README.md deleted file mode 100644 index 583ed4af99844..0000000000000 --- a/.ci/aarch64_linux/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Aarch64 (ARM/Graviton) Support Scripts -Scripts for building aarch64 PyTorch PIP Wheels. These scripts build the following wheels: -* torch -* torchvision -* torchaudio -* torchtext -* torchdata -## Aarch64_ci_build.sh -This script is design to support CD operations within PyPi manylinux aarch64 container, and be executed in the container. It prepares the container and then executes __aarch64_wheel_ci_build.py__ to build the wheels. The script "assumes" the PyTorch repo is located at: ```/pytorch``` and will put the wheels into ```/artifacts```. -### Usage -```DESIRED_PYTHON= aarch64_ci_build.sh``` - -__NOTE:__ CI build is currently __EXPERMINTAL__ - -## Build_aarch64_wheel.py -This app allows a person to build using AWS EC3 resources and requires AWS-CLI and Boto3 with AWS credentials to support building EC2 instances for the wheel builds. Can be used in a codebuild CD or from a local system. - -### Usage -```build_aarch64_wheel.py --key-name --use-docker --python 3.8 --branch ``` diff --git a/.ci/aarch64_linux/aarch64_ci_build.sh b/.ci/aarch64_linux/aarch64_ci_build.sh deleted file mode 100644 index b25f3b21e8eb1..0000000000000 --- a/.ci/aarch64_linux/aarch64_ci_build.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -set -eux -o pipefail - -GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-} - -# Set CUDA architecture lists to match x86 build_cuda.sh -if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then - export TORCH_CUDA_ARCH_LIST="8.0;9.0" -elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then - export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0" -elif [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then - export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0" -elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then - export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX" -fi - -# Compress the fatbin with -compress-mode=size for CUDA 13 -if [[ "$DESIRED_CUDA" == *"13"* ]]; then - export TORCH_NVCC_FLAGS="-compress-mode=size" - # Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801 - export BUILD_BUNDLE_PTXAS=1 -fi - -SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" -source $SCRIPTPATH/aarch64_ci_setup.sh - -############################################################################### -# Run aarch64 builder python -############################################################################### -cd / -# adding safe directory for git as the permissions will be -# on the mounted pytorch repo -git config --global --add safe.directory /pytorch -pip install -r /pytorch/requirements.txt -pip install auditwheel==6.2.0 wheel -if [ "$DESIRED_CUDA" = "cpu" ]; then - echo "BASE_CUDA_VERSION is not set. Building cpu wheel." - python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn -else - echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA" - export USE_SYSTEM_NCCL=1 - - # Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic) - if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then - echo "Bundling CUDA libraries with wheel for aarch64." - else - echo "Using nvidia libs from pypi for aarch64." - echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS" - export USE_NVIDIA_PYPI_LIBS=1 - fi - - python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda -fi diff --git a/.ci/aarch64_linux/aarch64_ci_setup.sh b/.ci/aarch64_linux/aarch64_ci_setup.sh deleted file mode 100755 index 8ffba65d7fedd..0000000000000 --- a/.ci/aarch64_linux/aarch64_ci_setup.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash -set -eux -o pipefail - -# This script is used to prepare the Docker container for aarch64_ci_wheel_build.py python script -# By creating symlinks from desired /opt/python to /usr/local/bin/ - -NUMPY_VERSION=2.0.2 -if [[ "$DESIRED_PYTHON" == "3.13" || "$DESIRED_PYTHON" == "3.13t" ]]; then - NUMPY_VERSION=2.1.2 -fi - -SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" -source $SCRIPTPATH/../manywheel/set_desired_python.sh - -pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2 - -for tool in python python3 pip pip3 ninja scons patchelf; do - ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin; -done - -python --version diff --git a/.ci/aarch64_linux/aarch64_wheel_ci_build.py b/.ci/aarch64_linux/aarch64_wheel_ci_build.py deleted file mode 100755 index a99e5f8f65659..0000000000000 --- a/.ci/aarch64_linux/aarch64_wheel_ci_build.py +++ /dev/null @@ -1,333 +0,0 @@ -#!/usr/bin/env python3 -# encoding: UTF-8 - -import os -import shutil -from subprocess import check_call, check_output - - -def list_dir(path: str) -> list[str]: - """' - Helper for getting paths for Python - """ - return check_output(["ls", "-1", path]).decode().split("\n") - - -def replace_tag(filename) -> None: - with open(filename) as f: - lines = f.readlines() - for i, line in enumerate(lines): - if line.startswith("Tag:"): - lines[i] = line.replace("-linux_", "-manylinux_2_28_") - print(f"Updated tag from {line} to {lines[i]}") - break - - with open(filename, "w") as f: - f.writelines(lines) - - -def patch_library_rpath( - folder: str, - lib_name: str, - use_nvidia_pypi_libs: bool = False, - desired_cuda: str = "", -) -> None: - """Apply patchelf to set RPATH for a library in torch/lib""" - lib_path = f"{folder}/tmp/torch/lib/{lib_name}" - - if use_nvidia_pypi_libs: - # For PyPI NVIDIA libraries, construct CUDA RPATH - cuda_rpaths = [ - "$ORIGIN/../../nvidia/cudnn/lib", - "$ORIGIN/../../nvidia/nvshmem/lib", - "$ORIGIN/../../nvidia/nccl/lib", - "$ORIGIN/../../nvidia/cusparselt/lib", - ] - - if "130" in desired_cuda: - cuda_rpaths.append("$ORIGIN/../../nvidia/cu13/lib") - else: - cuda_rpaths.extend( - [ - "$ORIGIN/../../nvidia/cublas/lib", - "$ORIGIN/../../nvidia/cuda_cupti/lib", - "$ORIGIN/../../nvidia/cuda_nvrtc/lib", - "$ORIGIN/../../nvidia/cuda_runtime/lib", - "$ORIGIN/../../nvidia/cufft/lib", - "$ORIGIN/../../nvidia/curand/lib", - "$ORIGIN/../../nvidia/cusolver/lib", - "$ORIGIN/../../nvidia/cusparse/lib", - "$ORIGIN/../../nvidia/nvtx/lib", - "$ORIGIN/../../nvidia/cufile/lib", - ] - ) - - # Add $ORIGIN for local torch libs - rpath = ":".join(cuda_rpaths) + ":$ORIGIN" - else: - # For bundled libraries, just use $ORIGIN - rpath = "$ORIGIN" - - if os.path.exists(lib_path): - os.system( - f"cd {folder}/tmp/torch/lib/; " - f"patchelf --set-rpath '{rpath}' --force-rpath {lib_name}" - ) - - -def copy_and_patch_library( - src_path: str, - folder: str, - use_nvidia_pypi_libs: bool = False, - desired_cuda: str = "", -) -> None: - """Copy a library to torch/lib and patch its RPATH""" - if os.path.exists(src_path): - lib_name = os.path.basename(src_path) - shutil.copy2(src_path, f"{folder}/tmp/torch/lib/{lib_name}") - patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda) - - -def package_cuda_wheel(wheel_path, desired_cuda) -> None: - """ - Package the cuda wheel libraries - """ - folder = os.path.dirname(wheel_path) - os.mkdir(f"{folder}/tmp") - os.system(f"unzip {wheel_path} -d {folder}/tmp") - # Delete original wheel since it will be repackaged - os.system(f"rm {wheel_path}") - - # Check if we should use PyPI NVIDIA libraries or bundle system libraries - use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1" - - if use_nvidia_pypi_libs: - print("Using nvidia libs from pypi - skipping CUDA library bundling") - # For PyPI approach, we don't bundle CUDA libraries - they come from PyPI packages - # We only need to bundle non-NVIDIA libraries - minimal_libs_to_copy = [ - "/lib64/libgomp.so.1", - "/usr/lib64/libgfortran.so.5", - "/acl/build/libarm_compute.so", - "/acl/build/libarm_compute_graph.so", - "/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0", - "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0", - "/usr/local/lib/libnvpl_lapack_core.so.0", - "/usr/local/lib/libnvpl_blas_core.so.0", - ] - - # Copy minimal libraries to unzipped_folder/torch/lib - for lib_path in minimal_libs_to_copy: - copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda) - - # Patch torch libraries used for searching libraries - torch_libs_to_patch = [ - "libtorch.so", - "libtorch_cpu.so", - "libtorch_cuda.so", - "libtorch_cuda_linalg.so", - "libtorch_global_deps.so", - "libtorch_python.so", - "libtorch_nvshmem.so", - "libc10.so", - "libc10_cuda.so", - "libcaffe2_nvrtc.so", - "libshm.so", - ] - for lib_name in torch_libs_to_patch: - patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda) - else: - print("Bundling CUDA libraries with wheel") - # Original logic for bundling system CUDA libraries - # Common libraries for all CUDA versions - common_libs = [ - # Non-NVIDIA system libraries - "/lib64/libgomp.so.1", - "/usr/lib64/libgfortran.so.5", - "/acl/build/libarm_compute.so", - "/acl/build/libarm_compute_graph.so", - # Common CUDA libraries (same for all versions) - "/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0", - "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0", - "/usr/local/lib/libnvpl_lapack_core.so.0", - "/usr/local/lib/libnvpl_blas_core.so.0", - "/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so", - "/usr/local/cuda/lib64/libcudnn.so.9", - "/usr/local/cuda/lib64/libcusparseLt.so.0", - "/usr/local/cuda/lib64/libcurand.so.10", - "/usr/local/cuda/lib64/libnccl.so.2", - "/usr/local/cuda/lib64/libnvshmem_host.so.3", - "/usr/local/cuda/lib64/libcudnn_adv.so.9", - "/usr/local/cuda/lib64/libcudnn_cnn.so.9", - "/usr/local/cuda/lib64/libcudnn_graph.so.9", - "/usr/local/cuda/lib64/libcudnn_ops.so.9", - "/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9", - "/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9", - "/usr/local/cuda/lib64/libcudnn_heuristic.so.9", - "/usr/local/cuda/lib64/libcufile.so.0", - "/usr/local/cuda/lib64/libcufile_rdma.so.1", - "/usr/local/cuda/lib64/libcusparse.so.12", - ] - - # CUDA version-specific libraries - if "13" in desired_cuda: - minor_version = desired_cuda[-1] - version_specific_libs = [ - "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13", - "/usr/local/cuda/lib64/libcublas.so.13", - "/usr/local/cuda/lib64/libcublasLt.so.13", - "/usr/local/cuda/lib64/libcudart.so.13", - "/usr/local/cuda/lib64/libcufft.so.12", - "/usr/local/cuda/lib64/libcusolver.so.12", - "/usr/local/cuda/lib64/libnvJitLink.so.13", - "/usr/local/cuda/lib64/libnvrtc.so.13", - f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}", - ] - elif "12" in desired_cuda: - # Get the last character for libnvrtc-builtins version (e.g., "129" -> "9") - minor_version = desired_cuda[-1] - version_specific_libs = [ - "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12", - "/usr/local/cuda/lib64/libcublas.so.12", - "/usr/local/cuda/lib64/libcublasLt.so.12", - "/usr/local/cuda/lib64/libcudart.so.12", - "/usr/local/cuda/lib64/libcufft.so.11", - "/usr/local/cuda/lib64/libcusolver.so.11", - "/usr/local/cuda/lib64/libnvJitLink.so.12", - "/usr/local/cuda/lib64/libnvrtc.so.12", - f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}", - ] - else: - raise ValueError(f"Unsupported CUDA version: {desired_cuda}.") - - # Combine all libraries - libs_to_copy = common_libs + version_specific_libs - - # Copy libraries to unzipped_folder/torch/lib - for lib_path in libs_to_copy: - copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda) - - # Make sure the wheel is tagged with manylinux_2_28 - for f in os.scandir(f"{folder}/tmp/"): - if f.is_dir() and f.name.endswith(".dist-info"): - replace_tag(f"{f.path}/WHEEL") - break - - os.system(f"wheel pack {folder}/tmp/ -d {folder}") - os.system(f"rm -rf {folder}/tmp/") - - -def complete_wheel(folder: str) -> str: - """ - Complete wheel build and put in artifact location - """ - wheel_name = list_dir(f"/{folder}/dist")[0] - - # Please note for cuda we don't run auditwheel since we use custom script to package - # the cuda dependencies to the wheel file using update_wheel() method. - # However we need to make sure filename reflects the correct Manylinux platform. - if "pytorch" in folder and not enable_cuda: - print("Repairing Wheel with AuditWheel") - check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder) - repaired_wheel_name = list_dir(f"/{folder}/wheelhouse")[0] - - print(f"Moving {repaired_wheel_name} wheel to /{folder}/dist") - os.rename( - f"/{folder}/wheelhouse/{repaired_wheel_name}", - f"/{folder}/dist/{repaired_wheel_name}", - ) - else: - repaired_wheel_name = list_dir(f"/{folder}/dist")[0] - - print(f"Copying {repaired_wheel_name} to artifacts") - shutil.copy2( - f"/{folder}/dist/{repaired_wheel_name}", f"/artifacts/{repaired_wheel_name}" - ) - - return repaired_wheel_name - - -def parse_arguments(): - """ - Parse inline arguments - """ - from argparse import ArgumentParser - - parser = ArgumentParser("AARCH64 wheels python CD") - parser.add_argument("--debug", action="store_true") - parser.add_argument("--build-only", action="store_true") - parser.add_argument("--test-only", type=str) - parser.add_argument("--enable-mkldnn", action="store_true") - parser.add_argument("--enable-cuda", action="store_true") - return parser.parse_args() - - -if __name__ == "__main__": - """ - Entry Point - """ - args = parse_arguments() - enable_mkldnn = args.enable_mkldnn - enable_cuda = args.enable_cuda - branch = check_output( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd="/pytorch" - ).decode() - - print("Building PyTorch wheel") - build_vars = "" - # MAX_JOB=5 is not required for CPU backend (see commit 465d98b) - if enable_cuda: - build_vars += "MAX_JOBS=5 " - - # Handle PyPI NVIDIA libraries vs bundled libraries - use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1" - if use_nvidia_pypi_libs: - print("Configuring build for PyPI NVIDIA libraries") - # Configure for dynamic linking (matching x86 logic) - build_vars += "ATEN_STATIC_CUDA=0 USE_CUDA_STATIC_LINK=0 USE_CUPTI_SO=1 " - else: - print("Configuring build for bundled NVIDIA libraries") - # Keep existing static linking approach - already configured above - - override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION") - desired_cuda = os.getenv("DESIRED_CUDA") - if override_package_version is not None: - version = override_package_version - build_vars += ( - f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version} PYTORCH_BUILD_NUMBER=1 " - ) - elif branch in ["nightly", "main"]: - build_date = ( - check_output(["git", "log", "--pretty=format:%cs", "-1"], cwd="/pytorch") - .decode() - .replace("-", "") - ) - version = ( - check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2] - ) - if enable_cuda: - build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 " - else: - build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 " - elif branch.startswith(("v1.", "v2.")): - build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 " - - if enable_mkldnn: - print("build pytorch with mkldnn+acl backend") - build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON " - build_vars += "ACL_ROOT_DIR=/acl " - if enable_cuda: - build_vars += "BLAS=NVPL " - else: - build_vars += "BLAS=OpenBLAS OpenBLAS_HOME=/opt/OpenBLAS " - else: - print("build pytorch without mkldnn backend") - - os.system(f"cd /pytorch; {build_vars} python3 -m build --wheel --no-isolation") - if enable_cuda: - print("Updating Cuda Dependency") - filename = os.listdir("/pytorch/dist/") - wheel_path = f"/pytorch/dist/{filename[0]}" - package_cuda_wheel(wheel_path, desired_cuda) - pytorch_wheel_name = complete_wheel("/pytorch/") - print(f"Build Complete. Created {pytorch_wheel_name}..") diff --git a/.ci/aarch64_linux/build_aarch64_wheel.py b/.ci/aarch64_linux/build_aarch64_wheel.py deleted file mode 100755 index a157ec57b574a..0000000000000 --- a/.ci/aarch64_linux/build_aarch64_wheel.py +++ /dev/null @@ -1,999 +0,0 @@ -#!/usr/bin/env python3 - -# This script is for building AARCH64 wheels using AWS EC2 instances. -# To generate binaries for the release follow these steps: -# 1. Update mappings for each of the Domain Libraries by adding new row to a table like this: -# "v1.11.0": ("0.11.0", "rc1"), -# 2. Run script with following arguments for each of the supported python versions and required tag, for example: -# build_aarch64_wheel.py --key-name --use-docker --python 3.8 --branch v1.11.0-rc3 - - -import os -import subprocess -import sys -import time -from typing import Optional, Union - -import boto3 - - -# AMI images for us-east-1, change the following based on your ~/.aws/config -os_amis = { - "ubuntu20_04": "ami-052eac90edaa9d08f", # login_name: ubuntu - "ubuntu22_04": "ami-0c6c29c5125214c77", # login_name: ubuntu - "redhat8": "ami-0698b90665a2ddcf1", # login_name: ec2-user -} - -ubuntu20_04_ami = os_amis["ubuntu20_04"] - - -def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]: - if key_name is None: - key_name = os.getenv("AWS_KEY_NAME") - if key_name is None: - return os.getenv("SSH_KEY_PATH", ""), "" - - homedir_path = os.path.expanduser("~") - default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem") - return os.getenv("SSH_KEY_PATH", default_path), key_name - - -ec2 = boto3.resource("ec2") - - -def ec2_get_instances(filter_name, filter_value): - return ec2.instances.filter( - Filters=[{"Name": filter_name, "Values": [filter_value]}] - ) - - -def ec2_instances_of_type(instance_type="t4g.2xlarge"): - return ec2_get_instances("instance-type", instance_type) - - -def ec2_instances_by_id(instance_id): - rc = list(ec2_get_instances("instance-id", instance_id)) - return rc[0] if len(rc) > 0 else None - - -def start_instance( - key_name, ami=ubuntu20_04_ami, instance_type="t4g.2xlarge", ebs_size: int = 50 -): - inst = ec2.create_instances( - ImageId=ami, - InstanceType=instance_type, - SecurityGroups=["ssh-allworld"], - KeyName=key_name, - MinCount=1, - MaxCount=1, - BlockDeviceMappings=[ - { - "DeviceName": "/dev/sda1", - "Ebs": { - "DeleteOnTermination": True, - "VolumeSize": ebs_size, - "VolumeType": "standard", - }, - } - ], - )[0] - print(f"Create instance {inst.id}") - inst.wait_until_running() - running_inst = ec2_instances_by_id(inst.id) - print(f"Instance started at {running_inst.public_dns_name}") - return running_inst - - -class RemoteHost: - addr: str - keyfile_path: str - login_name: str - container_id: Optional[str] = None - ami: Optional[str] = None - - def __init__(self, addr: str, keyfile_path: str, login_name: str = "ubuntu"): - self.addr = addr - self.keyfile_path = keyfile_path - self.login_name = login_name - - def _gen_ssh_prefix(self) -> list[str]: - return [ - "ssh", - "-o", - "StrictHostKeyChecking=no", - "-i", - self.keyfile_path, - f"{self.login_name}@{self.addr}", - "--", - ] - - @staticmethod - def _split_cmd(args: Union[str, list[str]]) -> list[str]: - return args.split() if isinstance(args, str) else args - - def run_ssh_cmd(self, args: Union[str, list[str]]) -> None: - subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args)) - - def check_ssh_output(self, args: Union[str, list[str]]) -> str: - return subprocess.check_output( - self._gen_ssh_prefix() + self._split_cmd(args) - ).decode("utf-8") - - def scp_upload_file(self, local_file: str, remote_file: str) -> None: - subprocess.check_call( - [ - "scp", - "-i", - self.keyfile_path, - local_file, - f"{self.login_name}@{self.addr}:{remote_file}", - ] - ) - - def scp_download_file( - self, remote_file: str, local_file: Optional[str] = None - ) -> None: - if local_file is None: - local_file = "." - subprocess.check_call( - [ - "scp", - "-i", - self.keyfile_path, - f"{self.login_name}@{self.addr}:{remote_file}", - local_file, - ] - ) - - def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None: - self.run_ssh_cmd("sudo apt-get install -y docker.io") - self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}") - self.run_ssh_cmd("sudo service docker start") - self.run_ssh_cmd(f"docker pull {image}") - self.container_id = self.check_ssh_output( - f"docker run -t -d -w /root {image}" - ).strip() - - def using_docker(self) -> bool: - return self.container_id is not None - - def run_cmd(self, args: Union[str, list[str]]) -> None: - if not self.using_docker(): - return self.run_ssh_cmd(args) - assert self.container_id is not None - docker_cmd = self._gen_ssh_prefix() + [ - "docker", - "exec", - "-i", - self.container_id, - "bash", - ] - p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE) - p.communicate( - input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode( - "utf-8" - ) - ) - rc = p.wait() - if rc != 0: - raise subprocess.CalledProcessError(rc, docker_cmd) - - def check_output(self, args: Union[str, list[str]]) -> str: - if not self.using_docker(): - return self.check_ssh_output(args) - assert self.container_id is not None - docker_cmd = self._gen_ssh_prefix() + [ - "docker", - "exec", - "-i", - self.container_id, - "bash", - ] - p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE) - (out, err) = p.communicate( - input=" ".join(["source .bashrc && "] + self._split_cmd(args)).encode( - "utf-8" - ) - ) - rc = p.wait() - if rc != 0: - raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err) - return out.decode("utf-8") - - def upload_file(self, local_file: str, remote_file: str) -> None: - if not self.using_docker(): - return self.scp_upload_file(local_file, remote_file) - tmp_file = os.path.join("/tmp", os.path.basename(local_file)) - self.scp_upload_file(local_file, tmp_file) - self.run_ssh_cmd( - ["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"] - ) - self.run_ssh_cmd(["rm", tmp_file]) - - def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None: - if not self.using_docker(): - return self.scp_download_file(remote_file, local_file) - tmp_file = os.path.join("/tmp", os.path.basename(remote_file)) - self.run_ssh_cmd( - ["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file] - ) - self.scp_download_file(tmp_file, local_file) - self.run_ssh_cmd(["rm", tmp_file]) - - def download_wheel( - self, remote_file: str, local_file: Optional[str] = None - ) -> None: - if self.using_docker() and local_file is None: - basename = os.path.basename(remote_file) - local_file = basename.replace( - "-linux_aarch64.whl", "-manylinux2014_aarch64.whl" - ) - self.download_file(remote_file, local_file) - - def list_dir(self, path: str) -> list[str]: - return self.check_output(["ls", "-1", path]).split("\n") - - -def wait_for_connection(addr, port, timeout=15, attempt_cnt=5): - import socket - - for i in range(attempt_cnt): - try: - with socket.create_connection((addr, port), timeout=timeout): - return - except (ConnectionRefusedError, TimeoutError): # noqa: PERF203 - if i == attempt_cnt - 1: - raise - time.sleep(timeout) - - -def update_apt_repo(host: RemoteHost) -> None: - time.sleep(5) - host.run_cmd("sudo systemctl stop apt-daily.service || true") - host.run_cmd("sudo systemctl stop unattended-upgrades.service || true") - host.run_cmd( - "while systemctl is-active --quiet apt-daily.service; do sleep 1; done" - ) - host.run_cmd( - "while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done" - ) - host.run_cmd("sudo apt-get update") - time.sleep(3) - host.run_cmd("sudo apt-get update") - - -def install_condaforge( - host: RemoteHost, suffix: str = "latest/download/Miniforge3-Linux-aarch64.sh" -) -> None: - print("Install conda-forge") - host.run_cmd(f"curl -OL https://github.com/conda-forge/miniforge/releases/{suffix}") - host.run_cmd(f"sh -f {os.path.basename(suffix)} -b") - host.run_cmd(f"rm -f {os.path.basename(suffix)}") - if host.using_docker(): - host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc") - else: - host.run_cmd( - [ - "sed", - "-i", - "'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH'", - ".bashrc", - ] - ) - - -def install_condaforge_python(host: RemoteHost, python_version="3.8") -> None: - if python_version == "3.6": - # Python-3.6 EOLed and not compatible with conda-4.11 - install_condaforge( - host, suffix="download/4.10.3-10/Miniforge3-4.10.3-10-Linux-aarch64.sh" - ) - host.run_cmd(f"conda install -y python={python_version} numpy pyyaml") - else: - install_condaforge( - host, suffix="download/4.11.0-4/Miniforge3-4.11.0-4-Linux-aarch64.sh" - ) - # Pytorch-1.10 or older are not compatible with setuptools=59.6 or newer - host.run_cmd( - f"conda install -y python={python_version} numpy pyyaml setuptools>=59.5.0" - ) - - -def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None: - host.run_cmd("pip3 install auditwheel") - host.run_cmd( - "conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf" - ) - from tempfile import NamedTemporaryFile - - with NamedTemporaryFile() as tmp: - tmp.write(embed_library_script.encode("utf-8")) - tmp.flush() - host.upload_file(tmp.name, "embed_library.py") - - print("Embedding libgomp into wheel") - if host.using_docker(): - host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag") - else: - host.run_cmd(f"python3 embed_library.py {wheel_name}") - - -def checkout_repo( - host: RemoteHost, - *, - branch: str = "main", - url: str, - git_clone_flags: str, - mapping: dict[str, tuple[str, str]], -) -> Optional[str]: - for prefix in mapping: - if not branch.startswith(prefix): - continue - tag = f"v{mapping[prefix][0]}-{mapping[prefix][1]}" - host.run_cmd(f"git clone {url} -b {tag} {git_clone_flags}") - return mapping[prefix][0] - - host.run_cmd(f"git clone {url} -b {branch} {git_clone_flags}") - return None - - -def build_torchvision( - host: RemoteHost, - *, - branch: str = "main", - use_conda: bool = True, - git_clone_flags: str, - run_smoke_tests: bool = True, -) -> str: - print("Checking out TorchVision repo") - build_version = checkout_repo( - host, - branch=branch, - url="https://github.com/pytorch/vision", - git_clone_flags=git_clone_flags, - mapping={ - "v1.7.1": ("0.8.2", "rc2"), - "v1.8.0": ("0.9.0", "rc3"), - "v1.8.1": ("0.9.1", "rc1"), - "v1.9.0": ("0.10.0", "rc1"), - "v1.10.0": ("0.11.1", "rc1"), - "v1.10.1": ("0.11.2", "rc1"), - "v1.10.2": ("0.11.3", "rc1"), - "v1.11.0": ("0.12.0", "rc1"), - "v1.12.0": ("0.13.0", "rc4"), - "v1.12.1": ("0.13.1", "rc6"), - "v1.13.0": ("0.14.0", "rc4"), - "v1.13.1": ("0.14.1", "rc2"), - "v2.0.0": ("0.15.1", "rc2"), - "v2.0.1": ("0.15.2", "rc2"), - }, - ) - print("Building TorchVision wheel") - - # Please note libnpg and jpeg are required to build image.so extension - if use_conda: - host.run_cmd("conda install -y libpng jpeg") - # Remove .so files to force static linking - host.run_cmd( - "rm miniforge3/lib/libpng.so miniforge3/lib/libpng16.so miniforge3/lib/libjpeg.so" - ) - # And patch setup.py to include libz dependency for libpng - host.run_cmd( - [ - 'sed -i -e \'s/image_link_flags\\.append("png")/image_link_flags += ["png", "z"]/\' vision/setup.py' - ] - ) - - build_vars = "" - if branch == "nightly": - version = host.check_output( - ["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"] - ).strip() - if len(version) == 0: - # In older revisions, version was embedded in setup.py - version = ( - host.check_output(["grep", '"version = \'"', "vision/setup.py"]) - .strip() - .split("'")[1][:-2] - ) - build_date = ( - host.check_output("cd vision && git log --pretty=format:%s -1") - .strip() - .split()[0] - .replace("-", "") - ) - build_vars += f"BUILD_VERSION={version}.dev{build_date}" - elif build_version is not None: - build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" - if host.using_docker(): - build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" - - host.run_cmd(f"cd vision && {build_vars} python3 -m build --wheel --no-isolation") - vision_wheel_name = host.list_dir("vision/dist")[0] - embed_libgomp(host, use_conda, os.path.join("vision", "dist", vision_wheel_name)) - - print("Copying TorchVision wheel") - host.download_wheel(os.path.join("vision", "dist", vision_wheel_name)) - if run_smoke_tests: - host.run_cmd( - f"pip3 install {os.path.join('vision', 'dist', vision_wheel_name)}" - ) - host.run_cmd("python3 vision/test/smoke_test.py") - print("Delete vision checkout") - host.run_cmd("rm -rf vision") - - return vision_wheel_name - - -def build_torchdata( - host: RemoteHost, - *, - branch: str = "main", - use_conda: bool = True, - git_clone_flags: str = "", -) -> str: - print("Checking out TorchData repo") - git_clone_flags += " --recurse-submodules" - build_version = checkout_repo( - host, - branch=branch, - url="https://github.com/pytorch/data", - git_clone_flags=git_clone_flags, - mapping={ - "v1.13.1": ("0.5.1", ""), - "v2.0.0": ("0.6.0", "rc5"), - "v2.0.1": ("0.6.1", "rc1"), - }, - ) - print("Building TorchData wheel") - build_vars = "" - if branch == "nightly": - version = host.check_output( - ["if [ -f data/version.txt ]; then cat data/version.txt; fi"] - ).strip() - build_date = ( - host.check_output("cd data && git log --pretty=format:%s -1") - .strip() - .split()[0] - .replace("-", "") - ) - build_vars += f"BUILD_VERSION={version}.dev{build_date}" - elif build_version is not None: - build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" - if host.using_docker(): - build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" - - host.run_cmd(f"cd data && {build_vars} python3 -m build --wheel --no-isolation") - wheel_name = host.list_dir("data/dist")[0] - embed_libgomp(host, use_conda, os.path.join("data", "dist", wheel_name)) - - print("Copying TorchData wheel") - host.download_wheel(os.path.join("data", "dist", wheel_name)) - - return wheel_name - - -def build_torchtext( - host: RemoteHost, - *, - branch: str = "main", - use_conda: bool = True, - git_clone_flags: str = "", -) -> str: - print("Checking out TorchText repo") - git_clone_flags += " --recurse-submodules" - build_version = checkout_repo( - host, - branch=branch, - url="https://github.com/pytorch/text", - git_clone_flags=git_clone_flags, - mapping={ - "v1.9.0": ("0.10.0", "rc1"), - "v1.10.0": ("0.11.0", "rc2"), - "v1.10.1": ("0.11.1", "rc1"), - "v1.10.2": ("0.11.2", "rc1"), - "v1.11.0": ("0.12.0", "rc1"), - "v1.12.0": ("0.13.0", "rc2"), - "v1.12.1": ("0.13.1", "rc5"), - "v1.13.0": ("0.14.0", "rc3"), - "v1.13.1": ("0.14.1", "rc1"), - "v2.0.0": ("0.15.1", "rc2"), - "v2.0.1": ("0.15.2", "rc2"), - }, - ) - print("Building TorchText wheel") - build_vars = "" - if branch == "nightly": - version = host.check_output( - ["if [ -f text/version.txt ]; then cat text/version.txt; fi"] - ).strip() - build_date = ( - host.check_output("cd text && git log --pretty=format:%s -1") - .strip() - .split()[0] - .replace("-", "") - ) - build_vars += f"BUILD_VERSION={version}.dev{build_date}" - elif build_version is not None: - build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" - if host.using_docker(): - build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" - - host.run_cmd(f"cd text && {build_vars} python3 -m build --wheel --no-isolation") - wheel_name = host.list_dir("text/dist")[0] - embed_libgomp(host, use_conda, os.path.join("text", "dist", wheel_name)) - - print("Copying TorchText wheel") - host.download_wheel(os.path.join("text", "dist", wheel_name)) - - return wheel_name - - -def build_torchaudio( - host: RemoteHost, - *, - branch: str = "main", - use_conda: bool = True, - git_clone_flags: str = "", -) -> str: - print("Checking out TorchAudio repo") - git_clone_flags += " --recurse-submodules" - build_version = checkout_repo( - host, - branch=branch, - url="https://github.com/pytorch/audio", - git_clone_flags=git_clone_flags, - mapping={ - "v1.9.0": ("0.9.0", "rc2"), - "v1.10.0": ("0.10.0", "rc5"), - "v1.10.1": ("0.10.1", "rc1"), - "v1.10.2": ("0.10.2", "rc1"), - "v1.11.0": ("0.11.0", "rc1"), - "v1.12.0": ("0.12.0", "rc3"), - "v1.12.1": ("0.12.1", "rc5"), - "v1.13.0": ("0.13.0", "rc4"), - "v1.13.1": ("0.13.1", "rc2"), - "v2.0.0": ("2.0.1", "rc3"), - "v2.0.1": ("2.0.2", "rc2"), - }, - ) - print("Building TorchAudio wheel") - build_vars = "" - if branch == "nightly": - version = ( - host.check_output(["grep", '"version = \'"', "audio/setup.py"]) - .strip() - .split("'")[1][:-2] - ) - build_date = ( - host.check_output("cd audio && git log --pretty=format:%s -1") - .strip() - .split()[0] - .replace("-", "") - ) - build_vars += f"BUILD_VERSION={version}.dev{build_date}" - elif build_version is not None: - build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" - if host.using_docker(): - build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" - - host.run_cmd( - f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \ - && ./packaging/ffmpeg/build.sh \ - && {build_vars} python3 -m build --wheel --no-isolation" - ) - - wheel_name = host.list_dir("audio/dist")[0] - embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name)) - - print("Copying TorchAudio wheel") - host.download_wheel(os.path.join("audio", "dist", wheel_name)) - - return wheel_name - - -def configure_system( - host: RemoteHost, - *, - compiler: str = "gcc-8", - use_conda: bool = True, - python_version: str = "3.8", -) -> None: - if use_conda: - install_condaforge_python(host, python_version) - - print("Configuring the system") - if not host.using_docker(): - update_apt_repo(host) - host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip") - else: - host.run_cmd("yum install -y sudo") - host.run_cmd("conda install -y ninja scons") - - if not use_conda: - host.run_cmd( - "sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip" - ) - host.run_cmd("pip3 install dataclasses typing-extensions") - if not use_conda: - print("Installing Cython + numpy from PyPy") - host.run_cmd("sudo pip3 install Cython") - host.run_cmd("sudo pip3 install numpy") - - -def build_domains( - host: RemoteHost, - *, - branch: str = "main", - use_conda: bool = True, - git_clone_flags: str = "", -) -> tuple[str, str, str, str]: - vision_wheel_name = build_torchvision( - host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags - ) - audio_wheel_name = build_torchaudio( - host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags - ) - data_wheel_name = build_torchdata( - host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags - ) - text_wheel_name = build_torchtext( - host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags - ) - return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name) - - -def start_build( - host: RemoteHost, - *, - branch: str = "main", - compiler: str = "gcc-8", - use_conda: bool = True, - python_version: str = "3.8", - pytorch_only: bool = False, - pytorch_build_number: Optional[str] = None, - shallow_clone: bool = True, - enable_mkldnn: bool = False, -) -> tuple[str, str, str, str, str]: - git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else "" - if host.using_docker() and not use_conda: - print("Auto-selecting conda option for docker images") - use_conda = True - if not host.using_docker(): - print("Disable mkldnn for host builds") - enable_mkldnn = False - - configure_system( - host, compiler=compiler, use_conda=use_conda, python_version=python_version - ) - - if host.using_docker(): - print("Move libgfortant.a into a standard location") - # HACK: pypa gforntran.a is compiled without PIC, which leads to the following error - # libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17' # noqa: E501, B950 - # Workaround by copying gfortran library from the host - host.run_ssh_cmd("sudo apt-get install -y gfortran-8") - host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8") - host.run_ssh_cmd( - [ - "docker", - "cp", - "/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a", - f"{host.container_id}:/opt/rh/devtoolset-10/root/usr/lib/gcc/aarch64-redhat-linux/10/", - ] - ) - - print("Checking out PyTorch repo") - host.run_cmd( - f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}" - ) - - host.run_cmd("pytorch/.ci/docker/common/install_openblas.sh") - - print("Building PyTorch wheel") - build_opts = "" - if pytorch_build_number is not None: - build_opts += f" -C--build-option=--build-number={pytorch_build_number}" - # Breakpad build fails on aarch64 - build_vars = "USE_BREAKPAD=0 " - if branch == "nightly": - build_date = ( - host.check_output("cd pytorch && git log --pretty=format:%s -1") - .strip() - .split()[0] - .replace("-", "") - ) - version = host.check_output("cat pytorch/version.txt").strip()[:-2] - build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1" - if branch.startswith(("v1.", "v2.")): - build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1" - if host.using_docker(): - build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" - if enable_mkldnn: - host.run_cmd("pytorch/.ci/docker/common/install_acl.sh") - print("build pytorch with mkldnn+acl backend") - build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON" - build_vars += " BLAS=OpenBLAS" - build_vars += " OpenBLAS_HOME=/opt/OpenBLAS" - build_vars += " ACL_ROOT_DIR=/acl" - host.run_cmd( - f"cd $HOME/pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}" - ) - print("Repair the wheel") - pytorch_wheel_name = host.list_dir("pytorch/dist")[0] - ld_library_path = "/acl/build:$HOME/pytorch/build/lib" - host.run_cmd( - f"export LD_LIBRARY_PATH={ld_library_path} && auditwheel repair $HOME/pytorch/dist/{pytorch_wheel_name}" - ) - print("replace the original wheel with the repaired one") - pytorch_repaired_wheel_name = host.list_dir("wheelhouse")[0] - host.run_cmd( - f"cp $HOME/wheelhouse/{pytorch_repaired_wheel_name} $HOME/pytorch/dist/{pytorch_wheel_name}" - ) - else: - print("build pytorch without mkldnn backend") - host.run_cmd( - f"cd pytorch && {build_vars} python3 -m build --wheel --no-isolation{build_opts}" - ) - - print("Deleting build folder") - host.run_cmd("cd pytorch && rm -rf build") - pytorch_wheel_name = host.list_dir("pytorch/dist")[0] - embed_libgomp(host, use_conda, os.path.join("pytorch", "dist", pytorch_wheel_name)) - print("Copying the wheel") - host.download_wheel(os.path.join("pytorch", "dist", pytorch_wheel_name)) - - print("Installing PyTorch wheel") - host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}") - - if pytorch_only: - return (pytorch_wheel_name, None, None, None, None) - domain_wheels = build_domains( - host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags - ) - - return (pytorch_wheel_name, *domain_wheels) - - -embed_library_script = """ -#!/usr/bin/env python3 - -from auditwheel.patcher import Patchelf -from auditwheel.wheeltools import InWheelCtx -from auditwheel.elfutils import elf_file_filter -from auditwheel.repair import copylib -from auditwheel.lddtree import lddtree -from subprocess import check_call -import os -import shutil -import sys -from tempfile import TemporaryDirectory - - -def replace_tag(filename): - with open(filename, 'r') as f: - lines = f.read().split("\\n") - for i,line in enumerate(lines): - if not line.startswith("Tag: "): - continue - lines[i] = line.replace("-linux_", "-manylinux2014_") - print(f'Updated tag from {line} to {lines[i]}') - - with open(filename, 'w') as f: - f.write("\\n".join(lines)) - - -class AlignedPatchelf(Patchelf): - def set_soname(self, file_name: str, new_soname: str) -> None: - check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name]) - - def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None: - check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name]) - - -def embed_library(whl_path, lib_soname, update_tag=False): - patcher = AlignedPatchelf() - out_dir = TemporaryDirectory() - whl_name = os.path.basename(whl_path) - tmp_whl_name = os.path.join(out_dir.name, whl_name) - with InWheelCtx(whl_path) as ctx: - torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib') - ctx.out_wheel=tmp_whl_name - new_lib_path, new_lib_soname = None, None - for filename, elf in elf_file_filter(ctx.iter_files()): - if not filename.startswith('torch/lib'): - continue - libtree = lddtree(filename) - if lib_soname not in libtree['needed']: - continue - lib_path = libtree['libs'][lib_soname]['path'] - if lib_path is None: - print(f"Can't embed {lib_soname} as it could not be found") - break - if lib_path.startswith(torchlib_path): - continue - - if new_lib_path is None: - new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher) - patcher.replace_needed(filename, lib_soname, new_lib_soname) - print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}') - if update_tag: - # Add manylinux2014 tag - for filename in ctx.iter_files(): - if os.path.basename(filename) != 'WHEEL': - continue - replace_tag(filename) - shutil.move(tmp_whl_name, whl_path) - - -if __name__ == '__main__': - embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag') -""" - - -def run_tests(host: RemoteHost, whl: str, branch="main") -> None: - print("Configuring the system") - update_apt_repo(host) - host.run_cmd("sudo apt-get install -y python3-pip git") - host.run_cmd("sudo pip3 install Cython") - host.run_cmd("sudo pip3 install numpy") - host.upload_file(whl, ".") - host.run_cmd(f"sudo pip3 install {whl}") - host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'") - host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch") - host.run_cmd("cd pytorch/test; python3 test_torch.py -v") - - -def get_instance_name(instance) -> Optional[str]: - if instance.tags is None: - return None - for tag in instance.tags: - if tag["Key"] == "Name": - return tag["Value"] - return None - - -def list_instances(instance_type: str) -> None: - print(f"All instances of type {instance_type}") - for instance in ec2_instances_of_type(instance_type): - ifaces = instance.network_interfaces - az = ifaces[0].subnet.availability_zone if len(ifaces) > 0 else None - print( - f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']} {az}" - ) - - -def terminate_instances(instance_type: str) -> None: - print(f"Terminating all instances of type {instance_type}") - instances = list(ec2_instances_of_type(instance_type)) - for instance in instances: - print(f"Terminating {instance.id}") - instance.terminate() - print("Waiting for termination to complete") - for instance in instances: - instance.wait_until_terminated() - - -def parse_arguments(): - from argparse import ArgumentParser - - parser = ArgumentParser("Build and test AARCH64 wheels using EC2") - parser.add_argument("--key-name", type=str) - parser.add_argument("--debug", action="store_true") - parser.add_argument("--build-only", action="store_true") - parser.add_argument("--test-only", type=str) - group = parser.add_mutually_exclusive_group() - group.add_argument("--os", type=str, choices=list(os_amis.keys())) - group.add_argument("--ami", type=str) - parser.add_argument( - "--python-version", - type=str, - choices=[f"3.{d}" for d in range(6, 12)], - default=None, - ) - parser.add_argument("--alloc-instance", action="store_true") - parser.add_argument("--list-instances", action="store_true") - parser.add_argument("--pytorch-only", action="store_true") - parser.add_argument("--keep-running", action="store_true") - parser.add_argument("--terminate-instances", action="store_true") - parser.add_argument("--instance-type", type=str, default="t4g.2xlarge") - parser.add_argument("--ebs-size", type=int, default=50) - parser.add_argument("--branch", type=str, default="main") - parser.add_argument("--use-docker", action="store_true") - parser.add_argument( - "--compiler", - type=str, - choices=["gcc-7", "gcc-8", "gcc-9", "clang"], - default="gcc-8", - ) - parser.add_argument("--use-torch-from-pypi", action="store_true") - parser.add_argument("--pytorch-build-number", type=str, default=None) - parser.add_argument("--disable-mkldnn", action="store_true") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_arguments() - ami = ( - args.ami - if args.ami is not None - else os_amis[args.os] - if args.os is not None - else ubuntu20_04_ami - ) - keyfile_path, key_name = compute_keyfile_path(args.key_name) - - if args.list_instances: - list_instances(args.instance_type) - sys.exit(0) - - if args.terminate_instances: - terminate_instances(args.instance_type) - sys.exit(0) - - if len(key_name) == 0: - raise RuntimeError(""" - Cannot start build without key_name, please specify - --key-name argument or AWS_KEY_NAME environment variable.""") - if len(keyfile_path) == 0 or not os.path.exists(keyfile_path): - raise RuntimeError(f""" - Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please - check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""") - - # Starting the instance - inst = start_instance( - key_name, ami=ami, instance_type=args.instance_type, ebs_size=args.ebs_size - ) - instance_name = f"{args.key_name}-{args.os}" - if args.python_version is not None: - instance_name += f"-py{args.python_version}" - inst.create_tags( - DryRun=False, - Tags=[ - { - "Key": "Name", - "Value": instance_name, - } - ], - ) - addr = inst.public_dns_name - wait_for_connection(addr, 22) - host = RemoteHost(addr, keyfile_path) - host.ami = ami - if args.use_docker: - update_apt_repo(host) - host.start_docker() - - if args.test_only: - run_tests(host, args.test_only) - sys.exit(0) - - if args.alloc_instance: - if args.python_version is None: - sys.exit(0) - install_condaforge_python(host, args.python_version) - sys.exit(0) - - python_version = args.python_version if args.python_version is not None else "3.10" - - if args.use_torch_from_pypi: - configure_system(host, compiler=args.compiler, python_version=python_version) - print("Installing PyTorch wheel") - host.run_cmd("pip3 install torch") - build_domains( - host, branch=args.branch, git_clone_flags=" --depth 1 --shallow-submodules" - ) - else: - start_build( - host, - branch=args.branch, - compiler=args.compiler, - python_version=python_version, - pytorch_only=args.pytorch_only, - pytorch_build_number=args.pytorch_build_number, - enable_mkldnn=not args.disable_mkldnn, - ) - if not args.keep_running: - print(f"Waiting for instance {inst.id} to terminate") - inst.terminate() - inst.wait_until_terminated() diff --git a/.ci/aarch64_linux/embed_library.py b/.ci/aarch64_linux/embed_library.py deleted file mode 100644 index 2834a4632989b..0000000000000 --- a/.ci/aarch64_linux/embed_library.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 - -import os -import shutil -import sys -from subprocess import check_call -from tempfile import TemporaryDirectory - -from auditwheel.elfutils import elf_file_filter -from auditwheel.lddtree import lddtree -from auditwheel.patcher import Patchelf -from auditwheel.repair import copylib -from auditwheel.wheeltools import InWheelCtx - - -def replace_tag(filename): - with open(filename) as f: - lines = f.read().split("\\n") - for i, line in enumerate(lines): - if not line.startswith("Tag: "): - continue - lines[i] = line.replace("-linux_", "-manylinux2014_") - print(f"Updated tag from {line} to {lines[i]}") - - with open(filename, "w") as f: - f.write("\\n".join(lines)) - - -class AlignedPatchelf(Patchelf): - def set_soname(self, file_name: str, new_soname: str) -> None: - check_call( - ["patchelf", "--page-size", "65536", "--set-soname", new_soname, file_name] - ) - - def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None: - check_call( - [ - "patchelf", - "--page-size", - "65536", - "--replace-needed", - soname, - new_soname, - file_name, - ] - ) - - -def embed_library(whl_path, lib_soname, update_tag=False): - patcher = AlignedPatchelf() - out_dir = TemporaryDirectory() - whl_name = os.path.basename(whl_path) - tmp_whl_name = os.path.join(out_dir.name, whl_name) - with InWheelCtx(whl_path) as ctx: - torchlib_path = os.path.join(ctx._tmpdir.name, "torch", "lib") - ctx.out_wheel = tmp_whl_name - new_lib_path, new_lib_soname = None, None - for filename, _ in elf_file_filter(ctx.iter_files()): - if not filename.startswith("torch/lib"): - continue - libtree = lddtree(filename) - if lib_soname not in libtree["needed"]: - continue - lib_path = libtree["libs"][lib_soname]["path"] - if lib_path is None: - print(f"Can't embed {lib_soname} as it could not be found") - break - if lib_path.startswith(torchlib_path): - continue - - if new_lib_path is None: - new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher) - patcher.replace_needed(filename, lib_soname, new_lib_soname) - print(f"Replacing {lib_soname} with {new_lib_soname} for {filename}") - if update_tag: - # Add manylinux2014 tag - for filename in ctx.iter_files(): - if os.path.basename(filename) != "WHEEL": - continue - replace_tag(filename) - shutil.move(tmp_whl_name, whl_path) - - -if __name__ == "__main__": - embed_library( - sys.argv[1], "libgomp.so.1", len(sys.argv) > 2 and sys.argv[2] == "--update-tag" - ) diff --git a/.ci/manywheel/build.sh b/.ci/manywheel/build.sh index 6b2a60bc5ca28..61beb47706b8f 100755 --- a/.ci/manywheel/build.sh +++ b/.ci/manywheel/build.sh @@ -5,13 +5,13 @@ set -ex SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" case "${GPU_ARCH_TYPE:-BLANK}" in - cuda) + cuda | cuda-aarch64) bash "${SCRIPTPATH}/build_cuda.sh" ;; rocm) bash "${SCRIPTPATH}/build_rocm.sh" ;; - cpu | cpu-cxx11-abi | cpu-s390x) + cpu | cpu-cxx11-abi | cpu-aarch64 | cpu-s390x) bash "${SCRIPTPATH}/build_cpu.sh" ;; xpu) diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index b84268fd12896..29dbc3822ed5c 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -18,12 +18,27 @@ retry () { $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) } +# Detect architecture first +ARCH=$(uname -m) +echo "Detected architecture: $ARCH" + PLATFORM="" # TODO move this into the Docker images OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release) if [[ "$OS_NAME" == *"AlmaLinux"* ]]; then retry yum install -q -y zip openssl - PLATFORM="manylinux_2_28_x86_64" + # Set platform based on architecture + case $ARCH in + x86_64) + PLATFORM="manylinux_2_28_x86_64" + ;; + aarch64) + PLATFORM="manylinux_2_28_aarch64" + ;; + *) + echo "Other architectures: $ARCH, not setting PLATFORM" + ;; + esac elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then retry dnf install -q -y zip openssl elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then @@ -38,6 +53,8 @@ else exit 1 fi +echo "Platform set to: $PLATFORM" + # We use the package name to test the package by passing this to 'pip install' # This is the env variable that setup.py uses to name the package. Note that # pip 'normalizes' the name first by changing all - to _ @@ -299,8 +316,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w # ROCm workaround for roctracer dlopens if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then patchedpath=$(fname_without_so_number $destpath) - # Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load - elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then + # Keep the so number for XPU dependencies, libgomp.so.1, ACL libraries, and NVPL libraries to avoid twice load + elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" || "$filename" == libarm_compute* || "$filename" == libnvpl* || "$filename" == "libgfortran.so.5" ]]; then patchedpath=$destpath else patchedpath=$(fname_with_sha256 $destpath) @@ -350,6 +367,10 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g') sed -i -e s#linux_x86_64#"${PLATFORM}"# $wheel_file; fi + if [[ $PLATFORM == "manylinux_2_28_aarch64" ]]; then + wheel_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/WHEEL/g') + sed -i -e s#linux_aarch64#"${PLATFORM}"# $wheel_file; + fi # regenerate the RECORD file with new hashes record_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/RECORD/g') diff --git a/.ci/manywheel/build_cpu.sh b/.ci/manywheel/build_cpu.sh index 9d982bd30e25a..9a6b14c0a5e37 100755 --- a/.ci/manywheel/build_cpu.sh +++ b/.ci/manywheel/build_cpu.sh @@ -15,6 +15,35 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then EXTRA_CAFFE2_CMAKE_FLAGS=() fi +# Detect architecture +ARCH=$(uname -m) +echo "Building CPU wheel for architecture: $ARCH" + +# Detect and configure OpenBLAS and ARM Compute Libraryfor CPU aarch64 +if [[ "$ARCH" == "aarch64" ]]; then + # Use OpenBLAS for BLAS/LAPACK on CPU aarch64 builds + if [[ ! -f "/opt/OpenBLAS/lib/libopenblas.so.0" ]]; then + echo "ERROR: OpenBLAS not found at /opt/OpenBLAS/lib/" + echo "OpenBLAS (BLAS/LAPACK) is required for CPU aarch64 builds" + exit 1 + fi + echo "Using OpenBLAS for CPU aarch64" + export BLAS=OpenBLAS + export OpenBLAS_HOME=/opt/OpenBLAS + + # ACL is required for aarch64 builds + if [[ ! -d "/acl" ]]; then + echo "ERROR: ARM Compute Library not found at /acl" + echo "ACL is required for aarch64 builds. Check Docker image setup." + exit 1 + fi + + export USE_MKLDNN=1 + export USE_MKLDNN_ACL=1 + export ACL_ROOT_DIR=/acl + echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl" +fi + WHEELHOUSE_DIR="wheelhousecpu" LIBTORCH_HOUSE_DIR="libtorch_housecpu" if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then @@ -34,8 +63,10 @@ elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then LIBGOMP_PATH="/usr/lib64/libgomp.so.1" elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then - if [[ "$(uname -m)" == "s390x" ]]; then + if [[ "$ARCH" == "s390x" ]]; then LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1" + elif [[ "$ARCH" == "aarch64" ]]; then + LIBGOMP_PATH="/usr/lib/aarch64-linux-gnu/libgomp.so.1" else LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1" fi @@ -49,6 +80,34 @@ DEPS_SONAME=( "libgomp.so.1" ) +# Add ARM-specific library dependencies for CPU builds +if [[ "$ARCH" == "aarch64" ]]; then + echo "Adding ARM-specific CPU library dependencies" + + # ARM Compute Library (if available) + if [[ -d "/acl/build" ]]; then + echo "Adding ARM Compute Library for CPU" + DEPS_LIST+=( + "/acl/build/libarm_compute.so" + "/acl/build/libarm_compute_graph.so" + ) + DEPS_SONAME+=( + "libarm_compute.so" + "libarm_compute_graph.so" + ) + fi + + # ARM system libraries + DEPS_LIST+=( + "/usr/lib64/libgfortran.so.5" + "/opt/OpenBLAS/lib/libopenblas.so.0" + ) + DEPS_SONAME+=( + "libgfortran.so.5" + "libopenblas.so.0" + ) +fi + rm -rf /usr/local/cuda* SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" diff --git a/.ci/manywheel/build_cuda.sh b/.ci/manywheel/build_cuda.sh index 2a822295e0361..b3258669b6f88 100644 --- a/.ci/manywheel/build_cuda.sh +++ b/.ci/manywheel/build_cuda.sh @@ -29,6 +29,35 @@ if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then EXTRA_CAFFE2_CMAKE_FLAGS=() fi +# Detect architecture +ARCH=$(uname -m) +echo "Building for architecture: $ARCH" + +# Detect and configure NVPL for BLAS/LAPACK and ARM Compute Library for CUDA aarch64 +if [[ "$ARCH" == "aarch64" ]]; then + # Use NVPL (NVIDIA Performance Libraries) for ARM + # NVPL provides optimized BLAS and LAPACK for better cpu performance on NVIDIA platforms + if [[ ! -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then + echo "ERROR: NVPL not found at /usr/local/lib/" + echo "NVPL (BLAS/LAPACK) is required for CUDA aarch64 builds" + exit 1 + fi + echo "Using NVPL BLAS/LAPACK for CUDA aarch64" + export BLAS=NVPL + + # ACL is required for aarch64 builds + if [[ ! -d "/acl" ]]; then + echo "ERROR: ARM Compute Library not found at /acl" + echo "ACL is required for aarch64 builds. Check Docker image setup." + exit 1 + fi + + export USE_MKLDNN=1 + export USE_MKLDNN_ACL=1 + export ACL_ROOT_DIR=/acl + echo "ARM Compute Library enabled for MKLDNN: ACL_ROOT_DIR=/acl" +fi + # Determine CUDA version and architectures to build for # # NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`, @@ -53,34 +82,60 @@ fi cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.') EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") +# Function to remove architectures from a list +remove_archs() { + local result="$1" + shift + for arch in "$@"; do + result="${result//${arch};/}" + done + echo "$result" +} + +# Function to filter CUDA architectures for aarch64 +# aarch64 ARM GPUs only support certain compute capabilities +# Keep: 8.0 (A100), 9.0+ (Hopper, Grace Hopper, newer) +# Remove: < 8.0 (no ARM GPUs), 8.6 (x86_64 RTX 3090/A6000 only) +filter_aarch64_archs() { + local arch_list="$1" + # Explicitly remove architectures not needed on aarch64 + arch_list=$(remove_archs "$arch_list" "5.0" "6.0" "7.0" "7.5" "8.6") + echo "$arch_list" +} + +# Base: Common architectures across all modern CUDA versions +TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0" + case ${CUDA_VERSION} in - #removing sm_50-sm_60 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases - #however we would like to keep sm_70 architecture see: https://github.com/pytorch/pytorch/issues/157517 - 12.8) - TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0" - ;; - 12.9) - TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0+PTX" - # WAR to resolve the ld error in libtorch build with CUDA 12.9 + 12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;${TORCH_CUDA_ARCH_LIST}" ;; # Only 12.6 includes Legacy Maxwell/Pascal that will be removed in future releases + 12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0" ;; # +Hopper/Blackwell support + 12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};10.0;12.0+PTX" # +Hopper/Blackwell support + PTX for forward compatibility if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then - TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX" + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//7.0;/}" # Remove 7.0 to resolve the ld error + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch fi ;; 13.0) - TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX" - ;; - 12.6) - TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0" - ;; - *) - echo "unknown cuda version $CUDA_VERSION" - exit 1 + TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX" + export TORCH_NVCC_FLAGS="-compress-mode=size" + export BUILD_BUNDLE_PTXAS=1 ;; + *) echo "unknown cuda version $CUDA_VERSION"; exit 1 ;; esac +# Filter for aarch64: Remove < 8.0 and 8.6 +[[ "$ARCH" == "aarch64" ]] && TORCH_CUDA_ARCH_LIST=$(filter_aarch64_archs "$TORCH_CUDA_ARCH_LIST") + +echo "TORCH_CUDA_ARCH_LIST set to: $TORCH_CUDA_ARCH_LIST" export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} echo "${TORCH_CUDA_ARCH_LIST}" +# Disable MAGMA for aarch64 as pre-built libraries are x86-64 only +if [[ "$ARCH" == "aarch64" ]]; then + echo "Disabling MAGMA for aarch64 architecture" + export USE_MAGMA=0 +fi + # Package directories WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot" LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot" @@ -244,6 +299,51 @@ else exit 1 fi +# Add ARM-specific library dependencies +if [[ "$ARCH" == "aarch64" ]]; then + echo "Adding ARM-specific library dependencies" + + # ARM Compute Library (if available) + if [[ -d "/acl/build" ]]; then + echo "Adding ARM Compute Library" + DEPS_LIST+=( + "/acl/build/libarm_compute.so" + "/acl/build/libarm_compute_graph.so" + ) + DEPS_SONAME+=( + "libarm_compute.so" + "libarm_compute_graph.so" + ) + fi + + # ARM system libraries + DEPS_LIST+=( + "/lib64/libgomp.so.1" + "/usr/lib64/libgfortran.so.5" + ) + DEPS_SONAME+=( + "libgomp.so.1" + "libgfortran.so.5" + ) + + # NVPL libraries (ARM optimized BLAS/LAPACK) + if [[ -d "/usr/local/lib" && -f "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ]]; then + echo "Adding NVPL libraries for ARM" + DEPS_LIST+=( + "/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0" + "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" + "/usr/local/lib/libnvpl_lapack_core.so.0" + "/usr/local/lib/libnvpl_blas_core.so.0" + ) + DEPS_SONAME+=( + "libnvpl_lapack_lp64_gomp.so.0" + "libnvpl_blas_lp64_gomp.so.0" + "libnvpl_lapack_core.so.0" + "libnvpl_blas_core.so.0" + ) + fi +fi + # run_tests.sh requires DESIRED_CUDA to know what tests to exclude export DESIRED_CUDA="$cuda_version_nodot" @@ -251,9 +351,11 @@ export DESIRED_CUDA="$cuda_version_nodot" rm -rf /usr/local/cuda || true ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda -# Switch `/usr/local/magma` to the desired CUDA version -rm -rf /usr/local/magma || true -ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma +# Switch `/usr/local/magma` to the desired CUDA version (skip for aarch64) +if [[ "$ARCH" != "aarch64" ]]; then + rm -rf /usr/local/magma || true + ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma +fi export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130 export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0 diff --git a/.ci/pytorch/check_binary.sh b/.ci/pytorch/check_binary.sh index 0f632f8006c07..34a26a293ae44 100755 --- a/.ci/pytorch/check_binary.sh +++ b/.ci/pytorch/check_binary.sh @@ -237,7 +237,8 @@ if [[ "$OSTYPE" == "msys" ]]; then fi # Test that CUDA builds are setup correctly -if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'xpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRED_CUDA" != *"rocm"* && "$(uname -m)" != "s390x" ]]; then +# Skip CUDA hardware checks for aarch64 as they run on CPU-only runners +if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'xpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRED_CUDA" != *"rocm"* && "$(uname -m)" != "s390x" && "$(uname -m)" != "aarch64" ]]; then if [[ "$PACKAGE_TYPE" == 'libtorch' ]]; then build_and_run_example_cpp check-torch-cuda else @@ -276,7 +277,9 @@ fi # if cuda if [[ "$PACKAGE_TYPE" != 'libtorch' ]]; then pushd "$(dirname ${BASH_SOURCE[0]})/smoke_test" python -c "from smoke_test import test_linalg; test_linalg()" - if [[ "$DESIRED_CUDA" == *cuda* ]]; then + # Skip CUDA linalg test for aarch64 as they run on CPU-only runners + # TODO: Remove this once CUDA ARM runner is available + if [[ "$DESIRED_CUDA" == *cuda* && "$(uname -m)" != "aarch64" ]]; then python -c "from smoke_test import test_linalg; test_linalg('cuda')" fi popd diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index baff04967e3ae..1c4f88775b14a 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -97,7 +97,6 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - {%- if config["gpu_arch_type"] != "cuda-aarch64" %} !{{ config["build_name"] }}-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -220,7 +219,6 @@ jobs: - name: Teardown ROCm uses: ./.github/actions/teardown-rocm {%- endif %} - {%- endif %} {%- if branches == "nightly" %} !{{ upload.upload_binaries(config) }} diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index bfa035bc753b8..cb4cc738abaef 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -260,11 +260,8 @@ jobs: "${DOCKER_IMAGE}" ) docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh" - if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then - docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/aarch64_linux/aarch64_ci_build.sh" - else - docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh" - fi + # Unified build script for all architectures (x86_64, aarch64, s390x) + docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh" - name: Chown artifacts if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }} diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 6a22e14af09b7..ff5ad7e89f99b 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -68,6 +68,7 @@ jobs: build_environment: linux-aarch64-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cpu-aarch64-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -136,6 +137,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_10-cuda-aarch64-12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda-aarch64-12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda-aarch64-12_6 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda-aarch64-12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -182,6 +208,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_10-cuda-aarch64-12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda-aarch64-12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: "12.8-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda-aarch64-12_8 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda-aarch64-12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -228,6 +279,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_10-cuda-aarch64-12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda-aarch64-12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: "12.9-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -274,6 +350,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_10-cuda-aarch64-13_0-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda-aarch64-13_0-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda-aarch64-13_0 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -317,6 +418,7 @@ jobs: build_environment: linux-aarch64-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cpu-aarch64-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -385,6 +487,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_11-cuda-aarch64-12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda-aarch64-12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda-aarch64-12_6 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda-aarch64-12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -431,6 +558,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_11-cuda-aarch64-12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda-aarch64-12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: "12.8-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda-aarch64-12_8 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda-aarch64-12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -477,6 +629,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_11-cuda-aarch64-12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda-aarch64-12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: "12.9-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -523,6 +700,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_11-cuda-aarch64-13_0-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda-aarch64-13_0-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda-aarch64-13_0 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -566,6 +768,7 @@ jobs: build_environment: linux-aarch64-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cpu-aarch64-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -634,6 +837,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_12-cuda-aarch64-12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda-aarch64-12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda-aarch64-12_6 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda-aarch64-12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -680,6 +908,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_12-cuda-aarch64-12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda-aarch64-12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: "12.8-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda-aarch64-12_8 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda-aarch64-12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -726,6 +979,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_12-cuda-aarch64-12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda-aarch64-12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: "12.9-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -772,6 +1050,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_12-cuda-aarch64-13_0-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda-aarch64-13_0-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda-aarch64-13_0 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -815,6 +1118,7 @@ jobs: build_environment: linux-aarch64-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cpu-aarch64-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -883,6 +1187,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_13-cuda-aarch64-12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda-aarch64-12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda-aarch64-12_6 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda-aarch64-12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -929,6 +1258,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_13-cuda-aarch64-12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda-aarch64-12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: "12.8-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda-aarch64-12_8 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda-aarch64-12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -975,6 +1329,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_13-cuda-aarch64-12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda-aarch64-12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: "12.9-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1021,6 +1400,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_13-cuda-aarch64-13_0-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda-aarch64-13_0-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda-aarch64-13_0 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1064,6 +1468,7 @@ jobs: build_environment: linux-aarch64-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cpu-aarch64-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1132,6 +1537,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_13t-cuda-aarch64-12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13t-cuda-aarch64-12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.13t" + build_name: manywheel-py3_13t-cuda-aarch64-12_6 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda-aarch64-12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1178,6 +1608,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_13t-cuda-aarch64-12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13t-cuda-aarch64-12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: "12.8-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.13t" + build_name: manywheel-py3_13t-cuda-aarch64-12_8 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda-aarch64-12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1224,6 +1679,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_13t-cuda-aarch64-12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13t-cuda-aarch64-12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: "12.9-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.13t" + build_name: manywheel-py3_13t-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1270,6 +1750,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_13t-cuda-aarch64-13_0-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13t-cuda-aarch64-13_0-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DESIRED_PYTHON: "3.13t" + build_name: manywheel-py3_13t-cuda-aarch64-13_0 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1313,6 +1818,7 @@ jobs: build_environment: linux-aarch64-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cpu-aarch64-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1381,6 +1887,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_14-cuda-aarch64-12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-cuda-aarch64-12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda-aarch64-12_6 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda-aarch64-12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1427,6 +1958,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_14-cuda-aarch64-12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-cuda-aarch64-12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: "12.8-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda-aarch64-12_8 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda-aarch64-12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1473,6 +2029,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_14-cuda-aarch64-12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-cuda-aarch64-12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: "12.9-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1519,6 +2100,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_14-cuda-aarch64-13_0-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14-cuda-aarch64-13_0-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DESIRED_PYTHON: "3.14" + build_name: manywheel-py3_14-cuda-aarch64-13_0 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1562,6 +2168,7 @@ jobs: build_environment: linux-aarch64-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cpu-aarch64-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1630,6 +2237,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_14t-cuda-aarch64-12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-cuda-aarch64-12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda-aarch64-12_6 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda-aarch64-12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1676,6 +2308,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_14t-cuda-aarch64-12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-cuda-aarch64-12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: "12.8-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda-aarch64-12_8 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda-aarch64-12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1722,6 +2379,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_14t-cuda-aarch64-12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-cuda-aarch64-12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: "12.9-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: @@ -1768,6 +2450,31 @@ jobs: timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_14t-cuda-aarch64-13_0-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_14t-cuda-aarch64-13_0-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DESIRED_PYTHON: "3.14t" + build_name: manywheel-py3_14t-cuda-aarch64-13_0 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: diff --git a/.github/workflows/generated-linux-binary-libtorch-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-nightly.yml index 446415807f204..e068d11ca5f1d 100644 --- a/.github/workflows/generated-linux-binary-libtorch-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-nightly.yml @@ -67,6 +67,7 @@ jobs: build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-cpu-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -133,6 +134,7 @@ jobs: build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-cuda12_6-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -201,6 +203,7 @@ jobs: build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-cuda12_8-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -269,6 +272,7 @@ jobs: build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-cuda12_9-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -337,6 +341,7 @@ jobs: build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-cuda13_0-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -406,6 +411,7 @@ jobs: build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-rocm7_0-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -524,6 +530,7 @@ jobs: build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-rocm7_1-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index a5f4e85ca58c1..ac04187e24d8c 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -66,6 +66,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -130,6 +131,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -196,6 +198,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -262,6 +265,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -328,6 +332,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -394,6 +399,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -509,6 +515,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -623,6 +630,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -732,6 +740,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -796,6 +805,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -862,6 +872,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -928,6 +939,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -994,6 +1006,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1060,6 +1073,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1175,6 +1189,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1289,6 +1304,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1398,6 +1414,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1462,6 +1479,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1528,6 +1546,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1594,6 +1613,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1660,6 +1680,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1726,6 +1747,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1841,6 +1863,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -1955,6 +1978,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2064,6 +2088,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2128,6 +2153,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2194,6 +2220,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2260,6 +2287,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2326,6 +2354,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2392,6 +2421,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2507,6 +2537,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2621,6 +2652,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2730,6 +2762,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2794,6 +2827,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2860,6 +2894,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2926,6 +2961,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2992,6 +3028,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3058,6 +3095,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3173,6 +3211,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3287,6 +3326,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3396,6 +3436,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3460,6 +3501,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3526,6 +3568,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3592,6 +3635,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3658,6 +3702,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3724,6 +3769,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3839,6 +3885,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3953,6 +4000,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4062,6 +4110,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4126,6 +4175,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4192,6 +4242,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4258,6 +4309,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4324,6 +4376,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4390,6 +4443,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4505,6 +4559,7 @@ jobs: build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4619,6 +4674,7 @@ jobs: PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 4a7ebe8366336..f9d668320ecb2 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -68,6 +68,7 @@ jobs: build_environment: linux-s390x-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cpu-s390x-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -132,6 +133,7 @@ jobs: build_environment: linux-s390x-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cpu-s390x-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -196,6 +198,7 @@ jobs: build_environment: linux-s390x-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cpu-s390x-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -260,6 +263,7 @@ jobs: build_environment: linux-s390x-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cpu-s390x-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -324,6 +328,7 @@ jobs: build_environment: linux-s390x-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cpu-s390x-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -388,6 +393,7 @@ jobs: build_environment: linux-s390x-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14-cpu-s390x-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -452,6 +458,7 @@ jobs: build_environment: linux-s390x-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_14t-cpu-s390x-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: diff --git a/.github/workflows/test-check-binary.yml b/.github/workflows/test-check-binary.yml index 5f0ad59d3a3bb..883b2d253aa8f 100644 --- a/.github/workflows/test-check-binary.yml +++ b/.github/workflows/test-check-binary.yml @@ -20,6 +20,8 @@ jobs: docker-image: python:3.11 docker-build-dir: "skip-docker-build" script: | + # Install dependencies FIRST (before torch) as torch imports may need them + pip install 'numpy>=1.21.2' 'protobuf>=3.20' 'typing-extensions>=4.8.0' pushd .ci/pytorch/ pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu DESIRED_PYTHON=3.11 DESIRED_CUDA=cpu PACKAGE_TYPE=manywheel ./check_binary.sh @@ -34,6 +36,8 @@ jobs: docker-image: python:3.11 docker-build-dir: "skip-docker-build" script: | + # Install dependencies FIRST (before torch) as torch imports may need them + pip install 'numpy>=1.21.2' 'protobuf>=3.20' 'typing-extensions>=4.8.0' STABLE_CUDA_VERSION=$(python3 .github/scripts/get_ci_variable.py --cuda-stable-version) CUDA_VERSION_NODOT=$(echo ${STABLE_CUDA_VERSION} | tr -d '.') pushd .ci/pytorch/ From 4ddca708b67b09c90c59bd1c16e1aa0d8c599ce1 Mon Sep 17 00:00:00 2001 From: manigkrish Date: Mon, 24 Nov 2025 20:03:31 +0000 Subject: [PATCH 015/338] Add human-readable type name comment to TYPE_MATCH guards for debugging clarity (#168272) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #168160 TYPE_MATCH guards currently generate code like: ___check_type_id(x, 94229757490048) The numeric type-id provides no information about the type being checked. This PR appends a human-readable `repr(type)` as a trailing comment: ___check_type_id(x, 94229757490048) # ### What This Change Does - Adds `repr(t)` to improve readability of guard output. - No behavior or semantics are changed — this is a debug-only improvement. ### Testing Verified that `repr(type)` produces readable, accurate names for built-in, user-defined, and torch.nn module types. Runtime behavior is unchanged; CI will validate everything end-to-end. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168272 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 --- torch/_dynamo/guards.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index cf621921cd59b..335323e638769 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1945,7 +1945,8 @@ def TYPE_MATCH(self, guard: Guard) -> None: guard._unserializable = True obj_id = self.id_ref(t, f"type({guard.name})") - code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" + type_repr = repr(t) + code = f"___check_type_id({self.arg_ref(guard)}, {obj_id}) # {type_repr}" self._set_guard_export_info(guard, [code]) self.get_guard_manager(guard).add_type_match_guard( From 43e23ee7f58e0e93537d29b059829e894c655d3a Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Fri, 21 Nov 2025 23:57:35 +0000 Subject: [PATCH 016/338] [3.14] Fix nn.Module annotations lookup (#168325) Found in https://github.com/pytorch/pytorch/pull/167407 but affects non-threaded builds as well Pull Request resolved: https://github.com/pytorch/pytorch/pull/168325 Approved by: https://github.com/williamwen42 --- tools/linter/adapters/import_linter.py | 1 + torch/_dynamo/functional_export.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tools/linter/adapters/import_linter.py b/tools/linter/adapters/import_linter.py index 69c5ecc19fa5c..1e1b6f79dffda 100644 --- a/tools/linter/adapters/import_linter.py +++ b/tools/linter/adapters/import_linter.py @@ -68,6 +68,7 @@ class LintMessage(NamedTuple): "torchrec", "numpy", "torch_xla", + "annotationlib", # added in python 3.14 ] ) diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 548a4b279b860..6eb2dcb59b7f3 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -1,5 +1,6 @@ import inspect import logging +import sys import traceback from collections import namedtuple from collections.abc import Callable @@ -651,7 +652,12 @@ def inner(*args: Any, **kwargs: Any) -> Any: graph_module._non_persistent_buffers_set = ( pyt.root._non_persistent_buffers_set.copy() ) - annotations = torch.nn.Module.__dict__.get("__annotations__", None) + if sys.version_info >= (3, 14): + import annotationlib # added in 3.14 + + annotations = annotationlib.get_annotations(torch.nn.Module) + else: + annotations = getattr(torch.nn.Module, "__annotations__", None) for name, value in pyt.root.__dict__.items(): if annotations and name not in annotations: graph_module.__dict__[name] = value From 9ff3bd7e8b36333f0948f3d70a7ce9df7b1158e1 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Mon, 24 Nov 2025 09:42:40 -0800 Subject: [PATCH 017/338] [precompile] set strict_autograd_cache=True when precompiling (#168989) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Right now we get a pretty hard to understand error message: ``` Traceback (most recent call last): File "/home/bobren/local/a/pytorch/spc.py", line 80, in .save_compiled_function(path) File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile.py", line 129, in save_compiled_function f.write(type(self).serialize(self)) File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile.py", line 145, in serialize type(compiled_fn).serialize_compile_artifacts(compiled_fn), File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile_types.py", line 54, in serialize_compile_artifacts def deserialize_compile_artifacts(cls, data: bytes) -> Any: TypeError: 'NoneType' object is not callable ``` which happens because cache is bypassed, so the "serialize" field on compiled_fn is set to None. after this PR we get a much more direct error message: ``` (/home/bobren/local/a/pytorch-env) [9:18] devgpu009:/home/bobren/local/a/pytorch [130] ❯ cache_tlp python spc.py Wrapped class is my_property: 123 Traceback (most recent call last): File "/home/bobren/local/a/pytorch/spc.py", line 79, in .aot_compile(((input_tensor,), {})) File "/home/bobren/local/a/pytorch/torch/_dynamo/eval_frame.py", line 806, in aot_compile return aot_compile_fullgraph( File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile.py", line 236, in aot_compile_fullgraph compiled_fn = backend( File "/home/bobren/local/a/pytorch/torch/__init__.py", line 2445, in __call__ return compile_fx(model_, inputs_, config_patches=self.config) File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2525, in compile_fx return _maybe_wrap_and_compile_fx_main( File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2602, in _maybe_wrap_and_compile_fx_main return _compile_fx_main( File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2797, in _compile_fx_main return aot_autograd( File "/home/bobren/local/a/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__ cg = aot_module_simplified(gm, example_inputs, **self.kwargs) File "/home/bobren/local/a/pytorch/torch/_functorch/aot_autograd.py", line 1097, in aot_module_simplified compiled_fn = AOTAutogradCache.try_load( File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 708, in try_load raise e File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 639, in try_load cache_key, debug_lines = autograd_cache_key( File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 499, in autograd_cache_key check_cacheable(gm) File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 292, in check_cacheable check_node_safe(node) File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 240, in check_node_safe raise BypassAOTAutogradCache( torch._functorch._aot_autograd.autograd_cache.BypassAOTAutogradCache: Unsupported call_function target tag_activation_checkpoint. Function module: torch.ops.higher_order, Function name: tag_activation_checkpoint ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168989 Approved by: https://github.com/jamesjwu --- torch/_dynamo/eval_frame.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 4253fa031d2ec..1075b6d66f7c8 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -801,14 +801,16 @@ def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: raise RuntimeError("aot compile requires a callable dynamo callback.") assert self._hooks is not None - return aot_compile_fullgraph( - fn, - example_inputs, - hooks=self._hooks, - backend=innermost_fn( - self.callback, unaltered_fn_attr="_torchdynamo_orig_backend" - ), - ) + + with torch._functorch.config.patch(strict_autograd_cache=True): + return aot_compile_fullgraph( + fn, + example_inputs, + hooks=self._hooks, + backend=innermost_fn( + self.callback, unaltered_fn_attr="_torchdynamo_orig_backend" + ), + ) # add context containing GraphModule to any GraphModule forward functions if isinstance(fn, GraphModule): From aa2cf78fb0c0a5137e289f13eebb9204f8d6ee0f Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sun, 23 Nov 2025 12:00:40 -0800 Subject: [PATCH 018/338] Refactor call_function (#168932) `call_function` is starting to get pretty long so pull the `nonstrict_traceable` portion out into a helper function before we make it even longer for #168890. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168932 Approved by: https://github.com/anijain2305 --- torch/_dynamo/variables/torch.py | 321 ++++++++++++++++--------------- 1 file changed, 165 insertions(+), 156 deletions(-) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 645a4e9595cc1..76da71f6fb323 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1435,162 +1435,7 @@ def call_function( from .builder import wrap_fx_proxy if self.nonstrict_traceable: - import torch._higher_order_ops.flat_apply as flat_apply - from torch._higher_order_ops.flat_apply import ( - func_to_graphable, - is_graphable_type, - ) - from torch._subclasses.fake_tensor import fake_tensor_tls - from torch.utils._pytree import tree_flatten - - from .base import AsPythonConstantNotImplementedError - - # 1. Convert `args, kwargs` into pytree-flattened proxy forms. - # - # Rather than reconstructing `args, kwargs` into python objects and - # then tree_flatten them, we just let Dynamo symbolically interpret - # `tree_flatten((args, kwargs))`. This saves us from having to - # worry about the reconstruction logic, side effects, and guards. - packed_input_vt = TupleVariable.build( - tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) - ) - out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type] - tx, [packed_input_vt], {} - ) - assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 - flat_args_vts, input_spec_vt = out_vt.items - assert isinstance(flat_args_vts, ListVariable) - - # Handle the case when the input contains a non-graphable type. - for flat_arg_vt in flat_args_vts.items: - arg_type = flat_arg_vt.python_type() - if not is_graphable_type(arg_type): - type_name = flat_arg_vt.python_type().__qualname__ - unimplemented( - gb_type="Invalid input type for nonstrict_trace-ed function", - context=f"Encountered input of type <{type_name}>.", - explanation=( - "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) " - "or pytree containers of those are allowed as inputs. The provided argument contains " - "an unsupported type." - ), - hints=[ - "Use one of the following to register the type with pytree:\n" - "* `torch.utils._pytree.register_constant`\n" - "* `torch.utils._pytree.register_dataclass`\n" - "* `torch.utils._pytree.register_pytree_node`", - ], - ) - - # Since we checked with `is_graphable` above, `as_proxy` on the - # flat_arg VT should always work. - proxified_flat_args = [ - flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vts.items - ] - - # The downstream `flat_apply` call requires the input spec; however, - # the spec not a graphable type, so we still have to reconstruct it - # into a python object, and store it as a constant attribute on the - # fx graph. - try: - input_spec = input_spec_vt.as_python_constant() - except AsPythonConstantNotImplementedError as e: - typ = e.vt.python_type() - type_name = typ.__qualname__ - import torch.utils._pytree as pytree - - if pytree.is_constant_class(typ): - unimplemented( - gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region", - context=f"Input={input_spec_vt}, offending type <{type_name}>.", - explanation=( - "Calling a `nonstrict_trace`-ed function with an input that contains an object " - f"of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object " - "was constructed _inside_ the `torch.compile` region. This is not supported." - ), - hints=[ - "Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.", - *graph_break_hints.SUPPORTABLE, - ], - from_exc=e, - ) - else: - unimplemented( - gb_type="Invalid use of pytree_flatten with nonstrict_trace-ed function", - context=f"Input={input_spec_vt}, offending type <{type_name}>.", - explanation=( - "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered " - f"with a `pytree_flatten` that places an object of type <{type_name}> into the context." - ), - hints=[ - "Modifying the `pytree_flatten` to avoid placing the object into the context.", - f"Apply one of the following to <{type_name}>:\n" - "* `torch.utils._pytree.register_constant`\n" - "* `torch.utils._pytree.register_dataclass`\n" - "* `torch.utils._pytree.register_pytree_node`", - *graph_break_hints.SUPPORTABLE, - ], - from_exc=e, - ) - - fn = self.value - - def patched_fn(*args, **kwargs): - # This enables reads to global/captured tensors, and we'll just - # treat them as constants in the graph. Note that after - # AOTDispatcher, this logic would disappear. - old_val = fake_tensor_tls.allow_non_fake_inputs_override - fake_tensor_tls.allow_non_fake_inputs_override = True - try: - res = fn(*args, **kwargs) - finally: # reset even when `fn` raises - fake_tensor_tls.allow_non_fake_inputs_override = old_val - return res - - # `flat_apply` wants a TreeSpec for the function input. - _, f_spec = func_to_graphable(patched_fn) - - # TreeSpec isn't graphable, so we register the function and input - # specs as attributes on the graph module. - f_spec_proxy = tx.output.register_static_attr_and_return_proxy( - f"{fn.__name__}_spec", f_spec - ) - input_spec_proxy = tx.output.register_static_attr_and_return_proxy( - fn.__name__ + "_input_spec", - # pyrefly: ignore [unbound-name] - input_spec, - ) - f_spec_proxy.node.type = type(f_spec) - # pyrefly: ignore [unbound-name] - input_spec_proxy.node.type = type(input_spec) - all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args) - - # 2. Create a proxy call to `flat_apply`, then fake-tensor propagate - # the call and wrap output into a VariableTracker. - proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {}) - try: - # TODO support more output types once `flat_apply` supports - # pytree-able output types. We can have Dynamo trace through an - # unflatten call (just like we traced through a flatten above) - # to rebuild the actual output VT. - out_vt = wrap_fx_proxy(tx, proxy) - except ( - # From `handle_traced_output`. - torch._dynamo.exc.Unsupported, - # From `flat_apply` assert on output type. - torch._dynamo.exc.TorchRuntimeError, - ): - unimplemented( - gb_type="Unsupported output type for nonstrict_trace-ed function", - context=f"Function: {fn.__name__}", - explanation=( - "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, list)" - " are allowed as output. The result of this call contains an unsupported type." - ), - hints=[*graph_break_hints.SUPPORTABLE], - ) - - return out_vt + return self._call_nonstrict_traceable_function(tx, args, kwargs) if self.torch_function_override_enabled(tx, args, kwargs): return dispatch_torch_function(tx, self, args, kwargs) @@ -1829,6 +1674,170 @@ def patched_fn(*args, **kwargs): return tensor_variable + def _call_nonstrict_traceable_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + import torch._higher_order_ops.flat_apply as flat_apply + from torch._higher_order_ops.flat_apply import ( + func_to_graphable, + is_graphable_type, + ) + from torch._subclasses.fake_tensor import fake_tensor_tls + from torch.utils._pytree import tree_flatten + + from .base import AsPythonConstantNotImplementedError + from .builder import wrap_fx_proxy + + # 1. Convert `args, kwargs` into pytree-flattened proxy forms. + # + # Rather than reconstructing `args, kwargs` into python objects and + # then tree_flatten them, we just let Dynamo symbolically interpret + # `tree_flatten((args, kwargs))`. This saves us from having to + # worry about the reconstruction logic, side effects, and guards. + packed_input_vt = TupleVariable.build( + tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) + ) + out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type] + tx, [packed_input_vt], {} + ) + assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 + flat_args_vts, input_spec_vt = out_vt.items + assert isinstance(flat_args_vts, ListVariable) + + # Handle the case when the input contains a non-graphable type. + for flat_arg_vt in flat_args_vts.items: + arg_type = flat_arg_vt.python_type() + if not is_graphable_type(arg_type): + type_name = flat_arg_vt.python_type().__qualname__ + unimplemented( + gb_type="Invalid input type for nonstrict_trace-ed function", + context=f"Encountered input of type <{type_name}>.", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) " + "or pytree containers of those are allowed as inputs. The provided argument contains " + "an unsupported type." + ), + hints=[ + "Use one of the following to register the type with pytree:\n" + "* `torch.utils._pytree.register_constant`\n" + "* `torch.utils._pytree.register_dataclass`\n" + "* `torch.utils._pytree.register_pytree_node`", + ], + ) + + # Since we checked with `is_graphable` above, `as_proxy` on the + # flat_arg VT should always work. + proxified_flat_args = [ + flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vts.items + ] + + # The downstream `flat_apply` call requires the input spec; however, + # the spec not a graphable type, so we still have to reconstruct it + # into a python object, and store it as a constant attribute on the + # fx graph. + try: + input_spec = input_spec_vt.as_python_constant() + except AsPythonConstantNotImplementedError as e: + typ = e.vt.python_type() + type_name = typ.__qualname__ + import torch.utils._pytree as pytree + + if pytree.is_constant_class(typ): + unimplemented( + gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region", + context=f"Input={input_spec_vt}, offending type <{type_name}>.", + explanation=( + "Calling a `nonstrict_trace`-ed function with an input that contains an object " + f"of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object " + "was constructed _inside_ the `torch.compile` region. This is not supported." + ), + hints=[ + "Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.", + *graph_break_hints.SUPPORTABLE, + ], + from_exc=e, + ) + else: + unimplemented( + gb_type="Invalid use of pytree_flatten with nonstrict_trace-ed function", + context=f"Input={input_spec_vt}, offending type <{type_name}>.", + explanation=( + "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered " + f"with a `pytree_flatten` that places an object of type <{type_name}> into the context." + ), + hints=[ + "Modifying the `pytree_flatten` to avoid placing the object into the context.", + f"Apply one of the following to <{type_name}>:\n" + "* `torch.utils._pytree.register_constant`\n" + "* `torch.utils._pytree.register_dataclass`\n" + "* `torch.utils._pytree.register_pytree_node`", + *graph_break_hints.SUPPORTABLE, + ], + from_exc=e, + ) + + fn = self.value + + def patched_fn(*args, **kwargs): + # This enables reads to global/captured tensors, and we'll just + # treat them as constants in the graph. Note that after + # AOTDispatcher, this logic would disappear. + old_val = fake_tensor_tls.allow_non_fake_inputs_override + fake_tensor_tls.allow_non_fake_inputs_override = True + try: + res = fn(*args, **kwargs) + finally: # reset even when `fn` raises + fake_tensor_tls.allow_non_fake_inputs_override = old_val + return res + + # `flat_apply` wants a TreeSpec for the function input. + _, f_spec = func_to_graphable(patched_fn) + + # TreeSpec isn't graphable, so we register the function and input + # specs as attributes on the graph module. + f_spec_proxy = tx.output.register_static_attr_and_return_proxy( + f"{fn.__name__}_spec", f_spec + ) + input_spec_proxy = tx.output.register_static_attr_and_return_proxy( + fn.__name__ + "_input_spec", + # pyrefly: ignore [unbound-name] + input_spec, + ) + f_spec_proxy.node.type = type(f_spec) + # pyrefly: ignore [unbound-name] + input_spec_proxy.node.type = type(input_spec) + all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args) + + # 2. Create a proxy call to `flat_apply`, then fake-tensor propagate + # the call and wrap output into a VariableTracker. + proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {}) + try: + # TODO support more output types once `flat_apply` supports + # pytree-able output types. We can have Dynamo trace through an + # unflatten call (just like we traced through a flatten above) + # to rebuild the actual output VT. + out_vt = wrap_fx_proxy(tx, proxy) + except ( + # From `handle_traced_output`. + torch._dynamo.exc.Unsupported, + # From `flat_apply` assert on output type. + torch._dynamo.exc.TorchRuntimeError, + ): + unimplemented( + gb_type="Unsupported output type for nonstrict_trace-ed function", + context=f"Function: {fn.__name__}", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, list)" + " are allowed as output. The result of this call contains an unsupported type." + ), + hints=[*graph_break_hints.SUPPORTABLE], + ) + + return out_vt + def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" if self.value is torch.nn.modules.utils._ntuple: From 40733d7891ca077a6f1e4121a8d94c7fc8bf7f14 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 24 Nov 2025 21:11:39 +0000 Subject: [PATCH 019/338] Revert "[DebugMode] wait before hashing collectives by default (#168119)" This reverts commit c56655268b4ae575ee4c89c312fd93ca2f5b3ba9. Reverted https://github.com/pytorch/pytorch/pull/168119 on behalf of https://github.com/yushangdi due to This PR caused DebugMode to hang/segfault sometimes. See repro in P2054777054 ([comment](https://github.com/pytorch/pytorch/pull/168119#issuecomment-3572738530)) --- .../tensor/debug/test_debug_mode.py | 26 ------------------- torch/utils/_debug_mode.py | 11 +------- 2 files changed, 1 insertion(+), 36 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index c0625d37c6dad..c8cc5930d4e67 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -576,32 +576,6 @@ def test_check_structure_mismatches(self): with self.assertRaisesRegex(ValueError, "Log lengths don't match"): DebugMode.check_hash_mismatches(dm1.logs, dm3.logs) - @unittest.skipIf( - not torch.cuda.is_available() - or torch.cuda.get_device_properties(0).total_memory < 2**26, - "Being conservative, test peak memory is 25MB?", - ) - def test_tensor_hash_waits_on_collective(self): - # test that hashing collectives gives correct results - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - - local_tensor = torch.ones(2**18, device=self.device_type) - dt = DTensor.from_local(local_tensor, mesh, [Shard(0)], run_check=False) - - with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): - dt.redistribute(mesh, [Replicate()]) - - # Find all_gather hash - all_gather_logs = [ - op - for op in debug_mode.logs - if isinstance(op, _OpCall) - and op.op == torch.ops._c10d_functional.all_gather_into_tensor.default - ] - self.assertEqual(len(all_gather_logs), 1) - actual_hash = all_gather_logs[0].log["hash"] - self.assertEqual(actual_hash, float(local_tensor.numel() * self.world_size)) - def test_pretty_print_dtensor_make_fx(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 0b853997261a9..745b05d1904d7 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -924,9 +924,7 @@ def dispatch_hook(func, types, args, kwargs, result): @staticmethod @contextlib.contextmanager def log_tensor_hashes( - hash_fn: Union[Callable, str, list[str]] = "norm", - hash_inputs: bool = False, - wait_on_collectives: bool = True, + hash_fn: Union[Callable, str, list[str]] = "norm", hash_inputs: bool = False ): """ Installs hook for tensor hash logging. @@ -938,7 +936,6 @@ def log_tensor_hashes( - "hash_tensor": uses torch.hash_tensor (XOR sum reduction) - List of strings: returns tuple of hashes from above options hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash". - wait_on_collectives: if True (default), waits on async collective Work handles before hashing. NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes. """ @@ -969,12 +966,6 @@ def _dispatch_hash_hook(func, types, args, kwargs, result): if "empty" in str(func) or "profiler" in str(func): return None - # Wait on async collective Work handles before hashing - if wait_on_collectives and isinstance(result, (tuple, list)): - for item in result: - if isinstance(item, torch.ScriptObject) and hasattr(item, "wait"): - item.wait() - out = {} out["hash"] = _tree_hash(result) if hash_inputs: From f6572c3014fb30bb973176f4c203cf2255aaf395 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 24 Nov 2025 12:02:36 -0800 Subject: [PATCH 020/338] [BE] Delete Pytorch-circleci-labels (#169003) Last change to this file was back in 2021, and last CircleCI job was wound down probably in 2022, so it's safe to assume it's unsued Pull Request resolved: https://github.com/pytorch/pytorch/pull/169003 Approved by: https://github.com/huydhn --- .github/pytorch-circleci-labels.yml | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 .github/pytorch-circleci-labels.yml diff --git a/.github/pytorch-circleci-labels.yml b/.github/pytorch-circleci-labels.yml deleted file mode 100644 index 6990a3d304b24..0000000000000 --- a/.github/pytorch-circleci-labels.yml +++ /dev/null @@ -1,21 +0,0 @@ -# For documentation concerning this configuration please refer to, -# https://github.com/pytorch/pytorch-probot#trigger-circleci-workflows -labels_to_circle_params: - ci/binaries: - parameter: run_binary_tests - default_true_on: - branches: - - nightly - - release/.* - tags: - - v[0-9]+(\.[0-9]+)*-rc[0-9]+ - set_to_false: - - run_build - ci/master: - parameter: run_master_build - set_to_false: - - run_build - ci/slow-gradcheck: - parameter: run_slow_gradcheck_build - set_to_false: - - run_build From 93fef4bd1dd265588863929e35d9ac89328d5695 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Fri, 21 Nov 2025 15:57:26 -0800 Subject: [PATCH 021/338] [inductor] fix picking wrong contiguous node (#168371) We may pick wrong contiguous node in mix-order reduction fusion due to dynamic shapes. Differential Revision: [D87788131](https://our.internmc.facebook.com/intern/diff/D87788131) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168371 Approved by: https://github.com/PaulZhang12 --- test/inductor/test_mix_order_reduction.py | 48 ++++++++++++++++++++++- torch/_inductor/scheduler.py | 4 +- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index cae48673f2332..7eee686ea99b0 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -382,7 +382,51 @@ def fwd_bwd(f): metrics.codegen_mix_order_reduction, ) - def test_layer_norm_bwd_with_dynamic_shape(self): + @parametrize("dynamic_dims", ([0], [1], [0, 1])) + def test_rms_norm_bwd_with_dynamic_shape(self, dynamic_dims): + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + def f(x, w, eps): + return F.rms_norm(x, x.shape[-1:], weight=w, eps=eps) + + def fwd_bwd(f): + x.grad = None + w.grad = None + out = f(x, w, eps) + out.backward(dy) + return x.grad, w.grad + + M0, M1, N = 251, 223, 128 + wbdtype = torch.float + xdtype = torch.float + x = torch.randn(M0, M1, N, dtype=xdtype, device=GPU_TYPE, requires_grad=True) + torch._dynamo.mark_dynamic(x, (0, 1)) + w = torch.randn(N, dtype=wbdtype, device=GPU_TYPE, requires_grad=True) + dy = torch.randn_like(x) + eps = 1e-5 + + opt_f = torch.compile( + f, + options={ + "split_reductions": False, + }, + ) + + ref = fwd_bwd(f) + act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f) + + self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}") + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) + + @parametrize("dynamic_dims", ([0], [1], [0, 1])) + def test_layer_norm_bwd_with_dynamic_shape(self, dynamic_dims): + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + def f(x, w, eps): return F.layer_norm(x, x.shape[-1:], weight=w, bias=None, eps=eps) @@ -397,7 +441,7 @@ def fwd_bwd(f): wbdtype = torch.float xdtype = torch.float x = torch.randn(M0, M1, N, dtype=xdtype, device=GPU_TYPE, requires_grad=True) - torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(x, dynamic_dims) w = torch.randn(N, dtype=wbdtype, device=GPU_TYPE, requires_grad=True) dy = torch.randn_like(x) eps = 1e-5 diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index e5bd34ea977e7..b084612b9acc7 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -273,7 +273,9 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: return False contiguous_node, other_node = ( - (node1, node2) if g1[1] == ncol else (node2, node1) + (node1, node2) + if V.graph.sizevars.evaluate_expr(sympy.Eq(g1[1], ncol)) + else (node2, node1) ) # We previously only check the contiguous_node has contiguous From 83bb24cdaf7d10567be01ff029fec223efcfbd42 Mon Sep 17 00:00:00 2001 From: ruisizhang123 Date: Mon, 24 Nov 2025 23:25:55 +0000 Subject: [PATCH 022/338] [simplefsdp] fix dsv3 estimation (#168199) As titled, there are comm size estimation regression after this PR: https://github.com/pytorch/pytorch/pull/167852/, which cause DSV3 dynamic shape estimation error: https://github.com/pytorch/torchtitan/issues/2037. Also added dynamic shape comm estimation test cases in the PR cc. @eellison @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/168199 Approved by: https://github.com/laithsakka --- test/distributed/test_inductor_collectives.py | 266 ++++++++++++++++++ torch/_inductor/comm_analysis.py | 33 ++- 2 files changed, 285 insertions(+), 14 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 33bf475b91460..fdf03fdf3a1f5 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -14,6 +14,7 @@ # for some reason importing functional collectives after dynamo breaks collectives handling! import torch.distributed._functional_collectives as _functional_collectives +from torch import nn from torch._C import FileCheck from torch._dynamo.testing import CompileCounter from torch._dynamo.utils import same @@ -2217,10 +2218,80 @@ def func(inp, group_size, group_name): ag_1_wait = torch.ops.c10d_functional.wait_tensor(ag_1_out) return ag_1_wait + # test for static shape input estimation gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name) g = gm.graph for n in g.nodes: if is_all_gather_into_tensor(n): + assert str(n.meta["val"].size()) in [ + "torch.Size([8, 4])", + "torch.Size([16, 4])", + ] + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=True + ) + assert est_ms_nccl > 0 + + # test for unbacked dynamic shape input estimation + class TestModule(nn.Module): + def __init__(self, group_size, group_name): + super().__init__() + self.group_size = group_size + self.group_name = group_name + + def forward(self, x): + u = x.item() + # Use u as a dimension of a new tensor: + y = torch.empty(u, 4, device=x.device) + return func(y, self.group_size, self.group_name) + + inp = torch.tensor(1, device=self.device) + model = TestModule(group_size, group_name).to(self.device) + exported_program = torch.export.export( + model, + (inp,), + ) + gm = exported_program.module() + g = gm.graph + for n in g.nodes: + if is_all_gather_into_tensor(n): + assert str(n.meta["val"].size()) in [ + "torch.Size([2*u0, 4])", + "torch.Size([4*u0, 4])", + ] + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=True + ) + assert est_ms_nccl > 0 + + # test for backed dynamic shape input estimation + inp = torch.ones(4, 4, device=self.device) + torch._dynamo.mark_dynamic(inp, 0, min=1, max=100) + gm = make_fx(func, tracing_mode="symbolic")(inp, group_size, group_name) + g = gm.graph + for n in g.nodes: + if is_all_gather_into_tensor(n): + assert str(n.meta["val"].size()) in [ + "torch.Size([16, 4])", + "torch.Size([2*s75, s75])", + "torch.Size([4*s75, s75])", + ] from torch._inductor.comm_analysis import ( estimate_nccl_collective_runtime_from_fx_node, ) @@ -2259,10 +2330,79 @@ def func(inp, group_size, group_name): rs_1_wait = torch.ops.c10d_functional.wait_tensor(rs_1_out) return rs_1_wait + # test for static shape input estimation gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name) g = gm.graph for n in g.nodes: if is_reduce_scatter_tensor(n): + assert str(n.meta["val"].size()) in [ + "torch.Size([1, 4])", + "torch.Size([2, 4])", + ] + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=True + ) + assert est_ms_nccl > 0 + + # test for unbacked dynamic shape input estimation + class TestModule(nn.Module): + def __init__(self, group_size, group_name): + super().__init__() + self.group_size = group_size + self.group_name = group_name + + def forward(self, x): + u = x.item() + # Use u as a dimension of a new tensor: + y = torch.empty(u, 4, device=x.device) + return func(y, self.group_size, self.group_name) + + inp = torch.tensor(1, device=self.device) + model = TestModule(group_size, group_name).to(self.device) + exported_program = torch.export.export( + model, + (inp,), + ) + gm = exported_program.module() + g = gm.graph + for n in g.nodes: + if is_reduce_scatter_tensor(n): + assert str(n.meta["val"].size()) in [ + "torch.Size([(u0//2), 4])", + "torch.Size([(u0//4), 4])", + ] + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=True + ) + assert est_ms_nccl > 0 + + # test for backed dynamic shape input estimation + inp = torch.ones(4, 4, device=self.device) + torch._dynamo.mark_dynamic(inp, 0, min=1, max=100) + gm = make_fx(func, tracing_mode="symbolic")(inp, group_size, group_name) + g = gm.graph + for n in g.nodes: + if is_reduce_scatter_tensor(n): + assert str(n.meta["val"].size()) in [ + "torch.Size([(s75//2), s75])", + "torch.Size([(s75//4), s75])", + ] from torch._inductor.comm_analysis import ( estimate_nccl_collective_runtime_from_fx_node, ) @@ -2299,10 +2439,70 @@ def func(inp, group_size, group_name): ar_1_wait = torch.ops.c10d_functional.wait_tensor(ar_1_out) return ar_1_wait + # test for static shape input estimation gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name) g = gm.graph for n in g.nodes: if is_all_reduce_tensor(n): + assert str(n.meta["val"].size()) in ["torch.Size([4, 4])"] + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=True + ) + assert est_ms_nccl > 0 + + # test for unbacked dynamic shape input estimation + class TestModule(nn.Module): + def __init__(self, group_size, group_name): + super().__init__() + self.group_size = group_size + self.group_name = group_name + + def forward(self, x): + u = x.item() + # Use u as a dimension of a new tensor: + y = torch.empty(u, 4, device=x.device) + return func(y, self.group_size, self.group_name) + + inp = torch.tensor(1, device=self.device) + model = TestModule(group_size, group_name).to(self.device) + exported_program = torch.export.export( + model, + (inp,), + ) + gm = exported_program.module() + g = gm.graph + for n in g.nodes: + if is_all_reduce_tensor(n): + assert str(n.meta["val"].size()) in ["torch.Size([u0, 4])"] + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=True + ) + assert est_ms_nccl > 0 + + # test for backed dynamic shape input estimation + inp = torch.ones(4, 4, device=self.device) + torch._dynamo.mark_dynamic(inp, 0, min=1, max=100) + gm = make_fx(func, tracing_mode="symbolic")(inp, group_size, group_name) + g = gm.graph + for n in g.nodes: + if is_all_reduce_tensor(n): + assert str(n.meta["val"].size()) in ["torch.Size([s75, s75])"] from torch._inductor.comm_analysis import ( estimate_nccl_collective_runtime_from_fx_node, ) @@ -2349,12 +2549,14 @@ def func(inp, group_size, group_name): a2a_1_wait = torch.ops.c10d_functional.wait_tensor(a2a_1_out) return a2a_1_wait + # test for static shape input estimation gm = make_fx(func)( torch.ones(group_size * 4, 1, device=self.device), group_size, group_name ) g = gm.graph for n in g.nodes: if is_all_to_all_tensor(n): + assert str(n.meta["val"].size()) in ["torch.Size([8, 1])"] from torch._inductor.comm_analysis import ( estimate_nccl_collective_runtime_from_fx_node, ) @@ -2368,6 +2570,70 @@ def func(inp, group_size, group_name): ) assert est_ms_nccl > 0 + # test for unbacked dynamic shape input estimation + class TestModule(nn.Module): + def __init__(self, group_size, group_name): + super().__init__() + self.group_size = group_size + self.group_name = group_name + + def forward(self, x): + u = x.item() + # Use u as a dimension of a new tensor: + y = torch.empty(u, 4, device=x.device) + return func(y, self.group_size, self.group_name) + + inp = torch.tensor(1, device=self.device) + model = TestModule(group_size, group_name).to(self.device) + exported_program = torch.export.export( + model, + (inp,), + ) + gm = exported_program.module() + g = gm.graph + for n in g.nodes: + if is_all_to_all_tensor(n): + assert str(n.meta["val"].size()) in ["torch.Size([4*u0, 4])"] + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + # TODO(ruisizhang123): Currently, NCCL estimation API does not support kwargs input + # (input_split_sizes & output_split_sizes in all-to-all) with dynamic shapes. + # est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + # n, use_nccl_estimator=True + # ) + # assert est_ms_nccl > 0 + + # test for backed dynamic shape input estimation + inp = torch.ones(4, 4, device=self.device) + torch._dynamo.mark_dynamic(inp, 0, min=1, max=100) + gm = make_fx(func, tracing_mode="symbolic")(inp, group_size, group_name) + g = gm.graph + for n in g.nodes: + if is_all_to_all_tensor(n): + assert str(n.meta["val"].size()) in [ + "torch.Size([2*(((s75**2)//2)), s75])" + ] + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + # TODO(ruisizhang123): Currently, NCCL estimation API does not support kwargs input + # (input_split_sizes & output_split_sizes in all-to-all) with dynamic shapes. + # est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + # n, use_nccl_estimator=True + # ) + # assert est_ms_nccl > 0 + @skip_if_lt_x_gpu(2) @requires_gloo() def test_regression_use_nccl_estimate_with_gloo(self): diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 55279f393d3aa..5b174414a67b6 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -1,6 +1,7 @@ import functools import logging import math +import operator from enum import IntEnum from typing import Any, Optional @@ -8,6 +9,7 @@ import torch import torch.utils._pytree as pytree +from torch.fx.experimental.symbolic_shapes import hint_int from torch.fx.operator_schemas import normalize_function from . import ir @@ -69,18 +71,23 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL: return get_collective_type_from_kernel_name(name) -def get_size_numel(size: torch.Size, fallback: int = 4096 * 4096) -> int: +def get_ir_node_size_numel(size: torch.Size, fallback: int = 4096 * 4096) -> int: numel = sympy_product(size) if isinstance(numel, sympy.Integer): return int(numel) - return V.graph.sizevars.size_hint(numel, fallback=fallback) +def get_fx_node_size_numel(size: torch.Size, fallback: int = 4096 * 4096) -> int: + numel = functools.reduce(operator.mul, size, 1) + result = hint_int(numel, fallback=fallback) + return result + + def get_collective_input_size_bytes(node: ir.IRNode) -> int: sz_bytes = 0 for inp in node.inputs: # type: ignore[attr-defined] - numel = get_size_numel(inp.layout.size) + numel = get_ir_node_size_numel(inp.layout.size) sz_bytes += numel * get_dtype_size(inp.layout.dtype) return sz_bytes @@ -350,18 +357,18 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: # dont double count pre-allocated buffer passed in kwargs.pop("out", None) - def tensor_bytes(t) -> int: - return get_size_numel(t.size()) * get_dtype_size(t.dtype) + def tensor_bytes(t: torch.Tensor) -> int: + return get_fx_node_size_numel(t.size()) * get_dtype_size(t.dtype) def add_inp_bytes(inp: torch.fx.Node): - t = inp.meta.get("val", None) - if t is None: + inp_val = inp.meta.get("val", None) + if not isinstance(inp_val, torch.Tensor): return nonlocal input_bytes if input_bytes is None: input_bytes = 0 - input_bytes += tensor_bytes(t) + input_bytes += tensor_bytes(inp_val) pytree.tree_map_only( torch.fx.Node, @@ -369,14 +376,12 @@ def add_inp_bytes(inp: torch.fx.Node): (args, kwargs), ) - output_tensor = fx_node.meta.get("val", None) + output_val = fx_node.meta.get("val", None) - if input_bytes is None or output_tensor is None: + if input_bytes is None or not isinstance(output_val, torch.Tensor): return 0 - output_bytes = ( - get_size_numel(output_tensor.size()) * output_tensor.element_size() - ) # pyre-ignore + output_bytes = tensor_bytes(output_val) return input_bytes + output_bytes @@ -467,7 +472,7 @@ def to_real_tensor(e: Any) -> Any: if isinstance(e, torch.fx.Node): return to_real_tensor(e.meta["val"]) if isinstance(e, torch.Tensor): - return _tensor([get_size_numel(e.size())], e.dtype, e.device) + return _tensor([get_fx_node_size_numel(e.size())], e.dtype, e.device) return e flat_args = [to_real_tensor(a) for a in flat_args] From 9708e048ac15d390da3d2c4c87774ee38beb304c Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 21 Nov 2025 11:23:32 -0800 Subject: [PATCH 023/338] [hoo] Invoke subgraph + effect (#167231) This PR adds support for effectful ops within invoke_subgraphs. * Most of the logic is in `invoke_subgraph.py_functionalize_impl`. * In the functionalization metadata collection phase, we note the tokens before going further down the dispatcher, and then note the tokens after coming back from the dispatcher. If there are nodes in the invoke_subgraph subgraph that contain effects, the number of effects should change, or the tokens used for an effect should. * We will store this effect difference in the `InvokeSubgraphCache` where the key is the identifier and value is the effect. For now we only support one effect within a subgraph. * During the tracing part of AOTAutograd, we will then wrap the subgraph to take in and output a token. Before: ``` def forward(self, x): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', x) return invoke_subgraph def repeated_subgraph(self, x): record_memory = torch.ops.mylib.record_memory.default("forward", "N") add = torch.ops.aten.add(x, x) return add ``` After: ``` def forward(self, token, x): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x) getitem = invoke_subgraph[0] # output token getitem_1 = invoke_subgraph[1] return (getitem, getitem_1) def repeated_subgraph(self, token, x): with_effects = torch.ops.higher_order.with_effects(token, torch.ops.mylib.record_memory.default, 'forward', 'N') getitem = with_effects[0] # output token add = torch.ops.aten.add(x, x) return (getitem, add) ``` * Then there is a bunch of logic within `_remove_effect_tokens` to handle removing the effects from the invoke_subgraph subgraph Differential Revision: [D87392741](https://our.internmc.facebook.com/intern/diff/D87392741) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167231 Approved by: https://github.com/anijain2305 --- test/export/test_converter.py | 2 +- test/export/test_passes.py | 15 +- test/export/test_torchbind.py | 12 +- test/higher_order_ops/test_with_effects.py | 100 ++++++++ torch/_guards.py | 18 ++ torch/_higher_order_ops/invoke_subgraph.py | 50 ++++ torch/export/_remove_effect_tokens_pass.py | 265 ++++++++++++--------- torch/export/_unlift.py | 24 +- 8 files changed, 352 insertions(+), 134 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index e739e5c346677..5b608503a1168 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1405,7 +1405,7 @@ def func3(x): # noqa: F841 ) # qnnpack not supported on s390x @xfailIfS390X - def test_ts2ep_convert_quantized_model(self): + def test_ts2ep_convert_quantized_model1(self): class Standalone(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 9cf442c27a2bb..866eeaaee3986 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -640,16 +640,13 @@ def forward(self, x): self.assertExpectedInline( without_token_ep.graph_module.code.strip(), """\ -def forward(self, token, obj_attr, x): - with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None - getitem = with_effects[0] - getitem_1 = with_effects[1] - getitem_2 = with_effects[2]; with_effects = None +def forward(self, obj_attr, x): + takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(foo = obj_attr, x = x); x = None + getitem_1 = takes_foo_tuple_return_default[0] + getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None - with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None - getitem_3 = with_effects_1[0] - getitem_4 = with_effects_1[1]; with_effects_1 = None - return (getitem_3, getitem_4)""", # noqa: B950 + takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None + return (takes_foo_default,)""", # noqa: B950 ) def test_fakify_script_objects(self): diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 246122433e06c..adf0986811648 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -461,9 +461,9 @@ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr _guards_fn = self._guards_fn(x); _guards_fn = None - takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) - takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None - add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None + takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) + takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None + add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950 ) self.assertExpectedInline( @@ -1087,10 +1087,12 @@ def forward(self, token, tq, x): str(ep.graph_module.graph).strip(), """\ graph(): + %token : [num_users=1] = placeholder[target=token] %tq : [num_users=2] = placeholder[target=tq] %x : [num_users=1] = placeholder[target=x] - %queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {}) - return (tq,)""", # noqa: B950 + %with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {}) + %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) + return (getitem, tq)""", # noqa: B950 ) def test_deepcopy(self): diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index 2c4cf02bc1c8a..ce0d0eb5dc3ef 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -26,6 +26,7 @@ ) from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.node import has_side_effect from torch.testing import FileCheck from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater from torch.testing._internal.common_quantization import skipIfNoDynamoSupport @@ -870,6 +871,105 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token): finally: handle.destroy() + @unittest.skipIf(not TEST_CUDA, "triton") + def test_export_invoke_subgraph(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + recorded_list = [] + + @torch.library.custom_op("mylib::record_memory", mutates_args=()) + def record_memory(prefix: str, module_name: str) -> None: + torch.cuda.synchronize() + mem_alloc = torch.cuda.memory_allocated() / 1024**2 + mem_reserved = torch.cuda.memory_reserved() / 1024**2 + memory_str = f"[{prefix}] {module_name}: allocated={mem_alloc:.2f} MB, reserved={mem_reserved:.2f} MB" + recorded_list.append(memory_str) + + @record_memory.register_fake + def record_memory_fake(prefix, module_name): + return + + record_memory.register_effect(_EffectType.ORDERED) + has_side_effect(torch.ops.mylib.record_memory.default) + + class N(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(1024, 1024) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(1024, 1024) + + @torch.compiler.nested_compile_region + def forward(self, x): + torch.ops.mylib.record_memory("forward", "N") + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod_list = torch.nn.ModuleList(N() for _ in range(3)) + + def forward(self, x): + for m in self.mod_list: + x = m(x) + torch.ops.mylib.record_memory("forward", "N") + return (x,) + + model = M().to("cuda") + torch.cuda.reset_peak_memory_stats() + + x = torch.randn(32, 1024, requires_grad=True, device="cuda") + + ep = torch.export.export(model, (x,)) + ep = ep.run_decompositions() + self.assertEqual(len(list(ep.graph_module.named_modules())), 2) + + self.assertExpectedInline( + ep.graph_module.code.strip(), + """\ +def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x): + repeated_subgraph0 = self.repeated_subgraph0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None + getitem = invoke_subgraph[0] + getitem_1 = invoke_subgraph[1]; invoke_subgraph = None + repeated_subgraph0_1 = self.repeated_subgraph0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, getitem_1, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias); repeated_subgraph0_1 = getitem = getitem_1 = p_mod_list_1_linear1_weight = p_mod_list_1_linear1_bias = p_mod_list_1_linear2_weight = p_mod_list_1_linear2_bias = None + getitem_2 = invoke_subgraph_1[0] + getitem_3 = invoke_subgraph_1[1]; invoke_subgraph_1 = None + repeated_subgraph0_2 = self.repeated_subgraph0 + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, 'subgraph_0', getitem_2, getitem_3, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias); repeated_subgraph0_2 = getitem_2 = getitem_3 = p_mod_list_2_linear1_weight = p_mod_list_2_linear1_bias = p_mod_list_2_linear2_weight = p_mod_list_2_linear2_bias = None + getitem_4 = invoke_subgraph_2[0] + getitem_5 = invoke_subgraph_2[1]; invoke_subgraph_2 = None + with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None + getitem_6 = with_effects[0]; with_effects = None + return (getitem_6, getitem_5)""", + ) + + self.assertExpectedInline( + ep.graph_module.repeated_subgraph0.code.strip(), + """\ +def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): + with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None + getitem = with_effects[0]; with_effects = None + permute = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None + addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, permute); arg3_1 = arg1_1 = permute = None + relu = torch.ops.aten.relu.default(addmm); addmm = None + permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None + addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None + return (getitem, addmm_1)""", + ) + + recorded_list.clear() + # TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph + # eagerly twice. Once for get_output_metadata and then once for + # InvokeSubgraphAutogradOp. This causes record_memory to be called twice. + with torch.no_grad(): + out2 = ep.module()(x) + self.assertEqual(len(recorded_list), 4) + self.assertTrue(torch.allclose(model(x)[0], out2[0])) + if __name__ == "__main__": run_tests() diff --git a/torch/_guards.py b/torch/_guards.py index 32b796d71eea7..1bd32fc7f08ec 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -713,6 +713,9 @@ def __init__(self) -> None: self.lazy_bwd_cache: dict[ str, dict[tuple[object], tuple[torch.fx.GraphModule, int]] ] = defaultdict(dict) + self.effects_cache: dict[ + str, set + ] = {} # Maps identifier -> set of effect types def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: self.dynamo_installed_submodules[fn_id].append(identifier) @@ -751,6 +754,21 @@ def get_lazy_bwd_entry( return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None)) + def add_effects(self, identifier: str, effects: set) -> None: + """Store the effect types for a given invoke_subgraph identifier.""" + if prev_effects := self.effects_cache.get(identifier, None): + assert effects == prev_effects, ( + "Different number of effects were found for invoke_subgraph " + f"call with identifier {identifier}. \n" + f"Previously we had the following effects: {prev_effects}.\n" + f"But now we have: {effects}." + ) + self.effects_cache[identifier] = effects + + def get_effects(self, identifier: str) -> Optional[set]: + """Retrieve the effect types for a given invoke_subgraph identifier.""" + return self.effects_cache.get(identifier, None) + class HopDispatchSetCache: def __init__(self) -> None: diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index e22b741631d3f..7d066e132e011 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -80,6 +80,7 @@ def __call__( assert all( isinstance(o, (torch.Tensor, int, torch.SymInt, torch.Generator)) for o in operands + if o is not None ), ( f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}" ) @@ -562,7 +563,34 @@ def _(ctx, subgraph, identifier, *operands): do_auto_functionalize_v2, ) + # (in the functionalization metadata phase) Capture tokens before + tokens_before = dict(ctx.mode._tokens) + + # Check if this subgraph has effects stored in the cache + invoke_subgraph_cache = get_invoke_subgraph_cache() + effects = None + if invoke_subgraph_cache: + effects = invoke_subgraph_cache.get_effects(identifier) + + if effects: + assert len(effects) == 1, "Multiple effects within a subgraph NYI" + tokens = ctx.mode._tokens + effects = next(iter(effects)) + token_input = tokens[effects] + + operands = (token_input, *operands) + + def wrap_subgraph(subgraph): + def wrapped_subgraph(token, *args): + res = subgraph(*args) + return ctx.unwrap_tensors(ctx.mode._tokens[effects]), *res + + return wrapped_subgraph + + subgraph = wrap_subgraph(subgraph) + unwrapped_operands = ctx.unwrap_tensors(operands) + hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands) if can_auto_functionalize(hop_instance): # NOTE: [auto_functionalize x invoke_subgraph caching] @@ -587,6 +615,28 @@ def _(ctx, subgraph, identifier, *operands): # of invoke_subgraph ops if input aliasing/mutation is detected. functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph) out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands) + + if effects: + (new_token, *out) = out + ctx.mode._tokens[effects] = new_token + + # (in the functionalization metadata phase) Capture tokens after and see if + # there are any differences (there are new effects or the token value for an + # effect type has changed) + tokens_after = dict(ctx.mode._tokens) + discovered_effects = set() + for effect_type, token in tokens_after.items(): + if effect_type not in tokens_before or tokens_before[effect_type] is not token: + discovered_effects.add(effect_type) + + if discovered_effects: + assert ctx.mode._allow_token_discovery, ( + f"Number of tokens changed by {len(discovered_effects)} when tracing subgraph {subgraph}." + ) + # Store discovered effects in the cache by identifier + if invoke_subgraph_cache: + invoke_subgraph_cache.add_effects(identifier, discovered_effects) + return ctx.wrap_tensors(out) diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 21930d81fe092..8504d1cbdb71f 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -15,113 +15,105 @@ ) -def _remove_effect_tokens_from_graph_helper( - ep, num_tokens, input_token_names, output_token_names +def _get_custom_obj_for_node(node, inputs_to_lifted_custom_objs, constants): + """Extract the custom object from a node's arguments.""" + custom_obj_node = node + custom_obj_meta = custom_obj_node.meta["val"] # type: ignore[union-attr] + assert isinstance(custom_obj_meta, CustomObjArgument) + + if custom_obj_meta.fake_val: + return custom_obj_meta.fake_val + elif custom_obj_node.name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] + return constants[inputs_to_lifted_custom_objs[custom_obj_node.name]] # type: ignore[union-attr] + else: + raise RuntimeError(f"Unable to find custom obj for node {node}") + + +def _replace_with_effects_node( + node, ep, inputs_to_lifted_custom_objs, output_tokens, input_tokens, module ): - inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs - - output_node = None - with_effect_nodes: list[torch.fx.Node] = [] - - # Output node need to check its args against output_token_names (collected from output_spec) - # Therefore, we only need to find the top-levele output node - output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output"))) - for module in ep.graph_module.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - - for node in module.graph.nodes: - if not (node.op == "call_function" and node.target is with_effects): - continue - - with_effect_nodes.append(node) - - # Remove tokens from outputs - assert output_node is not None - output_args = output_node.args[0] - assert len(output_args) >= num_tokens - out_token_nodes = output_args[:num_tokens] - output_node.args = (tuple(output_args[num_tokens:]),) - for out_token in out_token_nodes: - assert out_token.name in output_token_names - out_token.users.clear() - ep.graph.erase_node(out_token) - - # Replace with_effects(token, func, args) with just func(args) - for node in reversed(with_effect_nodes): - func = node.args[1] - assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) - - if func is torch.ops.higher_order.call_torchbind: - custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr] - assert isinstance(custom_obj_meta, CustomObjArgument) - if custom_obj_meta.fake_val: - custom_obj = custom_obj_meta.fake_val - elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] - custom_obj = ep.constants[ - inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr] - ] - else: - raise RuntimeError(f"Unable to find custom obj for node {node}") - schema = _get_schema(func, (custom_obj,) + node.args[3:]) - else: - schema = _get_schema(func, node.args[2:]) - - with ep.graph.inserting_before(node): - new_node = ep.graph.call_function(func, node.args[2:], node.kwargs) - for k, v in node.meta.items(): - new_node.meta[k] = v - if k == "unbacked_bindings": - # Remove the extra layer for effect token - old_bindings = new_node.meta[k] - new_bindings = { - k: path[1:] if path else path for k, path in old_bindings.items() - } - new_node.meta[k] = new_bindings - - node.replace_all_uses_with(new_node) - - # Update user getitem nodes - for user in list(new_node.users.keys()): - assert user.target is operator.getitem - # getitem(with_effects, 0) == token - if user.args[1] == 0: - ep.graph.erase_node(user) - - if len(schema.returns) == 1: - # If the function has 1 return then it will just directly return the - # result -- we don't need a getitem. So we can replace all the - # getitem(with_effects, 1) with just the note itself. - for user in list(new_node.users.keys()): - assert user.args[1] == 1 + """Replace a with_effects node with the underlying function call.""" + # Get the input nodes + token_node, func, *node_args = node.args + if token_node.op == "placeholder": + input_tokens.append(token_node) + + assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) + + # Get the schema for the function + if func is torch.ops.higher_order.call_torchbind: + custom_obj = _get_custom_obj_for_node( + node_args[0], inputs_to_lifted_custom_objs, ep.constants + ) + schema = _get_schema(func, [custom_obj] + node_args[1:]) + else: + schema = _get_schema(func, node_args) + + # Create the replacement node + with module.graph.inserting_before(node): + new_node = module.graph.call_function(func, tuple(node_args), node.kwargs) + + # Update getitem nodes that extract outputs from with_effects + for user in list(node.users.keys()): + assert user.target is operator.getitem + # getitem(with_effects, 0) is the token node + if user.args[1] == 0: + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.append(user) + + # Copy metadata from old node to new node + for k, v in node.meta.items(): + new_node.meta[k] = v + if k == "unbacked_bindings": + # Remove the extra layer for effect token + old_bindings = new_node.meta[k] + new_bindings = { + k: path[1:] if path else path for k, path in old_bindings.items() + } + new_node.meta[k] = new_bindings + + # Fix up the getitem nodes based on return count + if len(schema.returns) == 1: + # Single return: replace getitem(with_effects, 1) with the node itself + for user in list(node.users.keys()): + if user.args[1] == 1: user.replace_all_uses_with(new_node) - - new_node.meta["val"] = node.meta["val"][1] - elif len(schema.returns) > 1: - # If the function has more than 1 return then since we got rid of - # the 1st return value (the token), we need to bump all the other - # getitem calls by 1 down - for user in list(new_node.users.keys()): - assert user.args[1] >= 1 - user.args = (user.args[0], user.args[1] - 1) - - new_node.meta["val"] = node.meta["val"][1:] - else: - assert len(schema.returns) == 0 - assert len(new_node.users) == 0 - new_node.meta["val"] = None - - ep.graph.erase_node(node) - - # Remove tokens from inputs - placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"] - assert len(placeholders) >= num_tokens - inp_token_nodes = placeholders[:num_tokens] - for inp_token in inp_token_nodes: - assert inp_token.name in input_token_names - ep.graph.erase_node(inp_token) - - ep.graph.eliminate_dead_code() + new_node.meta["val"] = node.meta["val"][1] + elif len(schema.returns) > 1: + # Multiple returns: shift getitem indices down by 1 + for user in list(node.users.keys()): + if user.args[1] >= 1: + user.args = (new_node, user.args[1] - 1) + new_node.meta["val"] = node.meta["val"][1:] + else: + # No returns + assert len(schema.returns) == 0 + assert len(new_node.users) == 0 + new_node.meta["val"] = None + + +def _replace_invoke_subgraph_node(node, module, output_tokens, input_tokens): + """Replace an invoke_subgraph node to remove the token argument.""" + assert node.args[0].op == "get_attr" + submod = getattr(module, node.args[0].target) + if not submod.meta.get("has_with_effects", False): + return + + # Remove token from inputs + subgraph, identifier, token, *operands = node.args + node.args = (subgraph, identifier, *operands) + if token.op == "placeholder": + input_tokens.append(token) + + # Update getitem nodes to account for removed token output + for user in list(node.users.keys()): + if user.args[1] >= 1: + user.args = (node, user.args[1] - 1) + elif user.args[1] == 0: + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.append(user) def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: @@ -132,6 +124,64 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: This function does an inplace modification on the given ExportedProgram. """ + inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs + + # mark submodules with effects as having effects. This will be used in the following pass to remove effects from subgraphs + for _, module in ep.graph_module.named_modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + with_effect_nodes = [ + node for node in module.graph.nodes if node.target is with_effects + ] + if len(with_effect_nodes) > 0: + module.meta["has_with_effects"] = True + + # Process each module with the replace hook to ensure graph signature is updated + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + for _, module in ep.graph_module.named_modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + input_tokens = [] + output_tokens = [] + + # Process with_effects and invoke_subgraph nodes + for node in module.graph.nodes: + if node.target is with_effects: + _replace_with_effects_node( + node, + ep, + inputs_to_lifted_custom_objs, + output_tokens, + input_tokens, + module, + ) + elif node.target is torch.ops.higher_order.invoke_subgraph: + _replace_invoke_subgraph_node( + node, module, output_tokens, input_tokens + ) + + # Remove tokens from the output node + if len(output_tokens) > 0: + output_node = next(reversed(module.graph.find_nodes(op="output"))) + output_args = output_node.args[0] + assert len(output_args) >= len(output_tokens), ( + f"{output_args} output arguments found\n" + f"{output_tokens} output tokens found\n" + f"{module.graph}" + ) + output_node.args = (tuple(output_args[len(output_tokens) :]),) + + module.graph.eliminate_dead_code() + + # Remove tokens from the input placeholders + for node in module.graph.nodes: + if node.op == "placeholder" and node in input_tokens: + module.graph.erase_node(node) + + module.recompile() + num_tokens: int = 0 input_token_names: list[str] = [] new_input_specs: list[InputSpec] = [] @@ -159,9 +209,4 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: assert num_tokens == num_out_tokens - with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): - _remove_effect_tokens_from_graph_helper( - ep, num_tokens, input_token_names, output_token_names - ) - return ep diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 52d06a294fac1..6239c5899c233 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -748,11 +748,23 @@ def _unlift_exported_program_lifted_states( ) -> torch.fx.GraphModule: check_guards = check_guards and _ok_to_generate_guards_fn() + source_node_dict = { + node.name: node for node in ep.graph.nodes if node.op != "placeholder" + } + # placeholder node name might change after deepcopy + placeholder_source_node_dict = { + node.target: node for node in ep.graph.nodes if node.op == "placeholder" + } + + new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) + new_gm.meta.update(ep.graph_module.meta) + ep = copy.copy(ep) + ep._graph_module = new_gm + # TODO T206340015 if ep.verifiers[0].dialect != "TRAINING": ep = _remove_effect_tokens(ep) - new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) forward_arg_names = ( sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None @@ -786,19 +798,13 @@ def _unlift_exported_program_lifted_states( for out_spec in ep.graph_signature.output_specs ] - source_node_dict = { - node.name: node for node in ep.graph.nodes if node.op != "placeholder" - } - # placeholder node name might change after deepcopy - placeholder_source_node_dict = { - node.target: node for node in ep.graph.nodes if node.op == "placeholder" - } for node in new_gm.graph.nodes: source_node = None if node.op == "placeholder": source_node = placeholder_source_node_dict.get(node.target) else: - source_node = source_node_dict.get(node.name) + if node.name in source_node_dict: + source_node = source_node_dict.get(node.name) node.meta["from_node"] = [ NodeSource( source_node, From 561c1eb14884075fa7c9a3b6bccfa41a477caa43 Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 21 Nov 2025 11:23:32 -0800 Subject: [PATCH 024/338] [invoke_subgraph] Don't run the graph twice when autograd enabled (#167245) In the [previous PR](https://github.com/pytorch/pytorch/pull/167231/files#diff-e2b74af5d8b538a7d07d18507d27010703742ddad5f819992b55f5abc6d9a502R964-R966) we found that the autograd eager impl of invoke_subgraph calls the subgraph twice. If the subgraph contains effects then effects will be run twice, which is bad. This PR fixes the issue by getting the output metadata from `subgraph`'s `node.meta` if it exists. Differential Revision: [D87392740](https://our.internmc.facebook.com/intern/diff/D87392740) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167245 Approved by: https://github.com/anijain2305 ghstack dependencies: #167231 --- test/higher_order_ops/test_with_effects.py | 6 +- torch/_higher_order_ops/invoke_subgraph.py | 64 ++++++++++++++++++++-- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index ce0d0eb5dc3ef..ec936de9d0595 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -962,11 +962,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): ) recorded_list.clear() - # TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph - # eagerly twice. Once for get_output_metadata and then once for - # InvokeSubgraphAutogradOp. This causes record_memory to be called twice. - with torch.no_grad(): - out2 = ep.module()(x) + out2 = ep.module()(x) self.assertEqual(len(recorded_list), 4) self.assertTrue(torch.allclose(model(x)[0], out2[0])) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 7d066e132e011..bb0d6cef3ee6f 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -305,6 +305,62 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): def get_output_metadata(subgraph, *operands): + """ + Extract metadata about the subgraph outputs WITHOUT executing the subgraph. + This avoids running side-effectful operations twice (once here, once in forward). + We analyze the graph structure statically to extract metadata. + """ + # Unwrap FunctionalizeCtxWrapper if present + if isinstance(subgraph, FunctionalizeCtxWrapper): + subgraph = subgraph.subgraph + + # If not a GraphModule, fall back to execution-based metadata extraction + if not isinstance(subgraph, torch.fx.GraphModule): + return _get_output_metadata_by_execution(subgraph, *operands) + + output_metadata = OutputMetadata() + + # Extract output arguments from the output node + # The output node has args=(output_values,) where output_values is a tuple/list + output_node = next(reversed(subgraph.graph.find_nodes(op="output"))) + output_metadata.num_fw_outs = len(output_node.args[0]) + + for idx, output_arg in enumerate(output_node.args[0]): + if not isinstance(output_arg, torch.fx.Node): + if isinstance(output_arg, int): + output_metadata.indexes_with_symint.add(idx) + output_metadata.indexes_with_no_grad.add(idx) + continue + + # Check node metadata for type information + if output_arg.meta.get("val") is None: + # If we don't have complete metadata for all outputs, fall back to execution + # This is important for correctness (e.g., detecting SymInts) even though it + # runs side-effectful operations + return _get_output_metadata_by_execution(subgraph, *operands) + + val = output_arg.meta["val"] + if isinstance(val, torch.SymInt): + output_metadata.indexes_with_symint.add(idx) + output_metadata.indexes_with_no_grad.add(idx) + elif isinstance(val, torch.Tensor): + # Check if tensor requires grad from metadata + if hasattr(val, "requires_grad") and not val.requires_grad: + output_metadata.indexes_with_no_grad.add(idx) + else: + # Non-tensor, non-symint (shouldn't happen but be safe) + output_metadata.indexes_with_no_grad.add(idx) + + return output_metadata + + +def _get_output_metadata_by_execution(subgraph, *operands): + """ + Fallback: Extract metadata by executing the subgraph. + This should only be used when static analysis fails. + WARNING: This will run side-effectful operations! + """ + with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): # args are functional tensors, generate some example tensors @@ -324,19 +380,15 @@ def get_output_metadata(subgraph, *operands): num_fw_outs = len(fw_outs) - # Collect the indexes of none in the output to check that the grad - # is None at the corresponding index in the backward. This check is - # performed in the autograd.Function - InvokeSubgraphAutogradOp. - # Also collect the indexes of no_grad in the output to filter out - # the grad_outs in the `backward` method. output_metadata = OutputMetadata() - output_metadata.num_fw_outs = num_fw_outs + for idx, fw_out in enumerate(fw_outs): if isinstance(fw_out, torch.SymInt): output_metadata.indexes_with_symint.add(idx) elif not fw_out.requires_grad: output_metadata.indexes_with_no_grad.add(idx) + return output_metadata From 53b2f6292592c34bc9bccca39242fc461f887b6e Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 21 Nov 2025 11:23:33 -0800 Subject: [PATCH 025/338] [hoo] Fix unlift of effects with invoke_subgraph (#167363) Updates the implementation of `unlift_tokens` to handle unlifting invoke_subgraph. The context of `unlift_tokens` is currently tokens are threaded as inputs and outputs of the toplevel graph produced by AOTAutograd. However we don't want the inductor traced graph to have any notion of effects/tokens, just that the tokens should introduce some extra dependency behavior. So, we unlift the tokens from the toplevel graph. Instead of placeholder nodes the tokens will come from a `_make_token` call, and instead of outputting the tokens we will sink all tokens into `_sink_tokens`. Similarly, we want the invoke_subgraph subgraph to not have any notion of tokens, so we will also remove the tokens from the inputs of the invoke_subgraph subgraph. However, we still need a way mark the invoke_subgraph call as being effectful at the toplevel module to prevent invoke_subgraph calls from being reordered, so I wrap the invoke_subgraph with an effects. Before: ``` def forward(self, token, x): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x) getitem = invoke_subgraph[0] # output token getitem_1 = invoke_subgraph[1] return (getitem, getitem_1) def repeated_subgraph(self, token, x): with_effects = torch.ops.higher_order.with_effects(token, torch.ops.mylib.record_memory.default, 'forward', 'N') getitem = with_effects[0] # output token add = torch.ops.aten.add(x, x) return (getitem, add) ``` After: ``` def forward(self, x): token = torch.ops.prims._make_token.default() repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.with_effects( token, torch.ops.higher_order.invoke_subgraph, repeated_subgraph0, 'subgraph_0', token, x ) getitem = invoke_subgraph[0] # output token getitem_1 = invoke_subgraph[1] _ = torch.ops.prims._sink_tokens.default([getitem]) return (getitem_1,) def repeated_subgraph(self, x): token = torch.ops.prims._make_token.default() with_effects = torch.ops.higher_order.with_effects(token, torch.ops.mylib.record_memory.default, 'forward', 'N') getitem = with_effects[0] # output token add = torch.ops.aten.add(x, x) _ = torch.ops.prims._sink_tokens.default([getitem]) return (add,) ``` Differential Revision: [D87668981](https://our.internmc.facebook.com/intern/diff/D87668981) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167363 Approved by: https://github.com/fxdawnn ghstack dependencies: #167231, #167245 --- test/higher_order_ops/test_with_effects.py | 84 ++++-- torch/_functorch/_aot_autograd/utils.py | 291 +++++++++++++++------ torch/_higher_order_ops/effects.py | 3 +- torch/_prims/__init__.py | 2 + 4 files changed, 280 insertions(+), 100 deletions(-) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index ec936de9d0595..b7840c0729e27 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -18,6 +18,7 @@ nop, ) from torch._functorch.aot_autograd import aot_export_module +from torch._guards import tracing, TracingContext from torch._higher_order_ops.effects import ( _EffectType, _get_effect, @@ -137,7 +138,7 @@ def forward(self, arg1_1): with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None getitem_2 = with_effects_1[0]; with_effects_1 = None _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]); getitem_2 = _sink_tokens_default = None - return [add]""", # noqa: B950 + return (add,)""", # noqa: B950 ) def test_torchbind_custom_op(self): @@ -917,18 +918,19 @@ def forward(self, x): torch.ops.mylib.record_memory("forward", "N") return (x,) - model = M().to("cuda") - torch.cuda.reset_peak_memory_stats() + model = M().to("cuda") + torch.cuda.reset_peak_memory_stats() - x = torch.randn(32, 1024, requires_grad=True, device="cuda") + x = torch.randn(32, 1024, requires_grad=True, device="cuda") - ep = torch.export.export(model, (x,)) - ep = ep.run_decompositions() - self.assertEqual(len(list(ep.graph_module.named_modules())), 2) + # Test torch.export + ep = torch.export.export(model, (x,)) + decomp = ep.run_decompositions() + self.assertEqual(len(list(ep.graph_module.named_modules())), 2) - self.assertExpectedInline( - ep.graph_module.code.strip(), - """\ + self.assertExpectedInline( + decomp.graph_module.code.strip(), + """\ def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None @@ -945,11 +947,11 @@ def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None getitem_6 = with_effects[0]; with_effects = None return (getitem_6, getitem_5)""", - ) + ) - self.assertExpectedInline( - ep.graph_module.repeated_subgraph0.code.strip(), - """\ + self.assertExpectedInline( + decomp.graph_module.repeated_subgraph0.code.strip(), + """\ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None getitem = with_effects[0]; with_effects = None @@ -959,12 +961,56 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None return (getitem, addmm_1)""", - ) + ) - recorded_list.clear() - out2 = ep.module()(x) - self.assertEqual(len(recorded_list), 4) - self.assertTrue(torch.allclose(model(x)[0], out2[0])) + recorded_list.clear() + out2 = ep.module()(x) + self.assertEqual(len(recorded_list), 4) + self.assertTrue(torch.allclose(model(x)[0], out2[0])) + + # Test when we unlift the tokens from the graph. This is used in the inductor path. + with ( + tracing(TracingContext(None)), + torch._functorch.config.patch(unlift_effect_tokens=True), + ): + gm, gs = aot_export_module(ep.module(), (x,), trace_joint=False) + self.assertExpectedInline( + str(gm.code).strip(), + """\ +def forward(self, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1): + _make_token_default = torch.ops.prims._make_token.default() + repeated_subgraph0 = self.repeated_subgraph0 + with_effects_1 = torch.ops.higher_order.with_effects(_make_token_default, torch.ops.higher_order.invoke_subgraph, repeated_subgraph0, 'subgraph_0', arg13_1, arg1_1, arg2_1, arg3_1, arg4_1); _make_token_default = repeated_subgraph0 = arg13_1 = arg1_1 = arg2_1 = arg3_1 = arg4_1 = None + getitem = with_effects_1[0] + getitem_1 = with_effects_1[1]; with_effects_1 = None + repeated_subgraph0_1 = self.repeated_subgraph0 + with_effects_2 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.invoke_subgraph, repeated_subgraph0_1, 'subgraph_0', getitem_1, arg5_1, arg6_1, arg7_1, arg8_1); getitem = repeated_subgraph0_1 = getitem_1 = arg5_1 = arg6_1 = arg7_1 = arg8_1 = None + getitem_2 = with_effects_2[0] + getitem_3 = with_effects_2[1]; with_effects_2 = None + repeated_subgraph0_2 = self.repeated_subgraph0 + with_effects_3 = torch.ops.higher_order.with_effects(getitem_2, torch.ops.higher_order.invoke_subgraph, repeated_subgraph0_2, 'subgraph_0', getitem_3, arg9_1, arg10_1, arg11_1, arg12_1); getitem_2 = repeated_subgraph0_2 = getitem_3 = arg9_1 = arg10_1 = arg11_1 = arg12_1 = None + getitem_4 = with_effects_3[0] + getitem_5 = with_effects_3[1]; with_effects_3 = None + with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None + getitem_6 = with_effects[0]; with_effects = None + _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_6]); getitem_6 = _sink_tokens_default = None + return (getitem_5,)""", # noqa: B950 + ) + self.assertExpectedInline( + str(gm.repeated_subgraph0.code).strip(), + """\ +def forward(self, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): + _make_token_default = torch.ops.prims._make_token.default() + with_effects = torch.ops.higher_order.with_effects(_make_token_default, torch.ops.mylib.record_memory.default, 'forward', 'N'); _make_token_default = None + getitem = with_effects[0]; with_effects = None + t = torch.ops.aten.t.default(arg2_1); arg2_1 = None + addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, t); arg3_1 = arg1_1 = t = None + relu = torch.ops.aten.relu.default(addmm); addmm = None + t_1 = torch.ops.aten.t.default(arg4_1); arg4_1 = None + addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, t_1); arg5_1 = relu = t_1 = None + _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]); getitem = _sink_tokens_default = None + return (addmm_1,)""", # noqa: B950 + ) if __name__ == "__main__": diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 7a290161bb25b..e1255a6de8bf6 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -248,14 +248,28 @@ def maybe_to_fresh_input(idx, t, meta): def is_with_effects(node): - return ( + if ( node.op == "call_function" and node.target is torch.ops.higher_order.with_effects - ) - - -def is_with_effects_op(node, op): - return is_with_effects(node) and node.args[1] == op + ): + return True + elif ( + node.op == "call_function" + and node.target is torch.ops.higher_order.invoke_subgraph + ): + # Check if subgraph has effects by looking in the cache + from torch._guards import InvokeSubgraphCache, TracingContext + + tracing_ctx = TracingContext.try_get() + if tracing_ctx: + invoke_subgraph_cache = tracing_ctx.hop_dispatch_set_cache.get_cache( + torch.ops.higher_order.invoke_subgraph + ) + if invoke_subgraph_cache: + assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache) + effects = invoke_subgraph_cache.get_effects(node.args[1]) + return effects is not None + return False def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None): @@ -264,96 +278,215 @@ def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None): # _make_token() to create a token, and _sink_tokens() to collect the # tokens. See Note [Side-Effectful Tokens in AOTAutograd] # Logic: - # 1. Inputs identified as input tokens: - # - If used as a first argument in with_effects + # 1. In the case of with_effects: + # Before: + # ``` + # def forward(self, token, arg1_1): + # with_effects = torch.ops.higher_order.with_effects(token, ...) + # getitem = with_effects[0] + # getitem_1 = with_effects[0] + # return (getitem, getitem_1) + # ``` # - # 2. Outputs identified as output tokens: - # - If Produced by getitem(with_effects, 0) + # After: + # ``` + # def forward(self, arg1_1): + # _make_token_default = torch.ops.prims._make_token.default() + # with_effects = torch.ops.higher_order.with_effects(_make_token_default, ...) + # getitem = with_effects[0] + # getitem_1 = with_effects[0] + # _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]); + # return (getitem_1,) + # ``` # - # 3. Checks invariants of number input output tokens: - # forward: - # expected_num_erased_inputs == len(fw_metadata.tokens) - # expected_num_erased_outputs == len(fw_metadata.tokens) - # backward: - # expected_num_erased_inputs == fw_metadata.num_backward_tokens - # expected_num_erased_outputs == fw_metadata.num_backward_tokens + # 2. In the case of an invoke_subgraph node, we will use the + # InvokeSubgraphCache to determine if the subgraph has effects. Then we will + # turn it into a `with_effects` node. This is so that at the toplevel graph, + # the nodes will have the correct with_effects threading. We will apply this + # pass recursively to submodules so the tokens will be removed from the + # subgraph's inputs. + # + # Before: + # ``` + # def forward(self, token, arg1_1): + # repeated_subgraph0 = self.repeated_subgraph0 + # invoke_subgraph = torch.ops.higher_order.invoke_subgraph( + # repeated_subgraph0, 'subgraph_0', token, x, arg1_1) + # getitem = invoke_subgraph[0] + # getitem_1 = invoke_subgraph[1] + # return (getitem, getitem1) + # ``` + # + # After: + # ``` + # def forward(self, arg1_1): + # _make_token_default = torch.ops.prims._make_token.default() + # repeated_subgraph0 = self.repeated_subgraph0 + # with_effects_1 = torch.ops.higher_order.with_effects( + # _make_token_default, torch.ops.higher_order.invoke_subgraph, + # repeated_subgraph0, 'subgraph_0', arg1_1) + # getitem = with_effects_1[0] + # getitem_1 = with_effects_1[1]; with_effects_1 = None + # _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]) + # return (getitem_1,) + # ``` + # + # 3. The toplevel module should have the following invariants: + # forward: + # expected_num_erased_inputs == len(fw_metadata.tokens) + # expected_num_erased_outputs == len(fw_metadata.tokens) + # backward: + # expected_num_erased_inputs == fw_metadata.num_backward_tokens + # expected_num_erased_outputs == fw_metadata.num_backward_tokens num_forward_tokens = len(fw_metadata.tokens) num_backward_tokens = fw_metadata.num_backward_tokens - def rewrite_with_effects_input_token(module, node): + def replace_input_token_with_make_token(module, node): with module.graph.inserting_before(node): new_token_node = module.graph.call_function( torch.ops.prims._make_token.default, () ) new_token_node.meta["val"] = torch.tensor([]) new_token_node.meta["tensor_meta"] = torch.tensor([]) + node.replace_all_uses_with(new_token_node) + module.graph.erase_node(node) + + def get_output_tokens(node: torch.fx.Node) -> set[torch.fx.Node]: + output_tokens = set() + for user in list(node.users.keys()): + # Check if this is a getitem accessing index 0 (the token) + if ( + user.op == "call_function" + and user.target is operator.getitem + and len(user.args) > 1 + and user.args[1] == 0 + ): + # Check if this getitem is used in an output + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.add(user) + return output_tokens + + def _unlift_tokens_from_module_helper( + module: torch.fx.GraphModule, + subgraph_str: str, + expected_num_erased: Optional[int], + ): + input_token_nodes = set() + output_token_nodes = set() - args = list(node.args) - args[0] = new_token_node - node.args = tuple(args) - - def rewrite_output(module, node, output_token_nodes, other_output_args): - for output_token_node in output_token_nodes: - assert ( - output_token_node.op == "call_function" - and output_token_node.target is operator.getitem - and output_token_node.args[1] == 0 - ) - with module.graph.inserting_before(node): + for node in module.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.higher_order.with_effects + ): + if node.args[0].op == "placeholder": + input_token_nodes.add(node.args[0]) + replace_input_token_with_make_token(module, node.args[0]) + + tokens_from_with_effects = get_output_tokens(node) + output_token_nodes = output_token_nodes | tokens_from_with_effects + + elif ( + node.op == "call_function" + and node.target is torch.ops.higher_order.invoke_subgraph + ): + subgraph_node, identifier, *operands = node.args + + # Check if subgraph has effects by looking in the cache + from torch._guards import InvokeSubgraphCache, TracingContext + + effects = None + tracing_ctx = TracingContext.try_get() + if tracing_ctx: + invoke_subgraph_cache = ( + tracing_ctx.hop_dispatch_set_cache.get_cache( + torch.ops.higher_order.invoke_subgraph + ) + ) + if invoke_subgraph_cache: + assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache) + effects = invoke_subgraph_cache.get_effects(identifier) + + if effects is not None: + # Wrap invoke_subgraph with with_effects + # Before: invoke_subgraph(subgraph, id, token, *args) -> (token_out, result) + # After: with_effects(token, invoke_subgraph, subgraph, id, *args) -> (token_out, result) + # + # Note: The subgraph itself will be unlifted separately when we iterate + # through named_modules() below. + + num_tokens = len(effects) + assert num_tokens == 1, "Multiple token subgraph NYI" + token_args = operands[:num_tokens] + non_token_args = operands[num_tokens:] + + # Create with_effects wrapper around invoke_subgraph + # with_effects(token, op, *args) where op is invoke_subgraph + # Pass the subgraph and non-token args to invoke_subgraph + with module.graph.inserting_before(node): + new_node = module.graph.call_function( + torch.ops.higher_order.with_effects, + ( + token_args[0], # pyrefly: ignore[bad-argument-type] + torch.ops.higher_order.invoke_subgraph, + subgraph_node, + identifier, + *tuple(non_token_args), + ), + ) + node.replace_all_uses_with(new_node) + new_node.meta = node.meta + module.graph.erase_node(node) + + for token in token_args: + if token.op == "placeholder": + input_token_nodes.add(token) + replace_input_token_with_make_token(module, token) + + # Get output tokens from the new with_effects node + tokens_from_invoke_subgraph = get_output_tokens(new_node) + output_token_nodes = ( + output_token_nodes | tokens_from_invoke_subgraph + ) + + output_node = next(reversed(module.graph.find_nodes(op="output"))) + assert output_node is not None + with module.graph.inserting_before(output_node): module.graph.call_function( torch.ops.prims._sink_tokens.default, - (output_token_nodes,), + (list(output_token_nodes),), ) - node.args = (other_output_args,) - - def do(module, subgraph, expected_num_erased): - num_erased_inputs = 0 - num_erased_outs = 0 - input_nodes = [] - input_token_nodes = set() - with_effect_nodes = [] - output_token_nodes = [] - other_output_nodes = [] - for node in module.graph.nodes: - if node.op == "placeholder": - input_nodes.append(node) - elif is_with_effects(node): - with_effect_nodes.append(node) - if node.args[0] in input_nodes: - input_token_nodes.add(node.args[0]) - rewrite_with_effects_input_token(module, node) - elif node.op == "output": - outs = node.args[0] - for out in outs: - if ( - isinstance(out, torch.fx.node.Node) - and out.op == "call_function" - and out.target is operator.getitem - and out.args[1] == 0 - and out.args[0] in with_effect_nodes - ): - # pyrefly: ignore [missing-attribute] - output_token_nodes.append(out) - else: - other_output_nodes.append(out) - - rewrite_output(module, node, output_token_nodes, other_output_nodes) - num_erased_outs = len(output_token_nodes) - - for input_token_node in input_token_nodes: - module.graph.erase_node(input_token_node) - - num_erased_inputs = len(input_token_nodes) - - assert num_erased_inputs == expected_num_erased, ( - f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}" - ) - assert num_erased_outs == expected_num_erased, ( - f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}" + new_out_args = tuple( + [out for out in output_node.args[0] if out not in output_token_nodes] ) + output_node.args = (new_out_args,) + + if expected_num_erased: + assert len(input_token_nodes) == expected_num_erased, ( + f"{subgraph_str} num_erased_inputs:{len(input_token_nodes)} " + f"{input_token_nodes} != expected {expected_num_erased} \n" + f"{fw_module.print_readable(print_output=False)}" + ) + assert len(output_token_nodes) == expected_num_erased, ( + f"{subgraph_str} num_erased_outs:{len(output_token_nodes)} " + f"{output_token_nodes} != expected {expected_num_erased} \n" + f"{fw_module.print_readable(print_output=False)}" + ) module.recompile() + def unlift_tokens_from_module(module, subgraph_str, expected_num_erased): + for name, m in module.named_modules(): + if isinstance(m, torch.fx.GraphModule): + if name == "": + _unlift_tokens_from_module_helper( + m, subgraph_str, expected_num_erased + ) + else: + # Subgraph -- we may or may not have effects applied + _unlift_tokens_from_module_helper(m, f"{subgraph_str}_{name}", None) + if num_forward_tokens > 0: if aot_config.enable_log: from torch._dynamo.utils import lazy_format_graph_code @@ -369,7 +502,7 @@ def do(module, subgraph, expected_num_erased): colored=True, ), ) - do( + unlift_tokens_from_module( fw_module, "forward", num_forward_tokens, @@ -390,7 +523,7 @@ def do(module, subgraph, expected_num_erased): colored=True, ), ) - do(bw_module, "backward", num_backward_tokens) + unlift_tokens_from_module(bw_module, "backward", num_backward_tokens) # This is sad, but we need to update the metadata to get rid of # the tokens. diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index 2c8d75c67c791..b2fc74b7328f1 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -91,7 +91,6 @@ def __call__( ) -> tuple[Any, ...]: assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) assert not has_aliasing(op), "Ops with aliasing is not supported" - assert has_effects(op) assert isinstance(kwargs, dict) return super().__call__(token, op, *args, **kwargs) @@ -102,7 +101,7 @@ def __call__( def has_aliasing(op: OpType): # NOT FOR PUBLIC USE if isinstance(op, torch._ops.HigherOrderOperator): - return not _get_effect(op) + return False for arg in op._schema.arguments: if arg.alias_info is not None: diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 2c2b16373f8a0..e2e3220bb26d5 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -2994,6 +2994,8 @@ def _sink_tokens_aten(tokens) -> None: doc="Sink all of the tokens which were previously used for keeping track of side effects.", ) +torch.fx.node.has_side_effect(_sink_tokens) + register_rng_prims() register_debug_prims() From 9e88d50af667c4060872a15daf07ade99784203c Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 24 Nov 2025 23:40:33 +0000 Subject: [PATCH 026/338] [3/N] Use context managers (#167788) This PR uses context managers and suppresses ruff `SIM115` warnings in some places. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167788 Approved by: https://github.com/albanD --- test/distributed/argparse_util_test.py | 2 +- test/distributed/launcher/api_test.py | 2 +- test/distributed/launcher/test_run.py | 2 +- torch/_functorch/compilers.py | 11 ++--- .../runtime/caching/implementations.py | 4 +- torch/distributed/__init__.py | 2 +- .../elastic/timer/file_based_local_timer.py | 25 ++++++----- torch/hub.py | 2 +- torch/serialization.py | 2 +- torch/testing/_internal/common_utils.py | 2 +- .../_internal/distributed/distributed_test.py | 42 +++++++++---------- 11 files changed, 44 insertions(+), 52 deletions(-) diff --git a/test/distributed/argparse_util_test.py b/test/distributed/argparse_util_test.py index a3b3ef2bc717e..1902faf992734 100644 --- a/test/distributed/argparse_util_test.py +++ b/test/distributed/argparse_util_test.py @@ -16,7 +16,7 @@ class ArgParseUtilTest(unittest.TestCase): def setUp(self): # remove any lingering environment variables - for e in os.environ.keys(): # noqa: SIM118 + for e in os.environ.keys(): if e.startswith("PET_"): del os.environ[e] diff --git a/test/distributed/launcher/api_test.py b/test/distributed/launcher/api_test.py index 330fd302bbd45..48465516a913b 100644 --- a/test/distributed/launcher/api_test.py +++ b/test/distributed/launcher/api_test.py @@ -137,7 +137,7 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() # remove any lingering environment variables. - for env in os.environ.keys(): # noqa: SIM118 + for env in os.environ.keys(): if env.startswith("PET_"): del os.environ[env] diff --git a/test/distributed/launcher/test_run.py b/test/distributed/launcher/test_run.py index 484a975051d4f..1ba51bfa13908 100644 --- a/test/distributed/launcher/test_run.py +++ b/test/distributed/launcher/test_run.py @@ -70,7 +70,7 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() # remove any lingering environment variables - for env in os.environ.keys(): # noqa: SIM118 + for env in os.environ.keys(): if env.startswith("PET_"): del os.environ[env] diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index 8070e47153ca5..88954a636f915 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -391,13 +391,10 @@ def graph_saver_helper(gm_to_save, args, type_name): gm.to_folder( f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}" ) - pickle.dump( - input_meta, - open( - f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950 - "wb", - ), - ) # noqa: E501 + with open( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input" + ) as f: + pickle.dump(input_meta, f) if dump_example_input: torch.save( args, diff --git a/torch/_inductor/runtime/caching/implementations.py b/torch/_inductor/runtime/caching/implementations.py index 690855304b89d..ed83e490fd316 100644 --- a/torch/_inductor/runtime/caching/implementations.py +++ b/torch/_inductor/runtime/caching/implementations.py @@ -311,7 +311,7 @@ def insert(self, key: Any, value: Any) -> bool: r_fp, w_fp, inserted = None, None, False try: - w_fp = open(fpath, "xb") + w_fp = open(fpath, "xb") # noqa: SIM115 except FileExistsError: is_stale: bool = False with open(fpath, "rb") as r_fp: @@ -322,7 +322,7 @@ def insert(self, key: Any, value: Any) -> bool: # match so we choose to remove the old entry so that the new # k/v pair can be cached fpath.unlink() - w_fp = open(fpath, "xb") + w_fp = open(fpath, "xb") # noqa: SIM115 else: w_fp = None finally: diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 6c8912ffa4fa3..4e20a2b27e99d 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -76,7 +76,7 @@ class _DistributedPdb(pdb.Pdb): def interaction(self, *args, **kwargs): _stdin = sys.stdin try: - sys.stdin = open("/dev/stdin") + sys.stdin = open("/dev/stdin") # noqa: SIM115 pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index d0f61bf1cef32..8ed457a19f115 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -281,23 +281,22 @@ def _watchdog_loop(self) -> None: # 2. We are running the watchdog loop in a separate daemon # thread, which will not block the process to stop. try: - fd = open(self._file_path) + with open(self._file_path) as fd: + self._is_client_started = True + while not self._stop_signaled: + try: + run_once = self._run_once + self._run_watchdog(fd) + if run_once: + break + self._last_progress_time = int(time.time()) + except Exception: + logger.exception("Error running watchdog") + except Exception: logger.exception("Could not open the FileTimerServer pipe") raise - with fd: - self._is_client_started = True - while not self._stop_signaled: - try: - run_once = self._run_once - self._run_watchdog(fd) - if run_once: - break - self._last_progress_time = int(time.time()) - except Exception: - logger.exception("Error running watchdog") - def _run_watchdog(self, fd: io.TextIOWrapper) -> None: timer_requests = self._get_requests(fd, self._max_interval) self.register_timers(timer_requests) diff --git a/torch/hub.py b/torch/hub.py index 0862f4f84eaa0..bf138f7784347 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -736,7 +736,7 @@ def download_url_to_file( for _ in range(tempfile.TMP_MAX): tmp_dst = dst + "." + uuid.uuid4().hex + ".partial" try: - f = open(tmp_dst, "w+b") + f = open(tmp_dst, "w+b") # noqa: SIM115 except FileExistsError: continue break diff --git a/torch/serialization.py b/torch/serialization.py index ffa77cec732ed..398d011f324b5 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -746,7 +746,7 @@ def __exit__(self, *args): class _open_file(_opener[IO[bytes]]): def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None: - super().__init__(open(name, mode)) + super().__init__(open(name, mode)) # noqa: SIM115 def __exit__(self, *args): self.file_like.close() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 815cc8859080f..ef199e07d6a04 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1425,7 +1425,7 @@ def TemporaryFileName(*args, **kwargs): raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.") else: kwargs['delete'] = False - f = tempfile.NamedTemporaryFile(*args, **kwargs) + f = tempfile.NamedTemporaryFile(*args, **kwargs) # noqa:SIM115 try: f.close() yield f.name diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 478d3c978120b..8e6a5beb45ee7 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -87,6 +87,7 @@ skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, skipIfRocm, + TemporaryFileName, ) from torch.utils._python_dispatch import TorchDispatchMode from torch.utils.data.distributed import DistributedSampler @@ -215,10 +216,7 @@ def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False): def get_profiler_nccl_meta(prof): """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms" We will need to test metadata obtained from profiler here""" - with tempfile.NamedTemporaryFile(mode="w+t", suffix=".json") as tf: - tf.close() - trace_file = tf.name - + with TemporaryFileName(mode="w+t", suffix=".json") as trace_file: prof.export_chrome_trace(trace_file) with open(trace_file) as f: events = json.load(f)["traceEvents"] @@ -7075,27 +7073,25 @@ def _validate_execution_trace_nccl(self, et_file: str) -> None: def test_ddp_profiling_execution_trace(self): self.assertEqual(dist.get_backend(), "nccl") # Create a temp file to save execution trace data - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) - fp.close() - et_file = fp.name - et = ExecutionTraceObserver().register_callback(et_file) + with TemporaryFileName("w+t", suffix=".et.json") as et_file: + et = ExecutionTraceObserver().register_callback(et_file) - # first profiler context need not have ET - torch_profiler_ctx1 = torch.profiler.profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) - # collect ET in second profiler pass - torch_profiler_ctx2 = torch.profiler.profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - execution_trace_observer=et, - ) - self._test_ddp_profiling( - profiler_ctx=torch_profiler_ctx1, - profiler_ctx2=torch_profiler_ctx2, - ) + # first profiler context need not have ET + torch_profiler_ctx1 = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) + # collect ET in second profiler pass + torch_profiler_ctx2 = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + execution_trace_observer=et, + ) + self._test_ddp_profiling( + profiler_ctx=torch_profiler_ctx1, + profiler_ctx2=torch_profiler_ctx2, + ) - print(f"Execution trace saved at {fp.name}") - self._validate_execution_trace_nccl(et_file) + print(f"Execution trace saved at {et_file}") + self._validate_execution_trace_nccl(et_file) @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( From 3aeb7b0763f717096f6354adae53e350281bec77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakov=20Smoli=C4=87?= Date: Mon, 24 Nov 2025 23:48:27 +0000 Subject: [PATCH 027/338] inductor: fix failure in test_flex_decoding on class TestFlexDecoding(InductorTestCase): File "/root/pytorch/test/inductor/test_flex_decoding.py", line 751, in TestFlexDecoding @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") ^^^^^^^^^^^^^^ NameError: name 'SKIP_UT_ON_CPU' is not defined ``` Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/165404 Approved by: https://github.com/drisspg, https://github.com/Skylion007 --- test/inductor/test_flex_decoding.py | 32 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 995262b0f2104..27fdcc8fac404 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -31,7 +31,6 @@ skipXPUIf, ) from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS -from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils._triton import has_triton_tma_device @@ -59,22 +58,21 @@ ) TEST_ON_XPU = torch.xpu.is_available() and torch.utils._triton.has_triton() -if HAS_GPU: - if TEST_ON_CUDA: - test_device = ("cuda",) - test_dtypes = ( - [torch.float32, torch.bfloat16, torch.float16] - if PLATFORM_SUPPORTS_BF16 - else [torch.float16, torch.float32] - ) - test_dtypes_fast = [torch.float16] - SKIP_UT_ON_CPU = False - elif TEST_ON_XPU: - torch._C._set_onednn_allow_tf32(True) - test_device = ("xpu",) - test_dtypes = [torch.float32, torch.bfloat16, torch.float16] - test_dtypes_fast = [torch.float16] - SKIP_UT_ON_CPU = False +if TEST_ON_CUDA: + test_device = ("cuda",) + test_dtypes = ( + [torch.float32, torch.bfloat16, torch.float16] + if PLATFORM_SUPPORTS_BF16 + else [torch.float16, torch.float32] + ) + test_dtypes_fast = [torch.float16] + SKIP_UT_ON_CPU = False +elif TEST_ON_XPU: + torch._C._set_onednn_allow_tf32(True) + test_device = ("xpu",) + test_dtypes = [torch.float32, torch.bfloat16, torch.float16] + test_dtypes_fast = [torch.float16] + SKIP_UT_ON_CPU = False else: test_device = ("cpu",) torch_config_string = torch.__config__.show() From 42ab53de0b2361ec765420ca90102697eaf305f4 Mon Sep 17 00:00:00 2001 From: KaaustaaubShankar <76492384+KaaustaaubShankar@users.noreply.github.com> Date: Tue, 25 Nov 2025 00:01:15 +0000 Subject: [PATCH 028/338] Fix take_along_dim negative index handling (#146211) (#152161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes: #146211 This PR fixes an issue with `torch.take_along_dim()` not correctly handling negative indices. Previously, using negative values in the `indices` tensor caused an out-of-bounds error. This update wraps indices correctly, matching Python-style indexing semantics. ### 🔧 Changes - Modified `_take_along_dim_helper` to apply modulo logic for dimension-safe negative indexing. - Added a unit test `test_take_along_dim_negative_indices` to `test/test_indexing.py` to assert correctness of negative indexing behavior. ### 🧪 Testing ```bash pytest test/test_indexing.py -k test_take_along_dim_negative_indices ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/152161 Approved by: https://github.com/albanD --- aten/src/ATen/native/TensorAdvancedIndexing.cpp | 3 +++ torch/_refs/__init__.py | 4 ++++ torch/testing/_internal/common_methods_invocations.py | 8 ++++++++ 3 files changed, 15 insertions(+) diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 6c7efb3c161b0..537faf2a9194f 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -2669,6 +2669,9 @@ inline std::tuple _take_along_dim_helper( broadcast_shape = infer_size_symint(indices_sizes, self.sym_sizes()); auto self_broadcasted = at::broadcast_to_symint(self, broadcast_shape); + // Wrap negative indices to positive (Python-style) + indices_broadcasted = + indices_broadcasted.remainder(self_broadcasted.size(dim)); return std::make_tuple( std::move(self_broadcasted), std::move(indices_broadcasted), diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index e56163266caa1..4255142614103 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -4914,6 +4914,10 @@ def take_along_dim( broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size()) self_broadcast = broadcast_to(a, broadcast_shape) + # wrap negative indices + dim_size = self_broadcast.size(dim) + indices_broadcast = indices_broadcast % dim_size + return torch.gather(self_broadcast, dim, indices_broadcast) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0cf0f50c23ef5..5f3454ef54cca 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3007,6 +3007,14 @@ def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs yield SampleInput( make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device)) + # Negative indices sample — guarded against python_ref + if not kwargs.get('is_python_ref', False): + neg_idx = gather_variable((S, S), 1, S, True, device=device) - S + yield SampleInput( + make_arg((S, S)), + neg_idx, + 1) + def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs): From 6936e335f9b8ad0a83b65aa8af32b6b350e6f7d6 Mon Sep 17 00:00:00 2001 From: mansiag05 Date: Tue, 25 Nov 2025 00:20:21 +0000 Subject: [PATCH 029/338] =?UTF-8?q?Adding=20check=20for=20step=20size=3D0?= =?UTF-8?q?=20in=20unfold=20backward=20to=20avoid=20divide=20by=200=20?= =?UTF-8?q?=E2=80=A6=20(#162720)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …or FPE. Fixes #142462 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162720 Approved by: https://github.com/isuruf --- aten/src/ATen/native/UnfoldBackward.cpp | 1 + test/test_shape_ops.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/aten/src/ATen/native/UnfoldBackward.cpp b/aten/src/ATen/native/UnfoldBackward.cpp index ec4a2d7bf64c7..10c8f2b1a7bd0 100644 --- a/aten/src/ATen/native/UnfoldBackward.cpp +++ b/aten/src/ATen/native/UnfoldBackward.cpp @@ -21,6 +21,7 @@ Tensor unfold_backward( int64_t size, int64_t step ) { + TORCH_CHECK_VALUE(step > 0, "step is ", step, " but must be > 0"); auto grad_input = at::zeros(input_sizes, grad.options()); if (step >= size) { auto gI_unfolded = grad_input.unfold(dim, size, step); diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index 24c8122d5aeec..c8a06a49b5975 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -843,6 +843,16 @@ def test_unfold_errors(self, device): with self.assertRaisesRegex(RuntimeError, "step is -1 but must be > 0"): x.unfold(0, 1, -1) + def test_unfold_backward_errors(self, device): + grad_in = torch.randn(2, 3, device=device) + input_sizes = [6] + + with self.assertRaisesRegex(ValueError, "step is 0 but must be > 0"): + torch.ops.aten.unfold_backward(grad_in, input_sizes, 0, 3, 0) + + with self.assertRaisesRegex(RuntimeError, "size is -1 but must be >= 0"): + torch.ops.aten.unfold_backward(grad_in, input_sizes, 0, -1, 1) + instantiate_device_type_tests(TestShapeOps, globals()) From b5a4cde874bc667652facede4161d64b68a8db01 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 25 Nov 2025 00:22:18 +0000 Subject: [PATCH 030/338] [6/N] Use key in dict for existence checks (#168350) This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168350 Approved by: https://github.com/albanD --- .../maml_omniglot/support/omniglot_loaders.py | 2 +- test/functorch/discover_coverage.py | 12 +++++------ test/inductor/test_codecache.py | 20 +++++++++---------- test/inductor/test_decompose_mem_bound_mm.py | 2 +- test/inductor/test_fuzzer.py | 4 ++-- test/inductor/test_group_batch_fusion.py | 2 +- test/inductor/test_kernel_optimization.py | 2 +- test/inductor/test_loop_ordering.py | 2 +- test/inductor/test_max_autotune.py | 2 +- test/inductor/test_profiler.py | 2 +- test/inductor/test_quantization.py | 2 +- .../inductor/test_split_cat_fx_aten_passes.py | 2 +- test/inductor/test_split_cat_fx_passes.py | 2 +- test/optim/test_lrscheduler.py | 16 +++++++-------- 14 files changed, 36 insertions(+), 36 deletions(-) diff --git a/functorch/examples/maml_omniglot/support/omniglot_loaders.py b/functorch/examples/maml_omniglot/support/omniglot_loaders.py index ccba01ce181e8..b405174e58f1c 100644 --- a/functorch/examples/maml_omniglot/support/omniglot_loaders.py +++ b/functorch/examples/maml_omniglot/support/omniglot_loaders.py @@ -171,7 +171,7 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None): temp = {} # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} for img, label in self.x: - if label in temp.keys(): + if label in temp: temp[label].append(img) else: temp[label] = [img] diff --git a/test/functorch/discover_coverage.py b/test/functorch/discover_coverage.py index 2ac21e56c5c9c..c0c0db8762bde 100644 --- a/test/functorch/discover_coverage.py +++ b/test/functorch/discover_coverage.py @@ -90,7 +90,7 @@ def get_public_overridable_apis(pytorch_root="/raid/rzou/pt/debug-cpu"): def get_method_only_ops_we_care_about(): apis = get_public_overridable_apis() result = [] - for key in apis.keys(): + for key in apis: if not key.startswith("torch.Tensor"): continue if key in denylist: @@ -99,7 +99,7 @@ def get_method_only_ops_we_care_about(): # filter out in-place if api.endswith("_"): continue - if f"torch.{api}" not in apis.keys(): + if f"torch.{api}" not in apis: result.append(api) return result @@ -110,11 +110,11 @@ def get_method_only_ops_we_care_about(): def get_public_overridable_ops(): results = get_public_overridable_apis() cpy = copy.deepcopy(results) - for key in cpy.keys(): + for key in cpy: if not key.startswith("torch.Tensor"): continue api = key.split(".")[2] - if f"torch.{api}" in results.keys(): + if f"torch.{api}" in results: del results[key] return results @@ -122,7 +122,7 @@ def get_public_overridable_ops(): def get_public_overridable_outplace_ops(): results = get_public_overridable_ops() cpy = copy.deepcopy(results) - for key in cpy.keys(): + for key in cpy: # NB: there are no dunder methods bcs we don't document those if key.endswith("_"): del results[key] @@ -132,7 +132,7 @@ def get_public_overridable_outplace_ops(): def get_public_overridable_outplace_we_care_about(): results = get_public_overridable_outplace_ops() cpy = copy.deepcopy(results) - for key in cpy.keys(): + for key in cpy: # quantization if "quant" in key or ".q_" in key: del results[key] diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 4b9030b5cae4b..1ab261051f4c6 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -521,7 +521,7 @@ def fn(x, y): self.assertEqual(global_stats.fx_graph, Stats(2, 3, 2)) # Check that the cache entries seem reasonable - for k in global_stats.fx_graph.cache.keys(): + for k in global_stats.fx_graph.cache: self.assertRegex(k, r"pt2:fx-graph-v1::[0-9a-z]{52}:c[0-9]+") @requires_triton() @@ -2955,9 +2955,9 @@ def f(x, y, a, b): self.assertEqual(global_stats.autotune_remote, Stats(2, 2, 2)) # Check that the cache entries seem reasonable - for k in global_stats.autotune_remote.cache.keys(): + for k in global_stats.autotune_remote.cache: self.assertRegex(k, r"[0-9a-z]{52}") - for k in global_stats.triton.cache.keys(): + for k in global_stats.triton.cache: self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") @requires_gpu_and_triton @@ -2996,9 +2996,9 @@ def f(x, y, a, b): self.assertEqual(global_stats.autotune_remote, Stats(2, 2, 2)) # Check that the cache entries seem reasonable - for k in global_stats.autotune_remote.cache.keys(): + for k in global_stats.autotune_remote.cache: self.assertRegex(k, r"[0-9a-z]{52}") - for k in global_stats.triton.cache.keys(): + for k in global_stats.triton.cache: self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") @requires_gpu_and_triton @@ -3054,11 +3054,11 @@ def f(a, b, c, d, e, f): self.assertEqual(global_stats.bundled_autotune, Stats(1, 1, 1)) # Check that the cache entries seem reasonable - for k in global_stats.autotune_local.cache.keys(): + for k in global_stats.autotune_local.cache: self.assertRegex(k, r"tmp[^/]*/([^/]{2})/[^/]{64}\.best_config") - for k in global_stats.bundled_autotune.cache.keys(): + for k in global_stats.bundled_autotune.cache: self.assertRegex(k, r"pt2:bundled-autotune-v1::[0-9a-z]{64}:c[0-9]+") - for k in global_stats.triton.cache.keys(): + for k in global_stats.triton.cache: self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") @requires_triton() @@ -3159,10 +3159,10 @@ def f(a, b): self.assertEqual(global_stats.fx_graph, Stats(2, 1, 2)) # Check that the cache entries seem reasonable - for k in global_stats.aot_autograd.cache.keys(): + for k in global_stats.aot_autograd.cache: self.assertRegex(k, r"pt2:autograd-experimental::[0-9a-z]{52}:c[0-9]+") - for k in global_stats.fx_graph.cache.keys(): + for k in global_stats.fx_graph.cache: self.assertRegex(k, r"pt2:fx-graph-v1::[0-9a-z]{52}:c[0-9]+") @requires_gpu_and_triton diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index 4c07bc3e295aa..e880ed0d3573a 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -84,7 +84,7 @@ def compare_dict_tensors(self, ref_dict, res_dict, rtol=None, atol=None): self.setup_tolerance(rtol, atol) if len(set(ref_dict.keys())) != len(set(res_dict.keys())): return False - for key1 in ref_dict.keys(): + for key1 in ref_dict: key2 = "_orig_mod." + key1 assert key2 in res_dict, f"{key1} does not exist in traced module" if not torch.allclose( diff --git a/test/inductor/test_fuzzer.py b/test/inductor/test_fuzzer.py index d08f4c9282fa4..90871b3524d5e 100644 --- a/test/inductor/test_fuzzer.py +++ b/test/inductor/test_fuzzer.py @@ -150,7 +150,7 @@ def myfn(): self.assertEqual(len(new_results), 1) self.assertEqual( set(key_1.keys()), - {j for i in new_results.keys() for j in i} + {j for i in new_results.keys() for j in i} # noqa: SIM118 - set(MODULE_DEFAULTS["torch._inductor.config"].keys()), ) @@ -184,7 +184,7 @@ def myfn(): self.assertEqual(len(new_results), 1) self.assertEqual( set(key_1.keys()), - {j for i in new_results.keys() for j in i} + {j for i in new_results for j in i} # noqa: SIM118 - set(MODULE_DEFAULTS["torch._dynamo.config"].keys()), ) diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 7111e10a69fc6..adccebe785916 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -322,7 +322,7 @@ class TestGroupBatchFusion(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): return False - for key1 in ref_dict.keys(): + for key1 in ref_dict: key2 = "_orig_mod." + key1 assert key2 in res_dict, f"{key1} does not exist in traced module" if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): diff --git a/test/inductor/test_kernel_optimization.py b/test/inductor/test_kernel_optimization.py index b5ec255129805..dce810fd2cd14 100644 --- a/test/inductor/test_kernel_optimization.py +++ b/test/inductor/test_kernel_optimization.py @@ -32,7 +32,7 @@ class TestKernelOptimization(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): return False - for key1 in ref_dict.keys(): + for key1 in ref_dict: key2 = "_orig_mod." + key1 assert key2 in res_dict, f"{key1} does not exist in traced module" if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 60b4ce077bfcd..8be54c4adc022 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -812,7 +812,7 @@ def fn(nodes): n0, n1 = list(fused_norm_read_writes.var_ranges.keys()) # translation of above is n0 + 6 * n1 - self.assertTrue((n0 + 6 * n1) in fused_norm_read_writes.reads.keys()) + self.assertTrue((n0 + 6 * n1) in fused_norm_read_writes.reads) return nodes diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 90714b58951b1..db34336aeda99 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1405,7 +1405,7 @@ def test_inf_timing(self, multi_template): def mock_lookup(self, *args, **kwargs): timings = lookup(self, *args, **kwargs) - return {choice: float("inf") for choice in timings.keys()} + return {choice: float("inf") for choice in timings} a = torch.zeros([16, 16], device=GPU_TYPE) b = torch.zeros([16, 16], device=GPU_TYPE) diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index be35a2aedfe9e..b4e671c9ba68e 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -269,7 +269,7 @@ def fn(a, b, c): triton_events = [ event for event in trace_json["traceEvents"] - if "kernel_backend" in event.get("args", {}).keys() + if "kernel_backend" in event.get("args", {}) ] print(triton_events) diff --git a/test/inductor/test_quantization.py b/test/inductor/test_quantization.py index ecc46d00d1b87..0f137703d4f82 100644 --- a/test/inductor/test_quantization.py +++ b/test/inductor/test_quantization.py @@ -66,7 +66,7 @@ class TestQuantization(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): return False - for key1 in ref_dict.keys(): + for key1 in ref_dict: key2 = "_orig_mod." + key1 assert key2 in res_dict, f"{key1} does not exist in traced module" # if both of them are None, continue diff --git a/test/inductor/test_split_cat_fx_aten_passes.py b/test/inductor/test_split_cat_fx_aten_passes.py index a575c3b71374b..9b2f62e27488e 100644 --- a/test/inductor/test_split_cat_fx_aten_passes.py +++ b/test/inductor/test_split_cat_fx_aten_passes.py @@ -224,7 +224,7 @@ class TestSplitCatAten(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): return False - for key1 in ref_dict.keys(): + for key1 in ref_dict: key2 = "_orig_mod." + key1 assert key2 in res_dict, f"{key1} does not exist in traced module" if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index 4286bdfda7cd9..c1fc5ab8dd93f 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -1547,7 +1547,7 @@ def fn(x, y): numpy_compat_normalization(fn_t.graph) for n in fn_t.graph.nodes: - for k in n.kwargs.keys(): + for k in n.kwargs: self.assertTrue(k not in {"x", "x1", "x2", "a", "axis", "keepdims"}) @patch diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index 797822ea4deee..34066e633e844 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -2129,7 +2129,7 @@ def test_reduce_lr_on_plateau_state_dict(self): self.opt, mode="max", factor=0.5, patience=10 ) scheduler_copy.load_state_dict(scheduler.state_dict()) - for key in scheduler.__dict__.keys(): + for key in scheduler.__dict__: if key not in {"optimizer", "is_better"}: self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) @@ -2140,7 +2140,7 @@ def test_lambda_lr_state_dict_fn(self): scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x) scheduler_copy.load_state_dict(state) - for key in scheduler.__dict__.keys(): + for key in scheduler.__dict__: if key not in {"optimizer", "lr_lambdas"}: self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) @@ -2151,7 +2151,7 @@ def test_lambda_lr_state_dict_obj(self): scheduler_copy = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(-1)) scheduler_copy.load_state_dict(state) - for key in scheduler.__dict__.keys(): + for key in scheduler.__dict__: if key not in {"optimizer"}: self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) @@ -2176,7 +2176,7 @@ def _check_scheduler_state_dict(self, constr, constr2, epochs=10): scheduler.step() scheduler_copy = constr2() scheduler_copy.load_state_dict(scheduler.state_dict()) - for key in scheduler.__dict__.keys(): + for key in scheduler.__dict__: if key != "optimizer": self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr()) @@ -2328,7 +2328,7 @@ def _test_cycle_lr( ): for batch_num in range(batch_iterations): if verbose: - if "momentum" in self.opt.param_groups[0].keys(): + if "momentum" in self.opt.param_groups[0]: print( "batch{}:\tlr={},momentum={}".format( batch_num, @@ -2336,7 +2336,7 @@ def _test_cycle_lr( self.opt.param_groups[0]["momentum"], ) ) - elif use_beta1 and "betas" in self.opt.param_groups[0].keys(): + elif use_beta1 and "betas" in self.opt.param_groups[0]: print( "batch{}:\tlr={},beta1={}".format( batch_num, @@ -2364,7 +2364,7 @@ def _test_cycle_lr( rtol=0, ) - if use_beta1 and "betas" in param_group.keys(): + if use_beta1 and "betas" in param_group: self.assertEqual( momentum_target[batch_num], param_group["betas"][0], @@ -2376,7 +2376,7 @@ def _test_cycle_lr( atol=1e-5, rtol=0, ) - elif "momentum" in param_group.keys(): + elif "momentum" in param_group: self.assertEqual( momentum_target[batch_num], param_group["momentum"], From a2413b9ced3bbd4193f776193461dd1a51dc5bfb Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Mon, 24 Nov 2025 13:30:32 -0800 Subject: [PATCH 031/338] [precompile] move strict_autograd_cache=True decorator to aot_compile_fullgraph (#169008) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169008 Approved by: https://github.com/zhxchen17 ghstack dependencies: #168989 --- torch/_dynamo/aot_compile.py | 1 + torch/_dynamo/eval_frame.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index 20259b4595af7..196c073d6df99 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -185,6 +185,7 @@ def aot_compile_fullgraph( with ( get_metrics_context(), dynamo_timed("fullgraph_capture"), + torch._functorch.config.patch(strict_autograd_cache=True), ): capture_output = convert_frame.fullgraph_capture(model, args, kwargs) graph_capture_output = capture_output.graph_capture_output diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 1075b6d66f7c8..a9091767f70fd 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -802,15 +802,14 @@ def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: assert self._hooks is not None - with torch._functorch.config.patch(strict_autograd_cache=True): - return aot_compile_fullgraph( - fn, - example_inputs, - hooks=self._hooks, - backend=innermost_fn( - self.callback, unaltered_fn_attr="_torchdynamo_orig_backend" - ), - ) + return aot_compile_fullgraph( + fn, + example_inputs, + hooks=self._hooks, + backend=innermost_fn( + self.callback, unaltered_fn_attr="_torchdynamo_orig_backend" + ), + ) # add context containing GraphModule to any GraphModule forward functions if isinstance(fn, GraphModule): From 7c350369a719992df60352c8540d6432e001fb79 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Tue, 25 Nov 2025 00:16:42 +0000 Subject: [PATCH 032/338] [xpu][fix] Refine memory pool logic when expandable segement enabled (#168956) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Motivation This is a bug in the interaction between the memory pool for XPUGraph and expandable segments. When `unmap_block` is called, the allocator decreases `allocation_count` as expected: (see lines 862–867) https://github.com/pytorch/pytorch/blob/265397e178dab071294f6a10e35226fe333b2983/c10/xpu/XPUCachingAllocator.cpp#L862-L867 However, when an expandable segment is created via `try_allocate_expandable_block`, we never increment `allocation_count`. As a result, `allocation_count` can drop below its correct value after unmapping. # Solution This patch fixes the issue by ensuring `allocation_count` is incremented when creating a new expandable segment. # Additional Context PyTorch currently does not support using a custom allocator together with the expandable-segment feature in the memory pool. Therefore, we add an assertion to fail fast when this unsupported condition is detected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168956 Approved by: https://github.com/EikanWang --- c10/xpu/XPUCachingAllocator.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index d7eeb10caba1b..d97388c8703be 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -717,14 +717,23 @@ class DeviceCachingAllocator { if (isRetry) { stats.num_alloc_retries += 1; } + bool active_pool = + p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator(); if (set_fraction && stats.reserved_bytes[static_cast(StatType::AGGREGATE)].current + size > allowed_memory_maximum) { return false; } else if (AcceleratorAllocatorConfig::use_expandable_segments()) { + TORCH_CHECK( + !active_pool, + "torch.xpu.MemPool doesn't currently support expandable_segments."); p.block = try_allocate_expandable_block(device, p.queue(), p.pool, p.size()); + if (p.block && p.pool->owner_PrivatePool) { + // The block is used only for XPU graph's PrivatePool. + p.pool->owner_PrivatePool->allocation_count++; + } return bool(p.block); } void* ptr = sycl::aligned_alloc_device( From a928c9d9ea39b056c22cfe92371539907e36307e Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Tue, 25 Nov 2025 01:38:50 +0000 Subject: [PATCH 033/338] Remove useless parent method delegation (#168355) Remove redundant parent method delegations in Python code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168355 Approved by: https://github.com/Lucaskabela --- torch/_dynamo/variables/dicts.py | 5 ----- torch/_dynamo/variables/lists.py | 9 --------- torch/_inductor/ir.py | 6 ------ torch/ao/nn/intrinsic/qat/modules/conv_fused.py | 4 ---- .../fx/experimental/migrate_gradual_types/constraint.py | 6 ------ torch/testing/_internal/common_subclass.py | 3 --- 6 files changed, 33 deletions(-) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 7a74f487ff96c..93af8c46de01c 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -1341,11 +1341,6 @@ def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass - def install_dict_contains_guard( - self, tx: "InstructionTranslator", args: list[VariableTracker] - ) -> None: - super().install_dict_contains_guard(tx, args) - class FrozensetVariable(SetVariable): def debug_repr(self) -> str: diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 05129fcf8fb45..4f21e35479fb8 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1153,15 +1153,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.append_output(create_build_tuple(len(self.items))) - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - return super().call_method(tx, name, args, kwargs) - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "__class__": source = AttrSource(self.source, name) if self.source else None diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0f29d38cb44d0..6a2183f42886a 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -958,9 +958,6 @@ def _to_str(self, names: Sequence[str]) -> str: + [f"origin_node={self.origin_node!r}"] ) - def __post_init__(self) -> None: - super().__post_init__() - def __str__(self) -> str: return self._to_str(("ranges",)) @@ -8230,9 +8227,6 @@ def generate_output(output: Any, indices: list[tuple[Any, int]]) -> Any: # pyrefly: ignore [bad-return] return outputs - def apply_constraint(self) -> None: - return super().apply_constraint() - @ir_dataclass(frozen=False) class ComplexView(FallbackKernel): diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 1e49a274e129c..e52baa6a4e730 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -261,10 +261,6 @@ def _forward_slow(self, input): return conv_bn - def extra_repr(self): - # TODO(jerryzh): extend - return super().extra_repr() - def forward(self, input): return self._forward(input) diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py index 388d716245d4f..e46b3a607044a 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -138,9 +138,6 @@ def __init__(self, lhs, rhs, op): ) super().__init__(lhs, rhs, op) - def __eq__(self, other): - return super().__eq__(other) - class BinConstraintD(BinaryConstraint): """ @@ -153,9 +150,6 @@ def __init__(self, lhs, rhs, op): super().__init__(lhs, rhs, op) - def __eq__(self, other): - return super().__eq__(other) - class TGreatestUpperBound(Constraint): """ diff --git a/torch/testing/_internal/common_subclass.py b/torch/testing/_internal/common_subclass.py index 3aeb78035cb84..cca291133d3e9 100644 --- a/torch/testing/_internal/common_subclass.py +++ b/torch/testing/_internal/common_subclass.py @@ -200,9 +200,6 @@ def wrap(e): rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {}))) return rs - # To show how things happen later - def __rmul__(self, other): - return super().__rmul__(other) _SPECIAL_IMPLS = {} From b3bc797d23fa1430197bb2d6d1d70a3469daa91a Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Wed, 19 Nov 2025 02:02:48 +0000 Subject: [PATCH 034/338] [Inductor][Quant]Support qconv_pointwise.tensor and qconv2d_pointwise.binary_tensor (#166608) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166608 Approved by: https://github.com/Xia-Weiwen, https://github.com/mingfeima, https://github.com/jansel --- test/inductor/test_mkldnn_pattern_matcher.py | 84 ++++++++++++++++++++ torch/_inductor/fx_passes/quantization.py | 74 +++++++++++------ torch/_inductor/mkldnn_ir.py | 6 +- torch/_inductor/mkldnn_lowerings.py | 62 ++++++++++----- torch/_meta_registrations.py | 2 + 5 files changed, 181 insertions(+), 47 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index c135d05f060f1..a793a052c059d 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1164,6 +1164,25 @@ def matcher_check_fn(): quantization_with_autocast=quantization_with_autocast, ) + if torch._inductor.config.cpp_wrapper: + self._test_code_common( + mod, + (v,), + [f"aoti_torch_{device}__qconv_pointwise_tensor"], + [], + check_quantization=True, + num_include_ops=[3], + ) + else: + self._test_code_common( + mod, + (v,), + ["torch.ops.onednn.qconv_pointwise.tensor"], + [], + check_quantization=True, + num_include_ops=[3], + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm @@ -1270,6 +1289,25 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, ) + if torch._inductor.config.cpp_wrapper: + self._test_code_common( + mod, + (v,), + [f"aoti_torch_{device}__qconv_pointwise_tensor"], + [], + check_quantization=True, + num_include_ops=[2], + ) + else: + self._test_code_common( + mod, + (v,), + ["torch.ops.onednn.qconv_pointwise.tensor"], + [], + check_quantization=True, + num_include_ops=[2], + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_relu_cpu(self): @@ -1548,6 +1586,32 @@ def matcher_check_fn(): check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, ) + if not TEST_ACL: + if torch._inductor.config.cpp_wrapper: + self._test_code_common( + mod, + (v,), + [ + f"aoti_torch_{device}__qconv_pointwise_tensor", + f"aoti_torch_{device}__qconv2d_pointwise_binary_tensor", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + ) + else: + self._test_code_common( + mod, + (v,), + [ + "torch.ops.onednn.qconv_pointwise.tensor", + "torch.ops.onednn.qconv2d_pointwise.binary_tensor", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + ) + def _qconv2d_add_test_helper2( self, device="cpu", use_relu=False, int8_mixed_bf16=False ): @@ -1645,6 +1709,26 @@ def matcher_check_fn(): check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, ) + if not TEST_ACL: + if torch._inductor.config.cpp_wrapper: + self._test_code_common( + mod, + (x, x2, x3), + [f"aoti_torch_{device}__qconv2d_pointwise_binary_tensor"], + [], + check_quantization=True, + num_include_ops=[2], + ) + else: + self._test_code_common( + mod, + (x, x2, x3), + ["torch.ops.onednn.qconv2d_pointwise.binary_tensor"], + [], + check_quantization=True, + num_include_ops=[2], + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_add_cpu(self): diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index a0567da118109..ceb0ce3a2f6e6 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -179,9 +179,14 @@ def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): ) -def get_qconv_pt2e_pattern(users=1): +def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): + qconv_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) return CallFunction( - torch.ops.onednn.qconv_pointwise.default, + qconv_op, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), @@ -203,9 +208,14 @@ def get_qconv_pt2e_pattern(users=1): ) -def get_qconv2d_binary_pt2e_pattern(users=1): +def get_qconv2d_binary_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): + qconv_op = ( + torch.ops.onednn.qconv2d_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv2d_pointwise.binary + ) return CallFunction( - torch.ops.onednn.qconv2d_pointwise.binary, + qconv_op, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), @@ -431,7 +441,13 @@ def qconv(match: Match, *args, **kwargs): kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float8_e4m3fn, + torch.float32, + torch.bfloat16, + ] # Output QParams o_inv_scale = kwargs["output_scale"] o_zero_point = kwargs["output_zero_point"] @@ -816,12 +832,17 @@ def qconv_binary(match: Match, *args, **kwargs): def _register_quantization_unary_lowering(): # QConv2d - for users in [1, 2]: - qconv_pattern = get_qconv_pt2e_pattern(users) + for x_scale_zp_are_tensors, users in itertools.product([False, True], [1, 2]): + qconv_pattern = get_qconv_pt2e_pattern(x_scale_zp_are_tensors, users) + computation_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) _register_quantized_conv_lowering( qconv_pattern, 2, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, ) # QLinear @@ -841,12 +862,17 @@ def _register_quantization_unary_lowering(): def _register_quantization_binary_lowering(): # QConv2d - for users in (1, 2): - qconv_pattern = get_qconv2d_binary_pt2e_pattern(users) + for x_scale_zp_are_tensors, users in itertools.product([False, True], [1, 2]): + qconv_pattern = get_qconv2d_binary_pt2e_pattern(x_scale_zp_are_tensors, users) + computation_op = ( + torch.ops.onednn.qconv2d_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv2d_pointwise.binary + ) _register_quantized_conv_binary_lowering( qconv_pattern, 2, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + computation_op, ) # QLinear @@ -3027,13 +3053,13 @@ def _register_qconv_unary_fusion(): PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(users=1), aten.relu.default ), ), PostOpAttr( @@ -3041,7 +3067,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), 1, is_bf16, ), @@ -3052,7 +3078,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(users=1 if is_bf16 else 2), 2, is_bf16, ), @@ -3063,7 +3089,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(users=1 if is_bf16 else 2), 2, is_bf16, ), @@ -3083,14 +3109,14 @@ def _register_qconv_unary_fusion(): # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(users=1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), 1, is_bf16, ), @@ -3102,7 +3128,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(users=1 if is_bf16 else 2), 2, is_bf16, ), @@ -3114,7 +3140,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(users=1 if is_bf16 else 2), 2, is_bf16, ), @@ -3146,7 +3172,7 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -3158,7 +3184,7 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -3185,7 +3211,7 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -3223,7 +3249,7 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 0fb7bde84450d..0040d77a00afd 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -603,7 +603,7 @@ def __init__( inputs, constant_args, None, - op_overload=torch.ops.onednn.qconv_pointwise.default, + op_overload=torch.ops.onednn.qconv_pointwise.tensor, cpp_kernel_name=f"aoti_torch_{self.device_type}__qconv_pointwise_tensor", ) @@ -623,7 +623,7 @@ def create( x_zero_point: Union["ShapeAsConstantBuffer", "TensorBox"], qw: "TensorBox", # qw w_scale: "TensorBox", - w_zero_point: "TensorBox", + w_zero_point, bias: "TensorBox", stride: list[int], padding: list[int], @@ -711,7 +711,7 @@ def __init__( inputs, constant_args, None, - op_overload=torch.ops.onednn.qconv2d_pointwise.binary, + op_overload=torch.ops.onednn.qconv2d_pointwise.binary_tensor, cpp_kernel_name=( f"aoti_torch_{self.device_type}__qconv2d_pointwise_binary_tensor" ), diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 65261b2dff61b..14b492aff35ad 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -538,7 +538,7 @@ def qconvolution_unary( x_zp, packed_weight: TensorBox, w_scale: TensorBox, - w_zp: TensorBox, + w_zp, bias: TensorBox, stride, padding, @@ -551,15 +551,26 @@ def qconvolution_unary( scalars, algorithm, ): - # To align with qlinear where x_scale and x_zp are converted to Tensor - assert type(x_scale) is float - x_scale = V.graph.add_tensor_constant( - torch.tensor(x_scale, dtype=torch.float32), name="x_scale" - ) - assert type(x_zp) is int - x_zp = V.graph.add_tensor_constant( - torch.tensor(x_zp, dtype=torch.int32), name="x_zp" - ) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) is float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + + if x_zp is None: + x_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="x_zp" + ) + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) is int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + + if w_zp is None: + w_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="w_zp" + ) return TensorBox.create( mkldnn_ir.QConvPointWisePT2E.create( @@ -595,7 +606,7 @@ def qconvolution_binary( x_zp, packed_weight: TensorBox, w_scale: TensorBox, - w_zp: TensorBox, + w_zp, accum: TensorBox, bias: TensorBox, stride, @@ -613,15 +624,26 @@ def qconvolution_binary( unary_scalars, unary_algorithmm, ): - # To align with qlinear where x_scale and x_zp are converted to Tensor - assert type(x_scale) is float - x_scale = V.graph.add_tensor_constant( - torch.tensor(x_scale, dtype=torch.float32), name="x_scale" - ) - assert type(x_zp) is int - x_zp = V.graph.add_tensor_constant( - torch.tensor(x_zp, dtype=torch.int32), name="x_zp" - ) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) is float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + + if x_zp is None: + x_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="x_zp" + ) + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) is int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + + if w_zp is None: + w_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="w_zp" + ) if ( binary_attr == "sum" diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 2ed88a4ec2344..cd397a0bc29c9 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2552,6 +2552,7 @@ def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size): @register_meta(torch.ops.onednn.qconv2d_pointwise.default) @register_meta(torch.ops.onednn.qconv_pointwise.default) + @register_meta(torch.ops.onednn.qconv_pointwise.tensor) def meta_qconv_pointwise( x, x_scale, @@ -2603,6 +2604,7 @@ def meta_qconv_pointwise( return out @register_meta(torch.ops.onednn.qconv2d_pointwise.binary) + @register_meta(torch.ops.onednn.qconv2d_pointwise.binary_tensor) def meta_qconv2d_pointwise_binary( x, x_scale, From ca3e8b315f89c9f055b301ce31a939d8936ee4d1 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Tue, 25 Nov 2025 01:50:11 +0000 Subject: [PATCH 035/338] [1/N] Use TYPE_CHECKING (#165852) This PR moves typing imports into the `TYPE_CHECKING` block. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165852 Approved by: https://github.com/Lucaskabela --- torch/_export/utils.py | 4 ++-- torch/ao/nn/qat/dynamic/modules/linear.py | 2 +- torch/distributed/checkpoint/state_dict_saver.py | 9 +++++---- torch/fx/passes/_tensorify_python_scalars.py | 11 +++++++---- torch/onnx/_internal/exporter/_torchlib/ops/nn.py | 2 +- torch/onnx/_internal/fx/type_utils.py | 2 +- 6 files changed, 17 insertions(+), 13 deletions(-) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 3828dc97ac9bc..50a921a936d7d 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -24,6 +24,8 @@ if TYPE_CHECKING: + import sympy + from torch._export.passes.lift_constants_pass import ConstantAttrMap from torch._ops import OperatorBase from torch.export import ExportedProgram @@ -433,8 +435,6 @@ def _check_symint( def _check_input_constraints_for_graph( input_placeholders: list[torch.fx.Node], flat_args_with_path, range_constraints ) -> None: - import sympy # noqa: TC002 - if len(flat_args_with_path) != len(input_placeholders): raise RuntimeError( "Unexpected number of inputs " diff --git a/torch/ao/nn/qat/dynamic/modules/linear.py b/torch/ao/nn/qat/dynamic/modules/linear.py index 689a5361a7903..dc2238eedf6f9 100644 --- a/torch/ao/nn/qat/dynamic/modules/linear.py +++ b/torch/ao/nn/qat/dynamic/modules/linear.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: - from torch.ao.quantization.qconfig import QConfig # noqa: TC004 + from torch.ao.quantization.qconfig import QConfig __all__ = ["Linear"] diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 38ab2dcb510a8..204c2be176d14 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -6,15 +6,12 @@ from concurrent.futures import Future from dataclasses import dataclass from enum import Enum -from typing import cast, Optional, Union +from typing import cast, Optional, TYPE_CHECKING, Union from typing_extensions import deprecated import torch import torch.distributed as dist from torch.distributed._state_dict_utils import STATE_DICT_TYPE -from torch.distributed.checkpoint._async_executor import ( # noqa: TC001 - _AsyncCheckpointExecutor, -) from torch.distributed.checkpoint._async_process_executor import ( _ProcessBasedAsyncCheckpointExecutor, ) @@ -38,6 +35,10 @@ from .utils import _api_bc_check, _DistWrapper, _profile +if TYPE_CHECKING: + from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor + + __all__ = [ "save_state_dict", "save", diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index 089780e84705b..3e4c6c56bddf9 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -2,7 +2,7 @@ import logging import os -from typing import Any, Union +from typing import Any, TYPE_CHECKING, Union from sympy import Integer, Number, Symbol from sympy.logic.boolalg import BooleanAtom @@ -13,16 +13,14 @@ from torch._dynamo.symbolic_convert import TensorifyState from torch._dynamo.utils import get_metrics_context from torch._prims_common import get_computation_dtype -from torch._subclasses import fake_tensor # noqa: TCH001 from torch._subclasses.fake_tensor import FakeTensor from torch._utils_internal import justknobs_check from torch.fx._utils import lazy_format_graph_code -from torch.fx.experimental.symbolic_shapes import ( # noqa: TCH001 +from torch.fx.experimental.symbolic_shapes import ( guard_scalar, has_free_symbols, ShapeEnv, ) -from torch.fx.graph_module import GraphModule # noqa: TCH001 # TODO: refactor from torch.fx.passes.runtime_assert import _get_sym_val @@ -32,6 +30,11 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT +if TYPE_CHECKING: + from torch._subclasses import fake_tensor + from torch.fx.graph_module import GraphModule + + __all__: list[str] = [] log = logging.getLogger(__name__) diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py index 3f165dd0facc3..83eb5278380e1 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -1,7 +1,7 @@ """torch.ops.aten operators under the `core` module.""" # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" # pyrefly: ignore-errors -# ruff: noqa: TCH001,TCH002 +# ruff: noqa: TC001,TC002 # flake8: noqa: B950 from __future__ import annotations diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index 072f9f10e2646..7f6203d1d697c 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: - import onnx.defs # noqa: TCH004 + import onnx.defs # Enable both TorchScriptTensor and torch.Tensor to be tested From a5436a5e8e4ee42d1debf52c2786c7ae0043a434 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 25 Nov 2025 01:59:08 +0000 Subject: [PATCH 036/338] [CPU] add onednn context cache for qlinear to improve performance (#168150) **Summary** We noticed big framework overhead of `qlinear`. It's because to call onednn's primitive, we need to prepare a bunch of data structs as its args, which has big overhead. In the past, such things are cached in the context and attached to torch jit graph. However, Inductor does not support non-tensor data on graph. This PR adds a cache of those data structs by using a static `std::unordered_map`, whose key is weight data address as an `int64` and value is a struct that contains all data needed to run a primitive. This cache is safe in most cases where weight data address won't change during inference and weight data are not reused by different layers. However, since we cannot guarantee the assumption, we define an environment variable `"ONEDNN_CACHE_CONTEXT_UNSAFE"` to control this feature. Users should use it at their own risk. We found >5% E2E performance gain when running ViT with PT2E static quantization on an 6th gen of Intel Xeon CPU. **Test plan** ``` pytest -sv test/test_quantization.py -k "qlinear and pt2e" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168150 Approved by: https://github.com/mingfeima, https://github.com/jerryzh168 --- .../ATen/native/quantized/cpu/OnednnUtils.h | 36 ++++++++++ .../src/ATen/native/quantized/cpu/qlinear.cpp | 69 ++++++++++++++----- test/quantization/core/test_quantized_op.py | 53 +++++++++----- 3 files changed, 125 insertions(+), 33 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/OnednnUtils.h b/aten/src/ATen/native/quantized/cpu/OnednnUtils.h index 963a47a21fa9f..e3fe5c33406b6 100644 --- a/aten/src/ATen/native/quantized/cpu/OnednnUtils.h +++ b/aten/src/ATen/native/quantized/cpu/OnednnUtils.h @@ -462,4 +462,40 @@ at::Tensor _qconv_prepack_onednn( #define FP8E4M3_MAX 448.0 +#define CACHE_ONEDNN_CONTEXT_FLAG "ONEDNN_CACHE_CONTEXT_UNSAFE" + +struct QlinearForwardParams { + dnnl::matmul primitive; + ideep::exec_args args; + ideep::tensor packed_weight; + ideep::tensor weight_scales; + std::optional src_scale; + std::optional src_zero_point; + std::optional dst_scale; + std::optional dst_zero_point; + std::optional bias; + ideep::tensor scratchpad; + + void init_args() { + args.insert({DNNL_ARG_WEIGHTS, packed_weight}); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); + if (bias.has_value()) { + args.insert({DNNL_ARG_BIAS, bias.value()}); + } + if (src_scale.has_value()) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scale.value()}); + } + if (dst_scale.has_value()) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scale.value()}); + } + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, weight_scales}); + if (src_zero_point.has_value()) { + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point.value()}); + } + if (dst_zero_point.has_value()) { + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zero_point.value()}); + } + } +}; + #endif // #if AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 7a80b166f8cb7..ea1e6456d22d0 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -1147,24 +1147,13 @@ static at::Tensor linear_int8_with_onednn_weight( dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous(); auto src = at::native::itensor_from_tensor(input_contig); - auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight); - int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1); + int64_t K = input.size(dim - 1), M = input.numel() / K, N = onednn_weight.size(1); auto output_size = input.sizes().vec(); output_size[dim - 1] = N; - std::optional onednn_bias{std::nullopt}; bool with_bias = bias.has_value(); - at::Tensor bias_val_float; - if (with_bias) { - bias_val_float = bias.value().to(at::kFloat); - if (bias_val_float.dim() == 1) { - auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)}); - onednn_bias = at::native::itensor_view_from_dense(b_reshape); - } else { - onednn_bias = at::native::itensor_view_from_dense(bias_val_float); - } - } + std::vector src_dims = {M, K}; std::vector dst_dims = {M, N}; auto out_dtype = output_dtype.has_value() ? output_dtype.value() : input.scalar_type(); @@ -1185,6 +1174,39 @@ static at::Tensor linear_int8_with_onednn_weight( at::native::itensor_view_from_dense(other.value().reshape({-1, other.value().size(dim - 1)})) : empty_tensor; + // Fast path with cache of params + static const char* env_var = std::getenv(CACHE_ONEDNN_CONTEXT_FLAG); + static const std::string cache_flag_str = env_var ? std::string(env_var) : ""; + static const bool context_cache_enabled = cache_flag_str != "" && cache_flag_str == "1"; + static std::unordered_map qlinear_forward_params_map; + int64_t weight_addr = at::native::data_ptr_from_mkldnn(onednn_weight); + if (context_cache_enabled) { + auto it = qlinear_forward_params_map.find(weight_addr); + if (it != qlinear_forward_params_map.end()) { + auto& params = it->second; + auto& args = params.args; + args[DNNL_ARG_SRC] = std::move(src); + args[DNNL_ARG_DST] = std::move(dst); + if (binary_post_op == "add") { + args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = std::move(src1); + } + params.primitive.execute(ideep::stream::default_stream(), args); + return dim == 2 ? output : output.resize_(output_size); + } + } + + // Regular path + auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight); + tensor onednn_bias; + if (with_bias) { + at::Tensor bias_val_float = bias.value(); + if (bias_val_float.dim() == 1) { + auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)}); + onednn_bias = at::native::itensor_view_from_dense(b_reshape); + } else { + onednn_bias = at::native::itensor_view_from_dense(bias_val_float); + } + } // Create onednn primitive auto src_dtype = at::native::get_mkldnn_dtype(input.scalar_type()); auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any); @@ -1192,7 +1214,7 @@ static at::Tensor linear_int8_with_onednn_weight( auto dst_dtype = dst.get_data_type(); auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any); auto bias_desc = with_bias ? - tensor::desc(onednn_bias.value().get_dims(), ideep::data_type::f32, ideep::format_tag::any) : + tensor::desc(onednn_bias.get_dims(), onednn_bias.get_data_type(), ideep::format_tag::any) : empty_tensor_desc; // Get op attr for primitive // Note: output_scale & output_zero_point are for re-quantization of the final output. @@ -1249,7 +1271,7 @@ static at::Tensor linear_int8_with_onednn_weight( args.insert({DNNL_ARG_DST, dst}); args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); if (with_bias) { - args.insert({DNNL_ARG_BIAS, onednn_bias.value()}); + args.insert({DNNL_ARG_BIAS, onednn_bias}); } tensor src_scales_t = tensor(ideep::scale_t(1, input_scale)); tensor wei_scales_t = at::native::itensor_from_tensor(weight_scales); @@ -1273,7 +1295,22 @@ static at::Tensor linear_int8_with_onednn_weight( args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, src1}); } primitive.execute(ideep::stream::default_stream(), args); - return dim == 2 ? output : output.reshape(output_size); + // Update cache if needed + if (context_cache_enabled) { + QlinearForwardParams params; + params.primitive = primitive; + params.packed_weight = expected_weight; + params.weight_scales = wei_scales_t; + params.src_scale = input_scale != 1.0f ? std::make_optional(src_scales_t) : std::nullopt; + params.dst_scale = output_scale != 1.0f ? std::make_optional(dst_scales_t) : std::nullopt; + params.src_zero_point = input_zero_point != 0 ? std::make_optional(src_zp_t) : std::nullopt; + params.dst_zero_point = output_zero_point != 0 ? std::make_optional(dst_zp_t) : std::nullopt; + params.bias = with_bias ? std::make_optional(onednn_bias) : std::nullopt; + params.scratchpad = scratchpad; + params.init_args(); + qlinear_forward_params_map[weight_addr] = params; + } + return dim == 2 ? output : output.resize_(output_size); } #if AT_MKLDNN_ACL_ENABLED() diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 68053cdc61f81..7328870a64227 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4563,7 +4563,11 @@ def _test_qlinear_pt2e_helper( post_op="none", unary_post_op_args=(), post_op_algorithms=("none",), + test_fast_path=False, ): + if test_fast_path: + import os + os.environ["ONEDNN_CACHE_CONTEXT_UNSAFE"] = "1" qlinear_prepack = torch.ops.onednn.qlinear_prepack linear_op = F.linear in_channels_list = [4, 8] @@ -4615,12 +4619,14 @@ def _test_qlinear_pt2e_helper( qw_cpu = qw.int_repr() qw_packed = qlinear_prepack(qw_cpu, x.shape) + num_iter = 2 if test_fast_path else 1 # rerun to use cache if post_op in ("none", "relu", "gelu"): - qy_cpu = qlinear_op( - qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, - b, used_y_scale, used_y_zp, output_dtype, - post_op, unary_post_op_args, post_op_algo - ) + for _ in range(num_iter): + qy_cpu = qlinear_op( + qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, + b, used_y_scale, used_y_zp, output_dtype, + post_op, unary_post_op_args, post_op_algo + ) if post_op == "relu": y_ref = F.relu(y_ref) elif post_op == "gelu": @@ -4637,12 +4643,14 @@ def _test_qlinear_pt2e_helper( accum = qx2.int_repr() if output_dtype is None else qx2.dequantize() if bfloat16_out: accum = accum.bfloat16() - qy_cpu = qlinear_op( - qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, - accum, b, used_y_scale, used_y_zp, output_dtype, - x2_scale, x2_zp, "sum", binary_alpha, - unary_post_op, unary_post_op_args, post_op_algo - ) + for _ in range(num_iter): + # clone accum otherwise it gets accumulated multiple times + qy_cpu = qlinear_op( + qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, + accum.clone(), b, used_y_scale, used_y_zp, output_dtype, + x2_scale, x2_zp, "sum", binary_alpha, + unary_post_op, unary_post_op_args, post_op_algo + ) y_ref = y_ref + x2 * binary_alpha if unary_post_op == "relu": y_ref = F.relu(y_ref) @@ -4655,12 +4663,13 @@ def _test_qlinear_pt2e_helper( x2 = torch.randn(y_ref.size()) * 10 unary_post_op = "relu" if post_op == "add_relu" else "none" binary_alpha = 1.0 # we only support alpha=1.0 now - qy_cpu = qlinear_op( - qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, - x2, b, used_y_scale, used_y_zp, output_dtype, - 1.0, 0, "add", binary_alpha, - unary_post_op, unary_post_op_args, post_op_algo - ) + for _ in range(num_iter): + qy_cpu = qlinear_op( + qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, + x2, b, used_y_scale, used_y_zp, output_dtype, + 1.0, 0, "add", binary_alpha, + unary_post_op, unary_post_op_args, post_op_algo + ) y_ref = y_ref + x2 * binary_alpha if unary_post_op == "relu": y_ref = F.relu(y_ref) @@ -4686,17 +4695,22 @@ def _test_qlinear_pt2e_helper( y_s: {y_scale}, y_zp: {y_zp}""", ) + if test_fast_path: + del os.environ["ONEDNN_CACHE_CONTEXT_UNSAFE"] + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN def test_qlinear_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise self._test_qlinear_pt2e_helper(qlinear, "none") + self._test_qlinear_pt2e_helper(qlinear, "none", test_fast_path=True) @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN def test_qlinear_relu_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise self._test_qlinear_pt2e_helper(qlinear, "relu") + self._test_qlinear_pt2e_helper(qlinear, "relu", test_fast_path=True) @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN @@ -4704,30 +4718,35 @@ def test_qlinear_gelu_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise post_op_algorithms = ['none', 'tanh'] self._test_qlinear_pt2e_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms) + self._test_qlinear_pt2e_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms, test_fast_path=True) @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN def test_qlinear_sum_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_pt2e_helper(qlinear, "sum") + self._test_qlinear_pt2e_helper(qlinear, "sum", test_fast_path=True) @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN def test_qlinear_sum_relu_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_pt2e_helper(qlinear, "sum_relu") + self._test_qlinear_pt2e_helper(qlinear, "sum_relu", test_fast_path=True) @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN def test_qlinear_add_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_pt2e_helper(qlinear, "add") + self._test_qlinear_pt2e_helper(qlinear, "add", test_fast_path=True) @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN def test_qlinear_add_relu_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_pt2e_helper(qlinear, "add_relu") + self._test_qlinear_pt2e_helper(qlinear, "add_relu", test_fast_path=True) def _test_qlinear_fp8_helper( self, From 5a607febc04c3a2b5824c75f3f60307867439a2c Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 27 Nov 2025 20:20:06 +0000 Subject: [PATCH 037/338] Back out "Make PT2 compile backprop through custom op without autograd key a hard error (#166367)" (#168142) Summary: Original commit changeset: 7148dc4803f5 Original Phabricator Diff: D86736500 Differential Revision: D87407335 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168142 Approved by: https://github.com/wdvr --- aten/src/ATen/native/TensorCompare.cpp | 9 -- aten/src/ATen/native/native_functions.yaml | 5 -- test/distributed/test_inductor_collectives.py | 10 ++- test/test_autograd_fallback.py | 11 ++- torch/_functorch/aot_autograd.py | 4 - torch/_higher_order_ops/effects.py | 1 - torch/_library/autograd.py | 11 --- torch/_subclasses/fake_impls.py | 5 -- .../autograd_not_implemented_fallback.cpp | 90 +++++++------------ torch/fx/node.py | 1 - torchgen/native_function_generation.py | 1 - 11 files changed, 48 insertions(+), 100 deletions(-) diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 8a0b38eafab36..1a3843e9cdca8 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -480,14 +479,6 @@ Tensor isfinite(const Tensor& self) { }); } -void _async_error(std::string_view msg) { - TORCH_CHECK(0, msg); -} - -void _async_error_meta(std::string_view msg) { - // Do NOT error, it's an async error! -} - void _assert_async_cpu(const Tensor& self) { TORCH_CHECK( native::is_nonzero(self), diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4fa24ff378d72..81a782f733245 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -192,11 +192,6 @@ CompositeExplicitAutograd: _assert_tensor_metadata Meta: _assert_tensor_metadata_meta_symint -- func: _async_error(str msg) -> () - dispatch: - CompositeExplicitAutograd: _async_error - Meta: _async_error_meta - - func: _print(str s) -> () dispatch: CompositeExplicitAutograd: _print diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index fdf03fdf3a1f5..52062616a8562 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1348,11 +1348,13 @@ def func(inp, *, tag, ranks, group_size): assert counter.op_count == 3 # It generates 2 getattr to unpack the array assert same(out, correct) - # This doesn't work in all cases, and now we properly loudly error. - # See: https://github.com/pytorch/pytorch/issues/151240 - # When differentiable funcols are implemented can revert. - @unittest.expectedFailure def test_backwards(self): + """ + It's probably not that common to need backwards support for collectives. + + However, I wanted to at least see if it was possible to support it as a design goal. + """ + def func(inp): ar = _functional_collectives.all_reduce(inp, "sum", "0") return ar diff --git a/test/test_autograd_fallback.py b/test/test_autograd_fallback.py index 5748b5c4cca4b..d6252ac6f34a3 100644 --- a/test/test_autograd_fallback.py +++ b/test/test_autograd_fallback.py @@ -6,7 +6,6 @@ import numpy as np import torch -from torch._library.autograd import autograd_fallback_mode from torch.library import _scoped_library from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -16,6 +15,16 @@ ) +@contextlib.contextmanager +def autograd_fallback_mode(mode): + prev = torch._C._get_autograd_fallback_mode() + try: + torch._C._set_autograd_fallback_mode(mode) + yield + finally: + torch._C._set_autograd_fallback_mode(prev) + + class TestAutogradFallback(TestCase): test_ns = "_test_autograd_fallback" diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 8555026122ece..9fdebe6396d4b 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -26,7 +26,6 @@ from torch._guards import detect_fake_mode from torch._inductor.cudagraph_utils import BoxedDeviceIndex from torch._inductor.utils import BoxedBool -from torch._library.autograd import autograd_fallback_mode from torch._subclasses import FakeTensor, FakeTensorMode from torch.export._tree_utils import reorder_kwargs from torch.fx.experimental.proxy_tensor import make_fx @@ -529,9 +528,6 @@ def create_aot_state( stack.enter_context( torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing() ) - # Make it an error to backprop through PT2 compliant ops that silently - # detach autograd - stack.enter_context(autograd_fallback_mode("error")) from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj from torch._library.opaque_object import is_opaque_type diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index b2fc74b7328f1..86707a4f55ef1 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -59,7 +59,6 @@ def _get_effect(op: _op_identifier) -> Optional[_EffectType]: _register_effectful_op("aten::_print", _EffectType.ORDERED) -_register_effectful_op("aten::_async_error", _EffectType.ORDERED) _register_effectful_op("profiler::_record_function_exit._RecordFunction", None) _register_effectful_op(call_torchbind, _EffectType.ORDERED) _register_effectful_op(hop_print, _EffectType.ORDERED) diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index 125ed5b73d8e2..2707d07059edf 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import contextlib import dataclasses from collections.abc import Callable from dataclasses import dataclass @@ -236,16 +235,6 @@ def not_list_of_optional_tensor(tree): return True -@contextlib.contextmanager -def autograd_fallback_mode(mode): - prev = _C._get_autograd_fallback_mode() - try: - _C._set_autograd_fallback_mode(mode) - yield - finally: - _C._set_autograd_fallback_mode(prev) - - flatten = _pytree.tree_flatten unflatten = _pytree.tree_unflatten spec_t = _pytree.TreeSpec diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 530c8d939d77f..ff309af8a29e0 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -223,11 +223,6 @@ def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs): return r -@register_op_impl(aten._async_error.default) -def _async_error(fake_mode, func, msg: str): - pass - - @register_op_impl(aten.to.prim_Device) @register_op_impl(aten.to.device) def non_kwarg_to(fake_mode, func, *args, **kwargs): diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index a4a9afec1a7cc..386a8a9df534d 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -6,12 +6,6 @@ #include #include -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#endif - #include #include #include @@ -70,6 +64,7 @@ AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn; } // namespace void setAutogradFallbackMode(AutogradFallbackMode mode) { + TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'"); kAutogradFallbackMode = mode; } @@ -77,60 +72,41 @@ AutogradFallbackMode getAutogradFallbackMode() { return kAutogradFallbackMode; } -static void reportAutogradNotImplemented( - const std::string& op_name, - bool is_warn) { - if (is_warn) { - TORCH_WARN( - op_name, - ": an autograd kernel was not registered to the Autograd key(s) ", - "but we are trying to backprop through it. This may lead to silently incorrect behavior. ", - "This behavior is deprecated and will be removed in a future version of PyTorch. ", - "If your operator is differentiable, please ensure you have registered an " - "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " - "DispatchKey::CompositeImplicitAutograd). If your operator is not " - "differentiable, or to squash this warning and use the previous behavior, " - "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd."); - } else { - at::_async_error(c10::str( - op_name, - ": an autograd kernel was not registered to the Autograd key(s) ", - "but we are trying to backprop through it. This can lead to silently incorrect behavior. ", - "If your operator is differentiable, please ensure you have registered an " - "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " - "). If your operator is not " - "differentiable and ensure NO gradients flow through this operator, " - "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.")); - } +static void warnAutogradNotImplemented(const std::string& op_name) { + TORCH_WARN( + op_name, + ": an autograd kernel was not registered to the Autograd key(s) ", + "but we are trying to backprop through it. This may lead to silently incorrect behavior. ", + "This behavior is deprecated and will be removed in a future version of PyTorch. ", + "If your operator is differentiable, please ensure you have registered an " + "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " + "DispatchKey::CompositeImplicitAutograd). If your operator is not " + "differentiable, or to squash this warning and use the previous behavior, " + "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd."); } -struct NotImplementedBackward : public Node { - NotImplementedBackward( +struct WarnNotImplemented : public Node { + WarnNotImplemented( std::string op_name, size_t num_outputs, - bool is_warn, edge_list&& next_edges) : Node(std::move(next_edges)), op_name(std::move(op_name)), - num_outputs(num_outputs), - is_warn(is_warn) {} + num_outputs(num_outputs) {} - NotImplementedBackward(std::string op_name, size_t num_outputs, bool is_warn) - : op_name(std::move(op_name)), - num_outputs(num_outputs), - is_warn(is_warn) {} + WarnNotImplemented(std::string op_name, size_t num_outputs) + : op_name(std::move(op_name)), num_outputs(num_outputs) {} variable_list apply(variable_list&& inputs) override; std::string op_name; size_t num_outputs; - bool is_warn; }; // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) -auto NotImplementedBackward::apply(variable_list&& inputs) -> variable_list { +auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list { auto inputsLocal = std::move(inputs); - reportAutogradNotImplemented(op_name, is_warn); + warnAutogradNotImplemented(op_name); std::vector output(num_outputs); return output; } @@ -149,6 +125,8 @@ static void basicAutogradNotImplementedFallbackImpl( op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); return; } + TORCH_INTERNAL_ASSERT( + getAutogradFallbackMode() == AutogradFallbackMode::Warn); bool any_input_requires_grad = false; _foreach_tensor( @@ -164,9 +142,7 @@ static void basicAutogradNotImplementedFallbackImpl( // by putting it after the requires_grad checks. any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled(); - bool is_warn = getAutogradFallbackMode() == AutogradFallbackMode::Warn; - - std::shared_ptr grad_fn; + std::shared_ptr grad_fn; if (any_input_requires_grad) { // NB: It is standard to collect edges from all tensors // (see generated/VariableTypeEverything.cpp for examples) @@ -178,9 +154,8 @@ static void basicAutogradNotImplementedFallbackImpl( stack, stack_start, num_arguments); - grad_fn = std::shared_ptr( - new NotImplementedBackward( - op_name, all_tensors_on_stack.size(), is_warn), + grad_fn = std::shared_ptr( + new WarnNotImplemented(op_name, all_tensors_on_stack.size()), deleteNode); grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack)); } @@ -216,8 +191,8 @@ static void basicAutogradNotImplementedFallbackImpl( // >>> y = op(k) // >>> torch.autograd.grad(z.sum(), w) if (t.requires_grad()) { - t.register_hook([op_name, is_warn](const at::Tensor& grad) { - reportAutogradNotImplemented(op_name, is_warn); + t.register_hook([op_name](const at::Tensor& grad) { + warnAutogradNotImplemented(op_name); }); // If history is rebased, then we will attempt to warn // on the view's base. This will catch most cases (because @@ -227,19 +202,18 @@ static void basicAutogradNotImplementedFallbackImpl( const auto& base = t._base(); if (base.requires_grad()) { // Can only register_hook on tensors that require grad. - base.register_hook( - [op_name, is_warn](const at::TensorBase& grad) { - reportAutogradNotImplemented(op_name, is_warn); - }); + base.register_hook([op_name](const at::TensorBase& grad) { + warnAutogradNotImplemented(op_name); + }); } } return; } // If the post-autograd implementation returns any Tensors that - // don't require grad, then we install the NotImplementedBackward - // grad_fn. This grad_fn warns in backward and returns undefined - // tensor gradients. + // don't require grad, then we install the WarnNotImplemented grad_fn. + // This grad_fn warns in backward and returns undefined tensor + // gradients. // // NOTE [autograd fallback and in-place operations] // If the schema says the output is mutable, and the output diff --git a/torch/fx/node.py b/torch/fx/node.py index 294e15c550235..5afabe40ec341 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -90,7 +90,6 @@ _side_effectful_functions: set[Callable[..., Any]] = { torch._assert, torch._assert_async, - _ops.aten._async_error.default, _ops.aten._assert_async.msg, _ops.aten._assert_scalar.default, _ops.aten._assert_tensor_metadata.default, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 6cbb05682894e..f986c77f8faaa 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -55,7 +55,6 @@ # All of these operators don't have any tensor like returns FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ - "_async_error", "_assert_async", # no return "_assert_async.msg", # no return "_assert_tensor_metadata", # no return From 99024dec888ec1e50b546822a32b6fb2f35e5eaa Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 27 Nov 2025 20:21:01 +0000 Subject: [PATCH 038/338] Revert "Removed deprecated `split_cat_fx_passes` (#167738)" This reverts commit 7833690a37737e9284d1b87d5e0d8db23e9167ac. Reverted https://github.com/pytorch/pytorch/pull/167738 on behalf of https://github.com/wdvr due to split_cat_fx_passes is used much more widely than thought, we'll give teams some time to get rid of this before re-merging ([comment](https://github.com/pytorch/pytorch/pull/167738#issuecomment-3587188011)) --- test/inductor/test_perf.py | 2 ++ torch/_inductor/config.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 5ad37c10b2c1a..8a48bee86ba4e 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -278,6 +278,7 @@ def f(a, b): inp = (T(10, 10), T(10, 10)) self.assertExpectedInline(count_numel(f, *inp), """680""") + @patch.object(config, "split_cat_fx_passes", False) @patch.object( config, "pre_grad_fusion_options", @@ -299,6 +300,7 @@ def f(*inputs): inp = (T(10, 10) for _ in range(16)) self.assertExpectedInline(count_numel(f, *inp), """6400""") + @patch.object(config, "split_cat_fx_passes", False) @patch.object( config, "pre_grad_fusion_options", diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 45fa2d74acaed..7048990692da0 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -303,6 +303,9 @@ def prologue_fusion_enabled() -> bool: ] ] = None +# Deprecated +split_cat_fx_passes = True + # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. efficient_conv_bn_eval_fx_passes = False From 9844fbeadd5cebdf1281d6fbf79164139c352693 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 27 Nov 2025 20:32:06 +0000 Subject: [PATCH 039/338] Revert "[DTensor] update redistribute_cost, add disable_graph_based_transform (#166747)" This reverts commit a9184a03c8686caf3c8105bb104a70bfe17f1f5e. Reverted https://github.com/pytorch/pytorch/pull/166747 on behalf of https://github.com/atalman due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/166747#issuecomment-3587207604)) --- .../tensor/debug/test_debug_mode.py | 16 +++- test/distributed/tensor/test_op_strategy.py | 31 ------- test/distributed/tensor/test_redistribute.py | 74 --------------- torch/distributed/tensor/_collective_utils.py | 55 ++++------- torch/distributed/tensor/_redistribute.py | 93 ++----------------- 5 files changed, 37 insertions(+), 232 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index c8cc5930d4e67..801cb0ab64219 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -215,8 +215,12 @@ def test_debug_mode_densor_redistribution_trace(self): debug_mode.debug_string(), """\ aten::mm(dt: f32[128, 8]| S(0)[0]S(0)[1], dt: f32[8, 128]| S(1)[0]S(1)[1]) - redistribute_input(1, S(1)[0]S(1)[1] -> RR) - redistribute_input(t: f32[8, 16], trace: S(1)[0]S(1)[1]->S(1)R->RR) + redistribute_input(0, S(0)[0]S(0)[1] -> S(0)R) + redistribute_input(t: f32[16, 8], trace: S(0)[0]S(0)[1]->S(0)R) + _c10d_functional::all_gather_into_tensor(t: f32[16, 8], 2, 3) + _c10d_functional::wait_tensor(t: f32[32, 8]) + redistribute_input(1, S(1)[0]S(1)[1] -> RS(1)) + redistribute_input(t: f32[8, 16], trace: S(1)[0]S(1)[1]->S(1)R->RR->RS(1)) _c10d_functional::all_gather_into_tensor(t: f32[8, 16], 2, 3) _c10d_functional::wait_tensor(t: f32[16, 16]) aten::chunk(t: f32[16, 16], 2) @@ -225,9 +229,11 @@ def test_debug_mode_densor_redistribution_trace(self): _c10d_functional::wait_tensor(t: f32[32, 32]) aten::chunk(t: f32[32, 32], 4) aten::cat(['t: f32[8, 32]', 't: f32[8, 32]', 't: f32[8, 32]', 't: f32[8, 32]'], 1) - aten::mm(t: f32[16, 8], t: f32[8, 128]) - aten::sum(dt: f32[128, 128]| S(0)[0]S(0)[1]) - aten::sum(t: f32[16, 128])""", + aten::chunk(t: f32[8, 128], 2, 1) + aten::clone(t: f32[8, 64]) + aten::mm(t: f32[32, 8], t: f32[8, 64]) + aten::sum(dt: f32[128, 128]| S(0)S(1)) + aten::sum(t: f32[32, 64])""", ) def test_debug_mode_einsum(self): diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index e1d3f96e9e5f4..4819f40c74334 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -380,37 +380,6 @@ def test_bmm_strategies(self): ) self.assertFalse(output_sharding.needs_redistribute) - def test_redistribute_cost_with_order(self): - mesh_2d = DeviceMesh( - self.device_type, torch.arange(self.world_size).reshape(2, 2) - ) - - # Source: Shard on dim 0 across all three mesh dimensions - source_placement = (Shard(0), Shard(0)) - - # Target: Replicate on first mesh dimension, shard on others - # This requires 2 allgathers, one on dim=0 and one on dim=1 - replicate_mesh_dim0 = (Replicate(), Shard(0)) - - # Target: Replicate on second mesh dimension, shard on others - # This requires 1 allgather on dim=1 - replicate_mesh_dim1 = (Shard(0), Replicate()) - - global_tensor = torch.randn(4, 4) - global_tensor_meta = extract_tensor_meta(global_tensor) - - source_spec = DTensorSpec(mesh_2d, source_placement, global_tensor_meta) - target_spec_dim0 = DTensorSpec(mesh_2d, replicate_mesh_dim0, global_tensor_meta) - target_spec_dim1 = DTensorSpec(mesh_2d, replicate_mesh_dim1, global_tensor_meta) - - # Calculate costs for allgather on each mesh dimension - cost_mesh_dim0 = redistribute_cost(source_spec, target_spec_dim0) - cost_mesh_dim1 = redistribute_cost(source_spec, target_spec_dim1) - - # Cost increases with earlier mesh dimensions due to the way - # mesh dimensions are ordered (outer to inner in device hierarchy) - self.assertGreater(cost_mesh_dim0, cost_mesh_dim1) - # -------------Test op strategy registration------------- # custom op without List[Tensor] as input diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index ebb2c5f01668f..ec1d69e9b02e6 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -21,10 +21,6 @@ ) from torch.distributed.tensor._collective_utils import shard_dim_alltoall from torch.distributed.tensor._dtensor_spec import ShardOrderEntry -from torch.distributed.tensor._redistribute import ( - _gen_transform_infos, - use_min_cost_redistribution_plan, -) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -884,76 +880,6 @@ def test_ordered_redistribute(self): ) self.assertEqual(sharded_dt.to_local(), expected_dt.to_local()) - @with_comms - def test_force_min_cost_redistribution_plan(self): - """ - Test that the disable_graph_based_transform context manager correctly controls - the redistribution algorithm selection (graph-based vs greedy). - """ - # Set deterministic seed for reproducible tensor generation - torch.manual_seed(21) - mesh = init_device_mesh(self.device_type, (2, 2, 2)) - input_data = torch.randn((8, 8, 8), device=self.device_type) - - # the redistribution path differs if we use graph-based or greedy search solution - src_placement, src_order = ( - [Shard(0), Shard(0), Shard(0)], # All mesh dims shard tensor dim 0 - ( - ShardOrderEntry(tensor_dim=0, mesh_dims=(0, 1, 2)), - ), # Device order: 0→1→2 - ) - dst_placement, dst_order = ( - [Shard(1), Shard(1), Shard(1)], # All mesh dims shard tensor dim 1 - ( - ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 1, 2)), - ), # Device order: 0→1→2 - ) - - # Test both graph-based (enable_graph=True) and greedy (enable_graph=False) algorithms - for idx, enable_graph in enumerate([True, False]): - sharded_dt = _distribute_tensor( - input_data.clone(), mesh, src_placement, shard_order=src_order - ) - - with ( - use_min_cost_redistribution_plan(enabled=enable_graph), - DebugMode(record_torchfunction=False) as debug_mode, - ): - sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order) - trace_str = self._extract_redistribute_trace_from_debug_mode( - debug_mode.debug_string() - ) - - # Validate graph-based algorithm trace (idx=0, disable_graph=False) - # Graph-based uses optimal path search (Dijkstra's algorithm) - # Expected path has 6 transformations with strategic intermediate states - # Path: S(0)[0,1,2] → S(0)[0,1]S(2) → S(0)S(2)[1,0] → - # S(1)S(2)[1,0] → S(1)[0,1]S(2) → S(1)[0,1,2] - if idx == 0: - self.assertExpectedInline( - trace_str, - """S(0)[0]S(0)[1]S(0)[2]->S(0)[0]S(0)[1]S(2)->S(0)S(2)[1]S(2)[0]->S(1)S(2)[1]S(2)[0]->S(1)[0]S(1)[1]S(2)->S(1)[0]S(1)[1]S(1)[2]""", - ) - # Validate greedy algorithm trace (idx=1, disable_graph=True) - # Greedy uses simple heuristic approach (processes mesh dims sequentially) - # Expected path has 6 transformations but with different intermediate states - # Path: S(0)[0,1,2] → S(0)[0,1]R → S(0)RR → - # S(1)RR → S(1)[0,1]R → S(1)[0,1,2] - elif idx == 1: - self.assertExpectedInline( - trace_str, - """S(0)[0]S(0)[1]S(0)[2]->S(0)[0]S(0)[1]R->S(0)RR->S(1)RR->S(1)[0]S(1)[1]R->S(1)[0]S(1)[1]S(1)[2]""", - ) - expected_dt = _distribute_tensor( - input_data.clone(), mesh, dst_placement, shard_order=dst_order - ) - self.assertEqual(sharded_dt.to_local(), expected_dt.to_local()) - - # Clear the transformation cache between iterations. Without this, - # the second iteration would use cached paths from the first, - # causing the trace validation to fail because: - _gen_transform_infos.cache_clear() - @with_comms def test_generate_shard_orders(self): """Check if `generate_shard_orders` generates unique sharding combinations""" diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index 90f32efafd395..dff426a6d5e5a 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -227,7 +227,6 @@ def check_tensor_meta( return None -# TODO: autoparallel depends on this function, we will keep it until we update autoparallel redistribute_cost def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: assert spec.tensor_meta is not None, "spec should have tensor meta defined!" return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) @@ -339,61 +338,39 @@ def redistribute_cost( mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) cost = 0.0 + comm_bytes_gb = ( + spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 + ) # Transformation that considered for redistribute cost: # 1. allgather 2. alltoall # 3. allreduce 4. reduce_scatter - from torch.distributed._functional_collectives import _are_we_tracing - from torch.distributed.tensor._redistribute import ( - _gen_transform_infos, - _gen_transform_infos_non_cached, - ) - - # No redistribution needed when placements are already identical. - # This also prevents potential failures in _gen_transform_infos for certain configurations - # (e.g., sub-meshes) where finding a transform path between identical states may error out. - # TODO(zpcore): test placements with _StridedShard. - if current_spec.placements == target_spec.placements: - return cost - if _are_we_tracing(): - transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) - else: - transform_infos = _gen_transform_infos(current_spec, target_spec) - for transform_info in transform_infos: - assert current_spec.tensor_meta is not None, ( - "spec should have tensor meta defined!" - ) - comm_bytes_gb = ( - current_spec.tensor_meta.dtype.itemsize - * math.prod(transform_info.logical_shape) - / 1024 - / 1024 - / 1024 - ) - current = transform_info.src_dst_placements[0] - target = transform_info.src_dst_placements[1] + for i, (current, target) in enumerate( + zip(current_spec.placements, target_spec.placements) + ): if current == target: continue - mesh_dim = transform_info.mesh_dim - num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] if current.is_shard() and target.is_replicate(): + # allgather gives larger comm bytes + comm_bytes_gb *= num_devices_on_mesh_dim # add up allgather comm cost - cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + cost += allgather_cost(comm_bytes_gb, mesh_topo, i) elif current.is_shard() and target.is_shard(): - # should be alltoall comm, since we haven't implement it yet, add 1.0 as penalty + # should be alltoall comm, since we haven't implement it yet, add penalty # to favor allgather instead - # TODO: add alltoall_cost - comm_bytes_gb /= num_devices_on_mesh_dim - cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + 1.0 + cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0 elif current.is_partial() and target.is_replicate(): # add up allreduce comm cost - cost += allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim) + cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) elif current.is_partial() and target.is_shard(): # add up reduce_scatter comm cost - cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim) + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) # after reduce_scatter the comm bytes for further collectives halved. comm_bytes_gb /= num_devices_on_mesh_dim elif current.is_shard() and target.is_partial(): # ban shard -> partial as it does not make sense to perform # this redistribute return float("inf") + return cost diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 84e58c4df169c..a407ba6ca91df 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -32,72 +32,6 @@ logger = logging.getLogger(__name__) -# Global configuration flag to control the redistribution planning strategy. -# When True, forces the graph-based algorithm using Dijkstra's shortest path. -# When False, prefers the greedy algorithm for faster planning. Uses the graph-based algorithm -# only when necessary to support strided-shard redistribution -_FORCE_MIN_COST_REDISTRIBUTION_PLAN: Optional[bool] = None - - -@contextlib.contextmanager -def use_min_cost_redistribution_plan(enabled: bool = True): - """ - Context manager to control the redistribution planning strategy for DTensor operations. - - This context manager allows you to choose between two algorithms for computing the - sequence of collective operations needed to redistribute a DTensor from one placement - to another: - - - **Graph-based**: Uses Dijkstra's algorithm to find the minimum-cost path - through all possible placement transformations. This approach considers the global - cost of all collective operations and finds the optimal sequence. Best for complex - redistribution patterns where reducing communication cost and memory overhead is critical. - - - **Greedy**: Uses a heuristic approach that makes locally optimal choices - at each step. This is faster to compute but may not produce the globally optimal - transformation sequence. Best for simple redistribution patterns or when planning - speed is more important than optimal communication. - - **Default Behavior (without this context manager):** - - When this context manager is NOT used, the algorithm selection follows this priority: - - 1. **Non-default shard orders** - → Always use graph-based algorithm (required for correctness) - - 2. **Explicit `use_graph_based_transform` parameter** to `_gen_transform_infos_non_cached` - → Use the specified algorithm (True = graph-based, False = greedy) - - 3. **No explicit parameter** (default case) - → Use greedy algorithm for faster planning - - **Behavior with this context manager:** - - This context manager overrides the default selection by setting the global flag - `_FORCE_MIN_COST_REDISTRIBUTION_PLAN`, which takes precedence over the explicit - `use_graph_based_transform` parameter (but not over non-default shard order requirements). - - **Cache Considerations:** - - The redistribution planner caches transform info for performance via the `@cache` - decorator on `_gen_transform_infos`. If you need to change the algorithm selection - for the same input specs, clear the cache using `_gen_transform_infos.cache_clear()` - to ensure the new setting takes effect and doesn't reuse cached results from a - previous run. - - Args: - enabled (bool): If True, forces the use of the graph-based algorithm. - If False, forces the use of the greedy algorithm. - Default: True - """ - global _FORCE_MIN_COST_REDISTRIBUTION_PLAN - old_value = _FORCE_MIN_COST_REDISTRIBUTION_PLAN - _FORCE_MIN_COST_REDISTRIBUTION_PLAN = enabled - try: - yield - finally: - _FORCE_MIN_COST_REDISTRIBUTION_PLAN = old_value - class _TransformInfo(NamedTuple): mesh_dim: int @@ -714,29 +648,22 @@ def _gen_transform_infos_non_cached( dst_spec: DTensorSpec, use_graph_based_transform: Optional[bool] = None, ) -> list[_TransformInfo]: + transform_infos: list[_TransformInfo] = [] device_mesh = src_spec.device_mesh src_shard_order = src_spec.shard_order dst_shard_order = dst_spec.shard_order # DTensorSpec should automatically generate shard_order, and it can be () if # no shard. assert src_shard_order is not None and dst_shard_order is not None - - # Determine which transform strategy to use: - # 1. Non-standard device order → always use graph-based - # 2. Global flag or explicit parameter True → use graph-based - # 3. Otherwise → use greedy - has_non_default_order = not all( - DTensorSpec.is_default_device_order(order) - for order in (src_shard_order, dst_shard_order) - ) - - if has_non_default_order is True: - use_graph_based_transform = True - elif _FORCE_MIN_COST_REDISTRIBUTION_PLAN is not None: - use_graph_based_transform = _FORCE_MIN_COST_REDISTRIBUTION_PLAN - elif use_graph_based_transform is None: - use_graph_based_transform = False - + if use_graph_based_transform is None: + if all( + DTensorSpec.is_default_device_order(order) + for order in (src_shard_order, dst_shard_order) + ): + use_graph_based_transform = False + else: + # switch to graph search algorithm if the device order is not the default + use_graph_based_transform = True drp = get_redistribute_planner(device_mesh, len(src_spec.shape)) if use_graph_based_transform: transform_infos = drp.generate_graph_based_transform_infos( From 2f9b7dad7b5419b063bd0f2e204de192720ebb94 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 27 Nov 2025 20:33:52 +0000 Subject: [PATCH 040/338] Revert "Support AC in default partitioner when functionalization is enabled (#166610)" This reverts commit 2e1821bfda3602044657e0edb33d5700c9b86671. Reverted https://github.com/pytorch/pytorch/pull/166610 on behalf of https://github.com/atalman due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/166610#issuecomment-3587209482)) --- .../distributed/tensor/test_dtensor_export.py | 2 + test/dynamo/test_activation_checkpointing.py | 267 +++---------- test/functorch/test_aotdispatch.py | 15 +- test/higher_order_ops/test_local_map.py | 4 +- .../_aot_autograd/functional_utils.py | 20 +- .../_aot_autograd/graph_capture_wrappers.py | 5 - torch/_functorch/partitioners.py | 370 +++++++----------- 7 files changed, 204 insertions(+), 479 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index 4a88cf9a6e0b1..bd75668ab4856 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import contextlib +import unittest import torch import torch.distributed as dist @@ -356,6 +357,7 @@ def test_export_parallelize_module_with_dtensor_input( # aot_export_joint_with_descriptors on strict-exported exported_program.module() # is producing a joint graph with backward region missing + @unittest.expectedFailure def test_strict_export_parallelize_module_with_dtensor_input(self): self._run_test(strict_export_and_aot_export_joint_with_descriptors) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 768555efd1d4c..0d32a9e4917f5 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -15,7 +15,7 @@ import torch.distributed as dist import torch.nn as nn import torch.utils.checkpoint -from functorch.compile import default_partition, min_cut_rematerialization_partition +from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd from torch._dynamo.testing import ( AotEagerAndRecordGraphs, @@ -24,7 +24,7 @@ ) from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor @@ -281,14 +281,7 @@ def runtime_wrapper(*runtime_args): run(export_compiler) - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_tags_function(self, device, partition_fn): + def test_tags_function(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -304,22 +297,11 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd( - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - partition_fn=partition_fn, - ) + backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) @requires_cuda_and_triton - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_tags_function_via_global_checkpoint(self, device, partition_fn): + def test_tags_function_via_global_checkpoint(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -334,28 +316,17 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd( - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - partition_fn=partition_fn, - ) + backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) @requires_cuda_and_triton - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_tags_function_with_kwargs(self, device, partition_fn): + def test_tags_function_with_kwargs(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) def fn(x, y): return torch.utils.checkpoint.checkpoint( - gn, torch.sin(x), y, use_reentrant=False + gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False ) x = torch.randn(4, 4, device=device, requires_grad=True) @@ -365,22 +336,11 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd( - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - partition_fn=partition_fn, - ) + backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) @requires_cuda_and_triton - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_tags_sequential_layers(self, device, partition_fn): + def test_tags_sequential_layers(self, device): def gn(x): x = x.cos() for _ in range(3): @@ -401,22 +361,11 @@ def fn(x): freqs=[2, 18], ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default], ) # mm recomputed in the bwd - backend = aot_autograd( - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - partition_fn=partition_fn, - ) + backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) @requires_cuda_and_triton - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_tags_multiple_checkpoints(self, device, partition_fn): + def test_tags_multiple_checkpoints(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -434,22 +383,11 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=6, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd( - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - partition_fn=partition_fn, - ) + backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) @requires_cuda_and_triton - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_tags_module(self, device, partition_fn): + def test_tags_module(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -473,22 +411,11 @@ def fn(x): bw_compiler = functools.partial( count_ops, freq=1, op=torch.ops.aten.sigmoid.default ) - backend = aot_autograd( - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - partition_fn=partition_fn, - ) + backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) @requires_cuda_and_triton - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_tags_decomps(self, device, partition_fn): + def test_tags_decomps(self, device): # Ensures that tags are passed on through decompositions as well class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -516,7 +443,6 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, decompositions=lambda: import_module( "torch._inductor.compile_fx" ).select_decomp_table(), @@ -776,14 +702,7 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn): + def test_compile_selective_checkpoint_must_recompute(self, device): def context_fn_must_recompute_mm(): must_recompute_list = [ torch.ops.aten.mm.default, @@ -804,9 +723,9 @@ def context_fn_no_recompute_mm(): ), ) - def _test(context_fn, bw_compiler, partition_fn): + def _test(context_fn, bw_compiler): def gn(x): - return torch.cos(torch.sin(torch.matmul(x, x) @ x)) + return torch.sigmoid(torch.matmul(x, x)) def fn(x): return torch.utils.checkpoint.checkpoint( @@ -820,14 +739,14 @@ def fn(x): fw_compiler = functools.partial( count_ops, - freq=2, + freq=1, op=torch.ops.aten.mm.default, ) backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x) @@ -835,19 +754,17 @@ def fn(x): context_fn=context_fn_must_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6) + freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3) op=torch.ops.aten.mm.default, ), - partition_fn=partition_fn, ) _test( context_fn=context_fn_no_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=4, # 2 bwd mm ops per fwd matmul + freq=2, # 2 bwd mm ops per fwd matmul op=torch.ops.aten.mm.default, ), - partition_fn=partition_fn, ) def test_sac_with_partial_context_fn(self): @@ -884,16 +801,7 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_must_not_recompute_gemm( - self, device, partition_fn - ): + def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -933,22 +841,15 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization( - self, device, partition_fn + self, device ): def selective_checkpointing_context_fn(): no_recompute_list = [ @@ -988,7 +889,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, disable_functionalization=True, ) self._validate(fn, backend, x, y) @@ -996,14 +897,7 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn): + def test_compile_selective_checkpoint_triton_kernel(self, device): # Copy of the above test, but make sure that having a triton kernel in the # region does not error. def add_one(x): @@ -1063,21 +957,14 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn): + def test_compile_selective_checkpoint_tensor_subclass(self, device): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1120,21 +1007,14 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn): + def test_compile_selective_checkpoint_custom_rule(self, device): def _get_custom_policy(meta): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1192,21 +1072,14 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn): + def test_compile_selective_checkpoint_partial_ctx_fn(self, device): def selective_checkpointing_context_fn(no_recompute_list): return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) @@ -1245,21 +1118,14 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn): + def test_compile_selective_checkpoint_outplace_op(self, device): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1297,21 +1163,14 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_list_ops(self, device, partition_fn): + def test_compile_selective_checkpoint_list_ops(self, device): def selective_checkpointing_context_fn(): # recompute everything no_recompute_list = [] @@ -1347,7 +1206,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1358,14 +1217,7 @@ def fn(x, y): "requires TorchDispatchMode + torch.compile work to complete" ) @requires_cuda_and_triton - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn): + def test_compile_selective_checkpoint_inplace_op(self, device): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1405,7 +1257,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1413,14 +1265,7 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @torch._inductor.config.patch(fallback_random=True) - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_random_op(self, device, partition_fn): + def test_compile_selective_checkpoint_random_op(self, device): for preserve_rng_state in [True, False]: def selective_checkpointing_context_fn(): @@ -1467,7 +1312,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, @@ -1479,14 +1324,7 @@ def fn(x): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_invalid_context(self, partition_fn): + def test_compile_selective_checkpoint_invalid_context(self): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) * y @@ -1515,7 +1353,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) with self.assertRaisesRegex( Exception, "must generate a tuple of two `TorchDispatchMode`s" @@ -1524,14 +1362,7 @@ def fn(x, y): @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) - @parametrize( - "partition_fn", - [ - min_cut_rematerialization_partition, - default_partition, - ], - ) - def test_compile_selective_checkpoint_parametrization(self, partition_fn): + def test_compile_selective_checkpoint_parametrization(self): def sac_policy(): def _recomp_policy(): def _custom_policy(ctx, func, *args, **kwargs): @@ -1594,9 +1425,7 @@ def reset_parameters(self): bw_compiler = functools.partial( count_ops, freqs=[ - # 1 from mul recompute, 1 from mul backward - # w/o CSE, we have one extra mul - 3 if partition_fn is default_partition else 2, + 2, # 1 from mul recompute, 1 from mul backward 1, ], ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], @@ -1605,7 +1434,7 @@ def reset_parameters(self): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=partition_fn, + partition_fn=min_cut_rematerialization_partition, ) model = MLPModule() diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index c452f18e95d75..6cae42d8929da 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2640,7 +2640,7 @@ def backward(ctx, grad_output): return grad_output * x, grad_output * x def f(a, b): - return FwBwMutation.apply(a, b).sin_().clone() + return FwBwMutation.apply(a, b) inps = [ torch.ones(3, 3, requires_grad=True), @@ -2689,22 +2689,17 @@ def forward(self, primals_1, primals_2): add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None _foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None - clone = torch.ops.aten.clone.default(mul) - sin_ = torch.ops.aten.sin_.default(mul); mul = None - clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None - return (clone_1, add, clone)""", + return (mul, add)""", ) # important bit: there is 1 mutation in the bw self.assertExpectedInline( bw_graph[0].code.strip(), """\ -def forward(self, add, clone, tangents_1): - cos = torch.ops.aten.cos.default(clone); clone = None - mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None +def forward(self, add, tangents_1): _foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None - mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None - return (mul_2, None)""", + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None + return (mul_1, None)""", ) def test_fw_bw_mutation_no_functionalization2(self): diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 7b5f01d236e7f..a585f2055e89f 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -911,8 +911,8 @@ def inputs_fn(): op="call_function", target=torch.ops.aten.mm.default ) self.assertEqual(len(mm_nodes), 4) - self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward") - self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward") + self.assertNotIn("partitioner_tag", mm_nodes[0].meta) + self.assertNotIn("partitioner_tag", mm_nodes[1].meta) self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0) diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 5af4fc9ee1195..fcbf861e537db 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -10,7 +10,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional import torch from torch import Tensor @@ -450,7 +449,7 @@ def was_tensor_metadata_updated(arg, new_arg): # Returns the number of detected copy_ -def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]: +def assert_functional_graph(fx_g: torch.fx.Graph) -> int: allowed_mutation_ops = [ torch.ops.aten.copy_.default, torch.ops.aten.set_.source_Tensor, @@ -463,7 +462,6 @@ def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]: # NB: It would also be nice to verify that the mutations all happen at the # end, but we also do some administrative views after mutations so this # isn't actually true. (TODO: Could this cause problems for Inductor?) - error = None for n in fx_g.nodes: if n.op == "placeholder": placeholders.add(n) @@ -473,18 +471,14 @@ def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]: # this is mostly a hack to avoid failing XLA tests. # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 if "set_buffer_donor_" not in str(n.args[0]): - if n.args[0] not in placeholders: - error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + assert n.args[0] in placeholders, ( + f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + ) mutation_count += 1 else: - if n.target._schema.is_mutable: - error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" - return error, mutation_count - - -def assert_functional_graph(fx_g: torch.fx.Graph) -> int: - error, mutation_count = _is_functional_graph(fx_g) - assert error is None, error + assert not n.target._schema.is_mutable, ( + f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + ) return mutation_count diff --git a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py index 2ef84cb488604..bc4dc87ddeced 100644 --- a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -27,7 +27,6 @@ from torch._prims_common import CUDARngStateHelper from torch.fx.experimental.proxy_tensor import ( _proxy_tensor_disable_update_tensor_tracker, - get_proxy_mode, maybe_disable_thunkify, maybe_enable_thunkify, ) @@ -296,10 +295,6 @@ def inner_fn( (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs( fn, primals ) - mode = get_proxy_mode() - assert mode is not None, "Expected non-None proxy mode" - for node in mode.tracer.graph.nodes: - node.meta["partitioner_tag"] = "is_forward" # TODO: I think this hook can also be eliminated now if joint_fn_handle and joint_fn_handle.post_forward: diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 3b79a50ff9e21..c273ba39ce167 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -10,7 +10,6 @@ import os import os.path import re -import warnings from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass, replace @@ -53,7 +52,6 @@ ) from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput -from ._aot_autograd.functional_utils import _is_functional_graph from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems @@ -300,10 +298,6 @@ def _has_tag_is_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "is_backward" -def _has_tag_is_forward(node: fx.Node) -> bool: - return node.meta.get("partitioner_tag", None) == "is_forward" - - def _has_tag_must_be_in_forward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "must_be_in_forward" @@ -1028,87 +1022,69 @@ def default_partition( Returns: Returns the generated forward and backward Fx graph modules. """ - # Respect the original placement of ops rather than rely on dataflow. - forward_nodes = [] - last_node = None - for node in joint_module.graph.nodes: - if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node): - last_node = node - assert last_node is not None - for node in joint_module.graph.nodes: - if not _is_tangent(node): - forward_nodes.append(node) - if node is last_node: - break - forward_node_names = OrderedSet( - node.name for node in forward_nodes if node.op != "output" + if has_recomputable_ops(joint_module): + return min_cut_rematerialization_partition( + joint_module, + _joint_inputs, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( + _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) ) - graph_has_recomputable_ops = has_recomputable_ops(joint_module) - graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) - if graph_has_recomputable_ops: - if _is_functional_graph(joint_module.graph)[0] is not None: - # Fall-back to previous behavior to avoid bc-breaking, although can - # eventually flip the switch to make this a hard error. - warnings.warn( - "Trying to unsafely apply AC to a non-functional graph with the " - "default partitioner. Falling back to min-cut partitioner." - ) - return min_cut_rematerialization_partition( - joint_module, - _joint_inputs, - num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, - ) - - joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True) - - if not config.unsafe_allow_optimization_of_collectives: - force_save_collectives(joint_module) - - force_save_bw_mutation_src(joint_module) - - if static_lifetime_input_indices is None: - static_lifetime_input_indices = [] - node_info = classify_nodes( - joint_module, static_lifetime_input_indices, num_fwd_outputs + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" ) - + forward_node_names = OrderedSet( + node.name for node in forward_only_graph.nodes if node.op != "output" + ) + order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} saved_values = [] saved_sym_nodes = [] - distributed_enabled = torch.distributed.is_available() - - def is_tensor(node): - return "tensor_meta" in node.meta or isinstance( - node.meta.get("val"), torch._subclasses.FakeTensor - ) - - def is_multi_output(node): - return ( - all(user.target == operator.getitem for user in node.users) - and len(node.users) > 0 - ) - - def is_impure(node): - # wait tensor is an "impure" op according to DCE's definition of impure - # (see is_impure in torch/fx/node.py), but it survives past - # functionalization and can be safely dup'd and reordered under the - # assumption SPMD. - return ( - node.is_impure(impure_random=False) - and node.op - not in ( - "placeholder", - "output", - ) - and ( - not distributed_enabled - or node.target is not torch.ops._c10d_functional.wait_tensor.default - ) - ) + def is_mutated_later_in_fw(node): + if _has_tag_is_backward(node): + return False + tensor_arg_aliases = [ + x + for x in node.args + if isinstance(x, fx.Node) + and "val" in x.meta + and isinstance(x.meta["val"], torch.Tensor) + ] + while len(tensor_arg_aliases) > 0: + a = tensor_arg_aliases.pop() + for u in a.users: + if not isinstance(u.target, torch._ops.OpOverload): + continue + # If we witness a mutation on our node later, and that mutation is not "must be in backward", + # then our node needs to be computed in the forward (otherwise we will compute it on the mutated values) + if ( + # one of the args was mutated + u.target._schema.is_mutable + # and the mutation happens "later" + and order[u] > order[node] + # and the mutation happened during the forward + and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u)) + ): + for idx, alias_info in enumerate(u.target._schema.arguments): + if alias_info.is_write and u.args[idx] is a: + return True + elif u.target.is_view: + tensor_arg_aliases.append(u) + return False for node in joint_module.graph.nodes: if node.name not in forward_node_names: + # if a node isn't "required" to be in the forward, but any of its arguments + # are later mutated in the forward, then it must have been run in the forward + # (if not, and the node's arg was saved for backward, we would have mutated a saved value) + # NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated + if is_mutated_later_in_fw(node): + saved_values.append(node) continue if node.target is torch.ops.aten._assert_scalar.default: continue @@ -1116,48 +1092,37 @@ def is_impure(node): # Symints must be kept separate from tensors so that PythonFunction only calls # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes.append(node) - continue - if is_multi_output(node): - # Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE. - continue - if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE: - saved_values.append(node) - continue - if is_impure(node): - assert not graph_has_recomputable_ops, ( - "Trying to apply AC on a graph with impure op", - node, - node.target, - ) - saved_values.append(node) - continue - assert is_tensor(node) or node.op != "call_function", ( - f"Expected {node} to be a tensor" - ) - backward_usages = [n for n in node.users if n.name not in forward_node_names] - if all(is_sym_node(n) for n in backward_usages): - # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, - # and not the actual tensor data, - # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. - # - # Note that saving the tensor could also cause compilation problems: - # If the user mutated an input in the forward and uses its sizes/strides in the backward, - # then we would be obligated to clone the input before saving it to appease autograd. - # (This is how we originally found this bug). - saved_sym_nodes.extend(backward_usages) - continue - if not must_recompute(node): - saved_values.append(node) - + elif ( + "tensor_meta" not in node.meta + and node.op == "call_function" + and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) + ): + # Since we can't save tuple of tensor values, we need to flatten out what we're saving + users = node.users + assert all(user.target is operator.getitem for user in users) + saved_values.extend(users) + else: + backward_usages = [ + n for n in node.users if n.name not in forward_node_names + ] + if "tensor_meta" in node.meta and all( + is_sym_node(n) for n in backward_usages + ): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + else: + saved_values.append(node) saved_values = list(dict.fromkeys(saved_values).keys()) saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) - if config._sync_decision_cross_ranks: - saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values) - - if static_lifetime_input_nodes is None: - static_lifetime_input_nodes = node_info.static_lifetime_input_nodes - fw_module, bw_module = _extract_fwd_bwd_modules( + return _extract_fwd_bwd_modules( joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, @@ -1165,37 +1130,6 @@ def is_impure(node): static_lifetime_input_nodes=static_lifetime_input_nodes, ) - # Run DCE while overriding the definition of is_impure_node - def is_not_collective(node): - if not distributed_enabled: - return True - if node.target is torch.ops._c10d_functional.wait_tensor.default: - return False - if node.target is torch.ops._c10d_functional.all_gather_into_tensor.default: - return False - return True - - fw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) - bw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) - - if graph_has_recomputable_ops: - if graph_has_recomputable_rng_ops: - fw_module, bw_module = functionalize_rng_ops( - joint_module, fw_module, bw_module, len(saved_sym_nodes) - ) - bw_module = reordering_to_mimic_autograd_engine(bw_module) - - # raise all getitem ops to as early as possible - # this is helpful for memory, especially in the case of aot_eager backend - fw_module = raise_getitems(fw_module) - bw_module = raise_getitems(bw_module) - - fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False) - if len(node_info.required_bw_nodes) > 0: - bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True) - - return fw_module, bw_module - INT_INF = int(1e6) @@ -1690,16 +1624,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: break -def is_getitem_of_multi_output(node): - if node.target != operator.getitem: - return False - parent = node.args[0] - return "tensor_meta" not in parent.meta and node.op == "call_function" - - -def cleanup_recompute_tags( - joint_module: fx.GraphModule, *, is_default_partition: bool -) -> fx.GraphModule: +def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in between, we would still want to stash the tensor at the boundary of @@ -1736,20 +1661,6 @@ def cleanup_recompute_tags( # Solution: check whether `out` has a backward hook, and if so, intentionally save `out` # in forward graph outputs. With this, we can break the above circular dependency. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE - elif ( - "ac_graph_id" not in node.meta - and any(must_recompute(user) for user in node.users) - and not ( - # Avoid saving getitem nodes which are not labeled with "ac_graph_id" - is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta - ) - and is_default_partition - ): - # This node is not part of the AC region and a user is marked as recompute. - # This means it's an input to the AC region and we should save it. - # For ease of landing, gate this to default partitioner only, but we should think - # about flipping the switch in general as well. - node.meta["recompute"] = CheckpointPolicy.MUST_SAVE return joint_module @@ -2859,59 +2770,6 @@ def thread_graphsafe_rng_from_hops(module, is_backward): return module -def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs): - name_to_node = get_name_to_node(joint_module.graph) - required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() - for node in joint_module.graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - required_bw_nodes.add(node) - elif _must_be_in_backward(node): - required_bw_nodes.add(node) - - if node in required_bw_nodes: - required_bw_nodes.update(node.users) - - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - required_bw_nodes.update( - o for o in bwd_outputs if o is not None and o.op != "output" - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) - required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( - name_to_node[node.name] - for node in forward_only_graph.nodes - if node.op != "output" - ) - unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( - node - for node in joint_module.graph.nodes - if node not in required_fw_nodes and node not in required_bw_nodes - ) - static_lifetime_input_nodes = OrderedSet( - p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices - ) - fw_cnt = 0 - fw_order = {} - for node in joint_module.graph.nodes: - if node in required_fw_nodes: - fw_order[node] = fw_cnt - fw_cnt += 1 - return NodeInfo( - inputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - fw_order, - static_lifetime_input_nodes, - ) - - def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs, @@ -2960,16 +2818,68 @@ def min_cut_rematerialization_partition( graph_has_recomputable_ops = has_recomputable_ops(joint_module) graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) if graph_has_recomputable_ops: - joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False) + joint_module = cleanup_recompute_tags(joint_module) if not config.unsafe_allow_optimization_of_collectives: force_save_collectives(joint_module) force_save_bw_mutation_src(joint_module) + def classify_nodes(joint_module, static_lifetime_input_indices): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + elif _must_be_in_backward(node): + required_bw_nodes.add(node) + + if node in required_bw_nodes: + required_bw_nodes.update(node.users) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list( + filter(_is_fwd_seed_offset, joint_module.graph.nodes) + ) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( + _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" + ) + required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + ) + unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + ) + static_lifetime_input_nodes = OrderedSet( + p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices + ) + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, + required_fw_nodes, + required_bw_nodes, + unclaimed_nodes, + fw_order, + static_lifetime_input_nodes, + ) + if static_lifetime_input_indices is None: static_lifetime_input_indices = [] - node_info = classify_nodes( - joint_module, static_lifetime_input_indices, num_fwd_outputs - ) + node_info = classify_nodes(joint_module, static_lifetime_input_indices) # networkx blows up on graphs with no required backward nodes # Since there's nothing to partition anyway, and the default partitioner can "handle" From 088048f2fea28ff7d450f65c72419ca45780d30b Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 28 Nov 2025 00:51:01 +0000 Subject: [PATCH 041/338] Remove cudaProfilerInitialize (#168918) cudaProfilerInitialize is not required since CUDA 12+. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168918 Approved by: https://github.com/eqy --- torch/CMakeLists.txt | 5 ----- torch/csrc/cuda/shared/cudart.cpp | 20 -------------------- torch/utils/hipify/cuda_to_hip_mappings.py | 4 ---- 3 files changed, 29 deletions(-) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index c7a43f30e49d5..3a3ca0f1236ec 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -309,11 +309,6 @@ if(USE_NCCL AND NOT WIN32) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL) endif() -if(NOT MSVC) - # cudaProfilerInitialize must go away - set_source_files_properties(${TORCH_SRC_DIR}/csrc/cuda/shared/cudart.cpp PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") -endif() - # coreml if(USE_COREML_DELEGATE) list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/jit/backends/coreml/cpp/backend.cpp) diff --git a/torch/csrc/cuda/shared/cudart.cpp b/torch/csrc/cuda/shared/cudart.cpp index e7012fe82dd8f..378811f3ce46d 100644 --- a/torch/csrc/cuda/shared/cudart.cpp +++ b/torch/csrc/cuda/shared/cudart.cpp @@ -28,17 +28,6 @@ void initCudartBindings(PyObject* module) { // By splitting the names of these objects into two literals we prevent the // HIP rewrite rules from changing these names when building with HIP. -#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12000 - // cudaOutputMode_t is used in cudaProfilerInitialize only. The latter is gone - // in CUDA 12. - py::enum_( - cudart, - "cuda" - "OutputMode") - .value("KeyValuePair", cudaKeyValuePair) - .value("CSV", cudaCSV); -#endif - py::enum_( cudart, "cuda" @@ -100,15 +89,6 @@ void initCudartBindings(PyObject* module) { // NOLINTNEXTLINE(performance-no-int-to-ptr) return C10_CUDA_ERROR_HANDLED(cudaStreamDestroy((cudaStream_t)ptr)); }); -#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12000 - // cudaProfilerInitialize is no longer needed after CUDA 12: - // https://forums.developer.nvidia.com/t/cudaprofilerinitialize-is-deprecated-alternative/200776/3 - cudart.def( - "cuda" - "ProfilerInitialize", - cudaProfilerInitialize, - py::call_guard()); -#endif cudart.def( "cuda" "MemGetInfo", diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index fb7dc1c7cb7f0..9a4b81ab5cfb2 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -5529,10 +5529,6 @@ ), ), ("cudaDeviceGetLimit", ("hipDeviceGetLimit", CONV_DEVICE, API_RUNTIME)), - ( - "cudaProfilerInitialize", - ("hipProfilerInitialize", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED), - ), ("cudaProfilerStart", ("hipProfilerStart", CONV_OTHER, API_RUNTIME)), ("cudaProfilerStop", ("hipProfilerStop", CONV_OTHER, API_RUNTIME)), ( From 9cd055e547e9b67a5f9827f8999c38d7eda1bcb8 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 28 Nov 2025 00:54:57 +0000 Subject: [PATCH 042/338] [2/N] Remove unused header inclusion (#165831) Remove unused header inclusion in JIT code and other locations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165831 Approved by: https://github.com/ngimel, https://github.com/albanD --- aten/src/ATen/core/DeprecatedTypeProperties.cpp | 2 -- aten/src/ATen/core/Formatting.cpp | 1 - aten/src/ATen/core/NamedRegistrations.cpp | 2 -- aten/src/ATen/core/PythonFallbackKernel.cpp | 1 - aten/src/ATen/core/Tensor.cpp | 3 --- aten/src/ATen/core/VariableFallbackKernel.cpp | 1 - aten/src/ATen/core/Vitals.cpp | 1 - aten/src/ATen/core/class_type.cpp | 2 -- aten/src/ATen/core/custom_class.cpp | 1 - aten/src/ATen/core/interned_strings.cpp | 2 -- aten/src/ATen/core/tensor_type.cpp | 1 - aten/src/ATen/core/type.cpp | 2 -- aten/src/ATen/core/union_type.cpp | 5 ----- c10/cuda/CUDADeviceAssertionHost.cpp | 1 - c10/cuda/CUDAException.cpp | 1 - c10/cuda/CUDAMiscFunctions.cpp | 1 - torch/csrc/Device.cpp | 3 --- torch/csrc/Dtype.cpp | 4 ---- torch/csrc/DynamicTypes.cpp | 9 --------- torch/csrc/Event.cpp | 4 ---- torch/csrc/Layout.cpp | 3 --- torch/csrc/MemoryFormat.cpp | 1 - torch/csrc/QScheme.cpp | 3 --- torch/csrc/Size.cpp | 4 ++-- torch/csrc/Stream.cpp | 3 --- torch/csrc/TypeInfo.cpp | 4 ---- .../autograd/functions/recvrpc_backward.cpp | 1 - torch/csrc/distributed/autograd/init.cpp | 3 --- .../autograd/rpc_messages/rpc_with_autograd.cpp | 1 - torch/csrc/distributed/autograd/utils.cpp | 3 --- torch/csrc/jit/api/function_impl.cpp | 1 - torch/csrc/jit/api/module.cpp | 7 ------- torch/csrc/jit/backends/backend_debug_handler.cpp | 2 -- torch/csrc/jit/backends/backend_init.cpp | 2 -- torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp | 1 - .../jit/backends/nnapi/nnapi_backend_preprocess.cpp | 2 -- torch/csrc/jit/codegen/cuda/interface.cpp | 8 -------- torch/csrc/jit/codegen/fuser/codegen.cpp | 4 ---- torch/csrc/jit/codegen/fuser/compiler.cpp | 7 ------- torch/csrc/jit/codegen/fuser/executor.cpp | 3 --- torch/csrc/jit/codegen/fuser/fallback.cpp | 2 -- torch/csrc/jit/codegen/fuser/interface.cpp | 3 --- torch/csrc/jit/codegen/onednn/decompose_silu.cpp | 2 -- torch/csrc/jit/codegen/onednn/graph_fuser.cpp | 2 -- torch/csrc/jit/codegen/onednn/graph_helper.cpp | 2 -- torch/csrc/jit/codegen/onednn/graph_rewriter.cpp | 3 --- torch/csrc/jit/codegen/onednn/guard_shape.cpp | 2 -- torch/csrc/jit/codegen/onednn/interface.cpp | 2 -- torch/csrc/jit/codegen/onednn/kernel.cpp | 1 - torch/csrc/jit/frontend/builtin_functions.cpp | 2 -- .../jit/frontend/canonicalize_modified_loop.cpp | 2 -- torch/csrc/jit/frontend/error_report.cpp | 2 -- torch/csrc/jit/frontend/inline_loop_condition.cpp | 2 -- torch/csrc/jit/frontend/ir_emitter.cpp | 5 ----- torch/csrc/jit/frontend/lexer.cpp | 2 -- torch/csrc/jit/frontend/schema_matching.cpp | 1 - torch/csrc/jit/frontend/sugared_value.cpp | 2 -- torch/csrc/jit/frontend/tracer.cpp | 8 -------- torch/csrc/jit/frontend/versioned_symbols.cpp | 3 --- torch/csrc/jit/ir/alias_analysis.cpp | 1 - torch/csrc/jit/ir/constants.cpp | 4 ---- torch/csrc/jit/ir/ir.cpp | 3 --- torch/csrc/jit/ir/irparser.cpp | 2 -- torch/csrc/jit/ir/node_hashing.cpp | 3 --- torch/csrc/jit/ir/type_hashing.cpp | 3 --- torch/csrc/jit/jit_log.cpp | 1 - torch/csrc/jit/jit_opt_limit.cpp | 4 ---- torch/csrc/jit/mobile/compatibility/backport.cpp | 2 -- .../jit/mobile/compatibility/backport_manager.cpp | 2 -- torch/csrc/jit/mobile/function.cpp | 2 -- torch/csrc/jit/mobile/import.cpp | 2 -- torch/csrc/jit/mobile/interpreter.cpp | 2 -- torch/csrc/jit/mobile/module.cpp | 3 --- torch/csrc/jit/mobile/parse_bytecode.cpp | 1 - torch/csrc/jit/mobile/register_ops_common_utils.cpp | 1 - torch/csrc/jit/mobile/train/export_data.cpp | 1 - torch/csrc/jit/mobile/train/optim/sgd.cpp | 3 --- torch/csrc/jit/mobile/train/sequential.cpp | 1 - torch/csrc/jit/mobile/upgrader_mobile.cpp | 1 - torch/csrc/jit/operator_upgraders/utils.cpp | 3 +-- torch/csrc/jit/passes/autocast.cpp | 1 - torch/csrc/jit/passes/bailout_graph.cpp | 3 --- torch/csrc/jit/passes/check_strict_fusion.cpp | 1 - .../jit/passes/common_subexpression_elimination.cpp | 2 -- torch/csrc/jit/passes/concat_opt.cpp | 2 -- torch/csrc/jit/passes/constant_propagation.cpp | 2 -- torch/csrc/jit/passes/create_autodiff_subgraphs.cpp | 2 -- torch/csrc/jit/passes/create_functional_graphs.cpp | 1 - .../dbr_quantization/remove_redundant_aliases.cpp | 1 - torch/csrc/jit/passes/dtype_analysis.cpp | 5 ----- torch/csrc/jit/passes/erase_number_types.cpp | 1 - torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp | 3 --- torch/csrc/jit/passes/fold_conv_bn.cpp | 1 - torch/csrc/jit/passes/frozen_concat_linear.cpp | 6 ------ .../csrc/jit/passes/frozen_conv_add_relu_fusion.cpp | 6 ------ .../jit/passes/frozen_conv_add_relu_fusion_cuda.cpp | 2 -- torch/csrc/jit/passes/frozen_conv_folding.cpp | 2 -- .../csrc/jit/passes/frozen_graph_optimizations.cpp | 4 ---- torch/csrc/jit/passes/frozen_linear_transpose.cpp | 3 --- torch/csrc/jit/passes/fuse_relu.cpp | 1 - torch/csrc/jit/passes/graph_fuser.cpp | 4 ---- torch/csrc/jit/passes/guard_elimination.cpp | 3 --- torch/csrc/jit/passes/hoist_conv_packed_params.cpp | 2 -- torch/csrc/jit/passes/inliner.cpp | 2 -- torch/csrc/jit/passes/insert_guards.cpp | 1 - torch/csrc/jit/passes/integer_value_refinement.cpp | 1 - torch/csrc/jit/passes/liveness.cpp | 2 -- torch/csrc/jit/passes/lower_tuples.cpp | 1 - torch/csrc/jit/passes/metal_rewrite.cpp | 4 ---- torch/csrc/jit/passes/mkldnn_rewrite.cpp | 1 - torch/csrc/jit/passes/normalize_ops.cpp | 2 -- torch/csrc/jit/passes/onnx.cpp | 3 --- torch/csrc/jit/passes/onnx/constant_map.cpp | 2 -- torch/csrc/jit/passes/onnx/eval_peephole.cpp | 2 -- .../csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp | 2 -- torch/csrc/jit/passes/onnx/helper.cpp | 2 -- .../csrc/jit/passes/onnx/list_model_parameters.cpp | 2 -- .../autograd_function_process.cpp | 2 -- .../pattern_conversion/pattern_encapsulation.cpp | 3 --- torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp | 1 - torch/csrc/jit/passes/peephole.cpp | 3 --- torch/csrc/jit/passes/peephole_alias_sensitive.cpp | 6 ------ torch/csrc/jit/passes/peephole_list_idioms.cpp | 3 --- torch/csrc/jit/passes/peephole_non_tensor.cpp | 1 - torch/csrc/jit/passes/prepack_folding.cpp | 1 - .../jit/passes/quantization/insert_observers.cpp | 2 -- torch/csrc/jit/passes/refine_tuple_types.cpp | 2 -- torch/csrc/jit/passes/remove_redundant_profiles.cpp | 2 -- .../jit/passes/replacement_of_old_operators.cpp | 1 - torch/csrc/jit/passes/requires_grad_analysis.cpp | 1 - torch/csrc/jit/passes/restore_mutation.cpp | 2 -- torch/csrc/jit/passes/shape_analysis.cpp | 3 --- torch/csrc/jit/passes/symbolic_shape_analysis.cpp | 4 ---- .../jit/passes/symbolic_shape_runtime_fusion.cpp | 1 - torch/csrc/jit/passes/tensorexpr_fuser.cpp | 2 -- .../update_differentiable_graph_requires_grad.cpp | 1 - torch/csrc/jit/passes/utils/memory_dag.cpp | 1 - torch/csrc/jit/passes/utils/subgraph_utils.cpp | 1 - torch/csrc/jit/passes/vulkan_rewrite.cpp | 1 - torch/csrc/jit/python/pybind_utils.cpp | 2 -- torch/csrc/jit/python/python_custom_class.cpp | 2 -- torch/csrc/jit/python/python_dict.cpp | 1 - torch/csrc/jit/python/python_interpreter.cpp | 13 ------------- torch/csrc/jit/python/python_ir.cpp | 4 ---- torch/csrc/jit/python/python_sugared_value.cpp | 4 ---- torch/csrc/jit/python/python_tracer.cpp | 1 - torch/csrc/jit/runtime/autodiff.cpp | 1 - torch/csrc/jit/runtime/decomposition_registry.cpp | 1 - .../jit/runtime/decomposition_registry_util.cpp | 3 --- torch/csrc/jit/runtime/graph_executor.cpp | 1 - torch/csrc/jit/runtime/interpreter.cpp | 8 -------- torch/csrc/jit/runtime/jit_trace.cpp | 6 ------ torch/csrc/jit/runtime/operator.cpp | 1 - .../jit/runtime/profiling_graph_executor_impl.cpp | 3 --- torch/csrc/jit/runtime/profiling_record.cpp | 1 - torch/csrc/jit/runtime/register_c10_ops.cpp | 3 --- torch/csrc/jit/runtime/register_cuda_ops.cpp | 1 - torch/csrc/jit/runtime/register_distributed_ops.cpp | 3 --- .../csrc/jit/runtime/register_prim_ops_fulljit.cpp | 11 ----------- torch/csrc/jit/runtime/register_special_ops.cpp | 3 --- torch/csrc/jit/runtime/script_profile.cpp | 1 - .../runtime/serialized_shape_function_registry.cpp | 3 --- torch/csrc/jit/runtime/static/fusion.cpp | 1 - torch/csrc/jit/runtime/static/impl.cpp | 3 --- torch/csrc/jit/runtime/static/memory_planner.cpp | 3 --- torch/csrc/jit/runtime/static/native_ops.cpp | 5 ----- torch/csrc/jit/runtime/static/passes.cpp | 2 -- torch/csrc/jit/runtime/symbolic_shape_registry.cpp | 1 - .../jit/runtime/symbolic_shape_registry_util.cpp | 5 ----- torch/csrc/jit/testing/file_check.cpp | 1 - torch/csrc/python_dimname.cpp | 1 - torch/csrc/utils/cpp_stacktraces.cpp | 3 --- torch/csrc/utils/device_lazy_init.cpp | 1 - torch/csrc/utils/disable_torch_function.cpp | 1 - torch/csrc/utils/init.cpp | 3 --- torch/csrc/utils/object_ptr.cpp | 2 -- torch/csrc/utils/python_dispatch.cpp | 8 -------- torch/csrc/utils/tensor_apply.cpp | 2 -- torch/csrc/utils/tensor_dtypes.cpp | 1 - torch/csrc/utils/tensor_layouts.cpp | 3 --- torch/csrc/utils/tensor_list.cpp | 2 -- torch/csrc/utils/tensor_memoryformats.cpp | 2 -- torch/csrc/utils/tensor_qschemes.cpp | 2 -- torch/csrc/utils/tensor_types.cpp | 2 -- torch/csrc/utils/throughput_benchmark.cpp | 2 -- 185 files changed, 3 insertions(+), 462 deletions(-) diff --git a/aten/src/ATen/core/DeprecatedTypeProperties.cpp b/aten/src/ATen/core/DeprecatedTypeProperties.cpp index a97a6828571e7..369556aad9152 100644 --- a/aten/src/ATen/core/DeprecatedTypeProperties.cpp +++ b/aten/src/ATen/core/DeprecatedTypeProperties.cpp @@ -1,7 +1,5 @@ #include -#include -#include #include namespace at { diff --git a/aten/src/ATen/core/Formatting.cpp b/aten/src/ATen/core/Formatting.cpp index eddd5e4b4d6cf..62b16a83e523b 100644 --- a/aten/src/ATen/core/Formatting.cpp +++ b/aten/src/ATen/core/Formatting.cpp @@ -9,7 +9,6 @@ #include #include #include -#include namespace c10 { std::ostream& operator<<(std::ostream& out, Backend b) { diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index b78a563b673b0..fc2193e70cb19 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -1,7 +1,5 @@ #include -#include - using torch::CppFunction; TORCH_LIBRARY_IMPL(_, Named, m) { diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index 39f4e7cb69764..7b2b32531f059 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include namespace { diff --git a/aten/src/ATen/core/Tensor.cpp b/aten/src/ATen/core/Tensor.cpp index 090e77e703736..70907a60b65ae 100644 --- a/aten/src/ATen/core/Tensor.cpp +++ b/aten/src/ATen/core/Tensor.cpp @@ -1,8 +1,5 @@ #include -#include #include -#include -#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/aten/src/ATen/core/VariableFallbackKernel.cpp b/aten/src/ATen/core/VariableFallbackKernel.cpp index dad3f090bb1ea..94422df404558 100644 --- a/aten/src/ATen/core/VariableFallbackKernel.cpp +++ b/aten/src/ATen/core/VariableFallbackKernel.cpp @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/aten/src/ATen/core/Vitals.cpp b/aten/src/ATen/core/Vitals.cpp index ac1ee45d58345..db58c03830539 100644 --- a/aten/src/ATen/core/Vitals.cpp +++ b/aten/src/ATen/core/Vitals.cpp @@ -1,6 +1,5 @@ #include #include -#include #include namespace at::vitals { diff --git a/aten/src/ATen/core/class_type.cpp b/aten/src/ATen/core/class_type.cpp index a65124e80979e..ec1dba6192ac3 100644 --- a/aten/src/ATen/core/class_type.cpp +++ b/aten/src/ATen/core/class_type.cpp @@ -1,12 +1,10 @@ #include #include -#include #include #include #include #include -#include #include namespace c10 { diff --git a/aten/src/ATen/core/custom_class.cpp b/aten/src/ATen/core/custom_class.cpp index 2c9cc465466a3..820d27097c4db 100644 --- a/aten/src/ATen/core/custom_class.cpp +++ b/aten/src/ATen/core/custom_class.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include diff --git a/aten/src/ATen/core/interned_strings.cpp b/aten/src/ATen/core/interned_strings.cpp index 799f6821bb928..018ee82fe3227 100644 --- a/aten/src/ATen/core/interned_strings.cpp +++ b/aten/src/ATen/core/interned_strings.cpp @@ -2,12 +2,10 @@ #undef TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include #include #include #include #include -#include #include namespace c10 { diff --git a/aten/src/ATen/core/tensor_type.cpp b/aten/src/ATen/core/tensor_type.cpp index d428aceb3d04c..debd5e92bbc04 100644 --- a/aten/src/ATen/core/tensor_type.cpp +++ b/aten/src/ATen/core/tensor_type.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 35a729ccc9f39..215f91eed68be 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -1,10 +1,8 @@ #include -#include #include #include #include #include -#include #include #include #include diff --git a/aten/src/ATen/core/union_type.cpp b/aten/src/ATen/core/union_type.cpp index 8731c2cbc4952..6113041f15476 100644 --- a/aten/src/ATen/core/union_type.cpp +++ b/aten/src/ATen/core/union_type.cpp @@ -1,10 +1,5 @@ #include -#include -#include -#include -#include #include -#include #include #include #include diff --git a/c10/cuda/CUDADeviceAssertionHost.cpp b/c10/cuda/CUDADeviceAssertionHost.cpp index 08e657a411614..43dbb92531c14 100644 --- a/c10/cuda/CUDADeviceAssertionHost.cpp +++ b/c10/cuda/CUDADeviceAssertionHost.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 4e4419b4369a8..f0dbe49d2ea6c 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -2,7 +2,6 @@ #include #include -#include #include diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index 49bad41dda866..70bb0f841b35c 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -1,6 +1,5 @@ #include #include -#include #include namespace c10::cuda { diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index da7b287369dab..b3acb4e4bb466 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -2,15 +2,12 @@ #include #include -#include #include #include #include -#include #include -#include #include #include diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp index c302378de81e4..bff17ca0cbc79 100644 --- a/torch/csrc/Dtype.cpp +++ b/torch/csrc/Dtype.cpp @@ -1,15 +1,11 @@ #include #include -#include #include #include #include #include #include -#include -#include -#include #include PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) { diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index d5621146fef88..9db1903eec33a 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -1,18 +1,9 @@ -#include -#include #include #include #include #include #include -#include -#include -#include -#include - -#include -#include #include #include diff --git a/torch/csrc/Event.cpp b/torch/csrc/Event.cpp index fd7d72228fcea..f5bb1b60eac57 100644 --- a/torch/csrc/Event.cpp +++ b/torch/csrc/Event.cpp @@ -1,9 +1,6 @@ -#include #include #include #include -#include -#include #include #include @@ -12,7 +9,6 @@ #include #include -#include #include PyTypeObject* THPEventClass = nullptr; diff --git a/torch/csrc/Layout.cpp b/torch/csrc/Layout.cpp index 06b49d56f649d..af7dfc74379de 100644 --- a/torch/csrc/Layout.cpp +++ b/torch/csrc/Layout.cpp @@ -4,9 +4,6 @@ #include #include -#include - -#include #include #include diff --git a/torch/csrc/MemoryFormat.cpp b/torch/csrc/MemoryFormat.cpp index 5bd3f9eed42d6..0a8e212500cf1 100644 --- a/torch/csrc/MemoryFormat.cpp +++ b/torch/csrc/MemoryFormat.cpp @@ -6,7 +6,6 @@ #include -#include #include #include diff --git a/torch/csrc/QScheme.cpp b/torch/csrc/QScheme.cpp index 3fbabc1026f5e..e178ec9247ea5 100644 --- a/torch/csrc/QScheme.cpp +++ b/torch/csrc/QScheme.cpp @@ -4,9 +4,6 @@ #include #include -#include - -#include #include #include diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index ea39424cf8ea7..7a136420d7981 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -1,12 +1,12 @@ #include #include #include -#include +// #include #include -#include #include #include +#include #include #include diff --git a/torch/csrc/Stream.cpp b/torch/csrc/Stream.cpp index 6993f726597cb..3b290b0cfbe55 100644 --- a/torch/csrc/Stream.cpp +++ b/torch/csrc/Stream.cpp @@ -1,9 +1,6 @@ -#include #include #include #include -#include -#include #include #include diff --git a/torch/csrc/TypeInfo.cpp b/torch/csrc/TypeInfo.cpp index de23b79536033..355202c7e40f9 100644 --- a/torch/csrc/TypeInfo.cpp +++ b/torch/csrc/TypeInfo.cpp @@ -2,18 +2,14 @@ #include #include -#include #include #include #include -#include #include #include -#include -#include #include #include diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index c2d4630bdd0df..7bd98144439b4 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/torch/csrc/distributed/autograd/init.cpp b/torch/csrc/distributed/autograd/init.cpp index 115d371524d0e..1d4bacc094322 100644 --- a/torch/csrc/distributed/autograd/init.cpp +++ b/torch/csrc/distributed/autograd/init.cpp @@ -2,10 +2,7 @@ #include #include #include -#include #include -#include -#include namespace torch::distributed::autograd { diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp index fd5ab54e58cfa..4e9af6d1240ab 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp @@ -2,7 +2,6 @@ #include #include #include -#include namespace torch::distributed::autograd { diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index 84ddaa1a5ce07..ec1bbf13375f3 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -1,7 +1,4 @@ -#include -#include #include -#include #include #include #include diff --git a/torch/csrc/jit/api/function_impl.cpp b/torch/csrc/jit/api/function_impl.cpp index 0c911970347bd..5fb7fc1f01781 100644 --- a/torch/csrc/jit/api/function_impl.cpp +++ b/torch/csrc/jit/api/function_impl.cpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 61c32680c7c0b..88d235b8a27c2 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -1,22 +1,15 @@ -#include #include #include #include #include -#include #include #include -#include -#include -#include #include -#include #include #include #include #include #include -#include #include #include diff --git a/torch/csrc/jit/backends/backend_debug_handler.cpp b/torch/csrc/jit/backends/backend_debug_handler.cpp index 0d41034130395..ec9f2e4fa5611 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.cpp +++ b/torch/csrc/jit/backends/backend_debug_handler.cpp @@ -1,7 +1,5 @@ #include -#include - namespace torch::jit { std::atomic BackendDebugInfoRecorder::unique_debug_handle_{0}; diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp index b10aba884c721..ea71203412ef5 100644 --- a/torch/csrc/jit/backends/backend_init.cpp +++ b/torch/csrc/jit/backends/backend_init.cpp @@ -2,10 +2,8 @@ #include #include -#include #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp index 18c1bc62b8c6d..b0e368d6a3027 100644 --- a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp +++ b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include diff --git a/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp b/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp index 070e96c4f18d7..af6e9909deaa1 100644 --- a/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp +++ b/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp @@ -1,7 +1,5 @@ #include -#include #include -#include #include namespace py = pybind11; diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index 8dfa2bcc09c4a..d47c9f654d2bb 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -1,13 +1,5 @@ #include -#include -#include -#include -#include -#include -#include -#include - namespace torch::jit::fuser::cuda { static std::atomic cuda_fusion_guard_mode{true}; diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index a5cd6f4e3a43d..cb787cc2b58b3 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -1,11 +1,8 @@ #include -#include #include #include #include -#include -#include #include #include @@ -15,7 +12,6 @@ #include #include #include -#include #include namespace torch::jit::fuser { diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index a1ff6cb613e86..21d1a4734f70d 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -1,25 +1,18 @@ #include -#include #include #include #include #include -#include #include #include #include -#include #include #include #include -#include #include -#include -#include #include -#include #include #include diff --git a/torch/csrc/jit/codegen/fuser/executor.cpp b/torch/csrc/jit/codegen/fuser/executor.cpp index 67c4501dc2758..d66c8f94db4e7 100644 --- a/torch/csrc/jit/codegen/fuser/executor.cpp +++ b/torch/csrc/jit/codegen/fuser/executor.cpp @@ -1,8 +1,6 @@ #include -#include #include -#include #include #include #include @@ -13,7 +11,6 @@ #include #include -#include #include namespace torch::jit::fuser { diff --git a/torch/csrc/jit/codegen/fuser/fallback.cpp b/torch/csrc/jit/codegen/fuser/fallback.cpp index 698e2882d6a55..a3655b6382407 100644 --- a/torch/csrc/jit/codegen/fuser/fallback.cpp +++ b/torch/csrc/jit/codegen/fuser/fallback.cpp @@ -1,13 +1,11 @@ #include -#include //fmap #include #include #include #include #include #include -#include namespace torch::jit::fuser { diff --git a/torch/csrc/jit/codegen/fuser/interface.cpp b/torch/csrc/jit/codegen/fuser/interface.cpp index 41efa23e2b434..90537815be4e1 100644 --- a/torch/csrc/jit/codegen/fuser/interface.cpp +++ b/torch/csrc/jit/codegen/fuser/interface.cpp @@ -3,11 +3,8 @@ #include #include #include -#include #include -#include -#include namespace torch::jit { diff --git a/torch/csrc/jit/codegen/onednn/decompose_silu.cpp b/torch/csrc/jit/codegen/onednn/decompose_silu.cpp index 8a9e36c2973e4..0a03cf6c87190 100644 --- a/torch/csrc/jit/codegen/onednn/decompose_silu.cpp +++ b/torch/csrc/jit/codegen/onednn/decompose_silu.cpp @@ -1,9 +1,7 @@ #include #include -#include #include -#include namespace torch::jit::fuser::onednn { diff --git a/torch/csrc/jit/codegen/onednn/graph_fuser.cpp b/torch/csrc/jit/codegen/onednn/graph_fuser.cpp index 1c68edca761ba..2c6c96e6ede0f 100644 --- a/torch/csrc/jit/codegen/onednn/graph_fuser.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_fuser.cpp @@ -1,9 +1,7 @@ #include -#include #include #include #include -#include namespace torch::jit::fuser::onednn { diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.cpp b/torch/csrc/jit/codegen/onednn/graph_helper.cpp index 2ef9f3cfa955c..46e65cac23d06 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_helper.cpp @@ -1,7 +1,5 @@ -#include #include -#include #include #include diff --git a/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp b/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp index c8d7617fe8651..6780fffac01bb 100644 --- a/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp @@ -1,9 +1,6 @@ #include #include #include -#include -#include -#include namespace torch::jit::fuser::onednn { diff --git a/torch/csrc/jit/codegen/onednn/guard_shape.cpp b/torch/csrc/jit/codegen/onednn/guard_shape.cpp index a71f980d631f5..f7f1dc3776eed 100644 --- a/torch/csrc/jit/codegen/onednn/guard_shape.cpp +++ b/torch/csrc/jit/codegen/onednn/guard_shape.cpp @@ -2,8 +2,6 @@ #include #include -#include -#include namespace torch::jit::fuser::onednn { diff --git a/torch/csrc/jit/codegen/onednn/interface.cpp b/torch/csrc/jit/codegen/onednn/interface.cpp index 2d29c8fa0f755..459fd9684c408 100644 --- a/torch/csrc/jit/codegen/onednn/interface.cpp +++ b/torch/csrc/jit/codegen/onednn/interface.cpp @@ -8,8 +8,6 @@ #include #include #include -#include -#include #include #include #include diff --git a/torch/csrc/jit/codegen/onednn/kernel.cpp b/torch/csrc/jit/codegen/onednn/kernel.cpp index 85afc5fa8dc7b..2d6d48921847d 100644 --- a/torch/csrc/jit/codegen/onednn/kernel.cpp +++ b/torch/csrc/jit/codegen/onednn/kernel.cpp @@ -1,7 +1,6 @@ #include #include -#include #include namespace torch::jit::fuser::onednn { diff --git a/torch/csrc/jit/frontend/builtin_functions.cpp b/torch/csrc/jit/frontend/builtin_functions.cpp index 2225f58e54e75..38f142fb0ee28 100644 --- a/torch/csrc/jit/frontend/builtin_functions.cpp +++ b/torch/csrc/jit/frontend/builtin_functions.cpp @@ -1,8 +1,6 @@ #include #include -#include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp b/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp index f2ef8b0e953c4..63369535a9e77 100644 --- a/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp +++ b/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp @@ -1,8 +1,6 @@ -#include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/frontend/error_report.cpp b/torch/csrc/jit/frontend/error_report.cpp index 47a9343c5387f..6942c1bfb5944 100644 --- a/torch/csrc/jit/frontend/error_report.cpp +++ b/torch/csrc/jit/frontend/error_report.cpp @@ -1,7 +1,5 @@ #include -#include - namespace torch::jit { // Avoid storing objects with destructor in thread_local for mobile build. diff --git a/torch/csrc/jit/frontend/inline_loop_condition.cpp b/torch/csrc/jit/frontend/inline_loop_condition.cpp index da23769f402ae..6d3129c31a127 100644 --- a/torch/csrc/jit/frontend/inline_loop_condition.cpp +++ b/torch/csrc/jit/frontend/inline_loop_condition.cpp @@ -1,8 +1,6 @@ #include #include -#include -#include #include #include diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index fba613b5ea8f7..f1941215fcb96 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -2,10 +2,8 @@ #include #include -#include #include #include -#include #include #include #include @@ -18,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -29,7 +26,6 @@ #include #include #include -#include #include @@ -39,7 +35,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/jit/frontend/lexer.cpp b/torch/csrc/jit/frontend/lexer.cpp index 187721671e6e2..7fd0b66bba55e 100644 --- a/torch/csrc/jit/frontend/lexer.cpp +++ b/torch/csrc/jit/frontend/lexer.cpp @@ -1,7 +1,5 @@ #include -#include - #include #include #include diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index c3525ac9c8a20..83742b40ae9cc 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index f9a80cf4da5e4..9ebf9a7e06d4d 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -2,9 +2,7 @@ #include #include -#include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 3ccbd5257ae25..0a0709fdf506d 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -1,25 +1,17 @@ #include -#include #include #include #include -#include #include #include -#include -#include #include #include #include #include -#include #include #include -#include #include -#include -#include #include #include diff --git a/torch/csrc/jit/frontend/versioned_symbols.cpp b/torch/csrc/jit/frontend/versioned_symbols.cpp index 0a468d12d0216..6808804ba5f0b 100644 --- a/torch/csrc/jit/frontend/versioned_symbols.cpp +++ b/torch/csrc/jit/frontend/versioned_symbols.cpp @@ -1,8 +1,5 @@ #include -#include -#include - #include namespace torch::jit { diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 513258236ac4b..51dbb09db9ea0 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index e17c981a746e3..d3524f1ac1044 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -1,11 +1,7 @@ -#include #include -#include #include #include -#include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 9b00a703e352e..c5dfa56b48a2e 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1,16 +1,13 @@ #include -#include #include #include -#include #include #include #include #include #include #include -#include #include #include diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index 2fadc7d573e25..0fbf660da3b04 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include @@ -9,7 +8,6 @@ #ifndef AT_PER_OPERATOR_HEADERS #include #else -#include #include #endif diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index 1551e610c3d10..5e1e3c5aab153 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -1,15 +1,12 @@ #include #include -#include -#include #include #include #include #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/ir/type_hashing.cpp b/torch/csrc/jit/ir/type_hashing.cpp index 5d1c03cb493b2..2929f5aabd656 100644 --- a/torch/csrc/jit/ir/type_hashing.cpp +++ b/torch/csrc/jit/ir/type_hashing.cpp @@ -1,8 +1,5 @@ #include -#include -#include -#include #include #include diff --git a/torch/csrc/jit/jit_log.cpp b/torch/csrc/jit/jit_log.cpp index 83f0e158d31bb..9ae31ab11d1d0 100644 --- a/torch/csrc/jit/jit_log.cpp +++ b/torch/csrc/jit/jit_log.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/jit_opt_limit.cpp b/torch/csrc/jit/jit_opt_limit.cpp index c4c1a2307659f..d182a70fda443 100644 --- a/torch/csrc/jit/jit_opt_limit.cpp +++ b/torch/csrc/jit/jit_opt_limit.cpp @@ -1,13 +1,9 @@ -#include #include #include #include -#include -#include #include #include -#include #include // NOTE: Don't try to migrate jit to C++17 yet diff --git a/torch/csrc/jit/mobile/compatibility/backport.cpp b/torch/csrc/jit/mobile/compatibility/backport.cpp index e8d13b1955795..0b264a650da0c 100644 --- a/torch/csrc/jit/mobile/compatibility/backport.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport.cpp @@ -1,5 +1,3 @@ -#include -#include #include #include #include diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 4422608423ee7..c84e05f8a3f12 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -1,11 +1,9 @@ #include #include -#include #include #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 87128a180a6d6..3dc960040c88e 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -2,9 +2,7 @@ #include #include #include -#include #include -#include #include #include diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index ab05e48143e3e..16b1cf29b4e8b 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -17,12 +17,10 @@ #include #include #include -#include #include #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 41fc8d49efb16..a0e0959d6033d 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -13,7 +12,6 @@ #include #include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index fb38d70d6f340..fbe262b2fbc66 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -1,12 +1,9 @@ #include #include -#include #include #include -#include -#include #include #include diff --git a/torch/csrc/jit/mobile/parse_bytecode.cpp b/torch/csrc/jit/mobile/parse_bytecode.cpp index 1a1e278e371f8..1cb1661396276 100644 --- a/torch/csrc/jit/mobile/parse_bytecode.cpp +++ b/torch/csrc/jit/mobile/parse_bytecode.cpp @@ -6,7 +6,6 @@ #include #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/mobile/register_ops_common_utils.cpp b/torch/csrc/jit/mobile/register_ops_common_utils.cpp index 11e1481a8de4f..147bd7cbd569c 100644 --- a/torch/csrc/jit/mobile/register_ops_common_utils.cpp +++ b/torch/csrc/jit/mobile/register_ops_common_utils.cpp @@ -1,5 +1,4 @@ #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/mobile/train/export_data.cpp b/torch/csrc/jit/mobile/train/export_data.cpp index 2d0a91096a0c1..867a2be6b3d9f 100644 --- a/torch/csrc/jit/mobile/train/export_data.cpp +++ b/torch/csrc/jit/mobile/train/export_data.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/jit/mobile/train/optim/sgd.cpp b/torch/csrc/jit/mobile/train/optim/sgd.cpp index 1523c5629a9cb..1fedb07e3b4aa 100644 --- a/torch/csrc/jit/mobile/train/optim/sgd.cpp +++ b/torch/csrc/jit/mobile/train/optim/sgd.cpp @@ -1,10 +1,7 @@ #include -#include #include -#include - namespace torch::jit::mobile { bool SGDParamGroup::has_options() const { diff --git a/torch/csrc/jit/mobile/train/sequential.cpp b/torch/csrc/jit/mobile/train/sequential.cpp index 3b76db5e8d0cb..e249b2e340f79 100644 --- a/torch/csrc/jit/mobile/train/sequential.cpp +++ b/torch/csrc/jit/mobile/train/sequential.cpp @@ -1,5 +1,4 @@ #include -#include #include #include diff --git a/torch/csrc/jit/mobile/upgrader_mobile.cpp b/torch/csrc/jit/mobile/upgrader_mobile.cpp index 04bc12f1d1046..c78fc4397218e 100644 --- a/torch/csrc/jit/mobile/upgrader_mobile.cpp +++ b/torch/csrc/jit/mobile/upgrader_mobile.cpp @@ -5,7 +5,6 @@ * cd ~/pytorch && python torchgen/operator_versions/gen_mobile_upgraders.py */ -#include #include namespace c10 { diff --git a/torch/csrc/jit/operator_upgraders/utils.cpp b/torch/csrc/jit/operator_upgraders/utils.cpp index 98819b08d640b..fe110b5d570f3 100644 --- a/torch/csrc/jit/operator_upgraders/utils.cpp +++ b/torch/csrc/jit/operator_upgraders/utils.cpp @@ -2,9 +2,8 @@ #include #include -#include +#include #include -#include #include #include diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index 4699cceec5b0d..79ead2c5ee6c3 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/bailout_graph.cpp b/torch/csrc/jit/passes/bailout_graph.cpp index 5bea5e42c0d28..4fb339d0d53c1 100644 --- a/torch/csrc/jit/passes/bailout_graph.cpp +++ b/torch/csrc/jit/passes/bailout_graph.cpp @@ -2,14 +2,11 @@ #include #include -#include #include #include #include -#include #include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/check_strict_fusion.cpp b/torch/csrc/jit/passes/check_strict_fusion.cpp index 731382c316398..2a5ed995c1050 100644 --- a/torch/csrc/jit/passes/check_strict_fusion.cpp +++ b/torch/csrc/jit/passes/check_strict_fusion.cpp @@ -1,7 +1,6 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index cfa0ee4978826..e5d214762e2ab 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -5,8 +5,6 @@ #include #include -#include - namespace torch::jit { namespace { diff --git a/torch/csrc/jit/passes/concat_opt.cpp b/torch/csrc/jit/passes/concat_opt.cpp index a651458eb5e93..b21a65ea98dbe 100644 --- a/torch/csrc/jit/passes/concat_opt.cpp +++ b/torch/csrc/jit/passes/concat_opt.cpp @@ -11,9 +11,7 @@ #include #include #include -#include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index 97d4c6262ed9e..5e95f1eae39ec 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -1,10 +1,8 @@ #include -#include #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index cac257125b0fc..d0c836df9ffca 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -4,9 +4,7 @@ #include #include #include -#include #include -#include #include #include diff --git a/torch/csrc/jit/passes/create_functional_graphs.cpp b/torch/csrc/jit/passes/create_functional_graphs.cpp index 86e9fa13893f6..562659788d7b4 100644 --- a/torch/csrc/jit/passes/create_functional_graphs.cpp +++ b/torch/csrc/jit/passes/create_functional_graphs.cpp @@ -6,7 +6,6 @@ #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp b/torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp index 1d35b30c05024..6fcf40a4ded3e 100644 --- a/torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp +++ b/torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp @@ -2,7 +2,6 @@ #include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/dtype_analysis.cpp b/torch/csrc/jit/passes/dtype_analysis.cpp index 9cbe6a936232b..64f3165dbce82 100644 --- a/torch/csrc/jit/passes/dtype_analysis.cpp +++ b/torch/csrc/jit/passes/dtype_analysis.cpp @@ -1,14 +1,9 @@ -#include #include -#include -#include #include -#include #include #include #include #include -#include #include #ifndef AT_PER_OPERATOR_HEADERS diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 03b370576d57c..fd5fbdbaf3d83 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -2,7 +2,6 @@ #include #include -#include #include diff --git a/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp b/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp index 1bfa045d2d3f8..5320f88e12ccd 100644 --- a/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp +++ b/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp @@ -2,13 +2,10 @@ #include #include -#include #include #include #include -#include - namespace torch::jit { namespace { diff --git a/torch/csrc/jit/passes/fold_conv_bn.cpp b/torch/csrc/jit/passes/fold_conv_bn.cpp index c2cb24f3287ca..b1c6c52e99e46 100644 --- a/torch/csrc/jit/passes/fold_conv_bn.cpp +++ b/torch/csrc/jit/passes/fold_conv_bn.cpp @@ -10,7 +10,6 @@ #ifndef AT_PER_OPERATOR_HEADERS #include #else -#include #include #include #include diff --git a/torch/csrc/jit/passes/frozen_concat_linear.cpp b/torch/csrc/jit/passes/frozen_concat_linear.cpp index e2270aa8bd763..fc864a6346991 100644 --- a/torch/csrc/jit/passes/frozen_concat_linear.cpp +++ b/torch/csrc/jit/passes/frozen_concat_linear.cpp @@ -1,14 +1,8 @@ -#include #include #include -#include #include #include -#include -#include -#include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp b/torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp index 20edcdd96180b..3434b35760c56 100644 --- a/torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp +++ b/torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp @@ -1,14 +1,8 @@ -#include #include #include -#include #include -#include -#include -#include #ifdef USE_CUDA -#include #endif namespace torch::jit { diff --git a/torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp b/torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp index af0c0a6a7880d..be03a2750e6d9 100644 --- a/torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp +++ b/torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -8,7 +7,6 @@ #include #include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/frozen_conv_folding.cpp b/torch/csrc/jit/passes/frozen_conv_folding.cpp index 6bc75bfcc8cf6..e210f09ef3279 100644 --- a/torch/csrc/jit/passes/frozen_conv_folding.cpp +++ b/torch/csrc/jit/passes/frozen_conv_folding.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -11,7 +10,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp index e76575a2370a7..f086906a19f6c 100644 --- a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp +++ b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp @@ -1,12 +1,8 @@ -#include -#include -#include #include #include #include #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/frozen_linear_transpose.cpp b/torch/csrc/jit/passes/frozen_linear_transpose.cpp index 9595227d2587d..ccd97942f2c0f 100644 --- a/torch/csrc/jit/passes/frozen_linear_transpose.cpp +++ b/torch/csrc/jit/passes/frozen_linear_transpose.cpp @@ -1,9 +1,7 @@ #include -#include #include #include #include -#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -12,7 +10,6 @@ #include #endif -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/fuse_relu.cpp b/torch/csrc/jit/passes/fuse_relu.cpp index 1a8ee88b3da5c..953dc8fe2c37a 100644 --- a/torch/csrc/jit/passes/fuse_relu.cpp +++ b/torch/csrc/jit/passes/fuse_relu.cpp @@ -1,7 +1,6 @@ #include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 8dfa836f87bd8..03c418260e219 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -3,18 +3,14 @@ #include #include #include -#include #include #include #include #include #include -#include #include -#include #include -#include #include #include diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index 7b0fed5dc15f5..5f76a0ce0cf8f 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -2,9 +2,6 @@ #include #include -#include -#include -#include #include #include diff --git a/torch/csrc/jit/passes/hoist_conv_packed_params.cpp b/torch/csrc/jit/passes/hoist_conv_packed_params.cpp index 5ef4e5d576cb9..1222b7cb39be3 100644 --- a/torch/csrc/jit/passes/hoist_conv_packed_params.cpp +++ b/torch/csrc/jit/passes/hoist_conv_packed_params.cpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include #include diff --git a/torch/csrc/jit/passes/inliner.cpp b/torch/csrc/jit/passes/inliner.cpp index 1ddbb02f9278c..9c06f748e43a9 100644 --- a/torch/csrc/jit/passes/inliner.cpp +++ b/torch/csrc/jit/passes/inliner.cpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/insert_guards.cpp b/torch/csrc/jit/passes/insert_guards.cpp index 2bb810199e844..602a5086e7361 100644 --- a/torch/csrc/jit/passes/insert_guards.cpp +++ b/torch/csrc/jit/passes/insert_guards.cpp @@ -1,7 +1,6 @@ #include #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/integer_value_refinement.cpp b/torch/csrc/jit/passes/integer_value_refinement.cpp index 7405608bb4ca0..c760c5cb13798 100644 --- a/torch/csrc/jit/passes/integer_value_refinement.cpp +++ b/torch/csrc/jit/passes/integer_value_refinement.cpp @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/torch/csrc/jit/passes/liveness.cpp b/torch/csrc/jit/passes/liveness.cpp index 138c6fc78f752..5fc13b44f17d8 100644 --- a/torch/csrc/jit/passes/liveness.cpp +++ b/torch/csrc/jit/passes/liveness.cpp @@ -1,8 +1,6 @@ #include -#include #include -#include #include #include diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp index ff8c1642f6281..cfeb04f5f19e6 100644 --- a/torch/csrc/jit/passes/lower_tuples.cpp +++ b/torch/csrc/jit/passes/lower_tuples.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/metal_rewrite.cpp b/torch/csrc/jit/passes/metal_rewrite.cpp index 630701cab6dbb..82400a1cdcb1d 100644 --- a/torch/csrc/jit/passes/metal_rewrite.cpp +++ b/torch/csrc/jit/passes/metal_rewrite.cpp @@ -1,9 +1,5 @@ -#include -#include #include -#include -#include #include #include #include diff --git a/torch/csrc/jit/passes/mkldnn_rewrite.cpp b/torch/csrc/jit/passes/mkldnn_rewrite.cpp index 769d96eec218c..934f44a9ccf33 100644 --- a/torch/csrc/jit/passes/mkldnn_rewrite.cpp +++ b/torch/csrc/jit/passes/mkldnn_rewrite.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp index 1c0a453c28c52..4ce0afa2d3000 100644 --- a/torch/csrc/jit/passes/normalize_ops.cpp +++ b/torch/csrc/jit/passes/normalize_ops.cpp @@ -1,7 +1,5 @@ #include -#include - namespace torch::jit { namespace { diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index d3231222cb935..720688ccc76c0 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -1,9 +1,7 @@ #include -#include #include #include -#include #include #include #include @@ -13,7 +11,6 @@ #include #include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 902dc5f8924cd..60699a1e75ef4 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -1,7 +1,5 @@ #include -#include #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/onnx/eval_peephole.cpp b/torch/csrc/jit/passes/onnx/eval_peephole.cpp index 0334d5706a6eb..72fd0cb969074 100644 --- a/torch/csrc/jit/passes/onnx/eval_peephole.cpp +++ b/torch/csrc/jit/passes/onnx/eval_peephole.cpp @@ -1,10 +1,8 @@ #include #include #include -#include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index 2687ee9fb07dc..2f18a6d8c99cf 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -3,9 +3,7 @@ #include #include #include -#include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 8eab378c89223..3897a8d5cae5e 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -10,8 +10,6 @@ #include #endif -#include - namespace torch::jit { namespace onnx { using namespace ::c10::onnx; diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index f9aa740c44ada..491106f0cb24d 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -1,7 +1,5 @@ #include #include -#include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.cpp index 3283b82eb4673..e7f228f29b267 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.cpp @@ -1,8 +1,6 @@ #include #include -#include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp index a51801ac8363c..186a25873efa6 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp @@ -1,8 +1,5 @@ -#include -#include #include #include -#include // EDITING THIS FILE? READ THIS FIRST! // see Note [Edit Pattern Encapsulation] in pattern_encapsulation.h diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp index 5f35a85b2aa89..a00e98708e208 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp @@ -4,7 +4,6 @@ #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 92dfa86da4b1b..125a8f53b7950 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include #include #include @@ -11,7 +9,6 @@ #include #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/peephole_alias_sensitive.cpp b/torch/csrc/jit/passes/peephole_alias_sensitive.cpp index e3fca5c215f3b..e6ec265fc98b0 100644 --- a/torch/csrc/jit/passes/peephole_alias_sensitive.cpp +++ b/torch/csrc/jit/passes/peephole_alias_sensitive.cpp @@ -1,12 +1,6 @@ -#include #include -#include #include -#include -#include #include -#include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/peephole_list_idioms.cpp b/torch/csrc/jit/passes/peephole_list_idioms.cpp index e07496dee2e52..71734d32bbf8d 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.cpp +++ b/torch/csrc/jit/passes/peephole_list_idioms.cpp @@ -2,11 +2,8 @@ #include #include #include -#include -#include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/peephole_non_tensor.cpp b/torch/csrc/jit/passes/peephole_non_tensor.cpp index dbc8fce10da62..a6bd622fc5db9 100644 --- a/torch/csrc/jit/passes/peephole_non_tensor.cpp +++ b/torch/csrc/jit/passes/peephole_non_tensor.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/torch/csrc/jit/passes/prepack_folding.cpp b/torch/csrc/jit/passes/prepack_folding.cpp index 608432602ddbb..6efd442758586 100644 --- a/torch/csrc/jit/passes/prepack_folding.cpp +++ b/torch/csrc/jit/passes/prepack_folding.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index 5fab235044453..d1dc726faaae2 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -4,8 +4,6 @@ #include #include #include -#include -#include #include #include #include diff --git a/torch/csrc/jit/passes/refine_tuple_types.cpp b/torch/csrc/jit/passes/refine_tuple_types.cpp index 08d91d43150fc..14349438eef59 100644 --- a/torch/csrc/jit/passes/refine_tuple_types.cpp +++ b/torch/csrc/jit/passes/refine_tuple_types.cpp @@ -1,8 +1,6 @@ #include #include -#include - #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/remove_redundant_profiles.cpp b/torch/csrc/jit/passes/remove_redundant_profiles.cpp index 1bfb6396ebafc..e636433cff825 100644 --- a/torch/csrc/jit/passes/remove_redundant_profiles.cpp +++ b/torch/csrc/jit/passes/remove_redundant_profiles.cpp @@ -1,8 +1,6 @@ -#include #include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/passes/replacement_of_old_operators.cpp b/torch/csrc/jit/passes/replacement_of_old_operators.cpp index 090f4a46b1414..4e9f123918b61 100644 --- a/torch/csrc/jit/passes/replacement_of_old_operators.cpp +++ b/torch/csrc/jit/passes/replacement_of_old_operators.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/requires_grad_analysis.cpp b/torch/csrc/jit/passes/requires_grad_analysis.cpp index 88367b58a81bc..17a8289ba75cd 100644 --- a/torch/csrc/jit/passes/requires_grad_analysis.cpp +++ b/torch/csrc/jit/passes/requires_grad_analysis.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/restore_mutation.cpp b/torch/csrc/jit/passes/restore_mutation.cpp index 8e02f4f55e241..fbefcd7ed7ac1 100644 --- a/torch/csrc/jit/passes/restore_mutation.cpp +++ b/torch/csrc/jit/passes/restore_mutation.cpp @@ -1,5 +1,3 @@ -#include -#include #include #include diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 57dc2552c661c..7493667a2f027 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -11,9 +11,6 @@ #include #include -#include - -#include #include #include diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index 999f8247b7c84..75ec0e12016c6 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -12,16 +12,12 @@ #include #include #include -#include -#include #include #include #include #include #include #include -#include -#include #include #include #include diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp index 603631165717b..c8c6953f3447f 100644 --- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 672a9949c6b91..e0f03324454ef 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -13,7 +12,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp b/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp index 3333bfeefb120..fb46771cdbcd8 100644 --- a/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp +++ b/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp @@ -1,7 +1,6 @@ #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/passes/utils/memory_dag.cpp b/torch/csrc/jit/passes/utils/memory_dag.cpp index 8ad213082f52f..1b56cddf79d80 100644 --- a/torch/csrc/jit/passes/utils/memory_dag.cpp +++ b/torch/csrc/jit/passes/utils/memory_dag.cpp @@ -2,7 +2,6 @@ #include #include -#include namespace torch::jit { namespace { diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index f54adbd7223a2..6f92e821e5b44 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -3,7 +3,6 @@ #include #include -#include #include #include diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp index 7d9b3b8210c2b..4914a2f81869d 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.cpp +++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 9f7c2756d0d73..31f24bf1b4b92 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -8,7 +7,6 @@ #include -#include #include #include diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp index 32ba91df0ab34..25f5088368e58 100644 --- a/torch/csrc/jit/python/python_custom_class.cpp +++ b/torch/csrc/jit/python/python_custom_class.cpp @@ -1,8 +1,6 @@ #include #include -#include - #include namespace torch::jit { diff --git a/torch/csrc/jit/python/python_dict.cpp b/torch/csrc/jit/python/python_dict.cpp index ea64f5a985de0..82fd4449a4f15 100644 --- a/torch/csrc/jit/python/python_dict.cpp +++ b/torch/csrc/jit/python/python_dict.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/python/python_interpreter.cpp b/torch/csrc/jit/python/python_interpreter.cpp index 7b29134cf0e84..7e78cbd28f7e8 100644 --- a/torch/csrc/jit/python/python_interpreter.cpp +++ b/torch/csrc/jit/python/python_interpreter.cpp @@ -1,24 +1,11 @@ #include -#include -#include -#include -#include -#include #include #include #include #include -#include #include -#include -#include -#include -#include -#include -#include - namespace py = pybind11; namespace torch::jit { diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 6e5dcde957ddb..bd1290cbdf9e8 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -12,11 +12,7 @@ #include #include #include -#include -#include #include -#include -#include #include #include #include diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 8b16e089aa50e..26c8fe067a621 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include @@ -9,15 +8,12 @@ #include #include #include -#include #include #include #include #include #include -#include - namespace torch::jit { std::string typeString(py::handle h) { diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 9210311997384..5cf3bd900f351 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index f1e58a9bd3e38..214a07872d0ac 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/decomposition_registry.cpp b/torch/csrc/jit/runtime/decomposition_registry.cpp index 31ee76d142994..fbaa10ee32b2d 100644 --- a/torch/csrc/jit/runtime/decomposition_registry.cpp +++ b/torch/csrc/jit/runtime/decomposition_registry.cpp @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/torch/csrc/jit/runtime/decomposition_registry_util.cpp b/torch/csrc/jit/runtime/decomposition_registry_util.cpp index d0a4fa3b04fb2..ad48d1cd89370 100644 --- a/torch/csrc/jit/runtime/decomposition_registry_util.cpp +++ b/torch/csrc/jit/runtime/decomposition_registry_util.cpp @@ -5,10 +5,7 @@ * To re-generate, please run: * cd ~/pytorch && python torchgen/decompositions/gen_jit_decompositions.py */ -#include -#include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index bb152df094f5a..4bdab8c5dcb22 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 95b74376d2eb2..7fd16c08a9e73 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -1,17 +1,10 @@ #include -#include #include #include -#include #include #include #include -#include -#include -#include -#include -#include #include #include #include @@ -40,7 +33,6 @@ using torch::distributed::autograd::DistAutogradContainer; #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/jit_trace.cpp b/torch/csrc/jit/runtime/jit_trace.cpp index 45be4fe21bb4b..8a1daabf54ca9 100644 --- a/torch/csrc/jit/runtime/jit_trace.cpp +++ b/torch/csrc/jit/runtime/jit_trace.cpp @@ -1,16 +1,10 @@ -#include -#include #include #include #include #include #include -#include -#include #include -#include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 6f9dec70cddc9..30105754c5ee2 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 98acf24dd1df3..680244b363c36 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -16,11 +15,9 @@ #include #include #include -#include #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/profiling_record.cpp b/torch/csrc/jit/runtime/profiling_record.cpp index 30e8c58d65a0f..fb01aa2a25574 100644 --- a/torch/csrc/jit/runtime/profiling_record.cpp +++ b/torch/csrc/jit/runtime/profiling_record.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/register_c10_ops.cpp b/torch/csrc/jit/runtime/register_c10_ops.cpp index 85e8c0a2b037c..be7bfbd4acd24 100644 --- a/torch/csrc/jit/runtime/register_c10_ops.cpp +++ b/torch/csrc/jit/runtime/register_c10_ops.cpp @@ -1,8 +1,5 @@ -#include #include #include -#include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/runtime/register_cuda_ops.cpp b/torch/csrc/jit/runtime/register_cuda_ops.cpp index 9ca5e01b0dd01..dec70000d2ce7 100644 --- a/torch/csrc/jit/runtime/register_cuda_ops.cpp +++ b/torch/csrc/jit/runtime/register_cuda_ops.cpp @@ -1,6 +1,5 @@ // This file registers special JIT operators used to implement the PyTorch CUDA // API in TorchScript. -#include #include #include #include diff --git a/torch/csrc/jit/runtime/register_distributed_ops.cpp b/torch/csrc/jit/runtime/register_distributed_ops.cpp index a09a0f99f25ff..8ce967aca0f05 100644 --- a/torch/csrc/jit/runtime/register_distributed_ops.cpp +++ b/torch/csrc/jit/runtime/register_distributed_ops.cpp @@ -1,8 +1,5 @@ -#include -#include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index b09cc45ce33f7..b74fc4316c24f 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -4,25 +4,14 @@ #include #include #include -#include #include #include -#include -#include #include -#include #include #include -#include -#include -#include -#include #include #include -#include -#include -#include #include #include diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 0f2447e05a9f8..c7343914cb639 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -10,11 +9,9 @@ #include #include #include -#include #include #include -#include #include diff --git a/torch/csrc/jit/runtime/script_profile.cpp b/torch/csrc/jit/runtime/script_profile.cpp index a1e1ad6972e4a..a9151d0e00fbc 100644 --- a/torch/csrc/jit/runtime/script_profile.cpp +++ b/torch/csrc/jit/runtime/script_profile.cpp @@ -7,7 +7,6 @@ #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp index d77e0b3a10d64..89537c4b40422 100644 --- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp +++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp @@ -6,9 +6,6 @@ * cd ~/pytorch && python * torchgen/shape_functions/gen_jit_shape_functions.py */ -#include -#include -#include #include // clang-format off diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp index 61f2e5614ef05..1dc66c85f1dc4 100644 --- a/torch/csrc/jit/runtime/static/fusion.cpp +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 8ad348bb162c1..4cd12cf19fbb6 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -17,14 +16,12 @@ #include #include #include -#include #include #include #include #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/static/memory_planner.cpp b/torch/csrc/jit/runtime/static/memory_planner.cpp index 8660183867e08..d1051b94b63e2 100644 --- a/torch/csrc/jit/runtime/static/memory_planner.cpp +++ b/torch/csrc/jit/runtime/static/memory_planner.cpp @@ -1,11 +1,8 @@ #include #include -#include -#include #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 716202f45687a..9478dd98fce30 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -3,13 +3,8 @@ #include #include -#include -#include -#include #include #include -#include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index fdb0919da45ce..1029dd7019f8c 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index b1f0f410f14fe..6fb34bc2027b4 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp index ac0cd61fd2fef..c14277aebeb14 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry_util.cpp @@ -1,9 +1,4 @@ -#include -#include -#include -#include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index fb1280400a89d..0e792934472d1 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/python_dimname.cpp b/torch/csrc/python_dimname.cpp index d7046552f80f5..07f604600b22b 100644 --- a/torch/csrc/python_dimname.cpp +++ b/torch/csrc/python_dimname.cpp @@ -1,5 +1,4 @@ #include -#include #include #include diff --git a/torch/csrc/utils/cpp_stacktraces.cpp b/torch/csrc/utils/cpp_stacktraces.cpp index 641dffe08bc59..79c8253b91a62 100644 --- a/torch/csrc/utils/cpp_stacktraces.cpp +++ b/torch/csrc/utils/cpp_stacktraces.cpp @@ -1,8 +1,5 @@ #include -#include -#include - #include #include diff --git a/torch/csrc/utils/device_lazy_init.cpp b/torch/csrc/utils/device_lazy_init.cpp index e531cca4fb273..6083b55064c75 100644 --- a/torch/csrc/utils/device_lazy_init.cpp +++ b/torch/csrc/utils/device_lazy_init.cpp @@ -3,7 +3,6 @@ #include #include -#include #include #ifndef WIN32 diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index becbe1681f000..d75c0351fb6c4 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/utils/init.cpp b/torch/csrc/utils/init.cpp index 30e4082b0330b..986df49c571e1 100644 --- a/torch/csrc/utils/init.cpp +++ b/torch/csrc/utils/init.cpp @@ -2,9 +2,6 @@ #include #include -#include -#include - namespace torch::throughput_benchmark { void initThroughputBenchmarkBindings(PyObject* module) { diff --git a/torch/csrc/utils/object_ptr.cpp b/torch/csrc/utils/object_ptr.cpp index ff314fdad145a..c77797c0a48e3 100644 --- a/torch/csrc/utils/object_ptr.cpp +++ b/torch/csrc/utils/object_ptr.cpp @@ -1,8 +1,6 @@ #include #include -#include - template <> TORCH_PYTHON_API void THPPointer::free() { if (ptr && C10_LIKELY(Py_IsInitialized())) diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 3380bb0a13e57..69971fe09839b 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -1,14 +1,11 @@ #include #include -#include #include -#include #include #include #include #include -#include #include #include @@ -22,14 +19,9 @@ #include #include -#include -#include #include -#include #include -#include -#include #include #include diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp index c8a731d8d5fe7..efb0b2c6889ef 100644 --- a/torch/csrc/utils/tensor_apply.cpp +++ b/torch/csrc/utils/tensor_apply.cpp @@ -1,11 +1,9 @@ #include #include -#include #include #include -#include #include using namespace at; diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index e7c58540d74e4..39df9be68868a 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/utils/tensor_layouts.cpp b/torch/csrc/utils/tensor_layouts.cpp index be8816c8a9aba..d0bccbcf9106f 100644 --- a/torch/csrc/utils/tensor_layouts.cpp +++ b/torch/csrc/utils/tensor_layouts.cpp @@ -1,9 +1,6 @@ -#include -#include #include #include #include -#include #include #include diff --git a/torch/csrc/utils/tensor_list.cpp b/torch/csrc/utils/tensor_list.cpp index f25175af2dcc1..0a264e11e3586 100644 --- a/torch/csrc/utils/tensor_list.cpp +++ b/torch/csrc/utils/tensor_list.cpp @@ -2,10 +2,8 @@ #include #include -#include #include #include -#include #include using namespace at; diff --git a/torch/csrc/utils/tensor_memoryformats.cpp b/torch/csrc/utils/tensor_memoryformats.cpp index 28d56291bc945..c1a3ff326493a 100644 --- a/torch/csrc/utils/tensor_memoryformats.cpp +++ b/torch/csrc/utils/tensor_memoryformats.cpp @@ -1,11 +1,9 @@ #include #include -#include #include #include -#include #include namespace torch::utils { diff --git a/torch/csrc/utils/tensor_qschemes.cpp b/torch/csrc/utils/tensor_qschemes.cpp index 4c2e6f20557e9..f85d091bd57a0 100644 --- a/torch/csrc/utils/tensor_qschemes.cpp +++ b/torch/csrc/utils/tensor_qschemes.cpp @@ -2,11 +2,9 @@ #include #include -#include #include #include -#include #include namespace torch::utils { diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index c46baea82a442..620086f9ad50d 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -1,10 +1,8 @@ -#include #include #include #include -#include #include #include diff --git a/torch/csrc/utils/throughput_benchmark.cpp b/torch/csrc/utils/throughput_benchmark.cpp index 2f0ba77979a53..8e8016567c721 100644 --- a/torch/csrc/utils/throughput_benchmark.cpp +++ b/torch/csrc/utils/throughput_benchmark.cpp @@ -1,8 +1,6 @@ #include -#include #include -#include namespace torch::throughput_benchmark { From a0f3937b94422354538ebbd47202d5b0e8a3fd0d Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 28 Nov 2025 01:43:48 +0000 Subject: [PATCH 043/338] =?UTF-8?q?Remove=20unused=20inplace=20loop=20in?= =?UTF-8?q?=20test=5Fconv2d=5Fclamp=20of=20test=5Fjit=5Fllga=5Ffuser.py?= =?UTF-8?q?=E2=80=8E=20(#166691)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `inplace` is not used and not the test target in this unit test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166691 Approved by: https://github.com/rec, https://github.com/albanD --- test/test_jit_llga_fuser.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/test/test_jit_llga_fuser.py b/test/test_jit_llga_fuser.py index 1707288a318cd..d7c7f2885f6d5 100644 --- a/test/test_jit_llga_fuser.py +++ b/test/test_jit_llga_fuser.py @@ -507,13 +507,12 @@ def forward(self, x): x = torch.clamp(x, max=2) return x - for inplace in [False, True]: # noqa: F841 - for memory_format in [torch.contiguous_format, torch.channels_last]: - x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format) - m = M() - _, graph = self.checkTrace(m, [x], dtype) - self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5) - self.assertFused(graph, ['aten::_convolution', "aten::clamp"]) + for memory_format in [torch.contiguous_format, torch.channels_last]: + x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format) + m = M() + _, graph = self.checkTrace(m, [x], dtype) + self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5) + self.assertFused(graph, ['aten::_convolution', "aten::clamp"]) @onlyCPU @dtypes(torch.float32, torch.bfloat16) From a20f775e82564d2a9979221ed7f3b8d7cf54ce90 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Wed, 19 Nov 2025 08:48:20 +0000 Subject: [PATCH 044/338] [Quant][CPU] fix fp8 qconv (#167611) Summary: Fix fp8 qconv to support fp8_input/fp8_weight/bf16_bias in and bf16_output out. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167611 Approved by: https://github.com/Xia-Weiwen, https://github.com/mingfeima, https://github.com/jerryzh168 --- aten/src/ATen/native/quantized/cpu/qconv.cpp | 3 ++- test/quantization/core/test_quantized_op.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index cd8fb6df37f0e..c054d576516ce 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1426,8 +1426,9 @@ static at::Tensor _fp8_convolution_onednn_ref( w_scales_new_shape[0] = -1; auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape); auto output_padding = std::vector(kSpatialDim, 0); + auto bias_float = bias.has_value() ? bias.value().to(at::kFloat) : bias; auto y_f32 = at::convolution( - dqx, dqw, bias, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups + dqx, dqw, bias_float, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups ); if (!binary_attr.has_value() || binary_attr == "none") { if (unary_attr == "relu") { diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 7328870a64227..ce7eab2050d3a 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -7868,7 +7868,7 @@ def test_qconv1d_relu_pt2e(self): def _make_qconv_tensors_fp8( self, batch_size, input_channels_per_group, input_feature_map_shape, output_channels_per_group, groups, kernels, strides, pads, dilations, - use_bias, use_channelwise, use_transpose, + use_bias, use_channelwise, use_transpose, bfloat16_output, device=torch.device("cpu"), ): assert not (use_channelwise and use_transpose), \ @@ -7898,9 +7898,10 @@ def _make_qconv_tensors_fp8( X_q, X_scale = _quantize_fp8e4m3(X, channelwise=False) W = torch.randn(output_shape + kernels, device=device) * 0.1 W_q, W_scale = _quantize_fp8e4m3(W, channelwise=use_channelwise) - bias_float = torch.randn((output_channels,), device=device) if use_bias else None + bias_dtype = torch.bfloat16 if bfloat16_output else torch.float + bias = torch.randn((output_channels,), dtype=bias_dtype, device=device) if use_bias else None - return X, W, X_q, W_q, X_scale, W_scale, bias_float + return X, W, X_q, W_q, X_scale, W_scale, bias def _test_qconv_impl_cpu_tensor_fp8( self, @@ -7932,7 +7933,7 @@ def _test_qconv_impl_cpu_tensor_fp8( batch_size = 3 device = torch.device("cpu") use_transpose = False - X, W, X_q, W_q, X_scale, W_scale, bias_float = self._make_qconv_tensors_fp8( + X, W, X_q, W_q, X_scale, W_scale, bias = self._make_qconv_tensors_fp8( batch_size, input_channels_per_group, input_feature_map_shape, @@ -7945,11 +7946,13 @@ def _test_qconv_impl_cpu_tensor_fp8( use_bias, use_channelwise, use_transpose, + bfloat16_output, device=device, ) # Assign weights dqW = _dequantize_fp8e4m3(W_q, W_scale) dqX = _dequantize_fp8e4m3(X_q, X_scale) + bias_float = bias.float() if use_bias and bfloat16_output else bias conv_op.weight = torch.nn.Parameter(dqW, requires_grad=False) conv_op.bias = ( torch.nn.Parameter(bias_float, requires_grad=False) if use_bias else None @@ -8030,7 +8033,7 @@ def _test_qconv_impl_cpu_tensor_fp8( W_scale, torch.zeros([], dtype=torch.int8), # W_zero_point accum, - bias_float, + bias, strides, pads, dilations, @@ -8054,7 +8057,7 @@ def _test_qconv_impl_cpu_tensor_fp8( packed_weight, W_scale, torch.zeros([], dtype=torch.int8), # W_zero_point - bias_float, + bias, strides, pads, dilations, From f1076f5510920044912247b1abb8760cb820f598 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 27 Nov 2025 15:42:32 -0800 Subject: [PATCH 045/338] [dynamo] Support functools.partial as dict key (#169016) Improve error message along the way Fixes https://github.com/pytorch/pytorch/issues/169010 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169016 Approved by: https://github.com/guilhermeleobas, https://github.com/williamwen42, https://github.com/jansel --- test/dynamo/test_dicts.py | 44 +++++++++++++++++++++++++++++++- torch/_dynamo/variables/dicts.py | 20 ++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 79ead6b348a75..bda5f8759b46c 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -9,7 +9,9 @@ import unittest import weakref from collections import defaultdict, namedtuple, OrderedDict, UserDict -from typing import Any +from collections.abc import Callable +from functools import partial +from typing import Any, NamedTuple import torch import torch._dynamo.test_case @@ -1705,6 +1707,46 @@ def test_dict___iter__(self): it = d.__iter__() self.assertEqual(next(it), 1) + def test_functools_partial_key(self): + def gn(x, y): + return x + y + + def fn(x): + new_dict = {} + new_gn1 = partial(gn, x=1) + new_dict[new_gn1] = 5 + return x * new_dict[new_gn1] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + ref = fn(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + def test_namedtuple_functools(self): + class Container(NamedTuple): + partial_fn: Callable + const: int + + def gn(x, y): + return x + y + + def fn(x): + new_dict = {} + + new_gn = partial(gn, x=1) + key = Container(new_gn, 4) + new_dict[key] = 5 + return x * new_dict[key] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + ref = fn(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 93af8c46de01c..422cae7c4d3f1 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -105,6 +105,12 @@ def is_hashable(x: VariableTracker) -> bool: and isinstance(x.value, int) ): return isinstance(x.value, py_Hashable) + elif isinstance(x, variables.FunctoolsPartialVariable): + return ( + is_hashable(x.func) + and all(is_hashable(arg) for arg in x.args) + and all(is_hashable(value) for value in x.keywords.values()) + ) else: return isinstance( x, @@ -191,6 +197,11 @@ def underlying_value(self) -> Any: # an object as key (`class _ZeroSentinel(int): ...`): # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual return self.vt.value # type: ignore[attr-defined,union-attr] + elif isinstance(self.vt, variables.FunctoolsPartialVariable): + Hashable = ConstDictVariable._HashableTracker + items = (self.vt.func, *self.vt.args, *self.vt.keywords.values()) + x = tuple(Hashable(e).underlying_value for e in items) + return x else: x = self.vt.as_python_constant() return x @@ -420,7 +431,14 @@ def getitem_const_raise_exception_if_absent( ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: - raise_observed_exception(KeyError, tx) + try: + error_message = ( + f"Dict key lookup failed for {str(arg)}. " + f"Debug representation of the key is {arg.debug_repr()!r}" + ) + except Exception: + error_message = f"Dict key lookup failed for {str(arg)}" + raise_observed_exception(KeyError, tx, msg=error_message) return self.items[key] def getitem_const( From 9a296e640fc88aa44d275b48cd9cc30c573b169d Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 28 Nov 2025 03:53:19 +0000 Subject: [PATCH 046/338] Remove unused thrust inclusion (#169051) This PR removes unused thrust header to facilitate its moving to cccl. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169051 Approved by: https://github.com/albanD --- aten/src/ATen/native/cuda/block_reduce.cuh | 2 -- aten/src/ATen/native/sparse/cuda/SoftMax.cu | 18 +++--------------- .../native/sparse/cuda/SparseCUDATensor.cu | 1 - .../native/sparse/cuda/SparseCUDATensorMath.cu | 3 +-- .../ATen/native/sparse/cuda/SparseMatMul.cu | 1 - 5 files changed, 4 insertions(+), 21 deletions(-) diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index 1818987c6a588..019e4613bd014 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -1,7 +1,5 @@ #pragma once -#include - #include #include diff --git a/aten/src/ATen/native/sparse/cuda/SoftMax.cu b/aten/src/ATen/native/sparse/cuda/SoftMax.cu index 2ee8de3fd5edf..ec0d6f068ebf5 100644 --- a/aten/src/ATen/native/sparse/cuda/SoftMax.cu +++ b/aten/src/ATen/native/sparse/cuda/SoftMax.cu @@ -31,11 +31,13 @@ #include #include #include +#include #include +#include #include #include #include -#include +#include #include #include @@ -47,20 +49,6 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index b59221a3231a5..410c511bebef6 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -33,7 +33,6 @@ #include #include #include -#include #include namespace at::native { diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 62deedfc2a712..fab4f5438d5d4 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -37,10 +37,9 @@ #include #endif +#include #include #include -#include -#include #include #include diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index 49bea10c65104..745c9eb9af6ab 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -35,7 +35,6 @@ #include #include #include -#include #include #include From 7bc2a66ded06a0b2549aa51d807edc5dc3e73d1b Mon Sep 17 00:00:00 2001 From: linhaifeng <1371675203@qq.com> Date: Fri, 28 Nov 2025 04:58:46 +0000 Subject: [PATCH 047/338] [CUDA][BugFix] fix truncated error messages (#168942) Inspired by #168369 I found https://github.com/pytorch/pytorch/blob/9a38bb8622e5427e28b655df89b81293f63ecaac/c10/core/Device.h#L19 When device indices (DeviceIndex) with value 0 are passed to TORCH_CHECK macros,they are interpreted as string terminators (\0), causing error messages to be truncated. For example: ```cpp #include #include #include #include int8_t device = 0; int main() { std::cout << std::strlen((std::stringstream() << "Head" << device << "Tail").str().c_str()) << std::endl; std::cout << std::strlen((std::stringstream() << "Head" << static_cast(device) << "Tail").str().c_str()) << std::endl; std::cout << std::strlen((std::stringstream() << "Head" << +device << "Tail").str().c_str()) << std::endl; return 0; } ``` output ```bash 4 9 9 ``` Maybe we can use `+` instead of `static_cast`, but it needs discussion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168942 Approved by: https://github.com/cyyever, https://github.com/eqy --- aten/src/ATen/cuda/PeerToPeerAccess.cpp | 6 +++--- aten/src/ATen/cuda/detail/CUDAHooks.cpp | 2 +- c10/cuda/CUDAFunctions.cpp | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/cuda/PeerToPeerAccess.cpp b/aten/src/ATen/cuda/PeerToPeerAccess.cpp index 66a75db6ea067..a03d66f6147fc 100644 --- a/aten/src/ATen/cuda/PeerToPeerAccess.cpp +++ b/aten/src/ATen/cuda/PeerToPeerAccess.cpp @@ -42,10 +42,10 @@ void init_p2p_access_cache(int64_t num_devices) { bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) { at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); - TORCH_CHECK(dev >= 0 || dev < num_devices_, dev, " is not a device"); + TORCH_CHECK(dev >= 0 || dev < num_devices_, static_cast(dev), " is not a device"); TORCH_CHECK( dev_to_access >= 0 || dev_to_access < num_devices_, - dev_to_access, + static_cast(dev_to_access), " is not a device"); TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized"); @@ -147,7 +147,7 @@ bool get_fabric_access(c10::DeviceIndex dev) { #if !defined USE_ROCM && defined CUDA_VERSION && CUDA_VERSION >= 12040 && defined PYTORCH_C10_DRIVER_API_SUPPORTED at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); - TORCH_CHECK(dev >= 0 || dev < num_devices_, dev, " is not a device"); + TORCH_CHECK(dev >= 0 || dev < num_devices_, static_cast(dev), " is not a device"); auto& cache = fabricAccessEnabled_[dev]; if (cache != -1) { return cache; diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index b2b9be4498e5b..a4fd454633dc0 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -60,7 +60,7 @@ void set_magma_init_fn(void (*fn)()) { namespace { bool _hasPrimaryContext(DeviceIndex device_index) { TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(), - "hasPrimaryContext expects a valid device index, but got device_index=", device_index); + "hasPrimaryContext expects a valid device index, but got device_index=", static_cast(device_index)); unsigned int ctx_flags = 0; // In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird // (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero. diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 422652bb021b1..ec3a9e7badb56 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -242,7 +242,8 @@ cudaError_t GetDevice(DeviceIndex* device) { } cudaError_t SetDevice(DeviceIndex device, const bool force) { - TORCH_CHECK(device >= 0, "device id must be non-negative!", device); + TORCH_CHECK( + device >= 0, "device id must be non-negative!", static_cast(device)); targetDeviceIndex = -1; if (force) { return cudaSetDevice(device); @@ -323,7 +324,8 @@ cudaError_t GetDevice(DeviceIndex* device) { } cudaError_t SetDevice(DeviceIndex device, const bool force) { - TORCH_CHECK(device >= 0, "device id must be non-negative!", device); + TORCH_CHECK( + device >= 0, "device id must be non-negative!", static_cast(device)); if (force) { return cudaSetDevice(device); } From 7c648509a7470ace9fb2bae960dd4790f7e943e9 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 28 Nov 2025 06:12:47 +0000 Subject: [PATCH 048/338] [8/N] Use Python 3.10 typing (#168334) This PR applies Python 3.10 typing syntax to some files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168334 Approved by: https://github.com/Lucaskabela --- tools/dynamo/gb_id_mapping.py | 8 ++-- tools/experimental/torchfuzz/codegen.py | 3 +- tools/experimental/torchfuzz/fuzzer.py | 11 +++-- .../torchfuzz/multi_process_fuzzer.py | 11 +++-- tools/experimental/torchfuzz/operators/arg.py | 4 +- .../torchfuzz/operators/argsort.py | 3 +- .../experimental/torchfuzz/operators/base.py | 11 +++-- .../torchfuzz/operators/constant.py | 4 +- .../torchfuzz/operators/gather.py | 5 +-- .../torchfuzz/operators/index_select.py | 5 +-- .../experimental/torchfuzz/operators/item.py | 4 +- .../torchfuzz/operators/layout.py | 21 +++++---- .../torchfuzz/operators/masked_select.py | 5 +-- .../torchfuzz/operators/matrix_multiply.py | 9 ++-- .../torchfuzz/operators/nn_functional.py | 35 ++++++++------- .../torchfuzz/operators/nonzero.py | 5 +-- .../torchfuzz/operators/registry.py | 6 +-- .../torchfuzz/operators/scalar_pointwise.py | 3 +- .../torchfuzz/operators/unique.py | 4 +- tools/experimental/torchfuzz/ops_fuzzer.py | 9 ++-- tools/experimental/torchfuzz/tensor_fuzzer.py | 24 +++++------ tools/linter/adapters/_linter/block.py | 4 +- tools/linter/adapters/header_only_linter.py | 14 +++--- tools/linter/adapters/no_workflows_on_fork.py | 12 +++--- tools/nightly_hotpatch.py | 4 +- tools/setup_helpers/cmake_utils.py | 4 +- tools/stats/upload_stats_lib.py | 6 +-- .../upload_utilization_stats.py | 15 +++---- tools/stats/utilization_stats_lib.py | 43 +++++++++---------- tools/testing/update_slow_tests.py | 6 +-- tools/testing/upload_artifacts.py | 4 +- torch/onnx/_internal/exporter/_dispatching.py | 4 +- 32 files changed, 134 insertions(+), 172 deletions(-) diff --git a/tools/dynamo/gb_id_mapping.py b/tools/dynamo/gb_id_mapping.py index 1333e6d28cf1b..f7ec2347ba92e 100644 --- a/tools/dynamo/gb_id_mapping.py +++ b/tools/dynamo/gb_id_mapping.py @@ -3,10 +3,10 @@ import json import re from pathlib import Path -from typing import Any, Optional +from typing import Any -def get_source_segment(source: str, node: ast.AST) -> Optional[str]: +def get_source_segment(source: str, node: ast.AST) -> str | None: return ast.get_source_segment(source, node) @@ -48,7 +48,7 @@ def clean_string(s: Any) -> Any: return s -def expand_hints(hints: list[str], dynamo_dir: Optional[str] = None) -> list[str]: +def expand_hints(hints: list[str], dynamo_dir: str | None = None) -> list[str]: """ Expands hint references to their actual values from graph_break_hints. Uses exec() to avoid import dependencies. @@ -116,7 +116,7 @@ def extract_info_from_keyword(source: str, kw: ast.keyword) -> Any: def find_unimplemented_calls( - path: str, dynamo_dir: Optional[str] = None + path: str, dynamo_dir: str | None = None ) -> list[dict[str, Any]]: results = [] path_obj = Path(path) diff --git a/tools/experimental/torchfuzz/codegen.py b/tools/experimental/torchfuzz/codegen.py index c06df40a01bb4..3913e34b88cc9 100644 --- a/tools/experimental/torchfuzz/codegen.py +++ b/tools/experimental/torchfuzz/codegen.py @@ -1,6 +1,5 @@ # mypy: ignore-errors import os -from typing import Optional import torch @@ -504,7 +503,7 @@ def epilogue_codegen(self): def convert_graph_to_python_code( operation_graph: OperationGraph, - seed: Optional[int] = None, + seed: int | None = None, template: str = "default", ) -> str: """ diff --git a/tools/experimental/torchfuzz/fuzzer.py b/tools/experimental/torchfuzz/fuzzer.py index 5c54fded9f8a9..50a00853f0a54 100644 --- a/tools/experimental/torchfuzz/fuzzer.py +++ b/tools/experimental/torchfuzz/fuzzer.py @@ -4,7 +4,6 @@ import os import random import sys -from typing import Optional # Add parent directory to path so we can import torchfuzz as a module @@ -50,12 +49,12 @@ def _parse_supported_ops_with_weights(spec: str) -> tuple[list[str], dict[str, f def fuzz_and_execute( - seed: Optional[int] = None, - max_depth: Optional[int] = None, + seed: int | None = None, + max_depth: int | None = None, log_at_faluire: bool = False, template: str = "default", - supported_ops: Optional[list[str]] = None, - op_weights: Optional[dict[str, float]] = None, + supported_ops: list[str] | None = None, + op_weights: dict[str, float] | None = None, ) -> None: """ Generate a fuzzed operation stack, convert it to Python code, and execute it. @@ -328,7 +327,7 @@ def log(success: bool) -> None: # Single seed execution mode print("Running single fuzz_and_execute...") # Parse supported ops and any inline weights from that flag - parsed_supported_ops: Optional[list[str]] = None + parsed_supported_ops: list[str] | None = None parsed_weights: dict[str, float] = {} if args.supported_ops: parsed_supported_ops, parsed_weights = _parse_supported_ops_with_weights( diff --git a/tools/experimental/torchfuzz/multi_process_fuzzer.py b/tools/experimental/torchfuzz/multi_process_fuzzer.py index 21359b5e9da1a..2de88d47637cd 100644 --- a/tools/experimental/torchfuzz/multi_process_fuzzer.py +++ b/tools/experimental/torchfuzz/multi_process_fuzzer.py @@ -10,7 +10,6 @@ import time from collections import defaultdict from dataclasses import dataclass -from typing import Optional try: @@ -84,7 +83,7 @@ def is_ignored_output(output: str) -> int: def run_fuzzer_with_seed( seed: int, template: str = "default", - supported_ops: Optional[str] = None, + supported_ops: str | None = None, ) -> FuzzerResult: """ Run fuzzer.py with a specific seed. @@ -208,12 +207,12 @@ def handle_result_output( def run_multi_process_fuzzer( - num_processes: Optional[int] = None, + num_processes: int | None = None, seed_start: int = 0, seed_count: int = 100, verbose: bool = False, template: str = "default", - supported_ops: Optional[str] = None, + supported_ops: str | None = None, ) -> None: """ Run the multi-process fuzzer. @@ -504,10 +503,10 @@ def _print_operation_distribution(results: list[FuzzerResult]) -> None: def run_until_failure( - num_processes: Optional[int] = None, + num_processes: int | None = None, verbose: bool = False, template: str = "default", - supported_ops: Optional[str] = None, + supported_ops: str | None = None, ) -> None: """ Run the multi-process fuzzer with a random starting seed, iterating until a failure is found. diff --git a/tools/experimental/torchfuzz/operators/arg.py b/tools/experimental/torchfuzz/operators/arg.py index 8a9cc042cdb4d..edcc6c11f457f 100644 --- a/tools/experimental/torchfuzz/operators/arg.py +++ b/tools/experimental/torchfuzz/operators/arg.py @@ -1,7 +1,5 @@ """Arg operator implementation.""" -from typing import Optional - from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import Spec @@ -13,7 +11,7 @@ def __init__(self): super().__init__("arg") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Arg is not a torch operation, it represents function arguments.""" return None diff --git a/tools/experimental/torchfuzz/operators/argsort.py b/tools/experimental/torchfuzz/operators/argsort.py index 428c2b2fc308c..4281fc27daf2e 100644 --- a/tools/experimental/torchfuzz/operators/argsort.py +++ b/tools/experimental/torchfuzz/operators/argsort.py @@ -1,7 +1,6 @@ """Argsort operator implementation.""" import random -from typing import Optional import torch @@ -17,7 +16,7 @@ def __init__(self): super().__init__("argsort") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.argsort" diff --git a/tools/experimental/torchfuzz/operators/base.py b/tools/experimental/torchfuzz/operators/base.py index 3135a96a971f6..3e28f4f0bb2d9 100644 --- a/tools/experimental/torchfuzz/operators/base.py +++ b/tools/experimental/torchfuzz/operators/base.py @@ -1,7 +1,6 @@ """Base operator implementation.""" from abc import ABC, abstractmethod -from typing import Optional from torchfuzz.tensor_fuzzer import Spec @@ -22,7 +21,7 @@ def __init__(self, name: str, weight: float = 1.0): @property @abstractmethod - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """ Return the torch operation name this operator represents. @@ -57,10 +56,10 @@ def codegen( def get_weight( self, *, - target_spec: Optional[Spec] = None, - depth: Optional[int] = None, - stack_size: Optional[int] = None, - template: Optional[str] = None, + target_spec: Spec | None = None, + depth: int | None = None, + stack_size: int | None = None, + template: str | None = None, ) -> float: """ Return the selection weight for this operator. diff --git a/tools/experimental/torchfuzz/operators/constant.py b/tools/experimental/torchfuzz/operators/constant.py index ec3c95a3bdff9..67419672c2a4e 100644 --- a/tools/experimental/torchfuzz/operators/constant.py +++ b/tools/experimental/torchfuzz/operators/constant.py @@ -1,7 +1,5 @@ """Constant operator implementation.""" -from typing import Optional - from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import ( fuzz_scalar, @@ -20,7 +18,7 @@ def __init__(self): self.template = "default" # Track template for DTensor compatibility @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Constant is not a torch operation, it generates constant values.""" return None diff --git a/tools/experimental/torchfuzz/operators/gather.py b/tools/experimental/torchfuzz/operators/gather.py index 3daa1bcd7554e..cd7fa8d9fa4f2 100644 --- a/tools/experimental/torchfuzz/operators/gather.py +++ b/tools/experimental/torchfuzz/operators/gather.py @@ -1,7 +1,4 @@ -from typing import Optional - import torch - from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import Spec, TensorSpec @@ -13,7 +10,7 @@ def __init__(self): super().__init__("gather") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.gather" diff --git a/tools/experimental/torchfuzz/operators/index_select.py b/tools/experimental/torchfuzz/operators/index_select.py index 340b0ab6f434c..08ab682561166 100644 --- a/tools/experimental/torchfuzz/operators/index_select.py +++ b/tools/experimental/torchfuzz/operators/index_select.py @@ -1,7 +1,4 @@ -from typing import Optional - import torch - from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import Spec, TensorSpec @@ -13,7 +10,7 @@ def __init__(self): super().__init__("index_select") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.index_select" diff --git a/tools/experimental/torchfuzz/operators/item.py b/tools/experimental/torchfuzz/operators/item.py index 88bb2795b57ca..fc8d3e8bd26de 100644 --- a/tools/experimental/torchfuzz/operators/item.py +++ b/tools/experimental/torchfuzz/operators/item.py @@ -1,7 +1,5 @@ """Item operator implementation.""" -from typing import Optional - from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec @@ -13,7 +11,7 @@ def __init__(self): super().__init__("item") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Item is a tensor method, not a direct torch operation.""" return None diff --git a/tools/experimental/torchfuzz/operators/layout.py b/tools/experimental/torchfuzz/operators/layout.py index e753d93af5a63..66209812b7c37 100644 --- a/tools/experimental/torchfuzz/operators/layout.py +++ b/tools/experimental/torchfuzz/operators/layout.py @@ -1,7 +1,6 @@ """Tensor layout operator implementations.""" import random -from typing import Optional from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import fuzz_tensor_size, Spec, TensorSpec @@ -23,7 +22,7 @@ def __init__(self): super().__init__("view") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.Tensor.view" @@ -104,7 +103,7 @@ def __init__(self): super().__init__("reshape") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.reshape" @@ -179,7 +178,7 @@ def __init__(self): super().__init__("flatten") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.flatten" @@ -271,7 +270,7 @@ def __init__(self): super().__init__("squeeze") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.squeeze" @@ -323,7 +322,7 @@ def __init__(self): super().__init__("unsqueeze") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.unsqueeze" @@ -410,7 +409,7 @@ def __init__(self): super().__init__("split") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.split" @@ -490,7 +489,7 @@ def __init__(self): super().__init__("expand") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.expand" @@ -559,7 +558,7 @@ def __init__(self): super().__init__("cat") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.cat" @@ -664,7 +663,7 @@ def __init__(self): super().__init__("stack") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.stack" @@ -754,7 +753,7 @@ def __init__(self): super().__init__("chunk") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.chunk" diff --git a/tools/experimental/torchfuzz/operators/masked_select.py b/tools/experimental/torchfuzz/operators/masked_select.py index 5c68005dd111f..e88d031f95571 100644 --- a/tools/experimental/torchfuzz/operators/masked_select.py +++ b/tools/experimental/torchfuzz/operators/masked_select.py @@ -1,9 +1,6 @@ """Masked select operator implementation.""" -from typing import Optional - import torch - from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import Spec, TensorSpec @@ -15,7 +12,7 @@ def __init__(self): super().__init__("masked_select") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.masked_select" diff --git a/tools/experimental/torchfuzz/operators/matrix_multiply.py b/tools/experimental/torchfuzz/operators/matrix_multiply.py index 515623420f293..baa9e1c09ca33 100644 --- a/tools/experimental/torchfuzz/operators/matrix_multiply.py +++ b/tools/experimental/torchfuzz/operators/matrix_multiply.py @@ -1,7 +1,6 @@ """Matrix multiplication operator implementations.""" import random -from typing import Optional import torch @@ -52,7 +51,7 @@ def __init__(self): self.weight = 5.0 @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.mm" @@ -137,7 +136,7 @@ def __init__(self): self.weight = 5.0 @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.addmm" @@ -230,7 +229,7 @@ def __init__(self): self.weight = 5.0 @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.bmm" @@ -315,7 +314,7 @@ def __init__(self): self.weight = 500.0 @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.matmul" diff --git a/tools/experimental/torchfuzz/operators/nn_functional.py b/tools/experimental/torchfuzz/operators/nn_functional.py index 8f063926f933c..3eca2eb051c02 100644 --- a/tools/experimental/torchfuzz/operators/nn_functional.py +++ b/tools/experimental/torchfuzz/operators/nn_functional.py @@ -2,7 +2,6 @@ import math import random -from typing import Optional import torch @@ -27,7 +26,7 @@ def __init__(self): super().__init__("torch.nn.functional.embedding") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.embedding" @@ -109,7 +108,7 @@ def __init__(self): super().__init__("torch.nn.functional.linear") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.linear" @@ -207,7 +206,7 @@ def __init__(self): super().__init__("torch.nn.functional.relu") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.relu" @@ -250,7 +249,7 @@ def __init__(self): super().__init__("torch.nn.functional.softmax") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.softmax" @@ -297,7 +296,7 @@ def __init__(self): super().__init__("torch.nn.functional.dropout") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.dropout" @@ -341,7 +340,7 @@ def __init__(self): super().__init__("torch.nn.functional.layer_norm") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.layer_norm" @@ -438,7 +437,7 @@ def __init__(self): self.weight = 5.0 @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.rms_norm" @@ -512,7 +511,7 @@ def __init__(self): super().__init__("torch.nn.functional.gelu") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.gelu" @@ -554,7 +553,7 @@ def __init__(self): super().__init__("torch.sigmoid") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.sigmoid" @@ -596,7 +595,7 @@ def __init__(self): super().__init__("torch.tanh") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.tanh" @@ -638,7 +637,7 @@ def __init__(self): super().__init__("torch.nn.functional.batch_norm") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.batch_norm" @@ -742,7 +741,7 @@ def __init__(self): super().__init__("torch.nn.functional.group_norm") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.group_norm" @@ -846,7 +845,7 @@ def __init__(self): super().__init__("torch.nn.functional.leaky_relu") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.leaky_relu" @@ -888,7 +887,7 @@ def __init__(self): super().__init__("torch.nn.functional.elu") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.elu" @@ -930,7 +929,7 @@ def __init__(self): super().__init__("torch.nn.functional.silu") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.silu" @@ -972,7 +971,7 @@ def __init__(self): super().__init__("torch.nn.functional.scaled_dot_product_attention") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.scaled_dot_product_attention" @@ -1038,7 +1037,7 @@ def __init__(self): super().__init__("torch.nn.functional.multi_head_attention_forward") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nn.functional.multi_head_attention_forward" diff --git a/tools/experimental/torchfuzz/operators/nonzero.py b/tools/experimental/torchfuzz/operators/nonzero.py index 00b651e939b5d..ef22c3b700674 100644 --- a/tools/experimental/torchfuzz/operators/nonzero.py +++ b/tools/experimental/torchfuzz/operators/nonzero.py @@ -1,9 +1,6 @@ """Nonzero operator implementation.""" -from typing import Optional - import torch - from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import Spec, TensorSpec @@ -15,7 +12,7 @@ def __init__(self): super().__init__("nonzero") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.nonzero" diff --git a/tools/experimental/torchfuzz/operators/registry.py b/tools/experimental/torchfuzz/operators/registry.py index de9fb2618f4ad..aa1dd777efc58 100644 --- a/tools/experimental/torchfuzz/operators/registry.py +++ b/tools/experimental/torchfuzz/operators/registry.py @@ -1,7 +1,5 @@ """Operator registry for mapping operation names to operator instances.""" -from typing import Optional - from torchfuzz.operators.arg import ArgOperator from torchfuzz.operators.argsort import ArgsortOperator from torchfuzz.operators.base import Operator @@ -145,7 +143,7 @@ def register(self, operator: Operator): """Register an operator in the registry.""" self._operators[operator.name] = operator - def get(self, op_name: str) -> Optional[Operator]: + def get(self, op_name: str) -> Operator | None: """Get an operator by name.""" # Handle special arg_ operations by mapping them to the ArgOperator if op_name.startswith("arg_"): @@ -161,7 +159,7 @@ def list_operators(self) -> dict[str, Operator]: _global_registry = OperatorRegistry() -def get_operator(op_name: str) -> Optional[Operator]: +def get_operator(op_name: str) -> Operator | None: """Get an operator from the global registry.""" return _global_registry.get(op_name) diff --git a/tools/experimental/torchfuzz/operators/scalar_pointwise.py b/tools/experimental/torchfuzz/operators/scalar_pointwise.py index 6350c01206313..ff30feb840c4b 100644 --- a/tools/experimental/torchfuzz/operators/scalar_pointwise.py +++ b/tools/experimental/torchfuzz/operators/scalar_pointwise.py @@ -1,7 +1,6 @@ """Scalar pointwise operator implementation.""" import random -from typing import Optional import torch @@ -17,7 +16,7 @@ def __init__(self, name: str, symbol: str): self.symbol = symbol @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Scalar operations don't have specific torch ops, they use Python operators.""" return None diff --git a/tools/experimental/torchfuzz/operators/unique.py b/tools/experimental/torchfuzz/operators/unique.py index 5fa09dbe43153..9836fc5f3d942 100644 --- a/tools/experimental/torchfuzz/operators/unique.py +++ b/tools/experimental/torchfuzz/operators/unique.py @@ -1,7 +1,5 @@ """Unique operator implementation.""" -from typing import Optional - from torchfuzz.operators.base import Operator from torchfuzz.tensor_fuzzer import Spec, TensorSpec @@ -13,7 +11,7 @@ def __init__(self): super().__init__("unique") @property - def torch_op_name(self) -> Optional[str]: + def torch_op_name(self) -> str | None: """Return the torch operation name.""" return "torch.unique" diff --git a/tools/experimental/torchfuzz/ops_fuzzer.py b/tools/experimental/torchfuzz/ops_fuzzer.py index 3ff17bb5b559a..dda3dc6efcfe1 100644 --- a/tools/experimental/torchfuzz/ops_fuzzer.py +++ b/tools/experimental/torchfuzz/ops_fuzzer.py @@ -2,7 +2,6 @@ import random from dataclasses import dataclass -from typing import Optional import torch @@ -31,7 +30,7 @@ def _get_cached_operators(): def _get_template_filtered_operators( - template: str = "default", supported_ops: Optional[list[str]] = None + template: str = "default", supported_ops: list[str] | None = None ): """Get operators filtered by template's supported_ops, with user override. @@ -274,7 +273,7 @@ def fuzz_op( depth, stack_size, template: str = "default", - supported_ops: Optional[list[str]] = None, + supported_ops: list[str] | None = None, ) -> tuple[str, list[Spec]]: """ Given an output specification, returns an operation that can @@ -429,9 +428,9 @@ def _get_arg_args_specs(target_spec: Spec) -> tuple[str, list[Spec]]: def fuzz_operation_graph( target_spec: Spec, max_depth: int = 7, - seed: Optional[int] = None, + seed: int | None = None, template: str = "default", - supported_ops: Optional[list[str]] = None, + supported_ops: list[str] | None = None, ) -> OperationGraph: """ Generate a graph of operations that produces the target specification. diff --git a/tools/experimental/torchfuzz/tensor_fuzzer.py b/tools/experimental/torchfuzz/tensor_fuzzer.py index 0357d6cbca182..3ff71a03c2c2e 100644 --- a/tools/experimental/torchfuzz/tensor_fuzzer.py +++ b/tools/experimental/torchfuzz/tensor_fuzzer.py @@ -1,6 +1,6 @@ # mypy: ignore-errors import random -from typing import NamedTuple, Optional, Union +from typing import NamedTuple, Union import torch @@ -25,7 +25,7 @@ class ScalarSpec(NamedTuple): """Specification for a scalar argument.""" dtype: torch.dtype - constant: Optional[Union[int, float, bool, complex]] = ( + constant: int | float | bool | complex | None = ( None # If set, use this constant value instead of fuzzing ) @@ -334,10 +334,10 @@ def _compute_storage_size_needed( def fuzz_tensor( - size: Optional[tuple[int, ...]] = None, - stride: Optional[tuple[int, ...]] = None, - dtype: Optional[torch.dtype] = None, - seed: Optional[int] = None, + size: tuple[int, ...] | None = None, + stride: tuple[int, ...] | None = None, + dtype: torch.dtype | None = None, + seed: int | None = None, ) -> tuple[torch.Tensor, int]: """ Create a tensor with fuzzed size, stride, and dtype. @@ -423,10 +423,10 @@ def fuzz_tensor( def fuzz_tensor_simple( - size: Optional[tuple[int, ...]] = None, - stride: Optional[tuple[int, ...]] = None, - dtype: Optional[torch.dtype] = None, - seed: Optional[int] = None, + size: tuple[int, ...] | None = None, + stride: tuple[int, ...] | None = None, + dtype: torch.dtype | None = None, + seed: int | None = None, ) -> torch.Tensor: """ Convenience function that returns just the tensor without the seed. @@ -445,7 +445,7 @@ def fuzz_tensor_simple( def fuzz_non_contiguous_dense_tensor( - size: Optional[tuple[int, ...]] = None, dtype: Optional[torch.dtype] = None + size: tuple[int, ...] | None = None, dtype: torch.dtype | None = None ) -> torch.Tensor: """ Specifically generates tensors that are non-contiguous but dense and non-overlapping. @@ -492,7 +492,7 @@ def fuzz_non_contiguous_dense_tensor( return tensor -def fuzz_scalar(spec, seed: Optional[int] = None) -> Union[float, int, bool, complex]: +def fuzz_scalar(spec, seed: int | None = None) -> float | int | bool | complex: """ Create a Python scalar value from a ScalarSpec. diff --git a/tools/linter/adapters/_linter/block.py b/tools/linter/adapters/_linter/block.py index 4097da50a7e4e..7e506a49835c9 100644 --- a/tools/linter/adapters/_linter/block.py +++ b/tools/linter/adapters/_linter/block.py @@ -5,7 +5,7 @@ import token from enum import Enum from functools import cached_property, total_ordering -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from typing_extensions import Self @@ -64,7 +64,7 @@ class Category(str, Enum): is_method: bool = dc.field(default=False, repr=False) # A block index to the parent of this block, or None for a top-level block. - parent: Optional[int] = None + parent: int | None = None # A list of block indexes for the children children: list[int] = dc.field(default_factory=list) diff --git a/tools/linter/adapters/header_only_linter.py b/tools/linter/adapters/header_only_linter.py index 2548dae4c1994..f34a0bc55002d 100644 --- a/tools/linter/adapters/header_only_linter.py +++ b/tools/linter/adapters/header_only_linter.py @@ -10,7 +10,7 @@ import re from enum import Enum from pathlib import Path -from typing import NamedTuple, Union +from typing import NamedTuple LINTER_CODE = "HEADER_ONLY_LINTER" @@ -24,15 +24,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Union[str, None] - line: Union[int, None] - char: Union[int, None] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Union[str, None] - replacement: Union[str, None] - description: Union[str, None] + original: str | None + replacement: str | None + description: str | None CPP_TEST_GLOBS = [ diff --git a/tools/linter/adapters/no_workflows_on_fork.py b/tools/linter/adapters/no_workflows_on_fork.py index 02efd5f6f62a7..0f08b922eeccf 100644 --- a/tools/linter/adapters/no_workflows_on_fork.py +++ b/tools/linter/adapters/no_workflows_on_fork.py @@ -22,7 +22,7 @@ import re from enum import Enum from pathlib import Path -from typing import Any, NamedTuple, Optional, TYPE_CHECKING +from typing import Any, NamedTuple, TYPE_CHECKING from yaml import load @@ -63,10 +63,10 @@ def load_yaml(path: Path) -> Any: def gen_lint_message( - filename: Optional[str] = None, - original: Optional[str] = None, - replacement: Optional[str] = None, - description: Optional[str] = None, + filename: str | None = None, + original: str | None = None, + replacement: str | None = None, + description: str | None = None, ) -> LintMessage: return LintMessage( path=filename, @@ -85,7 +85,7 @@ def check_file(filename: str) -> list[LintMessage]: logging.debug("Checking file %s", filename) workflow = load_yaml(Path(filename)) - bad_jobs: dict[str, Optional[str]] = {} + bad_jobs: dict[str, str | None] = {} if type(workflow) is not dict: return [] diff --git a/tools/nightly_hotpatch.py b/tools/nightly_hotpatch.py index d8e78a82664d8..f4d3ab4e95fe9 100644 --- a/tools/nightly_hotpatch.py +++ b/tools/nightly_hotpatch.py @@ -7,7 +7,7 @@ import sys import tempfile import urllib.request -from typing import cast, NoReturn, Optional +from typing import cast, NoReturn def parse_arguments() -> argparse.Namespace: @@ -133,7 +133,7 @@ def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str: sys.exit(1) -def apply_patch(patch_file: str, target_dir: Optional[str], strip_count: int) -> None: +def apply_patch(patch_file: str, target_dir: str | None, strip_count: int) -> None: """ Applies the downloaded patch to the specified directory using the given strip count. diff --git a/tools/setup_helpers/cmake_utils.py b/tools/setup_helpers/cmake_utils.py index f89c2c99d38c5..a7e8ebe2edd06 100644 --- a/tools/setup_helpers/cmake_utils.py +++ b/tools/setup_helpers/cmake_utils.py @@ -6,10 +6,10 @@ from __future__ import annotations import re -from typing import IO, Optional, Union +from typing import IO -CMakeValue = Optional[Union[bool, str]] +CMakeValue = bool | str | None def convert_cmake_value_to_python_value( diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index 34548b80d76ba..dc006b5979fb8 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -9,7 +9,7 @@ import zipfile from functools import lru_cache from pathlib import Path -from typing import Any, cast, Optional, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING import boto3 # type: ignore[import] import requests @@ -94,7 +94,7 @@ def download_s3_artifacts( prefix: str, workflow_run_id: int, workflow_run_attempt: int, - job_id: Optional[int] = None, + job_id: int | None = None, ) -> list[Path]: bucket = get_s3_resource().Bucket(GHA_ARTIFACTS_BUCKET) objs = bucket.objects.filter( @@ -136,7 +136,7 @@ def upload_to_dynamodb( dynamodb_table: str, repo: str, docs: list[Any], - generate_partition_key: Optional[Callable[[str, dict[str, Any]], str]], + generate_partition_key: Callable[[str, dict[str, Any]], str] | None, ) -> None: print(f"Writing {len(docs)} documents to DynamoDB {dynamodb_table}") # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/dynamodb.html#batch-writing diff --git a/tools/stats/upload_utilization_stats/upload_utilization_stats.py b/tools/stats/upload_utilization_stats/upload_utilization_stats.py index 5b69c1a555952..66348e42a08a0 100644 --- a/tools/stats/upload_utilization_stats/upload_utilization_stats.py +++ b/tools/stats/upload_utilization_stats/upload_utilization_stats.py @@ -3,7 +3,6 @@ import os import sys from pathlib import Path -from typing import Union sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) @@ -11,7 +10,7 @@ import json import zipfile from dataclasses import asdict -from typing import Any, Optional +from typing import Any import pandas as pd # type: ignore[import] from tools.stats.upload_stats_lib import download_s3_artifacts, upload_to_s3 @@ -284,7 +283,7 @@ def get_log_data_from_local( file_path: str, artifact_prefix: str = "", ) -> tuple[ - Optional[UtilizationMetadata], list[UtilizationRecord], list[UtilizationRecord] + UtilizationMetadata | None, list[UtilizationRecord], list[UtilizationRecord] ]: test_log_content = read_file(file_path) if not test_log_content: @@ -302,7 +301,7 @@ def get_log_data_from_s3( workflow_run_attempt: int, artifact_prefix: str = JOB_TEST_ARTIFACT_PREFIX, ) -> tuple[ - Optional[UtilizationMetadata], list[UtilizationRecord], list[UtilizationRecord] + UtilizationMetadata | None, list[UtilizationRecord], list[UtilizationRecord] ]: artifact_paths = download_s3_artifacts( artifact_prefix, workflow_run_id, workflow_run_attempt, job_id @@ -331,9 +330,7 @@ def get_log_data_from_s3( print(f"Converted Log Model: UtilizationMetadata:\n {metadata}") return metadata, records, error_records - def _process_raw_record( - self, line: str - ) -> tuple[Optional[UtilizationRecord], bool]: + def _process_raw_record(self, line: str) -> tuple[UtilizationRecord | None, bool]: try: record = UtilizationRecord.from_json(line) if record.error: @@ -360,7 +357,7 @@ def convert_to_log_models( self, content: str, ) -> tuple[ - Optional[UtilizationMetadata], list[UtilizationRecord], list[UtilizationRecord] + UtilizationMetadata | None, list[UtilizationRecord], list[UtilizationRecord] ]: if not content: return None, [], [] @@ -397,7 +394,7 @@ def handle_file(file_path: Path) -> str: return "" -def read_file(file_path: Union[str, Path]) -> str: +def read_file(file_path: str | Path) -> str: try: if isinstance(file_path, Path): if file_path.is_file(): diff --git a/tools/stats/utilization_stats_lib.py b/tools/stats/utilization_stats_lib.py index 306cd7fe9e1f7..21ceb46a93d38 100644 --- a/tools/stats/utilization_stats_lib.py +++ b/tools/stats/utilization_stats_lib.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Optional # pyrefly: ignore [missing-import] from dataclasses_json import DataClassJsonMixin # type: ignore[import-not-found] @@ -12,9 +11,9 @@ # data model for test log usage @dataclass class UtilizationStats: - avg: Optional[float] = None - max: Optional[float] = None - raw: Optional[list[float]] = None + avg: float | None = None + max: float | None = None + raw: list[float] | None = None @dataclass @@ -27,38 +26,38 @@ class UtilizationMetadata(DataClassJsonMixin): # type: ignore[misc, no-any-unim usage_collect_interval: float data_model_version: float start_at: int - gpu_count: Optional[int] = None - cpu_count: Optional[int] = None - gpu_type: Optional[str] = None - error: Optional[str] = None + gpu_count: int | None = None + cpu_count: int | None = None + gpu_type: str | None = None + error: str | None = None @dataclass class GpuUsage(DataClassJsonMixin): # type: ignore[misc, no-any-unimported] - uuid: Optional[str] = None - util_percent: Optional[UtilizationStats] = None - mem_util_percent: Optional[UtilizationStats] = None - allocated_mem_percent: Optional[UtilizationStats] = None - allocated_mem_value: Optional[UtilizationStats] = None - total_mem_value: Optional[float] = None + uuid: str | None = None + util_percent: UtilizationStats | None = None + mem_util_percent: UtilizationStats | None = None + allocated_mem_percent: UtilizationStats | None = None + allocated_mem_value: UtilizationStats | None = None + total_mem_value: float | None = None @dataclass class RecordData(DataClassJsonMixin): # type: ignore[misc, no-any-unimported] - cpu: Optional[UtilizationStats] = None - memory: Optional[UtilizationStats] = None - gpu_usage: Optional[list[GpuUsage]] = None + cpu: UtilizationStats | None = None + memory: UtilizationStats | None = None + gpu_usage: list[GpuUsage] | None = None @dataclass class UtilizationRecord(DataClassJsonMixin): # type: ignore[misc, no-any-unimported] level: str timestamp: int - data: Optional[RecordData] = None - cmd_names: Optional[list[str]] = None - error: Optional[str] = None - log_duration: Optional[str] = None - logs: Optional[list[str]] = None + data: RecordData | None = None + cmd_names: list[str] | None = None + error: str | None = None + log_duration: str | None = None + logs: list[str] | None = None # the db schema related to this is: diff --git a/tools/testing/update_slow_tests.py b/tools/testing/update_slow_tests.py index c54399e18cdef..1b36defba67fa 100644 --- a/tools/testing/update_slow_tests.py +++ b/tools/testing/update_slow_tests.py @@ -3,7 +3,7 @@ import subprocess import time from pathlib import Path -from typing import Any, cast, Optional +from typing import Any, cast import requests from clickhouse import query_clickhouse # type: ignore[import] @@ -159,9 +159,7 @@ def add_labels(source_repo: str, pr_number: int, labels: list[str]) -> None: ) -def search_for_open_pr( - source_repo: str, search_string: str -) -> Optional[tuple[int, str]]: +def search_for_open_pr(source_repo: str, search_string: str) -> tuple[int, str] | None: params = { "q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}", "sort": "created", diff --git a/tools/testing/upload_artifacts.py b/tools/testing/upload_artifacts.py index 50f08c0f33cde..21a67f0786e2c 100644 --- a/tools/testing/upload_artifacts.py +++ b/tools/testing/upload_artifacts.py @@ -6,7 +6,7 @@ import zipfile from functools import lru_cache from pathlib import Path -from typing import Any, Optional +from typing import Any from filelock import FileLock, Timeout @@ -154,7 +154,7 @@ def parse_xml_and_upload_json() -> None: uploading the same file from multiple processes. """ try: - job_id: Optional[int] = int(os.environ.get("JOB_ID", 0)) + job_id: int | None = int(os.environ.get("JOB_ID", 0)) if job_id == 0: job_id = None except (ValueError, TypeError): diff --git a/torch/onnx/_internal/exporter/_dispatching.py b/torch/onnx/_internal/exporter/_dispatching.py index 1f935cfed192d..92df182c82c03 100644 --- a/torch/onnx/_internal/exporter/_dispatching.py +++ b/torch/onnx/_internal/exporter/_dispatching.py @@ -86,7 +86,7 @@ def _param_type_compatible_with_arg( assigned_types: dict[str, ir.TypeProtocol], ) -> bool: # Handle Python types first - if isinstance(value, bool): # noqa: SIM102 + if isinstance(value, bool): if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.BOOL)}: return True if isinstance(value, int) and param.type_constraint.allowed_types & { @@ -124,7 +124,7 @@ def _param_type_compatible_with_arg( ir.TensorType(ir.DataType.COMPLEX128), }: return True - if isinstance(value, str): # noqa: SIM102 + if isinstance(value, str): if param.type_constraint.allowed_types & {ir.TensorType(ir.DataType.STRING)}: return True if isinstance(value, (list, tuple)): From 29856679769b3dede478767e2fe6cfb51197cb25 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Fri, 28 Nov 2025 06:29:10 +0000 Subject: [PATCH 049/338] Fix an unsafe indexing in fx exception handling (#169140) There is an unsafe indexing into `e.args` in FX's exception handling, which would lead to nested exception and mask the actual error. This PR fixes it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169140 Approved by: https://github.com/cyyever, https://github.com/Skylion007 --- torch/fx/_symbolic_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 150c8ed746872..dfd777dc58056 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -881,7 +881,7 @@ def forward(*args, **kwargs): self.submodule_paths = None except RuntimeError as e: - if isinstance(e.args[0], str) and "data-dependent" in e.args[0]: + if e.args and isinstance(e.args[0], str) and "data-dependent" in e.args[0]: partial_fx_graph = self.graph.python_code( root_module="self", verbose=True, From f47dd0ddef1359e5b43e4b962412f67b30ecde56 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 28 Nov 2025 08:00:09 +0000 Subject: [PATCH 050/338] Enable SIM118 (#167399) This PR enables the `SIM118` rule of ruff, which checks for key-existence checks against dict.keys() calls. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167399 Approved by: https://github.com/albanD --- .ci/lumen_cli/cli/lib/common/gh_summary.py | 2 +- .github/scripts/get_workflow_job_id.py | 2 +- .github/scripts/test_trymerge.py | 6 ++---- .github/scripts/trymerge.py | 7 +++---- benchmarks/dynamo/common.py | 2 +- benchmarks/dynamo/microbenchmarks/operator_inp_utils.py | 2 +- benchmarks/dynamo/training_loss.py | 2 +- benchmarks/functional_autograd_benchmark/vision_models.py | 2 +- benchmarks/operator_benchmark/benchmark_core.py | 2 +- docs/source/scripts/exportdb/generate_example_rst.py | 2 +- pyproject.toml | 1 - test/distributed/argparse_util_test.py | 2 +- test/distributed/launcher/api_test.py | 2 +- test/distributed/launcher/test_run.py | 2 +- test/dynamo/test_dicts.py | 4 ++-- test/dynamo/test_guard_serialization.py | 2 +- test/dynamo/test_higher_order_ops.py | 2 +- test/dynamo/test_modules.py | 2 +- test/dynamo/test_nested_graph_breaks.py | 5 +---- test/dynamo/test_precompile_context.py | 2 +- test/dynamo/test_repros.py | 2 +- test/test_torchfuzz_repros.py | 1 - tools/autograd/gen_variable_type.py | 2 +- tools/stats/monitor.py | 2 +- tools/stats/upload_stats_lib.py | 2 +- tools/test/heuristics/test_utils.py | 2 +- tools/test/test_test_selections.py | 6 +++--- tools/testing/target_determination/heuristics/interface.py | 4 ++-- torch/_inductor/tiling_utils.py | 2 +- torch/distributed/_local_tensor/_c10d.py | 4 +--- torch/distributed/_tools/fsdp2_mem_tracker.py | 2 +- torch/distributed/flight_recorder/components/builder.py | 2 +- torch/distributed/flight_recorder/components/utils.py | 2 +- torch/utils/_debug_mode.py | 4 ++-- torchgen/operator_versions/gen_mobile_upgraders.py | 2 +- 35 files changed, 41 insertions(+), 51 deletions(-) diff --git a/.ci/lumen_cli/cli/lib/common/gh_summary.py b/.ci/lumen_cli/cli/lib/common/gh_summary.py index 72bfaa76e7068..73ae0aa20c39c 100644 --- a/.ci/lumen_cli/cli/lib/common/gh_summary.py +++ b/.ci/lumen_cli/cli/lib/common/gh_summary.py @@ -117,7 +117,7 @@ def md_kv_table(rows: Iterable[Mapping[str, str | int | float]]) -> str: Render a list of dicts as a Markdown table using Jinja template. """ rows = list(rows) - cols = list({k for r in rows for k in r.keys()}) + cols = list({k for r in rows for k in r}) md = _TPL_TABLE.render(cols=cols, rows=rows).strip() + "\n" return md diff --git a/.github/scripts/get_workflow_job_id.py b/.github/scripts/get_workflow_job_id.py index 54e66621c9fd0..db3d8a4e493b1 100644 --- a/.github/scripts/get_workflow_job_id.py +++ b/.github/scripts/get_workflow_job_id.py @@ -88,7 +88,7 @@ def fetch_jobs(url: str, headers: dict[str, str]) -> list[dict[str, str]]: response, links = fetch_url(url, headers=headers, reader=parse_json_and_links) jobs = response["jobs"] assert type(jobs) is list - while "next" in links.keys(): + while "next" in links: response, links = fetch_url( links["next"]["url"], headers=headers, reader=parse_json_and_links ) diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index 790deb85ef8c3..9eb41a9b623cb 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -435,15 +435,13 @@ def test_get_checkruns_many_runs(self, *args: Any) -> None: pr = GitHubPR("pytorch", "pytorch", 105260) conclusions = pr.get_checkrun_conclusions() self.assertEqual(len(conclusions), 221) - self.assertTrue( - "pull / linux-docs / build-docs-cpp-false" in conclusions.keys() - ) + self.assertTrue("pull / linux-docs / build-docs-cpp-false" in conclusions) def test_cancelled_gets_ignored(self, *args: Any) -> None: """Tests that cancelled workflow does not override existing successful status""" pr = GitHubPR("pytorch", "pytorch", 110367) conclusions = pr.get_checkrun_conclusions() - lint_checks = [name for name in conclusions.keys() if "Lint" in name] + lint_checks = [name for name in conclusions if "Lint" in name] self.assertTrue(len(lint_checks) > 0) self.assertTrue( all(conclusions[name].status == "SUCCESS" for name in lint_checks) diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 697ab6992793d..2100f96427574 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -2232,12 +2232,12 @@ def categorize_checks( # If required_checks is not set or empty, consider all names are relevant relevant_checknames = [ name - for name in check_runs.keys() + for name in check_runs if not required_checks or any(x in name for x in required_checks) ] for checkname in required_checks: - if all(checkname not in x for x in check_runs.keys()): + if all(checkname not in x for x in check_runs): pending_checks.append((checkname, None, None)) for checkname in relevant_checknames: @@ -2398,8 +2398,7 @@ def merge( ) pending, failing, _ = categorize_checks( checks, - required_checks - + [x for x in checks.keys() if x not in required_checks], + required_checks + [x for x in checks if x not in required_checks], ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD if ignore_flaky_failures else 0, diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index b3484e7196a83..3d3065ade8a5b 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2472,7 +2472,7 @@ def dump_max_mean_values(tol, ref, res): for refi, resi in zip(ref, res): dump_max_mean_values(tol, refi, resi) elif isinstance(ref, dict): - for k in ref.keys(): + for k in ref: dump_max_mean_values(tol, ref[k], res[k]) elif isinstance(ref, torch.Tensor): res = res.to(base_device) diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py index 8a6978dd448be..4387c9097af7e 100644 --- a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py +++ b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py @@ -293,7 +293,7 @@ def get_inputs_for_operator( yield args, kwargs def get_all_ops(self): - for key in self.operator_db.keys(): + for key in self.operator_db: try: op = eval(key) except AttributeError: diff --git a/benchmarks/dynamo/training_loss.py b/benchmarks/dynamo/training_loss.py index 1e7e57dfdbaea..911f00e1a50b2 100644 --- a/benchmarks/dynamo/training_loss.py +++ b/benchmarks/dynamo/training_loss.py @@ -153,7 +153,7 @@ def main(): "bert-base-cased", num_labels=5 ) optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer) - if "capturable" in inspect.signature(optimizer_cls).parameters.keys(): + if "capturable" in inspect.signature(optimizer_cls).parameters: optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True) else: optimizer = optimizer_cls(model.parameters(), lr=args.lr) diff --git a/benchmarks/functional_autograd_benchmark/vision_models.py b/benchmarks/functional_autograd_benchmark/vision_models.py index a33ac09da43ee..e5eac60017668 100644 --- a/benchmarks/functional_autograd_benchmark/vision_models.py +++ b/benchmarks/functional_autograd_benchmark/vision_models.py @@ -133,7 +133,7 @@ def forward(*new_params: Tensor) -> Tensor: weight_dict = criterion.weight_dict final_loss = cast( Tensor, - sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict), + sum(loss[k] * weight_dict[k] for k in loss if k in weight_dict), ) return final_loss diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index 7a8f0988a1fbf..5e88af6738a05 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -303,7 +303,7 @@ def split(s): break_idxs = [-1] curr_brackets = [] for i, c in enumerate(s): - if c in open_to_close.keys(): + if c in open_to_close: curr_brackets.append(c) elif c in open_to_close.values(): assert curr_brackets and open_to_close[curr_brackets[-1]] == c, ( diff --git a/docs/source/scripts/exportdb/generate_example_rst.py b/docs/source/scripts/exportdb/generate_example_rst.py index 8fdacad11053e..9d470d07e17e3 100644 --- a/docs/source/scripts/exportdb/generate_example_rst.py +++ b/docs/source/scripts/exportdb/generate_example_rst.py @@ -122,7 +122,7 @@ def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules): {module_contents} """ - tag_names = "\n ".join(t for t in tag_to_modules.keys()) + tag_names = "\n ".join(t for t in tag_to_modules) with open(os.path.join(PWD, "blurb.txt")) as file: blurb = file.read() diff --git a/pyproject.toml b/pyproject.toml index d9927122352f6..dfc622650f5e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,7 +177,6 @@ ignore = [ "SIM115", # Checks for cases where files are opened without using a context manager. "SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM117", - "SIM118", "SIM300", # Yoda condition detected "UP007", # keep-runtime-typing "UP045", # keep-runtime-typing diff --git a/test/distributed/argparse_util_test.py b/test/distributed/argparse_util_test.py index 1902faf992734..a3b3ef2bc717e 100644 --- a/test/distributed/argparse_util_test.py +++ b/test/distributed/argparse_util_test.py @@ -16,7 +16,7 @@ class ArgParseUtilTest(unittest.TestCase): def setUp(self): # remove any lingering environment variables - for e in os.environ.keys(): + for e in os.environ.keys(): # noqa: SIM118 if e.startswith("PET_"): del os.environ[e] diff --git a/test/distributed/launcher/api_test.py b/test/distributed/launcher/api_test.py index 48465516a913b..32e5f74cd6770 100644 --- a/test/distributed/launcher/api_test.py +++ b/test/distributed/launcher/api_test.py @@ -137,7 +137,7 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() # remove any lingering environment variables. - for env in os.environ.keys(): + for env in os.environ.keys(): # noqa:SIM118 if env.startswith("PET_"): del os.environ[env] diff --git a/test/distributed/launcher/test_run.py b/test/distributed/launcher/test_run.py index 1ba51bfa13908..484a975051d4f 100644 --- a/test/distributed/launcher/test_run.py +++ b/test/distributed/launcher/test_run.py @@ -70,7 +70,7 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() # remove any lingering environment variables - for env in os.environ.keys(): + for env in os.environ.keys(): # noqa: SIM118 if env.startswith("PET_"): del os.environ[env] diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index bda5f8759b46c..cdaeb2d91fbfb 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -143,7 +143,7 @@ def test_dict_subclass_methods_fallback_readonly(self): def fn(x): for value in sd.values(): x = x * value - for key in sd.keys(): + for key in sd: x = x * key for k, v in sd.items(): x = x * k @@ -189,7 +189,7 @@ def fn(sd, x): for value in sd.values(): x = x * value sd[6] = 14 - for key in sd.keys(): + for key in sd: x = x * key for k, v in sd.items(): x = x * k diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index efa9b7572b2be..9e3a62477db97 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -339,7 +339,7 @@ def _test_serialization(self, guard_type, fn, *args, **kwargs): # NB: This is super janky and might cause unforeseen problems if kwarg_gen_fn is not None: kwargs = kwarg_gen_fn() - for key in self._frame_state.f_locals.keys(): + for key in self._frame_state.f_locals: if key in kwargs and isinstance(kwargs[key], Iterator): self._frame_state.f_locals[key] = kwargs[key] diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 21398490e7b03..1f1a92b8c2b2b 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2182,7 +2182,7 @@ def _check_map_graph_and_extract(self, fn, args): gm = backend.graphs[0] graph = gm.code.strip() subgraphs = [] - for module_name in gm._modules.keys(): + for module_name in gm._modules: subgraphs.append(getattr(gm, module_name).code.strip()) return (graph, *subgraphs) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 4718ef0795897..bacab94e345d4 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -2788,7 +2788,7 @@ def __init__(self) -> None: ) def forward(self, x): - for activation_name in self.activations.keys(): + for activation_name in self.activations: x = self.activations[activation_name](x) return x diff --git a/test/dynamo/test_nested_graph_breaks.py b/test/dynamo/test_nested_graph_breaks.py index c3ce926b8dd5d..ca6fc89af651d 100644 --- a/test/dynamo/test_nested_graph_breaks.py +++ b/test/dynamo/test_nested_graph_breaks.py @@ -835,10 +835,7 @@ def f8(x): self.assertEqual(len(torch._dynamo.utils.counters["resumes"]), 2) for name in ("resume_in_f4", "resume_in_f7"): self.assertTrue( - any( - name in key - for key in torch._dynamo.utils.counters["resumes"].keys() - ) + any(name in key for key in torch._dynamo.utils.counters["resumes"]) ) def test_disable_nested_graph_breaks(self): diff --git a/test/dynamo/test_precompile_context.py b/test/dynamo/test_precompile_context.py index 6c72f65f53ae2..af86220d0cdf1 100644 --- a/test/dynamo/test_precompile_context.py +++ b/test/dynamo/test_precompile_context.py @@ -58,7 +58,7 @@ def simple_function(x): result.sum().backward() self.assertEqual(len(PrecompileContext._dynamo_cache_entries), 1) self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1) - for key in PrecompileContext._backend_artifacts_by_key.keys(): + for key in PrecompileContext._backend_artifacts_by_key: result = PrecompileContext.serialize_artifact_by_key(key) assert isinstance(result, BackendCacheArtifact) self.assertEqual(result.key, key) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 3fc5da288786e..a07bd92331faa 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5035,7 +5035,7 @@ def cat(instance_lists: list["Instances"]) -> "Instances": for i in instance_lists[1:]: assert i.image_size == image_size ret = Instances(image_size) - for k in instance_lists[0]._fields.keys(): + for k in instance_lists[0]._fields: values = [i.get(k) for i in instance_lists] v0 = values[0] if isinstance(v0, torch.Tensor): diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index b77701948d8ce..e00f0bb66aa75 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -13,7 +13,6 @@ import torch from torch.testing._internal.common_utils import run_tests, TestCase -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class TestFuzzerCompileIssues(TestCase): diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index e1a518aca6704..4b6ce65bb0bff 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -1003,7 +1003,7 @@ def gen_variable_type_func( result[f"type_derived_method_definitions_{key}"] = [type_definition] result[f"wrapper_registrations_{key}"] = [wrapper_registration] else: - for key in fn.info.keys(): + for key in fn.info: type_definition = METHOD_DEFINITION.substitute( return_type=cpp.returns_type( f.func.returns, symint=True diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index 38d1f94b178b2..97c00a4d09239 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -354,7 +354,7 @@ def _calculate_gpu_utilization(self, data_list: list[UsageData]) -> list[GpuUsag gpu_allocated_mem_values[gpu.uuid].append(gpu.allocated_mem_value) gpu_total_mem_values[gpu.uuid] = gpu.total_mem_value - for gpu_uuid in gpu_utilization.keys(): + for gpu_uuid in gpu_utilization: gpu_util_stats = self._generate_stats(gpu_utilization[gpu_uuid]) gpu_mem_util_stats = self._generate_stats(gpu_mem_utilization[gpu_uuid]) gpu_allocated_mem_stats = self._generate_stats(gpu_allocated_mem[gpu_uuid]) diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index dc006b5979fb8..9d9b52da9259d 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -49,7 +49,7 @@ def _get_artifact_urls(prefix: str, workflow_run_id: int) -> dict[Path, str]: headers=_get_request_headers(), ) artifacts = response.json()["artifacts"] - while "next" in response.links.keys(): + while "next" in response.links: response = requests.get( response.links["next"]["url"], headers=_get_request_headers() ) diff --git a/tools/test/heuristics/test_utils.py b/tools/test/heuristics/test_utils.py index e1f47b8453e17..39b5132b70062 100644 --- a/tools/test/heuristics/test_utils.py +++ b/tools/test/heuristics/test_utils.py @@ -21,7 +21,7 @@ def assertDictAlmostEqual( self, first: dict[TestRun, Any], second: dict[TestRun, Any] ) -> None: self.assertEqual(first.keys(), second.keys()) - for key in first.keys(): + for key in first: self.assertAlmostEqual(first[key], second[key]) def test_normalize_ratings(self) -> None: diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index f5164ddbc3a17..ea8d3e208db54 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -374,7 +374,7 @@ def test_split_shards(self) -> None: expected_shards, calculate_shards( 2, - [TestRun(t) for t in test_times.keys()], + [TestRun(t) for t in test_times], test_times, gen_class_times(test_times), ), @@ -404,7 +404,7 @@ def test_split_shards(self) -> None: expected_shards, calculate_shards( 2, - [TestRun(t) for t in test_times.keys()], + [TestRun(t) for t in test_times], test_times, gen_class_times(test_times), ), @@ -422,7 +422,7 @@ def test_split_shards(self) -> None: expected_shards, calculate_shards( 2, - [TestRun(t) for t in test_times.keys()], + [TestRun(t) for t in test_times], test_times, gen_class_times(test_times), ), diff --git a/tools/testing/target_determination/heuristics/interface.py b/tools/testing/target_determination/heuristics/interface.py index 48fbfa342a93f..4a33bb129dd34 100644 --- a/tools/testing/target_determination/heuristics/interface.py +++ b/tools/testing/target_determination/heuristics/interface.py @@ -75,7 +75,7 @@ def set_test_score(self, test_run: TestRun, new_score: float) -> None: return # We don't need this test relevant_test_runs: list[TestRun] = [ - tr for tr in self._test_scores.keys() if tr & test_run and tr != test_run + tr for tr in self._test_scores if tr & test_run and tr != test_run ] # Set the score of all the tests that are covered by test_run to the same score @@ -95,7 +95,7 @@ def add_test_score(self, test_run: TestRun, score_to_add: float) -> None: return relevant_test_runs: list[TestRun] = [ - tr for tr in self._test_scores.keys() if tr & test_run + tr for tr in self._test_scores if tr & test_run ] for relevant_test_run in relevant_test_runs: diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index ae529a355f275..89ad329abd70b 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -162,7 +162,7 @@ def find_broadcast_var( variables[v] = get_hint(v) zero_index = sympy_subs(index, variables) - for v in var_ranges.keys(): + for v in var_ranges: if v not in index.free_symbols: continue diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index a6a8c41103c9f..ab2387af051dc 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -238,9 +238,7 @@ def _local_functional_all_to_all_single( ): local_ints = dict(input_split_size.node._local_ints.items()) else: - local_ints = { - rank: int(input_split_size) for rank in tensor._local_tensors.keys() - } + local_ints = {rank: int(input_split_size) for rank in tensor._local_tensors} for rank, split_size in local_ints.items(): if rank not in split_local_sizes: split_local_sizes[rank] = [] diff --git a/torch/distributed/_tools/fsdp2_mem_tracker.py b/torch/distributed/_tools/fsdp2_mem_tracker.py index 52a601b895a89..8ac6dcb55e189 100644 --- a/torch/distributed/_tools/fsdp2_mem_tracker.py +++ b/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -383,7 +383,7 @@ def _instrument_fsdp_module(self) -> None: if not unique_handlers.get(fsdp_state._post_forward_hook_handle): unique_handlers[fsdp_state._post_forward_hook_handle] = True # call remove on the handles once - for f_hook_handle in unique_handlers.keys(): + for f_hook_handle in unique_handlers: f_hook_handle.remove() # pyrefly: ignore # missing-attribute for module in self._root_mod.modules(): diff --git a/torch/distributed/flight_recorder/components/builder.py b/torch/distributed/flight_recorder/components/builder.py index f3c9d324fc479..56736450e3f2a 100644 --- a/torch/distributed/flight_recorder/components/builder.py +++ b/torch/distributed/flight_recorder/components/builder.py @@ -181,7 +181,7 @@ def build_collectives( mismatch = {_groups[g].id: 0 for g in _groups} # For best effort partial analysis. - dumps_ranks = {int(key) for key in all_entries.keys()} + dumps_ranks = {int(key) for key in all_entries} """ - it doesn't matter what order I put collectives/ncclops into their table. we can later on re-sort it by start time - there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast) diff --git a/torch/distributed/flight_recorder/components/utils.py b/torch/distributed/flight_recorder/components/utils.py index 4e4e448158124..25c5350381187 100644 --- a/torch/distributed/flight_recorder/components/utils.py +++ b/torch/distributed/flight_recorder/components/utils.py @@ -701,7 +701,7 @@ def check_no_missing_dump_files( all_ranks = set() for membership in memberships: all_ranks.add(int(membership.global_rank)) - dumps_ranks = {int(key) for key in entries.keys()} + dumps_ranks = {int(key) for key in entries} assert dumps_ranks == all_ranks, ( f"Missing dump files from ranks {all_ranks - dumps_ranks}" ) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 745b05d1904d7..14c1607383e1c 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -39,7 +39,7 @@ import traceback import weakref from collections.abc import Callable -from typing import Any, Optional, TYPE_CHECKING, Union # noqa: F401 +from typing import Any, TYPE_CHECKING, Union # noqa: F401 import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -1101,7 +1101,7 @@ def check_hash_mismatches( def compare_triton_hashes(hashes1, hashes2, is_input): assert set(hashes1.keys()) == set(hashes2.keys()) # type: ignore[union-attr] - for key in hashes1.keys(): + for key in hashes1: if hashes1[key] != hashes2[key]: difference_info.append( { diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index d29b274f71bd2..15b74ac9c21a7 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -305,7 +305,7 @@ def get_upgrader_bytecode_function_to_index_map( upgrader_bytecode_function_to_index_map = {} index = 0 for upgrader_bytecode in upgrader_dict: - for upgrader_name in upgrader_bytecode.keys(): + for upgrader_name in upgrader_bytecode: if upgrader_name in EXCLUE_UPGRADER_SET: continue upgrader_bytecode_function_to_index_map[upgrader_name] = index From 9f8ef8855d3078d70f7b782540ff2aaf158d6742 Mon Sep 17 00:00:00 2001 From: Robert Hardwick Date: Thu, 27 Nov 2025 20:46:33 +0000 Subject: [PATCH 051/338] Unset LD_LIBRARY_PATH for binary wheel checks (#168349) This PR does 2 things - Unset LD_LIBRARY_PATH in check_binary.sh so that system libaries are not found by the linker. This will ensure that only packaged .so files in the whl are used by the runtime linker. E.g. in https://github.com/pytorch/pytorch/pull/166044 we had a problem of missing libopenblas.so but the check_binary.sh passed. - Allow DESIRED_PYTHON format cp310_cp310t. The script in .ci/python/set_desired_python.sh manipulates the env variable DESIRED_PYTHON, so this change allows the check_binary.sh to be used with the alt format. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168349 Approved by: https://github.com/atalman --- .ci/pytorch/check_binary.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.ci/pytorch/check_binary.sh b/.ci/pytorch/check_binary.sh index 34a26a293ae44..95d57f35ce4bd 100755 --- a/.ci/pytorch/check_binary.sh +++ b/.ci/pytorch/check_binary.sh @@ -25,6 +25,8 @@ set -eux -o pipefail # Pythonless binary, then it expects to be in the root folder of the unzipped # libtorch package. +# ensure we don't link to system libraries, linked libraries should be found from RPATH +unset LD_LIBRARY_PATH if [[ -z ${DESIRED_PYTHON:-} ]]; then export DESIRED_PYTHON=${MATRIX_PYTHON_VERSION:-} @@ -46,7 +48,10 @@ if [[ "$PACKAGE_TYPE" == libtorch ]]; then export install_root="$PWD" else - if [[ $DESIRED_PYTHON =~ ([0-9].[0-9]+)t ]]; then + if [[ $DESIRED_PYTHON =~ ^cp([0-9])([0-9][0-9])(-cp[0-9]+)?t?$ ]]; then + # Handle inputs like cp310-cp310 or cp310-cp310t + py_dot="${BASH_REMATCH[1]}.${BASH_REMATCH[2]}" + elif [[ $DESIRED_PYTHON =~ ([0-9].[0-9]+)t ]]; then # For python that is maj.mint keep original version py_dot="$DESIRED_PYTHON" elif [[ $DESIRED_PYTHON =~ ([0-9].[0-9]+) ]]; then From 6658a04c7ca67acb64512341342e7b3ee13ee386 Mon Sep 17 00:00:00 2001 From: zhudada Date: Fri, 28 Nov 2025 14:02:08 +0000 Subject: [PATCH 052/338] [Code Clean] Better error handling in aten/src/ATen/native/* and aten/src/ATen/mkl/Exceptions.h (#165290) Replace the runtime_error of the vallina C++ exceptions with TORCH_CEHCK Including: - aten/src/ATen/native/* - aten/src/ATen/mkl/Exceptions.h fix partialy #148114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165290 Approved by: https://github.com/fffrog, https://github.com/albanD --- aten/src/ATen/mkl/Exceptions.h | 7 +--- aten/src/ATen/native/ComparisonUtils.cpp | 37 ++++++++++--------- aten/src/ATen/native/TriangularOps.cpp | 5 +-- .../ao_sparse/quantized/cpu/packed_params.h | 5 +-- aten/src/ATen/native/cuda/jit_utils.cpp | 4 +- .../hip/bgemm_kernels/bgemm_kernel_template.h | 9 +---- aten/src/ATen/native/hip/ck_gemm_template.h | 14 ++----- .../ATen/native/nested/NestedTensorMath.cpp | 9 ++--- aten/src/ATen/native/quantized/PackedParams.h | 29 ++++----------- aten/src/ATen/native/quantized/cudnn/utils.h | 9 ++--- .../src/ATen/native/sparse/SparseBlasImpl.cpp | 7 +--- 11 files changed, 48 insertions(+), 87 deletions(-) diff --git a/aten/src/ATen/mkl/Exceptions.h b/aten/src/ATen/mkl/Exceptions.h index c70a7ab7a593e..4bcb5ac30555f 100644 --- a/aten/src/ATen/mkl/Exceptions.h +++ b/aten/src/ATen/mkl/Exceptions.h @@ -5,16 +5,13 @@ #include #include #include +#include namespace at::native { static inline void MKL_DFTI_CHECK(MKL_INT status) { - if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) { - std::ostringstream ss; - ss << "MKL FFT error: " << DftiErrorMessage(status); - throw std::runtime_error(ss.str()); - } + TORCH_CHECK(!status || DftiErrorClass(status, DFTI_NO_ERROR), "MKL FFT error: ", DftiErrorMessage(status)); } } // namespace at::native diff --git a/aten/src/ATen/native/ComparisonUtils.cpp b/aten/src/ATen/native/ComparisonUtils.cpp index 13bef0a00b9c9..e0fc7e630accc 100644 --- a/aten/src/ATen/native/ComparisonUtils.cpp +++ b/aten/src/ATen/native/ComparisonUtils.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #ifdef AT_PER_OPERATOR_HEADERS #include @@ -14,14 +15,12 @@ namespace native { template static void _assert_match(const O& original, const C& compared, const std::string& name) { - if (compared) { - bool equal = (original == compared.value()); - if (!equal) { - std::stringstream msg; - msg << "Tensor " << name << " mismatch! Expected: " << compared.value() << ", Got: " << original; - throw std::runtime_error(msg.str()); - } - } + TORCH_CHECK(!compared || original == compared.value(), "Tensor ", + name, + " mismatch! Expected: ", + compared.value(), + ", Got: ", + original); } template<> @@ -31,19 +30,21 @@ void _assert_match>( const std::string& name) { if (compared) { const c10::Device& expected = compared.value(); - if (original.type() != expected.type()) { - std::stringstream msg; - msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original; - throw std::runtime_error(msg.str()); - } + TORCH_CHECK(original.type() == expected.type(), "Tensor ", + name, + " mismatch! Expected: ", + expected, + ", Got: ", + original); // If the expected device doesn't have an index (e.g., just "cuda"), // or if both devices have the same index, consider them equal - if (expected.has_index() && original.has_index() && expected.index() != original.index()) { - std::stringstream msg; - msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original; - throw std::runtime_error(msg.str()); - } + TORCH_CHECK(!expected.has_index() || !original.has_index() || expected.index() == original.index(), "Tensor ", + name, + " mismatch! Expected: ", + expected, + ", Got: ", + original); } } diff --git a/aten/src/ATen/native/TriangularOps.cpp b/aten/src/ATen/native/TriangularOps.cpp index 08b666e296ed7..5560f3e79f273 100644 --- a/aten/src/ATen/native/TriangularOps.cpp +++ b/aten/src/ATen/native/TriangularOps.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -181,9 +182,7 @@ TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) { } Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) { - if (sizes.size() != 2) { - throw std::runtime_error("expected matrix input"); - } + TORCH_CHECK(sizes.size() == 2, "expected matrix input"); auto grad_input = at::zeros_symint(sizes[0] * sizes[1], grad.options()); auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong)); diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h b/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h index 14f98b5a49782..eae44f2a6071a 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/packed_params.h @@ -3,6 +3,7 @@ #include #include +#include namespace ao::sparse { @@ -62,9 +63,7 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { virtual std::optional bias() = 0; virtual void set_bias(const std::optional& bias) { - throw std::runtime_error( - "set_bias is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "set_bias is not implemented for this packed parameter type"); } protected: diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index e65fa4ceb38e9..5c0cb1d534db1 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -12,7 +12,7 @@ #include #include #include - +#include #include #include #include @@ -1615,7 +1615,7 @@ NvrtcFunction jit_pwise_function( AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLogSize(program, &logsize)); std::string log(logsize, '\0'); AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLog(program, &log[0])); - throw std::runtime_error(code + log); + TORCH_CHECK(false, code + log); } size_t ptx_size = 0; diff --git a/aten/src/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h b/aten/src/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h index 7cf35e13349ff..52b3651ebee73 100644 --- a/aten/src/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h +++ b/aten/src/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h @@ -4,7 +4,7 @@ #include #include #include - +#include #include #include #include @@ -151,12 +151,7 @@ void bgemm_kernel_impl(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { b_element_op, cde_element_op ); - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } + TORCH_CHECK(gemm.IsSupportedArgument(argument), "wrong! device_gemm with the specified compilation parameters does not support this GEMM problem"); auto stream = at::cuda::getCurrentHIPStream().stream(); invoker.Run(argument, StreamConfig{stream, false}); } diff --git a/aten/src/ATen/native/hip/ck_gemm_template.h b/aten/src/ATen/native/hip/ck_gemm_template.h index b34a8b132674a..2e54eb0ea5078 100644 --- a/aten/src/ATen/native/hip/ck_gemm_template.h +++ b/aten/src/ATen/native/hip/ck_gemm_template.h @@ -14,7 +14,7 @@ #include #include #include - +#include #include #include #include @@ -225,12 +225,7 @@ void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) { c_element_op); - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } + TORCH_CHECK(gemm.IsSupportedArgument(argument), "wrong! device_gemm with the specified compilation parameters does not support this GEMM problem"); auto stream = at::cuda::getCurrentHIPStream().stream(); @@ -384,10 +379,7 @@ void gemm_impl_wmma(CUDABLAS_GEMM_ARGTYPES(Dtype)) { { printf("error shape = %ld %ld %ld TRANSA=%d TRANSB=%d \n", n, m, k,TRANSA, TRANSB); - - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + TORCH_CHECK(false, "wrong! device_gemm with the specified compilation parameters does not support this GEMM problem"); } diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index ed7442b1c5969..8956890a88b72 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -745,12 +746,8 @@ inline std::tuple NestedTensor_compute_size_stride( numel_reshaped *= size_reshaped; } else if (size_reshaped == -1) { - if (infer_index > -1) { - throw std::runtime_error("only one dimension can be inferred"); - } - else { - infer_index = idim; - } + TORCH_CHECK(infer_index <= -1, "only one dimension can be inferred"); + infer_index = idim; } else { TORCH_CHECK(false, "invalid shape dimension ", size_reshaped); diff --git a/aten/src/ATen/native/quantized/PackedParams.h b/aten/src/ATen/native/quantized/PackedParams.h index d73bc0adbc4ef..bd78cc01e9a01 100644 --- a/aten/src/ATen/native/quantized/PackedParams.h +++ b/aten/src/ATen/native/quantized/PackedParams.h @@ -2,6 +2,7 @@ #include #include +#include struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { virtual at::Tensor apply( @@ -19,9 +20,7 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { double /*output_scale*/, int64_t /*output_zero_point*/, at::Tensor& output) { - throw std::runtime_error( - "apply_out is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "apply_out is not implemented for this packed parameter type"); return output; } @@ -30,9 +29,7 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { double /*output_scale*/, int64_t /*output_zero_point*/, at::Tensor& output) { - throw std::runtime_error( - "apply_relu_out is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "apply_relu_out is not implemented for this packed parameter type"); return output; } @@ -55,9 +52,7 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { at::Tensor input, double input_scale, int64_t input_zero_point) { - throw std::runtime_error( - "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed parameter type"); return {}; } @@ -79,9 +74,7 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { at::Tensor input, double input_scale, int64_t input_zero_point) { - throw std::runtime_error( - "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed parameter type"); return {}; } @@ -96,18 +89,14 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { const at::Tensor& /* input */, at::Tensor& output, bool /* reduce_range */) { - throw std::runtime_error( - "apply_dynamic_out is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "apply_dynamic_out is not implemented for this packed parameter type"); return output; } virtual at::Tensor& apply_dynamic_relu_out( const at::Tensor& /* input */, at::Tensor& output, bool /* reduce_range */) { - throw std::runtime_error( - "apply_dynamic_relu_out is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "apply_dynamic_relu_out is not implemented for this packed parameter type"); return output; } @@ -116,9 +105,7 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { virtual std::optional bias() = 0; virtual void set_bias(std::optional /*bias*/) { - throw std::runtime_error( - "set_bias is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "set_bias is not implemented for this packed parameter type"); } }; diff --git a/aten/src/ATen/native/quantized/cudnn/utils.h b/aten/src/ATen/native/quantized/cudnn/utils.h index 824694d363a01..0b46f743fa68d 100644 --- a/aten/src/ATen/native/quantized/cudnn/utils.h +++ b/aten/src/ATen/native/quantized/cudnn/utils.h @@ -13,6 +13,7 @@ This file contains some of the auxiliary functions used by both Conv.cpp & Linea #include #include #include +#include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include @@ -43,14 +44,10 @@ struct PackedLinearWeightCudnn : public LinearPackedParamsBase { int64_t output_zero_point) override; at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) override { - throw std::runtime_error( - "apply_dynamic is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "apply_dynamic is not implemented for this packed parameter type"); } at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) override { - throw std::runtime_error( - "apply_dynamic_relu is not implemented for this packed " - "parameter type"); + TORCH_CHECK(false, "apply_dynamic_relu is not implemented for this packed parameter type"); } std::tuple> unpack() override; diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp index c841da8354b5f..9047594c5565e 100644 --- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp @@ -7,7 +7,7 @@ // Required for checking whether Triton kernels are available #include - +#include #ifndef AT_PER_OPERATOR_HEADERS #include #include @@ -248,10 +248,7 @@ Tensor& _compressed_row_strided_addmm_out( try { return triton_kernel.call(self, mat1, mat2, beta, alpha, result); } catch (std::runtime_error& e) { - const std::string msg = e.what(); - if (msg != std::string("Unable to cast NotImplemented to Tensor")) { - throw std::runtime_error(msg); - } + TORCH_CHECK(e.what() == std::string("Unable to cast NotImplemented to Tensor"), e.what()); } /* else triton_kernel returned NotImplemented, continue with the generic method below */ } From b39813b4a04931682b0491adba2138d01d716d99 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:49:28 +0000 Subject: [PATCH 053/338] [ROCm][CI] Use hash for download-artifact action instead of tag (#169217) Security-related update Hash chosen from https://github.com/actions/download-artifact/tree/v4.1.7 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169217 Approved by: https://github.com/zxiiro, https://github.com/jeffdaily --- .github/workflows/docker-cache-rocm.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docker-cache-rocm.yml b/.github/workflows/docker-cache-rocm.yml index 380b8c2d1e257..ffb2007ca105f 100644 --- a/.github/workflows/docker-cache-rocm.yml +++ b/.github/workflows/docker-cache-rocm.yml @@ -37,7 +37,7 @@ jobs: pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }} steps: - name: Download artifacts - uses: actions/download-artifact@v4.1.7 + uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e #4.1.7 with: run-id: ${{ github.event.workflow_run.id || github.event.inputs.run_id }} path: ./docker-builds-artifacts From b7f6b9a4fc6259f7af068f31868b3119bb1bac3e Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 28 Nov 2025 18:46:50 +0000 Subject: [PATCH 054/338] [cuDNN] Leak `BenchmarkCache` to avoid teardown segfault, correct compile-time cuDNN version to 9.10.2 in CI (#169153) Apparent segfaults observed since 9.10.0 (thanks @nWEIdia for obtaining the stack trace!) seem to be due to teardown of `ExecutionPlan` objects in the benchmark cache, we attempt to workaround this by leaking the `BenchmarkCache` at teardown. The backend version of cuDNN in CI in the 12.8 build is also upgraded as a litmus test to verify that this change works. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169153 Approved by: https://github.com/Skylion007 --- .ci/docker/common/install_cuda.sh | 2 +- aten/src/ATen/native/cudnn/Conv_v8.cpp | 35 ++++++++++++++++++-------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index fe2f9ae3185a3..fe0cb8cc79c4f 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -129,7 +129,7 @@ function install_129 { } function install_128 { - CUDNN_VERSION=9.8.0.87 + CUDNN_VERSION=9.10.2.21 echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" # install CUDA 12.8.1 in the same container install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index 75ab950e19bbb..7bc7a80cbb891 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -350,11 +350,26 @@ struct BenchmarkCache { // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to // be thread safe across all engines see Limitations in // https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html -thread_local BenchmarkCache - benchmark_cache; -thread_local BenchmarkCache - benchmark_cache_fused; +// +// We also leak them due to apparent teardown segfaults observed since cuDNN +// version 9.10+ +BenchmarkCache* +_get_benchmark_cache() { + static thread_local BenchmarkCache< + cudnn_frontend::ExecutionPlan, + CacheKeyWrapper>* benchmark_cache = + new BenchmarkCache(); + return benchmark_cache; +} +BenchmarkCache* +_get_benchmark_cache_fused() { + static thread_local BenchmarkCache< + cudnn_frontend::ExecutionPlan, + CacheKeyFusedWrapper>* benchmark_cache_fused = + new BenchmarkCache(); + return benchmark_cache_fused; +} } // namespace void run_conv_plan( @@ -876,7 +891,7 @@ void try_plans( for (auto& plan : plans) { try { run_conv_plan(handle, x, y, w, plan, operation); - benchmark_cache.update(key, plan); + _get_benchmark_cache()->update(key, plan); return; } catch (cudnn_frontend::cudnnException&) { } catch (CuDNNError&) { @@ -900,7 +915,7 @@ void try_plans_fused( for (auto& plan : plans) { try { run_conv_plan_fused(handle, x, y, w, z, b, plan); - benchmark_cache_fused.update(key, plan); + _get_benchmark_cache_fused()->update(key, plan); return; } catch (cudnn_frontend::cudnnException&) { } catch (CuDNNError&) { @@ -931,7 +946,7 @@ bool try_configs( continue; } run_conv_plan(handle, x, y, w, plan, operation); - benchmark_cache.update(key, plan); + _get_benchmark_cache()->update(key, plan); return true; } catch (cudnn_frontend::cudnnException&) { } catch (CuDNNError&) { @@ -962,7 +977,7 @@ bool try_configs_fused( continue; } run_conv_plan_fused(handle, x, y, w, z, b, plan); - benchmark_cache_fused.update(key, plan); + _get_benchmark_cache_fused()->update(key, plan); return true; } catch (cudnn_frontend::cudnnException&) { } catch (CuDNNError&) { @@ -998,7 +1013,7 @@ void run_single_conv( deterministic, allow_tf32); // TODO: is this thread safe if cache is updated? is pointer stale? - auto search = benchmark_cache.find(key); + auto search = _get_benchmark_cache()->find(key); if (search) { try { run_conv_plan(handle, x, y, w, *search, operation); @@ -1098,7 +1113,7 @@ void run_fused_conv( groups, deterministic, allow_tf32); - auto search = benchmark_cache_fused.find(key); + auto search = _get_benchmark_cache_fused()->find(key); if (search) { try { run_conv_plan_fused(handle, x, y, w, z, b, *search); From 4c246677784c6a14bc2dbb9ff8773ef0a3a3222f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 28 Nov 2025 20:02:15 +0000 Subject: [PATCH 055/338] Revert "Fix local_map default partitioner issue (#168396)" This reverts commit 89891302d444d9d84e9072f78767c77eceffbfa2. Reverted https://github.com/pytorch/pytorch/pull/168396 on behalf of https://github.com/atalman due to breaks https://github.com/pytorch/pytorch/issues/169221 ([comment](https://github.com/pytorch/pytorch/pull/168396#issuecomment-3590271910)) --- torch/_higher_order_ops/local_map.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch/_higher_order_ops/local_map.py b/torch/_higher_order_ops/local_map.py index 1d4ad631ea102..7970acbc5d6ad 100644 --- a/torch/_higher_order_ops/local_map.py +++ b/torch/_higher_order_ops/local_map.py @@ -334,13 +334,6 @@ def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]: static_lifetime_input_indices=[], ) - # Fix tags because min-cut does not respect fw/bw boundary, breaking - # default partitioner's assumptions. - for node in new_fw_gm.graph.nodes: - node.meta["partitioner_tag"] = "is_forward" - for node in new_bw_gm.graph.nodes: - node.meta["partitioner_tag"] = "is_backward" - # Propagate meta onto fw/bw graphs, later will be set on proxied nodes new_fw_gm.meta["local_map_kwargs"] = local_map_kwargs new_bw_gm.meta["local_map_kwargs"] = {**local_map_kwargs} From 409a5fee945c46a3edaf5df162812f201bfd7b2f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 28 Nov 2025 20:27:32 +0800 Subject: [PATCH 056/338] [pytree][dynamo] make `tree_*` functions accept both Python and C++ `PyTreeSpec` (#152624) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152624 Approved by: https://github.com/Lucaskabela --- test/test_pytree.py | 40 ++++++++- torch/_dynamo/polyfills/pytree.py | 7 +- torch/utils/_cxx_pytree.py | 34 ++++++-- torch/utils/_pytree.py | 135 +++++++++++++++++++----------- 4 files changed, 158 insertions(+), 58 deletions(-) diff --git a/test/test_pytree.py b/test/test_pytree.py index 09cf0bbd47a43..92ab336e6e7bc 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -22,6 +22,7 @@ parametrize, run_tests, subtest, + TEST_WITH_TORCHDYNAMO, TestCase, ) @@ -52,6 +53,14 @@ def __init__(self, x, y): self.x = x self.y = y + def __eq__(self, other): + if not isinstance(other, GlobalDummyType): + return NotImplemented + return self.x == other.x and self.y == other.y + + def __hash__(self): + return hash((self.x, self.y)) + cxx_pytree.register_pytree_node( GlobalDummyType, @@ -1490,6 +1499,25 @@ def setUp(self): if IS_FBCODE: raise unittest.SkipTest("C++ pytree tests are not supported in fbcode") + def assertEqual(self, x, y, *args, **kwargs): + x_typename, y_typename = type(x).__name__, type(y).__name__ + if not ("treespec" in x_typename.lower() or "treespec" in y_typename.lower()): + super().assertEqual(x, y, *args, **kwargs) + + # The Dynamo polyfill returns a polyfilled Python class for C++ PyTreeSpec instead of the + # C++ class. So we compare the type names and reprs instead because the types themselves + # won't be equal. + super().assertEqual(x_typename, y_typename, *args, **kwargs) + if not TEST_WITH_TORCHDYNAMO or type(x) is type(y): + super().assertEqual(x, y, *args, **kwargs) + else: + super().assertEqual( + x.unflatten(range(x.num_leaves)), + y.unflatten(range(y.num_leaves)), + *args, + **kwargs, + ) + def test_treespec_equality(self): self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf()) @@ -1530,7 +1558,9 @@ def test_pytree_serialize(self, spec): serialized_spec = cxx_pytree.treespec_dumps(spec) self.assertIsInstance(serialized_spec, str) - self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec)) + + roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec) + self.assertEqual(roundtrip_spec, spec) def test_pytree_serialize_namedtuple(self): python_pytree._register_namedtuple( @@ -1563,6 +1593,14 @@ def __init__(self, x, y): self.x = x self.y = y + def __eq__(self, other): + if not isinstance(other, LocalDummyType): + return NotImplemented + return self.x == other.x and self.y == other.y + + def __hash__(self): + return hash((self.x, self.y)) + cxx_pytree.register_pytree_node( LocalDummyType, lambda dummy: ([dummy.x, dummy.y], None), diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index 63a72afa43a6d..f5f9c18303336 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -23,6 +23,7 @@ ) import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 +import torch.utils._pytree as python_pytree from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES from ..decorators import substitute_in_graph @@ -430,8 +431,8 @@ def unflatten(self, leaves: Iterable[Any], /) -> PyTree: return self._unflatten_func(self._metadata, subtrees) -def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: - return isinstance(obj, PyTreeSpec) +def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]: + return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec)) @substitute_in_graph( # type: ignore[arg-type] @@ -701,7 +702,7 @@ def tree_structure( def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: if not _is_pytreespec_instance(treespec): raise TypeError( - f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " + f"Expected `treespec` to be an instance of " f"PyTreeSpec but got item of type {type(treespec)}." ) return treespec.unflatten(leaves) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index f9350124d135a..e88209398302b 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -13,6 +13,7 @@ """ import functools +import sys import types from collections.abc import Callable, Iterable, Mapping from typing import Any, overload, TypeAlias, TypeVar, Union @@ -266,8 +267,20 @@ def _private_register_pytree_node( ) -def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]: - return isinstance(obj, TreeSpec) +def _is_pytreespec_instance( + obj: Any, + /, +) -> TypeIs[Union[TreeSpec, python_pytree.PyTreeSpec]]: + if isinstance(obj, (TreeSpec, python_pytree.PyTreeSpec)): + return True + if "torch._dynamo.polyfills.pytree" in sys.modules: + # The PyTorch Dynamo pytree module is not always available, so we check if it is loaded. + # If the PyTorch Dynamo pytree module is loaded, we can check if the treespec + # is an instance of the PyTorch Dynamo TreeSpec class. + import torch._dynamo.polyfills.pytree as dynamo_pytree + + return isinstance(obj, dynamo_pytree.PyTreeSpec) + return False def treespec_leaf() -> TreeSpec: @@ -394,7 +407,15 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: The reconstructed pytree, containing the ``leaves`` placed in the structure described by ``treespec``. """ - return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type] + if not _is_pytreespec_instance(treespec): + if not _is_pytreespec_instance(leaves): + raise TypeError( + f"Expected `treespec` to be an instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + # Allow passing the PyTreeSpec instance as the first argument + leaves, treespec = treespec, leaves + return treespec.unflatten(leaves) def tree_iter( @@ -959,8 +980,9 @@ def _broadcast_to_and_flatten( is_leaf: Callable[[PyTree], bool] | None = None, ) -> list[Any] | None: if not _is_pytreespec_instance(treespec): - raise AssertionError( - f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}" + raise TypeError( + f"Expected `treespec` to be an instance of " + f"PyTreeSpec but got item of type {type(treespec)}." ) full_tree = tree_unflatten([0] * treespec.num_leaves, treespec) try: @@ -973,7 +995,7 @@ def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str: """Serialize a treespec to a JSON string.""" if not _is_pytreespec_instance(treespec): raise TypeError( - f"treespec_dumps(treespec): Expected `treespec` to be instance of " + f"Expected `treespec` to be an instance of " f"PyTreeSpec but got item of type {type(treespec)}." ) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 16877719718af..eca0c0c7ab5c7 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -20,6 +20,7 @@ import importlib import importlib.metadata import json +import sys import threading import types import warnings @@ -35,15 +36,20 @@ NoReturn, overload, Protocol, + TYPE_CHECKING, TypeAlias, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self +from typing_extensions import deprecated, NamedTuple, Self, TypeIs from torch.torch_version import TorchVersion as _TorchVersion +if TYPE_CHECKING: + import torch.utils._cxx_pytree as cxx_pytree + + __all__ = [ "PyTree", "Context", @@ -249,9 +255,9 @@ def register_pytree_node( return if _cxx_pytree_imported: - from . import _cxx_pytree as cxx + import torch.utils._cxx_pytree as cxx_pytree - cxx._private_register_pytree_node( + cxx_pytree._private_register_pytree_node( cls, flatten_fn, unflatten_fn, @@ -1176,12 +1182,12 @@ def child(self, index: int) -> Self: return self._children[index] def flatten_up_to(self, tree: PyTree) -> list[PyTree]: - def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: + def helper(treespec: TreeSpec, node: PyTree, subtrees: list[PyTree]) -> None: if treespec.is_leaf(): - subtrees.append(tree) + subtrees.append(node) return - node_type = _get_node_type(tree) + node_type = _get_node_type(node) if treespec.type not in BUILTIN_TYPES: # Always require custom node types to match exactly if node_type != treespec.type: @@ -1190,7 +1196,7 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: f"expected {treespec.type!r}, but got {node_type!r}.", ) flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - children, context = flatten_fn(tree) + children, context = flatten_fn(node) if len(children) != treespec.num_children: raise ValueError( f"Node arity mismatch; " @@ -1212,10 +1218,10 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: f"Node type mismatch; " f"expected {treespec.type!r}, but got {node_type!r}.", ) - if len(tree) != treespec.num_children: + if len(node) != treespec.num_children: raise ValueError( f"Node arity mismatch; " - f"expected {treespec.num_children}, but got {len(tree)}.", + f"expected {treespec.num_children}, but got {len(node)}.", ) if both_standard_dict: @@ -1227,7 +1233,7 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: else treespec._context[1] ) expected_keys = dict_context - got_key_set = set(tree) + got_key_set = set(node) expected_key_set = set(expected_keys) if got_key_set != expected_key_set: missing_keys = expected_key_set.difference(got_key_set) @@ -1238,11 +1244,11 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: if extra_keys: message += f"; extra key(s): {extra_keys}" raise ValueError(f"Node keys mismatch{message}.") - children = [tree[key] for key in expected_keys] + children = [node[key] for key in expected_keys] else: # node_type is treespec.type flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - children, context = flatten_fn(tree) + children, context = flatten_fn(node) if ( node_type is not deque # ignore mismatch of `maxlen` for deque ) and context != treespec._context: @@ -1366,6 +1372,44 @@ def treespec_dict( return TreeSpec(dict, list(dct.keys()), list(dct.values())) +def _is_pytreespec_instance( + obj: Any, +) -> TypeIs[Union[TreeSpec, "cxx_pytree.PyTreeSpec"]]: + if isinstance(obj, TreeSpec): + return True + if "torch.utils._cxx_pytree" in sys.modules: + # The C++ pytree module is not always available, so we check if it is loaded. + # If the C++ pytree module is loaded, we can check if the treespec + # is an instance of the C++ TreeSpec class. + import torch.utils._cxx_pytree as cxx_pytree + + if isinstance(obj, cxx_pytree.PyTreeSpec): + return True + if "torch._dynamo.polyfills.pytree" in sys.modules: + # The PyTorch Dynamo pytree module is not always available, so we check if it is loaded. + # If the PyTorch Dynamo pytree module is loaded, we can check if the treespec + # is an instance of the PyTorch Dynamo TreeSpec class. + import torch._dynamo.polyfills.pytree as dynamo_pytree + + return isinstance(obj, dynamo_pytree.PyTreeSpec) + return False + + +def _ensure_python_treespec_instance( + treespec: Union[TreeSpec, "cxx_pytree.PyTreeSpec"], +) -> TreeSpec: + if isinstance(treespec, TreeSpec): + return treespec + + if not _is_pytreespec_instance(treespec): + raise TypeError( + f"Expected `treespec` to be an instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + dummy_tree = treespec.unflatten([0] * treespec.num_leaves) + return tree_structure(dummy_tree) + + def tree_flatten( tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, @@ -1396,11 +1440,14 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: """Given a list of values and a TreeSpec, builds a pytree. This is the inverse operation of `tree_flatten`. """ - if not isinstance(treespec, TreeSpec): - raise TypeError( - f"tree_unflatten(leaves, treespec): Expected `treespec` to be " - f"instance of TreeSpec but got item of type {type(treespec)}.", - ) + if not _is_pytreespec_instance(treespec): + if not _is_pytreespec_instance(leaves): + raise TypeError( + f"Expected `treespec` to be an instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + # Allow passing the PyTreeSpec instance as the first argument + leaves, treespec = treespec, leaves return treespec.unflatten(leaves) @@ -1830,35 +1877,31 @@ def _broadcast_to_and_flatten( treespec: TreeSpec, is_leaf: Callable[[PyTree], bool] | None = None, ) -> list[Any] | None: - if not isinstance(treespec, TreeSpec): - raise AssertionError("treespec must be a TreeSpec") - - if tree_is_leaf(tree, is_leaf=is_leaf): - return [tree] * treespec.num_leaves - if treespec.is_leaf(): - return None - node_type = _get_node_type(tree) - if node_type != treespec.type: - return None - - flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - child_pytrees, context = flatten_fn(tree) + def broadcast_prefix( + prefix_tree: PyTree, + full_tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> list[Any]: + result: list[Any] = [] + + def add_leaves(x: Any, subtree: PyTree) -> None: + subtreespec = tree_structure(subtree, is_leaf=is_leaf) + result.extend([x] * subtreespec.num_leaves) + + tree_map_( + add_leaves, + prefix_tree, + full_tree, + is_leaf=is_leaf, + ) + return result - # Check if the Node is different from the spec - if len(child_pytrees) != treespec.num_children or context != treespec._context: + full_tree = tree_unflatten([0] * treespec.num_leaves, treespec) + try: + return broadcast_prefix(tree, full_tree, is_leaf=is_leaf) + except ValueError: return None - # Recursively flatten the children - result: list[Any] = [] - for child, child_spec in zip(child_pytrees, treespec._children, strict=True): - flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) - if flat is not None: - result += flat - else: - return None - - return result - @dataclasses.dataclass class _TreeSpecSchema: @@ -1971,11 +2014,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str: - if not isinstance(treespec, TreeSpec): - raise TypeError( - f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " - f"TreeSpec but got item of type {type(treespec)}.", - ) + treespec = _ensure_python_treespec_instance(treespec) if protocol is None: protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL From 1ccb743b7b5be955f49736c162c4f5004b8a0dd8 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 28 Nov 2025 20:12:49 +0800 Subject: [PATCH 057/338] [BE][4/5] fix typos in aten/ (aten/src/ATen/native/) (#157553) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157553 Approved by: https://github.com/albanD --- .lintrunner.toml | 1 - aten/src/ATen/native/cpu/UpSampleKernel.cpp | 2 +- aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h | 4 ++-- aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu | 2 +- aten/src/ATen/native/cuda/Blas.cpp | 2 +- aten/src/ATen/native/cuda/CUDAJitLoops.cuh | 2 +- aten/src/ATen/native/cuda/Dropout.cu | 4 ++-- aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu | 2 +- aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu | 2 +- aten/src/ATen/native/cuda/GridSampler.cu | 8 ++++---- aten/src/ATen/native/cuda/GroupMM.cu | 2 +- aten/src/ATen/native/cuda/GroupedBlas.cpp | 2 +- aten/src/ATen/native/cuda/IGammaKernel.cu | 4 ++-- aten/src/ATen/native/cuda/IndexKernel.cu | 2 +- aten/src/ATen/native/cuda/Indexing.cu | 2 +- aten/src/ATen/native/cuda/KernelUtils.cuh | 4 ++-- aten/src/ATen/native/cuda/LogAddExpKernel.cu | 2 +- aten/src/ATen/native/cuda/LossCTC.cu | 2 +- aten/src/ATen/native/cuda/Math.cuh | 2 +- aten/src/ATen/native/cuda/MemoryAccess.cuh | 2 +- aten/src/ATen/native/cuda/MultinomialKernel.cu | 2 +- aten/src/ATen/native/cuda/Normalization.cuh | 4 ++-- aten/src/ATen/native/cuda/Randperm.cu | 2 +- aten/src/ATen/native/cuda/Reduce.cuh | 4 ++-- aten/src/ATen/native/cuda/ReflectionPad.cu | 2 +- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 2 +- aten/src/ATen/native/cuda/ScaledGroupMM.cu | 2 +- aten/src/ATen/native/cuda/group_norm_kernel.cu | 2 +- aten/src/ATen/native/cuda/jit_utils.cpp | 10 +++++----- aten/src/ATen/native/cuda/layer_norm_kernel.cu | 2 +- aten/src/ATen/native/cudnn/MHA.cpp | 8 ++++---- aten/src/ATen/native/hip/ck_bgemm_bfloat16.hip | 2 +- aten/src/ATen/native/hip/ck_gemm_bfloat16.hip | 4 ++-- aten/src/ATen/native/hip/ck_gemm_float.hip | 2 +- aten/src/ATen/native/hip/ck_gemm_half.hip | 4 ++-- aten/src/ATen/native/metal/MetalShaders.h | 2 +- aten/src/ATen/native/metal/ops/MetalNeurons.mm | 2 +- aten/src/ATen/native/mkldnn/xpu/detail/Attr.h | 4 ++-- aten/src/ATen/native/mps/OperationUtils.mm | 6 +++--- aten/src/ATen/native/mps/kernels/GridSampler.metal | 6 +++--- aten/src/ATen/native/mps/kernels/Indexing.metal | 2 +- aten/src/ATen/native/mps/kernels/Quantized.metal | 2 +- aten/src/ATen/native/mps/kernels/UnaryKernel.metal | 2 +- aten/src/ATen/native/mps/kernels/UpSample.metal | 2 +- aten/src/ATen/native/nested/NestedTensorMath.cpp | 8 ++++---- tools/linter/dictionary.txt | 7 ++++--- 46 files changed, 74 insertions(+), 74 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 0f46b398ca501..cd5d338b63639 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1113,7 +1113,6 @@ exclude_patterns = [ # These files are all grandfathered in, feel free to remove from this list # as necessary # NOTE: remove the patterns in the order they are listed - 'aten/src/ATen/native/[a-pA-P]*/**', 'aten/src/ATen/[a-mA-M]*/**', 'test/**', ] diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index e59e5985bf7f3..79583b59edaf1 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -1017,7 +1017,7 @@ struct HelperInterpBase { while (aligned_interp_size % sizeof(int32_t) != 0) { aligned_interp_size += 1; } - // assert that we wont go out of bounds + // assert that we won't go out of bounds TORCH_INTERNAL_ASSERT(aligned_interp_size * sizeof(int16_t) < interp_size * sizeof(double)); } diff --git a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h index 146c60e5cd0fa..c1bf79dfa44e6 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h +++ b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h @@ -655,7 +655,7 @@ void ImagingResampleHorizontalConvolution8u4x( // last element auto mmk = _mm256_set1_epi32(k[i]); // For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes - // lines 0, 1 and 2 wont go out of allocated memory bounds + // lines 0, 1 and 2 won't go out of allocated memory bounds auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256( mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)), mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1); @@ -1312,7 +1312,7 @@ void ImagingResampleVerticalConvolution8u( // Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3 // It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data. - // We also wont go out of bounds of lineOut memory allocation + // We also won't go out of bounds of lineOut memory allocation std::memcpy(lineOut + j, (uint8_t *) &o, 4); } diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu index 47c705a667b52..e1ef5e2204dac 100644 --- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu +++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu @@ -705,7 +705,7 @@ namespace { ); } while (!done && max_threads); if (!done) { - TORCH_INTERNAL_ASSERT(false, "Couldn't reduce launch bounds to accomodate sharedMemPerBlock limit"); + TORCH_INTERNAL_ASSERT(false, "Couldn't reduce launch bounds to accommodate sharedMemPerBlock limit"); } break; } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 75a4d357a1c0b..7fe95e86b6299 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -182,7 +182,7 @@ static bool isInputCompliesAddmmCudaLt( // NOTE: row-major result is important when bias is 1D. // This is because Lt broadcasts 1D bias over the columns // while the aten::addmm API broadcasts it over the rows, - // and this is in conjuction with the data preparation + // and this is in conjunction with the data preparation // procedure that does not transpose arguments with // col-major result. For col-major result we need // to explicitly transpose the problem so that bias is diff --git a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh index c4c3af83ccd80..384b1f61771e0 100644 --- a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh +++ b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh @@ -298,7 +298,7 @@ static void jitted_gpu_kernel_impl( at::opmath_type scalar_val, const std::tuple& extra_args) { - // TODO: Memory use can probably be optimized by re-using kernels across GPUs with + // TODO: Memory use can probably be optimized by reusing kernels across GPUs with // the same compute capability static std::mutex jiterator_mutex; static std::vector device_caches(c10::cuda::device_count()); diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index 9c1a6e046de78..fe63594f272cf 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -75,7 +75,7 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo // We'll use this to actually cause vectorized loads later LoadT *value = reinterpret_cast(&src); - //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything + //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for Halfs, so generate float for everything // Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4) // sets of rand. if ((VEC >= 4) || (gridxvec_loop_state == 0)) { @@ -159,7 +159,7 @@ fused_dropout_kernel(cuda::detail::TensorInfo a, for (IndexType linearIndex = idx; linearIndex < rounded_size; linearIndex += gridDim.x * blockDim.x*UNROLL) { -//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything +//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for Halfs, so generate float for everything float4 rand = curand_uniform4(&state); scalar_t src[UNROLL]; rand.x = rand.x < p; diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu index 6ce419137345f..250a05898bea4 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu @@ -24,7 +24,7 @@ namespace at::native { namespace { /* This code computes the sum of the weights in two-steps: - 1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indeces` + 1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indices` 2) Each partial-sum from 1) are summed and scatter into `grad_weight` Notice, `NROWS_PER_THREAD` impacts the Achieved Occupancy of the diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu index 9ac0e875b2d68..93d05b9db3987 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu @@ -204,7 +204,7 @@ Scalar scalar_reciprocal(const Scalar& scalar) { return Scalar(1. / scalar.toComplexDouble()); } TORCH_INTERNAL_ASSERT( - false, "divison with ", scalar.type(), " not supported"); + false, "division with ", scalar.type(), " not supported"); } void foreach_tensor_div_scalar_kernel_cuda_( diff --git a/aten/src/ATen/native/cuda/GridSampler.cu b/aten/src/ATen/native/cuda/GridSampler.cu index 2c9128eee2217..6ef8edef3f516 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cu +++ b/aten/src/ATen/native/cuda/GridSampler.cu @@ -57,7 +57,7 @@ namespace { const index_t n = index / (out_H * out_W); const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; - // get the corresponding input x, y co-ordinates from grid + // get the corresponding input x, y coordinates from grid opmath_t x = grid.data[grid_offset]; opmath_t y = grid.data[grid_offset + grid_sCoor]; @@ -193,7 +193,7 @@ namespace { const index_t n = index / (out_D * out_H * out_W); const index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid opmath_t x = grid.data[grid_offset]; opmath_t y = grid.data[grid_offset + grid_sCoor]; opmath_t z = grid.data[grid_offset + 2 * grid_sCoor]; @@ -358,7 +358,7 @@ namespace { const index_t n = index / (out_H * out_W); const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; - // get the corresponding input x, y co-ordinates from grid + // get the corresponding input x, y coordinates from grid scalar_t x = grid.data[grid_offset]; scalar_t y = grid.data[grid_offset + grid_sCoor]; @@ -572,7 +572,7 @@ namespace { const index_t n = index / (out_D * out_H * out_W); const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid scalar_t ix = grid.data[grid_offset]; scalar_t iy = grid.data[grid_offset + grid_sCoor]; scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu index 3f4f998d92cd6..aa55c02e48138 100644 --- a/aten/src/ATen/native/cuda/GroupMM.cu +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -8,7 +8,7 @@ #include -// Three warninngs in Cutlass included header files +// Three warnings in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index f4b229156d79f..2052f344adf64 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -528,7 +528,7 @@ _scaled_grouped_mm_cuda_v2( "Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ", mat_b.size(dim_b)); // Note: only (-1, -2) is currently supported - TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only"); + TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Currently contraction dims must be (-1, -2) only"); } else { TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); } diff --git a/aten/src/ATen/native/cuda/IGammaKernel.cu b/aten/src/ATen/native/cuda/IGammaKernel.cu index 73db6272be9ef..63b5cc1be700b 100644 --- a/aten/src/ATen/native/cuda/IGammaKernel.cu +++ b/aten/src/ATen/native/cuda/IGammaKernel.cu @@ -377,7 +377,7 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { * result at the boundary * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for * Large Parameter (see DLMF 8.12.4 [igam1]) - * - if x > 1.1 and x < a, using the substraction from the regularized lower + * - if x > 1.1 and x < a, using the subtraction from the regularized lower * incomplete gamma * - otherwise, calculate the series from [igam2] eq (5) */ @@ -460,7 +460,7 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) { * result at the boundary * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for * Large Parameter (see DLMF 8.12.3 [igam1]) - * - if x > 1 and x > a, using the substraction from the regularized upper + * - if x > 1 and x > a, using the subtraction from the regularized upper * incomplete gamma * - otherwise, calculate the series from [igam2] eq (4) */ diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index db85f62c8d124..04b0756817d51 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -323,7 +323,7 @@ void cuda_take_put_kernel( const auto offset_calc = make_offset_calculator<2>(iter); using uindex_t = std::make_unsigned_t; - // OffsetCalculator needs the sizes and strides reveresed + // OffsetCalculator needs the sizes and strides reversed const auto indexed_sizes = std::vector(indexed.sizes().rbegin(), indexed.sizes().rend()); const auto indexed_strides = std::vector(indexed.strides().rbegin(), indexed.strides().rend()); const auto* indexed_strides_data = indexed_strides.data(); diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index dacef18c79b68..d8a87774ce72c 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -1611,7 +1611,7 @@ void index_select_out_cuda_impl( // SmallIndexKernel is more performant when the number of indices is small, and pre-loading // the index reduces memory accesses. When the number of indices is large, we avoid that - // and increase parallellism by calling gather_out which is a generalization of index_select + // and increase parallelism by calling gather_out which is a generalization of index_select if (cuda::detail::canUse32BitIndexMath(out) && cuda::detail::canUse32BitIndexMath(self) && cuda::detail::canUse32BitIndexMath(index) && diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 5c8b98105bb26..a400bb19988a9 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -269,7 +269,7 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd( scalar_t* dst = self_ptr + index; - //pack coalseced bf16 and fp16 + //pack coalesced bf16 and fp16 if constexpr (std::is_same::value || std::is_same::value) { typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; @@ -312,7 +312,7 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd( } } - // not coalsced, so now let try to capture lane-matches... + // not coalesced, so now let try to capture lane-matches... if (numel > 16 /*<-hueristic threshold*/ * 64 ) { // well shucks, unlikely to capture same-dest atomics in a wave. diff --git a/aten/src/ATen/native/cuda/LogAddExpKernel.cu b/aten/src/ATen/native/cuda/LogAddExpKernel.cu index 910d3c1cddc93..90356f51a668a 100644 --- a/aten/src/ATen/native/cuda/LogAddExpKernel.cu +++ b/aten/src/ATen/native/cuda/LogAddExpKernel.cu @@ -70,7 +70,7 @@ __host__ __device__ c10::complex _fast_build_exp_inf(const c10::comple // this function only handles the case where the real part of x is infinite const auto ximag = std::imag(x); constexpr auto exp_x_abs = std::numeric_limits::infinity(); - if (!::isfinite(ximag)) { // add this to make consitent with std::exp(x+yi) + if (!::isfinite(ximag)) { // add this to make consistent with std::exp(x+yi) return {exp_x_abs, std::numeric_limits::quiet_NaN()}; } const auto sin = std::sin(ximag); diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index 4c5eabd049687..b1bce2948a5a0 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -343,7 +343,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data, if (input_length == 0) return; - // "first" row, the beta initialization before eq (10) (t=target_length - differes per batch) + // "first" row, the beta initialization before eq (10) (t=target_length - differs per batch) for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) { int64_t s = threadIdx.x + block_s; scalar_t lb; diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 1fa245af1a4d1..cc43c6015c9e2 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -816,7 +816,7 @@ const auto erfcx_string = jiterator_stringify( with the usual checks for overflow etcetera. Performance-wise, it seems to be substantially faster than either - the SLATEC DERFC function [or an erfcx function derived therefrom] + the SLATEC DERFC function [or an erfcx function derived there from] or Cody's CALERF function (from netlib.org/specfun), while retaining near machine precision in accuracy. */ diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index d29ba35393a08..373b44cca7901 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -370,7 +370,7 @@ struct vectorized { #ifdef USE_ROCM // This is similar to vectorized policy above, but this one supports -// heterogenous input tensor types as templated parameters. +// heterogeneous input tensor types as templated parameters. // Its use should be limited to frequently used heterogeneous data types // as each instantiation will generate a separate kernel, leading to code // bloating if applied to all combinations supported in PyTorch. Assumption: all diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 8132e7df57b51..c5668c9af3b00 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -309,7 +309,7 @@ __global__ void sampleMultinomialOnce( } else { // This should address a rare bug where we don't select a valid index. This likely occurs when // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but - // and our uniform sample is greater than this value. In this case we likely have unitialized memory + // and our uniform sample is greater than this value. In this case we likely have uninitialized memory // in dest[curDist]. So basically we will loop through the distribution and pick the largest index // where the distribution is non-zero. This is obviously terribly inefficient, but due to the // rarity in which this occurs, this should not be an issue. diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index d211adc3f6a78..bbd65419bbb92 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -1654,7 +1654,7 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; - // Input is guarunteed to be channels-last compatible + // Input is guaranteed to be channels-last compatible at::Tensor grad_input = at::empty_like(input); dim3 block; @@ -1722,7 +1722,7 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const auto reduction_size = input.numel() / stride; auto norm_fct = 1.0 / reduction_size; - // Input is guarunteed to be channels-last compatible + // Input is guaranteed to be channels-last compatible at::Tensor grad_input = at::empty_like(input); dim3 block; diff --git a/aten/src/ATen/native/cuda/Randperm.cu b/aten/src/ATen/native/cuda/Randperm.cu index bde5457e8cdd8..4764a51d46a5c 100644 --- a/aten/src/ATen/native/cuda/Randperm.cu +++ b/aten/src/ATen/native/cuda/Randperm.cu @@ -37,7 +37,7 @@ namespace at::native { // threshold probability for having non-duplicate keys, then it can be proved that[1] // the number of bits required is: ceil(log2(n - (6 n^2 + 1) / (12 log(q)))) // -// Then after sort, we lauch a separate kernel that additionally shuffles any islands +// Then after sort, we launch a separate kernel that additionally shuffles any islands // of values whose keys matched. The algorithm of this kernel is as follows: // Each thread reads its key and the keys of its neighbors to tell if it's part of an island. // For each island, the first thread in the island sees a key match at index i+1 but not index i-1. diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 22d82df5f205f..91cd5a2a09938 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1086,12 +1086,12 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ // load instructions. // // Case 1: "vectorize along input" - // This case happens when we are reducing along fastest moving dimesion. In such case, threads + // This case happens when we are reducing along fastest moving dimension. In such case, threads // with the same threadIdx.y works on the same reduction cooperatively and will produce results // for the same output. In such case, values in each loaded vector always correspond to the same output. // // Case 2: "vectorize along output" - // This case happens when the fastest moving dimesion is not the dimension of reduction. In such case, + // This case happens when the fastest moving dimension is not the dimension of reduction. In such case, // threads with different threadIdx.x are independent and will produce results for different outputs. // In such case, values in each loaded vector always correspond to different outputs. if (fastest_moving_stride == sizeof(scalar_t)) { diff --git a/aten/src/ATen/native/cuda/ReflectionPad.cu b/aten/src/ATen/native/cuda/ReflectionPad.cu index 228f0321026f5..935471dad5c13 100644 --- a/aten/src/ATen/native/cuda/ReflectionPad.cu +++ b/aten/src/ATen/native/cuda/ReflectionPad.cu @@ -273,7 +273,7 @@ __global__ void reflection_pad2d_backward_det_out_kernel( const int64_t dist_cols = ::abs(inp_col - (input_dim_x - 1)); // we were dist_rows after, now we want to be dist_rows before - // we were dist_cols before, now we wnat to be dist_cols after + // we were dist_cols before, now we want to be dist_cols after const int64_t reflect_tr_out_row = (corner_tr_out_row - dist_rows); const int64_t reflect_tr_out_col = (corner_tr_out_col + dist_cols); const int64_t reflect_tr_out = diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 8971e05094651..032228e7abc05 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -5,7 +5,7 @@ #include #include -// Two warninngs in Cutlass included header files +// Two warnings in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-field-initializers") diff --git a/aten/src/ATen/native/cuda/ScaledGroupMM.cu b/aten/src/ATen/native/cuda/ScaledGroupMM.cu index 71c9c8dac766d..4b1d186d58e01 100644 --- a/aten/src/ATen/native/cuda/ScaledGroupMM.cu +++ b/aten/src/ATen/native/cuda/ScaledGroupMM.cu @@ -7,7 +7,7 @@ #include #include -// Two warninngs in Cutlass included header files +// Two warnings in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index d144a9954ed33..0ef6434f909de 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -457,7 +457,7 @@ __global__ void GammaBetaBackwardCUDAKernel2( } } - // Do warp reduce for the 2st 16 cols in the tile. + // Do warp reduce for the 2nd 16 cols in the tile. sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y]; sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y]; sum1 = cuda_utils::WarpReduceSum(sum1); diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 5c0cb1d534db1..fc788d7a0254e 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -1556,19 +1556,19 @@ NvrtcFunction jit_pwise_function( ss << '_' << hash_code; file_path = ss.str(); - std::ifstream readin{file_path, std::ios::in | std::ifstream::binary}; - if (readin.fail()) { + std::ifstream read_stream{file_path, std::ios::in | std::ifstream::binary}; + if (read_stream.fail()) { // NOTE: this does not warn because the file might not exist // TODO: consider if this should explicitly check for the file's existence or not to throw // an informative warning - readin.close(); + read_stream.close(); } else { // TODO: try passing the "mapped" file directly to cuModuleLoadCall instead of using an intermediate buffer - std::vector buffer(std::istreambuf_iterator(readin), {}); + std::vector buffer(std::istreambuf_iterator(read_stream), {}); AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), buffer.data())); AT_CUDA_DRIVER_CHECK( nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str())); - readin.close(); + read_stream.close(); return compiled_kernel_; } } diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 937008f1e83bd..6f5112c605fab 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -1049,7 +1049,7 @@ void launch_vectorized_layer_norm_kernel( C10_CUDA_KERNEL_LAUNCH_CHECK(); #ifdef USE_ROCM - // the blocks.x contains the max grid x dimention without invalid configuration error + // the blocks.x contains the max grid x dimension without invalid configuration error // Fix invalid configuration https://github.com/pytorch/pytorch/issues/136291 // Ensure all elements are processed. Prepare for next round int64_t remaining = M - blocks.x; diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 7604244997bcf..504688f203333 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -177,7 +177,7 @@ bool use_ragged_in_dense( TORCH_WARN_ONCE( "TORCH_CUDNN_SDPA_AVOID_RECOMPILE=1 only works with Q, K, V, and output in BSHD memory layout," "e.g., Q, K, V must be allocated with torch.randn((B, S, H, D).transpose(1, 2)." - "Falling back to regualr dense case, which may trigger excessive recompilation."); + "Falling back to regular dense case, which may trigger excessive recompilation."); } return all_bshd; } @@ -771,7 +771,7 @@ std::unique_ptr build_graph_nestedtensor( if (attn_bias.has_value()) { TORCH_CHECK( false, - "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + "attn_bias not yet supported with cuDNN Attention and NestedTensor"); scaled_dot_product_flash_attention_options.set_bias( mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(BIAS) @@ -1196,7 +1196,7 @@ std::unique_ptr build_graph_backward_nestedtensor( if (attn_bias.has_value()) { TORCH_CHECK( false, - "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + "attn_bias not yet supported with cuDNN Attention and NestedTensor"); sdpa_backward_options.set_bias( mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(BIAS) @@ -1864,7 +1864,7 @@ void run_cudnn_SDP_bprop_nestedtensor( } TORCH_CHECK( !attn_bias.has_value(), - "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + "attn_bias not yet supported with cuDNN Attention and NestedTensor"); auto workspace_size = mha_graph.get_workspace_size(); auto workspace_ptr = diff --git a/aten/src/ATen/native/hip/ck_bgemm_bfloat16.hip b/aten/src/ATen/native/hip/ck_bgemm_bfloat16.hip index 3872edb37f332..ea3bc875e0f19 100644 --- a/aten/src/ATen/native/hip/ck_bgemm_bfloat16.hip +++ b/aten/src/ATen/native/hip/ck_bgemm_bfloat16.hip @@ -30,7 +30,7 @@ static const std::unordered_map< }; -// This is the heursitic to choose a kernel based on inputs +// This is the heuristic to choose a kernel based on inputs BGEMMKernel_BFloat16 dispatch_bfloat16_bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { // Optional/future use: directly lookup shape tuples to map to instances /* diff --git a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip index 0050e8419e850..c223644e12920 100644 --- a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip +++ b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip @@ -11,7 +11,7 @@ using S = ck::Sequence; namespace at::native { void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { - // If any of the shapes cant be tiled, we must use padding. + // If any of the shapes can't be tiled, we must use padding. bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); // Dispatch to best implementation. // TODO add more configurations. Optimize. @@ -471,7 +471,7 @@ void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { } void dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { - // If any of the shapes cant be tiled, we must use padding. + // If any of the shapes can't be tiled, we must use padding. bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); // Dispatch to best implementation. // TODO add more configurations. Optimize. diff --git a/aten/src/ATen/native/hip/ck_gemm_float.hip b/aten/src/ATen/native/hip/ck_gemm_float.hip index c4fea6088d3f0..16c796c5270e3 100644 --- a/aten/src/ATen/native/hip/ck_gemm_float.hip +++ b/aten/src/ATen/native/hip/ck_gemm_float.hip @@ -11,7 +11,7 @@ using S = ck::Sequence; namespace at::native { void dispatch_float_gemm(CUDABLAS_GEMM_ARGTYPES(float)) { - // If any of the shapes cant be tiled, we must use padding. + // If any of the shapes can't be tiled, we must use padding. bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); // Dispatch to best implementation. // TODO add more configurations. Optimize. diff --git a/aten/src/ATen/native/hip/ck_gemm_half.hip b/aten/src/ATen/native/hip/ck_gemm_half.hip index 1b39283f9f944..75cbeec7c085f 100644 --- a/aten/src/ATen/native/hip/ck_gemm_half.hip +++ b/aten/src/ATen/native/hip/ck_gemm_half.hip @@ -13,7 +13,7 @@ namespace at::native { void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { #if 0 - // If any of the shapes cant be tiled, we must use padding. + // If any of the shapes can't be tiled, we must use padding. bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); // Dispatch to best implementation. // TODO add more configurations. Optimize. @@ -299,7 +299,7 @@ void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { #endif } void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) { - // If any of the shapes cant be tiled, we must use padding. + // If any of the shapes can't be tiled, we must use padding. bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); // Dispatch to best implementation. // TODO add more configurations. Optimize. diff --git a/aten/src/ATen/native/metal/MetalShaders.h b/aten/src/ATen/native/metal/MetalShaders.h index 3fcc84173d396..81ea5daf3403b 100644 --- a/aten/src/ATen/native/metal/MetalShaders.h +++ b/aten/src/ATen/native/metal/MetalShaders.h @@ -545,7 +545,7 @@ kernel void reshape(texture2d_array in_arr[[texture(0), func const ushort slices2 = divRoundUp(C2, 4); const ushort slices1 = divRoundUp(C1, 4); const ushort n2 = gid.z / slices2; //image index - const ushort s2 = gid.z - n2 * slices2; // slice offest + const ushort s2 = gid.z - n2 * slices2; // slice offset half4 value; for (int idx = 0; idx < 4; ++idx){ // we compute the "linear index" of the output element, diff --git a/aten/src/ATen/native/metal/ops/MetalNeurons.mm b/aten/src/ATen/native/metal/ops/MetalNeurons.mm index 09944092f6a1c..4e928949ae4c4 100644 --- a/aten/src/ATen/native/metal/ops/MetalNeurons.mm +++ b/aten/src/ATen/native/metal/ops/MetalNeurons.mm @@ -86,4 +86,4 @@ static Tensor tanh(const Tensor& input) { m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), TORCH_FN(hardsigmoid_)); } -} // namepsace at::native::metal +} // namespace at::native::metal diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h b/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h index 49a249b5aea84..a5f084dba0be8 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h @@ -34,7 +34,7 @@ namespace at::native::onednn { /* oneDNN postops usage: - Currently, oneDNN supports 5 kinds of post ops. More details can be refered + Currently, oneDNN supports 5 kinds of post ops. More details can be referred to oneDNN doc. https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise @@ -399,7 +399,7 @@ static inline void construct_attr_for_unary( } else { TORCH_CHECK( unary_post_op == "none", - "onednn qlinear: unspported unary post op", + "onednn qlinear: unsupported unary post op", unary_post_op); } } diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 196d514a2c580..d5ed84aec5617 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -845,7 +845,7 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} break; } default: - TORCH_INTERNAL_ASSERT(false, "Unsupported number of paramaters ", nparams); + TORCH_INTERNAL_ASSERT(false, "Unsupported number of parameters ", nparams); } return libMap[key] = lib; } @@ -1173,9 +1173,9 @@ static dispatch_data_t getSectionData(const std::string& name) { } void MetalKernelFunction::dispatch(c10::ArrayRef length, c10::OptionalArrayRef group_size) { - TORCH_CHECK(!length.empty() && length.size() < 4, "Dispatch dimentions must be less than 3 and non-empty"); + TORCH_CHECK(!length.empty() && length.size() < 4, "Dispatch dimensions must be less than 3 and non-empty"); TORCH_CHECK(!group_size.has_value() || group_size->size() == length.size(), - "size and group_size must have same number of dimentions"); + "size and group_size must have same number of dimensions"); const auto max_tg_size = getMaxThreadsPerThreadgroup(); const auto group_size_length = group_size.has_value() ? group_size->size() : 0; auto tg_size = MTLSizeMake(group_size_length > 0 ? group_size->at(0) : max_tg_size, diff --git a/aten/src/ATen/native/mps/kernels/GridSampler.metal b/aten/src/ATen/native/mps/kernels/GridSampler.metal index 84bfbb57f8f03..fa66ff5e6a0b8 100644 --- a/aten/src/ATen/native/mps/kernels/GridSampler.metal +++ b/aten/src/ATen/native/mps/kernels/GridSampler.metal @@ -59,7 +59,7 @@ static GridSamplerOffsets find_grid_sampler_offsets( return offsets; } -// Mod function which gives postive output when `a` is negative +// Mod function which gives positive output when `a` is negative static int32_t mod(int32_t a, int32_t b) { auto r = a % b; return r + (r < 0 ? b : 0); @@ -191,9 +191,9 @@ void grid_sampler_single_element( int32_t right_indices[3]; opmath_t scales[3]; - // For each dimension, find the pair of indices in the cooresponding dimension + // For each dimension, find the pair of indices in the corresponding dimension // of `input` which surround the grid coordinate in that dimension. We'll do - // this by mapping different coordiante spaces onto each other. There are + // this by mapping different coordinate spaces onto each other. There are // basically three different coordinate spaces to keep in mind: // // * aligned grid space diff --git a/aten/src/ATen/native/mps/kernels/Indexing.metal b/aten/src/ATen/native/mps/kernels/Indexing.metal index ebe078d01781e..09fe380b4c2b3 100644 --- a/aten/src/ATen/native/mps/kernels/Indexing.metal +++ b/aten/src/ATen/native/mps/kernels/Indexing.metal @@ -178,7 +178,7 @@ kernel void index_put_serial( constant uint4& ndim_nindices_numel, device ErrorMessages* error_buffer, uint thread_index [[thread_position_in_grid]]) { - (void)thread_index; // Suppress unused vairable varning + (void)thread_index; // Suppress unused variable warning for (uint idx = 0; idx < ndim_nindices_numel.z; ++idx) { index_put_impl( output, diff --git a/aten/src/ATen/native/mps/kernels/Quantized.metal b/aten/src/ATen/native/mps/kernels/Quantized.metal index b84c033a07f49..a3f9a42457da5 100644 --- a/aten/src/ATen/native/mps/kernels/Quantized.metal +++ b/aten/src/ATen/native/mps/kernels/Quantized.metal @@ -112,7 +112,7 @@ kernel void int4pack_mm(constant T *A [[buffer(0)]], constant uchar *B_ptr = B + ((n * K) / k_pack_factor); thread float4 result = float4(0.0); - // We multipy group of 4 channels with these scales. + // We multiply group of 4 channels with these scales. // Because corresponding values from weight matrix are effectively left // shifted. This is to avoid doing right shift on those values which ends up // affecting performance. This is the trick applied in MLX kernels. diff --git a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal index a6ec9d036dce3..3779d4be7b7bb 100644 --- a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal @@ -387,7 +387,7 @@ struct log1p_functor { } template inline enable_if_t, T> operator()(const T x) { - // TODO: Implement proper log1p algoirthm + // TODO: Implement proper log1p algorithm auto magnitude = ::precise::sqrt((1.0f + x.x) * (1.0f + x.x) + x.y * x.y); auto real = ::precise::log(magnitude); auto imag = (x.x == -1 && x.y == 0) ? 0 : ::precise::atan2(x.y, 1.0 + x.x); diff --git a/aten/src/ATen/native/mps/kernels/UpSample.metal b/aten/src/ATen/native/mps/kernels/UpSample.metal index 393c9e1b4d422..fa9b5a1bb107d 100644 --- a/aten/src/ATen/native/mps/kernels/UpSample.metal +++ b/aten/src/ATen/native/mps/kernels/UpSample.metal @@ -448,7 +448,7 @@ kernel void upsample_trilinear_backward( // See Note [ Weights computation for uint8_t and multiplication trick ] // Essentially fall back to fixed floating point arithmetic during uint8 -// interpolation, which is not necesserily more accurate (see example below), +// interpolation, which is not necessarily more accurate (see example below), // but matches closes to what CPU can deliver // I.e. mid-point 152+249+172+35 is 152, but algorithm yields 153 as horizontal // and vertical interpolation is done in separate steps and results are rounded diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 8956890a88b72..318bbb3728a85 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -42,7 +42,7 @@ Tensor pad_tensor_to_shape( const Tensor& t, IntArrayRef goal_shape, double value = 0) { - std::vector padd; + std::vector padding; auto tup = t.sizes(); TORCH_CHECK( t.dim() == (int64_t)(goal_shape.size()), @@ -52,10 +52,10 @@ Tensor pad_tensor_to_shape( goal_shape.size(), " of goal shape."); for (int64_t i = static_cast(tup.size()) - 1; i >= 0; i--) { - padd.push_back(0); - padd.push_back(goal_shape[i] - tup[i]); + padding.push_back(0); + padding.push_back(goal_shape[i] - tup[i]); } - Tensor new_tensor = at::constant_pad_nd(t, IntArrayRef(padd), value); + Tensor new_tensor = at::constant_pad_nd(t, IntArrayRef(padding), value); new_tensor = new_tensor.reshape(goal_shape); return new_tensor; } diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index c4a250db04836..7668a4bca228d 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -28,15 +28,16 @@ inp inps inpt inpts -matA -matB -matC +mata +matb +matc nd nin NotIn nout NowNs numer +OffsetT oH optins ot From f4dedf78fc30fd4b93975787ca6074ee89db9467 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 28 Nov 2025 20:12:50 +0800 Subject: [PATCH 058/338] [BE][5/5] fix typos in aten/ (aten/src/ATen/) (#157554) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157554 Approved by: https://github.com/yewentao256, https://github.com/albanD ghstack dependencies: #157553 --- .lintrunner.toml | 1 - aten/src/ATen/core/DimVector.h | 2 +- aten/src/ATen/core/GeneratorForPrivateuseone.cpp | 2 +- aten/src/ATen/core/IListRef.h | 4 ++-- aten/src/ATen/core/IListRef_inl.h | 2 +- aten/src/ATen/core/Variadic.h | 2 +- aten/src/ATen/core/dispatch/Dispatcher.cpp | 8 ++++---- aten/src/ATen/core/dispatch/Dispatcher.h | 5 +++-- aten/src/ATen/core/ivalue_inl.h | 2 +- aten/src/ATen/core/jit_type.h | 4 ++-- aten/src/ATen/cpu/vec/vec512/vec512_float8.h | 8 ++++---- aten/src/ATen/cuda/CUDAGreenContext.cpp | 2 +- aten/src/ATen/cuda/CUDASparseDescriptors.cpp | 2 +- aten/src/ATen/cuda/CachingHostAllocator.h | 4 ++-- aten/src/ATen/cuda/detail/TensorInfo.cuh | 2 +- aten/src/ATen/cuda/jiterator.cu | 2 +- aten/src/ATen/functorch/LegacyVmapTransforms.h | 2 +- aten/src/ATen/functorch/TensorWrapper.h | 2 +- 18 files changed, 28 insertions(+), 28 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index cd5d338b63639..8b577c47e0d0a 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1113,7 +1113,6 @@ exclude_patterns = [ # These files are all grandfathered in, feel free to remove from this list # as necessary # NOTE: remove the patterns in the order they are listed - 'aten/src/ATen/[a-mA-M]*/**', 'test/**', ] init_command = [ diff --git a/aten/src/ATen/core/DimVector.h b/aten/src/ATen/core/DimVector.h index 576b9e142ebf1..aadb3fa867f4a 100644 --- a/aten/src/ATen/core/DimVector.h +++ b/aten/src/ATen/core/DimVector.h @@ -3,7 +3,7 @@ namespace at { -// Re-declaring 'DimVector' type and size inside 'at' namespace. +// Redeclaring 'DimVector' type and size inside 'at' namespace. // This is done to avoid modifying every use into their 'c10' // equivalent. diff --git a/aten/src/ATen/core/GeneratorForPrivateuseone.cpp b/aten/src/ATen/core/GeneratorForPrivateuseone.cpp index 030e9f70851a6..7dca153436dbf 100644 --- a/aten/src/ATen/core/GeneratorForPrivateuseone.cpp +++ b/aten/src/ATen/core/GeneratorForPrivateuseone.cpp @@ -16,7 +16,7 @@ _GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) { TORCH_WARN_DEPRECATION( "REGISTER_GENERATOR_PRIVATEUSE1 is deprecated. \ - Please derive PrivateUse1HooksInterface to implememt getNewGenerator instead.") + Please derive PrivateUse1HooksInterface to implement getNewGenerator instead.") TORCH_CHECK( !GetGeneratorPrivate().has_value(), diff --git a/aten/src/ATen/core/IListRef.h b/aten/src/ATen/core/IListRef.h index a11a78c03a3bb..8ea6249f2b699 100644 --- a/aten/src/ATen/core/IListRef.h +++ b/aten/src/ATen/core/IListRef.h @@ -149,7 +149,7 @@ * First, keep in mind that we assume that boxed containers will * have to deal with `IValue` (e.g. `c10::List`). In this context, * what may be happening is that `IValue` doesn't store internally - * your type `T`. Instead, it constructs a type new `T` everytime + * your type `T`. Instead, it constructs a type new `T` every time * you try to get `T` for it (see `IListRef`). */ @@ -186,7 +186,7 @@ class IListRef; * This macro is useful because it allows us to handle different * types (that correspond to different tags) to be implemented * only once. We can do it even when the implementation of the - * different tags aren't syntatically the same, by dispatching + * different tags aren't syntactically the same, by dispatching * it to a function (e.g. `ImplT::(this_)`). */ #define TORCH_ILISTREF_UNWRAP(TAG, BODY) \ diff --git a/aten/src/ATen/core/IListRef_inl.h b/aten/src/ATen/core/IListRef_inl.h index df320c13d9c23..425a80a710f6b 100644 --- a/aten/src/ATen/core/IListRef_inl.h +++ b/aten/src/ATen/core/IListRef_inl.h @@ -42,7 +42,7 @@ class IListRefTagImplBase { /* * We have these function (besides the `unwrap`s above) because the * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*` - * weren't syntatically equal for the existing tags at the time + * weren't syntactically equal for the existing tags at the time * (`Unboxed` and `Boxed`). */ static IListRefConstRef front(const list_type& lst) { diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h index da4df1b1b1a66..f594deb566547 100644 --- a/aten/src/ATen/core/Variadic.h +++ b/aten/src/ATen/core/Variadic.h @@ -12,7 +12,7 @@ namespace at { // in order. This is most commonly used in autogenerated code, // where it is convenient to have a function that can uniformly // take arguments of different types. If your arguments -// are homogenous consider using a std::initializer_list instead. +// are homogeneous consider using a std::initializer_list instead. // // For examples of this in use, see torch/csrc/utils/variadic.h template diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 5facca30a54f3..1291b4d3c3227 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -111,7 +111,7 @@ void Dispatcher::waitForDef(const FunctionSchema& schema) { TORCH_INTERNAL_ASSERT(r, "Expected main interpreter to define ", schema.operator_name(), ", but this didn't happen within timeout. Are you trying to load " - "different models in the same torchdeploy/multipy instance? You " + "different models in the same torchdeploy/multipy instance? You " // codespell:ignore "must warmup each interpreter identically, e.g., import all " "the same dependencies."); } @@ -129,7 +129,7 @@ void Dispatcher::waitForImpl(const OperatorName& op_name, std::optional= 0 && static_cast(idx) < backendFallbackKernels_.size(), "idx=", idx); - // NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time, - // refer to https://github.com/pytorch/pytorch/issues/163979 for more informations. + // NB: Preserve BC for registering fallback for AutogradPrivateUse1 multiple time, + // refer to https://github.com/pytorch/pytorch/issues/163979 for more information. TORCH_CHECK( dispatchKey == DispatchKey::AutogradPrivateUse1 || !backendFallbackKernels_[idx].kernel.isValid(), diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 880de786b708d..6b63bd48009ee 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -222,7 +222,8 @@ class TORCH_API Dispatcher final { return backendFallbackKernels_[dispatch_ix].kernel.isValid(); } - // Used by torchdeploy/multipy for multiple interpreters racing. + // Used by torchdeploy/multipy for multiple // codespell:ignore: multipy + // interpreters racing. void waitForDef(const FunctionSchema& schema); void waitForImpl( const OperatorName& op_name, @@ -414,7 +415,7 @@ class TORCH_API Dispatcher final { std::unique_ptr listeners_; // This condition variable gets notified whenever we add a new def/impl to the - // dispatch table. This is primarily used by multipy/torchdeploy, when + // dispatch table. This is primarily used by multiply/torchdeploy, when // we have multiple interpreters trying to register to the dispatch table. // In this situation, whenever the non-primary interpreter would have tried // to register to the dispatch table, instead it will check to see if the diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index ac7540cffd18f..f384a3ea46f28 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -992,7 +992,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { std::unique_lock lock(mutex_); if (completed_) { // This should be rare and shouldn't cause log spew. Its important to - // log errors and thats why we have this log here. + // log errors and that's why we have this log here. std::string msg = c10::str( "Skipping setting following error on the Future since " "it is already marked completed (this is not necessarily " diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 535831ea11d6e..5378bd0b3d14b 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -887,7 +887,7 @@ struct TORCH_API ListType // this function will return the global singleton type pointer // the type List. // The extra "identifier" argument is needed because we have multiple container types - // that all re-use this function (List, array, etc.) + // that all reuse this function (List, array, etc.) static TypePtr get(const std::string& identifier, TypePtr inner); // common cast List[Tensor] @@ -985,7 +985,7 @@ struct TORCH_API DictType : public SharedType { // this function will return the global singleton type pointer // the type List. // The extra "identifier" argument is needed because we have multiple container types - // that all re-use this function (Dict and unordered_map) + // that all reuse this function (Dict and unordered_map) static TypePtr get(const std::string& identifier, TypePtr key, TypePtr val); private: diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float8.h b/aten/src/ATen/cpu/vec/vec512/vec512_float8.h index 12ee4c460641f..0a54986d82b78 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_float8.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float8.h @@ -498,8 +498,8 @@ static inline Vectorized binary_fp8_op_as_fp32( // Refer to // https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +, -// -, *, /, planed to be deleted in the future and here is just to make compiler -// happy +// -, *, /, planned to be deleted in the future and here is just to make +// compiler happy Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { @@ -585,8 +585,8 @@ class Vectorized : public Vectorizedf8 { // Refer to // https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +, -// -, *, /, planed to be deleted in the future and here is just to make compiler -// happy +// -, *, /, planned to be deleted in the future and here is just to make +// compiler happy Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { diff --git a/aten/src/ATen/cuda/CUDAGreenContext.cpp b/aten/src/ATen/cuda/CUDAGreenContext.cpp index 8aa05b80f82f9..a579e45e16066 100644 --- a/aten/src/ATen/cuda/CUDAGreenContext.cpp +++ b/aten/src/ATen/cuda/CUDAGreenContext.cpp @@ -7,7 +7,7 @@ #define HAS_CUDA_GREEN_CONTEXT() 1 #else #define HAS_CUDA_GREEN_CONTEXT() 0 -// Suppress unsued private field warnings as this class is not supposed to be called +// Suppress unused private field warnings as this class is not supposed to be called C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field") #endif diff --git a/aten/src/ATen/cuda/CUDASparseDescriptors.cpp b/aten/src/ATen/cuda/CUDASparseDescriptors.cpp index d5f04df55f9c2..c7ab4fbfc95df 100644 --- a/aten/src/ATen/cuda/CUDASparseDescriptors.cpp +++ b/aten/src/ATen/cuda/CUDASparseDescriptors.cpp @@ -179,7 +179,7 @@ CuSparseSpMatCsrDescriptor::CuSparseSpMatCsrDescriptor(const Tensor& input, int6 batch_offset * values_batch_stride * values.itemsize(), index_type, // data type of row offsets index index_type, // data type of col indices - CUSPARSE_INDEX_BASE_ZERO, // base index of row offset and col indes + CUSPARSE_INDEX_BASE_ZERO, // base index of row offset and col index value_type // data type of values )); diff --git a/aten/src/ATen/cuda/CachingHostAllocator.h b/aten/src/ATen/cuda/CachingHostAllocator.h index b9486314b1c21..53b0cdced4c18 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.h +++ b/aten/src/ATen/cuda/CachingHostAllocator.h @@ -10,7 +10,7 @@ namespace at::cuda { // // A caching allocator for CUDA host allocations (pinned memory). // -// This provides a drop-in replacement for THCudaHostAllocator, which re-uses +// This provides a drop-in replacement for THCudaHostAllocator, which reuses // freed pinned (page-locked) memory allocations. This avoids device // synchronizations due to cudaFreeHost calls. // @@ -26,7 +26,7 @@ inline TORCH_CUDA_CPP_API at::HostAllocator* getCachingHostAllocator() { } // Records an event in the specified stream. The allocation corresponding to the -// input `ptr`/`ctx` will not be re-used until the event has occurred. +// input `ptr`/`ctx` will not be reused until the event has occurred. C10_DEPRECATED_MESSAGE( "at::cuda::CachingHostAllocator_recordEvent(...) is deprecated. Please use at::getHostAllocator(at::kCUDA)->record_event(...) instead.") inline TORCH_CUDA_CPP_API bool CachingHostAllocator_recordEvent( diff --git a/aten/src/ATen/cuda/detail/TensorInfo.cuh b/aten/src/ATen/cuda/detail/TensorInfo.cuh index a320000ae881f..9f3f7d31add5c 100644 --- a/aten/src/ATen/cuda/detail/TensorInfo.cuh +++ b/aten/src/ATen/cuda/detail/TensorInfo.cuh @@ -93,7 +93,7 @@ struct IndexToOffset { } }; -// Uses dynamic (runtime) instead of static (compiletime) dims +// Uses dynamic (runtime) instead of static (compile time) dims template struct IndexToOffset { static inline __host__ __device__ IndexType get( diff --git a/aten/src/ATen/cuda/jiterator.cu b/aten/src/ATen/cuda/jiterator.cu index d664c828bdad6..0545c8354eda3 100644 --- a/aten/src/ATen/cuda/jiterator.cu +++ b/aten/src/ATen/cuda/jiterator.cu @@ -32,7 +32,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic( // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) // fn_ptr is set to the appropriate function based on the vec size and GPU used - // TODO: Memory use can probably be optimized by re-using kernels across GPUs with + // TODO: Memory use can probably be optimized by reusing kernels across GPUs with // the same compute capability std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype); diff --git a/aten/src/ATen/functorch/LegacyVmapTransforms.h b/aten/src/ATen/functorch/LegacyVmapTransforms.h index 390989d45bf73..bf21951f22268 100644 --- a/aten/src/ATen/functorch/LegacyVmapTransforms.h +++ b/aten/src/ATen/functorch/LegacyVmapTransforms.h @@ -143,7 +143,7 @@ struct TORCH_API VmapPhysicalView { // mapping a physical tensor to a new logical tensor (BatchedTensor) VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; - // Maps a logical shape to a physical shape by pre-pending the batch + // Maps a logical shape to a physical shape by prepending the batch // sizes to the logical shape. VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const; diff --git a/aten/src/ATen/functorch/TensorWrapper.h b/aten/src/ATen/functorch/TensorWrapper.h index bf7b14fd41689..281682fa8bc0a 100644 --- a/aten/src/ATen/functorch/TensorWrapper.h +++ b/aten/src/ATen/functorch/TensorWrapper.h @@ -27,7 +27,7 @@ namespace at::functorch { // // There are alternative designs we could have chosen (e.g. each grad transform // stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper -// design is that we can re-use existing VariableType kernels (i.e. Autograd kernels) +// design is that we can reuse existing VariableType kernels (i.e. Autograd kernels) // without much modification. Since a TensorWrapper looks like a regular Tensor, // the VariableType kernel can pull out the AutogradMeta struct from where it // expects and extend the autograd graph From 6864e309092a71f8ab0ca6a4dc7f8a4073fd31c4 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 28 Nov 2025 20:12:51 +0800 Subject: [PATCH 059/338] [BE][1/6] fix typos in test/ (#157635) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157635 Approved by: https://github.com/yewentao256, https://github.com/albanD ghstack dependencies: #157553, #157554 --- .lintrunner.toml | 4 - .../libtorch_agnostic_2_10/ops.py | 2 +- test/distributed/_pycute/test_complement.py | 2 +- .../fsdp/test_distributed_checkpoint.py | 2 +- test/distributed/tensor/test_random_ops.py | 2 +- .../test_aten_comm_compute_reordering.py | 2 +- test/distributed/test_inductor_collectives.py | 2 +- test/fx/test_partitioner_order.py | 16 +-- test/higher_order_ops/test_invoke_subgraph.py | 2 +- test/inductor/test_flex_attention.py | 4 +- test/inductor/test_fxir_backend.py | 2 +- test/inductor/test_memory.py | 6 +- test/inductor/test_provenance_tracing.py | 2 +- test/inductor/test_unbacked_symints.py | 2 +- test/jit/fixtures_srcs/generate_models.py | 2 +- test/jit/test_class_type.py | 2 +- test/jit/test_freezing.py | 4 +- test/jit/test_list_dict.py | 14 ++- test/jit/test_peephole.py | 4 +- test/jit/test_remove_mutation.py | 2 +- test/jit/test_symbolic_shape_analysis.py | 8 +- test/mobile/model_test/gen_test_model.py | 2 +- .../model_test/update_production_ops.py | 6 +- test/nn/test_convolution.py | 4 +- test/nn/test_module_hooks.py | 2 +- test/nn/test_parametrization.py | 4 +- test/nn/test_pruning.py | 2 +- test/onnx/test_op_consistency.py | 2 +- test/onnx/test_pytorch_jit_onnx.py | 2 +- test/onnx/test_pytorch_onnx_onnxruntime.py | 10 +- test/package/test_glob_group.py | 7 +- test/package/test_model.py | 2 +- test/package/test_package_script.py | 2 +- test/package/test_save_load.py | 3 +- test/profiler/test_profiler.py | 4 +- test/quantization/core/test_quantized_op.py | 8 +- .../quantization/core/test_workflow_module.py | 6 +- .../eager/test_quantize_eager_qat.py | 2 +- test/quantization/fx/test_model_report_fx.py | 16 +-- test/quantization/fx/test_quantize_fx.py | 10 +- test/quantization/jit/test_quantize_jit.py | 12 +- .../pt2e/test_quantize_pt2e_qat.py | 2 +- test/run_doctests.sh | 2 +- test/run_test.py | 4 +- test/scripts/cuda_memcheck_common.py | 2 +- test/test_mps.py | 10 +- test/test_ops.py | 2 +- test/torch_np/numpy_tests/core/test_dtype.py | 4 +- test/torch_np/numpy_tests/core/test_einsum.py | 6 +- .../numpy_tests/core/test_indexing.py | 114 ++++++++++-------- .../numpy_tests/core/test_multiarray.py | 16 +-- .../numpy_tests/lib/test_histograms.py | 2 +- .../numpy_tests/lib/test_index_tricks.py | 2 +- test/torch_np/test_ndarray_methods.py | 8 +- 54 files changed, 191 insertions(+), 175 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 8b577c47e0d0a..9b4c68070571c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1110,10 +1110,6 @@ exclude_patterns = [ 'torch/_inductor/fx_passes/serialized_patterns/**', 'torch/_inductor/autoheuristic/artifacts/**', 'torch/utils/model_dump/preact.mjs', - # These files are all grandfathered in, feel free to remove from this list - # as necessary - # NOTE: remove the patterns in the order they are listed - 'test/**', ] init_command = [ 'python3', diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index b68839dc565c7..d53e481ca4a10 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -210,7 +210,7 @@ def my_shape(t) -> tuple[int]: Args: t: Tensor - input tensor - Returns: tuple - shape of the imput tensor. + Returns: tuple - shape of the input tensor. """ return torch.ops.libtorch_agnostic_2_10.my_shape.default(t) diff --git a/test/distributed/_pycute/test_complement.py b/test/distributed/_pycute/test_complement.py index fd6413bcd112e..e54364732f049 100644 --- a/test/distributed/_pycute/test_complement.py +++ b/test/distributed/_pycute/test_complement.py @@ -52,7 +52,7 @@ def helper_test_complement(self, layout): _LOGGER.debug(f"{layout} => {layoutR}") - # Post-condition: test disjointness of the codomains + # Post-condition: test disjointedness of the codomains for a in range(size(layout)): for b in range(size(layoutR)): assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0) diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index 67f8e1af9abbd..0885e70141e78 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -30,7 +30,7 @@ ) sys.exit(0) -# NB: this iterable needs to be orderd as otherwise different ranks may run with +# NB: this iterable needs to be ordered as otherwise different ranks may run with # conflicting settings when e.g., @parametrize(_DISTRIBUTED_STATE_DICT_IMPLS) is # used to decorate tests _DISTRIBUTED_STATE_DICT_IMPLS = ( diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 4bcddc198836b..15c9be4485379 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -304,7 +304,7 @@ def test_rng_tracker_init(self): + torch.initial_seed() ) torch.distributed.broadcast(seed_local, src=0) - # if localtensor, it should automaticall reconcile after the broadcast + # if local tensor, it should automatically reconcile after the broadcast # since all virtual ranks should have rank 0's initial_seed() seed_from_rank_0 = seed_local diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 0e76da0dbe9c0..966f84ff0ee56 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -397,7 +397,7 @@ def fn(g1, g2, g3): self.rank, self.world_size, self.backend(device_type), fake_pg=True ): # all_reduces remain in order! - # note: this isnt actually invariant of pass currently.. + # note: this isn't actually invariant of pass currently.. # but we should keep collectives stable without reordering opportunities _, code = run_and_get_aten_graph(fn, g1, g2, g3) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 52062616a8562..4be02cbafbe1f 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1680,7 +1680,7 @@ def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): compiled = torch.compile(func) code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) - # shouldnt have bucketed + # shouldn't have bucketed FileCheck().check_count("wait_tensor.default(", 2, exactly=True).run(code) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index f4c3ef072f9a6..670f675f3f94d 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -33,17 +33,17 @@ def forward(self, x): class TestPartitionerOrder(TestCase): - # partitoner test to check graph node order remains the same with the original graph after partitioning + # partitioner test to check graph node order remains the same with the original graph after partitioning def test_partitioner_graph_node_order(self): m = AddModule() traced_m = torch.fx.symbolic_trace(m) origin_node_order = [n.name for n in traced_m.graph.nodes] - partions = DummyPartitioner(traced_m).propose_partitions() - partion_nodes = [list(partition.nodes) for partition in partions] - partition_node_order = [n.name for n in partion_nodes[0]] + partitions = DummyPartitioner(traced_m).propose_partitions() + partition_nodes = [list(partition.nodes) for partition in partitions] + partition_node_order = [n.name for n in partition_nodes[0]] self.assertTrue(partition_node_order == origin_node_order) - # partitoner test to check graph node order remains the same during multiple runs + # partitioner test to check graph node order remains the same during multiple runs def test_partitioner_multiple_runs_order(self): m = AddModule() traced_m = torch.fx.symbolic_trace(m) @@ -52,9 +52,9 @@ def test_partitioner_multiple_runs_order(self): node_order = [n.name for n in partition_nodes[0]] for _ in range(10): traced_m = torch.fx.symbolic_trace(m) - new_partion = DummyPartitioner(traced_m).propose_partitions() - new_partion_nodes = [list(partition.nodes) for partition in new_partion] - new_node_order = [n.name for n in new_partion_nodes[0]] + new_partition = DummyPartitioner(traced_m).propose_partitions() + new_partition_nodes = [list(partition.nodes) for partition in new_partition] + new_node_order = [n.name for n in new_partition_nodes[0]] self.assertTrue(node_order == new_node_order) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 00cb0e7b8b21a..c8a4ac1b67a84 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -2597,7 +2597,7 @@ def forward(self, l_x_: "f32[8, 8]", l_y_: "f32[8, 8]"): """, ) - # High piority - grads are wrong + # High priority - grads are wrong @unittest.expectedFailure def test_grad_accuracy_check(self): class Foo: diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index c095243df7654..13cd35fc67735 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2290,8 +2290,8 @@ def run(q, k, v): def _opaque_mask(b, h, q_idx, kv_idx): ref = ql // frame - mot = kl // frame - limit = (ref + mot) * frame + mot = kl // frame # codespell:ignore + limit = (ref + mot) * frame # codespell:ignore return q_idx < limit block_mask = create_block_mask( diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index 2c232594f3329..8f443cd43edcc 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -516,7 +516,7 @@ def test_dynamic_shapes_precomputed_size(self): def test_dynamic_launch_grid_calc(self): """ - Test the dyanmic launch grid calculation. + Test the dynamic launch grid calculation. """ func = torch.add diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 2bb3cf9d66432..1efcd546720a0 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -242,9 +242,9 @@ def reorder_with_only_dfs( @mock.patch.object(config, "allow_buffer_reuse", False) @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") @config.patch("test_configs.track_memory_lifecycle", "assert") - def test_mutation_size_propogation(self): + def test_mutation_size_propagation(self): """ - This tests correct size propogation in the case of mutations. + This tests correct size propagation in the case of mutations. In this example, buf1 is a mutation of buf0; we should have: * buf0: has size_alloc 2048 and size_free 0; * buf1: has size_alloc 0 and size_free 2048. @@ -444,7 +444,7 @@ def replace_foreach(gm): "allow_buffer_reuse": False, # make sure the mm is at the end so # the earlier deallocation is not at the last step, - # which doesnt distinguish between returned tensors + # which doesn't distinguish between returned tensors # and which tensors are deallocated immediately prior "reorder_for_peak_memory": False, } diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index 3fd27cc02b006..93397b6eae072 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -480,7 +480,7 @@ def get_node_with_target(self, gm, target): @requires_gpu_and_triton # test only works for cuda pattern matcher def test_pattern_matcher_transfer_meta(self): """ - Test that stack trace is transfered when node is decomposed in post_grad_passes + Test that stack trace is transferred when node is decomposed in post_grad_passes """ class Model(torch.nn.Module): diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 2574d2210da60..04c8c0573e99d 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -650,7 +650,7 @@ def fn(x): torch.testing.assert_close(actual, expected) @skipIfXpu( - msg="Invalid SPIR-V modul,https://github.com/intel/torch-xpu-ops/issues/2329" + msg="Invalid SPIR-V module,https://github.com/intel/torch-xpu-ops/issues/2329" ) @skipGPUIf(not HAS_GPU, "requires gpu and triton") @inductor_config.patch({"max_autotune": True}) diff --git a/test/jit/fixtures_srcs/generate_models.py b/test/jit/fixtures_srcs/generate_models.py index 233295cf8b4b9..6935d64cf23bd 100644 --- a/test/jit/fixtures_srcs/generate_models.py +++ b/test/jit/fixtures_srcs/generate_models.py @@ -173,7 +173,7 @@ def get_output_model_version(script_module: torch.nn.Module) -> int: Loop through all test modules. If the corresponding model doesn't exist in `test/jit/fixtures`, generate one. For the following reason, a model won't be exported: -1. The test module doens't cover the changed operator. For example, test_versioned_div_tensor_example_v4 +1. The test module doesn't cover the changed operator. For example, test_versioned_div_tensor_example_v4 is supposed to test the operator aten::div.Tensor. If the model doesn't include this operator, it will fail. The error message includes the actual operator list from the model. diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index 0ae1c3dcfd307..4b5f2ad9a0d77 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -1534,7 +1534,7 @@ def forward(self): def test_class_attribute_wrong_type(self): """ - Test that the error message displayed when convering a class type + Test that the error message displayed when converting a class type to an IValue that has an attribute of the wrong type. """ diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 91ecf6f3629b2..6e13b1d14a58d 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -3321,7 +3321,7 @@ def forward(self, x): scripted = torch.jit.freeze(torch.jit.script(mod)) optimized = torch.jit.optimize_for_inference(scripted) inp = torch.rand([1, 8, 8, 8]) - # a1 cant be inplaced for first use, can for second + # a1 can't be inplaced for first use, can for second FileCheck().check("ScalarMul(").check("ScalarMul_").run(optimized.graph) self.assertEqual(optimized(inp), mod(inp)) @@ -3413,7 +3413,7 @@ def __init__(self, tensor): def forward(self, x): # x can't be inplaced because its a return value, - # check that the inplacing pass doesnt try to inplace + # check that the inplacing pass doesn't try to inplace # self.tensor because its always alive return x * self.tensor, x diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 90dbc30d5d790..1949ec46557dd 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -1773,7 +1773,7 @@ def setdefault( return x self.checkScript(setdefault, (self.dict(), "a", torch.randn(2, 2))) - self.checkScript(setdefault, (self.dict(), "nonexistant", torch.randn(2, 2))) + self.checkScript(setdefault, (self.dict(), "nonexistent", torch.randn(2, 2))) @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_update(self): @@ -1894,9 +1894,13 @@ def type_default() -> Dict[str, Tensor]: @torch.jit.script def missing_index(x: Dict[str, int]) -> int: - return x["dne"] + return x["dne"] # codespell:ignore - with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", 'x["dne"'): + with self.assertRaisesRegexWithHighlight( + RuntimeError, + "KeyError", + 'x["dne"', # codespell:ignore + ): missing_index({"item": 20, "other_item": 120}) code = dedent( @@ -2368,7 +2372,7 @@ class TestScriptDict(JitTestCase): The vast majority of tests are for making sure that objects returned by torch.jit.script behave like dictionaries do so that they are fungible - in almost all cirumstances with regular dictionaries. + in almost all circumstances with regular dictionaries. """ def _script_dict_add(self, d: torch._C.ScriptDict, k: int, v: int): @@ -2605,7 +2609,7 @@ class TestScriptList(JitTestCase): The vast majority of tests are for making sure that instances of torch._C.ScriptList behave like lists do so that they are fungible - in almost all cirumstances with regular list. + in almost all circumstances with regular list. """ def _script_list_add(self, l: torch._C.ScriptList, e: int): diff --git a/test/jit/test_peephole.py b/test/jit/test_peephole.py index 12b9c3f18348a..61c443fc1b659 100644 --- a/test/jit/test_peephole.py +++ b/test/jit/test_peephole.py @@ -360,7 +360,7 @@ def foo(x: List[int], b: List[int]): torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) torch._C._jit_pass_constant_propagation(foo.graph) - # cant infer anything + # can't infer anything test_const_tuple_output(foo.graph, []) @torch.jit.script @@ -374,7 +374,7 @@ def foo(x: List[int], b: List[int]): torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) torch._C._jit_pass_constant_propagation(foo.graph) - # we cant infer anything, only len(b) != 4 + # we can't infer anything, only len(b) != 4 test_const_tuple_output(foo.graph, []) @torch.jit.script diff --git a/test/jit/test_remove_mutation.py b/test/jit/test_remove_mutation.py index 3250a86f80453..31230e522b2a9 100644 --- a/test/jit/test_remove_mutation.py +++ b/test/jit/test_remove_mutation.py @@ -292,7 +292,7 @@ def forward(self): FileCheck().check_not("aten::add_").run(mod_script.forward.graph) self.assertEqual(mod(), mod_script()) - # test that the output doesnt alias the input + # test that the output doesn't alias the input for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]: result = torch_op(inputs) sums = [ten.sum() for ten in result] diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index 702fdd851954c..ad1f4fc7a157a 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -85,7 +85,7 @@ def test_write(self): def foo(a, b): return a * b - # broadcast appends cant be removed, so we bail on propagation + # broadcast appends can't be removed, so we bail on propagation torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) FileCheck().check("Tensor = aten::mul").run(foo.graph) @@ -521,7 +521,7 @@ def test_returning_input_symbolic_shapes(self): torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph) ) g = shape_compute_graph.partial_eval_shape_graph() - # to make into a jit function cant have multiple outputs + # to make into a jit function can't have multiple outputs g.makeMultiOutputIntoTuple() func = torch._C._create_function_from_graph("partial_eval_graph", g) out = func([20, 16, 5, 10]) @@ -543,7 +543,7 @@ def test_partial_eval_graph_conv(self): self.assertTrue(output_sizes[i] < 0) self.assertTrue(output_sizes[1] >= 0) g = shape_compute_graph.partial_eval_shape_graph() - # to make into a jit function cant have multiple outputs + # to make into a jit function can't have multiple outputs g.makeMultiOutputIntoTuple() func = torch._C._create_function_from_graph("partial_eval_graph", g) inp = torch.randn(20, 16, 5, 10) @@ -667,7 +667,7 @@ def test_stitching_multi_output(self): outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes() ) g = shape_compute_graph.partial_eval_shape_graph() - # to make into a jit function cant have multiple outputs + # to make into a jit function can't have multiple outputs g.makeMultiOutputIntoTuple() func = torch._C._create_function_from_graph("partial_eval_graph", g) mapping = shape_compute_graph.graph_output_to_symbolic_shape_dim() # noqa: F841 diff --git a/test/mobile/model_test/gen_test_model.py b/test/mobile/model_test/gen_test_model.py index 5e760a739cec7..680e01ba27c70 100644 --- a/test/mobile/model_test/gen_test_model.py +++ b/test/mobile/model_test/gen_test_model.py @@ -92,7 +92,7 @@ # "dynamic_quant_ops": DynamicQuantModule(), "static_quant_ops": StaticQuantModule(), "fused_quant_ops": FusedQuantModule(), - # TorchScript buildin ops + # TorchScript builtin ops "torchscript_builtin_ops": TSBuiltinOpsModule(), "torchscript_collection_ops": TSCollectionOpsModule(), # vision diff --git a/test/mobile/model_test/update_production_ops.py b/test/mobile/model_test/update_production_ops.py index dbec56e64261a..7879403b90bbb 100644 --- a/test/mobile/model_test/update_production_ops.py +++ b/test/mobile/model_test/update_production_ops.py @@ -22,9 +22,9 @@ # aggregate occurrence per op traced_operators[op] = 1 + (traced_operators.get(op, 0)) # merge dtypes for each kernel - for kernal, dtypes in info["kernel_metadata"].items(): - new_dtypes = dtypes + (kernel_metadata.get(kernal, [])) - kernel_metadata[kernal] = list(set(new_dtypes)) + for kernel, dtypes in info["kernel_metadata"].items(): + new_dtypes = dtypes + (kernel_metadata.get(kernel, [])) + kernel_metadata[kernel] = list(set(new_dtypes)) # Only test these built-in ops. No custom ops or non-CPU ops. diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 83f4d0ccc9600..b92137ca3430e 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -1063,7 +1063,7 @@ def test_grouped_conv_cudnn_nhwc_support(self): @unittest.skipIf(not TEST_CUDNN, "needs cudnn") def test_conv_cudnn_memory_layout_dominance(self): # desired behavior here is to have the memory_layout of conv.weight to - # dominant the layout of output. + # dominate the layout of output. # which is not the same as current behavior, we'll fix this in # following up PRs and remove the `expectedFailure` tag input = torch.randint( @@ -3659,7 +3659,7 @@ def helper( input_format=input_format, weight_format=weight_format, ) - # test when input channel is 1 and not converted to channels last + # test when input channels is 1 and not converted to channels last helper( nn.Conv2d, 2, diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index 4e8821656b7e1..aedb1d343c0ae 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -1529,7 +1529,7 @@ def hook_pre(mod, grad_output): ): mod(inp.clone(), True) - # Input inplace error should throw an error if we try to re-use the view after they have + # Input inplace error should throw an error if we try to reuse the view after they have # been modified local_inp = inp.clone() out = mod(local_inp, False) diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py index aee8d4df50e6e..5dca91f0d2c80 100644 --- a/test/nn/test_parametrization.py +++ b/test/nn/test_parametrization.py @@ -199,9 +199,7 @@ def forward(self, x): self.assertTrue(parametrize.is_parametrized(model, "bias")) self.assertEqual(model.bias[0].item(), 0.0) self.assertEqual(model.bias[-1].item(), 0.0) - self.assertEqual( - len(list(model.parameters())), 2 - ) # Nothing weird has happpened + self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened # Should not throw sgd = torch.optim.SGD(model.parameters(), lr=0.01) diff --git a/test/nn/test_pruning.py b/test/nn/test_pruning.py index 51078cbcf64fb..451eae8e4a418 100644 --- a/test/nn/test_pruning.py +++ b/test/nn/test_pruning.py @@ -498,7 +498,7 @@ def test_l1_unstructured_pruning_with_importance_scores(self): def test_unstructured_pruning_same_magnitude(self): r"""Since it may happen that the tensor to prune has entries with the same exact magnitude, it is important to check that pruning happens - consistenly based on the bottom % of weights, and not by threshold, + consistently based on the bottom % of weights, and not by threshold, which would instead kill off *all* units with magnitude = threshold. """ AMOUNT = 0.2 diff --git a/test/onnx/test_op_consistency.py b/test/onnx/test_op_consistency.py index 762279b71d851..ee4742b25498e 100644 --- a/test/onnx/test_op_consistency.py +++ b/test/onnx/test_op_consistency.py @@ -192,7 +192,7 @@ "scatter_reduce", # ONNX has not include_self parameter and default is include_self=True mode matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", + reason="ONNX doesn't support include_self=False option", ), skip( "stft", diff --git a/test/onnx/test_pytorch_jit_onnx.py b/test/onnx/test_pytorch_jit_onnx.py index bc3c64ab8679b..1a9c78195afd8 100644 --- a/test/onnx/test_pytorch_jit_onnx.py +++ b/test/onnx/test_pytorch_jit_onnx.py @@ -55,7 +55,7 @@ class _TestJITIRToONNX: ort_providers = ["CPUExecutionProvider"] check_shape = True check_dtype = True - ignore_none = True # True for tracing, and Flase for scripting + ignore_none = True # True for tracing, and False for scripting def run_test(self, graph_ir, example_inputs, parse_tensor_constants=False): graph = torch._C.parse_ir(graph_ir, parse_tensor_constants) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 2e96f70cf56f2..5394ba762c9fa 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -101,7 +101,7 @@ def _construct_tensor_for_quantization_test( """Helper function to generate weights and test inputs in a deterministic way. Due to difference in implementation details between PyTorch and ONNXRuntime, randomly generated - test data for quantization tests can be flaky. To help stablize the test, this helper function is + test data for quantization tests can be flaky. To help stabilize the test, this helper function is used to generate weights and test inputs in a deterministic way. Args: @@ -6697,7 +6697,7 @@ def forward(self, x, y): @skipIfUnsupportedMinOpsetVersion(9) def test_new_empty(self): - class Emtpy(torch.nn.Module): + class Empty(torch.nn.Module): def forward(self, x): return ( x.new_empty(x.shape[0]).fill_(0), @@ -6705,8 +6705,8 @@ def forward(self, x): ) x = torch.randn(2, 3, 4) - self.run_test(Emtpy(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) - self.run_test(Emtpy(), x, remained_onnx_input_idx=[]) + self.run_test(Empty(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]}) + self.run_test(Empty(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_full(self): @@ -9935,7 +9935,7 @@ def forward(self, x: Tensor): self.run_test(MyModule(), x) - @skipScriptTest() # Scripting fails for add lists for opsets < 11. Chek test_derive_index_scripting + @skipScriptTest() # Scripting fails for add lists for opsets < 11. Check test_derive_index_scripting def test_derive_index(self): class MyModule(torch.nn.Module): def forward(self, x: Tensor): diff --git a/test/package/test_glob_group.py b/test/package/test_glob_group.py index f41f2a86f6da2..65c106b364aea 100644 --- a/test/package/test_glob_group.py +++ b/test/package/test_glob_group.py @@ -42,8 +42,11 @@ def test_one_star_middle(self): ) def test_one_star_partial(self): - glob_group = GlobGroup("fo*.bar") - self.assertMatchesGlob(glob_group, ["fo.bar", "foo.bar", "foobar.bar"]) + glob_group = GlobGroup("fo*.bar") # codespell:ignore + self.assertMatchesGlob( + glob_group, + ["fo.bar", "foo.bar", "foobar.bar"], # codespell:ignore + ) self.assertNotMatchesGlob(glob_group, ["oij.bar", "f.bar", "foo"]) def test_one_star_multiple_in_component(self): diff --git a/test/package/test_model.py b/test/package/test_model.py index ea0d2c0788b61..959c683d40b29 100644 --- a/test/package/test_model.py +++ b/test/package/test_model.py @@ -98,7 +98,7 @@ def test_model_save(self): # use the same API to load the package. # The convention is for each model to provide a - # 'model' package with a 'load' function that actual + # 'model' package with a 'load' function that actually # reads the model out of the archive. # How the load function is implemented is up to the diff --git a/test/package/test_package_script.py b/test/package/test_package_script.py index 13c2426f197c3..a9b8165380eeb 100644 --- a/test/package/test_package_script.py +++ b/test/package/test_package_script.py @@ -241,7 +241,7 @@ def test_save_scriptmodules_submod_redefinition(self): """ Test to verify saving multiple ScriptModules with same top module but different submodules works. Submodule is redefined to between - the defintion of the top module to check that the different concrete + the definition of the top module to check that the different concrete types of the modules are thoroughly recognized by serializaiton code. """ diff --git a/test/package/test_save_load.py b/test/package/test_save_load.py index edbba9f6f8ee8..8dd47604822ef 100644 --- a/test/package/test_save_load.py +++ b/test/package/test_save_load.py @@ -110,7 +110,8 @@ def test_bad_dunder_imports(self): buffer = BytesIO() with PackageExporter(buffer) as e: e.save_source_string( - "m", '__import__(these, unresolvable, "things", wont, crash, me)' + "m", + '__import__(these, unresolvable, "things", won, crash, me)', # codespell:ignore ) def test_save_module_binary(self): diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 831f99aafff0a..f8865488fa58e 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -1492,7 +1492,7 @@ def test_profiler_type(self): def test_profiler_correlation_id(self): """ - We expect the correlation_id to be unique across multiple invokation of the profiler, + We expect the correlation_id to be unique across multiple invocation of the profiler, So we will reuse id_uniqueness_set. """ id_uniqueness_set = set() @@ -3276,7 +3276,7 @@ def check_metadata(prof, op_name, metadata_key): check_metadata(prof, op_name="aten::add", metadata_key="Ev Idx") - @unittest.skipIf(not torch.cuda.is_available(), "requries CUDA") + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") def test_profiler_debug_autotuner(self): """ This test makes sure that profiling events will be present when the kernel is run using the DebugAutotuner. diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index ce7eab2050d3a..75bc27453e01a 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -2163,7 +2163,7 @@ def test_qtopk(self): test_cases = itertools.product(x_dims, sides, dims, largest, sorted, dtypes, is_nhwc) k = 2 - for x_dim, side, dim, larg, sort, dtype, nhwc in test_cases: + for x_dim, side, dim, large, sort, dtype, nhwc in test_cases: if nhwc and x_dim != 4: # NHWC requires 4 dimensions continue if dim >= x_dim: # Dimension to find top-k for should exist @@ -2176,12 +2176,12 @@ def test_qtopk(self): qX = qX.permute([0, 3, 1, 2]) X = np.transpose(X, [0, 3, 1, 2]) - unquantized_out = torch.topk(qX.dequantize(), k, dim=dim, largest=larg, sorted=sort) + unquantized_out = torch.topk(qX.dequantize(), k, dim=dim, largest=large, sorted=sort) values = torch.quantize_per_tensor(X, scale, zp, dtype) indices = torch.tensor(X).long() - quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort) + quantized_out = torch.topk(qX, k, dim=dim, largest=large, sorted=sort) assert len(unquantized_out) == len(quantized_out) torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0]) @@ -5983,7 +5983,7 @@ def test_benchmark(self): "out_channel:", out_channel, "kernel_size:", kernel_size, "height:", height, - "widht:", width + "width:", width ) conv = torch.nn.Conv2d(in_channel, out_channel, kernel_size).cuda() input = torch.randn((batch_size, in_channel, height, width), device='cuda') diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 93993fe33a49c..3c0cd31b82a24 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -581,7 +581,7 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type): norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) return norm - assert self.histogram.size()[0] == self.bins, "bins mistmatch" + assert self.histogram.size()[0] == self.bins, "bins mismatch" bin_width = (self.max_val - self.min_val) / self.bins # cumulative sum @@ -808,7 +808,7 @@ def test_histogram_observer_against_reference(self, N, bins, dtype, qscheme, red def test_histogram_observer_extreme_inputs(self): """ Ensures that the HistogramObserver is able to work correctly in - a rare case: extreme samll max values + a rare case: extreme small max values """ obs = HistogramObserver() test_input = torch.tensor( @@ -1139,7 +1139,7 @@ def forward(self, x): def test_syncbn_preserves_qconfig(self): """ Makes sure that if a BatchNorm is not fused and a qconfig exists, - convering the module to SyncBatchNorm preserves the qconfig. + converting the module to SyncBatchNorm preserves the qconfig. """ m = nn.Sequential( nn.Conv2d(1, 1, 1), diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index da67f19488a4f..a6655798c5cff 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -565,7 +565,7 @@ def checkQuantized(model): def test_train_save_load_eval(self): r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict - During eval, we first call prepare_qat and conver on the model and then load the state_dict + During eval, we first call prepare_qat and convert on the model and then load the state_dict and compare results against original model """ for qengine in supported_qengines: diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index adf1fee586723..cab72394ae29d 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -499,7 +499,7 @@ def forward(self, x): - Reset for each epoch is correctly resetting the values Partition on Output -- the calcuation of the ratio is occurring correctly +- the calculation of the ratio is occurring correctly """ @@ -918,7 +918,7 @@ def test_constructor(self): @skipIfNoFBGEMM def test_prepare_model_callibration(self): """ - Tests model_report.prepare_detailed_calibration that prepares the model for callibration + Tests model_report.prepare_detailed_calibration that prepares the model for calibration Specifically looks at: - Whether observers are properly inserted into regular nn.Module - Whether the target and the arguments of the observers are proper @@ -1150,7 +1150,7 @@ def test_qconfig_mapping_generation(self): """ Tests for generation of qconfigs by ModelReport API - Tests that qconfigmapping is generated - - Tests that mappings include information for for relavent modules + - Tests that mappings include information for for relevant modules """ with override_quantized_engine('fbgemm'): # set the backend for this test @@ -1209,7 +1209,7 @@ def test_equalization_mapping_generation(self): """ Tests for generation of qconfigs by ModelReport API - Tests that equalization config generated when input-weight equalization detector used - - Tests that mappings include information for for relavent modules + - Tests that mappings include information for for relevant modules """ with override_quantized_engine('fbgemm'): # set the backend for this test @@ -1305,7 +1305,7 @@ def get_example_inputs(self): return (torch.arange(27).reshape((1, 3, 3, 3)),) def _get_prepped_for_calibration_model(self, model, detector_set, fused=False): - r"""Returns a model that has been prepared for callibration and corresponding model_report""" + r"""Returns a model that has been prepared for calibration and corresponding model_report""" # pass in necessary inputs to helper example_input = model.get_example_inputs()[0] @@ -1530,7 +1530,7 @@ def get_outlier_inputs(self): def _get_prepped_for_calibration_model(self, model, detector_set, use_outlier_data=False): - r"""Returns a model that has been prepared for callibration and corresponding model_report""" + r"""Returns a model that has been prepared for calibration and corresponding model_report""" # call the general helper function to calibrate example_input = model.get_example_inputs()[0] @@ -1762,7 +1762,7 @@ class TestFxModelReportVisualizer(QuantizationTestCase): def _callibrate_and_generate_visualizer(self, model, prepared_for_callibrate_model, mod_report): r""" - Callibrates the passed in model, generates report, and returns the visualizer + Calibrates the passed in model, generates report, and returns the visualizer """ # now we actually calibrate the model example_input = model.get_example_inputs()[0] @@ -1937,7 +1937,7 @@ def test_generate_tables_single_feat_match(self): self.assertEqual(channel_info_features, 1) def _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused: bool = False): - r"""Returns a model that has been prepared for callibration and corresponding model_report""" + r"""Returns a model that has been prepared for calibration and corresponding model_report""" # set the backend for this test torch.backends.quantized.engine = "fbgemm" diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index f2b3091b75d6c..1b1aada9d34a1 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -523,7 +523,7 @@ def test_fuse_conv_bn_add_relu_by_default(self): @skipIfNoONEDNN def test_fuse_conv_bn_add_relu_lowering(self): """ Test fusion and lowering of Conv2d - (bn -) ReLU - by FX. For onednn backedn only. + by FX. For onednn backend only. """ from torch.ao.quantization.backend_config import get_onednn_backend_config qconfig_mapping = get_default_qconfig_mapping('onednn') @@ -5693,12 +5693,12 @@ def forward(self, x): self.assertTrue( type(mod_prep.untraceable_module_class.linear) is not torch.ao.nn.qat.modules.linear.Linear, - "prepare_qat_fx shold not convert anything inside untraced module classes", + "prepare_qat_fx should not convert anything inside untraced module classes", ) self.assertTrue( type(mod_prep.untraceable_module_name.linear) is not torch.ao.nn.qat.modules.linear.Linear, - "prepare_qat_fx shold not convert anything inside modules named in untraced_module_names", + "prepare_qat_fx should not convert anything inside modules named in untraced_module_names", ) def test_qconfig_dict_setup(self): @@ -6315,7 +6315,7 @@ def _test_linear_activation_fusion_lowering_helper( @skipIfNoONEDNN def test_linear_leaky_relu_lowering(self): """ Test fusion and lowering of Linear - (bn -) LeakyReLU - by FX. For onednn backedn only. + by FX. For onednn backend only. """ from torch.ao.quantization.backend_config import get_onednn_backend_config qconfig_mapping = get_default_qconfig_mapping('onednn') @@ -6334,7 +6334,7 @@ def test_linear_leaky_relu_lowering(self): @skipIfNoONEDNN def test_linear_tanh_lowering(self): """ Test fusion and lowering of Linear - Tanh - by FX. For onednn backedn only. + by FX. For onednn backend only. """ from torch.ao.quantization.backend_config import get_onednn_backend_config qconfig_mapping = get_default_qconfig_mapping('onednn') diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index 81bdd50adbd43..59e78f8694d8f 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -1069,15 +1069,15 @@ def forward(self, x): m = prepare_jit(m, qconfig_dict) # observers for input, output and value between conv1/conv2 assert len(attrs_with_prefix(m, "_observer_")) == 3, ( - "Expected to have 3 obervers" + "Expected to have 3 observers" ) # observer for weight assert len(attrs_with_prefix(m.conv1, "_observer_")) == 1, ( - "Expected to have 1 obervers" + "Expected to have 1 observers" ) # observer for weight assert len(attrs_with_prefix(m.conv2, "_observer_")) == 1, ( - "Expected to have 1 obervers" + "Expected to have 1 observers" ) data = torch.randn(1, 3, 10, 10, dtype=torch.float) @@ -1088,13 +1088,13 @@ def forward(self, x): # check all observers have been removed assert len(attrs_with_prefix(m, "_observer_")) == 0, ( - "Expected to have 0 obervers" + "Expected to have 0 observers" ) assert len(attrs_with_prefix(m.conv1, "_observer_")) == 0, ( - "Expected to have 0 obervers" + "Expected to have 0 observers" ) assert len(attrs_with_prefix(m.conv2, "_observer_")) == 0, ( - "Expected to have 0 obervers" + "Expected to have 0 observers" ) quant_func = ( diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index aa8743c32297f..db394f69d6f6d 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -1093,7 +1093,7 @@ def forward(self, x): permute_out = torch.permute(conv_out, (0, 2, 3, 1)) linear_out = self.linears(permute_out) my_linear_out = self.my_linear(linear_out) - # Hardtanh doesnt get quantized via xnnpack quantizer in this test + # Hardtanh doesn't get quantized via xnnpack quantizer in this test # because it relies on the propagation rules # Need to fix this return torch.nn.functional.hardtanh(my_linear_out) diff --git a/test/run_doctests.sh b/test/run_doctests.sh index 2942e961c9da8..f327ed14184f2 100755 --- a/test/run_doctests.sh +++ b/test/run_doctests.sh @@ -21,7 +21,7 @@ if [[ ! -d "$TORCH_MODPATH" ]] ; then else export XDOCTEST_GLOBAL_EXEC="from torch import nn\nimport torch.nn.functional as F\nimport torch" export XDOCTEST_OPTIONS="+IGNORE_WHITESPACE" - # Note: google wont catch numpy style docstrings (a few exist) but it also wont fail + # Note: google won't catch numpy style docstrings (a few exist) but it also won't fail # on things not intended to be doctests. export XDOCTEST_STYLE="google" xdoctest torch "$TORCH_MODPATH" --style="$XDOCTEST_STYLE" --global-exec "$XDOCTEST_GLOBAL_EXEC" --options="$XDOCTEST_OPTIONS" diff --git a/test/run_test.py b/test/run_test.py index 63285f67a27d4..39b13980c2f04 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -798,7 +798,7 @@ def read_pytest_cache(key: str) -> Any: # skip it and move on sc_command = f"--scs={stepcurrent_key}" print_to_file( - "Test succeeeded in new process, continuing with the rest of the tests" + "Test succeeded in new process, continuing with the rest of the tests" ) elif num_failures[current_failure] >= 3: # This is for log classifier so it can prioritize consistently @@ -2157,7 +2157,7 @@ def __str__(self): if IS_CI: for test, _ in all_failures: test_stats = test_prioritizations.get_test_stats(test) - print_to_stderr("Emiting td_test_failure_stats_v2") + print_to_stderr("Emitting td_test_failure_stats_v2") emit_metric( "td_test_failure_stats_v2", { diff --git a/test/scripts/cuda_memcheck_common.py b/test/scripts/cuda_memcheck_common.py index 016cb3d035413..82518c88d4bcb 100644 --- a/test/scripts/cuda_memcheck_common.py +++ b/test/scripts/cuda_memcheck_common.py @@ -47,7 +47,7 @@ def __init__(self, lines): def parse(message): """A simple parser that parses the report of cuda-memcheck. This parser is meant to be simple and it only split the report into separate errors and a summary. Where each error is further - splitted into error message and backtrace. No further details are parsed. + split into error message and backtrace. No further details are parsed. A report contains multiple errors and a summary on how many errors are detected. It looks like: diff --git a/test/test_mps.py b/test/test_mps.py index 51f2637e4d55e..bf837e788e74c 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8245,7 +8245,7 @@ def test_inplace_bitwise_not(self, dtype): self.assertEqual(x_mps.cpu(), x_cpu) def test_empty_posneginf(self): - # just to check that it doesnt crash + # just to check that it doesn't crash input_tensor = torch.empty(0, device="mps") out_pos = torch.isposinf(input_tensor) out_neg = torch.isposinf(input_tensor) @@ -8253,7 +8253,7 @@ def test_empty_posneginf(self): self.assertEqual(out_neg.numel(), 0) def test_empty_dot(self): - # just to check that it doesnt crash + # just to check that it doesn't crash a = torch.rand((0), device="mps") b = torch.rand((0), device="mps") self.assertEqual(a.dot(b), a.cpu().dot(b.cpu())) @@ -9667,7 +9667,7 @@ def get_mps_memory_usage(): memory_footprints = [] for _ in range(100): output = F.scaled_dot_product_attention(query, key, value) - # syncronize to wait for the GPU computation to return + # synchronize to wait for the GPU computation to return torch.mps.synchronize() current_mem, driver_mem = get_mps_memory_usage() memory_footprints.append((current_mem, driver_mem)) @@ -12977,8 +12977,8 @@ def test_reduction_utils(self, dtype): idx = 25 x[idx] = torch.nan lib.do_max(z0, z1, x) - self.assertTrue(z0.isnan().all().item(), f"results are {z0}, but all elements shold have been nan") - self.assertTrue((z1 == idx).all().item(), f"results are {z1}, but all elements shold have been {idx}") + self.assertTrue(z0.isnan().all().item(), f"results are {z0}, but all elements should have been nan") + self.assertTrue((z1 == idx).all().item(), f"results are {z1}, but all elements should have been {idx}") @parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16]) def test_atomic_add(self, dtype): diff --git a/test/test_ops.py b/test/test_ops.py index 5f44a3ba0841b..dbcc0567ea1da 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3002,7 +3002,7 @@ def test_0d_tensor_with_python_scalar(self, device, dtype, op): if torch.float not in op.supported_backward_dtypes(device): raise unittest.SkipTest("Does not support autograd") - # skip if operator doesnt support forward AD + # skip if operator doesn't support forward AD if not op.supports_forward_ad: raise unittest.SkipTest("Does not support forward_ad") diff --git a/test/torch_np/numpy_tests/core/test_dtype.py b/test/torch_np/numpy_tests/core/test_dtype.py index 19b41d877ca8d..5f5d1a5dc7563 100644 --- a/test/torch_np/numpy_tests/core/test_dtype.py +++ b/test/torch_np/numpy_tests/core/test_dtype.py @@ -87,7 +87,7 @@ def test_invalid_types(self): assert_raises(TypeError, np.dtype, "l8") assert_raises(TypeError, np.dtype, "L8") - # XXX: what is 'q'? on my 64-bit ubuntu matching it's int64, same as 'l' + # XXX: what is 'q'? on my 64-bit ubuntu machine it's int64, same as 'l' # if np.dtype('q').itemsize == 8: # assert_raises(TypeError, np.dtype, 'q4') # assert_raises(TypeError, np.dtype, 'Q4') @@ -351,7 +351,7 @@ class dt: np.dtype(dt_instance) -@skip(reason="Parameteric dtypes, our stuff is simpler.") +@skip(reason="Parametric dtypes, our stuff is simpler.") @instantiate_parametrized_tests class TestClassGetItem(TestCase): def test_dtype(self) -> None: diff --git a/test/torch_np/numpy_tests/core/test_einsum.py b/test/torch_np/numpy_tests/core/test_einsum.py index 45c1d97474872..8e4dcafc621a4 100644 --- a/test/torch_np/numpy_tests/core/test_einsum.py +++ b/test/torch_np/numpy_tests/core/test_einsum.py @@ -922,7 +922,7 @@ def test_einsum_fixedstridebug(self): tp = np.tensordot(A, B, axes=(0, 0)) assert_equal(es, tp) # The following is the original test case from the bug report, - # made repeatable by changing random arrays to aranges. + # made repeatable by changing random arrays to aranges. # codespell:ignore aranges A = np.arange(3 * 3).reshape(3, 3).astype(np.float64) B = np.arange(3 * 3 * 64 * 64).reshape(3, 3, 64, 64).astype(np.float32) es = np.einsum("cl, cpxy->lpxy", A, B) @@ -1092,7 +1092,7 @@ def test_expand(self): self.optimize_compare("ab,cd,de->abcde") self.optimize_compare("ab,cd,de->be") self.optimize_compare("ab,bcd,cd->abcd") - self.optimize_compare("ab,bcd,cd->abd") + self.optimize_compare("ab,bcd,cd->abd") # codespell:ignore def test_edge_cases(self): # Difficult edge cases for optimization @@ -1105,7 +1105,7 @@ def test_edge_cases(self): self.optimize_compare("ed,fcd,ff,bcf->be") self.optimize_compare("baa,dcf,af,cde->be") self.optimize_compare("bd,db,eac->ace") - self.optimize_compare("fff,fae,bef,def->abd") + self.optimize_compare("fff,fae,bef,def->abd") # codespell:ignore self.optimize_compare("efc,dbc,acf,fd->abe") self.optimize_compare("ba,ac,da->bcd") diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index 16d89c0321984..08af1303c4a3e 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -464,8 +464,8 @@ def test_indexing_array_weird_strides(self): def test_indexing_array_negative_strides(self): # From gh-8264, # core dumps if negative strides are used in iteration - arro = np.zeros((4, 4)) - arr = arro[::-1, ::-1] + arro = np.zeros((4, 4)) # codespell:ignore + arr = arro[::-1, ::-1] # codespell:ignore slices = (slice(None), [0, 1, 2, 3]) arr[slices] = 10 @@ -716,41 +716,41 @@ def _get_multi_index(self, arr, indices): # check if this is fancy indexing (set no_copy). ndim = 0 ellipsis_pos = None # define here mostly to replace all but first. - for i, indx in enumerate(in_indices): - if indx is None: + for i, indx in enumerate(in_indices): # codespell:ignore + if indx is None: # codespell:ignore continue - if isinstance(indx, np.ndarray) and indx.dtype == bool: + if isinstance(indx, np.ndarray) and indx.dtype == bool: # codespell:ignore no_copy = False - if indx.ndim == 0: + if indx.ndim == 0: # codespell:ignore raise IndexError # boolean indices can have higher dimensions - ndim += indx.ndim - fancy_dim += indx.ndim + ndim += indx.ndim # codespell:ignore + fancy_dim += indx.ndim # codespell:ignore continue - if indx is Ellipsis: + if indx is Ellipsis: # codespell:ignore if ellipsis_pos is None: ellipsis_pos = i continue # do not increment ndim counter raise IndexError - if isinstance(indx, slice): + if isinstance(indx, slice): # codespell:ignore ndim += 1 continue - if not isinstance(indx, np.ndarray): + if not isinstance(indx, np.ndarray): # codespell:ignore # This could be open for changes in numpy. # numpy should maybe raise an error if casting to intp # is not safe. It rejects np.array([1., 2.]) but not # [1., 2.] as index (same for ie. np.take). # (Note the importance of empty lists if changing this here) try: - indx = np.array(indx, dtype=np.intp) + indx = np.array(indx, dtype=np.intp) # codespell:ignore except ValueError: raise IndexError from None - in_indices[i] = indx - elif indx.dtype.kind != "b" and indx.dtype.kind != "i": + in_indices[i] = indx # codespell:ignore + elif indx.dtype.kind != "b" and indx.dtype.kind != "i": # codespell:ignore raise IndexError( "arrays used as indices must be of integer (or boolean) type" ) - if indx.ndim != 0: + if indx.ndim != 0: # codespell:ignore no_copy = False ndim += 1 fancy_dim += 1 @@ -771,37 +771,42 @@ def _get_multi_index(self, arr, indices): arr.ndim - ndim ) - for ax, indx in enumerate(in_indices): - if isinstance(indx, slice): + for ax, indx in enumerate(in_indices): # codespell:ignore + if isinstance(indx, slice): # codespell:ignore # convert to an index array - indx = np.arange(*indx.indices(arr.shape[ax])) - indices.append(["s", indx]) + indx = np.arange(*indx.indices(arr.shape[ax])) # codespell:ignore + indices.append(["s", indx]) # codespell:ignore continue - elif indx is None: + elif indx is None: # codespell:ignore # this is like taking a slice with one element from a new axis: indices.append(["n", np.array([0], dtype=np.intp)]) arr = arr.reshape(arr.shape[:ax] + (1,) + arr.shape[ax:]) continue - if isinstance(indx, np.ndarray) and indx.dtype == bool: - if indx.shape != arr.shape[ax : ax + indx.ndim]: + if isinstance(indx, np.ndarray) and indx.dtype == bool: # codespell:ignore + if indx.shape != arr.shape[ax : ax + indx.ndim]: # codespell:ignore raise IndexError try: flat_indx = np.ravel_multi_index( - np.nonzero(indx), arr.shape[ax : ax + indx.ndim], mode="raise" + np.nonzero(indx), # codespell:ignore + arr.shape[ax : ax + indx.ndim], # codespell:ignore + mode="raise", ) except Exception: error_unless_broadcast_to_empty = True # fill with 0s instead, and raise error later - flat_indx = np.array([0] * indx.sum(), dtype=np.intp) + flat_indx = np.array( + [0] * indx.sum(), # codespell:ignore + dtype=np.intp, + ) # concatenate axis into a single one: - if indx.ndim != 0: + if indx.ndim != 0: # codespell:ignore arr = arr.reshape( arr.shape[:ax] - + (np.prod(arr.shape[ax : ax + indx.ndim]),) - + arr.shape[ax + indx.ndim :] + + (np.prod(arr.shape[ax : ax + indx.ndim]),) # codespell:ignore + + arr.shape[ax + indx.ndim :] # codespell:ignore ) - indx = flat_indx + indx = flat_indx # codespell:ignore else: # This could be changed, a 0-d boolean index can # make sense (even outside the 0-d indexed array case) @@ -811,27 +816,30 @@ def _get_multi_index(self, arr, indices): else: # If the index is a singleton, the bounds check is done # before the broadcasting. This used to be different in <1.9 - if indx.ndim == 0: - if indx >= arr.shape[ax] or indx < -arr.shape[ax]: + if indx.ndim == 0: # codespell:ignore + if ( + indx >= arr.shape[ax] # codespell:ignore + or indx < -arr.shape[ax] # codespell:ignore + ): raise IndexError - if indx.ndim == 0: + if indx.ndim == 0: # codespell:ignore # The index is a scalar. This used to be two fold, but if # fancy indexing was active, the check was done later, # possibly after broadcasting it away (1.7. or earlier). # Now it is always done. - if indx >= arr.shape[ax] or indx < -arr.shape[ax]: + if indx >= arr.shape[ax] or indx < -arr.shape[ax]: # codespell:ignore raise IndexError if len(indices) > 0 and indices[-1][0] == "f" and ax != ellipsis_pos: # NOTE: There could still have been a 0-sized Ellipsis # between them. Checked that with ellipsis_pos. - indices[-1].append(indx) + indices[-1].append(indx) # codespell:ignore else: # We have a fancy index that is not after an existing one. # NOTE: A 0-d array triggers this as well, while one may # expect it to not trigger it, since a scalar would not be # considered fancy indexing. num_fancy += 1 - indices.append(["f", indx]) + indices.append(["f", indx]) # codespell:ignore if num_fancy > 1 and not no_copy: # We have to flush the fancy indexes left @@ -841,16 +849,16 @@ def _get_multi_index(self, arr, indices): new_indices.insert(0, ["f"]) ni = 0 ai = 0 - for indx in indices: + for indx in indices: # codespell:ignore ni += 1 - if indx[0] == "f": - new_indices[0].extend(indx[1:]) + if indx[0] == "f": # codespell:ignore + new_indices[0].extend(indx[1:]) # codespell:ignore del new_indices[ni] ni -= 1 - for ax in range(ai, ai + len(indx[1:])): + for ax in range(ai, ai + len(indx[1:])): # codespell:ignore fancy_axes.append(ax) axes.remove(ax) - ai += len(indx) - 1 # axis we are at + ai += len(indx) - 1 # axis we are at # codespell:ignore indices = new_indices # and now we need to transpose arr: arr = arr.transpose(*(fancy_axes + axes)) @@ -858,46 +866,52 @@ def _get_multi_index(self, arr, indices): # We only have one 'f' index now and arr is transposed accordingly. # Now handle newaxis by reshaping... ax = 0 - for indx in indices: - if indx[0] == "f": - if len(indx) == 1: + for indx in indices: # codespell:ignore + if indx[0] == "f": # codespell:ignore + if len(indx) == 1: # codespell:ignore continue # First of all, reshape arr to combine fancy axes into one: orig_shape = arr.shape - orig_slice = orig_shape[ax : ax + len(indx[1:])] + orig_slice = orig_shape[ax : ax + len(indx[1:])] # codespell:ignore arr = arr.reshape( arr.shape[:ax] + (np.prod(orig_slice).astype(int),) - + arr.shape[ax + len(indx[1:]) :] + + arr.shape[ax + len(indx[1:]) :] # codespell:ignore ) # Check if broadcasting works - res = np.broadcast(*indx[1:]) + res = np.broadcast(*indx[1:]) # codespell:ignore # unfortunately the indices might be out of bounds. So check # that first, and use mode='wrap' then. However only if # there are any indices... if res.size != 0: if error_unless_broadcast_to_empty: raise IndexError - for _indx, _size in zip(indx[1:], orig_slice): + for _indx, _size in zip(indx[1:], orig_slice): # codespell:ignore if _indx.size == 0: continue if np.any(_indx >= _size) or np.any(_indx < -_size): raise IndexError - if len(indx[1:]) == len(orig_slice): + if len(indx[1:]) == len(orig_slice): # codespell:ignore if np.prod(orig_slice) == 0: # Work around for a crash or IndexError with 'wrap' # in some 0-sized cases. try: mi = np.ravel_multi_index( - indx[1:], orig_slice, mode="raise" + indx[1:], # codespell:ignore + orig_slice, + mode="raise", # codespell:ignore ) except Exception as exc: # This happens with 0-sized orig_slice (sometimes?) # here it is a ValueError, but indexing gives a: raise IndexError("invalid index into 0-sized") from exc else: - mi = np.ravel_multi_index(indx[1:], orig_slice, mode="wrap") + mi = np.ravel_multi_index( + indx[1:], # codespell:ignore + orig_slice, + mode="wrap", + ) else: # Maybe never happens... raise ValueError @@ -911,7 +925,7 @@ def _get_multi_index(self, arr, indices): continue # If we are here, we have a 1D array for take: - arr = arr.take(indx[1], axis=ax) + arr = arr.take(indx[1], axis=ax) # codespell:ignore ax += 1 return arr, no_copy diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index cc5e64874a05e..4f4bc16f53221 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -1703,7 +1703,7 @@ def test_sort_size_0(self): msg = "test empty array sort with axis=None" assert_equal(np.sort(a, axis=None), a.ravel(), msg) - @skip(reason="waaay tooo sloooow") + @skip(reason="waaay tooo sloooow") # codespell:ignore def test_sort_degraded(self): # test degraded dataset would take minutes to run with normal qsort d = np.arange(1000000) @@ -2647,7 +2647,7 @@ def test_dot_out_mem_overlap(self): assert_raises(ValueError, np.dot, a, b, out=b[::2]) assert_raises(ValueError, np.dot, a, b, out=b.T) - @xpassIfTorchDynamo_np # (reason="TODO: overlapping memor in matmul") + @xpassIfTorchDynamo_np # (reason="TODO: overlapping memory in matmul") def test_matmul_out(self): # overlapping memory a = np.arange(18).reshape(2, 3, 3) @@ -3330,8 +3330,8 @@ def test_combinations(self, data): assert_equal(np.argmax(rarr), rpos, err_msg=f"{rarr!r}") assert_equal(rarr[np.argmax(rarr)], val, err_msg=f"{rarr!r}") - padd = np.repeat(np.min(arr), 513) - rarr = np.concatenate((arr, padd)) + padding = np.repeat(np.min(arr), 513) + rarr = np.concatenate((arr, padding)) rpos = pos assert_equal(np.argmax(rarr), rpos, err_msg=f"{rarr!r}") assert_equal(rarr[np.argmax(rarr)], val, err_msg=f"{rarr!r}") @@ -3439,8 +3439,8 @@ def test_combinations(self, data): assert_equal(np.argmin(rarr), rpos, err_msg=f"{rarr!r}") assert_equal(rarr[np.argmin(rarr)], min_val, err_msg=f"{rarr!r}") - padd = np.repeat(np.max(arr), 513) - rarr = np.concatenate((arr, padd)) + padding = np.repeat(np.max(arr), 513) + rarr = np.concatenate((arr, padding)) rpos = pos assert_equal(np.argmin(rarr), rpos, err_msg=f"{rarr!r}") assert_equal(rarr[np.argmin(rarr)], min_val, err_msg=f"{rarr!r}") @@ -4318,7 +4318,7 @@ def test_array_base(self, obj): # See also gh-21612 if isinstance(obj, str): # @parametrize breaks with bytes objects - obj = bytes(obj, enconding="latin-1") + obj = bytes(obj, encoding="latin-1") new = np.frombuffer(obj) assert new.base is obj @@ -4432,7 +4432,7 @@ def test_basic(self): ) assert_array_equal(x[9:].ravel(), 0) - @skip(reason="how to find if someone is refencing an array") + @skip(reason="how to find if someone is referencing an array") def test_check_reference(self): x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) y = x diff --git a/test/torch_np/numpy_tests/lib/test_histograms.py b/test/torch_np/numpy_tests/lib/test_histograms.py index f638e994c1f4c..24986b15883c7 100644 --- a/test/torch_np/numpy_tests/lib/test_histograms.py +++ b/test/torch_np/numpy_tests/lib/test_histograms.py @@ -351,7 +351,7 @@ def test_signed_overflow_bounds(self): self.do_signed_overflow_bounds(np.short) self.do_signed_overflow_bounds(np.intc) - @xfail # (reason="int->float conversin loses precision") + @xfail # (reason="int->float conversion loses precision") def test_signed_overflow_bounds_2(self): self.do_signed_overflow_bounds(np.int_) self.do_signed_overflow_bounds(np.longlong) diff --git a/test/torch_np/numpy_tests/lib/test_index_tricks.py b/test/torch_np/numpy_tests/lib/test_index_tricks.py index 6b373e87f2b5e..2a90d7a70484e 100644 --- a/test/torch_np/numpy_tests/lib/test_index_tricks.py +++ b/test/torch_np/numpy_tests/lib/test_index_tricks.py @@ -284,7 +284,7 @@ def test_mgrid_size_none_handling(self, start, stop, step, expected): assert_equal(grid.size, expected[0]) assert_equal(grid_small.size, expected[1]) - @xfail # (reason="mgrid not implementd") + @xfail # (reason="mgrid not implemented") def test_accepts_npfloating(self): # regression test for #16466 grid64 = mgrid[0.1:0.33:0.1,] diff --git a/test/torch_np/test_ndarray_methods.py b/test/torch_np/test_ndarray_methods.py index b25faac56cb83..27da866aaaa44 100644 --- a/test/torch_np/test_ndarray_methods.py +++ b/test/torch_np/test_ndarray_methods.py @@ -480,8 +480,8 @@ def test_combinations(self, data): assert_equal(np.argmax(rarr), rpos, err_msg=f"{rarr!r}") assert_equal(rarr[np.argmax(rarr)], val, err_msg=f"{rarr!r}") - padd = np.repeat(np.min(arr), 513) - rarr = np.concatenate((arr, padd)) + padding = np.repeat(np.min(arr), 513) + rarr = np.concatenate((arr, padding)) rpos = pos assert_equal(np.argmax(rarr), rpos, err_msg=f"{rarr!r}") assert_equal(rarr[np.argmax(rarr)], val, err_msg=f"{rarr!r}") @@ -593,8 +593,8 @@ def test_combinations(self, data): assert_equal(np.argmin(rarr), rpos, err_msg=f"{rarr!r}") assert_equal(rarr[np.argmin(rarr)], min_val, err_msg=f"{rarr!r}") - padd = np.repeat(np.max(arr), 513) - rarr = np.concatenate((arr, padd)) + padding = np.repeat(np.max(arr), 513) + rarr = np.concatenate((arr, padding)) rpos = pos assert_equal(np.argmin(rarr), rpos, err_msg=f"{rarr!r}") assert_equal(rarr[np.argmin(rarr)], min_val, err_msg=f"{rarr!r}") From a2973fb00ec002dd4b6bbf07385f066efb259b8c Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sat, 29 Nov 2025 03:06:18 +0000 Subject: [PATCH 060/338] [9/N] Use Python 3.10 typing (#167806) This PR applies Python 3.10 typing syntax to some files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167806 Approved by: https://github.com/albanD --- .../_checkpoint/checkpoint_wrapper.py | 4 +- .../algorithms/_comm_hooks/default_hooks.py | 7 +- .../ddp_comm_hooks/ddp_zero_hook.py | 4 +- torch/distributed/algorithms/join.py | 8 +- .../algorithms/model_averaging/averagers.py | 9 +- .../hierarchical_model_averager.py | 5 +- .../algorithms/model_averaging/utils.py | 10 +- torch/distributed/collective_utils.py | 32 ++- torch/distributed/constants.py | 3 +- torch/distributed/device_mesh.py | 87 ++++--- torch/distributed/distributed_c10d.py | 228 +++++++++--------- torch/distributed/elastic/agent/server/api.py | 28 +-- .../agent/server/local_elastic_agent.py | 14 +- torch/distributed/elastic/events/__init__.py | 6 +- torch/distributed/elastic/events/api.py | 6 +- torch/distributed/elastic/metrics/__init__.py | 2 +- torch/distributed/elastic/metrics/api.py | 5 +- .../elastic/multiprocessing/__init__.py | 10 +- .../elastic/multiprocessing/api.py | 74 +++--- .../multiprocessing/errors/__init__.py | 6 +- .../multiprocessing/errors/error_handler.py | 4 +- .../subprocess_handler/handlers.py | 3 +- .../subprocess_handler/subprocess_handler.py | 10 +- .../elastic/multiprocessing/tail_log.py | 6 +- .../elastic/rendezvous/_etcd_stub.py | 6 +- torch/distributed/elastic/rendezvous/api.py | 16 +- .../rendezvous/c10d_rendezvous_backend.py | 10 +- .../elastic/rendezvous/dynamic_rendezvous.py | 48 ++-- .../elastic/rendezvous/etcd_rendezvous.py | 5 +- .../rendezvous/etcd_rendezvous_backend.py | 12 +- .../elastic/rendezvous/etcd_server.py | 12 +- .../elastic/rendezvous/etcd_store.py | 5 +- .../rendezvous/static_tcp_rendezvous.py | 4 +- torch/distributed/elastic/rendezvous/utils.py | 18 +- torch/distributed/elastic/timer/api.py | 10 +- .../elastic/timer/file_based_local_timer.py | 8 +- .../utils/data/elastic_distributed_sampler.py | 6 +- .../distributed/elastic/utils/distributed.py | 3 +- torch/distributed/elastic/utils/logging.py | 7 +- torch/distributed/elastic/utils/store.py | 7 +- torch/distributed/launcher/api.py | 24 +- torch/distributed/nn/api/remote_module.py | 43 ++-- .../distributed/optim/functional_adadelta.py | 3 +- torch/distributed/optim/functional_adagrad.py | 3 +- torch/distributed/optim/functional_adam.py | 5 +- torch/distributed/optim/functional_adamax.py | 3 +- torch/distributed/optim/functional_adamw.py | 5 +- torch/distributed/optim/functional_rmsprop.py | 3 +- torch/distributed/optim/functional_rprop.py | 3 +- torch/distributed/optim/functional_sgd.py | 9 +- torch/distributed/optim/named_optimizer.py | 10 +- torch/distributed/optim/optimizer.py | 3 +- .../optim/zero_redundancy_optimizer.py | 24 +- torch/distributed/pipelining/_backward.py | 20 +- .../pipelining/_schedule_visualizer.py | 20 +- torch/distributed/pipelining/_utils.py | 5 +- torch/distributed/pipelining/microbatch.py | 8 +- torch/distributed/pipelining/schedules.py | 222 +++++++++-------- torch/distributed/pipelining/stage.py | 56 ++--- torch/distributed/remote_device.py | 9 +- torch/distributed/rendezvous.py | 3 +- torch/distributed/rpc/options.py | 10 +- torch/distributed/run.py | 9 +- torch/distributed/tensor/_api.py | 84 +++---- torch/distributed/tensor/_collective_utils.py | 4 +- torch/distributed/tensor/_dispatch.py | 4 +- torch/distributed/tensor/_dtensor_spec.py | 12 +- torch/distributed/tensor/_op_schema.py | 24 +- .../distributed/tensor/_ops/_common_rules.py | 4 +- torch/distributed/tensor/_ops/_mask_buffer.py | 3 +- torch/distributed/tensor/_ops/_math_ops.py | 12 +- torch/distributed/tensor/_ops/_matrix_ops.py | 6 +- .../distributed/tensor/_ops/_pointwise_ops.py | 10 +- torch/distributed/tensor/_ops/_tensor_ops.py | 10 +- torch/distributed/tensor/_ops/_view_ops.py | 14 +- torch/distributed/tensor/_random.py | 8 +- torch/distributed/tensor/_redistribute.py | 16 +- torch/distributed/tensor/_sharding_prop.py | 24 +- torch/distributed/tensor/_utils.py | 4 +- .../examples/comm_mode_features_example.py | 4 +- .../tensor/examples/flex_attention_cp.py | 5 +- .../_context_parallel/_attention.py | 54 ++--- .../_context_parallel/_load_balancer.py | 5 +- .../tensor/experimental/_func_map.py | 6 +- .../tensor/experimental/_register_sharding.py | 3 +- .../tensor/experimental/_tp_transform.py | 4 +- .../tensor/parallel/_data_parallel_utils.py | 4 +- torch/distributed/tensor/parallel/api.py | 7 +- torch/distributed/tensor/parallel/ddp.py | 4 +- torch/distributed/tensor/parallel/fsdp.py | 10 +- .../tensor/parallel/input_reshard.py | 6 +- torch/distributed/tensor/parallel/loss.py | 8 +- torch/distributed/tensor/parallel/style.py | 48 ++-- torch/distributed/tensor/placement_types.py | 18 +- torch/distributed/utils.py | 4 +- 95 files changed, 800 insertions(+), 877 deletions(-) diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index eae76e8cc72af..081d397a9c1f1 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterator from enum import auto, Enum from functools import partial -from typing import Any, Optional +from typing import Any import torch import torch.nn as nn @@ -248,7 +248,7 @@ def apply_activation_checkpointing( model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=lambda _: True, - auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None, + auto_wrap_policy: Callable[[nn.Module, bool, int], bool] | None = None, ): """ Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration. diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index 872ad0e2a7673..76cd01c2265b1 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import functools -from typing import Optional import torch import torch.distributed as dist @@ -136,7 +135,7 @@ def _low_precision_hook( prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, - output: Optional[torch.Tensor], + output: torch.Tensor | None, ): if grad.dtype != prec: grad.data = grad.data.to(prec) @@ -151,7 +150,7 @@ def _low_precision_hook( def fp16_compress_hook( - state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None + state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None ): r""" Implement FSDP communication hook for a simple gradient compression approach. @@ -172,7 +171,7 @@ def fp16_compress_hook( def bf16_compress_hook( - state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None + state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None ): r""" Implement FSDP communication hook for a simple gradient compression approach . diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 2e55941b370cd..fa8c865c89151 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import weakref from collections.abc import Callable -from typing import Any, Optional +from typing import Any import torch import torch.distributed as dist @@ -47,7 +47,7 @@ def _perform_local_step( # expects `None` in a list position to indicate that the corresponding # parameter should not be updated num_local_optim_params = len(zero.optim.param_groups[0]["params"]) - gradients: list[Optional[torch.Tensor]] = [ + gradients: list[torch.Tensor | None] = [ _NO_PARAM_UPDATE for _ in range(num_local_optim_params) ] assert bucket_index in overlap_info.offsets, ( diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index bf7cb117f87ee..52d0c52fbfb59 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -2,7 +2,7 @@ import warnings from abc import ABC, abstractmethod from types import TracebackType -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple import torch import torch.distributed as dist @@ -228,9 +228,9 @@ def __enter__(self): ... def __exit__( self, - type: Optional[type[BaseException]], - value: Optional[BaseException], - traceback: Optional[TracebackType], + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, ): r""" Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py index dd97e5191808f..5d669d4ea5922 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -2,7 +2,6 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Optional, Union import torch import torch.distributed as dist @@ -23,7 +22,7 @@ class ModelAverager(ABC): will be used. (default: ``None``) """ - def __init__(self, process_group: Optional[dist.ProcessGroup] = None): + def __init__(self, process_group: dist.ProcessGroup | None = None): self.process_group = ( process_group if process_group is not None else _not_none(dist.group.WORLD) ) @@ -88,7 +87,7 @@ class PeriodicModelAverager(ModelAverager): """ def __init__( - self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None + self, period, warmup_steps=0, process_group: dist.ProcessGroup | None = None ): super().__init__(process_group) if warmup_steps < 0: @@ -108,9 +107,7 @@ def __init__( def average_parameters( self, - params: Union[ - Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] - ], + params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]], ): """ Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``. diff --git a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py index 33cde4cb3a743..4f7edc447d108 100644 --- a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py +++ b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -4,7 +4,6 @@ import warnings from collections import OrderedDict from collections.abc import Iterable -from typing import Union import torch import torch.distributed as dist @@ -160,9 +159,7 @@ def _find_process_group(self): def average_parameters( self, - params: Union[ - Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] - ], + params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]], ): """ Averages parameters or parameter groups of an optimizer. diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index fa8cc184eddc5..6a61c036913ed 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import itertools from collections.abc import Iterable, Iterator -from typing import Union import torch import torch.distributed as dist @@ -51,10 +50,7 @@ def average_parameters( def get_params_to_average( - params: Union[ - Iterable[torch.nn.Parameter], - Iterable[dict[str, torch.nn.Parameter]], - ], + params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]], ): """ Return a list of parameters that need to average. @@ -83,9 +79,7 @@ def get_params_to_average( def average_parameters_or_parameter_groups( - params: Union[ - Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] - ], + params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]], process_group: ProcessGroup, ): """Averages parameters of a model or parameter groups of an optimizer.""" diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index e608e26a3a854..cb20c58f13309 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -13,7 +13,7 @@ import logging from collections import defaultdict from dataclasses import dataclass -from typing import Any, cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar if TYPE_CHECKING: @@ -37,19 +37,19 @@ @dataclass class SyncPayload(Generic[T]): - stage_name: Optional[str] + stage_name: str | None success: bool payload: T - exception: Optional[Exception] = None + exception: Exception | None = None def broadcast( - data_or_fn: Union[T, Callable[[], T]], + data_or_fn: T | Callable[[], T], *, success: bool = True, - stage_name: Optional[str] = None, + stage_name: str | None = None, rank: int = 0, - pg: Optional[dist.ProcessGroup] = None, + pg: dist.ProcessGroup | None = None, ) -> T: """ Broadcasts the data payload from rank 0 to all other ranks. @@ -79,8 +79,8 @@ def broadcast( "Data or Function is expected to be None if not successful" ) - payload: Optional[T] = None - exception: Optional[Exception] = None + payload: T | None = None + exception: Exception | None = None # if no pg is passed then execute if rank is 0 if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): # determine if it is an executable function or data payload only @@ -124,9 +124,9 @@ def broadcast( def all_gather( - data_or_fn: Union[T, Callable[[], T]], - stage_name: Optional[str] = None, - pg: Optional[dist.ProcessGroup] = None, + data_or_fn: T | Callable[[], T], + stage_name: str | None = None, + pg: dist.ProcessGroup | None = None, ) -> list[T]: """ A simple all_gather primitive with basic synchronization guard logic, @@ -144,8 +144,8 @@ def all_gather( Example usage: >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) """ - payload: Optional[T] = None - exception: Optional[Exception] = None + payload: T | None = None + exception: Exception | None = None success = True # determine if it is an executable function or data payload only if callable(data_or_fn): @@ -247,7 +247,7 @@ def _summarize_ranks(ranks: Iterable[int]) -> str: raise AssertionError("ranks should all be positive") if len(set(ranks)) != len(ranks): raise AssertionError("ranks should not contain duplicates") - curr: Optional[Union[int, range]] = None + curr: int | range | None = None ranges = [] while ranks: x = ranks.pop(0) @@ -345,9 +345,7 @@ def _desync_table_str(tag: str, value_ranks: dict[Any, set[int]]) -> str: return str(f"{headers}\n{row_str}") -def _check_rng_sync( - generator: torch.Generator, group: dist.ProcessGroup -) -> Optional[str]: +def _check_rng_sync(generator: torch.Generator, group: dist.ProcessGroup) -> str | None: value_ranks, value_header = _check_rng_sync_internal(generator, group) log_str = None if len(value_ranks) > 1: diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index c1e604bc86753..0a077bd6d4e5e 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -1,5 +1,4 @@ from datetime import timedelta -from typing import Optional from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT @@ -19,7 +18,7 @@ try: from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT - default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT + default_pg_nccl_timeout: timedelta | None = _DEFAULT_PG_NCCL_TIMEOUT except ImportError: # if C++ NCCL support is not compiled, we don't have access to the default nccl value. # if anyone is actually trying to use nccl in this state, it should error. diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 05ded47876a8c..86bdd44fa3656 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -65,7 +65,7 @@ def _init_device_mesh_stub(): "DeviceMesh requires numpy >= 1.21 to be installed for type checking" ) - BackendConfig = tuple[Optional[str], Optional[C10dBackend.Options]] + BackendConfig = tuple[str | None, C10dBackend.Options | None] torch.serialization.add_safe_globals([_MeshLayout]) class _MeshEnv(threading.local): @@ -175,7 +175,7 @@ class DeviceMesh: _device_type: str _rank_map: torch.Tensor - _mesh_dim_names: Optional[tuple[str, ...]] + _mesh_dim_names: tuple[str, ...] | None _layout: _MeshLayout _root_mesh: Optional["DeviceMesh"] = None # Record flatten mesh name to its flattened mesh in root mesh. @@ -184,14 +184,14 @@ class DeviceMesh: def __init__( self, device_type: str, - mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, + mesh: Union[torch.Tensor, "ArrayLike"] | None = None, *, - mesh_dim_names: Optional[tuple[str, ...]] = None, - backend_override: Optional[tuple[BackendConfig, ...]] = None, + mesh_dim_names: tuple[str, ...] | None = None, + backend_override: tuple[BackendConfig, ...] | None = None, _init_backend: bool = True, - _rank: Optional[int] = None, - _layout: Optional[_MeshLayout] = None, - _rank_map: Optional[torch.Tensor] = None, + _rank: int | None = None, + _layout: _MeshLayout | None = None, + _rank_map: torch.Tensor | None = None, _root_mesh: Optional["DeviceMesh"] = None, ) -> None: # no-op in OSS, logs API usage metrics in meta-internal runs @@ -292,7 +292,7 @@ def __init__( raise AssertionError( f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}" ) - self._coordinate_on_dim: Optional[list[int]] = ( + self._coordinate_on_dim: list[int] | None = ( rank_coords[0].tolist() if rank_coords.size(0) > 0 else None ) @@ -317,7 +317,7 @@ def mesh(self) -> torch.Tensor: ) @property - def mesh_dim_names(self) -> Optional[tuple[str, ...]]: + def mesh_dim_names(self) -> tuple[str, ...] | None: """Returns the names of mesh dimensions.""" return self._mesh_dim_names @@ -378,7 +378,7 @@ def _init_one_process_group( rank_map: torch.Tensor, dim_name: str, backend_override: BackendConfig, - ) -> Optional[str]: + ) -> str | None: # Generate a 2D global mesh tensor for the current dim for PG creation. pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map) backend, pg_options = backend_override @@ -471,7 +471,7 @@ def _init_one_process_group( def _init_process_groups( layout: _MeshLayout, rank_map: torch.Tensor, - mesh_dim_names: Optional[tuple[str, ...]], + mesh_dim_names: tuple[str, ...] | None, backend_override: tuple[BackendConfig, ...], ) -> list[str]: # group_name associated with each mesh dimension, each @@ -543,9 +543,7 @@ def __eq__(self, other: object) -> bool: and self._thread_id == other._thread_id ) - def __getitem__( - self, mesh_dim_names: Union[str, tuple[str, ...]] - ) -> "DeviceMesh": + def __getitem__(self, mesh_dim_names: str | tuple[str, ...]) -> "DeviceMesh": """ Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. The submesh created consists of the dimensions and the communicators indicated by @@ -613,7 +611,7 @@ def __getitem__( submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names) return submesh - def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: + def get_group(self, mesh_dim: int | str | None = None) -> ProcessGroup: """ Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. @@ -705,7 +703,7 @@ def _create_sub_mesh( def _create_flatten_mesh( self, - mesh_dim_name: Optional[str] = None, + mesh_dim_name: str | None = None, backend_override: BackendConfig = (None, None), ) -> "DeviceMesh": root_mesh = self._get_root_mesh() @@ -754,7 +752,7 @@ def _create_flatten_mesh( return res_flattened_mesh - def _get_root_mesh_dim(self) -> Optional[int]: + def _get_root_mesh_dim(self) -> int | None: """ Returns the index of the mesh dim in the root mesh. The device_mesh passed in needs to be sliced out from the root mesh @@ -893,11 +891,11 @@ def _get_all_submeshes(self, mesh_dim_name: str) -> list["DeviceMesh"]: @staticmethod def from_group( - group: Union[ProcessGroup, list[ProcessGroup]], + group: ProcessGroup | list[ProcessGroup], device_type: str, - mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, + mesh: Union[torch.Tensor, "ArrayLike"] | None = None, *, - mesh_dim_names: Optional[tuple[str, ...]] = None, + mesh_dim_names: tuple[str, ...] | None = None, ) -> "DeviceMesh": """ Constructs a :class:`DeviceMesh` with ``device_type`` from an @@ -986,7 +984,7 @@ def from_group( device_mesh._dim_group_names = [group.group_name for group in groups] return device_mesh - def size(self, mesh_dim: Optional[int] = None) -> int: + def size(self, mesh_dim: int | None = None) -> int: if mesh_dim is not None: return self._layout[mesh_dim].numel() return self._layout.numel() @@ -1005,7 +1003,7 @@ def get_rank(self) -> int: """ return get_rank() - def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: + def get_local_rank(self, mesh_dim: int | str | None = None) -> int: """ Returns the local rank of the given mesh_dim of the DeviceMesh. @@ -1049,7 +1047,7 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: ) return not_none(get_rank(mesh_dim_group)) - def get_coordinate(self) -> Optional[list[int]]: + def get_coordinate(self) -> list[int] | None: """ Return the relative indices of this rank relative to all dimensions of the mesh. If this rank is not part of the mesh, return None. @@ -1058,10 +1056,11 @@ def get_coordinate(self) -> Optional[list[int]]: def _flatten( self, - mesh_dim_name: Optional[str] = None, - backend_override: Union[ - None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] - ] = None, + mesh_dim_name: str | None = None, + backend_override: None + | str + | C10dBackend.Options + | tuple[str, C10dBackend.Options] = None, ) -> "DeviceMesh": """ Returns a 1D DeviceMesh by flattening the current DeviceMesh. @@ -1095,7 +1094,7 @@ def _create_unflatten_mesh( mesh_sizes: tuple[int, ...], mesh_dim_names: tuple[str, ...], backend_override: tuple[ - tuple[Optional[str], Optional[C10dBackend.Options]], ... + tuple[str | None, C10dBackend.Options | None], ... ] = ((None, None),), ) -> "DeviceMesh": inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes)) @@ -1140,15 +1139,13 @@ def _create_unflatten_mesh( def _unflatten( self, - dim: Union[int, str], + dim: int | str, mesh_sizes: tuple[int, ...], mesh_dim_names: tuple[str, ...], - backend_override: Optional[ - dict[ - str, - Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], - ] - ] = None, + backend_override: dict[ + str, str | C10dBackend.Options | tuple[str, C10dBackend.Options] + ] + | None = None, ) -> "DeviceMesh": """ Returns a DeviceMesh by unflatten the current DeviceMesh. @@ -1239,11 +1236,11 @@ def _concatenate(device_mesh_list: list["DeviceMesh"]) -> "DeviceMesh": def _normalize_backend_override( backend_override: dict[ - Union[int, str], - Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], + int | str, + str | C10dBackend.Options | tuple[str, C10dBackend.Options], ], ndim: int, - mesh_dim_names: Optional[tuple[str, ...]] = None, + mesh_dim_names: tuple[str, ...] | None = None, ) -> Iterator[BackendConfig]: if mesh_dim_names is None: mesh_dim_names = () @@ -1278,13 +1275,11 @@ def init_device_mesh( device_type: str, mesh_shape: tuple[int, ...], *, - mesh_dim_names: Optional[tuple[str, ...]] = None, - backend_override: Optional[ - dict[ - Union[int, str], - Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], - ] - ] = None, + mesh_dim_names: tuple[str, ...] | None = None, + backend_override: dict[ + int | str, str | C10dBackend.Options | tuple[str, C10dBackend.Options] + ] + | None = None, ) -> DeviceMesh: """ Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 801716e3855ac..b7a3dbf33f91f 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -17,7 +17,7 @@ from collections import namedtuple from collections.abc import Callable from datetime import timedelta -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING from typing_extensions import deprecated import torch @@ -309,7 +309,7 @@ def register_backend( name, func, extended_api=False, - devices: Optional[Union[str, list[str]]] = None, + devices: str | list[str] | None = None, ) -> None: """ Register a new backend with the given name and instantiating function. @@ -504,10 +504,10 @@ def __init__( self, op: Callable, tensor: torch.Tensor, - peer: Optional[int] = None, - group: Optional[ProcessGroup] = None, + peer: int | None = None, + group: ProcessGroup | None = None, tag: int = 0, - group_peer: Optional[int] = None, + group_peer: int | None = None, ): """Init.""" self.op = op @@ -523,10 +523,10 @@ def __new__( cls, op: Callable, tensor: torch.Tensor, - peer: Optional[int] = None, - group: Optional[ProcessGroup] = None, + peer: int | None = None, + group: ProcessGroup | None = None, tag: int = 0, - group_peer: Optional[int] = None, + group_peer: int | None = None, ): """Create and return a new instance of the class.""" _check_op(op) @@ -566,9 +566,9 @@ def __init__( self, op: Callable, tensor: torch.Tensor, - dst_tensor: Optional[torch.Tensor] = None, - redop: Optional[ReduceOp] = None, - root: Optional[int] = None, + dst_tensor: torch.Tensor | None = None, + redop: ReduceOp | None = None, + root: int | None = None, ): self.op = op self.tensor = tensor @@ -587,7 +587,7 @@ def __init__( _group_count = 0 _tags_to_pg: dict[str, list[ProcessGroup]] = {} _pg_to_tag: dict[ProcessGroup, str] = {} -_backend: Optional[str] = None +_backend: str | None = None class _World: @@ -605,7 +605,7 @@ def __init__(self) -> None: self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {} @property - def default_pg(self) -> Optional[ProcessGroup]: + def default_pg(self) -> ProcessGroup | None: """ Process group that includes all ranks of the cluster. @@ -730,11 +730,11 @@ class _WorldMeta(type): # Points to the default PG once initialized. @property - def WORLD(cls) -> Optional[ProcessGroup]: + def WORLD(cls) -> ProcessGroup | None: return _world.default_pg @WORLD.setter - def WORLD(cls, pg: Optional[ProcessGroup]): + def WORLD(cls, pg: ProcessGroup | None): _world.default_pg = pg @@ -772,12 +772,12 @@ def _check_valid_timeout(timeout: Any) -> None: # Default process group state -_default_pg_init_method: Optional[str] = None +_default_pg_init_method: str | None = None STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" -def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str: +def _get_object_coll_device(group: ProcessGroup | None = None) -> str: """ .. note:: This is an internal helper and does not have backward compatibility, please use with caution. @@ -843,7 +843,7 @@ def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str: return devices[0].type -def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device: +def _get_pg_default_device(group: ProcessGroup | None = None) -> torch.device: """ .. note:: This method will be deprecated, it only stays for backward-compatiblity reason. Alternatives: @@ -923,7 +923,7 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device return rv -def _device_capability(group: Optional[ProcessGroup] = None) -> list[str]: +def _device_capability(group: ProcessGroup | None = None) -> list[str]: """ Return the device type(s) supported by ``group``. @@ -1007,7 +1007,7 @@ def _store_based_barrier( ) -def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool: +def _rank_not_in_group(group: ProcessGroup | None) -> bool: """Check if the current process's rank is not in a given group.""" if group is None: return False @@ -1089,7 +1089,7 @@ def _get_global_rank(group, rank) -> int: return get_global_rank(group, rank) -def get_process_group_ranks(group: Optional[ProcessGroup]) -> list[int]: +def get_process_group_ranks(group: ProcessGroup | None) -> list[int]: """ Get all ranks associated with ``group``. @@ -1148,7 +1148,7 @@ def _check_tensor_list(param, param_name) -> None: ) -def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGroup: +def _group_or_default_group(group: ProcessGroup | None = None) -> ProcessGroup: if group is None or group is GroupMember.WORLD: group = _get_default_group() return group @@ -1156,8 +1156,8 @@ def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGrou def _canonicalize_group_rank( group: ProcessGroup, - global_rank: Optional[int] = None, - group_rank: Optional[int] = None, + global_rank: int | None = None, + group_rank: int | None = None, return_global: bool = False, ) -> int: """ @@ -1361,7 +1361,7 @@ def _update_default_pg(pg) -> None: torch._C._distributed_c10d._set_global_rank(rank) -def get_backend_config(group: Optional[ProcessGroup] = None) -> str: +def get_backend_config(group: ProcessGroup | None = None) -> str: """ Return the backend configuration of the given process group. @@ -1381,7 +1381,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str: return str(not_none(backend_config)) -def get_backend(group: Optional[ProcessGroup] = None) -> Backend: +def get_backend(group: ProcessGroup | None = None) -> Backend: """ Return the backend of the given process group. @@ -1407,7 +1407,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend: return Backend(not_none(pg_store)[0]) -def get_default_backend_for_device(device: Union[str, torch.device]) -> str: +def get_default_backend_for_device(device: str | torch.device) -> str: """ Return the default backend for the given device. @@ -1441,7 +1441,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int: return -1 -def _get_pg_config(group: Optional[ProcessGroup] = None) -> dict[str, Any]: +def _get_pg_config(group: ProcessGroup | None = None) -> dict[str, Any]: """ Return the pg configuration of the given process group. @@ -1473,7 +1473,7 @@ def get_pg_count() -> int: return _world.group_count -def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: +def get_node_local_rank(fallback_rank: int | None = None) -> int: """ Return the local rank of the current process relative to the node. @@ -1526,7 +1526,7 @@ def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None: backend._add_ephemeral_timeout(timeout) -def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None: +def _set_pg_timeout(timeout: timedelta, group: ProcessGroup | None = None) -> None: """ Set the timeout for the given process group when users want to use a different timeout instead of default values. @@ -1575,16 +1575,16 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> @_exception_logger @_time_logger def init_process_group( - backend: Optional[str] = None, - init_method: Optional[str] = None, - timeout: Optional[timedelta] = None, + backend: str | None = None, + init_method: str | None = None, + timeout: timedelta | None = None, world_size: int = -1, rank: int = -1, - store: Optional[Store] = None, + store: Store | None = None, group_name: str = "", - pg_options: Optional[Any] = None, - device_id: Optional[Union[torch.device, int]] = None, - _ranks: Optional[list[int]] = None, + pg_options: Any | None = None, + device_id: torch.device | int | None = None, + _ranks: list[int] | None = None, ) -> None: """ Initialize the default distributed process group. @@ -2216,7 +2216,7 @@ def _new_process_group_helper( return pg, prefix_store -def destroy_process_group(group: Optional[ProcessGroup] = None): +def destroy_process_group(group: ProcessGroup | None = None): """ Destroy a given process group, and deinitialize the distributed package. @@ -2305,7 +2305,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): _unregister_process_group(pg.group_name) -def _abort_process_group(group: Optional[ProcessGroup] = None): +def _abort_process_group(group: ProcessGroup | None = None): """ Abort a given process group. If group.WORLD (i.e. `None`) is given, all process groups including the default one will be aborted. @@ -2397,7 +2397,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None): _unregister_process_group(pg.group_name) -def get_rank(group: Optional[ProcessGroup] = None) -> int: +def get_rank(group: ProcessGroup | None = None) -> int: """ Return the rank of the current process in the provided ``group``, default otherwise. @@ -2424,7 +2424,7 @@ def get_rank(group: Optional[ProcessGroup] = None) -> int: return get_group_rank(group, default_pg.rank()) -def get_world_size(group: Optional[ProcessGroup] = None) -> int: +def get_world_size(group: ProcessGroup | None = None) -> int: """ Return the number of processes in the current process group. @@ -2445,11 +2445,11 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int: def isend( tensor: torch.Tensor, - dst: Optional[int] = None, - group: Optional[ProcessGroup] = None, + dst: int | None = None, + group: ProcessGroup | None = None, tag: int = 0, - group_dst: Optional[int] = None, -) -> Optional[Work]: + group_dst: int | None = None, +) -> Work | None: """ Send a tensor asynchronously. @@ -2490,11 +2490,11 @@ def isend( def irecv( tensor: torch.Tensor, - src: Optional[int] = None, - group: Optional[ProcessGroup] = None, + src: int | None = None, + group: ProcessGroup | None = None, tag: int = 0, - group_src: Optional[int] = None, -) -> Optional[Work]: + group_src: int | None = None, +) -> Work | None: """ Receives a tensor asynchronously. @@ -2536,10 +2536,10 @@ def irecv( @_exception_logger def send( tensor: torch.Tensor, - dst: Optional[int] = None, - group: Optional[ProcessGroup] = None, + dst: int | None = None, + group: ProcessGroup | None = None, tag: int = 0, - group_dst: Optional[int] = None, + group_dst: int | None = None, ) -> None: """ Send a tensor synchronously. @@ -2568,10 +2568,10 @@ def send( @_exception_logger def recv( tensor: torch.Tensor, - src: Optional[int] = None, - group: Optional[ProcessGroup] = None, + src: int | None = None, + group: ProcessGroup | None = None, tag: int = 0, - group_src: Optional[int] = None, + group_src: int | None = None, ) -> int: """ Receives a tensor synchronously. @@ -2623,7 +2623,7 @@ class _CoalescingManager: def __init__(self) -> None: self.works: list[Work] = [] - def append(self, work: Optional[Work] = None): + def append(self, work: Work | None = None): if work: self.works.append(work) @@ -2634,8 +2634,8 @@ def wait(self): @contextlib.contextmanager def _coalescing_manager( - group: Optional[ProcessGroup] = None, - device: Optional[torch.device] = None, + group: ProcessGroup | None = None, + device: torch.device | None = None, async_ops: bool = False, ): """ @@ -2731,13 +2731,13 @@ def _coalescing_manager( class _TimeEstimator: def __init__(self) -> None: - self.estimated_time: Optional[float] = None + self.estimated_time: float | None = None @contextlib.contextmanager def _time_estimator( - group: Optional[ProcessGroup] = None, - device: Optional[torch.device] = None, + group: ProcessGroup | None = None, + device: torch.device | None = None, ): """ Context manager used to estimate time of collectives. @@ -2862,10 +2862,10 @@ def peer_kwarg(op: P2POp) -> dict[str, int]: @_exception_logger def broadcast( tensor: torch.Tensor, - src: Optional[int] = None, - group: Optional[ProcessGroup] = None, + src: int | None = None, + group: ProcessGroup | None = None, async_op: bool = False, - group_src: Optional[int] = None, + group_src: int | None = None, ): """ Broadcasts the tensor to the whole group. @@ -3084,11 +3084,11 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): @_exception_logger def reduce( tensor: torch.Tensor, - dst: Optional[int] = None, + dst: int | None = None, op=ReduceOp.SUM, - group: Optional[ProcessGroup] = None, + group: ProcessGroup | None = None, async_op: bool = False, - group_dst: Optional[int] = None, + group_dst: int | None = None, ): """ Reduces the tensor data across all machines. @@ -3268,10 +3268,10 @@ def all_gather_object(object_list, obj, group=None): @_exception_logger def gather_object( obj: Any, - object_gather_list: Optional[list[Any]] = None, - dst: Optional[int] = None, - group: Optional[ProcessGroup] = None, - group_dst: Optional[int] = None, + object_gather_list: list[Any] | None = None, + dst: int | None = None, + group: ProcessGroup | None = None, + group_dst: int | None = None, ): """ Gathers picklable objects from the whole group in a single process. @@ -3399,10 +3399,10 @@ def gather_object( @_exception_logger def send_object_list( object_list: list[Any], - dst: Optional[int] = None, - group: Optional[ProcessGroup] = None, - device: Optional[torch.device] = None, - group_dst: Optional[int] = None, + dst: int | None = None, + group: ProcessGroup | None = None, + device: torch.device | None = None, + group_dst: int | None = None, use_batch: bool = False, ): """ @@ -3517,10 +3517,10 @@ def send_object_list( @_exception_logger def recv_object_list( object_list: list[Any], - src: Optional[int] = None, - group: Optional[ProcessGroup] = None, - device: Optional[torch.device] = None, - group_src: Optional[int] = None, + src: int | None = None, + group: ProcessGroup | None = None, + device: torch.device | None = None, + group_src: int | None = None, use_batch: bool = False, ): """ @@ -3659,10 +3659,10 @@ def recv_object_list( @_exception_logger def broadcast_object_list( object_list: list[Any], - src: Optional[int] = None, - group: Optional[ProcessGroup] = None, - device: Optional[torch.device] = None, - group_src: Optional[int] = None, + src: int | None = None, + group: ProcessGroup | None = None, + device: torch.device | None = None, + group_src: int | None = None, ): """ Broadcasts picklable objects in ``object_list`` to the whole group. @@ -3791,10 +3791,10 @@ def broadcast_object_list( @_exception_logger def scatter_object_list( scatter_object_output_list: list[Any], - scatter_object_input_list: Optional[list[Any]] = None, - src: Optional[int] = None, - group: Optional[ProcessGroup] = None, - group_src: Optional[int] = None, + scatter_object_input_list: list[Any] | None = None, + src: int | None = None, + group: ProcessGroup | None = None, + group_src: int | None = None, ): """ Scatters picklable objects in ``scatter_object_input_list`` to the whole group. @@ -4265,11 +4265,11 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): @_exception_logger def gather( tensor: torch.Tensor, - gather_list: Optional[list[torch.Tensor]] = None, - dst: Optional[int] = None, - group: Optional[ProcessGroup] = None, + gather_list: list[torch.Tensor] | None = None, + dst: int | None = None, + group: ProcessGroup | None = None, async_op: bool = False, - group_dst: Optional[int] = None, + group_dst: int | None = None, ): """ Gathers a list of tensors in a single process. @@ -4348,11 +4348,11 @@ def gather( @_exception_logger def scatter( tensor: torch.Tensor, - scatter_list: Optional[list[torch.Tensor]] = None, - src: Optional[int] = None, - group: Optional[ProcessGroup] = None, + scatter_list: list[torch.Tensor] | None = None, + src: int | None = None, + group: ProcessGroup | None = None, async_op: bool = False, - group_src: Optional[int] = None, + group_src: int | None = None, ): """ Scatters a list of tensors to all processes in a group. @@ -4895,7 +4895,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False @_exception_logger def barrier( - group: Optional[ProcessGroup] = GroupMember.WORLD, async_op=False, device_ids=None + group: ProcessGroup | None = GroupMember.WORLD, async_op=False, device_ids=None ): """ Synchronize all processes. @@ -4967,7 +4967,7 @@ def barrier( def monitored_barrier( - group: Optional[ProcessGroup] = GroupMember.WORLD, + group: ProcessGroup | None = GroupMember.WORLD, timeout=None, wait_all_ranks=False, ): @@ -5104,7 +5104,7 @@ def _process_group_name(ranks, use_hashed_name): return pg_name -def _get_backend_from_str(backend: Optional[str] = None) -> Backend: +def _get_backend_from_str(backend: str | None = None) -> Backend: # Default to the same backend as the global process group # if backend is not specified. if not backend: @@ -5124,12 +5124,12 @@ def _is_safe_to_split() -> bool: @_time_logger def split_group( - parent_pg: Optional[ProcessGroup] = None, - split_ranks: Optional[list] = None, - timeout: Optional[timedelta] = None, - pg_options: Optional[Any] = None, - group_desc: Optional[str] = None, -) -> Optional[ProcessGroup]: + parent_pg: ProcessGroup | None = None, + split_ranks: list | None = None, + timeout: timedelta | None = None, + pg_options: Any | None = None, + group_desc: str | None = None, +) -> ProcessGroup | None: """ Create a new process group split from the given parent process group. @@ -5290,7 +5290,7 @@ def new_group( pg_options=None, use_local_synchronization=False, group_desc=None, - device_id: Optional[torch.device] = None, + device_id: torch.device | None = None, ): """ Create a new distributed group. @@ -5380,7 +5380,7 @@ def _new_group_with_tag( pg_tag=None, use_local_synchronization=False, group_desc=None, - device_id: Optional[torch.device] = None, + device_id: torch.device | None = None, ): """ Variant of ``new_group`` that exposes tag creation. @@ -5693,7 +5693,7 @@ def new_subgroups_by_enumeration( return cur_subgroup, subgroups -def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGroup]: +def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> ProcessGroup | None: if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"): tag = f"user:{tag}" @@ -5765,9 +5765,9 @@ def _get_process_group_store(pg: ProcessGroup) -> Store: @_time_logger def shrink_group( ranks_to_exclude: list[int], - group: Optional[ProcessGroup] = None, + group: ProcessGroup | None = None, shrink_flags: int = SHRINK_DEFAULT, - pg_options: Optional[Any] = None, + pg_options: Any | None = None, ) -> ProcessGroup: """ Shrinks a process group by excluding specified ranks. @@ -5857,7 +5857,7 @@ def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> N ) -def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict: +def _prepare_shrink_target_group(group: ProcessGroup | None) -> dict: """Prepare and validate the target group for shrinking.""" target_pg = group if group is not None else _get_default_group() @@ -6107,7 +6107,7 @@ def _create_shrunk_process_group( return new_pg -def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None: +def _destroy_all_other_groups(exclude_group: ProcessGroup | None = None) -> None: """ Destroy all process groups except the excluded group and clean up all global state. @@ -6223,9 +6223,9 @@ def _update_process_group_global_state( store: Store, group_name: str, backend_config: str, - rank_mapping: Optional[dict[int, int]] = None, - pg_tag: Optional[str] = None, - user_tag: Optional[str] = None, + rank_mapping: dict[int, int] | None = None, + pg_tag: str | None = None, + user_tag: str | None = None, ) -> None: """ Update all global state dictionaries for a process group. diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 1122913ed95db..2575aa137a581 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional, Union +from typing import Any import torch.distributed.elastic.rendezvous as rdzv import torch.distributed.elastic.utils.store as store_util @@ -89,19 +89,19 @@ class WorkerSpec: role: str local_world_size: int rdzv_handler: rdzv.RendezvousHandler - fn: Optional[Callable] = None + fn: Callable | None = None # TODO @kiuk - make entrypoint a required field - entrypoint: Union[Callable, str, None] = None + entrypoint: Callable | str | None = None args: tuple = () max_restarts: int = 3 monitor_interval: float = 0.1 - master_port: Optional[int] = None - master_addr: Optional[str] = None - local_addr: Optional[str] = None + master_port: int | None = None + master_addr: str | None = None + local_addr: str | None = None event_log_handler: str = "null" - numa_options: Optional[NumaOptions] = None - duplicate_stdout_filters: Optional[list[str]] = None - duplicate_stderr_filters: Optional[list[str]] = None + numa_options: NumaOptions | None = None + duplicate_stdout_filters: list[str] | None = None + duplicate_stderr_filters: list[str] | None = None virtual_local_rank: bool = False def __post_init__(self): @@ -807,11 +807,11 @@ def _construct_event( self, state: str, source: EventSource, - worker: Optional[Worker] = None, - raw_error: Optional[str] = None, - duration_ms: Optional[float] = None, - exit_code: Optional[int] = None, - worker_pid: Optional[int] = None, + worker: Worker | None = None, + raw_error: str | None = None, + duration_ms: float | None = None, + exit_code: int | None = None, + worker_pid: int | None = None, ) -> Event: wg = self._worker_group spec = wg.spec diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 5fd3b7d3526db..ef281b6c58c31 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -15,7 +15,7 @@ import time import uuid from string import Template -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch.distributed.elastic.timer as timer from torch.distributed.elastic import events @@ -152,16 +152,16 @@ def __init__( logs_specs: LogsSpecs, start_method="spawn", exit_barrier_timeout: float = 300, - log_line_prefix_template: Optional[str] = None, + log_line_prefix_template: str | None = None, ): super().__init__(spec, exit_barrier_timeout) self._start_method = start_method - self._pcontext: Optional[PContext] = None + self._pcontext: PContext | None = None self._rdzv_handler = spec.rdzv_handler self._log_line_prefix_template = log_line_prefix_template - self._worker_watchdog: Optional[timer.FileTimerServer] = None + self._worker_watchdog: timer.FileTimerServer | None = None self._logs_specs = logs_specs - self._health_check_server: Optional[HealthCheckServer] = None + self._health_check_server: HealthCheckServer | None = None def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None: enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER @@ -244,7 +244,7 @@ def _get_fq_hostname(self) -> str: def _log_watchdog_event( self, name: str, - request: Optional[timer.FileTimerRequest], + request: timer.FileTimerRequest | None, ) -> None: wg = self._worker_group spec = wg.spec @@ -297,7 +297,7 @@ def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: args: dict[int, tuple] = {} envs: dict[int, dict[str, str]] = {} - log_line_prefixes: Optional[dict[int, str]] = ( + log_line_prefixes: dict[int, str] | None = ( {} if self._log_line_prefix_template else None ) for worker in worker_group.workers: diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py index 02e158b021a0e..deea40f3899ae 100644 --- a/torch/distributed/elastic/events/__init__.py +++ b/torch/distributed/elastic/events/__init__.py @@ -86,10 +86,10 @@ def construct_and_record_rdzv_event( node_state: NodeState, name: str = "", hostname: str = "", - pid: Optional[int] = None, + pid: int | None = None, master_endpoint: str = "", - local_id: Optional[int] = None, - rank: Optional[int] = None, + local_id: int | None = None, + rank: int | None = None, ) -> None: """ Initialize rendezvous event object and record its operations. diff --git a/torch/distributed/elastic/events/api.py b/torch/distributed/elastic/events/api.py index 939ab0793f65d..31afe29ff5f59 100644 --- a/torch/distributed/elastic/events/api.py +++ b/torch/distributed/elastic/events/api.py @@ -10,7 +10,7 @@ import json from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Optional, Union +from typing import Union __all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] @@ -95,8 +95,8 @@ class RdzvEvent: pid: int node_state: NodeState master_endpoint: str = "" - rank: Optional[int] = None - local_id: Optional[int] = None + rank: int | None = None + local_id: int | None = None error_trace: str = "" def __str__(self): diff --git a/torch/distributed/elastic/metrics/__init__.py b/torch/distributed/elastic/metrics/__init__.py index b07671fbac9d3..b2c2330924879 100644 --- a/torch/distributed/elastic/metrics/__init__.py +++ b/torch/distributed/elastic/metrics/__init__.py @@ -158,7 +158,7 @@ def emit(self, metric_data): ) -def initialize_metrics(cfg: Optional[MetricsConfig] = None): +def initialize_metrics(cfg: MetricsConfig | None = None): pass diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 07d0f9fc43cc7..102049481538d 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -11,7 +11,6 @@ import time from collections import namedtuple from functools import wraps -from typing import Optional from typing_extensions import deprecated @@ -37,7 +36,7 @@ class MetricsConfig: __slots__ = ["params"] - def __init__(self, params: Optional[dict[str, str]] = None): + def __init__(self, params: dict[str, str] | None = None): self.params = params if self.params is None: self.params = {} @@ -77,7 +76,7 @@ def add_value(self, metric_name: str, metric_value: int): # pyre-fixme[9]: group has type `str`; used as `None`. -def configure(handler: MetricHandler, group: Optional[str] = None): +def configure(handler: MetricHandler, group: str | None = None): if group is None: global _default_metrics_handler # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index a68968bac8f4d..60b7cd32fd253 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -102,15 +102,15 @@ def trainer(a, b, c): def start_processes( name: str, - entrypoint: Union[Callable, str], + entrypoint: Callable | str, args: dict[int, tuple], envs: dict[int, dict[str, str]], logs_specs: LogsSpecs, - log_line_prefixes: Optional[dict[int, str]] = None, + log_line_prefixes: dict[int, str] | None = None, start_method: str = "spawn", - numa_options: Optional[NumaOptions] = None, - duplicate_stdout_filters: Optional[list[str]] = None, - duplicate_stderr_filters: Optional[list[str]] = None, + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, ) -> PContext: """ Start ``n`` copies of ``entrypoint`` processes with the provided options. diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index dd1633252cb48..41252bc35e00b 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -25,7 +25,7 @@ from enum import IntFlag from multiprocessing import synchronize from types import FrameType -from typing import Any, Optional, TextIO, Union +from typing import Any, TextIO, Union import torch.multiprocessing as mp from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record @@ -73,7 +73,7 @@ def __init__(self, msg: str, sigval: signal.Signals) -> None: self.sigval = sigval -def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None: +def _terminate_process_handler(signum: int, frame: FrameType | None) -> None: """Termination handler that raises exceptions on the main process. When the process receives death signal(SIGTERM, SIGINT), this termination handler will @@ -156,9 +156,7 @@ def to_std(v: str) -> Std: # type: ignore[return] ) -def to_map( - val_or_map: Union[Std, dict[int, Std]], local_world_size: int -) -> dict[int, Std]: +def to_map(val_or_map: Std | dict[int, Std], local_world_size: int) -> dict[int, Std]: """ Certain APIs take redirect settings either as a single value (e.g. apply to all local ranks) or as an explicit user-provided mapping. This method is a convenience @@ -216,10 +214,10 @@ class LogsSpecs(ABC): def __init__( self, - log_dir: Optional[str] = None, - redirects: Union[Std, dict[int, Std]] = Std.NONE, - tee: Union[Std, dict[int, Std]] = Std.NONE, - local_ranks_filter: Optional[set[int]] = None, + log_dir: str | None = None, + redirects: Std | dict[int, Std] = Std.NONE, + tee: Std | dict[int, Std] = Std.NONE, + local_ranks_filter: set[int] | None = None, ) -> None: self._root_log_dir = log_dir self._redirects = redirects @@ -254,10 +252,10 @@ class DefaultLogsSpecs(LogsSpecs): def __init__( self, - log_dir: Optional[str] = None, - redirects: Union[Std, dict[int, Std]] = Std.NONE, - tee: Union[Std, dict[int, Std]] = Std.NONE, - local_ranks_filter: Optional[set[int]] = None, + log_dir: str | None = None, + redirects: Std | dict[int, Std] = Std.NONE, + tee: Std | dict[int, Std] = Std.NONE, + local_ranks_filter: set[int] | None = None, ) -> None: if log_dir != os.devnull: if not log_dir: @@ -275,7 +273,7 @@ def __init__( def root_log_dir(self) -> str: return str(self._root_log_dir) - def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): + def _make_log_dir(self, log_dir: str | None, rdzv_run_id: str): base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") os.makedirs(base_log_dir, exist_ok=True) dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) @@ -465,13 +463,13 @@ class PContext(abc.ABC): def __init__( self, name: str, - entrypoint: Union[Callable, str], + entrypoint: Callable | str, args: dict[int, tuple], envs: dict[int, dict[str, str]], logs_specs: LogsSpecs, - log_line_prefixes: Optional[dict[int, str]] = None, - duplicate_stdout_filters: Optional[list[str]] = None, - duplicate_stderr_filters: Optional[list[str]] = None, + log_line_prefixes: dict[int, str] | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, ): self.name = name # validate that all mappings have the same number of keys and @@ -491,8 +489,8 @@ def __init__( self.stderrs = logs_dest.stderrs self.error_files = logs_dest.error_files self.nprocs = nprocs - self.filtered_stdout: Optional[TextIO] = None - self.filtered_stderr: Optional[TextIO] = None + self.filtered_stdout: TextIO | None = None + self.filtered_stderr: TextIO | None = None self._tail_logs = [ TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes), @@ -582,7 +580,7 @@ def _start(self) -> None: raise NotImplementedError @abc.abstractmethod - def _poll(self) -> Optional[RunProcsResult]: + def _poll(self) -> RunProcsResult | None: """ Poll the run status of the processes running under this context. This method follows an "all-or-nothing" policy and returns @@ -592,7 +590,7 @@ def _poll(self) -> Optional[RunProcsResult]: """ raise NotImplementedError - def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]: + def wait(self, timeout: float = -1, period: float = 1) -> RunProcsResult | None: """ Wait for the specified ``timeout`` seconds, polling every ``period`` seconds for the processes to be done. Returns ``None`` if the processes are still running @@ -646,9 +644,7 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: """ raise NotImplementedError - def close( - self, death_sig: Optional[signal.Signals] = None, timeout: int = 30 - ) -> None: + def close(self, death_sig: signal.Signals | None = None, timeout: int = 30) -> None: r""" Terminates all processes managed by this context and cleans up any meta resources (e.g. redirect, error_file files). @@ -685,7 +681,7 @@ def _wrap( stderr_redirects: dict[int, str], # redirect file for stderr (to console if None) ret_vals: dict[int, mp.SimpleQueue], queue_finished_reading_event: synchronize.Event, - numa_options: Optional[NumaOptions], + numa_options: NumaOptions | None, ) -> None: # get the per-rank params up front so we fail fast if no mapping is found args_ = args[local_rank] @@ -721,10 +717,10 @@ def __init__( envs: dict[int, dict[str, str]], start_method: str, logs_specs: LogsSpecs, - log_line_prefixes: Optional[dict[int, str]] = None, - numa_options: Optional[NumaOptions] = None, - duplicate_stdout_filters: Optional[list[str]] = None, - duplicate_stderr_filters: Optional[list[str]] = None, + log_line_prefixes: dict[int, str] | None = None, + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, ): super().__init__( name, @@ -746,12 +742,12 @@ def __init__( # see comments in ``join()`` for what this is self._return_values: dict[int, Any] = {} - self._pc: Optional[mp.ProcessContext] = None + self._pc: mp.ProcessContext | None = None # Note: set method should ONLY be invoked for the use case when all processes finished # successfully. If any process died on event.wait() calling set() method will deadlock. self._worker_finished_event = mp.get_context(self.start_method).Event() - self._numa_options: Optional[NumaOptions] = numa_options + self._numa_options: NumaOptions | None = numa_options def _start(self): if self._pc: @@ -780,7 +776,7 @@ def _start(self): def _is_done(self) -> bool: return len(self._return_values) == self.nprocs - def _poll(self) -> Optional[RunProcsResult]: + def _poll(self) -> RunProcsResult | None: assert self._pc is not None # assertion for mypy type checker try: @@ -910,10 +906,10 @@ def __init__( args: dict[int, tuple], envs: dict[int, dict[str, str]], logs_specs: LogsSpecs, - log_line_prefixes: Optional[dict[int, str]] = None, - numa_options: Optional[NumaOptions] = None, - duplicate_stdout_filters: Optional[list[str]] = None, - duplicate_stderr_filters: Optional[list[str]] = None, + log_line_prefixes: dict[int, str] | None = None, + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, ): super().__init__( name, @@ -930,7 +926,7 @@ def __init__( self._running_local_ranks: set[int] = set(range(self.nprocs)) self._failures: dict[int, ProcessFailure] = {} self.subprocess_handlers: dict[int, SubprocessHandler] = {} - self._numa_options: Optional[NumaOptions] = numa_options + self._numa_options: NumaOptions | None = numa_options def _start(self): if self.subprocess_handlers: @@ -965,7 +961,7 @@ def _capture_process_failures(self, done_local_ranks: set[int]): ) # else: --> succeeded; nothing to do - def _poll(self) -> Optional[RunProcsResult]: + def _poll(self) -> RunProcsResult | None: done_local_ranks: set[int] = set() self._capture_process_failures(done_local_ranks) diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index fa6abc8794b65..f61c99dc5c777 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -312,8 +312,8 @@ def _format_failure( def record( - fn: Callable[_P, _R], error_handler: Optional[ErrorHandler] = None -) -> Callable[_P, Union[_R, None]]: + fn: Callable[_P, _R], error_handler: ErrorHandler | None = None +) -> Callable[_P, _R | None]: """ Syntactic sugar to record errors/exceptions that happened in the decorated function using the provided ``error_handler``. @@ -353,7 +353,7 @@ def main(): if not error_handler: error_handler = get_error_handler() - def wrap(f: Callable[_P, _R]) -> Callable[_P, Union[_R, None]]: + def wrap(f: Callable[_P, _R]) -> Callable[_P, _R | None]: @wraps(f) def wrapper(*args: _P.args, **kwargs: _P.kwargs): assert error_handler is not None # assertion for mypy type checker diff --git a/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/torch/distributed/elastic/multiprocessing/errors/error_handler.py index 437a9c07d2cf9..ab6613e54dee1 100644 --- a/torch/distributed/elastic/multiprocessing/errors/error_handler.py +++ b/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -13,7 +13,7 @@ import time import traceback import warnings -from typing import Any, Optional +from typing import Any __all__ = ["ErrorHandler"] @@ -33,7 +33,7 @@ class ErrorHandler: Subclasses should override ``initialize()`` and ``record_exception()``. """ - def _get_error_file_path(self) -> Optional[str]: + def _get_error_file_path(self) -> str | None: """ Return the error file path. diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py index 947ce7b001ef7..ea1742626e285 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( SubprocessHandler, @@ -21,7 +20,7 @@ def get_subprocess_handler( stdout: str, stderr: str, local_rank_id: int, - numa_options: Optional[NumaOptions] = None, + numa_options: NumaOptions | None = None, ) -> SubprocessHandler: return SubprocessHandler( entrypoint=entrypoint, diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py index eae4e632e0856..d4642541a191c 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -9,7 +9,7 @@ import signal import sys from subprocess import Popen -from typing import Any, Optional +from typing import Any from torch.numa.binding import maybe_wrap_command_args_with_numa_binding, NumaOptions @@ -38,10 +38,10 @@ def __init__( entrypoint: str, args: tuple, env: dict[str, str], - stdout: Optional[str], - stderr: Optional[str], + stdout: str | None, + stderr: str | None, local_rank_id: int, - numa_options: Optional[NumaOptions], + numa_options: NumaOptions | None, ): self._stdout = open(stdout, "w") if stdout else None self._stderr = open(stderr, "w") if stderr else None @@ -76,7 +76,7 @@ def _popen(self, args: tuple, env: dict[str, str]) -> Popen: **kwargs, ) - def close(self, death_sig: Optional[signal.Signals] = None) -> None: + def close(self, death_sig: signal.Signals | None = None) -> None: if not death_sig: death_sig = _get_default_signal() if IS_WINDOWS: diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index ad7c37e82c098..77d410cce55c0 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -13,7 +13,7 @@ from collections.abc import Callable from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Optional, TextIO, TYPE_CHECKING +from typing import TextIO, TYPE_CHECKING if TYPE_CHECKING: @@ -30,7 +30,7 @@ def tail_logfile( dst: TextIO, finished: Event, interval_sec: float, - log_line_filter: Optional[Callable[[str], bool]] = None, + log_line_filter: Callable[[str], bool] | None = None, ): while not os.path.exists(file): if finished.is_set(): @@ -98,7 +98,7 @@ def __init__( name: str, log_files: dict[int, str], dst: TextIO, - log_line_prefixes: Optional[dict[int, str]] = None, + log_line_prefixes: dict[int, str] | None = None, interval_sec: float = 0.1, log_line_filter: Callable[[str], bool] = (lambda _: True), ): diff --git a/torch/distributed/elastic/rendezvous/_etcd_stub.py b/torch/distributed/elastic/rendezvous/_etcd_stub.py index 066a1c973e4d9..5890a97c672a6 100644 --- a/torch/distributed/elastic/rendezvous/_etcd_stub.py +++ b/torch/distributed/elastic/rendezvous/_etcd_stub.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional +from typing import Any """ @@ -65,11 +65,11 @@ def read(self, key: str) -> None: raise EtcdStubError def write( - self, key: str, value: Any, ttl: Optional[int] = None, **kwargs: Any + self, key: str, value: Any, ttl: int | None = None, **kwargs: Any ) -> None: raise EtcdStubError def test_and_set( - self, key: str, value: Any, prev_value: Any, ttl: Optional[int] = None + self, key: str, value: Any, prev_value: Any, ttl: int | None = None ) -> None: raise EtcdStubError diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index 9e66c0228daa7..2b3fa8183dfb8 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar from torch.distributed import Store from torch.distributed.elastic.utils.distributed import get_free_port @@ -72,8 +72,8 @@ class RendezvousStoreInfo: def build( rank: int, store: Store, - local_addr: Optional[str], - server_port: Optional[int] = None, + local_addr: str | None, + server_port: int | None = None, ) -> "RendezvousStoreInfo": """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. @@ -137,7 +137,7 @@ def world_size(self) -> int: return self._world_size @property - def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]: + def bootstrap_store_info(self) -> RendezvousStoreInfo | None: """Store information that can used by trainer code to bootstrap distributed comms.""" return self._bootstrap_store_info @@ -265,7 +265,7 @@ def __init__( run_id: str, min_nodes: int, max_nodes: int, - local_addr: Optional[str] = None, + local_addr: str | None = None, **kwargs, ): if not backend: @@ -293,7 +293,7 @@ def get(self, key: str, default: Any = None) -> Any: """Return the value for ``key`` if ``key`` exists, else ``default``.""" return self.config.get(key, default) - def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]: + def get_as_bool(self, key: str, default: bool | None = None) -> bool | None: """Return the value for ``key`` as a ``bool``.""" value = self.get(key, default) if value is None or isinstance(value, bool): @@ -312,7 +312,7 @@ def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool f"The rendezvous configuration option '{key}' does not represent a valid boolean value." ) - def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]: + def get_as_int(self, key: str, default: int | None = None) -> int | None: """Return the value for ``key`` as an ``int``.""" value = self.get(key, default) if value is None: @@ -350,7 +350,7 @@ def register(self, backend: str, creator: RendezvousHandlerCreator) -> None: if not backend: raise ValueError("The rendezvous backend name must be a non-empty string.") - current_creator: Optional[RendezvousHandlerCreator] + current_creator: RendezvousHandlerCreator | None try: current_creator = self._registry[backend] except KeyError: diff --git a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py index 982ff267a06a9..0296c4d45ddc1 100644 --- a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -11,7 +11,7 @@ import tempfile from base64 import b64decode, b64encode from datetime import timedelta -from typing import Any, cast, Optional +from typing import Any, cast from torch.distributed import FileStore, Store, TCPStore from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState @@ -70,15 +70,15 @@ def name(self) -> str: """See base class.""" return "c10d" - def get_state(self) -> Optional[tuple[bytes, Token]]: + def get_state(self) -> tuple[bytes, Token] | None: """See base class.""" base64_state: bytes = self._call_store("get", self._key) return self._decode_state(base64_state) def set_state( - self, state: bytes, token: Optional[Token] = None - ) -> Optional[tuple[bytes, Token, bool]]: + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: """See base class.""" base64_state_str: str = b64encode(state).decode() @@ -117,7 +117,7 @@ def _call_store(self, store_op: str, *args, **kwargs) -> Any: "The connection to the C10d store has failed. See inner exception for details." ) from exc - def _decode_state(self, base64_state: bytes) -> Optional[tuple[bytes, Token]]: + def _decode_state(self, base64_state: bytes) -> tuple[bytes, Token] | None: if base64_state == self._NULL_SENTINEL.encode(): return None diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index 2a0e44aef31af..35496e62ba6ac 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta, timezone from enum import Enum -from typing import Any, Optional +from typing import Any import torch.distributed as dist from torch.distributed import Store @@ -68,7 +68,7 @@ def name(self) -> str: """Get the name of the backend.""" @abstractmethod - def get_state(self) -> Optional[tuple[bytes, Token]]: + def get_state(self) -> tuple[bytes, Token] | None: """Get the rendezvous state. Returns: @@ -84,8 +84,8 @@ def get_state(self) -> Optional[tuple[bytes, Token]]: @abstractmethod def set_state( - self, state: bytes, token: Optional[Token] = None - ) -> Optional[tuple[bytes, Token, bool]]: + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: """Set the rendezvous state. The new rendezvous state is set conditionally: @@ -154,10 +154,10 @@ class RendezvousTimeout: def __init__( self, - join: Optional[timedelta] = None, - last_call: Optional[timedelta] = None, - close: Optional[timedelta] = None, - heartbeat: Optional[timedelta] = None, + join: timedelta | None = None, + last_call: timedelta | None = None, + close: timedelta | None = None, + heartbeat: timedelta | None = None, ) -> None: self._set_timeouts( join=join, last_call=last_call, close=close, heartbeat=heartbeat @@ -183,7 +183,7 @@ def heartbeat(self) -> timedelta: """Get the keep-alive heartbeat timeout.""" return self._heartbeat - def _set_timeouts(self, **timeouts: Optional[timedelta]): + def _set_timeouts(self, **timeouts: timedelta | None): for name, timeout in timeouts.items(): if timeout is None: timeout = self._DEFAULT_TIMEOUTS[name] @@ -258,7 +258,7 @@ def __init__(self) -> None: # An integer that is incremented with each call to generate(). self._local_id = 0 - def generate(self, local_addr: Optional[str] = None) -> _NodeDesc: + def generate(self, local_addr: str | None = None) -> _NodeDesc: # This method can be called by multiple threads concurrently; therefore, # we must increment the integer atomically. with self._lock: @@ -297,7 +297,7 @@ class _RendezvousState: round: int complete: bool - deadline: Optional[datetime] + deadline: datetime | None closed: bool participants: dict[_NodeDesc, int] wait_list: set[_NodeDesc] @@ -345,7 +345,7 @@ def state(self) -> _RendezvousState: """Get the local state.""" @abstractmethod - def sync(self) -> Optional[bool]: + def sync(self) -> bool | None: """Read or writes the latest state. Returns: @@ -408,13 +408,13 @@ def state(self) -> _RendezvousState: """See base class.""" return self._state - def sync(self) -> Optional[bool]: + def sync(self) -> bool | None: """See base class.""" - state_bits: Optional[bytes] = None + state_bits: bytes | None = None token = None - has_set: Optional[bool] + has_set: bool | None if self._dirty: has_set = False @@ -574,7 +574,7 @@ def run( self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float, - update_deadline: Optional[Callable[[timedelta], float]] = None, + update_deadline: Callable[[timedelta], float] | None = None, ) -> None: """Execute a rendezvous operation. @@ -638,7 +638,7 @@ def run( self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float, - update_deadline: Optional[Callable[[timedelta], float]] = None, + update_deadline: Callable[[timedelta], float] | None = None, ) -> None: """See base class.""" action = None @@ -1006,7 +1006,7 @@ class DynamicRendezvousHandler(RendezvousHandler): _state_holder: _RendezvousStateHolder _op_executor: _RendezvousOpExecutor _heartbeat_lock: threading.Lock - _keep_alive_timer: Optional[_PeriodicTimer] + _keep_alive_timer: _PeriodicTimer | None @classmethod def from_backend( @@ -1016,8 +1016,8 @@ def from_backend( backend: RendezvousBackend, min_nodes: int, max_nodes: int, - local_addr: Optional[str] = None, - timeout: Optional[RendezvousTimeout] = None, + local_addr: str | None = None, + timeout: RendezvousTimeout | None = None, keep_alive_interval: int = 5, keep_alive_max_attempt: int = 3, ): @@ -1102,15 +1102,15 @@ def __init__( self._keep_alive_timer = None # Cached shared store server reference - self._shared_tcp_store_server: Optional[dist.Store] = None + self._shared_tcp_store_server: dist.Store | None = None - self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None + self._bootstrap_store_info: RendezvousStoreInfo | None = None def _record( self, message: str, node_state: NodeState = NodeState.RUNNING, - rank: Optional[int] = None, + rank: int | None = None, ) -> None: construct_and_record_rdzv_event( name=f"{self.__class__.__name__}.{get_method_name()}", @@ -1379,7 +1379,7 @@ def _get_deadline(self, timeout: timedelta) -> float: return time.monotonic() + timeout.total_seconds() -def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]: +def _get_timeout(params: RendezvousParameters, key: str) -> timedelta | None: timeout = params.get_as_int(key + "_timeout") if timeout is None: return None diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index 300399414d9ce..93a7073bed87a 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -12,7 +12,6 @@ import sys import threading import time -from typing import Optional try: @@ -153,7 +152,7 @@ class EtcdRendezvousHandler(RendezvousHandler): +--------------------------------------------+--------------------------+ """ - def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]): + def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: str | None): """ Args: rdzv_impl: the implementation of the rendezvous @@ -542,7 +541,7 @@ def join_rendezvous(self, expected_version): # When reaching min workers, or changing state to frozen, we'll set # the active_version node to be ephemeral. - set_ttl: Optional[int] = None + set_ttl: int | None = None if len(state["participants"]) == self._num_max_workers: state["status"] = "frozen" state["keep_alives"] = [] diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py index a0012607ce36f..4cda28221ff4e 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -7,7 +7,7 @@ import binascii from base64 import b64decode, b64encode -from typing import cast, Optional +from typing import cast import urllib3.exceptions # type: ignore[import] @@ -49,8 +49,8 @@ def __init__( self, client: etcd.Client, run_id: str, - key_prefix: Optional[str] = None, - ttl: Optional[int] = None, + key_prefix: str | None = None, + ttl: int | None = None, ) -> None: if not run_id: raise ValueError("The run id must be a non-empty string.") @@ -72,7 +72,7 @@ def name(self) -> str: """See base class.""" return "etcd-v2" - def get_state(self) -> Optional[tuple[bytes, Token]]: + def get_state(self) -> tuple[bytes, Token] | None: """See base class.""" try: result = self._client.read(self._key) @@ -86,8 +86,8 @@ def get_state(self) -> Optional[tuple[bytes, Token]]: return self._decode_state(result) def set_state( - self, state: bytes, token: Optional[Token] = None - ) -> Optional[tuple[bytes, Token, bool]]: + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: """See base class.""" base64_state = b64encode(state).decode() diff --git a/torch/distributed/elastic/rendezvous/etcd_server.py b/torch/distributed/elastic/rendezvous/etcd_server.py index 7e54fdd9839af..347e7339d9a46 100644 --- a/torch/distributed/elastic/rendezvous/etcd_server.py +++ b/torch/distributed/elastic/rendezvous/etcd_server.py @@ -15,7 +15,7 @@ import subprocess import tempfile import time -from typing import Optional, TextIO, Union +from typing import TextIO try: @@ -64,7 +64,7 @@ def find_free_port(): raise RuntimeError("Failed to create a socket") -def stop_etcd(subprocess, data_dir: Optional[str] = None): +def stop_etcd(subprocess, data_dir: str | None = None): if subprocess and subprocess.poll() is None: logger.info("stopping etcd server") subprocess.terminate() @@ -107,7 +107,7 @@ class EtcdServer: etcd_binary_path: path of etcd server binary (see above for fallback path) """ - def __init__(self, data_dir: Optional[str] = None): + def __init__(self, data_dir: str | None = None): self._port = -1 self._host = "localhost" @@ -123,7 +123,7 @@ def __init__(self, data_dir: Optional[str] = None): data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data") ) self._etcd_cmd = None - self._etcd_proc: Optional[subprocess.Popen] = None + self._etcd_proc: subprocess.Popen | None = None def _get_etcd_server_process(self) -> subprocess.Popen: if not self._etcd_proc: @@ -149,7 +149,7 @@ def start( self, timeout: int = 60, num_retries: int = 3, - stderr: Union[int, TextIO, None] = None, + stderr: int | TextIO | None = None, ) -> None: """ Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests. @@ -185,7 +185,7 @@ def start( atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir) def _start( - self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None + self, data_dir: str, timeout: int = 60, stderr: int | TextIO | None = None ) -> None: sock = find_free_port() sock_peer = find_free_port() diff --git a/torch/distributed/elastic/rendezvous/etcd_store.py b/torch/distributed/elastic/rendezvous/etcd_store.py index 781a40e20e91c..faaf77587bc9d 100644 --- a/torch/distributed/elastic/rendezvous/etcd_store.py +++ b/torch/distributed/elastic/rendezvous/etcd_store.py @@ -9,7 +9,6 @@ import random import time from base64 import b64decode, b64encode -from typing import Optional # pyre-ignore[21]: Could not find name `Store` in `torch.distributed`. from torch.distributed import Store @@ -40,7 +39,7 @@ def __init__( etcd_client, etcd_store_prefix, # Default timeout same as in c10d/Store.hpp - timeout: Optional[datetime.timedelta] = None, + timeout: datetime.timedelta | None = None, ): super().__init__() # required for pybind trampoline. @@ -121,7 +120,7 @@ def add(self, key, num: int) -> int: except etcd.EtcdCompareFailed: cas_delay() - def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None): + def wait(self, keys, override_timeout: datetime.timedelta | None = None): """ Wait until all of the keys are published, or until timeout. diff --git a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py index e6395b70be2b4..52b6800053088 100644 --- a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -9,7 +9,7 @@ import datetime import logging -from typing import cast, Optional +from typing import cast from torch.distributed import PrefixStore, Store, TCPStore from torch.distributed.elastic.rendezvous import ( @@ -51,7 +51,7 @@ def __init__( self.world_size = world_size self.run_id = run_id self.timeout = datetime.timedelta(seconds=timeout) - self._store: Optional[Store] = None + self._store: Store | None = None def get_backend(self) -> str: return "static" diff --git a/torch/distributed/elastic/rendezvous/utils.py b/torch/distributed/elastic/rendezvous/utils.py index e4717959232d1..05ebbba55913f 100644 --- a/torch/distributed/elastic/rendezvous/utils.py +++ b/torch/distributed/elastic/rendezvous/utils.py @@ -14,7 +14,7 @@ from collections.abc import Callable from datetime import timedelta from threading import Event, Thread -from typing import Any, Optional, Union +from typing import Any __all__ = ["parse_rendezvous_endpoint"] @@ -44,7 +44,7 @@ def _parse_rendezvous_config(config_str: str) -> dict[str, str]: "=,...,=." ) - value: Optional[str] + value: str | None if values: value = values[0].strip() else: @@ -58,7 +58,7 @@ def _parse_rendezvous_config(config_str: str) -> dict[str, str]: return config -def _try_parse_port(port_str: str) -> Optional[int]: +def _try_parse_port(port_str: str) -> int | None: """Try to extract the port number from ``port_str``.""" if port_str and re.match(r"^[0-9]{1,5}$", port_str): return int(port_str) @@ -66,7 +66,7 @@ def _try_parse_port(port_str: str) -> Optional[int]: def parse_rendezvous_endpoint( - endpoint: Optional[str], default_port: int + endpoint: str | None, default_port: int ) -> tuple[str, int]: """Extract the hostname and the port number from a rendezvous endpoint. @@ -166,7 +166,7 @@ def _matches_machine_hostname(host: str) -> bool: return False -def _delay(seconds: Union[float, tuple[float, float]]) -> None: +def _delay(seconds: float | tuple[float, float]) -> None: """Suspend the current thread for ``seconds``. Args: @@ -200,9 +200,9 @@ class _Context: kwargs: dict[str, Any] stop_event: Event - _name: Optional[str] - _thread: Optional[Thread] - _finalizer: Optional[weakref.finalize] + _name: str | None + _thread: Thread | None + _finalizer: weakref.finalize | None # The context that is shared between the timer and the background thread. _ctx: _Context @@ -227,7 +227,7 @@ def __init__( self._finalizer = None @property - def name(self) -> Optional[str]: + def name(self) -> str | None: """Get the name of the timer.""" return self._name diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index 7c856f078d89a..efe942022246e 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -10,7 +10,7 @@ import time from contextlib import contextmanager from inspect import getframeinfo, stack -from typing import Any, Optional +from typing import Any __all__ = [ @@ -130,7 +130,7 @@ def __init__( self._request_queue = request_queue self._max_interval = max_interval self._daemon = daemon - self._watchdog_thread: Optional[threading.Thread] = None + self._watchdog_thread: threading.Thread | None = None self._stop_signaled = False @abc.abstractmethod @@ -234,7 +234,7 @@ def stop(self) -> None: logger.info("No watchdog thread running, doing nothing") -_timer_client: Optional[TimerClient] = None +_timer_client: TimerClient | None = None def configure(timer_client: TimerClient): @@ -247,9 +247,7 @@ def configure(timer_client: TimerClient): @contextmanager -def expires( - after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None -): +def expires(after: float, scope: str | None = None, client: TimerClient | None = None): """ Acquires a countdown timer that expires in ``after`` seconds from now, unless the code-block that it wraps is finished within the timeframe. diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 8ed457a19f115..14ec6e6af8537 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -14,7 +14,7 @@ import threading import time from collections.abc import Callable -from typing import Optional, TypeVar +from typing import TypeVar from typing_extensions import ParamSpec from torch.distributed.elastic.timer.api import TimerClient, TimerRequest @@ -131,7 +131,7 @@ def __init__( self.signal = signal @_retry(max_retries=10, sleep_time=0.1) - def _open_non_blocking(self) -> Optional[io.TextIOWrapper]: + def _open_non_blocking(self) -> io.TextIOWrapper | None: # The server may have crashed or may haven't started yet. # In such case, calling open() in blocking model blocks the client. # To avoid such issue, open it in non-blocking mode, and an OSError will @@ -200,7 +200,7 @@ def __init__( run_id: str, max_interval: float = 10, daemon: bool = True, - log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None, + log_event: Callable[[str, FileTimerRequest | None], None] | None = None, ) -> None: self._file_path = file_path self._run_id = run_id @@ -208,7 +208,7 @@ def __init__( self._daemon = daemon self._timers: dict[tuple[int, str], FileTimerRequest] = {} self._stop_signaled = False - self._watchdog_thread: Optional[threading.Thread] = None + self._watchdog_thread: threading.Thread | None = None self._is_client_started = False if os.path.exists(self._file_path): diff --git a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py index a10d49ae4897f..c824cc2fd018c 100644 --- a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py +++ b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py @@ -8,7 +8,7 @@ import math from collections.abc import Iterator, Sized -from typing import cast, Optional, TypeVar +from typing import cast, TypeVar import torch from torch.utils.data import Dataset @@ -44,8 +44,8 @@ class ElasticDistributedSampler(DistributedSampler[T]): def __init__( self, dataset: Dataset[T], - num_replicas: Optional[int] = None, - rank: Optional[int] = None, + num_replicas: int | None = None, + rank: int | None = None, start_index: int = 0, ): super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank) diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 34a8cd8a22bb5..7b294d222ea7d 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -10,7 +10,6 @@ import os import socket from contextlib import closing -from typing import Optional import torch.distributed as dist from torch.distributed.elastic.utils.logging import get_logger @@ -35,7 +34,7 @@ def create_c10d_store( timeout: float = (60 * 10), # 10 min wait_for_workers: bool = True, retries=3, - use_libuv: Optional[bool] = None, + use_libuv: bool | None = None, ): if use_libuv is not None: logger.warning( diff --git a/torch/distributed/elastic/utils/logging.py b/torch/distributed/elastic/utils/logging.py index c7d56374e7d38..aadf37eb16b80 100644 --- a/torch/distributed/elastic/utils/logging.py +++ b/torch/distributed/elastic/utils/logging.py @@ -10,12 +10,11 @@ import logging import os import warnings -from typing import Optional from torch.distributed.elastic.utils.log_level import get_log_level -def get_logger(name: Optional[str] = None) -> logging.Logger: +def get_logger(name: str | None = None) -> logging.Logger: """ Util function to set up a simple logger that writes into stderr. The loglevel is fetched from the LOGLEVEL @@ -32,13 +31,13 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: return _setup_logger(name or _derive_module_name(depth=2)) -def _setup_logger(name: Optional[str] = None) -> logging.Logger: +def _setup_logger(name: str | None = None) -> logging.Logger: logger = logging.getLogger(name) logger.setLevel(os.environ.get("LOGLEVEL", get_log_level())) return logger -def _derive_module_name(depth: int = 1) -> Optional[str]: +def _derive_module_name(depth: int = 1) -> str | None: """ Derives the name of the caller module from the stack frames. diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py index e01991114bef8..598899e936aa0 100644 --- a/torch/distributed/elastic/utils/store.py +++ b/torch/distributed/elastic/utils/store.py @@ -10,7 +10,6 @@ from collections.abc import Callable, Iterable from contextlib import contextmanager from datetime import timedelta -from typing import Optional import torch @@ -109,7 +108,7 @@ def _try_detecting_missing_ranks( rank: int, rank_decoder: Callable[[int], str], trace_timeout: float, -) -> Optional[Iterable[str]]: +) -> Iterable[str] | None: store.set(f"{key_prefix}{rank}{_TRACE}", "") def _find_missing_ranks(): @@ -169,8 +168,8 @@ def barrier( world_size: int, key_prefix: str, barrier_timeout: float = 300, - rank: Optional[int] = None, - rank_tracing_decoder: Optional[Callable[[int], str]] = None, + rank: int | None = None, + rank_tracing_decoder: Callable[[int], str] | None = None, trace_timeout: float = 10, ) -> None: """ diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 666fb24463f0d..2adf5549fecf1 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -11,7 +11,7 @@ import uuid from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import Any import torch import torch.distributed.elastic.rendezvous.registry as rdzv_registry @@ -90,7 +90,7 @@ class LaunchConfig: min_nodes: int max_nodes: int nproc_per_node: int - logs_specs: Optional[LogsSpecs] = None + logs_specs: LogsSpecs | None = None run_id: str = "" role: str = "default_role" rdzv_endpoint: str = "" @@ -100,14 +100,14 @@ class LaunchConfig: max_restarts: int = 3 monitor_interval: float = 0.1 start_method: str = "spawn" - log_line_prefix_template: Optional[str] = None + log_line_prefix_template: str | None = None metrics_cfg: dict[str, str] = field(default_factory=dict) - local_addr: Optional[str] = None + local_addr: str | None = None event_log_handler: str = "null" - numa_options: Optional[NumaOptions] = None + numa_options: NumaOptions | None = None signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT" - duplicate_stdout_filters: Optional[list[str]] = None - duplicate_stderr_filters: Optional[list[str]] = None + duplicate_stdout_filters: list[str] | None = None + duplicate_stderr_filters: list[str] | None = None virtual_local_rank: bool = False def __post_init__(self): @@ -161,7 +161,7 @@ def main(): def __init__( self, config: LaunchConfig, - entrypoint: Union[Callable, str, None], + entrypoint: Callable | str | None, ): self._config = config self._entrypoint = entrypoint @@ -170,9 +170,7 @@ def __call__(self, *args): return launch_agent(self._config, self._entrypoint, list(args)) -def _get_entrypoint_name( - entrypoint: Union[Callable, str, None], args: list[Any] -) -> str: +def _get_entrypoint_name(entrypoint: Callable | str | None, args: list[Any]) -> str: """Retrieve entrypoint name with the rule: 1. If entrypoint is a function, use ``entrypoint.__qualname__``. 2. If entrypoint is a string, check its value: @@ -194,7 +192,7 @@ def _get_entrypoint_name( def _get_addr_and_port( rdzv_parameters: RendezvousParameters, -) -> tuple[Optional[str], Optional[int]]: +) -> tuple[str | None, int | None]: if rdzv_parameters.backend != "static": return (None, None) endpoint = rdzv_parameters.endpoint @@ -213,7 +211,7 @@ def _get_addr_and_port( def launch_agent( config: LaunchConfig, - entrypoint: Union[Callable, str, None], + entrypoint: Callable | str | None, args: list[Any], ) -> dict[int, Any]: if not config.run_id: diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index d2db28d4371de..728bf9c0288a2 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -5,7 +5,7 @@ import sys import types from collections.abc import Callable, Iterator, Mapping -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar, Union from typing_extensions import Self import torch @@ -122,8 +122,8 @@ def __init__( self, remote_device: str, module_cls: type[nn.Module], - args: Optional[tuple] = None, - kwargs: Optional[dict[str, Any]] = None, + args: tuple | None = None, + kwargs: dict[str, Any] | None = None, _module_interface_cls: Any = None, ): """ @@ -310,32 +310,32 @@ def __setstate__(self, state): ) def register_buffer( - self, name: str, tensor: Optional[Tensor], persistent: bool = True + self, name: str, tensor: Tensor | None, persistent: bool = True ) -> None: _raise_not_supported(self.register_buffer.__name__) - def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + def register_parameter(self, name: str, param: Parameter | None) -> None: _raise_not_supported(self.register_parameter.__name__) - def add_module(self, name: str, module: Optional[Module]) -> None: + def add_module(self, name: str, module: Module | None) -> None: _raise_not_supported(self.add_module.__name__) def apply(self, fn: Callable[[Module], None]) -> Self: # type: ignore[return] _raise_not_supported(self.apply.__name__) - def cuda(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return] + def cuda(self, device: int | device | None = None) -> Self: # type: ignore[return] _raise_not_supported(self.cuda.__name__) - def ipu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return] + def ipu(self, device: int | device | None = None) -> Self: # type: ignore[return] _raise_not_supported(self.ipu.__name__) - def xpu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return] + def xpu(self, device: int | device | None = None) -> Self: # type: ignore[return] _raise_not_supported(self.xpu.__name__) def cpu(self) -> Self: # type: ignore[return] _raise_not_supported(self.cpu.__name__) - def type(self, dst_type: Union[dtype, str]) -> Self: # type: ignore[return] + def type(self, dst_type: dtype | str) -> Self: # type: ignore[return] _raise_not_supported(self.type.__name__) def float(self) -> Self: # type: ignore[return] @@ -355,19 +355,16 @@ def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var] def register_backward_hook( # type: ignore[return] self, - hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]], + hook: Callable[[Module, _grad_t, _grad_t], None | _grad_t], # pyrefly: ignore [bad-return] ) -> RemovableHandle: _raise_not_supported(self.register_backward_hook.__name__) def register_forward_pre_hook( # type: ignore[return] self, - hook: Union[ - Callable[[T, tuple[Any, ...]], Optional[Any]], - Callable[ - [T, tuple[Any, ...], dict[str, Any]], - Optional[tuple[Any, dict[str, Any]]], - ], + hook: Callable[[T, tuple[Any, ...]], Any | None] + | Callable[ + [T, tuple[Any, ...], dict[str, Any]], tuple[Any, dict[str, Any]] | None ], prepend: bool = False, with_kwargs: bool = False, @@ -377,10 +374,8 @@ def register_forward_pre_hook( # type: ignore[return] def register_forward_hook( # type: ignore[return, override] self, - hook: Union[ - Callable[[T, tuple[Any, ...], Any], Optional[Any]], - Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]], - ], + hook: Callable[[T, tuple[Any, ...], Any], Any | None] + | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Any | None], prepend: bool = False, with_kwargs: bool = False, # pyrefly: ignore [bad-return] @@ -435,7 +430,7 @@ def modules(self) -> Iterator[Module]: # type: ignore[return] def named_modules( self, - memo: Optional[set[Module]] = None, + memo: set[Module] | None = None, prefix: str = "", remove_duplicate: bool = True, ): @@ -694,8 +689,8 @@ def __init__( self, remote_device: str, module_cls: type[nn.Module], - args: Optional[tuple] = None, - kwargs: Optional[dict[str, Any]] = None, + args: tuple | None = None, + kwargs: dict[str, Any] | None = None, ): super().__init__(remote_device, module_cls, args, kwargs) diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index 9af7bba4680dc..e8455c5ef5a41 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch import torch.optim._functional as F @@ -53,7 +52,7 @@ def __init__( self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) - def step(self, gradients: list[Optional[Tensor]]): + def step(self, gradients: list[Tensor | None]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index 5820a94183c72..3da4e29b3f015 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch import torch.optim._functional as F @@ -70,7 +69,7 @@ def __init__( "step": torch.tensor(0.0), } - def step(self, gradients: list[Optional[Tensor]]): + def step(self, gradients: list[Tensor | None]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index b736cd4d164f7..1763edd14c9da 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch import torch.optim._functional as F @@ -68,7 +67,7 @@ def __init__( # param group as it's not a common use case. self.param_group = {"params": params} - def step_param(self, param: Tensor, grad: Optional[Tensor]): + def step_param(self, param: Tensor, grad: Tensor | None): """ Similar to step, but operates on a single parameter and optionally a gradient tensor. @@ -128,7 +127,7 @@ def step_param(self, param: Tensor, grad: Optional[Tensor]): found_inf=None, ) - def step(self, gradients: list[Optional[Tensor]]): + def step(self, gradients: list[Tensor | None]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index 9327eca3abfbb..595a5668a78fc 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch import torch.optim._functional as F @@ -64,7 +63,7 @@ def __init__( # param group as it's not a common use case. self.param_group = {"params": params} - def step(self, gradients: list[Optional[Tensor]]): + def step(self, gradients: list[Tensor | None]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 8d79cc0f27f0e..d695ce8b473af 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch import torch.optim._functional as F @@ -68,7 +67,7 @@ def __init__( # param group as it's not a common use case. self.param_group = {"params": params} - def step_param(self, param: Tensor, grad: Optional[Tensor]): + def step_param(self, param: Tensor, grad: Tensor | None): params_with_grad = [] grads = [] exp_avgs = [] @@ -129,7 +128,7 @@ def step_param(self, param: Tensor, grad: Optional[Tensor]): has_complex=has_complex, ) - def step(self, gradients: list[Optional[Tensor]]): + def step(self, gradients: list[Tensor | None]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index 424c2276bff08..45341b03237b4 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch import torch.optim._functional as F @@ -57,7 +56,7 @@ def __init__( self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) - def step(self, gradients: list[Optional[Tensor]]): + def step(self, gradients: list[Tensor | None]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index 877ea6bddef47..ffc9c510dabca 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch import torch.optim._functional as F @@ -51,7 +50,7 @@ def __init__( self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) - def step(self, gradients: list[Optional[Tensor]]): + def step(self, gradients: list[Tensor | None]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index e0a00cf02e976..aed92403e6fb6 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch import torch.optim._functional as F @@ -56,7 +55,7 @@ def __init__( # param group as it's not a common use case. self.param_group = {"params": params} - def step_param(self, param: Tensor, grad: Optional[Tensor]): + def step_param(self, param: Tensor, grad: Tensor | None): """Similar to self.step, but operates on a single parameter and its gradient. """ @@ -67,7 +66,7 @@ def step_param(self, param: Tensor, grad: Optional[Tensor]): dampening = self.defaults["dampening"] lr = self.defaults["lr"] params = [param] - momentum_buffer_list: list[Optional[Tensor]] = [] + momentum_buffer_list: list[Tensor | None] = [] grads = [] has_sparse_grad = False @@ -106,11 +105,11 @@ def step_param(self, param: Tensor, grad: Optional[Tensor]): if momentum_buffer is not None: state["momentum_buffer"] = momentum_buffer - def step(self, gradients: list[Optional[Tensor]]): + def step(self, gradients: list[Tensor | None]): params = self.param_group["params"] params_with_grad = [] grads = [] - momentum_buffer_list: list[Optional[Tensor]] = [] + momentum_buffer_list: list[Tensor | None] = [] lr = self.defaults["lr"] weight_decay = self.defaults["weight_decay"] momentum = self.defaults["momentum"] diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index c2384dabd9dad..a8432e198a083 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable, Collection, Mapping from copy import deepcopy -from typing import Any, Optional, overload, Union +from typing import Any, overload import torch import torch.nn as nn @@ -62,10 +62,10 @@ class _NamedOptimizer(optim.Optimizer): def __init__( self, - named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]], + named_parameters: Mapping[str, torch.Tensor | ShardedTensor], optimizer_class: optim.Optimizer, - param_groups: Optional[Collection[Mapping[str, Any]]] = None, - module: Optional[nn.Module] = None, + param_groups: Collection[Mapping[str, Any]] | None = None, + module: nn.Module | None = None, *args: tuple[Any, ...], **kwargs: dict[str, Any], ) -> None: @@ -152,7 +152,7 @@ def step(self, closure: None = None) -> None: ... @overload def step(self, closure: Callable[[], float]) -> float: ... - def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + def step(self, closure: Callable[[], float] | None = None) -> float | None: """ Perform a single optimization step. diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index 9d17601a4e3fb..f9477aa414b42 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -2,7 +2,6 @@ import logging from collections import defaultdict from threading import Lock -from typing import Optional import torch import torch.distributed.autograd as dist_autograd @@ -51,7 +50,7 @@ def __init__(self, optim_cls, local_params_rref, *args, **kwargs): def step(self, autograd_ctx_id: int): all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) # apply functional optimizer step with a list of gradients - grads: list[Optional[Tensor]] = [ + grads: list[Tensor | None] = [ all_local_grads[p] if p in all_local_grads else None # noqa: SIM401 for p in self._local_params ] diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index 8c82b53eff757..3183299a48347 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -13,7 +13,7 @@ import logging from collections.abc import Callable from itertools import chain -from typing import Any, Optional, Union +from typing import Any import torch import torch.distributed as dist @@ -173,7 +173,7 @@ def __init__( # DDP guarantees all parameters in the bucket have the same device # pyrefly: ignore [read-only] self.device: torch.device = self.parameters[0].device - self.tensor: Optional[torch.Tensor] = None + self.tensor: torch.Tensor | None = None class _OverlapStatus(enum.IntEnum): @@ -252,7 +252,7 @@ def __init__(self, world_size) -> None: # Group Ranks self.assigned_ranks_per_bucket: list[set[int]] = [] self.num_bucket_assignments: int = 0 - self.total_size: Optional[int] = None + self.total_size: int | None = None # Modified per iteration self.broadcast_handles: list[Any] = [] @@ -377,7 +377,7 @@ def __init__( self, params, optimizer_class: type[Optimizer], - process_group: Optional[Any] = None, + process_group: Any | None = None, parameters_as_bucket_view: bool = False, overlap_with_ddp: bool = False, **defaults: Any, @@ -649,7 +649,7 @@ def _partition_param_group( def _partition_parameters( self, - params_per_rank: Optional[list[list[torch.Tensor]]] = None, + params_per_rank: list[list[torch.Tensor]] | None = None, ) -> list[list[dict]]: r""" Partitions parameters across distributed data parallel ranks. @@ -869,7 +869,7 @@ def _device_to_params_per_rank( def _get_min_index( self, values: list[int], - disallowed_indices: Optional[set[int]] = None, + disallowed_indices: set[int] | None = None, ) -> int: r""" Return ``values.index(min(values))``, except only uses one pass. @@ -1036,10 +1036,10 @@ def _bucket_assignments_per_rank(self) -> list[dict[int, _DDPBucketAssignment]]: def _local_step( self, - gradients: Optional[list[Optional[torch.Tensor]]] = None, - closure: Optional[Callable[[], float]] = None, + gradients: list[torch.Tensor | None] | None = None, + closure: Callable[[], float] | None = None, **kwargs: Any, - ) -> Optional[float]: + ) -> float | None: r""" Perform a single optimizer step without syncing parameters across ranks. @@ -1111,9 +1111,9 @@ def _local_step( # pyrefly: ignore [bad-override] def step( self, - closure: Optional[Callable[[], float]] = None, + closure: Callable[[], float] | None = None, **kwargs: Any, - ) -> Optional[float]: + ) -> float | None: r""" Perform a single optimizer step and syncs parameters across all ranks. @@ -1403,7 +1403,7 @@ def _build_ddp_param_buckets(self) -> None: def _verify_and_init_params( self, params: Any, - ) -> Union[list[torch.Tensor], list[dict]]: + ) -> list[torch.Tensor] | list[dict]: r""" Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters. diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index e34460449e1e0..bfcf294c2946c 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -3,7 +3,7 @@ import collections import logging from collections.abc import Iterator -from typing import Any, Optional, Union +from typing import Any import torch from torch.autograd.graph import GradientEdge, Node @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]: +def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None: """ Get the grad function or grad accumulator for a tensor. @@ -142,10 +142,10 @@ def get_param_groups( def stage_backward_input( stage_outputs_or_loss: list[torch.Tensor], - output_grads: Optional[list[torch.Tensor]], + output_grads: list[torch.Tensor] | None, input_values: list[torch.Tensor], weights: Iterator[Parameter], -) -> tuple[tuple[Optional[torch.Tensor], ...], list[dict[str, Any]]]: +) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]]]: """ Compute the gradients for only the stage inputs with respect to the stage outputs (if non-last stage) or loss (if last stage) @@ -225,10 +225,10 @@ def hook(grad_inputs): def stage_backward_weight( weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False -) -> tuple[Optional[torch.Tensor], ...]: +) -> tuple[torch.Tensor | None, ...]: # map weights to param_group_weights grad_acc_to_weight = {} - weight_grads: list[Optional[torch.Tensor]] = [] + weight_grads: list[torch.Tensor | None] = [] for index, weight in enumerate(weights): grad_acc = _get_grad_fn_or_grad_acc(weight) grad_acc_to_weight[grad_acc] = weight, index @@ -282,8 +282,8 @@ def stage_backward( stage_output, output_grads, input_values, - outputs_with_grads_idxs: Optional[list[int]] = None, # deprecated, not used -) -> tuple[Optional[torch.Tensor], ...]: + outputs_with_grads_idxs: list[int] | None = None, # deprecated, not used +) -> tuple[torch.Tensor | None, ...]: """ This is a helper function to: 1. compute the gradients for the stage inputs, and @@ -303,7 +303,7 @@ def stage_backward( # stage_output may be a composite datatype like dict. Extract all individual # tensor values here stage_output_tensors: list[torch.Tensor] = [] - output_grad_tensors: list[Optional[torch.Tensor]] = [] + output_grad_tensors: list[torch.Tensor | None] = [] def extract_tensors_with_grads( output_val, @@ -363,7 +363,7 @@ def extract_tensors_with_grads( ) # Extract gradients wrt the input values - grad_inputs: list[Optional[torch.Tensor]] = [] + grad_inputs: list[torch.Tensor | None] = [] for val in input_values: if isinstance(val, torch.Tensor): grad_inputs.append(val.grad) diff --git a/torch/distributed/pipelining/_schedule_visualizer.py b/torch/distributed/pipelining/_schedule_visualizer.py index e5891c775a687..5ecc5bf19ab17 100644 --- a/torch/distributed/pipelining/_schedule_visualizer.py +++ b/torch/distributed/pipelining/_schedule_visualizer.py @@ -10,7 +10,7 @@ """ import collections -from typing import NamedTuple, Optional, Union +from typing import NamedTuple from unittest import mock from torch.distributed.pipelining.schedules import ( @@ -32,13 +32,13 @@ class OpKey(NamedTuple): def get_schedule_ops( - schedule: Union[str, type[_PipelineSchedule]], + schedule: str | type[_PipelineSchedule], pp_degree: int, num_microbatches: int, - num_stages_per_rank: Optional[int] = None, + num_stages_per_rank: int | None = None, add_spacing: bool = False, with_comms: bool = False, -) -> list[list[Optional[_Action]]]: +) -> list[list[_Action | None]]: """ Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists where each inner list represents a rank and each element in the inner list represents an action. @@ -86,7 +86,7 @@ def get_schedule_ops( assert schedule_instance.pipeline_order is not None # Convert to List[List[_Action]] - all_actions: list[list[Optional[_Action]]] = [] + all_actions: list[list[_Action | None]] = [] if with_comms: runtime = _PipelineScheduleRuntime(stages, num_microbatches) runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order) @@ -136,8 +136,8 @@ def __init__( def add_schedule_op_spacing( - schedule: list[list[Optional[_Action]]], -) -> list[list[Optional[_Action]]]: + schedule: list[list[_Action | None]], +) -> list[list[_Action | None]]: """ Add spacing to the schedule based on dependencies between ranks. @@ -169,7 +169,7 @@ def add_schedule_op_spacing( ) num_ranks = len(schedule) - spaced_schedule: list[list[Optional[_Action]]] = [[] for _ in range(num_ranks)] + spaced_schedule: list[list[_Action | None]] = [[] for _ in range(num_ranks)] rank_ops = [collections.deque(ops) for ops in schedule] # Track completion times: (stage_index, action_type, microbatch_index) -> completion_time @@ -331,8 +331,8 @@ def schedule_action(action: _Action, rank: int, timestep: int) -> int: def visualize_schedule( - schedule: list[list[Optional[_Action]]], - filename: Optional[str] = None, + schedule: list[list[_Action | None]], + filename: str | None = None, ) -> None: """ Visualize the schedule using matplotlib. diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index 2f0472211b8c8..79b74be406814 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -3,7 +3,6 @@ import logging from dataclasses import dataclass -from typing import Union import torch from torch import fx @@ -76,8 +75,8 @@ def validate_tensor_metadata(desc, expected, given): def validate_tensors_metadata( desc, - expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], - actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], + expected_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...], + actual_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...], ): if len(expected_tensors) != len(actual_tensors): raise PipeliningShapeError( diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index 251d53a22bf27..a82f83072fa18 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -3,7 +3,7 @@ import logging import operator from collections.abc import Sequence -from typing import Any, Optional +from typing import Any import torch from torch.fx.node import map_aggregate @@ -307,10 +307,10 @@ def _shard_dict_of_args( def split_args_kwargs_into_chunks( args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]], + kwargs: dict[str, Any] | None, chunks: int, - args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, ) -> tuple[list[tuple], list[dict]]: """ Given a sequence of args and kwargs, split them into a number of chunks diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 7bdf3c65e4e8f..5657068f0bcd7 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -11,7 +11,7 @@ from collections.abc import Callable from enum import Enum from functools import lru_cache -from typing import Any, cast, NamedTuple, Optional, Protocol, Union +from typing import Any, cast, NamedTuple, Protocol import torch import torch.distributed as dist @@ -131,8 +131,8 @@ def from_str(action): class _Action(NamedTuple): stage_index: int computation_type: _ComputationType - microbatch_index: Optional[int] = None - sub_actions: Optional[tuple["_Action", ...]] = None + microbatch_index: int | None = None + sub_actions: tuple["_Action", ...] | None = None def __str__(self): return self.__repr__() @@ -220,8 +220,8 @@ def _get_profiler_function_name(action: _Action) -> str: def _format_pipeline_order( - pipeline_order: dict[int, list[Optional[_Action]]], - error_step_number: Optional[int] = None, + pipeline_order: dict[int, list[_Action | None]], + error_step_number: int | None = None, ) -> str: """ Formats the pipeline order in a timestep (row) x rank (column) grid of actions @@ -286,10 +286,10 @@ class _PipelineSchedule(ABC): def __init__( self, n_microbatches: int, - loss_fn: Optional[Callable[..., torch.Tensor]] = None, - args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + loss_fn: Callable[..., torch.Tensor] | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, scale_grads: bool = True, ): # From arguments @@ -360,10 +360,10 @@ def _update_losses(self, stages, losses): @abstractmethod def _step_microbatches( self, - arg_mbs: Optional[list] = None, - kwarg_mbs: Optional[list] = None, - target_mbs: Optional[list] = None, - losses: Optional[list] = None, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, return_outputs: bool = True, ): """ @@ -382,7 +382,7 @@ def step( self, *args, target=None, - losses: Optional[list] = None, + losses: list | None = None, return_outputs=True, **kwargs, ): @@ -399,7 +399,7 @@ def step( """ raise NotImplementedError - def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs): + def eval(self, *args, target=None, losses: list | None = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the @@ -421,10 +421,10 @@ def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs): def _check_inputs( self, - arg_mbs: Optional[list] = None, - kwarg_mbs: Optional[list] = None, - target_mbs: Optional[list] = None, - losses: Optional[list] = None, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, ) -> tuple[list, list]: """ Pre-process/check inputs @@ -463,7 +463,7 @@ def _compute_loss(self, output, target): def _split_inputs( self, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ): """ Splits a full-batch input into chunks (i.e. microbatches) and returns @@ -494,9 +494,7 @@ def _merge_outputs(self, output_chunks: list[Any]) -> Any: ) -def _batch_p2p( - p2p_ops: list[dist.P2POp], desc: Optional[str] = None -) -> list[dist.Work]: +def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None) -> list[dist.Work]: """ Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. """ @@ -508,7 +506,7 @@ def _batch_p2p( def _sorted_batch_p2p( - p2p_ops: list[dist.P2POp], desc: Optional[str] = None + p2p_ops: list[dist.P2POp], desc: str | None = None ) -> dict[int, list[dist.Work]]: """ Sorts the list of P2P ops by the peer rank, and then calls @@ -557,10 +555,10 @@ def __init__( self, stage: _PipelineStageBase, n_microbatches: int, - loss_fn: Optional[Callable] = None, - args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, scale_grads: bool = True, ): # Init parent @@ -584,7 +582,7 @@ def __init__( or equal to the number of stages ({self._num_stages})." ) - self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = ( + self.pipeline_order: dict[int, list[_Action | None]] | None = ( self._get_pipeline_order() ) @@ -608,7 +606,7 @@ def step( self, *args, target=None, - losses: Optional[list] = None, + losses: list | None = None, return_outputs: bool = True, **kwargs, ): @@ -656,7 +654,7 @@ def step( else: return None - def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: """ Returns the pipeline execution order as a schedule IR. @@ -683,10 +681,10 @@ class _ScheduleForwardOnly(PipelineScheduleSingle): def _step_microbatches( self, - arg_mbs: Optional[list] = None, - kwarg_mbs: Optional[list] = None, - target_mbs: Optional[list] = None, - losses: Optional[list] = None, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, return_outputs: bool = True, ): """ @@ -734,10 +732,10 @@ class ScheduleGPipe(PipelineScheduleSingle): def _step_microbatches( self, - arg_mbs: Optional[list] = None, - kwarg_mbs: Optional[list] = None, - target_mbs: Optional[list] = None, - losses: Optional[list] = None, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, return_outputs: bool = True, ): """ @@ -812,7 +810,7 @@ def _step_microbatches( self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1) - def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: """ Returns the pipeline order for GPipe schedule. @@ -822,7 +820,7 @@ def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: pp_group_size = self._num_stages for rank in range(pp_group_size): - actions: list[Optional[_Action]] = [] + actions: list[_Action | None] = [] # 1. Initial delay based on rank position warmup_delay = rank @@ -853,10 +851,10 @@ class Schedule1F1B(PipelineScheduleSingle): def _step_microbatches( self, - arg_mbs: Optional[list] = None, - kwarg_mbs: Optional[list] = None, - target_mbs: Optional[list] = None, - losses: Optional[list] = None, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, return_outputs: bool = True, ): """ @@ -995,7 +993,7 @@ def _step_microbatches( self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1) - def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: """ Returns the pipeline order for 1F1B schedule. @@ -1005,7 +1003,7 @@ def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: pp_group_size = self._num_stages for rank in range(pp_group_size): - actions: list[Optional[_Action]] = [] + actions: list[_Action | None] = [] # 1. Warmup phase: initial delay based on rank actions.extend([None] * rank) @@ -1069,13 +1067,13 @@ def _requires_reduce_grad(action_type: _ComputationType) -> bool: def _add_reduce_grad( - actions: list[Optional[_Action]], n_microbatches: int -) -> list[Optional[_Action]]: + actions: list[_Action | None], n_microbatches: int +) -> list[_Action | None]: """ REDUCE_GRAD refers to joint across minibatches grad reduction. reduce_grad frees memory and we want to schedule it just after the last "backward"-like stage. """ - actions_with_reduce_grad: list[Optional[_Action]] = [] + actions_with_reduce_grad: list[_Action | None] = [] cnt: dict[int, int] = defaultdict(int) def _leaf_action(a, to_schedule): @@ -1102,7 +1100,7 @@ def _leaf_action(a, to_schedule): def _add_unshard_reshard( - compute_actions: list[Optional[_Action]], + compute_actions: list[_Action | None], max_active_stages: int = 3, ) -> list[_Action]: """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP. @@ -1117,9 +1115,7 @@ def _add_unshard_reshard( (to account for having one f and one b active, and something else prefetching?) """ - def next_stage_indices( - count: int, next_actions: list[Optional[_Action]] - ) -> list[int]: + def next_stage_indices(count: int, next_actions: list[_Action | None]) -> list[int]: """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.""" seen: set[int] = set() ret: list[int] = [] @@ -1187,7 +1183,7 @@ def _reshard(stage_index: int): def _merge_bw( - compute_actions: list[Optional[_Action]], + compute_actions: list[_Action | None], ) -> list[_Action]: """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops. (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD) @@ -1259,9 +1255,7 @@ def _get_comms(action: _Action) -> tuple[_Action, _Action]: recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx) return send, recv - def _ready_to_schedule( - action: Optional[_Action], prev_actions: set[_Action] - ) -> bool: + def _ready_to_schedule(action: _Action | None, prev_actions: set[_Action]) -> bool: """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. This helps ensure a sane (non-hanging) ordering of sends and recvs. But it also means we might not be able to schedule our next compute action yet. @@ -1343,7 +1337,7 @@ def _ready_to_schedule( def _validate_schedule( - actions: dict[int, list[Optional[_Action]]], + actions: dict[int, list[_Action | None]], pp_group_size: int, num_stages: int, num_microbatches: int, @@ -1479,11 +1473,11 @@ def __init__( self, stages: list[_PipelineStageBase], n_microbatches: int, - loss_fn: Optional[Callable] = None, - args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, - use_full_backward: Optional[bool] = None, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + use_full_backward: bool | None = None, scale_grads: bool = True, backward_requires_autograd: bool = True, ): @@ -1516,7 +1510,7 @@ def __init__( self._should_compute_loss = lambda stage: stage.is_last and has_loss # This will be set during init of derived schedules - self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[_Action | None]] = {} # When using a custom backward function, we may or may not need autograd to be used # for the backward pass. This flag is used to determine whether or torch.is_grad_enabled() @@ -1559,7 +1553,7 @@ def _initialize_stages(self, args: tuple[Any, ...], kwargs): self._stages_backward_initialized = True def _validate_and_set_stage_mapping( - self, actions: dict[int, list[Optional[_Action]]] + self, actions: dict[int, list[_Action | None]] ) -> None: """ Allocates the stage index to rank mapping which is needed for communication @@ -1600,7 +1594,7 @@ def step( self, *args, target=None, - losses: Optional[list] = None, + losses: list | None = None, return_outputs: bool = True, **kwargs, ): @@ -1657,10 +1651,10 @@ def step( def _step_microbatches( self, - arg_mbs: Optional[list] = None, - kwarg_mbs: Optional[list] = None, - target_mbs: Optional[list] = None, - losses: Optional[list] = None, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, return_outputs: bool = True, ): """ @@ -1851,10 +1845,10 @@ class _PipelineContext: def __init__( self, schedule_ref: _PipelineSchedule, - arg_mbs: Optional[list[tuple]] = None, - kwarg_mbs: Optional[list[dict]] = None, - target_mbs: Optional[list] = None, - losses: Optional[list] = None, + arg_mbs: list[tuple] | None = None, + kwarg_mbs: list[dict] | None = None, + target_mbs: list | None = None, + losses: list | None = None, ): self.schedule_ref = schedule_ref self.arg_mbs = arg_mbs @@ -1931,7 +1925,7 @@ def register_custom_function( def _prepare_schedule_with_comms( self, - actions: dict[int, list[Optional[_Action]]], + actions: dict[int, list[_Action | None]], format: str = "compute_only", ): """ @@ -2042,10 +2036,10 @@ def _assert_unsharded(self, stage: _PipelineStageBase): def _step_microbatches( self, - arg_mbs: Optional[list] = None, - kwarg_mbs: Optional[list] = None, - target_mbs: Optional[list] = None, - losses: Optional[list] = None, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, return_outputs: bool = True, ): """ @@ -2304,8 +2298,8 @@ def __init__( self, stages: list[_PipelineStageBase], n_microbatches: int, - loss_fn: Optional[Union[Callable, _Loss]] = None, - output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + loss_fn: Callable | _Loss | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, scale_grads: bool = True, backward_requires_autograd: bool = True, ): @@ -2321,7 +2315,7 @@ def __init__( # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[_Action | None]] = {} # ======================================================================== for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) @@ -2338,7 +2332,7 @@ def _calculate_single_rank_operations(self, rank): # Store the list of operations used for that rank # Pre-padding, rank starts with no-ops based on the warmup. - rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] + rank_ops: list[_Action | None] = [None for _ in range(rank)] for stage_index in stage_indices: rank_ops.extend( @@ -2378,7 +2372,7 @@ def _get_1f1b_rank_ops( # Store the list of operations used for that rank # Pre-padding, rank starts with no-ops based on the warmup. - rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] + rank_ops: list[_Action | None] = [None for _ in range(rank)] # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. # Formula: @@ -2518,10 +2512,10 @@ def __init__( self, stages: list[_PipelineStageBase], n_microbatches: int, - loss_fn: Optional[Callable] = None, - args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, scale_grads: bool = True, backward_requires_autograd: bool = True, ): @@ -2549,7 +2543,7 @@ def __init__( # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[_Action | None]] = {} for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops @@ -2557,7 +2551,7 @@ def __init__( # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime self._prepare_schedule_with_comms(self.pipeline_order) - def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: def get_rank_warmup_ops(rank): # Warms up operations for last stage warmups_ops_last_stage = ( @@ -2632,10 +2626,10 @@ def __init__( self, stages: list[_PipelineStageBase], n_microbatches: int, - loss_fn: Optional[Callable] = None, - args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, scale_grads: bool = True, backward_requires_autograd: bool = True, ): @@ -2665,7 +2659,7 @@ def __init__( # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[_Action | None]] = {} for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops @@ -2680,7 +2674,7 @@ def __init__( # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime self._prepare_schedule_with_comms(self.pipeline_order) - def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: def get_rank_warmup_ops(rank): # Warms up operations for last stage warmups_ops_last_stage = ( @@ -2758,7 +2752,7 @@ def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): return False seen_ops: set[tuple[int, _ComputationType, int]] = set() - result: dict[int, list[Optional[_Action]]] = {} + result: dict[int, list[_Action | None]] = {} next_pointer: dict[int, int] = {} bubbles_added: dict[int, int] = {} total_bubbles_added = 0 @@ -2831,10 +2825,10 @@ def __init__( self, stages: list[_PipelineStageBase], n_microbatches: int, - loss_fn: Optional[Callable] = None, - args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, scale_grads: bool = True, backward_requires_autograd: bool = True, ): @@ -2870,7 +2864,7 @@ def __init__( # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[_Action | None]] = {} for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops @@ -2878,11 +2872,11 @@ def __init__( # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime self._prepare_schedule_with_comms(self.pipeline_order) - def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least # as large of the number of microbatches needed to fully utilize the pipeline n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches) - rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] + rank_ops: list[_Action | None] = [None for _ in range(rank)] # Forward and backward action counts for stage chunk 0 and chunk 1 f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0 @@ -3009,10 +3003,10 @@ def __init__( self, stages: list[_PipelineStageBase], n_microbatches: int, - loss_fn: Optional[Callable] = None, - args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, scale_grads: bool = True, backward_requires_autograd: bool = True, ): @@ -3053,7 +3047,7 @@ def __init__( # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[_Action | None]] = {} for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops @@ -3061,8 +3055,8 @@ def __init__( # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime self._prepare_schedule_with_comms(self.pipeline_order) - def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: - actions: list[Optional[_Action]] = [] + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + actions: list[_Action | None] = [] counters: dict[ tuple[int, _ComputationType], int ] = {} # (stage_index, computation_type) -> mb_index @@ -3271,12 +3265,12 @@ def _simulate_comms_compute( _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule} - def add_to_schedule(rank: int, action: Optional[_Action]): + def add_to_schedule(rank: int, action: _Action | None): _schedule[rank].append(action) if action is not None: _prev_ops_rank[rank].add(action) - def _ready_to_schedule(action: Optional[_Action]) -> bool: + def _ready_to_schedule(action: _Action | None) -> bool: if action is None: return True diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index a232f5519c9ee..cc0d51020458b 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -4,7 +4,7 @@ import operator from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any, cast, Optional, Union +from typing import Any, cast, Union import torch import torch.distributed as dist @@ -99,7 +99,7 @@ def __repr__(self): def _make_tensor_from_meta( - example: Union[torch.Tensor, FakeTensor], + example: torch.Tensor | FakeTensor, device: torch.device, ) -> torch.Tensor: """ @@ -126,8 +126,8 @@ def __init__( stage_index: int, num_stages: int, device: torch.device, - group: Optional[dist.ProcessGroup] = None, - dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + group: dist.ProcessGroup | None = None, + dw_builder: Callable[[], Callable[..., None]] | None = None, ): """ Args: @@ -176,11 +176,11 @@ def __init__( ) # Run time states - self._outputs_meta: Optional[tuple[torch.Tensor, ...]] = None + self._outputs_meta: tuple[torch.Tensor, ...] | None = None # map microbatch ID to list of forward tensor args self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {} # map microbatch ID to list of backward grad tensor args - self.bwd_cache: dict[int, tuple[Optional[torch.Tensor], ...]] = {} + self.bwd_cache: dict[int, tuple[torch.Tensor | None, ...]] = {} # Caching chunk outputs for final output merge or reduction self.output_chunks: list[Any] = [] @@ -196,10 +196,10 @@ def __init__( # Backward infra will created lazily self.grad_recv_info: dict = {} - self.grad_send_info: Optional[list] = None + self.grad_send_info: list | None = None # To be populated later by the Schedule - self.chunks: Optional[int] = None + self.chunks: int | None = None self.stage_index_to_group_rank: dict[int, int] = { i: i % self.group_size for i in range(self.num_stages) } @@ -261,11 +261,11 @@ def get_outputs_meta(self) -> tuple[torch.Tensor, ...]: def _create_grad_send_info( self, args_recv_info: tuple, - ) -> list[Optional[int]]: + ) -> list[int | None]: """ Create a list of stage indices to send gradients to. """ - grad_send_info: list[Optional[int]] = [] + grad_send_info: list[int | None] = [] def map_recv_to_send(a): # Note: we send gradients back to previous stage as long as in @@ -288,7 +288,7 @@ def _prepare_forward_infra( self, num_microbatches: int, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> tuple[Any, ...]: raise NotImplementedError @@ -388,7 +388,7 @@ def get_local_bwd_output(self, mb_index): return self.bwd_cache.pop(mb_index) def set_local_bwd_input( - self, next_stage_bwd_outputs: tuple[Optional[torch.Tensor], ...], mb_index: int + self, next_stage_bwd_outputs: tuple[torch.Tensor | None, ...], mb_index: int ) -> None: """ Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. @@ -588,7 +588,7 @@ def backward_maybe_with_nosync( backward_type, bwd_kwargs: dict, last_backward: bool = False, - ) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]]: + ) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None]: """ Whether using PP with FSDP, DDP, or replicate there are some runtime differences between the last backward step and the other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but @@ -600,7 +600,7 @@ def perform_backward( backward_type, ) -> Callable[ [], - tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]], + tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None], ]: if backward_type == "full": return lambda: ( @@ -663,7 +663,7 @@ def forward_one_chunk( self, fwd_chunk_id: int, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, save_forward_output: bool = True, ): """ @@ -779,7 +779,7 @@ def backward_one_chunk( "input_values": input_values, } - grads_input: tuple[Optional[torch.Tensor], ...] = () + grads_input: tuple[torch.Tensor | None, ...] = () # Custom backward function if self.dw_builder: @@ -1019,7 +1019,7 @@ def __init__( stage_index: int, pipe_info: PipeInfo, device: torch.device, - group: Optional[dist.ProcessGroup] = None, + group: dist.ProcessGroup | None = None, ): """ Create a pipeline stage given a stage_module to be wrapped by this stage @@ -1086,7 +1086,7 @@ def _prepare_forward_infra( self, num_microbatches: int, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> tuple[Any, ...]: """ Create send/recv infrastructures for activations (during forward) @@ -1183,7 +1183,7 @@ def create_recv_tensor(placeholder, arg_node): def find_dst_rank( self, user: fx.Node, - ) -> Optional[int]: + ) -> int | None: """ Find the destination rank of a `user` node. If the `user` is not a submod, `None` may be returned. @@ -1293,7 +1293,7 @@ def build_stage( stage_index: int, pipe_info: PipeInfo, device: torch.device, - group: Optional[dist.ProcessGroup] = None, + group: dist.ProcessGroup | None = None, ) -> _PipelineStage: """ Create a pipeline stage given a stage_module to be wrapped by this stage @@ -1347,14 +1347,14 @@ def __init__( stage_index: int, num_stages: int, device: torch.device, - input_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, - output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, - group: Optional[dist.ProcessGroup] = None, - dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + input_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None, + output_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None, + group: dist.ProcessGroup | None = None, + dw_builder: Callable[[], Callable[..., None]] | None = None, ): super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) - self.inputs: Optional[list[torch.Tensor]] = None - self.inputs_meta: Optional[tuple[torch.Tensor, ...]] = None + self.inputs: list[torch.Tensor] | None = None + self.inputs_meta: tuple[torch.Tensor, ...] | None = None # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) because it # might be breaking for existing users. if input_args is None: @@ -1410,7 +1410,7 @@ def __init__( def _shape_inference( self, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ): if kwargs is None: kwargs = {} @@ -1522,7 +1522,7 @@ def _prepare_forward_infra( self, num_microbatches: int, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> tuple[Any, ...]: # TODO move self.device to an argument from step API (from its input tensors)? assert num_microbatches is not None, "TODO fix num_microbatches" diff --git a/torch/distributed/remote_device.py b/torch/distributed/remote_device.py index a71e15c9c349b..3ad0076f5e890 100644 --- a/torch/distributed/remote_device.py +++ b/torch/distributed/remote_device.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional, Union import torch @@ -22,14 +21,14 @@ class _remote_device: and "cuda:1", just represent local devices. """ - def __init__(self, remote_device: Union[str, torch.device]): + def __init__(self, remote_device: str | torch.device): PARSE_ERROR = ( f"Could not parse remote_device: {remote_device}. The valid format is " "'/' or 'rank:/' or ''" ) self._worker_name = None self._rank = None - self._device: Optional[Union[str, int, torch.device]] = None + self._device: str | int | torch.device | None = None if isinstance(remote_device, torch.device): self._device = remote_device @@ -81,11 +80,11 @@ def _is_valid_local_device(device): except Exception: return False - def worker_name(self) -> Optional[str]: + def worker_name(self) -> str | None: """Return the name of remote worker representing the remote device and ``None`` if no worker name is available.""" return self._worker_name - def rank(self) -> Optional[int]: + def rank(self) -> int | None: """ Returns the rank of remote worker representing the remote device. Returns ``None`` if no rank is available. diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index a65bfa783efc3..f7913341175fb 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -11,7 +11,6 @@ import sys from collections.abc import Callable, Iterator from datetime import timedelta -from typing import Optional from torch.distributed import FileStore, Store, TCPStore @@ -71,7 +70,7 @@ def _get_use_libuv_from_query_dict(query_dict: dict[str, str]) -> bool: return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" -def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): +def _rendezvous_helper(url: str, rank: int, world_size_opt: int | None, **kwargs): result = urlparse(url) if world_size_opt is None: world_size = -1 diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 7c1e3d4b5a04f..c58a2bf923910 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Optional, Union +from typing import Union import torch @@ -89,10 +89,10 @@ def __init__( num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS, rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC, init_method: str = rpc_contants.DEFAULT_INIT_METHOD, - device_maps: Optional[dict[str, dict[DeviceType, DeviceType]]] = None, - devices: Optional[list[DeviceType]] = None, - _transports: Optional[list] = None, - _channels: Optional[list] = None, + device_maps: dict[str, dict[DeviceType, DeviceType]] | None = None, + devices: list[DeviceType] | None = None, + _transports: list | None = None, + _channels: list | None = None, ): full_device_maps = ( {} diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 2343f7bb9b74c..3d8d0fb64276e 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -375,7 +375,6 @@ def main(): from argparse import ArgumentParser, REMAINDER from collections.abc import Callable from importlib import metadata -from typing import Optional, Union import torch from torch.distributed.argparse_util import check_env, env @@ -798,7 +797,7 @@ def get_use_env(args) -> bool: return args.use_env -def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]: +def _get_logs_specs_class(logs_specs_name: str | None) -> type[LogsSpecs]: """ Attempts to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. Provides plugin mechanism to provide custom implementation of LogsSpecs. @@ -827,7 +826,7 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]: return logs_specs_cls -def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]: +def config_from_args(args) -> tuple[LaunchConfig, Callable | str, list[str]]: # If ``args`` not passed, defaults to ``sys.argv[:1]`` min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) if not (0 < min_nodes <= max_nodes): @@ -871,7 +870,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str rdzv_endpoint = get_rdzv_endpoint(args) - ranks: Optional[set[int]] = None + ranks: set[int] | None = None if args.local_ranks_filter: try: ranks = set(map(int, args.local_ranks_filter.split(","))) @@ -920,7 +919,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str ) with_python = not args.no_python - cmd: Union[Callable, str] + cmd: Callable | str cmd_args = [] use_env = get_use_env(args) if args.run_path: diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index fb072d8dce629..070d8625f50e0 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -5,7 +5,7 @@ import inspect import warnings from collections.abc import Callable, Sequence -from typing import Any, cast, Optional +from typing import Any, cast from typing_extensions import deprecated import torch @@ -74,7 +74,7 @@ class _ToTorchTensor(torch.autograd.Function): def forward( # type: ignore[override] ctx, input: "DTensor", - grad_placements: Optional[Sequence[Placement]], + grad_placements: Sequence[Placement] | None, ): ctx.dtensor_spec = input._spec ctx.grad_placements = grad_placements @@ -135,8 +135,8 @@ def forward( # type: ignore[override] device_mesh: DeviceMesh, placements: tuple[Placement, ...], run_check: bool, - shape: Optional[torch.Size] = None, - stride: Optional[tuple[int, ...]] = None, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, ) -> "DTensor": ctx.previous_placement = placements ctx.previous_device_mesh = device_mesh @@ -359,12 +359,12 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[ @staticmethod def from_local( local_tensor: torch.Tensor, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, *, run_check: bool = False, - shape: Optional[torch.Size] = None, - stride: Optional[tuple[int, ...]] = None, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, ) -> "DTensor": """ Create a :class:`DTensor` from a local torch.Tensor on each rank @@ -448,7 +448,7 @@ def from_local( ) def to_local( - self, *, grad_placements: Optional[Sequence[Placement]] = None + self, *, grad_placements: Sequence[Placement] | None = None ) -> torch.Tensor: """ Get the local tensor of this DTensor on its current rank. For sharding it returns @@ -486,12 +486,12 @@ def to_local( def redistribute( self, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, *, async_op: bool = False, - forward_dtype: Optional[torch.dtype] = None, - backward_dtype: Optional[torch.dtype] = None, + forward_dtype: torch.dtype | None = None, + backward_dtype: torch.dtype | None = None, ) -> "DTensor": """ ``redistribute`` performs necessary collective operations that redistribute the current @@ -568,7 +568,7 @@ def redistribute( ) def full_tensor( - self, *, grad_placements: Optional[Sequence[Placement]] = None + self, *, grad_placements: Sequence[Placement] | None = None ) -> torch.Tensor: """ Return the full tensor of this DTensor. It will perform necessary collectives @@ -691,10 +691,10 @@ def __metadata_guard__( def distribute_tensor( tensor: torch.Tensor, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, *, - src_data_rank: Optional[int] = 0, + src_data_rank: int | None = 0, ) -> DTensor: """ Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according @@ -858,7 +858,7 @@ def distribute_tensor( def _shard_tensor( full_tensor: torch.Tensor, placements: Sequence[Shard], - device_mesh: Optional[DeviceMesh] = None, + device_mesh: DeviceMesh | None = None, ) -> "DTensor": """ Locally shards a full tensor based on indicated sharding arrangement, and @@ -894,10 +894,10 @@ def _shard_tensor( def distribute_module( module: nn.Module, - device_mesh: Optional[DeviceMesh] = None, - partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, - input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, - output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, + device_mesh: DeviceMesh | None = None, + partition_fn: Callable[[str, nn.Module, DeviceMesh], None] | None = None, + input_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None, + output_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None, ) -> nn.Module: """ This function expose three functions to control the parameters/inputs/outputs of the module: @@ -1050,8 +1050,8 @@ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: def _dtensor_init_helper( # type: ignore[no-untyped-def] init_op, size: torch.Size, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, **kwargs, ) -> DTensor: # if device_mesh is None, use the one from mesh resources @@ -1116,11 +1116,11 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] def ones( # type: ignore[no-untyped-def] *size, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, layout: torch.layout = torch.strided, requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, ) -> DTensor: """ Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined @@ -1159,11 +1159,11 @@ def ones( # type: ignore[no-untyped-def] def empty( # type: ignore[no-untyped-def] *size, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, layout: torch.layout = torch.strided, requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, ) -> DTensor: """ Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` @@ -1204,11 +1204,11 @@ def full( # type: ignore[no-untyped-def] size, fill_value, *, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, layout: torch.layout = torch.strided, requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, ) -> DTensor: """ Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and @@ -1250,10 +1250,10 @@ def full( # type: ignore[no-untyped-def] def rand( # type: ignore[no-untyped-def] *size, requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, ) -> DTensor: """ Returns a :class:`DTensor` filled with random numbers from a uniform distribution @@ -1294,10 +1294,10 @@ def rand( # type: ignore[no-untyped-def] def randn( # type: ignore[no-untyped-def] *size, requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, ) -> DTensor: """ Returns a :class:`DTensor` filled with random numbers from a normal distribution @@ -1338,10 +1338,10 @@ def randn( # type: ignore[no-untyped-def] def zeros( # type: ignore[no-untyped-def] *size, requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, ) -> DTensor: """ Returns a :class:`DTensor` filled with the scalar value 0. diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index dff426a6d5e5a..1d2690ccba38d 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -74,7 +74,7 @@ def mesh_scatter( async_op: bool = False, *, group_src: int = 0, -) -> Optional[Work]: +) -> Work | None: """ scatter a list of tensors to a device mesh dimension. We by default use the first rank of the mesh dimension as the source of truth, i.e @@ -135,7 +135,7 @@ def mesh_broadcast( async_op: bool = False, *, group_src: int = 0, -) -> Optional[Work]: +) -> Work | None: """ broadcast the tensor to a device mesh dimension. We by default use the first rank of the mesh dimension as the source of truth, i.e diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index aaa5d25c79ba7..56c9cb1a94783 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -3,7 +3,7 @@ import logging import warnings from collections.abc import Sequence -from typing import cast, Optional +from typing import cast import torch import torch.distributed as dist @@ -518,7 +518,7 @@ def _unwrap_to_op_info_impl( kwargs_schema: dict[str, object] = {} local_args: list[object] = [] local_kwargs: dict[str, object] = {} - compute_mesh: Optional[DeviceMesh] = None + compute_mesh: DeviceMesh | None = None for arg in args_list: if isinstance(arg, dtensor.DTensor): diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index ca51cdf70c058..0499fc696799b 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -2,7 +2,7 @@ import math from collections import defaultdict from dataclasses import dataclass -from typing import Any, cast, NamedTuple, Optional +from typing import Any, cast, NamedTuple import torch from torch.distributed.device_mesh import DeviceMesh @@ -71,7 +71,7 @@ class DTensorSpec: placements: tuple[Placement, ...] # tensor meta will only be set during sharding propagation - tensor_meta: Optional[TensorMeta] = None + tensor_meta: TensorMeta | None = None # When a tensor dimension is sharded across multiple mesh axes, # `shard_order` specifies the sequence in which these shardings are applied. @@ -206,7 +206,7 @@ def _convert_shard_order_to_StridedShard( @staticmethod def _maybe_convert_StridedShard_to_shard_order( placements: tuple[Placement, ...], mesh: DeviceMesh - ) -> Optional[ShardOrder]: + ) -> ShardOrder | None: """ Try to convert _StridedShard placements to ShardOrder. @@ -441,7 +441,7 @@ def is_default_device_order(shard_order: ShardOrder) -> bool: @staticmethod def format_shard_order_str( placements: tuple[Placement, ...], - shard_order: Optional[ShardOrder] = None, + shard_order: ShardOrder | None = None, ) -> str: """ Format DTensor sharding information as a human-readable string. @@ -617,7 +617,7 @@ def from_dim_map( mesh: DeviceMesh, dim_map: list[int], sums: list[int], - tensor_meta: Optional[TensorMeta] = None, + tensor_meta: TensorMeta | None = None, ) -> "DTensorSpec": """ Construct a DTensorSpec from dim_map list and pending sum. @@ -669,7 +669,7 @@ def is_sharded(self) -> bool: return any(placement.is_shard() for placement in self.placements) def shallow_copy_with_tensor_meta( - self, tensor_meta: Optional[TensorMeta] + self, tensor_meta: TensorMeta | None ) -> "DTensorSpec": """ Shallow copy the DTensorSpec with a new tensor_meta. diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index 283eaf4a06db8..4fec0293554ac 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -26,7 +26,7 @@ from collections.abc import Sequence from dataclasses import dataclass from functools import cached_property -from typing import Any, Optional, Union +from typing import Any from typing_extensions import deprecated import torch @@ -60,11 +60,11 @@ ArgsType = tuple[object, ...] KwargsType = dict[str, object] -PlacementList = list[Optional[Placement]] +PlacementList = list[Placement | None] # ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should # be the same set of possibilities. -OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]] +OutputSpecType = DTensorSpec | Sequence[DTensorSpec | None] | None def _rebuild_tensor_from_dtensor_meta(arg) -> object: @@ -109,8 +109,8 @@ class OpSpec: # output_specs and input_specs are related: for this op, given these input_specs, # this is the way the output would look - output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]] - input_specs: Optional[Sequence[DTensorSpec]] = None + output_specs: DTensorSpec | tuple[DTensorSpec | None, ...] + input_specs: Sequence[DTensorSpec] | None = None """ redistribute_cost tells how expensive it is to redistribute a given input into the @@ -138,7 +138,7 @@ class OpSpec: K, # cost of redistributing tensor_a from 'Shard(0)' ], """ - redistribute_cost: Optional[list[list[float]]] = None + redistribute_cost: list[list[float]] | None = None @cached_property def output_spec(self) -> DTensorSpec: @@ -301,7 +301,7 @@ class RuntimeSchemaInfo: # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc. static_argnum: int = 100 # This static_kwargkey records static kwarg names which would affect sharding prop - static_kwargkey: Optional[list[str]] = None + static_kwargkey: list[str] | None = None # each op can decide if it wants to use pytree flatten/unflatten during operator # eager execution, by default we don't need to do flatten/unflatten, only if the # op indicate it needs to, this is to accelerate eager performance. @@ -331,9 +331,9 @@ class OpSchema: args_schema: ArgsType kwargs_schema: KwargsType - schema_info: Optional[RuntimeSchemaInfo] = None + schema_info: RuntimeSchemaInfo | None = None - _comparison_key: Optional[tuple[object, ...]] = None + _comparison_key: tuple[object, ...] | None = None @property def args_spec(self) -> tuple[DTensorSpec, ...]: @@ -570,7 +570,7 @@ class OutputSharding: # specifies the output sharding pattern output_spec: OutputSpecType # schema for redistribution if needed - redistribute_schema: Optional[OpSchema] = None + redistribute_schema: OpSchema | None = None # flag indicating if inputs need redistribution needs_redistribute: bool = False # flag to use values from `redistribute_schema` @@ -606,7 +606,7 @@ class OpInfo: flat_args_schema: list[object] local_args: Sequence[object] local_kwargs: dict[str, object] - args_tree_spec: Optional[TreeSpec] = None + args_tree_spec: TreeSpec | None = None # the output sharding info - output_sharding: Optional[OutputSharding] = None + output_sharding: OutputSharding | None = None diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 1e7ff648f7fbd..88a6e4298d246 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import string -from typing import cast, Optional +from typing import cast import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta @@ -44,7 +44,7 @@ def einop_rule( op_schema: OpSchema, *, linearity: bool = False, - enforce_sharding: Optional[dict[str, int]] = None, + enforce_sharding: dict[str, int] | None = None, ) -> OutputSharding: """ Propagate the sharding of inputs to output for ops whose data moves according to einsum notation. diff --git a/torch/distributed/tensor/_ops/_mask_buffer.py b/torch/distributed/tensor/_ops/_mask_buffer.py index 7fe06c11aea9d..26b0a713db42c 100644 --- a/torch/distributed/tensor/_ops/_mask_buffer.py +++ b/torch/distributed/tensor/_ops/_mask_buffer.py @@ -1,14 +1,13 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass -from typing import Optional import torch @dataclass class MaskBuffer: - data: Optional[torch.Tensor] = None + data: torch.Tensor | None = None # refcount allows shared usage of the MaskBuffer, as long as all users have the same data refcount: int = 0 diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index ac0180f07d05e..7721ec3bc090f 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import cast, Optional, Union +from typing import cast, Union import torch from torch.distributed.device_mesh import DeviceMesh @@ -47,7 +47,7 @@ class Reduction(Enum): @dataclass(frozen=True) class NormReduction: - norm_type: Union[int, float, str] + norm_type: int | float | str ReductionOpType = Union[NormReduction, str] @@ -71,9 +71,9 @@ class _NormPartial(Partial): similarly for inf and -inf norm. For 0-norm, the reduction op is sum. """ - norm_type: Union[int, float, str] = 2 + norm_type: int | float | str = 2 - def __init__(self, norm_type: Union[int, float, str] = 2): + def __init__(self, norm_type: int | float | str = 2): reduce_op = None if norm_type in (float("inf"), "inf"): reduce_op = "max" @@ -174,7 +174,7 @@ def __str__(self) -> str: return f"_NormP({self.reduce_op}, {self.norm_type})" -def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]: +def _infer_reduction_dims(dims_arg: object, ndim: int) -> list[int] | None: if dims_arg is None: return None dims = cast(list[int], as_list(dims_arg)) @@ -1096,7 +1096,7 @@ def _common_norm_backward_strategy( out_tuple_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): # args for OpSpec - output_specs_list: list[Optional[DTensorSpec]] = [] + output_specs_list: list[DTensorSpec | None] = [] input_specs_list: list[DTensorSpec] = [] redistribute_costs = [] diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index 81b9e328f0604..30498a95e29d6 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -2,8 +2,6 @@ # implement matrix related ops for distributed tensor -from typing import Optional - import torch from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta @@ -708,7 +706,7 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate ) = op_schema.args_schema return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2] has_attn_bias = attn_bias_strategy is not None - debug_attn_mask_sharding: Optional[Placement] = ( + debug_attn_mask_sharding: Placement | None = ( Replicate() if return_debug_mask else None ) @@ -1073,7 +1071,7 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy: ) def valid_grouped_mm_strides( - input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...] + input_specs: list[DTensorSpec], output_specs: tuple[DTensorSpec | None, ...] ) -> bool: # 1. compute the local-tensor shape/strides given this sharding proposal # 2. apply the logic from the groped_mm meta function diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 011a1ec667fb4..2fa8fabd8f08a 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from collections.abc import Sequence -from typing import cast, Optional +from typing import cast import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -493,7 +493,7 @@ def common_pointwise_strategy( followed_strategy: OpStrategy, followed_strategy_index: int, linearity: int = -1, - scalar_tensor_idx: Optional[int] = None, + scalar_tensor_idx: int | None = None, ) -> OpStrategy: """ Common strategy for pointwise operations. @@ -713,11 +713,11 @@ def list_pointwise_strategy( def args_tuple_strategies( args_schema: tuple[object, ...], - ) -> list[Optional[TupleStrategy]]: + ) -> list[TupleStrategy | None]: first_arg = args_schema[0] assert isinstance(first_arg, TupleStrategy) strategy_len = len(first_arg.children) - tuple_strategies: list[Optional[TupleStrategy]] = [] + tuple_strategies: list[TupleStrategy | None] = [] for arg_idx, arg in enumerate(args_schema): if isinstance(arg, TupleStrategy): # every tuple strategy should have the same length @@ -743,7 +743,7 @@ def args_tuple_strategies( for child_idx, child_strtgy in enumerate(follow_strategy.children): assert isinstance(child_strtgy, OpStrategy) - args_schema: list[Optional[OpStrategy]] = [ + args_schema: list[OpStrategy | None] = [ cast(OpStrategy, arg_strategy.children[child_idx]) if arg_strategy else None for arg_strategy in args_strategies ] diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index cb336486785af..a6ff33a12a189 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from collections.abc import Sequence, Sized -from typing import cast, Optional +from typing import cast import torch from torch._prims_common import IntLike @@ -723,7 +723,7 @@ def merge_placement( # current replicate, just follow new placement return new_placement - follow_placements: Optional[list[Placement]] = None + follow_placements: list[Placement] | None = None mesh = tuple_strategy.child_mesh(0) for arg_strategy in tuple_strategy.children: if not isinstance(arg_strategy, OpStrategy): @@ -889,7 +889,7 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding: if not isinstance(indices_spec, DTensorSpec): raise AssertionError(f"Expected DTensorSpec, got {type(indices_spec)}") - all_indices_spec: list[Optional[DTensorSpec]] = [ + all_indices_spec: list[DTensorSpec | None] = [ indices_spec if dim == i else None for i in range(values_spec.ndim) ] @@ -936,7 +936,7 @@ def prop_index_put(op_schema: OpSchema) -> StrategyType: op_strategy = OpStrategy([]) # 1. `indices` should all be replicated first. indices_redistribute_costs = [] - new_indices_spec: list[Optional[DTensorSpec]] = [] + new_indices_spec: list[DTensorSpec | None] = [] for indices_spec_child in indices_spec.children: if not isinstance(indices_spec_child, OpStrategy): raise AssertionError(f"Expected OpStrategy, got {type(indices_spec_child)}") @@ -1046,7 +1046,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding: raise AssertionError(f"Expected DTensorSpec, got {type(values_spec)}") if not isinstance(multi_indices_spec, list): raise AssertionError(f"Expected list, got {type(multi_indices_spec)}") - multi_indices_spec = cast(list[Optional[DTensorSpec]], multi_indices_spec) + multi_indices_spec = cast(list[DTensorSpec | None], multi_indices_spec) valid_indices_spec: list[tuple[int, DTensorSpec]] = [ (i, a) for i, a in enumerate(multi_indices_spec) if a is not None ] diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 6c8954729b976..32e2e43c5d255 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -216,7 +216,7 @@ def expand(input_shape: Shape, shape: Shape) -> DimMap: return tuple(mapping) -def normalize_sizes(sizes: Union[Shape, tuple[Shape]]) -> Shape: +def normalize_sizes(sizes: Shape | tuple[Shape]) -> Shape: if isinstance(sizes[0], int): return cast(Shape, sizes) elif len(sizes) == 1: @@ -428,7 +428,7 @@ def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: return tuple(dimmap) -def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: +def dim_squeeze(shape: Shape, dim: int | None = None) -> DimMap: # FIXME: this is wrong when dim=None and one of the dimensions # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to @@ -457,7 +457,7 @@ def dim_view_as_real(shape: Shape) -> DimMap: return tuple(results) -def dim_reduction(ndim: int, dim_or_dims: Optional[DimsType], keepdim: bool) -> DimMap: +def dim_reduction(ndim: int, dim_or_dims: DimsType | None, keepdim: bool) -> DimMap: """ General fallback for reduction ops where Partial() does not apply. @@ -542,7 +542,7 @@ def collect_used_inputs(cmd: DimSpec) -> None: def maybe_get_shard_mesh_dim_and_placement( input_dim: InputDim, - ) -> tuple[Optional[int], Optional[Shard]]: + ) -> tuple[int | None, Shard | None]: # if input_dim is sharded, return the mesh_dim and shard placement for i, placement in enumerate(input_src_placements): if isinstance(placement, Shard) and placement.dim == input_dim.input_dim: @@ -556,7 +556,7 @@ def maybe_get_shard_mesh_dim_and_placement( # 1 and 2 doesn't require the info of whether current input is sharded. # 3 requires that info, to decide whether we can error out. Maybe we can refactor # to make this function purely "theoretical". - def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: + def get_in_dim_to_shard(cmd: DimSpec) -> InputDim | None: if isinstance(cmd, InputDim): return cmd elif isinstance(cmd, Flatten): @@ -692,7 +692,7 @@ def _rewrite_shard_dim(p: Shard): def register_op_strategy_map( aten_op_overload: torch._ops.OpOverload, local_op_name: Callable[..., torch.Tensor], - schema_info: Optional[RuntimeSchemaInfo] = None, + schema_info: RuntimeSchemaInfo | None = None, strict_view: bool = False, ) -> None: """ diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index d117df2d67e2e..40415947be9a0 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -3,7 +3,7 @@ import contextlib import warnings from logging import getLogger -from typing import Optional, Union +from typing import Optional import torch from torch.distributed.device_mesh import _get_device_handle, DeviceMesh @@ -174,7 +174,7 @@ def distribute_region_enabled(self, value) -> None: self._use_distribute_region = value def _distribute_region( - self, spec: DTensorSpec, generator: Optional[torch.Generator] = None + self, spec: DTensorSpec, generator: torch.Generator | None = None ): pass @@ -240,7 +240,7 @@ def _set_device_state(self, state: torch.Tensor): @contextlib.contextmanager def _distribute_region( - self, spec: DTensorSpec, generator: Optional[torch.Generator] = None + self, spec: DTensorSpec, generator: torch.Generator | None = None ): from torch.distributed._local_tensor import maybe_enable_local_tracker @@ -340,7 +340,7 @@ def _set_pre_op_offset(self, state: _PhiloxState, spec: DTensorSpec) -> None: mesh = spec.mesh # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP # case. Replace the custom logic with dim_map once we support it. - dim_map: list[Union[int, list[int]]] = [-1] * spec.ndim + dim_map: list[int | list[int]] = [-1] * spec.ndim for i, placement in enumerate(spec.placements): if isinstance(placement, Shard): shard_dim = placement.dim diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index a407ba6ca91df..f38ca7acebbb4 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -8,7 +8,7 @@ from collections import defaultdict from collections.abc import Sequence from functools import cache -from typing import cast, NamedTuple, Optional +from typing import cast, NamedTuple import torch import torch.distributed._functional_collectives as funcol @@ -88,7 +88,7 @@ class DTensorRedistributePlanner: class DistState: placements: tuple[Placement, ...] tensor_dim_to_mesh_dim: ShardOrder - _hash: Optional[int] = dataclasses.field( + _hash: int | None = dataclasses.field( default=None, init=False, repr=False, compare=False ) @@ -161,7 +161,7 @@ def stringify_transform_infos( mesh: DeviceMesh, transform_infos: Sequence[_TransformInfo], src_placement: tuple[Placement, ...], - src_shard_order: Optional[ShardOrder] = None, + src_shard_order: ShardOrder | None = None, ) -> str: """ Generate a string representation of the sequence of state transitions @@ -646,7 +646,7 @@ def generate_greedy_transform_infos( def _gen_transform_infos_non_cached( src_spec: DTensorSpec, dst_spec: DTensorSpec, - use_graph_based_transform: Optional[bool] = None, + use_graph_based_transform: bool | None = None, ) -> list[_TransformInfo]: transform_infos: list[_TransformInfo] = [] device_mesh = src_spec.device_mesh @@ -678,7 +678,7 @@ def _gen_transform_infos_non_cached( def _gen_transform_infos( src_spec: DTensorSpec, dst_spec: DTensorSpec, - use_graph_based_transform: Optional[bool] = None, + use_graph_based_transform: bool | None = None, ) -> list[_TransformInfo]: return _gen_transform_infos_non_cached( src_spec, dst_spec, use_graph_based_transform @@ -692,7 +692,7 @@ def redistribute_local_tensor( *, async_op: bool = False, is_backward: bool = False, - use_graph_based_transform: Optional[bool] = None, + use_graph_based_transform: bool | None = None, ) -> torch.Tensor: """ This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to @@ -846,8 +846,8 @@ def forward( # type: ignore[override] device_mesh: DeviceMesh, placements: tuple[Placement, ...], async_op: bool = False, - forward_dtype: Optional[torch.dtype] = None, - backward_dtype: Optional[torch.dtype] = None, + forward_dtype: torch.dtype | None = None, + backward_dtype: torch.dtype | None = None, ): ctx.async_op = async_op ctx.backward_dtype = backward_dtype diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index f3dc04ef10f97..f3cbb90dc8f04 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Sequence from functools import lru_cache from itertools import chain -from typing import cast, Optional, Union +from typing import cast import torch from torch._guards import detect_fake_mode @@ -69,9 +69,7 @@ def __init__(self) -> None: ) # op map to save indices of shape (and stride) args which may need to be # modified in sharding prop - self.op_to_shape_and_stride_idx: dict[ - OpOverload, Union[int, tuple[int, int]] - ] = { + self.op_to_shape_and_stride_idx: dict[OpOverload, int | tuple[int, int]] = { # new factory ops aten.new_empty.default: 1, aten.new_full.default: 1, @@ -91,7 +89,7 @@ def register_sharding_prop_rule( self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding], - schema_info: Optional[RuntimeSchemaInfo] = None, + schema_info: RuntimeSchemaInfo | None = None, ): """ Register a sharding propagation rule for an operator. @@ -104,7 +102,7 @@ def register_op_strategy( self, op_overload: OpOverload, strategy_func: Callable[[OpSchema], StrategyType], - schema_info: Optional[RuntimeSchemaInfo] = None, + schema_info: RuntimeSchemaInfo | None = None, ): """ Register a :class:`OpStrategy` generator for an operator. @@ -157,7 +155,7 @@ def register_op_strategy( def _propagate_tensor_meta_non_cached( self, op_schema: OpSchema - ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + ) -> None | TensorMeta | Sequence[TensorMeta | None]: """ Propagate the tensor metadata, it could either return a TensorMeta or a list/tuple of TensorMetas @@ -191,7 +189,7 @@ def _propagate_tensor_meta_non_cached( ) elif isinstance(fake_out, (tuple, list)): - tensor_meta_list: list[Optional[TensorMeta]] = [] + tensor_meta_list: list[TensorMeta | None] = [] for fake_out_item in fake_out: if isinstance(fake_out_item, torch.Tensor): tensor_meta_list.append( @@ -215,7 +213,7 @@ def _propagate_tensor_meta_non_cached( @lru_cache # noqa: B019 def _propagate_tensor_meta( self, op_schema: OpSchema - ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + ) -> None | TensorMeta | Sequence[TensorMeta | None]: """ Cached version of _propagate_tensor_meta_non_cached This is a private API. Use propagate_tensor_meta instead. @@ -224,7 +222,7 @@ def _propagate_tensor_meta( def propagate_tensor_meta( self, op_schema: OpSchema - ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + ) -> None | TensorMeta | Sequence[TensorMeta | None]: """ Propagate the tensor metadata, it could either return a TensorMeta or a list/tuple of TensorMetas. This is a public API that should be @@ -239,7 +237,7 @@ def _create_output_spec_with_new_tensor_meta( self, op: OpOverload, output_specs: OutputSpecType, - output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]], + output_tensor_meta: None | TensorMeta | Sequence[TensorMeta | None], ) -> OutputSpecType: """ Wrap the output_specs with the tensor metadata from the output. @@ -260,7 +258,7 @@ def _create_output_spec_with_new_tensor_meta( ) return output_specs.shallow_copy_with_tensor_meta(output_tensor_meta) elif isinstance(output_specs, (tuple, list)): - new_specs: list[Optional[DTensorSpec]] = [] + new_specs: list[DTensorSpec | None] = [] if not isinstance(output_tensor_meta, (tuple, list)) or len( output_specs ) != len(output_tensor_meta): @@ -593,7 +591,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin ) def _select_strategy( - self, strategy: OpStrategy, op_schema: Optional[OpSchema] = None + self, strategy: OpStrategy, op_schema: OpSchema | None = None ) -> OpSpec: if len(strategy.strategies) == 1: # short cut with only one possible OpSpec diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 74ad2aaa80434..adf0e8e8069a6 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,7 +1,7 @@ import threading from collections import defaultdict from collections.abc import Sequence -from typing import cast, Optional +from typing import cast import torch import torch.distributed._functional_collectives as funcol @@ -159,7 +159,7 @@ def compute_local_shape_and_global_offset( def _compute_local_shape_and_global_offset( global_shape: ShapeType, mesh_shape: ShapeType, - my_coordinate: Optional[list[int]], + my_coordinate: list[int] | None, placements: Sequence[Placement], ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ diff --git a/torch/distributed/tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py index 6744448527821..3f5cf80f36a1c 100644 --- a/torch/distributed/tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -5,7 +5,7 @@ import argparse import os -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import torch import torch.nn as nn @@ -55,7 +55,7 @@ def __init__(self, world_size: int, rank: int) -> None: self.device_type = get_device_type() def _MLP_model_setup( - self, model_type: type, parallelize_plan: Union[None, dict] = None + self, model_type: type, parallelize_plan: None | dict = None ) -> tuple[nn.Module, torch.Tensor]: """ Creates MLP or MLPStacked model for examples diff --git a/torch/distributed/tensor/examples/flex_attention_cp.py b/torch/distributed/tensor/examples/flex_attention_cp.py index 5de92579b25b6..8b309a6d2646e 100644 --- a/torch/distributed/tensor/examples/flex_attention_cp.py +++ b/torch/distributed/tensor/examples/flex_attention_cp.py @@ -5,7 +5,6 @@ import os from functools import lru_cache -from typing import Optional import torch import torch.distributed as dist @@ -27,8 +26,8 @@ def get_device_type() -> str: @lru_cache def create_block_mask_cached( score_mod: _mask_mod_signature, - B: Optional[int], - H: Optional[int], + B: int | None, + H: int | None, M: int, N: int, device: str = "cuda", diff --git a/torch/distributed/tensor/experimental/_context_parallel/_attention.py b/torch/distributed/tensor/experimental/_context_parallel/_attention.py index b1903e211a1c1..f3d06b4fd274d 100644 --- a/torch/distributed/tensor/experimental/_context_parallel/_attention.py +++ b/torch/distributed/tensor/experimental/_context_parallel/_attention.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import auto, Enum from functools import partial -from typing import Any, cast, Optional, Protocol, TypeAlias +from typing import Any, cast, Protocol, TypeAlias import torch import torch.distributed as dist @@ -140,8 +140,8 @@ class _SDPAMerger: def __init__(self, convert_to_f32: bool, seq_dim: int): self._seq_dim = seq_dim - self._out: Optional[torch.Tensor] = None - self._lse: Optional[torch.Tensor] = None + self._out: torch.Tensor | None = None + self._lse: torch.Tensor | None = None self._should_lse_squeeze = False self._convert_to_f32 = convert_to_f32 self._out_dtype = torch.float32 @@ -250,7 +250,7 @@ class _AllToAllRotater(_RingRotater): def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: self._pg = pg self._seq_dim = seq_dim - self._buffer: Optional[torch.Tensor] = None + self._buffer: torch.Tensor | None = None def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: curr_buffer = curr_buffer.contiguous() @@ -272,7 +272,7 @@ class _AllGatherRotater(_RingRotater): def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: self._pg = pg self._seq_dim = seq_dim - self._aggregated_buffer: Optional[torch.Tensor] = None + self._aggregated_buffer: torch.Tensor | None = None self._idx = 0 def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: @@ -293,7 +293,7 @@ def next_buffer(self) -> torch.Tensor: def _create_rotater( - pg: dist.ProcessGroup, seq_dim: int, method: Optional[_RotateMethod] = None + pg: dist.ProcessGroup, seq_dim: int, method: _RotateMethod | None = None ) -> _RingRotater: if method is None: method = _cp_options.rotate_method @@ -655,7 +655,7 @@ def _scaled_dot_product_ring_flash_attention( is_causal: bool = False, return_debug_mask: bool = False, *, - scale: Optional[float] = None, + scale: float | None = None, ) -> tuple[torch.Tensor, ...]: if return_debug_mask: raise NotImplementedError("return_debug_mask is not supported yet") @@ -681,12 +681,12 @@ def _scaled_dot_product_ring_efficient_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, + attn_bias: torch.Tensor | None = None, compute_log_sumexp: bool = True, dropout_p: float = 0.0, is_causal: bool = False, *, - scale: Optional[float] = None, + scale: float | None = None, ) -> tuple[torch.Tensor, ...]: if attn_bias is not None: raise NotImplementedError("attn_bias is not supported yet") @@ -718,13 +718,13 @@ def _scaled_dot_product_ring_cudnn_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, + attn_bias: torch.Tensor | None = None, compute_log_sumexp: bool = True, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, *, - scale: Optional[float] = None, + scale: float | None = None, ) -> tuple[torch.Tensor, ...]: if attn_bias is not None: raise NotImplementedError("attn_bias is not supported yet") @@ -769,7 +769,7 @@ def _scaled_dot_product_ring_flash_attention_backward( philox_seed: torch.Tensor, philox_offset: torch.Tensor, *, - scale: Optional[float] = None, + scale: float | None = None, ) -> tuple[torch.Tensor, ...]: # TODO: remove this hardcoding seq_dim = 2 @@ -812,7 +812,7 @@ def _scaled_dot_product_ring_efficient_attention_backward( grad_input_mask: tuple[bool, ...], is_causal: bool = False, *, - scale: Optional[float] = None, + scale: float | None = None, ) -> tuple[torch.Tensor, ...]: # TODO: remove this hardcoding seq_dim = 2 @@ -856,7 +856,7 @@ def _scaled_dot_product_ring_cudnn_attention_backward( dropout_p: float, is_causal: bool, *, - scale: Optional[float] = None, + scale: float | None = None, ) -> tuple[torch.Tensor, ...]: # TODO: remove this hardcoding seq_dim = 2 @@ -938,8 +938,8 @@ def _sdpa_handler( ArgsType = tuple[Any, ...] KwargsType = dict[str, Any] -InputFnType = Callable[[Optional[nn.Module], ArgsType, KwargsType, DeviceMesh], Any] -OutputFnType = Callable[[Optional[nn.Module], Any, Any, DeviceMesh], Any] +InputFnType = Callable[[nn.Module | None, ArgsType, KwargsType, DeviceMesh], Any] +OutputFnType = Callable[[nn.Module | None, Any, Any, DeviceMesh], Any] _replaced_functions: dict[Callable, tuple[str, Callable]] = {} @@ -1039,7 +1039,7 @@ def _context_parallel_buffers( mesh: DeviceMesh, buffers: list[torch.Tensor | BlockMask], buffer_seq_dims: list[int], - load_balancer: Optional[_LoadBalancer] = None, + load_balancer: _LoadBalancer | None = None, ) -> list[torch.Tensor | BlockMask]: """ Shard the buffers along the sequence dimensions according to CP rules. @@ -1136,7 +1136,7 @@ def _create_cp_block_mask( Q_LEN: int, KV_LEN: int, device_mesh: DeviceMesh, - load_balancer: Optional[_LoadBalancer] = None, + load_balancer: _LoadBalancer | None = None, ) -> BlockMask: """ Creates a specialized BlockMask for Context Parallel FlexAttention. @@ -1197,7 +1197,7 @@ def _rewrite_mask_mod( rank: int, block_size: int, local_q_size: int, - qkv_rearrange_indices: Optional[torch.Tensor] = None, + qkv_rearrange_indices: torch.Tensor | None = None, ) -> _mask_mod_signature: assert qkv_rearrange_indices is None or qkv_rearrange_indices.ndim == 2, ( "load balance index expects shape (1, seq_len) or (B, seq_len) " @@ -1301,7 +1301,7 @@ def _apply(self, module: nn.Module, mesh: DeviceMesh) -> nn.Module: raise ValueError(f"Unknown attention type: {self.attention_type}") def flex_input_fn( - self, module: Optional[nn.Module], args: Any, kwargs: Any, mesh: DeviceMesh + self, module: nn.Module | None, args: Any, kwargs: Any, mesh: DeviceMesh ) -> Any: args_list = list(args) for idx, name in enumerate( @@ -1329,7 +1329,7 @@ def flex_input_fn( def sdpa_input_fn( self, - module: Optional[nn.Module], + module: nn.Module | None, args: tuple[Any, ...], kwargs: dict[str, Any], mesh: DeviceMesh, @@ -1351,7 +1351,7 @@ def sdpa_input_fn( return new_args, new_kwargs def sdpa_output_fn( - self, module: Optional[nn.Module], inputs: Any, outputs: Any, mesh: DeviceMesh + self, module: nn.Module | None, inputs: Any, outputs: Any, mesh: DeviceMesh ) -> Any: new_outputs = [] for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs: @@ -1373,7 +1373,7 @@ def _context_parallel_shard( mesh: DeviceMesh, buffers: CPBufferContainer, seq_dims: CPBufferSeqDims, - load_balancer: Optional[_LoadBalancer] = None, + load_balancer: _LoadBalancer | None = None, ) -> list[torch.Tensor | BlockMask]: """ Shard the buffers along the specified sequence dimensions (`seq_dims`), so that each @@ -1464,9 +1464,9 @@ def _disable_context_parallel_dispatcher() -> None: def context_parallel( mesh: DeviceMesh, *, - buffers: Optional[list[torch.Tensor]] = None, - buffer_seq_dims: Optional[list[int]] = None, - no_restore_buffers: Optional[set[torch.Tensor]] = None, + buffers: list[torch.Tensor] | None = None, + buffer_seq_dims: list[int] | None = None, + no_restore_buffers: set[torch.Tensor] | None = None, ) -> Generator[None, None, None]: """ @@ -1554,7 +1554,7 @@ def context_parallel_unshard( mesh: DeviceMesh, buffers: list[torch.Tensor], seq_dims: list[int], - load_balancer: Optional[_LoadBalancer] = None, + load_balancer: _LoadBalancer | None = None, ) -> list[torch.Tensor]: """ Unshard the tensors (e.g., output) that are sharded due to context parallelism. diff --git a/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py b/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py index e5230092b41d7..4b293b0e260ef 100644 --- a/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py +++ b/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py @@ -2,7 +2,6 @@ # for different load-balancing strategies in tensor sharding. import functools from abc import ABC, abstractmethod -from typing import Optional import torch from torch import Tensor @@ -12,7 +11,7 @@ # make it private since it's still a prototype class _LoadBalancer(ABC): @abstractmethod - def _generate_indices(self, restore: bool = False) -> Optional[Tensor]: + def _generate_indices(self, restore: bool = False) -> Tensor | None: """ Generate indices for load balancing. Args: @@ -478,7 +477,7 @@ def _generate_indices(self, restore: bool = False) -> Tensor: def _create_default_load_balancer( seq_length: int, world_size: int, device: str | torch.device -) -> Optional[_LoadBalancer]: +) -> _LoadBalancer | None: from ._attention import _cp_options if _cp_options.enable_load_balance: diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index cf0e9df1ab332..759841a40aaa1 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -24,11 +24,11 @@ def local_map( - func: Optional[Callable] = None, + func: Callable | None = None, out_placements: OutputPlacements = None, in_placements: InputPlacements = None, in_grad_placements: InputPlacements = None, - device_mesh: Optional[DeviceMesh] = None, + device_mesh: DeviceMesh | None = None, *, redistribute_inputs: bool = False, ): @@ -163,7 +163,7 @@ def _local_map_wrapped( out_placements: OutputPlacements, in_placements: InputPlacements, in_grad_placements: InputPlacements, - device_mesh: Optional[DeviceMesh], + device_mesh: DeviceMesh | None, redistribute_inputs: bool, *args, **kwargs, diff --git a/torch/distributed/tensor/experimental/_register_sharding.py b/torch/distributed/tensor/experimental/_register_sharding.py index 9879946f54bc1..7b365dcf286d0 100644 --- a/torch/distributed/tensor/experimental/_register_sharding.py +++ b/torch/distributed/tensor/experimental/_register_sharding.py @@ -2,7 +2,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from collections.abc import Callable, Sequence from functools import partial -from typing import Union import torch from torch._ops import OpOverload @@ -21,7 +20,7 @@ __all__ = ["register_sharding"] -def register_sharding(op: Union[OpOverload, list[OpOverload]]): +def register_sharding(op: OpOverload | list[OpOverload]): """ :meth:`register_sharding` is an experimental API that allows users to register sharding strategies for an operator when the tensor inputs and outputs are DTensor. diff --git a/torch/distributed/tensor/experimental/_tp_transform.py b/torch/distributed/tensor/experimental/_tp_transform.py index 426eb2ac83b38..1075df79f3395 100644 --- a/torch/distributed/tensor/experimental/_tp_transform.py +++ b/torch/distributed/tensor/experimental/_tp_transform.py @@ -2,7 +2,7 @@ import copy import operator from collections.abc import Sequence -from typing import Any, cast, Optional +from typing import Any, cast import torch from torch._subclasses.fake_tensor import FakeTensor @@ -273,7 +273,7 @@ def _create_placement_strategy( node: Node, mesh: DeviceMesh, placements: tuple[Placement, ...], - input_specs: Optional[Sequence[DTensorSpec]] = None, + input_specs: Sequence[DTensorSpec] | None = None, ) -> OpSpec: """ Util function to construct an OpSpec for a given node. diff --git a/torch/distributed/tensor/parallel/_data_parallel_utils.py b/torch/distributed/tensor/parallel/_data_parallel_utils.py index c41da260a02f9..735b74e099478 100644 --- a/torch/distributed/tensor/parallel/_data_parallel_utils.py +++ b/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -1,5 +1,5 @@ from functools import partial -from typing import no_type_check, Optional +from typing import no_type_check import torch from torch.distributed._functional_collectives import AsyncCollectiveTensor @@ -21,7 +21,7 @@ def sync_grad_hook(grad, *, device_handle=None, compute_stream=None): def _flatten_tensor( tensor: torch.Tensor, -) -> tuple[torch.Tensor, Optional[DTensorSpec]]: +) -> tuple[torch.Tensor, DTensorSpec | None]: if isinstance(tensor, DTensor): tensor._local_tensor.requires_grad_() return tensor._local_tensor, tensor._spec diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index 51cfd0f144b3f..954b62327808d 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -1,7 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import warnings from fnmatch import fnmatch -from typing import Optional, Union import torch import torch.nn as nn @@ -14,10 +13,10 @@ def parallelize_module( # type: ignore[return] module: nn.Module, - device_mesh: Optional[DeviceMesh] = None, - parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None, + device_mesh: DeviceMesh | None = None, + parallelize_plan: ParallelStyle | dict[str, ParallelStyle] | None = None, *, - src_data_rank: Optional[int] = 0, + src_data_rank: int | None = 0, ) -> nn.Module: """ Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index 7b19f97675197..19c1d3ca5477e 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Optional +from typing import Any import torch.nn as nn from torch.distributed.tensor.parallel._data_parallel_utils import ( @@ -48,7 +48,7 @@ def _reconstruct_dtensor(module: nn.Module, _input: Any): def _localize_dtensor( - module: nn.Module, *_: Any, ignored_params: Optional[set[nn.Parameter]] = None + module: nn.Module, *_: Any, ignored_params: set[nn.Parameter] | None = None ): """ Convert DTensor parameters to local tensors diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index f491624b5aaea..9e68ed6b1dba5 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import copy -from typing import Any, cast, Optional +from typing import Any, cast import torch import torch.distributed as dist @@ -297,7 +297,7 @@ def _pre_load_state_dict( def _all_gather_dtensor( tensor: DTensor, - parent_mesh: Optional[DeviceMesh], + parent_mesh: DeviceMesh | None, ) -> torch.Tensor: """All gather a DTensor in its FSDP dimension and return the local tensor.""" assert parent_mesh == tensor.device_mesh @@ -336,7 +336,7 @@ def __init__(self, device_handle) -> None: def pre_flatten_transform( self, tensor: torch.Tensor, - ) -> tuple[torch.Tensor, Optional[Any]]: + ) -> tuple[torch.Tensor, Any | None]: return _flatten_tensor(tensor) def post_unflatten_transform( @@ -365,7 +365,7 @@ def chunk_tensor( world_size: int, num_devices_per_node: int, pg: dist.ProcessGroup, - device: Optional[torch.device] = None, + device: torch.device | None = None, ) -> torch.Tensor: return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg) @@ -386,6 +386,6 @@ def pre_load_state_dict_transform( def all_gather_dtensor( self, tensor: DTensor, - parent_mesh: Optional[DeviceMesh], + parent_mesh: DeviceMesh | None, ) -> torch.Tensor: return _all_gather_dtensor(tensor, parent_mesh) diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index de003c5994684..81e25621e040a 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from functools import partial -from typing import Any, Optional +from typing import Any import torch from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard @@ -14,7 +14,7 @@ def input_reshard( module: torch.nn.Module, tp_device_mesh: DeviceMesh, - input_reshard_dim: Optional[int] = None, + input_reshard_dim: int | None = None, ) -> torch.nn.Module: """ Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation. @@ -42,7 +42,7 @@ def input_reshard( if input_reshard_dim is None: return module - cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None + cx: torch.autograd.graph.saved_tensors_hooks | None = None def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: tuple[Any, ...]) -> None: saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks( diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index 7cb26bf699650..9c1adbf2a672a 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import contextlib -from typing import cast, Optional +from typing import cast import torch import torch._prims_common as utils @@ -201,8 +201,8 @@ def _log_softmax_backward_handler( def _nll_loss_forward( x: Tensor, target: Tensor, - weight: Optional[Tensor], - local_weight: Optional[Tensor], + weight: Tensor | None, + local_weight: Tensor | None, reduction: int, ignore_index: int, input_shape: torch.Size, @@ -356,7 +356,7 @@ def _nll_loss_and_log_softmax_backward( grad_output: Tensor, x: Tensor, target: Tensor, - weight: Optional[Tensor], + weight: Tensor | None, reduction: int, ignore_index: int, total_weight: Tensor, diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 182a3fbcafebf..9eed832eabe86 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod from functools import partial -from typing import Any, Optional, Union +from typing import Any import torch import torch.nn as nn @@ -36,7 +36,7 @@ class ParallelStyle(ABC): flexibility for different kind of style implementations. """ - src_data_rank: Optional[int] = 0 + src_data_rank: int | None = 0 @abstractmethod def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ... @@ -82,8 +82,8 @@ class ColwiseParallel(ParallelStyle): def __init__( self, *, - input_layouts: Optional[Placement] = None, - output_layouts: Optional[Placement] = None, + input_layouts: Placement | None = None, + output_layouts: Placement | None = None, use_local_output: bool = True, ): super().__init__() @@ -212,8 +212,8 @@ class RowwiseParallel(ParallelStyle): def __init__( self, *, - input_layouts: Optional[Placement] = None, - output_layouts: Optional[Placement] = None, + input_layouts: Placement | None = None, + output_layouts: Placement | None = None, use_local_output: bool = True, ): super().__init__() @@ -473,14 +473,10 @@ class PrepareModuleInput(ParallelStyle): def __init__( self, *, - input_layouts: Optional[ - Union[Placement, tuple[Optional[Placement], ...]] - ] = None, - desired_input_layouts: Optional[ - Union[Placement, tuple[Optional[Placement], ...]] - ] = None, - input_kwarg_layouts: Optional[dict[str, Placement]] = None, - desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None, + input_layouts: Placement | tuple[Placement | None, ...] | None = None, + desired_input_layouts: Placement | tuple[Placement | None, ...] | None = None, + input_kwarg_layouts: dict[str, Placement] | None = None, + desired_input_kwarg_layouts: dict[str, Placement] | None = None, use_local_output: bool = False, ): self.input_layouts = ( @@ -513,8 +509,8 @@ def _prepare_input_arg( self, input: Any, mesh: DeviceMesh, - input_layout: Optional[Placement], - desired_layout: Optional[Placement], + input_layout: Placement | None, + desired_layout: Placement | None, ): if input_layout is not None: if isinstance(input, DTensor): @@ -637,8 +633,8 @@ class PrepareModuleOutput(ParallelStyle): def __init__( self, *, - output_layouts: Union[Placement, tuple[Optional[Placement], ...]], - desired_output_layouts: Union[Placement, tuple[Placement, ...]], + output_layouts: Placement | tuple[Placement | None, ...], + desired_output_layouts: Placement | tuple[Placement, ...], use_local_output: bool = True, ): self.output_layouts = ( @@ -768,17 +764,13 @@ class PrepareModuleInputOutput(ParallelStyle): def __init__( self, *, - input_layouts: Optional[ - Union[Placement, tuple[Optional[Placement], ...]] - ] = None, - desired_input_layouts: Optional[ - Union[Placement, tuple[Optional[Placement], ...]] - ] = None, - input_kwarg_layouts: Optional[dict[str, Placement]] = None, - desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None, + input_layouts: Placement | tuple[Placement | None, ...] | None = None, + desired_input_layouts: Placement | tuple[Placement | None, ...] | None = None, + input_kwarg_layouts: dict[str, Placement] | None = None, + desired_input_kwarg_layouts: dict[str, Placement] | None = None, use_local_input: bool = False, - output_layouts: Union[Placement, tuple[Optional[Placement], ...]], - desired_output_layouts: Union[Placement, tuple[Placement, ...]], + output_layouts: Placement | tuple[Placement | None, ...], + desired_output_layouts: Placement | tuple[Placement, ...], use_local_output: bool = True, ): self.prepare_module_input = PrepareModuleInput( diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 726abc5971376..a9f253c177ef2 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass, field -from typing import cast, Optional +from typing import cast import torch import torch._C @@ -129,7 +129,7 @@ def _local_shard_size_and_offset( curr_local_size: int, num_chunks: int, rank: int, - ) -> tuple[int, Optional[int]]: + ) -> tuple[int, int | None]: return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank) @staticmethod @@ -151,7 +151,7 @@ def _shard_tensor( tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - src_data_rank: Optional[int] = 0, + src_data_rank: int | None = 0, ) -> torch.Tensor: """ shard and scatter a tensor on a mesh dimension (use coordinate @@ -203,7 +203,7 @@ def _make_shard_tensor( tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - src_data_rank: Optional[int] = 0, + src_data_rank: int | None = 0, ) -> torch.Tensor: shard_placement = cls(dim) return shard_placement._shard_tensor(tensor, mesh, mesh_dim, src_data_rank) @@ -566,7 +566,7 @@ def _make_shard_tensor( tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - src_data_rank: Optional[int] = 0, + src_data_rank: int | None = 0, split_factor: int = 1, ) -> torch.Tensor: strided_shard_placement = cls(dim=dim, split_factor=split_factor) @@ -689,7 +689,7 @@ def _local_shard_size_and_offset( curr_local_size: int, num_chunks: int, rank: int, - ) -> tuple[int, Optional[int]]: + ) -> tuple[int, int | None]: # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [curr_local_size] @@ -742,7 +742,7 @@ def _make_replicate_tensor( tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - src_data_rank: Optional[int] = 0, + src_data_rank: int | None = 0, ) -> torch.Tensor: """ Replicate (broadcast) a torch.Tensor on a mesh dimension (use @@ -765,7 +765,7 @@ def _replicate_tensor( tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - src_data_rank: Optional[int] = 0, + src_data_rank: int | None = 0, ) -> torch.Tensor: return Replicate._make_replicate_tensor(tensor, mesh, mesh_dim, src_data_rank) @@ -863,7 +863,7 @@ class MaskPartial(Partial): mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) # required fields for computing the local offset and deriving the mask - offset_shape: Optional[torch.Size] = None + offset_shape: torch.Size | None = None offset_dim: int = 0 def __init__( diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 275814693354f..9422d05bf7e7d 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -44,7 +44,7 @@ def _pack_kwargs(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], tuple[str, def _cast_forward_inputs( - dtype: Optional[torch.dtype], + dtype: torch.dtype | None, *args: Any, **kwargs: Any, ) -> tuple[Any, Any]: @@ -257,7 +257,7 @@ def apply(x): def _to_kwargs( inputs: tuple[Any, ...], - kwargs: Optional[dict[str, Any]], + kwargs: dict[str, Any] | None, target_device: torch.device, use_side_stream_for_tensor_copies: bool, ) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]: From d9d5e91b43f70eb8637af55db6856d49be391ffd Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Sat, 29 Nov 2025 03:09:54 +0000 Subject: [PATCH 061/338] [dynamo, 3.14] Ensure `typing.Union` is correctly traced in Dynamo. (#169084) Fixes `test_vector_norm_decom_unbacked_checks_cpu` from `test/test_linalg.py` on Python 3.14 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169084 Approved by: https://github.com/rtimpe, https://github.com/williamwen42 --- torch/_dynamo/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ec8f83c33d333..c6825737ec994 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1065,6 +1065,10 @@ def istype(obj: object, allowed_types: Any) -> bool: ) +if sys.version_info >= (3, 14): + _builtin_final_typing_classes += (typing.Union,) + + def is_typing(value: Any) -> bool: # _Final catches most of typing classes: # - Any From 90b27e7e8352cde97d32ddad24740ef819633f38 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 29 Nov 2025 00:44:10 -0800 Subject: [PATCH 062/338] [dynamo] Skip Dynamo wrapped tests on cpython tests (#169233) Running Dynamo-wrapped tests on Cpython tests does not make sense. It often leads to some unusual unexpected successes and failures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169233 Approved by: https://github.com/guilhermeleobas --- torch/_dynamo/test_case.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 0706e55abd8fa..ad2637b3b124b 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -25,6 +25,7 @@ from torch._logging._internal import trace_log from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] IS_WINDOWS, + skipIfTorchDynamo, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO, TestCase as TorchTestCase, @@ -130,6 +131,7 @@ def tearDown(self) -> None: torch._dynamo.config.nested_graph_breaks = self.prev_nested_graph_breaks +@skipIfTorchDynamo("Not a suitable dynamo wrapped test") class CPythonTestCase(TestCase): """ Test class for CPython tests located in "test/dynamo/CPython/Py_version/*". From f7c0d03819ebed05c4038f095d66d1b8c54aca17 Mon Sep 17 00:00:00 2001 From: Firoz Syed Date: Sat, 29 Nov 2025 23:22:55 +0000 Subject: [PATCH 063/338] [Durin] Bump Kineto Submodule to latest (#169098) Summary: Bump kineto Submodule to latest commit Context : D87020156 [Kineto] Remove superfluous logging Differential Revision: D87904965 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169098 Approved by: https://github.com/wdvr, https://github.com/yaoyj11, https://github.com/cyyever --- third_party/kineto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/kineto b/third_party/kineto index 6fcbc53d33dd2..31f85df8fbd89 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 6fcbc53d33dd275c0aba1e5d7701d471b7f6eeb3 +Subproject commit 31f85df8fbd89c188f14ef10f1ec65379786b943 From 79d7b178225e5ed24d4e1db74e5abbff848f5fb7 Mon Sep 17 00:00:00 2001 From: hipudding Date: Sat, 29 Nov 2025 23:33:28 +0000 Subject: [PATCH 064/338] [openreg] Expand autocast test coverage for custom device (#169029) Add comprehensive test cases for autocast functionality including: - bfloat16 dtype support - Nested context management - Autograd integration and gradient flow - Mixed input dtypes and edge cases - Docstrings for all test methods Pull Request resolved: https://github.com/pytorch/pytorch/pull/169029 Approved by: https://github.com/fffrog, https://github.com/cyyever Co-authored-by: Jiawei Li --- .../torch_openreg/tests/test_autocast.py | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py index 6474a349ab430..25eb9cf3c570c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py @@ -6,6 +6,7 @@ class TestAutocast(TestCase): def test_autocast_with_unsupported_type(self): + """Test autocast with unsupported dtype (float32)""" with self.assertWarnsRegex( UserWarning, "In openreg autocast, but the target dtype is not supported. Disabling autocast.\n" @@ -15,6 +16,7 @@ def test_autocast_with_unsupported_type(self): _ = torch.ones(10) def test_autocast_operator_not_supported(self): + """Test that binary_cross_entropy is not supported in autocast""" with self.assertRaisesRegex( RuntimeError, "torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.", @@ -25,6 +27,7 @@ def test_autocast_operator_not_supported(self): _ = torch.nn.functional.binary_cross_entropy(x, y) def test_autocast_low_precision(self): + """Test low precision operations (mm) in autocast""" with torch.amp.autocast(device_type="openreg", dtype=torch.float16): x = torch.randn(2, 3, device="openreg") y = torch.randn(3, 3, device="openreg") @@ -32,20 +35,162 @@ def test_autocast_low_precision(self): self.assertEqual(result.dtype, torch.float16) def test_autocast_fp32(self): + """Test fp32 operations (asin) in autocast""" with torch.amp.autocast(device_type="openreg"): x = torch.randn(2, device="openreg", dtype=torch.float16) result = torch.asin(x) self.assertEqual(result.dtype, torch.float32) def test_autocast_default_dtype(self): + """Test default autocast dtype""" openreg_fast_dtype = torch.get_autocast_dtype(device_type="openreg") self.assertEqual(openreg_fast_dtype, torch.half) def test_autocast_set_dtype(self): + """Test setting autocast dtype""" for dtype in [torch.float16, torch.bfloat16]: torch.set_autocast_dtype("openreg", dtype) self.assertEqual(torch.get_autocast_dtype("openreg"), dtype) + def test_autocast_bfloat16(self): + """Test autocast with bfloat16 dtype""" + with torch.amp.autocast(device_type="openreg", dtype=torch.bfloat16): + x = torch.randn(2, 3, device="openreg", dtype=torch.float32) + y = torch.randn(3, 3, device="openreg", dtype=torch.float32) + result = torch.mm(x, y) + self.assertEqual(result.dtype, torch.bfloat16) + + def test_autocast_low_precision_bfloat16(self): + """Test low precision operations with bfloat16""" + with torch.amp.autocast(device_type="openreg", dtype=torch.bfloat16): + x = torch.randn(2, 3, device="openreg") + y = torch.randn(3, 3, device="openreg") + result = torch.mm(x, y) + self.assertEqual(result.dtype, torch.bfloat16) + + def test_autocast_fp32_with_bfloat16(self): + """Test fp32 operations with bfloat16 autocast""" + with torch.amp.autocast(device_type="openreg", dtype=torch.bfloat16): + x = torch.randn(2, device="openreg", dtype=torch.bfloat16) + result = torch.asin(x) + self.assertEqual(result.dtype, torch.float32) + + def test_autocast_nested_context(self): + """Test nested autocast contexts""" + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, 3, device="openreg") + y = torch.randn(3, 3, device="openreg") + result1 = torch.mm(x, y) + self.assertEqual(result1.dtype, torch.float16) + + # Nested autocast context with bfloat16 + with torch.amp.autocast(device_type="openreg", dtype=torch.bfloat16): + result2 = torch.mm(x, y) + self.assertEqual(result2.dtype, torch.bfloat16) + + # After exiting nested context, should restore to float16 + result3 = torch.mm(x, y) + self.assertEqual(result3.dtype, torch.float16) + + def test_autocast_fallthrough_operation(self): + """Test fallthrough operations (operations not specially registered)""" + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, 3, device="openreg", dtype=torch.float32) + # add operation is not specially registered, should fallthrough + result = torch.add(x, x) + # fallthrough operations should preserve input type or use default behavior + self.assertEqual(result.dtype, torch.float32) + + def test_autocast_with_requires_grad(self): + """Test autocast interaction with requires_grad""" + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, 3, device="openreg", requires_grad=True) + y = torch.randn(3, 3, device="openreg", requires_grad=True) + result = torch.mm(x, y) + self.assertEqual(result.dtype, torch.float16) + self.assertTrue(result.requires_grad) + + # Test backward propagation + loss = result.sum() + loss.backward() + self.assertIsNotNone(x.grad) + self.assertIsNotNone(y.grad) + + def test_autocast_mixed_input_dtypes(self): + """Test combinations of different input dtypes""" + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, 3, device="openreg", dtype=torch.float32) + y = torch.randn(3, 3, device="openreg", dtype=torch.float16) + # mm operation should convert inputs to low precision + result = torch.mm(x, y) + self.assertEqual(result.dtype, torch.float16) + + def test_autocast_already_target_dtype(self): + """Test when inputs are already in target dtype""" + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, 3, device="openreg", dtype=torch.float16) + y = torch.randn(3, 3, device="openreg", dtype=torch.float16) + result = torch.mm(x, y) + self.assertEqual(result.dtype, torch.float16) + + def test_autocast_combination_operations(self): + """Test multiple operations combination under autocast""" + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, 3, device="openreg") + y = torch.randn(3, 3, device="openreg") + z = torch.randn(2, device="openreg") + + # Low precision operation + result1 = torch.mm(x, y) + self.assertEqual(result1.dtype, torch.float16) + + # fp32 operation + result2 = torch.asin(z) + self.assertEqual(result2.dtype, torch.float32) + + # Combined operations + result3 = torch.mm(result1, y) + self.assertEqual(result3.dtype, torch.float16) + + def test_autocast_disable(self): + """Test disabling autocast""" + with torch.amp.autocast( + device_type="openreg", dtype=torch.float16, enabled=False + ): + x = torch.randn(2, 3, device="openreg", dtype=torch.float32) + y = torch.randn(3, 3, device="openreg", dtype=torch.float32) + result = torch.mm(x, y) + # When autocast is disabled, should preserve original dtype + self.assertEqual(result.dtype, torch.float32) + + def test_autocast_cache_enabled(self): + """Test autocast caching""" + with torch.amp.autocast( + device_type="openreg", dtype=torch.float16, cache_enabled=True + ): + x = torch.randn(2, 3, device="openreg") + y = torch.randn(3, 3, device="openreg") + result1 = torch.mm(x, y) + result2 = torch.mm(x, y) + self.assertEqual(result1.dtype, torch.float16) + self.assertEqual(result2.dtype, torch.float16) + + def test_autocast_fp32_operation_with_float16_input(self): + """Test fp32 operations receiving float16 input""" + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, device="openreg", dtype=torch.float16) + result = torch.asin(x) + # asin should output float32 + self.assertEqual(result.dtype, torch.float32) + + def test_autocast_fp32_operation_with_float32_input(self): + """Test fp32 operations receiving float32 input""" + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, device="openreg", dtype=torch.float32) + result = torch.asin(x) + # asin should output float32 + self.assertEqual(result.dtype, torch.float32) + if __name__ == "__main__": run_tests() From 4bebc827c47d2f1f0fa1a417a5201a97aef3d985 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Sat, 29 Nov 2025 17:17:55 +0000 Subject: [PATCH 065/338] Add `torch.Tensor.__annotate__` to the `testing_ignore` list (#169013) Skip `torch.Tensor.__annotate__` when testing for `__torch_override__` Pull Request resolved: https://github.com/pytorch/pytorch/pull/169013 Approved by: https://github.com/rtimpe, https://github.com/williamwen42 --- torch/overrides.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/overrides.py b/torch/overrides.py index 22dfb67b825cc..e0597eafd8107 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -25,6 +25,7 @@ import collections import contextlib import functools +import sys import types import warnings from collections.abc import Callable, Iterable @@ -119,7 +120,7 @@ def get_ignored_functions() -> set[Callable]: False """ Tensor = torch.Tensor - return { + functions = { torch.typename, torch.is_tensor, torch.is_storage, @@ -384,6 +385,11 @@ def get_ignored_functions() -> set[Callable]: Tensor._use_count, } + if sys.version_info >= (3, 14): + functions.add(Tensor.__annotate__) + + return functions + @functools.cache def get_default_nowrap_functions() -> set[Callable]: From 84149583d483e9c973c9a0feda70e4f3964947b0 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 29 Nov 2025 09:48:49 -0800 Subject: [PATCH 066/338] [dynamo][dicts] Decentralize and Improve key hash implementation for Dict variable tracker (#169204) Fixes https://github.com/pytorch/pytorch/issues/167956 ## Summary This PR decentralizes and improves the hash implementation for dictionary keys in Dynamo's ConstDictVariable tracker. Instead of maintaining a centralized list of hashable types and custom equality logic in _HashableTracker, we now delegate hashability checks, hash computation, and equality comparison to individual VariableTracker subclasses. ## Motivation The previous implementation had several issues: 1. Centralized logic: All hashability checks and hash computations were centralized in dicts.py, making it difficult to add support for new hashable types 2. Maintainability: Adding a new hashable type required modifying multiple locations in _HashableTracker (underlying_value, _eq_impl, and the is_hashable function) 3. Scattered knowledge: Type-specific hashing logic was separated from the type's own implementation 4. Limited extensibility: No clear protocol for VariableTracker subclasses to declare themselves as hashable ## Changes New Protocol Methods Added three new methods to the VariableTracker base class: 1. is_python_hashable(): Returns whether the underlying Python object is hashable 2. get_python_hash(): Computes the hash value for the underlying Python object 3. is_python_equal(other): Checks Python-level equality between two VariableTrackers The base implementation raises unimplemented() with helpful error messages, and subclasses override these methods as appropriate. ## Simplified _HashableTracker The _HashableTracker class in ConstDictVariable is now much simpler: - Removed underlying_value property (centralized type handling) - Removed _eq_impl static method (centralized equality logic) - Simplified __hash__() to delegate to vt.get_python_hash() - Simplified __eq__() to delegate to vt.is_python_equal() ## Decentralized Implementations Implemented the new protocol methods across relevant VariableTracker subclasses: - ConstantVariable, TensorVariable, TupleVariable, ListVariable - FrozensetVariable, FrozenDataClassVariable - BuiltinVariable, UserFunctionVariable, SkipFunctionVariable - FunctoolsPartialVariable, WeakRefVariable - NumpyVariable, NNModuleVariable, MethodWrapperVariable - TorchInGraphFunctionVariable, TorchHigherOrderOperatorVariable - TypingVariable, UserDefinedObjectVariable, UserDefinedClassVariable - SymNodeVariable, EnumVariable ## Enhanced Test Coverage Added 14 new test cases covering various hashable types as dictionary keys: - range, tuples, enums, frozensets - Typing constructs (e.g., typing.Union) - NumPy dtypes, method wrappers - Torch builtin functions, frozen dataclasses - Custom objects with __hash__ - Negative test for unhashable types (lists) ## Improved Error Messages Updated error messages to be more informative when encountering unhashable types, showing both the Python type and the VariableTracker type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169204 Approved by: https://github.com/jansel ghstack dependencies: #169233 --- test/dynamo/test_dicts.py | 210 +++++++++++++++++- .../TestCustomOp.test_impl_device_cpu | 0 torch/_dynamo/graph_break_registry.json | 44 ++++ torch/_dynamo/utils.py | 18 ++ torch/_dynamo/variables/base.py | 51 +++++ torch/_dynamo/variables/builtin.py | 9 + torch/_dynamo/variables/constant.py | 25 +++ torch/_dynamo/variables/dicts.py | 206 ++++++----------- torch/_dynamo/variables/functions.py | 46 ++++ torch/_dynamo/variables/higher_order_ops.py | 9 + torch/_dynamo/variables/lists.py | 34 +++ torch/_dynamo/variables/misc.py | 37 +++ torch/_dynamo/variables/tensor.py | 28 +++ torch/_dynamo/variables/torch.py | 9 + torch/_dynamo/variables/user_defined.py | 55 +++-- 15 files changed, 627 insertions(+), 154 deletions(-) create mode 100644 test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index cdaeb2d91fbfb..4c233ea9458f3 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -19,6 +19,7 @@ import torch._functorch.config import torch.nn import torch.utils.checkpoint +from torch._dynamo.exc import Unsupported from torch._dynamo.testing import same from torch._dynamo.utils import dict_items from torch.testing._internal.common_utils import ( @@ -89,7 +90,7 @@ def forward(self, x): inp = torch.randn(4, 4) mod = Foo() - opt_f = torch.compile(mod) + opt_f = torch.compile(mod, backend="eager", fullgraph=True) self.assertEqual(mod(inp), opt_f(inp)) def test_dict_subclass_local_with_non_dict_method(self): @@ -518,7 +519,7 @@ def fn(d): args1 = {namedtuple: None, 3: torch.randn(3)} cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(fn, backend=cnts) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) self.assertEqual(fn(args1), opt_fn(args1)) self.assertEqual(cnts.frame_count, 1) # Test a failing namedtuple guard @@ -538,7 +539,7 @@ def fn(d, x): args1[3] = z cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(fn, backend=cnts) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) self.assertEqual(fn(args1, x), opt_fn(args1, x)) self.assertEqual(cnts.frame_count, 1) @@ -1062,8 +1063,6 @@ def fn(b: Any): a = {"one": torch.ones(1)} return a | b - from torch._dynamo.exc import Unsupported - for arg in args: with self.assertRaisesRegex(Unsupported, "Observed exception"): _ = fn(arg) @@ -1204,6 +1203,156 @@ def f(): opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) + def test_range_as_dict_key(self): + def fn(x): + d = {range(5): x * 2, range(10, 15): x * 3} + return d[range(0, 5, 1)] + d[range(10, 15)] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_tuple_as_dict_key(self): + def fn(x): + d = {(1, 2): x * 2, (3, 4, 5): x * 3} + return d[(1, 2)] + d[(3, 4, 5)] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_enum_as_dict_key(self): + class Color(enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + def fn(x): + d = {Color.RED: x * 2, Color.GREEN: x * 3, Color.BLUE: x * 4} + return d[Color.RED] + d[Color.GREEN] + d[Color.BLUE] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_intenum_as_dict_key(self): + class Priority(enum.IntEnum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + + def fn(x): + d = {Priority.LOW: x * 2, Priority.MEDIUM: x * 3, Priority.HIGH: x * 4} + return d[Priority.LOW] + d[Priority.MEDIUM] + d[Priority.HIGH] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_frozenset_as_dict_key(self): + def fn(x): + d = {frozenset([1, 2]): x * 2, frozenset([3, 4, 5]): x * 3} + return d[frozenset([1, 2])] + d[frozenset([3, 4, 5])] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_typing_union_as_dict_key(self): + from typing import Union + + def fn(x): + d = {Union[int, str]: x * 2, Union[float, bool]: x * 3} + return d[Union[int, str]] + d[Union[float, bool]] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_numpy_dtype_as_dict_key(self): + import numpy as np + + def fn(x): + d = {np.float32: x * 2, np.int64: x * 3, np.bool_: x * 4} + return d[np.float32] + d[np.int64] + d[np.bool_] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_method_wrapper_as_dict_key(self): + add_method = list.__add__ + mul_method = list.__mul__ + + def fn(x): + # Method wrappers are the type of bound methods on built-in types + d = {add_method: x * 2, mul_method: x * 3} + return d[add_method] + d[mul_method] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_torch_builtin_function_as_dict_key(self): + def fn(x, y): + # Using torch built-in functions as dictionary keys + d = {torch.add: x * 2, torch.mul: y * 3, torch.sub: x + y} + return d[torch.add] + d[torch.mul] + d[torch.sub] + + x = torch.randn(4) + y = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x, y), opt_fn(x, y)) + + def test_frozen_dataclass_as_dict_key(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class Point: + x: int + y: int + + def fn(tensor): + p1 = Point(1, 2) + p2 = Point(3, 4) + d = {p1: tensor * 2, p2: tensor * 3} + return d[Point(1, 2)] + d[Point(3, 4)] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_list_as_dict_key_raises_typeerror(self): + def fn(x): + d = {[1, 2, 3]: x * 2} + return d[[1, 2, 3]] + + x = torch.randn(4) + + # First check that eager execution raises TypeError + with self.assertRaises(TypeError): + fn(x) + + # Also check that compiled version raises TypeError + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaisesRegex(Unsupported, "Observed exception"): + opt_fn(x) + + def test_get_default_nowrap_functions_as_dict_key(self): + def fn(x): + # Get the set of default nowrap functions + nowrap_funcs = torch.overrides.get_default_nowrap_functions() + # Use the set as a dict key and search for Tensor.grad.__get__ in it + d = {frozenset(nowrap_funcs): x * 2} + # Check if Tensor.grad.__get__ is in the set + if torch.Tensor.grad.__get__ in nowrap_funcs: + return d[frozenset(nowrap_funcs)] + x + return x + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + instantiate_parametrized_tests(DictTests) @@ -1738,7 +1887,9 @@ def fn(x): new_gn = partial(gn, x=1) key = Container(new_gn, 4) new_dict[key] = 5 - return x * new_dict[key] + # Make another key that should hash to the same value + key1 = Container(new_gn, 4) + return x * new_dict[key1] x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -1747,6 +1898,53 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) + def test_custom_object_as_dict_key(self): + """Test that custom objects with __hash__ as dict keys are properly handled. + + This test verifies that when using custom objects with overridden __hash__ + and __eq__ as dictionary keys, two instances with the same hash and equality + should be recognized as the same key. + """ + + class CustomKey: + def __init__(self, value, name): + self.value = value + self.name = name + + def fn(x): + d = {} + # Create first instance + key1 = CustomKey(42, "test") + d[key1] = x * 2 + + # Create second instance with same values - should hash to same value + key2 = CustomKey(42, "test") + d[key2] = x * 3 # This should overwrite the first value + + return d[key1] * d[key2] + + x = torch.randn(4) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertTrue(same(opt_fn(x), fn(x))) + + def test_user_defined_object(self): + class A: + def __init__(self): + self.x = {} + REF[self] = {} + + REF = {} + + def f(a, x): + REF[a]["foo"] = x + return x + 1 + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + + x = torch.randn(4) + self.assertTrue(same(f(A(), x), opt_f(A(), x))) + class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict diff --git a/test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu b/test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 5f967971005f6..9bfe593417699 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3667,5 +3667,49 @@ "Use custom operators instead of direct attribute/method access." ] } + ], + "GB0363": [ + { + "Gb_type": "User-defined object with overridden __hash__", + "Context": "hashing object of type={type(obj)} and variable tracker {vt}", + "Explanation": "Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", + "Hints": [ + "Dynamo does not support hashing user-defined objects with overridden __hash__", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0364": [ + { + "Gb_type": "Dynamo cannot determine whether the underlying object is hashable", + "Context": "is_python_hashable {self}", + "Explanation": "Dynamo does not know whether the underlying python object for {self} is hashable", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0365": [ + { + "Gb_type": "Dynamo cannot determine the hash of an object", + "Context": "get_python_hash {self}", + "Explanation": "Dynamo does not know the hash of the underlying python object for {self}", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0366": [ + { + "Gb_type": "Dynamo cannot determine the equality comparison of an object", + "Context": "is_python_equal {self}", + "Explanation": "Dynamo does not know the equality comparison of the underlying python object for {self}", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } ] } diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index c6825737ec994..5b1070aad5ad6 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4956,3 +4956,21 @@ def get_traced_code() -> Optional[list[CodeType]]: from torch._guards import TracingContext return TracingContext.get_traced_code() + + +def raise_on_overridden_hash(obj: Any, vt: VariableTracker) -> None: + from . import graph_break_hints + from .exc import unimplemented + + is_overridden = type(obj).__dict__.get("__hash__", False) + + if is_overridden: + unimplemented( + gb_type="User-defined object with overridden __hash__", + context=f"hashing object of type={type(obj)} and variable tracker {vt}", + explanation=f"Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", + hints=[ + "Dynamo does not support hashing user-defined objects with overridden __hash__", + *graph_break_hints.SUPPORTABLE, + ], + ) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 617f787e43d8a..0dcf75d344060 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -683,6 +683,57 @@ def build( else: return variables.LazyVariableTracker.create(value, source) + def is_python_hashable(self): + """ + Unlike the variable tracker's own __hash__, this method checks whether + the underlying Python object referenced by this variable tracker is hashable. + """ + unimplemented( + gb_type="Dynamo cannot determine whether the underlying object is hashable", + context=f"is_python_hashable {self}", + explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def get_python_hash(self): + """ + Unlike the variable tracker’s own __hash__, this method is used by + ConstDictVariableTracker to compute the hash of the underlying key object. + """ + unimplemented( + gb_type="Dynamo cannot determine the hash of an object", + context=f"get_python_hash {self}", + explanation=f"Dynamo does not know the hash of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def is_python_equal(self, other): + """ + NB - Deliberately not overriding the __eq__ method because that can + disable the __hash__ for the vt itself. + """ + unimplemented( + gb_type="Dynamo cannot determine the equality comparison of an object", + context=f"is_python_equal {self}", + explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + def __init__( self, *, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index ae6678628634a..8fdaefea56f89 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -3243,6 +3243,15 @@ def call_contains( ) -> VariableTracker: return a.call_method(tx, "__contains__", [b], {}) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn + @contextlib.contextmanager def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]: diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 672fa1d804383..0b2eaaea80826 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -23,6 +23,7 @@ istype, np, raise_args_mismatch, + raise_on_overridden_hash, ) from .base import ValueMutationNew, VariableTracker @@ -340,6 +341,20 @@ def call_obj_hasattr( result = hasattr(self.value, name) return variables.ConstantVariable.create(result) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + # Could be an EnumVariable as well + from .tensor import SymNodeVariable + + if isinstance(other, SymNodeVariable): + return self.as_python_constant() == other.evaluate_expr() + return self.as_python_constant() == other.as_python_constant() + class EnumVariable(VariableTracker): """VariableTracker for enum.Enum and enum.IntEnum instances @@ -388,3 +403,13 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker member = getattr(self.value, name) source = self.source and AttrSource(self.source, name) return VariableTracker.build(tx, member, source=source) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 422cae7c4d3f1..9b98c91723063 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -20,14 +20,11 @@ import collections import functools -import inspect import operator import types -from collections.abc import Hashable as py_Hashable, Sequence +from collections.abc import Sequence from typing import Any, Optional, TYPE_CHECKING, Union -from torch._subclasses.fake_tensor import is_fake - from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import raise_observed_exception, unimplemented @@ -55,8 +52,8 @@ # [Adding a new supported class within the keys of ConstDictVariable] -# - Add its tracker type to is_hashable -# - (perhaps) Define how it is compared in _HashableTracker._eq_impl +# - Implement is_python_hashable() method in the VariableTracker subclass +# - Implement get_python_hash() and is_python_equal() methods for hashable types def was_instancecheck_override(obj: Any) -> bool: @@ -73,7 +70,7 @@ def raise_unhashable( raise_observed_exception( TypeError, tx, - args=[ConstantVariable(f"unhashable type: {type(arg.realize())}")], + msg=f"Unhashable type: {arg.python_type()!r} and variable tracker = {type(arg.realize())}", ) @@ -88,52 +85,7 @@ def is_hashable(x: VariableTracker) -> bool: and x.is_hashable() ): return True - - if isinstance(x, variables.TensorVariable): - # Tensors are hashable if they have an example_value (a fake tensor) - # Most VT's should have one. - # It'd be nice if at some point we could assert that they all have one - return x.as_proxy().node.meta.get("example_value") is not None - elif isinstance(x, variables.TupleVariable): - return all(is_hashable(e) for e in x.items) - elif isinstance(x, variables.FrozenDataClassVariable): - return all(is_hashable(e) for e in x.fields.values()) - elif ( - isinstance(x, variables.UserDefinedObjectVariable) - and not was_instancecheck_override(x.value) - and inspect.getattr_static(x.value, "__hash__") is int.__hash__ - and isinstance(x.value, int) - ): - return isinstance(x.value, py_Hashable) - elif isinstance(x, variables.FunctoolsPartialVariable): - return ( - is_hashable(x.func) - and all(is_hashable(arg) for arg in x.args) - and all(is_hashable(value) for value in x.keywords.values()) - ) - else: - return isinstance( - x, - ( - variables.BuiltinVariable, - variables.SymNodeVariable, - variables.ConstantVariable, - variables.EnumVariable, - variables.FrozensetVariable, - variables.UserDefinedClassVariable, - variables.UserFunctionVariable, - variables.SkipFunctionVariable, - variables.misc.NumpyVariable, - variables.NNModuleVariable, - variables.UnspecializedNNModuleVariable, - variables.MethodWrapperVariable, - variables.TorchInGraphFunctionVariable, - variables.TypingVariable, - variables.FunctoolsPartialVariable, - variables.WeakRefVariable, - variables.TorchHigherOrderOperatorVariable, - ), - ) + return x.is_python_hashable() class ConstDictVariable(VariableTracker): @@ -154,88 +106,47 @@ class _HashableTracker: def __init__(self, vt: VariableTracker) -> None: # We specialize SymNodes vt = specialize_symnode(vt) - # TODO Temporarily remove to figure out what keys are we breaking on - # and add proper support for them + + # If Dynamo does not know the hashability of the vt, it will raise unsupported here if not is_hashable(vt): raise_unhashable(vt) self.vt = vt - @property - def underlying_value(self) -> Any: + def __hash__(self) -> int: + """ + Computes the hash value for the wrapped VariableTracker. + + For unrealized LazyVariableTrackers, uses the hash of the original value + to avoid realizing the tracker and inserting unnecessary guards. + For all other cases, delegates to the VariableTracker's get_python_hash method. + + Returns: + The hash value of the underlying variable tracker + """ if ( isinstance(self.vt, variables.LazyVariableTracker) and not self.vt.is_realized() and self.vt.is_hashable() ): - return self.vt.original_value() - if isinstance(self.vt, variables.TensorVariable): - x = self.vt.as_proxy().node.meta["example_value"] - elif isinstance(self.vt, variables.TupleVariable): - Hashable = ConstDictVariable._HashableTracker - x = tuple(Hashable(e).underlying_value for e in self.vt.items) - elif isinstance(self.vt, variables.NNModuleVariable): - return self.vt.value - elif isinstance(self.vt, variables.UnspecializedNNModuleVariable): - return self.vt.value - elif isinstance(self.vt, variables.UserFunctionVariable): - return self.vt.get_function() - elif isinstance(self.vt, variables.WeakRefVariable): - # Access the underlying value inside the referent_vt for the key representation - Hashable = ConstDictVariable._HashableTracker - return Hashable(self.vt.referent_vt).underlying_value - elif isinstance(self.vt, variables.FrozenDataClassVariable): - Hashable = ConstDictVariable._HashableTracker - fields_values = { - k: Hashable(v).underlying_value - for k, v in self.vt.fields.items() # type: ignore[attr-defined] - } - return variables.FrozenDataClassVariable.HashWrapper( - self.vt.python_type(), fields_values - ) - elif isinstance(self.vt, variables.UserDefinedObjectVariable): - # The re module in Python 3.13+ has a dictionary (_cache2) with - # an object as key (`class _ZeroSentinel(int): ...`): - # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual - return self.vt.value # type: ignore[attr-defined,union-attr] - elif isinstance(self.vt, variables.FunctoolsPartialVariable): - Hashable = ConstDictVariable._HashableTracker - items = (self.vt.func, *self.vt.args, *self.vt.keywords.values()) - x = tuple(Hashable(e).underlying_value for e in items) - return x - else: - x = self.vt.as_python_constant() - return x + return hash(self.vt.original_value()) + return self.vt.get_python_hash() - def __hash__(self) -> int: - return hash(self.underlying_value) - - @staticmethod - def _eq_impl(a: Any, b: Any) -> bool: - # TODO: Put this in utils and share it between variables/builtin.py and here - type_a, type_b = type(a), type(b) - if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)): - return False - - if isinstance(a, tuple): - Hashable = ConstDictVariable._HashableTracker - return len(a) == len(b) and all( - Hashable._eq_impl(u, v) for u, v in zip(a, b) - ) - elif is_fake(a): - return a is b - else: - return a == b + def __eq__(self, other) -> bool: + """ + Checks equality between two _HashableTracker instances. - def __eq__(self, other: object) -> bool: - Hashable = ConstDictVariable._HashableTracker - assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), ( - type(other) - ) - if isinstance(other, Hashable): - return Hashable._eq_impl(self.underlying_value, other.underlying_value) + Delegates to the VariableTracker's is_python_equal method to compare + the underlying variable trackers for Python-level equality. + + Args: + other: Another _HashableTracker instance to compare with - # constant - return Hashable._eq_impl(self.underlying_value, other) + Returns: + True if the underlying variable trackers are Python-equal, False otherwise + """ + if self.vt is other.vt: + return True + return self.vt.is_python_equal(other.vt) def __init__( self, @@ -324,7 +235,7 @@ def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return ( - is_hashable(vt) + vt.is_python_hashable() and Hashable(vt) in self.items and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) ) @@ -536,8 +447,6 @@ def call_method( Hashable = ConstDictVariable._HashableTracker - arg_hashable = args and is_hashable(args[0]) - if name == "__init__": temp_dict_vt = variables.BuiltinVariable(dict).call_dict( tx, *args, **kwargs @@ -606,6 +515,7 @@ def call_method( self.install_dict_keys_match_guard() return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -620,16 +530,21 @@ def call_method( tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = args[1] return ConstantVariable.create(None) - elif name == "__delitem__" and arg_hashable and self.is_mutable(): - self.install_dict_keys_match_guard() - self.should_reconstruct_all = True - tx.output.side_effects.mutation(self) - self.items.__delitem__(Hashable(args[0])) - return ConstantVariable.create(None) + elif name == "__delitem__" and self.is_mutable(): + arg_hashable = args and is_hashable(args[0]) + if arg_hashable: + self.install_dict_keys_match_guard() + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.__delitem__(Hashable(args[0])) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) elif name == "get": if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -645,6 +560,7 @@ def call_method( if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -736,6 +652,7 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -751,6 +668,7 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -903,6 +821,12 @@ def clone(self, **kwargs: Any) -> VariableTracker: self.install_dict_keys_match_guard() return super().clone(**kwargs) + def is_python_hashable(self): + """ + Dictionaries are mutable and therefore not hashable in Python. + """ + return False + class MappingProxyVariable(VariableTracker): # proxies to the original dict_vt @@ -1416,6 +1340,18 @@ def call_method( return FrozensetVariable(r.items) # type: ignore[attr-defined] return super().call_method(tx, name, args, kwargs) + def is_python_hashable(self): + """ + Frozensets are immutable and hashable in Python. + """ + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class DictKeySetVariable(SetVariable): def debug_repr(self) -> str: @@ -1605,3 +1541,9 @@ def call_method( return self.dv_dict.call_method(tx, "__eq__", [args[0].dv_dict], {}) return ConstantVariable.create(False) return super().call_method(tx, name, args, kwargs) + + def is_python_hashable(self): + """ + Dictionary item views are not hashable in Python. + """ + return False diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index deee9bcec42de..360c0fdd94488 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -807,6 +807,15 @@ def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: return collected return None + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.UserFunctionVariable) and self.fn is other.fn + class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): _nonvar_fields = { @@ -1963,6 +1972,15 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return fn_var_getattr(tx, self.value, self.source, name) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class WrappedSkipFunctionVariable(SkipFunctionVariable): def __init__( @@ -2349,6 +2367,34 @@ def guard_as_python_constant(self) -> Any: **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, ) + def is_python_hashable(self) -> bool: + return ( + self.func.is_python_hashable() + and all(arg.is_python_hashable() for arg in self.args) + and all(value.is_python_hashable() for value in self.keywords.values()) + ) + + def get_python_hash(self): + func_hash = self.func.get_python_hash() + args_hash = (arg.get_python_hash() for arg in self.args) + values_hash = (value.get_python_hash() for value in self.keywords.values()) + return hash((func_hash, *args_hash, *values_hash)) + + def is_python_equal(self, other): + return ( + self.func.is_python_equal(other.func) + and all( + arg_a.is_python_equal(arg_b) + for (arg_a, arg_b) in zip(self.args, other.args) + ) + and all( + value_a.is_python_equal(value_b) + for (value_a, value_b) in zip( + self.keywords.values(), other.keywords.values() + ) + ) + ) + class PolyfilledFunctionVariable(VariableTracker): _nonvar_fields = { diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index afb6522ac0e5c..8b178b3be1ac3 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1738,6 +1738,15 @@ def _call_function( def as_python_constant(self): return self.value + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 4f21e35479fb8..a97c284f9516c 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -620,6 +620,25 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return self.items[fields.index(name)] return super().var_getattr(tx, name) + def is_python_hashable(self): + return True + + def get_python_hash(self): + l = self.range_length() + start = self.start() + step = self.step() + return hash((l, start, step)) + + def is_python_equal(self, other): + if not isinstance(other, variables.RangeVariable): + return False + + return ( + self.start() == other.start() + and self.step() == other.step() + and self.stop() == other.stop() + ) + class CommonListMethodsVariable(BaseListVariable): """ @@ -981,6 +1000,9 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) + def is_python_hashable(self): + return False + class DequeVariable(CommonListMethodsVariable): def __init__( @@ -1170,6 +1192,18 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) + def is_python_hashable(self): + return all(item.is_python_hashable() for item in self.items) + + def get_python_hash(self): + items = tuple(x.get_python_hash() for x in self.items) + return hash(items) + + def is_python_equal(self, other): + return isinstance(other, variables.TupleVariable) and all( + a.is_python_equal(b) for (a, b) in zip(self.items, other.items) + ) + class SizeVariable(TupleVariable): """torch.Size(...)""" diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 8d074f913dbf5..5bd8ad5d075e6 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1306,6 +1306,15 @@ def is_python_constant(self): def as_python_constant(self): return self.method_wrapper + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class GetSetDescriptorVariable(VariableTracker): def __init__(self, desc, **kwargs) -> None: @@ -1440,6 +1449,15 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # codegen.append_output(codegen.create_load_const(self.value)) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): @@ -1618,6 +1636,15 @@ def as_proxy(self): return super().as_proxy() + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls class NullVariable(VariableTracker): @@ -2097,3 +2124,13 @@ def reconstruct(self, codegen: "PyCodegen"): codegen(self.referent_vt) codegen(self.callback_vt) codegen.extend_output(create_call_function(2, False)) + + def is_python_hashable(self): + return self.referent_vt.is_python_hashable() + + def get_python_hash(self): + # weakref relies on the referent's hash + return self.referent_vt.get_python_hash() + + def is_python_equal(self, other): + return self.referent_vt.is_python_equal(other.referent_vt) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 0787ef7c49b57..548e69ef0262d 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1428,6 +1428,20 @@ def set_name_hint(self, name: str): self.proxy.node._rename(name) self._is_name_set = True + def is_python_hashable(self): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return self.as_proxy().node.meta["example_value"] is not None + + def get_python_hash(self): + return hash(self.as_proxy().node.meta["example_value"]) + + def is_python_equal(self, other): + a = self.as_proxy().node.meta["example_value"] + b = other.as_proxy().node.meta["example_value"] + return a is b + class SymNodeVariable(VariableTracker): """ @@ -1516,6 +1530,20 @@ def call_method( ), ) + def is_python_hashable(self): + return True + + def get_python_hash(self): + # Essentially convert the SymNode to a constant variable whenever its + # searched for a dict key. + return hash(self.evaluate_expr()) + + def is_python_equal(self, other): + if isinstance(other, SymNodeVariable): + return self.evaluate_expr() == other.evaluate_expr() + # could be constant variable as well + return self.evaluate_expr() == other.as_python_constant() + class NumpyNdarrayVariable(TensorVariable): """ diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 76da71f6fb323..78d87a09713ab 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -2075,6 +2075,15 @@ def torch_function_override_enabled(self, tx, args, kwargs): ) ) and can_dispatch_torch_function(tx, args, kwargs) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class DispatchKeySetVariable(BaseTorchVariable): """represents torch.DispatchKeySet""" diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index e87af5b87a75a..9de51061cbe31 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -89,6 +89,7 @@ object_has_getattribute, proxy_args_kwargs, raise_args_mismatch, + raise_on_overridden_hash, set_methods, tensortype_to_dtype, tuple_methods, @@ -927,6 +928,18 @@ def const_getattr(self, tx: "InstructionTranslator", name): return self.value.__name__ return super().const_getattr(tx, name) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return ( + isinstance(other, variables.UserDefinedClassVariable) + and self.value is other.value + ) + class UserDefinedExceptionClassVariable(UserDefinedClassVariable): @property @@ -1743,26 +1756,20 @@ def call_obj_hasattr( handle_observed_exception(tx) return variables.ConstantVariable.create(False) + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True -class FrozenDataClassVariable(UserDefinedObjectVariable): - class HashWrapper: - """This class is hashed if a dataclass is used as a key in a dict. - It's necessary to avoid side effects from calling the __init__ of the dataclass class when hashing""" + def get_python_hash(self): + # default hash + return hash(self.value) - def __init__(self, c, fields): - self.cls = c - self.fields = tuple(fields.items()) + def is_python_equal(self, other): + # id check + return self.value is other.value - def __eq__(self, other): - return ( - type(self) is type(other) - and self.cls == other.cls - and self.fields == other.fields - ) - - def __hash__(self): - return hash((self.cls, self.fields)) +class FrozenDataClassVariable(UserDefinedObjectVariable): @staticmethod def create(tx, value, source): from dataclasses import fields @@ -1860,6 +1867,22 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value_type.__name__})" + def is_python_hashable(self): + # TODO - Check corner cases like eq=False, hash=False etc + return True + + def get_python_hash(self): + return hash(tuple(arg.get_python_hash() for arg in self.fields.values())) + + def is_python_equal(self, other): + is_class_same = self.python_type() is other.python_type() + is_field_name_same = self.fields.keys() == other.fields.keys() + is_field_value_same = all( + value_a.is_python_equal(value_b) + for value_a, value_b in zip(self.fields.values(), other.fields.values()) + ) + return is_class_same and is_field_name_same and is_field_value_same + class SourcelessGraphModuleVariable(UserDefinedObjectVariable): def __init__( From 076e7b19fa1d481ad778d06d2b49ba57d3ce8c88 Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Fri, 28 Nov 2025 18:31:56 +0000 Subject: [PATCH 067/338] [3.14] Skip broken numpy test (#169030) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169030 Approved by: https://github.com/guilhermeleobas, https://github.com/williamwen42 --- test/test_numpy_interop.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index 6ed34f2559a18..bc4742e88841e 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -4,6 +4,7 @@ import sys from itertools import product +from unittest import skipIf import numpy as np @@ -32,6 +33,11 @@ def test_numpy_non_writeable(self, device): self.assertWarns(UserWarning, lambda: torch.from_numpy(arr)) @onlyCPU + @skipIf( + sys.version_info[:2] == (3, 14) + and np.lib.NumpyVersion(np.__version__) < "2.4.0", + "Broken in older numpy versions, see https://github.com/numpy/numpy/issues/30265", + ) def test_numpy_unresizable(self, device) -> None: x = np.zeros((2, 2)) y = torch.from_numpy(x) # noqa: F841 From 45d310ad84854dff730c0b12e577d7998d978686 Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Fri, 28 Nov 2025 17:28:02 +0000 Subject: [PATCH 068/338] [3.14] Fix dynamo error on np.broadcast_shapes (#168888) Not strictly related to 3.14, but this is exposed by upgrading scipy versions. np.broadcast_shapes seems to have been untested previously. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168888 Approved by: https://github.com/guilhermeleobas, https://github.com/williamwen42 --- test/torch_np/test_function_base.py | 6 ++++++ torch/_numpy/_funcs_impl.py | 6 ++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/torch_np/test_function_base.py b/test/torch_np/test_function_base.py index 2514856761802..bd5d0ae39e4b8 100644 --- a/test/torch_np/test_function_base.py +++ b/test/torch_np/test_function_base.py @@ -34,5 +34,11 @@ def test_basic(self): np.append([[1, 2, 3], [4, 5, 6]], [7, 8, 9], axis=0) +class TestMisc(TestCase): + def test_broadcast_shapes(self): + result = np.broadcast_shapes((1, 2), (2, 2)) + assert_equal(result, (2, 2)) + + if __name__ == "__main__": run_tests() diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py index f57e7fb001fb0..3417a401acb05 100644 --- a/torch/_numpy/_funcs_impl.py +++ b/torch/_numpy/_funcs_impl.py @@ -714,8 +714,10 @@ def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False): return torch.broadcast_to(array, size=shape) -# This is a function from tuples to tuples, so we just reuse it -from torch import broadcast_shapes +# This is a function from tuples to tuples, so we just reuse it. However, +# dynamo expects its __module__ to be torch._numpy +def broadcast_shapes(*args): + return torch.broadcast_shapes(*args) def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False): From d5038950bacfe36bbf24a47a455fe76901deb8e8 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 1 Dec 2025 00:42:45 +0000 Subject: [PATCH 069/338] Avoid std::tie and returning value constructions in qconv_unpack.cpp (#169207) This PR avoids returning value construction in `qconv_unpack.cpp`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169207 Approved by: https://github.com/Skylion007 --- .../ATen/native/quantized/qconv_unpack.cpp | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/native/quantized/qconv_unpack.cpp b/aten/src/ATen/native/quantized/qconv_unpack.cpp index 4c2352a396177..dcbfa7fdcf3f1 100644 --- a/aten/src/ATen/native/quantized/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/qconv_unpack.cpp @@ -82,32 +82,28 @@ class QConv1dUnpackWeightsInt8 final { static std::tuple> run( const c10::intrusive_ptr>& packed_weight) { auto& ctx = at::globalContext(); - at::Tensor weight; - std::optional bias; #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM || ctx.qEngine() == at::QEngine::X86) { - std::tie(weight, bias) = packed_weight->unpack(); - weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return std::tuple>(weight, bias); + auto result = packed_weight->unpack(); + std::get<0>(result).squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return result; } #endif #ifdef USE_PYTORCH_QNNPACK if (ctx.qEngine() == at::QEngine::QNNPACK) { - std::tie(weight, bias) = packed_weight->unpack(); - at::Tensor new_weight = weight.clone(); - new_weight = new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return std::tuple>(new_weight, bias); + auto result = packed_weight->unpack(); + std::get<0>(result).squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return result; } #endif #if AT_MKLDNN_ENABLED() if (ctx.qEngine() == at::QEngine::ONEDNN) { - std::tie(weight, bias) = packed_weight->unpack(); - at::Tensor new_weight = weight.clone(); - new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return std::tuple>(new_weight, bias); + auto result = packed_weight->unpack(); + std::get<0>(result).squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return result; } #endif From 42e9005cda22da3f1c559c3649218cebd671027c Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 1 Dec 2025 02:33:27 +0000 Subject: [PATCH 070/338] Remove unused TVM code in dynamo (#169247) The returned `log_file` is always opened, it's unclear why the branch was created to check the file existence. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169247 Approved by: https://github.com/yewentao256, https://github.com/jansel --- torch/_dynamo/backends/tvm.py | 37 +++++++---------------------------- 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index 92258d55d48c6..02dde50de0fe0 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -82,37 +82,14 @@ def tvm( # pyrefly: ignore [import-error] from tvm import auto_scheduler - log_file = tempfile.NamedTemporaryFile() - - # pyrefly: ignore [bad-argument-type] - if not os.path.exists(log_file): - tasks, task_weights = auto_scheduler.extract_tasks( - mod["main"], params, target - ) - if len(tasks) != 0: - tuner = auto_scheduler.TaskScheduler(tasks, task_weights) - # pyrefly: ignore [bad-argument-type] - if not os.path.exists(log_file): - assert trials > 0 - tune_option = auto_scheduler.TuningOptions( - num_measure_trials=trials, - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], - early_stopping=2000, - ) - try: - tuner.tune(tune_option) - except Exception: - # pyrefly: ignore [bad-argument-type] - if os.path.exists(log_file): - # pyrefly: ignore [bad-argument-type] - os.unlink(log_file) - raise - - with auto_scheduler.ApplyHistoryBest(log_file): - with tvm.transform.PassContext( + with ( + tempfile.NamedTemporaryFile() as log_file, + auto_scheduler.ApplyHistoryBest(log_file), + tvm.transform.PassContext( opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True} - ): - lib = relay.build(mod, target=target, params=params) + ), + ): + lib = relay.build(mod, target=target, params=params) elif scheduler == "meta_schedule": # pyrefly: ignore [import-error] from tvm import meta_schedule as ms From 94ca8d5f1e81fea3ae488650a0fb6795049a9f87 Mon Sep 17 00:00:00 2001 From: Tomasz Bohutyn Date: Mon, 1 Dec 2025 02:44:42 +0000 Subject: [PATCH 071/338] [xpu][feature] enable Sycl CPP extension on Windows (#162579) PR for enabling #153265 on Windows Pull Request resolved: https://github.com/pytorch/pytorch/pull/162579 Approved by: https://github.com/dvrogozh, https://github.com/EikanWang, https://github.com/guangyey, https://github.com/albanD --- test/test_cpp_extensions_jit.py | 8 ++- torch/utils/cpp_extension.py | 106 +++++++++++++++++++++++++++----- 2 files changed, 95 insertions(+), 19 deletions(-) diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index bacff3c396569..541aef8499b6b 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -141,7 +141,6 @@ def _test_jit_xpu_extension(self, extra_sycl_cflags): sources=[sycl_file], extra_sycl_cflags=extra_sycl_cflags, verbose=True, - keep_intermediates=True, build_directory=temp_dir, ) @@ -155,7 +154,12 @@ def _test_jit_xpu_extension(self, extra_sycl_cflags): # 2 * sigmoid(0) = 2 * 0.5 = 1 self.assertEqual(z, torch.ones_like(z)) finally: - shutil.rmtree(temp_dir) + if IS_WINDOWS: + # rmtree returns permission error: [WinError 5] Access is denied + # on Windows, this is a workaround + subprocess.run(["rd", "/s", "/q", temp_dir], stdout=subprocess.PIPE) + else: + shutil.rmtree(temp_dir) @unittest.skipIf(not (TEST_XPU), "XPU not found") def test_jit_xpu_extension(self): diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index dd0e42a4ae0cd..14ddcbf732b91 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -576,12 +576,37 @@ def _append_sycl_std_if_no_std_present(cflags) -> None: def _wrap_sycl_host_flags(cflags): + host_cflags = [] host_cxx = get_cxx_compiler() - host_cflags = [ - f'-fsycl-host-compiler={host_cxx}', - shlex.quote(f'-fsycl-host-compiler-options={cflags}'), - ] - return host_cflags + if IS_WINDOWS: + for flag in cflags: + if flag.startswith("-I"): + flag = flag.replace("\\", "\\\\").replace("-I", "/I") + else: + flag = flag.replace("-D", "/D") + flag = flag.replace('"', '\\"') + host_cflags.append(flag) + joined_host_cflags = ' '.join(host_cflags) + + external_include = _join_sycl_home("include").replace("\\", "\\\\") + + # Some versions of DPC++ compiler pass paths to SYCL headers as user include paths (`-I`) rather + # than system paths (`-isystem`). This makes host compiler to report warnings encountered in the + # SYCL headers, such as deprecated warnings, even if warmed API is not actually used in the program. + # We expect that this issue will be addressed in the later version of DPC++ compiler. To workaround the + # issue now we wrap paths to SYCL headers in `/external:I`. Warning free compilation is especially important + # for Windows build as `/sdl` compilation flag assumes that and we will fail compilation otherwise. + wrapped_host_cflags = [ + f"-fsycl-host-compiler={host_cxx}", + f'-fsycl-host-compiler-options="\\"/external:I{external_include}\\" /external:W0 {joined_host_cflags}"', + ] + else: + joined_host_cflags = ' '.join(cflags) + wrapped_host_cflags = [ + f"-fsycl-host-compiler={host_cxx}", + shlex.quote(f"-fsycl-host-compiler-options={joined_host_cflags}"), + ] + return wrapped_host_cflags class BuildExtension(build_ext): @@ -807,6 +832,7 @@ def unix_wrap_ninja_compile(sources, extra_cc_cflags = self.compiler.compiler_so[1:] with_cuda = any(map(_is_cuda_file, sources)) with_sycl = any(map(_is_sycl_file, sources)) + assert not (with_sycl and with_cuda) # extra_postargs can be either: # - a dict mapping cxx/nvcc/sycl to extra flags @@ -862,7 +888,6 @@ def unix_wrap_ninja_compile(sources, host_cflags = [item.replace('"', '\\"') for item in host_cflags] else: host_cflags = [item.replace('"', '\\\\"') for item in host_cflags] - host_cflags = ' '.join(host_cflags) # Note the order: shlex.quote sycl_flags first, _wrap_sycl_host_flags # second. Reason is that sycl host flags are quoted, space containing # strings passed to SYCL compiler. @@ -1015,6 +1040,8 @@ def win_wrap_ninja_compile(sources, else: common_cflags.extend(COMMON_MSVC_FLAGS) with_cuda = any(map(_is_cuda_file, sources)) + with_sycl = any(map(_is_sycl_file, sources)) + assert not (with_sycl and with_cuda) # extra_postargs can be either: # - a dict mapping cxx/nvcc to extra flags @@ -1058,6 +1085,30 @@ def win_wrap_ninja_compile(sources, else: cuda_dlink_post_cflags = None + sycl_cflags = None + sycl_post_cflags = None + sycl_dlink_post_cflags = None + if with_sycl: + sycl_cflags = common_cflags + pp_opts + _COMMON_SYCL_FLAGS + if isinstance(extra_postargs, dict): + sycl_post_cflags = extra_postargs['sycl'] + else: + sycl_post_cflags = list(extra_postargs) + _append_sycl_targets_if_missing(sycl_post_cflags) + append_std17_if_no_std_present(sycl_cflags) + _append_sycl_std_if_no_std_present(sycl_cflags) + host_cflags = common_cflags + pp_opts + post_cflags + append_std17_if_no_std_present(host_cflags) + + sycl_cflags = _nt_quote_args(sycl_cflags) + host_cflags = _nt_quote_args(host_cflags) + + sycl_cflags += _wrap_sycl_host_flags(host_cflags) + sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS.copy() + sycl_dlink_post_cflags += _get_sycl_device_flags(sycl_post_cflags) + sycl_post_cflags = _nt_quote_args(sycl_post_cflags) + + _write_ninja_file_and_compile_objects( sources=sources, objects=objects, @@ -1066,13 +1117,13 @@ def win_wrap_ninja_compile(sources, cuda_cflags=cuda_cflags, cuda_post_cflags=cuda_post_cflags, cuda_dlink_post_cflags=cuda_dlink_post_cflags, - sycl_cflags=None, - sycl_post_cflags=None, - sycl_dlink_post_cflags=None, + sycl_cflags=sycl_cflags, + sycl_post_cflags=sycl_post_cflags, + sycl_dlink_post_cflags=sycl_dlink_post_cflags, build_directory=output_dir, verbose=True, with_cuda=with_cuda, - with_sycl=False) + with_sycl=with_sycl) # Return *all* object filenames, not just the ones we just built. return objects @@ -1492,6 +1543,7 @@ def SyclExtension(name, sources, *args, **kwargs): libraries.append("c10_xpu") libraries.append("torch") libraries.append("torch_cpu") + libraries.append("sycl") if not kwargs.get('py_limited_api', False): # torch_python uses more than the python limited api libraries.append("torch_python") @@ -2107,6 +2159,7 @@ def _jit_compile(name, with_cudnn = any('cudnn' in f for f in extra_ldflags or []) if with_sycl is None: with_sycl = any(map(_is_sycl_file, sources)) + assert not (with_sycl and with_cuda) old_version = JIT_EXTENSION_VERSIONER.get_version(name) version = JIT_EXTENSION_VERSIONER.bump_version_if_changed( name, @@ -2211,6 +2264,7 @@ def _write_ninja_file_and_compile_objects( with_cuda = any(map(_is_cuda_file, sources)) if with_sycl is None: with_sycl = any(map(_is_sycl_file, sources)) + assert not (with_sycl and with_cuda) build_file_path = os.path.join(build_directory, 'build.ninja') if verbose: logger.debug('Emitting ninja build file %s...', build_file_path) @@ -2270,9 +2324,11 @@ def _write_ninja_file_and_build_library( with_cuda = any(map(_is_cuda_file, sources)) if with_sycl is None: with_sycl = any(map(_is_sycl_file, sources)) + assert not (with_sycl and with_cuda) extra_ldflags = _prepare_ldflags( extra_ldflags or [], with_cuda, + with_sycl, verbose, is_standalone) build_file_path = os.path.join(build_directory, 'build.ninja') @@ -2325,7 +2381,7 @@ def verify_ninja_availability() -> None: raise RuntimeError("Ninja is required to load C++ extensions (pip install ninja to get it)") -def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): +def _prepare_ldflags(extra_ldflags, with_cuda, with_sycl, verbose, is_standalone): if IS_WINDOWS: python_lib_path = os.path.join(sys.base_exec_prefix, 'libs') @@ -2385,6 +2441,12 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): else: extra_ldflags.append(f'-L{_join_rocm_home("lib")}') extra_ldflags.append('-lamdhip64') + if with_sycl: + if IS_WINDOWS: + extra_ldflags.append('c10_xpu.lib') + extra_ldflags.append('torch_xpu.lib') + extra_ldflags.append(f'/LIBPATH:{_join_sycl_home("lib")}') + extra_ldflags.append('sycl.lib') return extra_ldflags @@ -2759,7 +2821,7 @@ def _write_ninja_file_to_build_library(path, icpx_version = _get_icpx_version() if int(icpx_version) < 20250200: host_cflags = [item.replace('\\"', '\\\\"') for item in host_cflags] - host_cflags = ' '.join(host_cflags) + sycl_cflags += _wrap_sycl_host_flags(host_cflags) sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS.copy() sycl_dlink_post_cflags += _get_sycl_device_flags(sycl_cflags) @@ -2969,11 +3031,21 @@ def sanitize_flags(flags): cuda_devlink_rule, cuda_devlink = [], [] if sycl_dlink_post_cflags: - sycl_devlink_out = os.path.join(os.path.dirname(objects[0]), 'sycl_dlink.o') - sycl_devlink_rule = ['rule sycl_devlink'] - sycl_devlink_rule.append(' command = $sycl $in -o $out $sycl_dlink_post_cflags') - sycl_devlink = [f'build {sycl_devlink_out}: sycl_devlink {" ".join(objects)}'] - objects += [sycl_devlink_out] + sycl_devlink_out = os.path.join(os.path.dirname(objects[0]), "sycl_dlink.o") + if IS_WINDOWS: + sycl_devlink_objects = [obj.replace(":", "$:") for obj in objects] + objects += [sycl_devlink_out] + sycl_devlink_out = sycl_devlink_out.replace(":", "$:") + else: + sycl_devlink_objects = list(objects) + objects += [sycl_devlink_out] + sycl_devlink_rule = ["rule sycl_devlink"] + sycl_devlink_rule.append( + " command = $sycl $in -o $out $sycl_dlink_post_cflags" + ) + sycl_devlink = [ + f"build {sycl_devlink_out}: sycl_devlink {' '.join(sycl_devlink_objects)}" + ] else: sycl_devlink_rule, sycl_devlink = [], [] From c8210e7d94bad5ae21ac389fa4ba8a463c76c4d0 Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Mon, 1 Dec 2025 06:09:15 +0000 Subject: [PATCH 072/338] [Accelerator] Add Accelerator Capabilities API (#165631) # Motivation There are several issues related to the data type and precision that an accelerator supports (see #165038 and #143112). Sometimes, we have to check for these capabilities in the document, and then hard-code. This PR proposes a new unified API for users to check their accelerator capabilities. # Changes This PR creates a new data structure `DeviceCapability` containing the capabilities that an accelerator commonly has: - Supporting DataType (set to be supported as default): - `fp16`, `int32`, `complex` ... etc - Other capabilities (need to be discussed) To access the structure, this PR defines a new Python API in the Accelerator module -- `get_device_capability`. It takes `device` as an input and returns a dictionary containing the capabilities (now we have `supported_dtypes` as the key). # Usage ```python >>> import torch >>> import torch_openreg >>> torch.accelerator.get_device_capability('openreg:0') {'supported_dtypes': [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, torch.complex32, torch.complex64, torch.complex128, torch.bool, torch.qint8, torch.quint8, torch.qint32, torch.bfloat16, torch.quint4x2, torch.quint2x4, torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16, torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e5m2fnuz, torch.float8_e4m3fnuz, torch.uint16, torch.uint32, torch.uint64, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.int1, torch.int2, torch.int3, torch.int4, torch.int5, torch.int6, torch.int7, torch.float8_e8m0fnu, torch.float4_e2m1fn_x2]} ``` # TODO - So far, precision is the only capability to track, based on my knowledge. But we can find more capabilities in common, and the API should be designed for good extension. - It will support other in-tree accelerators, such as **cuda** and **mps**. - Clarify whether the capabilities are software or hardware supported. (By @guangyey ) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165631 Approved by: https://github.com/fffrog, https://github.com/guangyey, https://github.com/albanD Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Co-authored-by: Jiawei Li --- aten/src/ATen/DeviceAccelerator.cpp | 6 ++ aten/src/ATen/DeviceAccelerator.h | 5 ++ c10/core/DeviceCapability.h | 74 +++++++++++++++++++ c10/core/impl/DeviceGuardImplInterface.h | 26 +++++++ c10/core/impl/VirtualGuardImpl.h | 4 + .../torch_openreg/csrc/runtime/OpenRegGuard.h | 9 +++ .../torch_openreg/tests/test_device.py | 9 ++- torch/_C/__init__.pyi.in | 1 + torch/accelerator/__init__.py | 26 ++++++- torch/csrc/DeviceAccelerator.cpp | 19 +++++ 10 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 c10/core/DeviceCapability.h diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index aa9d6e6b1ce9b..efab9ec9c5927 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -130,6 +130,12 @@ c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) { impl.uncheckedSetDevice({device_type, device_index}); return impl.getDevice().index(); } + +c10::DeviceCapability getDeviceCapability(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + c10::impl::VirtualGuardImpl impl(device_type); + return impl.getDeviceCapability({device_type, device_index}); +} // NOLINTEND(bugprone-unchecked-optional-access) } // namespace at::accelerator diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index 2cc4cff7cd1f2..d24b42ca459e7 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -73,6 +74,10 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +// Get the device capability of the given device index. +TORCH_API c10::DeviceCapability getDeviceCapability( + c10::DeviceIndex device_index); + TORCH_API inline void emptyCache() { const auto device_type = getAccelerator(true).value(); at::getDeviceAllocator(device_type)->emptyCache(); diff --git a/c10/core/DeviceCapability.h b/c10/core/DeviceCapability.h new file mode 100644 index 0000000000000..e24f12614978a --- /dev/null +++ b/c10/core/DeviceCapability.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +constexpr size_t NUMBER_OF_DEVICE_CAPABILITIES = NumScalarTypes; + +// Generate bitfields for each scalar type +#define DEFINE_SCALAR_TYPE(_1, n) unsigned int has_##n : 1; + +// Generate enum indices for each scalar type +#define DEFINE_SCALAR_ENUM(_1, name) kIndex_##name, + +enum ScalarTypeIndex { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_ENUM) +}; + +/** + * @brief DeviceCapability represents the the common capabilities that all + * devices should support. + * + * This struct provides a compact way to represent the common capabilities that + * all devices should support. Includes the following capabilities: + * - Supported data types + * + * Purpose + * - Enable device-specific optimizations based on supported capabilities + * + * Contract + * + * Supported data types: + * - Each bitfield represents support for one device capability + * - Bit value 1 means the capability is supported, 0 means not supported + * - The struct is initialized with all capabilities enabled by default + * + * @note Adding New Capabilities + * + * 1. Define the new capability in the `DeviceCapability` struct + * 2. Update the support of the new capability in each accelerator + * implementation + * 3. Add the new capability to the returned PyObject Dictionary + */ +struct C10_API DeviceCapability { + union { + struct { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE) + }; + uint64_t capability_bits; // Allow direct bit manipulation + }; + + // Default constructor with all capabilities enabled. + DeviceCapability() + : capability_bits((1ULL << NUMBER_OF_DEVICE_CAPABILITIES) - 1) {} + + // Iterate supported ScalarTypes without allocating a vector + template + void forEachSupportedScalarType(F&& visitor) const { +#define VISIT_SCALAR_TYPE(_1, n) \ + if (has_##n) { \ + visitor(ScalarType::n); \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(VISIT_SCALAR_TYPE) + +#undef VISIT_SCALAR_TYPE + } +}; + +#undef DEFINE_SCALAR_ENUM +#undef DEFINE_SCALAR_TYPE +} // namespace c10 diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index f9f67497c6315..00096584b9229 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -191,6 +192,15 @@ struct C10_API DeviceGuardImplInterface { */ virtual DeviceIndex deviceCount() const noexcept = 0; + /** + * Get the following capabilities of the current device: + * (1) Data type support + * Returns DeviceCapability object. + */ + virtual DeviceCapability getDeviceCapability(Device /*unused*/) const { + TORCH_CHECK(false, "Backend doesn't support getting device capabilities."); + } + /** * Return true if all the work previously enqueued on the stream for * asynchronous execution has completed running on the device. @@ -291,6 +301,22 @@ struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface { return 1; } + DeviceCapability getDeviceCapability(Device /*unused*/) const override { + DeviceCapability cap; + if constexpr (D == DeviceType::Meta) { + cap.capability_bits = 0; + // Meta only supports basic types for shape inference + // Byte, Char, Short, Int, Long, Float, Double, + // Bool, ComplexFloat, ComplexDouble + cap.capability_bits = (1ULL << kIndex_Byte) | (1ULL << kIndex_Char) | + (1ULL << kIndex_Short) | (1ULL << kIndex_Int) | + (1ULL << kIndex_Long) | (1ULL << kIndex_Float) | + (1ULL << kIndex_Double) | (1ULL << kIndex_ComplexFloat) | + (1ULL << kIndex_ComplexDouble) | (1ULL << kIndex_Bool); + } + return cap; + } + // Event-related functions void record( void** /*event*/, diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 3d259f5e390e3..0254c69baba00 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -57,6 +57,10 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { return impl_->deviceCount(); } + DeviceCapability getDeviceCapability(Device d) const override { + return impl_->getDeviceCapability(d); + } + // Event functions void record( void** event, diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h index 59bc2d5cdbff5..3c1c1193d3cdb 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -50,6 +51,14 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { return c10::Device(static_type, device_index); } + /** + * Get the device capability for a given device. + * By default, OpenReg has 2 same devices with the same capability. + */ + c10::DeviceCapability getDeviceCapability(c10::Device /*unused*/) const override { + return c10::DeviceCapability(); + } + /** * Set the current device to c10::Device. */ diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py index f925f15600ce7..9cb4a785d36e7 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py @@ -1,7 +1,7 @@ # Owner(s): ["module: PrivateUse1"] import torch -import torch_openreg # noqa: F401 +from torch.testing._internal.common_dtype import get_all_dtypes from torch.testing._internal.common_utils import run_tests, TestCase @@ -31,6 +31,13 @@ def test_invalid_device_index(self): with self.assertRaisesRegex(RuntimeError, "The device index is out of range"): torch.accelerator.set_device_index(2) + def test_device_capability(self): + capability = torch.accelerator.get_device_capability("openreg:0") + supported_dtypes = capability["supported_dtypes"] + expected_dtypes = get_all_dtypes(include_complex32=True, include_qint=True) + + self.assertTrue(all(dtype in supported_dtypes for dtype in expected_dtypes)) + if __name__ == "__main__": run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e9b58b9ce71eb..01c4abd6fab76 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2493,6 +2493,7 @@ def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails def _accelerator_getAccelerator() -> _device: ... def _accelerator_setDeviceIndex(device_index: _int) -> None: ... def _accelerator_getDeviceIndex() -> _int: ... +def _accelerator_getDeviceCapability(device_index: _int) -> dict[str, Any]: ... def _accelerator_setStream(Stream) -> None: ... def _accelerator_getStream(device_index: _int) -> Stream: ... def _accelerator_synchronizeDevice(device_index: _int) -> None: ... diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index e1a82aa63ce22..a1335d2ad03bd 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -2,7 +2,8 @@ This package introduces support for the current :ref:`accelerator` in python. """ -from typing import Optional +from functools import cache +from typing import Any from typing_extensions import deprecated import torch @@ -25,6 +26,7 @@ "current_accelerator", "current_device_idx", # deprecated "current_device_index", + "get_device_capability", "current_stream", "device_count", "device_index", @@ -152,6 +154,28 @@ def current_device_index() -> int: """ +@cache +def get_device_capability(device: _device_t = None, /) -> dict[str, Any]: + r"""Return the capability of the currently selected device. + + Args: + device (:class:`torch.device`, str, int, optional): The device to query capabilities for + :ref:`accelerator` device type. If not given, + use :func:`torch.accelerator.current_device_index` by default. + + Returns: + dict[str, Any]: A dictionary containing device capability information. The dictionary includes: + - ``supported_dtypes`` (set(torch.dtype)): Set of PyTorch data types supported by the device + + Examples: + >>> # Query capabilities for current device + >>> capabilities = torch.accelerator.get_device_capability("cuda:0") + >>> print("Supported dtypes:", capabilities["supported_dtypes"]) + """ + device_index = _get_device_index(device, optional=True) + return torch._C._accelerator_getDeviceCapability(device_index) + + def set_device_index(device: _device_t, /) -> None: r"""Set the current device index to a given device. diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 14e54851178f5..c6ffa893d95ae 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -33,6 +33,25 @@ void initModule(PyObject* module) { return at::accelerator::getDeviceIndex(); }); + m.def("_accelerator_getDeviceCapability", [](c10::DeviceIndex device_index) { + const auto device_type = at::accelerator::getAccelerator(true).value(); + torch::utils::maybe_initialize_device(device_type); + auto caps = at::accelerator::getDeviceCapability(device_index); + + py::dict dict; + + py::set dtype_set; + caps.forEachSupportedScalarType([&](c10::ScalarType dtype) { + THPDtype* thp_dtype = torch::getTHPDtype(dtype); + py::object dtype_obj = + py::reinterpret_borrow((PyObject*)thp_dtype); + dtype_set.add(dtype_obj); + }); + + dict["supported_dtypes"] = dtype_set; + return dict; + }); + m.def("_accelerator_setStream", [](c10::Stream stream) { const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); From 539ba711b029de9f191070f4f0d12f18f5b7f292 Mon Sep 17 00:00:00 2001 From: Hari Krishna Sai Kodali Date: Mon, 1 Dec 2025 06:45:09 +0000 Subject: [PATCH 073/338] add device generalization support for distributed tests (#165067) ## MOTIVATION To generalize Distributed test cases for non-CUDA devices ## CHANGES - Replaced hard coded device/backends with torch.accelerator.current_accelerator() and dist.get_default_backend_for_device - Use DistributedTestBase instead of MultiProcessTestCase to use common utilities - Remove instantiate_device_tests and make use of torch.accelerator.current_accelerator for test/distributed/test_c10d_object_collectives.py - fix deterministic context issue for non-cuda devices in test/distributed/optim/test_zero_redundancy_optimizer.py - use torch.accelerator.device_count() for multi-gpu check in torch/testing/_internal/distributed/_tensor/common_dtensor.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/165067 Approved by: https://github.com/albanD --- .../test_2d_composability.py | 33 +++++----- .../test_pp_composability.py | 33 +++++----- .../ddp_comm_hooks/test_ddp_hooks.py | 62 +++++++------------ .../checkpoint/test_state_dict_utils.py | 36 ++++++----- .../optim/test_zero_redundancy_optimizer.py | 21 +++++-- .../test_c10d_functional_native.py | 21 ++----- .../test_c10d_object_collectives.py | 48 +++++--------- test/distributed/test_device_mesh.py | 10 ++- .../distributed/_tensor/common_dtensor.py | 12 ++-- 9 files changed, 129 insertions(+), 147 deletions(-) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 9375c86d35584..0da7a86d06754 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -64,7 +64,12 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir -device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) +curr_backend = dist.get_default_backend_for_device(device_type) class SimpleModel(nn.Module): @@ -422,10 +427,10 @@ class TestFullyShard2DStateDict(DTensorTestBase): @property def backend(self): # need to specify gloo backend for testing cpu offload - return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl" + return f"cpu:gloo,{device_type}:{curr_backend}" - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_fully_shard_tp_2d_set_full_state_dict(self): dummy_model = SimpleModel().to(device_type) mesh_2d = init_device_mesh( @@ -514,8 +519,8 @@ def _check_module(self, m1, m2, check_grad=False): ).to_local() self.assertEqual(param_m2, param_m1) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_2d_ddp_integration_functionality(self) -> None: model, twod_model, dp_pg = self.init_model(self.device_type) optim = torch.optim.Adam(model.parameters(), lr=3e-5) @@ -566,8 +571,8 @@ def _compare_params(self, m1, m2): p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local() self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}") - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_2d_fsdp_state_enable_extension(self): mesh_2d = init_device_mesh( self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") @@ -642,18 +647,18 @@ def _test_2d_e2e_training( # Ensure all params are still the same after optimizer update. self._compare_params(model, model_2d) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_2d_e2e_training_default(self): self._test_2d_e2e_training() - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_2d_e2e_training_use_orig_params(self): self._test_2d_e2e_training(use_orig_params=True) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_2d_e2e_training_not_use_orig_params(self): # TODO: need to revisit input_reshard API about why it failed multi-gpu tests. # self._test_2d_e2e_training(recompute_activation=True) @@ -666,10 +671,10 @@ class TestNew2dParallelStateDict(DTensorTestBase): @property def backend(self): # need to specify gloo backend for testing cpu offload - return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl" + return f"cpu:gloo,{device_type}:{curr_backend}" - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_fsdp_2d_extension(self): """ Test whether _fsdp_extension from FSDPstate has been set correctly. @@ -700,8 +705,8 @@ def test_fsdp_2d_extension(self): model_1d_fsdp_state = _get_module_fsdp_state(model_1d) self.assertEqual(model_1d_fsdp_state._fsdp_extension, None) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms @parametrize("is_even_sharded_model", [True, False]) def test_2d_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven @@ -756,8 +761,8 @@ def test_2d_state_dict(self, is_even_sharded_model): torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True ) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms @parametrize("is_even_sharded_model", [True, False]) def test_2d_load_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven @@ -811,8 +816,8 @@ def test_2d_load_state_dict(self, is_even_sharded_model): self.assertEqual(v1.device_mesh, v2.device_mesh) self.assertEqual(v1.placements, v2.placements) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms @parametrize("is_even_sharded_model", [True, False]) def test_2d_optim_state_dict(self, is_even_sharded_model): simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven @@ -899,9 +904,9 @@ def test_2d_optim_state_dict(self, is_even_sharded_model): else: self.assertEqual(new_state, state) + @skip_if_lt_x_gpu(4) @with_comms @with_temp_dir - @skip_if_lt_x_gpu(4) def test_fsdp1_tp_2d_set_full_state_dict(self): """ This is a workaround for loading full state dict into a FSDP1+TP 2D model. diff --git a/test/distributed/_composable/test_composability/test_pp_composability.py b/test/distributed/_composable/test_composability/test_pp_composability.py index a66518fc0ef0f..3a221bf91a4d6 100644 --- a/test/distributed/_composable/test_composability/test_pp_composability.py +++ b/test/distributed/_composable/test_composability/test_pp_composability.py @@ -29,8 +29,8 @@ parallelize_module, RowwiseParallel, ) -from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( + at_least_x_gpu, MultiProcessTestCase, requires_accelerator_dist_backend, skip_if_lt_x_gpu, @@ -40,7 +40,6 @@ parametrize, run_tests, skip_but_pass_in_sandcastle_if, - TEST_XPU, ) from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir @@ -49,7 +48,11 @@ from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE -device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) backend = torch.distributed.get_default_backend_for_device(device_type) @@ -107,11 +110,9 @@ def world_size(self): def device(self): return self.rank - @requires_accelerator_dist_backend(["nccl", "xccl"]) + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs" - ) + @skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs") def test_pp_and_dcp(self): """ Test that pipeline parallelism and distributed checkpointing can be used together and @@ -201,11 +202,9 @@ def _dcp_test(self): _dcp_test(self) - @requires_accelerator_dist_backend(["nccl", "xccl"]) + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs" - ) + @skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs") @parametrize( "ScheduleClass", [ @@ -355,11 +354,9 @@ def apply_tp( torch.distributed.destroy_process_group() - @requires_accelerator_dist_backend(["nccl", "xccl"]) + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs" - ) + @skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs") @parametrize( "ScheduleClass", [ @@ -550,11 +547,9 @@ def apply_same_precision(partial_model): torch.distributed.destroy_process_group() - @requires_accelerator_dist_backend(["nccl", "xccl"]) + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs" - ) + @skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs") @parametrize( "ScheduleClass", [ diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py index 89a893037c3b5..ee800f73b29d5 100644 --- a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import os import sys import torch @@ -18,8 +17,8 @@ ) from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, - requires_nccl, + DistributedTestBase, + requires_accelerator_dist_backend, skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN @@ -30,9 +29,16 @@ sys.exit(0) +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) + + def gpus_for_rank(world_size): - visible_devices = list(range(torch.cuda.device_count())) - gpus_per_process = torch.cuda.device_count() // world_size + visible_devices = list(range(torch.accelerator.device_count())) + gpus_per_process = torch.accelerator.device_count() // world_size gpus_for_rank = [] for rank in range(world_size): gpus_for_rank.append( @@ -60,27 +66,7 @@ def forward(self, x, rank): return self.t0(x ** (1 + rank)) -class DistributedDataParallelCommHookTest(MultiProcessTestCase): - def setUp(self): - super().setUp() - self._spawn_processes() - - def tearDown(self): - try: - os.remove(self.file_name) - except OSError: - pass - - def _get_process_group_nccl(self): - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="nccl", - world_size=self.world_size, - rank=self.rank, - store=store, - ) - return dist.distributed_c10d._get_default_group() - +class DistributedDataParallelCommHookTest(DistributedTestBase): @property def world_size(self): return 2 @@ -119,14 +105,14 @@ def _run_and_get_grads(self, model): param = next(model.parameters()) return param.grad - @requires_nccl() + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_allreduce_hook(self): """ This unit test verifies the ``allreduce`` hook registered case gives same result with no hook registered case. """ - process_group = self._get_process_group_nccl() + process_group = self.create_pg(device_type) # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -135,14 +121,14 @@ def test_ddp_comm_hook_allreduce_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0) - @requires_nccl() + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_fp16compress_hook(self): """ This unit test verifies the ``fp16 compress`` hook registered case gives close result with no hook registered case. """ - process_group = self._get_process_group_nccl() + process_group = self.create_pg(device_type) # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -151,14 +137,14 @@ def test_ddp_comm_hook_fp16compress_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) - @requires_nccl() + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_quantize_per_tensor_hook(self): """ This unit test verifies the ``quantize per tensor`` hook registered case gives close result with no hook registered case. """ - process_group = self._get_process_group_nccl() + process_group = self.create_pg(device_type) # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -167,14 +153,14 @@ def test_ddp_comm_hook_quantize_per_tensor_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) - @requires_nccl() + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_quantize_per_channel_hook(self): """ This unit test verifies the ``quantize per channel`` hook registered case gives close result with no hook registered case. """ - process_group = self._get_process_group_nccl() + process_group = self.create_pg(device_type) # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -185,14 +171,14 @@ def test_ddp_comm_hook_quantize_per_channel_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) - @requires_nccl() + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) def test_ddp_comm_hook_noop_hook(self): """ This unit test verifies the ``noop`` hook registered case and a subsequent allreduce gives same result with no hook registered case. """ - process_group = self._get_process_group_nccl() + process_group = self.create_pg(device_type) # No hook registered case, get the reference grads. reference_grads = self._get_grads(process_group, None) @@ -204,10 +190,10 @@ def test_ddp_comm_hook_noop_hook(self): torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0) - @requires_nccl() + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) def test_is_last_hook(self): - process_group = self._get_process_group_nccl() + process_group = self.create_pg(device_type) def hook(flags, bucket): flags.append(bucket.is_last()) diff --git a/test/distributed/checkpoint/test_state_dict_utils.py b/test/distributed/checkpoint/test_state_dict_utils.py index 76e9aeb9e3302..c0f850cf95c9c 100644 --- a/test/distributed/checkpoint/test_state_dict_utils.py +++ b/test/distributed/checkpoint/test_state_dict_utils.py @@ -32,7 +32,7 @@ class TestStateDictUtils(DTensorTestBase): @property def world_size(self): - return min(4, torch.cuda.device_count()) + return min(4, torch.accelerator.device_count()) @with_comms @skip_if_lt_x_gpu(2) @@ -49,7 +49,7 @@ def test_gather_state_dict_dtensor(self): dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"]) - self.assertTrue(gathered_state_dict["dtensor"].is_cuda) + self.assertEqual(gathered_state_dict["dtensor"].device.type, self.device_type) @with_comms @skip_if_lt_x_gpu(4) @@ -69,14 +69,16 @@ def test_gather_with_cpu_and_ranks_only(self): ) if dist.get_rank() in (0, 2): self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"]) - self.assertFalse(gathered_state_dict["dtensor"].is_cuda) + self.assertNotEqual( + gathered_state_dict["dtensor"].device.type, self.device_type + ) else: self.assertEqual(gathered_state_dict, {}) @with_comms @skip_if_lt_x_gpu(4) def test_cpu_and_ranks_only(self): - device = torch.device("cuda") + device = torch.device(self.device_type) state_dict = { "tensor1": torch.arange(10, device=device), "tensor2": torch.ones(10, device=device), @@ -85,7 +87,7 @@ def test_cpu_and_ranks_only(self): cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2)) if dist.get_rank() in (0, 2): for v in cpu_state_dict.values(): - self.assertFalse(v.is_cuda) + self.assertNotEqual(v.device.type, self.device_type) self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10)) self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10)) else: @@ -109,27 +111,27 @@ def create_dtensor(): for _ in range(10): tensor, dtensor = create_dtensor() ltensor.append(tensor) - ltensor.append(torch.ones(10, device=torch.device("cuda"))) + ltensor.append(torch.ones(10, device=torch.device(self.device_type))) ldtensor.append(dtensor) - ldtensor.append(torch.ones(10, device=torch.device("cuda"))) + ldtensor.append(torch.ones(10, device=torch.device(self.device_type))) tensor, dtensor = create_dtensor() dist_state_dict = { "local": dtensor, "list": ldtensor, - "arange": torch.arange(10, device=torch.device("cuda")), + "arange": torch.arange(10, device=torch.device(self.device_type)), } state_dict = { "local": tensor, "list": ltensor, - "arange": torch.arange(10, device=torch.device("cuda")), + "arange": torch.arange(10, device=torch.device(self.device_type)), } self.assertEqual(state_dict, _gather_state_dict(dist_state_dict)) @with_comms @skip_if_lt_x_gpu(2) def test_create_cpu_state_dict(self): - device = torch.device("cuda") + device = torch.device(self.device_type) rank = dist.get_rank() # Scale tensors based on world size # to fit in the tensor shards accurately. @@ -149,7 +151,7 @@ def test_create_cpu_state_dict(self): metadata=ShardMetadata( shard_offsets=[5 * rank, 0], shard_sizes=[5, 10], - placement=f"rank:{rank}/cuda:{rank}", + placement=f"rank:{rank}/{self.device_type}:{rank}", ), ) ], @@ -159,7 +161,7 @@ def test_create_cpu_state_dict(self): torch.arange(50 * scale_factor, device=device).reshape( 5 * scale_factor, 10 ), - init_device_mesh("cuda", mesh_shape=(self.world_size,)), + init_device_mesh(self.device_type, mesh_shape=(self.world_size,)), [Shard(0)], ), "non_tensor_bytes_io": copy.deepcopy(buffer), @@ -245,7 +247,7 @@ def test_state_dict_util_distribute_tensors(self): even_tensor = torch.randn(self.world_size, 2) uneven_tensor = torch.randn(1, 2) - mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,)) + mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) even_dtensor = distribute_tensor( torch.randn(self.world_size, 2), mesh, [Shard(0)] ) @@ -273,10 +275,10 @@ def test_state_dict_util_distribute_tensors(self): @with_comms @skip_if_lt_x_gpu(2) def test_cpu_offload_for_dtensor(self): - device_mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,)) + device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) sd = { "k": DTensor.from_local( - torch.ones(8, 8, device="cuda"), device_mesh, [Shard(0)] + torch.ones(8, 8, device=self.device_type), device_mesh, [Shard(0)] ) } cpu_sd = _create_cpu_state_dict(sd) @@ -290,12 +292,12 @@ def test_cpu_offload_for_dtensor(self): self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"])) _copy_state_dict(sd, cpu_sd, non_blocking=True) - torch.cuda.synchronize() + torch.accelerator.synchronize() self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"])) sd["k"] += 1 self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"])) _copy_state_dict(sd, cpu_sd, non_blocking=True) - torch.cuda.synchronize() + torch.accelerator.synchronize() self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"])) diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 35eefdad512e6..e26d67a1d9f1f 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -7,7 +7,7 @@ import copy import sys -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from typing import Any, cast import numpy as np @@ -40,7 +40,6 @@ skip_if_rocm_multiprocess, skip_if_win32, ) -from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -57,7 +56,21 @@ HAS_TORCHVISION = False -device_type = str(get_devtype()) +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) + + +@contextmanager +def deterministic_algorithms(enabled=True): + prev_state = torch.are_deterministic_algorithms_enabled() + torch.use_deterministic_algorithms(enabled) + try: + yield + finally: + torch.use_deterministic_algorithms(prev_state) class TestZeroRedundancyOptimizer(DistributedTestBase): @@ -1241,7 +1254,7 @@ def _test_ddp_zero_overlap( enabled=True, deterministic=True, benchmark=False ) if "cuda" in device - else torch.use_deterministic_algorithms(True) + else deterministic_algorithms(True) ) with det_ctx: device_ids = [rank] if requires_ddp_rank(device) else None diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 0877eb53cd6f5..473198e5421c5 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -24,7 +24,7 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 from torch.testing._internal.common_device_type import e4m3_type from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, + DistributedTestBase, requires_accelerator_dist_backend, skip_if_lt_x_gpu, ) @@ -59,12 +59,8 @@ def load_test_module(name): sys.exit(0) -@requires_accelerator_dist_backend(["nccl", "xccl"]) -class TestWithNCCL(MultiProcessTestCase): - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - +@requires_accelerator_dist_backend() +class TestWithNCCL(DistributedTestBase): @property def world_size(self) -> int: return 2 @@ -78,16 +74,7 @@ def device(self) -> torch.device: return torch.device(self.rank) def _init_process_group(self) -> None: - torch.accelerator.set_device_index(self.rank) - store = dist.FileStore(self.file_name, self.world_size) - backend = dist.get_default_backend_for_device(self.device.type) - - dist.init_process_group( - backend=backend, - world_size=self.world_size, - rank=self.rank, - store=store, - ) + self.create_pg(self.device.type) torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) @skip_if_lt_x_gpu(2) diff --git a/test/distributed/test_c10d_object_collectives.py b/test/distributed/test_c10d_object_collectives.py index 594564c456068..7b97614c8c0ac 100644 --- a/test/distributed/test_c10d_object_collectives.py +++ b/test/distributed/test_c10d_object_collectives.py @@ -11,13 +11,10 @@ print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) -from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS from torch.testing._internal.common_utils import ( run_tests, skipIfHpu, - TEST_CUDA, - TEST_HPU, TEST_WITH_DEV_DBG_ASAN, ) @@ -29,16 +26,12 @@ ) sys.exit(0) -if TEST_HPU: - DEVICE = "hpu" -elif TEST_CUDA: - DEVICE = "cuda" -else: - DEVICE = "cpu" - -device_module = torch.get_device_module(DEVICE) -device_count = device_module.device_count() -BACKEND = dist.get_default_backend_for_device(DEVICE) +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) +device_count = torch.accelerator.device_count() def with_comms(func=None): @@ -49,11 +42,10 @@ def with_comms(func=None): @wraps(func) def wrapper(self, *args, **kwargs): - if DEVICE != "cpu" and device_count < self.world_size: + if device_type != "cpu" and device_count < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) - kwargs["device"] = DEVICE - self.pg = self.create_pg(device=DEVICE) + self.pg = self.create_pg(device=device_type) try: return func(self, *args, **kwargs) finally: @@ -64,7 +56,7 @@ def wrapper(self, *args, **kwargs): class TestObjectCollectives(DistributedTestBase): @with_comms() - def test_all_gather_object(self, device): + def test_all_gather_object(self): output = [None] * dist.get_world_size() dist.all_gather_object(object_list=output, obj=self.rank) @@ -72,7 +64,7 @@ def test_all_gather_object(self, device): self.assertEqual(i, v, f"rank: {self.rank}") @with_comms() - def test_gather_object(self, device): + def test_gather_object(self): output = [None] * dist.get_world_size() if self.rank == 0 else None dist.gather_object(obj=self.rank, object_gather_list=output) @@ -82,7 +74,7 @@ def test_gather_object(self, device): @skipIfHpu @with_comms() - def test_send_recv_object_list(self, device): + def test_send_recv_object_list(self): val = 99 if self.rank == 0 else None object_list = [val] * dist.get_world_size() if self.rank == 0: @@ -96,7 +88,7 @@ def test_send_recv_object_list(self, device): self.assertEqual(None, object_list[0]) @with_comms() - def test_broadcast_object_list(self, device): + def test_broadcast_object_list(self): val = 99 if self.rank == 0 else None object_list = [val] * dist.get_world_size() # TODO test with broadcast_object_list's device argument @@ -105,7 +97,7 @@ def test_broadcast_object_list(self, device): self.assertEqual(99, object_list[0]) @with_comms() - def test_scatter_object_list(self, device): + def test_scatter_object_list(self): input_list = list(range(dist.get_world_size())) if self.rank == 0 else None output_list = [None] dist.scatter_object_list( @@ -123,34 +115,30 @@ def setup_sub_pg(self): my_pg = dist.new_group(ranks, use_local_synchronization=True) return rank, ranks, my_pg - @skipIfHpu @with_comms() - def test_subpg_scatter_object(self, device): + def test_subpg_scatter_object(self): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg) self.assertEqual(rank, out_list[0]) - @skipIfHpu @with_comms() - def test_subpg_all_gather_object(self, device): + def test_subpg_all_gather_object(self): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] * len(ranks) dist.all_gather_object(out_list, rank, group=my_pg) self.assertEqual(ranks, out_list) - @skipIfHpu @with_comms() - def test_subpg_gather_object(self, device): + def test_subpg_gather_object(self): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] * len(ranks) if rank == ranks[0] else None dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg) if rank == ranks[0]: self.assertEqual(ranks, out_list) - @skipIfHpu @with_comms() - def test_subpg_broadcast_object(self, device): + def test_subpg_broadcast_object(self): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] if rank == ranks[0]: @@ -159,7 +147,5 @@ def test_subpg_broadcast_object(self, device): self.assertEqual(ranks[0], out_list[0]) -devices = ("cpu", "cuda", "hpu") -instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices) if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index a0de1b13c6161..6a49f989ac3ad 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -29,7 +29,7 @@ ) from torch.distributed.tensor.placement_types import _Partial, Shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase +from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -38,7 +38,11 @@ from torch.utils._typing_utils import not_none -device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) device_count = torch.accelerator.device_count() try: @@ -58,7 +62,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran os.environ["LOCAL_RANK"] = f"{local_rank}" -@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.") +@unittest.skipIf(TEST_XPU or TEST_HPU, "XPU/HPU does not support gloo backend.") class DeviceMeshTestGlooBackend(DTensorTestBase): @property def backend(self): diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 1f6c4aece1e80..54bc65bc93365 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -43,6 +43,7 @@ SequenceParallel, ) from torch.testing._internal.common_distributed import ( + ACCELERATOR_DIST_BACKENDS, MultiProcContinuousTest, MultiProcessTestCase, MultiThreadedTestCase, @@ -396,14 +397,17 @@ def build_device_mesh(self) -> DeviceMesh: return init_device_mesh(self.device_type, (self.world_size,)) def init_pg(self, eager_init, backend: Optional[str] = None) -> None: - if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: + if backend is None: + backend = self.backend + + requires_gpu = any( + gpu_backend in backend for gpu_backend in ACCELERATOR_DIST_BACKENDS + ) + if requires_gpu and torch.accelerator.device_count() < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) curr_backend = dist.get_default_backend_for_device(self.device_type) - if backend is None: - backend = self.backend - if backend not in [ "nccl", "gloo", From 9fb52efc797b47a1f425a03aa5e47b866d8b1098 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sun, 30 Nov 2025 18:12:24 -0800 Subject: [PATCH 074/338] [user-streams] Handle the record_stream problem (#168061) This PR implements handling for properly deallocating a tensor allocated on one stream and used in a side stream in the backward pass. If the tensor's last usage is on the side stream, then syncrhonization is needed to ensure the memory is not reused on the allocating stream until after the side stream is finished. Accordingly, we estimate the runtime of both streams, record an event on the side stream at the last usage of the tensor, and then wait on that event in the allocating stream. In order to find the right location to wait, we estimate the runtime of all ops on both streams, and find the corresponding location in the allocating stream where we think the side stream has finished its work based on the runtime estimates. We then insert a `synced_deallocation` op at that location. This op waits on an event and takes a dummy tensor argument. This final usage of the tensor as a dummy argument serves the function of delaying the deallocation (and any subsequent reuse) of the memory on the allocating stream until after the wait event has completed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168061 Approved by: https://github.com/ngimel, https://github.com/eellison --- test/dynamo/test_streams.py | 187 +++++++++++++++++- torch/_dynamo/variables/streams.py | 30 +++ .../_functorch/_aot_autograd/graph_capture.py | 7 +- .../_functorch/_aot_autograd/indexed_dict.py | 54 +++++ torch/_functorch/_aot_autograd/streams.py | 167 +++++++++++++++- 5 files changed, 437 insertions(+), 8 deletions(-) create mode 100644 torch/_functorch/_aot_autograd/indexed_dict.py diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 7a40ae926a527..c594c87b7f1b7 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -525,7 +525,11 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None # Annotation: {'stream': 0} - add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = None + + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(2, 0); record_event_default = None + sync_dealloc_default = torch.ops.streams.sync_dealloc.default(2, 1, mul_3); mul_3 = sync_dealloc_default = None return (add_3, add_2) """, ) @@ -590,7 +594,11 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): wait_event_default = torch.ops.streams.wait_event.default(2, 0); wait_event_default = None # Annotation: {'stream': 0} - add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = None + + # No stacktrace found for following nodes + record_event_default_1 = torch.ops.streams.record_event.default(3, 0); record_event_default_1 = None + sync_dealloc_default = torch.ops.streams.sync_dealloc.default(3, 1, mul_3); mul_3 = sync_dealloc_default = None return (add_3, add_2) """, ) @@ -689,6 +697,181 @@ def test_run_opcheck_wait_record_stream(self): for args in sample_inputs: opcheck(wait_stream, args) + @requires_cuda + def test_record_stream_problem_basic(self): + # see https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream + # for what this tests/solves for + # We expect there to be a sync_dealloc op added to the graph for y + # synchronizing the first stream w/ the second stream after the second stream is finished + def fn(x): + e = torch.Event() + with torch.Stream(device="cuda:0"): + y = torch.ones(2, 2, device="cuda:0") + e.record() + z = y * x + + with torch.Stream(device="cuda:0"): + e.wait() + z0 = y * 2 * x + + return z0, z + + inp = (torch.ones(2, 2, device="cuda", requires_grad=True),) + ( + actual, + _, + fw_graphs, + bw_graphs, + ) = extract_graph(fn, *inp) + + actual[1].sum().backward() + + self.assertExpectedInline( + print_graph(bw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): + # Annotation: {'stream': 1} + ones: "f32[2, 2]" = torch.ops.aten.ones.default([2, 2], device = device(type='cuda', index=0), pin_memory = False) + + # Annotation: {'stream': 2} + mul_1: "f32[2, 2]" = torch.ops.aten.mul.Tensor(ones, 2) + mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, mul_1); tangents_1 = mul_1 = None + + # Annotation: {'stream': 1} + mul_4: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, ones); tangents_2 = ones = None + + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(3, 1); record_event_default = None + wait_event_default = torch.ops.streams.wait_event.default(3, 2); wait_event_default = None + + # Annotation: {'stream': 2} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_3, mul_4); mul_3 = None + + # No stacktrace found for following nodes + record_event_default_1 = torch.ops.streams.record_event.default(4, 2); record_event_default_1 = None + sync_dealloc_default = torch.ops.streams.sync_dealloc.default(4, 1, mul_4); mul_4 = sync_dealloc_default = None + return (add,) +""", + ) + + @requires_cuda + def test_record_stream_problem_interleaved(self): + # see https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream + # for what this tests/solves for + # This will have interleaved computation where y is + # first allocated on the first stream used on the second stream + # used on the first stream again then finally used on the last stream + def fn(x): + e = torch.Event() + with torch.Stream(device="cuda:0"): + y = torch.ones(2, 2, device="cuda:0") + z = y * x + e.record() + + with torch.Stream(device="cuda:0"): + e.wait() + z0 = y * 2 * z + e.record() + + with torch.Stream(device="cuda:0"): + e.wait() + z1 = y * x * z0 + e.record() + + with torch.Stream(device="cuda:0"): + e.wait() + z2 = y * 4 * z1 + e.record() + + e.wait() + return z, z1, z2 + + inp = (torch.ones(2, 2, device="cuda", requires_grad=True),) + ( + actual, + _, + fw_graphs, + bw_graphs, + ) = extract_graph(fn, *inp) + + actual[1].sum().backward() + + self.assertExpectedInline( + print_graph(bw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[2, 2]", mul: "f32[2, 2]", tangents_1: "f32[2, 2]", \ +tangents_2: "f32[2, 2]", tangents_3: "f32[2, 2]"): + # Annotation: {'stream': 1} + ones: "f32[2, 2]" = torch.ops.aten.ones.default([2, 2], device = device(type='cuda', index=0), pin_memory = False) + + # Annotation: {'stream': 4} + mul_5: "f32[2, 2]" = torch.ops.aten.mul.Tensor(ones, 4) + mul_7: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_3, mul_5); tangents_3 = mul_5 = None + + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(6, 4); record_event_default = None + wait_event_default = torch.ops.streams.wait_event.default(6, 3); wait_event_default = None + + # Annotation: {'stream': 3} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, mul_7); tangents_2 = None + + # No stacktrace found for following nodes + record_event_default_4 = torch.ops.streams.record_event.default(10, 3); record_event_default_4 = None + sync_dealloc_default = torch.ops.streams.sync_dealloc.default(10, 4, mul_7); mul_7 = sync_dealloc_default = None + + # Annotation: {'stream': 3} + mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(ones, primals_1); primals_1 = None + mul_8: "f32[2, 2]" = torch.ops.aten.mul.Tensor(add, mul_3); mul_3 = None + + # No stacktrace found for following nodes + record_event_default_1 = torch.ops.streams.record_event.default(7, 3); record_event_default_1 = None + + # Annotation: {'stream': 2} + mul_1: "f32[2, 2]" = torch.ops.aten.mul.Tensor(ones, 2) + mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(mul_1, mul); mul = None + + # Annotation: {'stream': 3} + mul_9: "f32[2, 2]" = torch.ops.aten.mul.Tensor(add, mul_2); add = mul_2 = None + mul_10: "f32[2, 2]" = torch.ops.aten.mul.Tensor(mul_9, ones); mul_9 = None + + # No stacktrace found for following nodes + wait_event_default_1 = torch.ops.streams.wait_event.default(7, 2); wait_event_default_1 = None + + # Annotation: {'stream': 2} + mul_11: "f32[2, 2]" = torch.ops.aten.mul.Tensor(mul_8, mul_1); mul_1 = None + + # No stacktrace found for following nodes + record_event_default_5 = torch.ops.streams.record_event.default(11, 2); record_event_default_5 = None + sync_dealloc_default_1 = torch.ops.streams.sync_dealloc.default(11, 3, mul_8); mul_8 = sync_dealloc_default_1 = None + record_event_default_2 = torch.ops.streams.record_event.default(8, 2); record_event_default_2 = None + wait_event_default_2 = torch.ops.streams.wait_event.default(8, 1); wait_event_default_2 = None + + # Annotation: {'stream': 1} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_1, mul_11); tangents_1 = None + + # No stacktrace found for following nodes + record_event_default_6 = torch.ops.streams.record_event.default(12, 1); record_event_default_6 = None + sync_dealloc_default_2 = torch.ops.streams.sync_dealloc.default(12, 2, mul_11); mul_11 = sync_dealloc_default_2 = None + + # Annotation: {'stream': 1} + mul_12: "f32[2, 2]" = torch.ops.aten.mul.Tensor(add_1, ones); add_1 = ones = None + + # No stacktrace found for following nodes + record_event_default_3 = torch.ops.streams.record_event.default(9, 1); record_event_default_3 = None + wait_event_default_3 = torch.ops.streams.wait_event.default(9, 3); wait_event_default_3 = None + + # Annotation: {'stream': 3} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_10, mul_12); mul_10 = None + + # No stacktrace found for following nodes + record_event_default_7 = torch.ops.streams.record_event.default(13, 3); record_event_default_7 = None + sync_dealloc_default_3 = torch.ops.streams.sync_dealloc.default(13, 1, mul_12); mul_12 = sync_dealloc_default_3 = None + return (add_2,) +""", + ) + @requires_cuda def test_inductor_lowering(self): with patch("torch._inductor.config.implicit_fallbacks", False): diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index 38da38a8cfc18..426f50e76d6ab 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -175,6 +175,36 @@ def _( has_side_effect(torch.ops.streams.wait_stream.default) +@custom_op("streams::sync_dealloc", mutates_args=()) +def sync_dealloc( + wait_event_index: int, src_stream_index: int, to_dealloc: torch.Tensor +) -> None: + """An op which waits on an event and moves the last usage of to_dealloc + after the wait, so that after the sync occurs, the deallocation or + subsequent reuse of the tensor's memory will be guaranteed to happen + after a side stream is finished using it. + See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream + for more details""" + torch.ops.streams.wait_event.default(wait_event_index, src_stream_index) + + +has_side_effect(torch.ops.streams.sync_dealloc.default) + + +@custom_op("streams::record_stream", mutates_args=()) +def record_stream(tensor: torch.Tensor, stream_index: int) -> None: + tensor.record_stream(_get_stream_by_index(stream_index)) + + +@record_stream.register_fake +def _( + src_stream_index: int, + wait_event_index: int, + to_dealloc: torch.Tensor, +) -> None: + pass + + class SymbolicStreamState: """Track the currently entered stream if any""" diff --git a/torch/_functorch/_aot_autograd/graph_capture.py b/torch/_functorch/_aot_autograd/graph_capture.py index f17a516183975..7dceaee3dacb2 100644 --- a/torch/_functorch/_aot_autograd/graph_capture.py +++ b/torch/_functorch/_aot_autograd/graph_capture.py @@ -33,7 +33,7 @@ handle_effect_tokens_fn, ) from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta -from .streams import assign_backward_streams, insert_backward_syncs +from .streams import assign_backward_streams, insert_backward_syncs, sync_deallocations from .utils import ( call_and_expect_output_descs, copy_fwd_metadata_to_bw_nodes, @@ -477,8 +477,13 @@ def aot_dispatch_autograd_graph( # After copying metadata, assign streams to gradient accumulation nodes assign_backward_streams(fx_g) + # Insert syncs for newly assigned backward streams insert_backward_syncs(fx_g) + # Sync deallocations for tensors where the stream w/ their last usage + # is distinct from their allocation strea + sync_deallocations(fx_g) + fx_g.graph.eliminate_dead_code() if not aot_config.disable_functionalization: # There should be *NO* mutating ops in the graph at this point. diff --git a/torch/_functorch/_aot_autograd/indexed_dict.py b/torch/_functorch/_aot_autograd/indexed_dict.py new file mode 100644 index 0000000000000..39a06996c6e08 --- /dev/null +++ b/torch/_functorch/_aot_autograd/indexed_dict.py @@ -0,0 +1,54 @@ +from collections.abc import Iterator, MutableMapping +from typing import Generic, Optional, TypeVar + + +K = TypeVar("K") +V = TypeVar("V") + + +# Used for fast next key access (using the fact that the dict is ordered) +# Note: doesn't support deletion but we don't need it! +class IndexedDict(MutableMapping[K, V], Generic[K, V]): + """A dict that maintains insertion order with O(1) index access.""" + + __slots__ = ("_dict", "_keys", "_key_to_index") + + def __init__(self) -> None: + self._dict: dict[K, V] = {} + self._keys: list[K] = [] # typing: ignore[bad-override] + self._key_to_index: dict[K, int] = {} + + def __setitem__(self, key: K, value: V) -> None: + if key not in self._dict: + self._key_to_index[key] = len(self._keys) + self._keys.append(key) + self._dict[key] = value + + def __getitem__(self, key: K) -> V: + return self._dict[key] + + def __delitem__(self, key: K) -> None: + raise NotImplementedError("Deletion not supported for IndexedDict") + + def __len__(self) -> int: + return len(self._dict) + + def __iter__(self) -> Iterator[K]: + return iter(self._keys) + + def __contains__(self, key: object) -> bool: + return key in self._dict + + def next_key(self, key: K) -> Optional[K]: + """Get the next key in insertion order. O(1).""" + idx = self._key_to_index.get(key) + if idx is not None and idx + 1 < len(self._keys): + return self._keys[idx + 1] + return None + + def prev_key(self, key: K) -> Optional[K]: + """Get the previous key in insertion order. O(1).""" + idx = self._key_to_index.get(key) + if idx is not None and idx > 0: + return self._keys[idx - 1] + return None diff --git a/torch/_functorch/_aot_autograd/streams.py b/torch/_functorch/_aot_autograd/streams.py index 1fc8a965740fd..1eb76a637bf71 100644 --- a/torch/_functorch/_aot_autograd/streams.py +++ b/torch/_functorch/_aot_autograd/streams.py @@ -1,21 +1,61 @@ -from typing import Optional, TypeAlias +from typing import Any, Optional, TypeAlias import torch.fx import torch.fx.traceback +import torch.utils._pytree as pytree from torch._dynamo.graph_utils import _get_flat_args from torch._dynamo.variables.streams import get_current_stream, new_event +from torch.utils._runtime_estimation import ( + _FLOAT_TYPES, + _IGNORE_OPS, + get_compute_time, + get_transfer_time, +) + +from .indexed_dict import IndexedDict Node: TypeAlias = torch.fx.Node Graph: TypeAlias = torch.fx.Graph +def get_roofline_estimate(node: Node) -> float: + assert node.op == "call_function", "non-func node in roofline estimate" + + def map_value(x: Any) -> Any: + return x.meta.get("value", x) if isinstance(x, Node) else x + + func = node.target + if func in _IGNORE_OPS: + return 0.0 + + mapped_args = torch.fx.map_arg(node.args, map_value) + mapped_kwargs = torch.fx.map_arg(node.kwargs, map_value) + flat_args_kwargs = [map_value(x) for x in _get_flat_args(node, {})] + flat_outs, _ = pytree.tree_flatten(node.meta.get("value", node)) + out = node.meta.get("value", node) + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in _FLOAT_TYPES + } + + return ( + max( + get_transfer_time(flat_args_kwargs, flat_outs), + get_compute_time(func, mapped_args, mapped_kwargs, out, out_dtypes), + ) + / 1e6 + ) + + def is_gradient_acc(node: Node) -> bool: return node.meta.get("is_gradient_acc", False) def is_bwd_node(node: Node) -> bool: - return node.meta.get("partitioner_tag") == "is_backward" + tag = node.meta.get("partitioner_tag") + return tag == "is_backward" or tag == "must_be_in_backward" def get_device(node: Node) -> torch.device: @@ -44,7 +84,7 @@ def set_stream(node: Node, ind: int) -> None: node.meta["custom"] = {"stream": ind} -def insert_record_event_after_node(graph: Graph, node: Node, event_ind: int) -> None: +def insert_record_event_after_node(graph: Graph, node: Node, event_ind: int) -> Node: with graph.inserting_after(node): node = graph.call_function( torch.ops.streams.record_event.default, @@ -55,8 +95,10 @@ def insert_record_event_after_node(graph: Graph, node: Node, event_ind: int) -> ) node.meta["partitioner_tag"] = "must_be_in_backward" + return node -def insert_wait_event_before_node(graph: Graph, node: Node, event_ind: int) -> None: + +def insert_wait_event_before_node(graph: Graph, node: Node, event_ind: int) -> Node: with graph.inserting_before(node): node = graph.call_function( torch.ops.streams.wait_event.default, @@ -67,6 +109,95 @@ def insert_wait_event_before_node(graph: Graph, node: Node, event_ind: int) -> N ) node.meta["partitioner_tag"] = "must_be_in_backward" + return node + + +def populate_stream_timeline( + stream_to_timeline: dict[Optional[int], IndexedDict[Node, float]], + graph: Graph, + stream_index: Optional[int], +) -> IndexedDict[Node, float]: + if stream_index not in stream_to_timeline: + stream_to_timeline[stream_index] = IndexedDict() + total_time = 0.0 + for node in graph.nodes: + # mlazos: not sure if we should include forward here too but don't think it matters + if is_bwd_node(node) and get_stream(node) == stream_index: + total_time += get_roofline_estimate(node) + stream_to_timeline[stream_index][node] = ( + total_time # NB: total time includes the node's runtime + ) + + return stream_to_timeline[stream_index] + + +# NB: we start all estimates at 0, estimating the total runtime of each stream with timestamps at each node +# we then try and use these timestamps to estimate when to deallocate tensors used in side streams +# See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream +# for details on the problem being addressed. Rather than using the automatic memory management approach of record_stream +# we attempt to find the point which to deallocate based on the estimated timestamps. +def handle_synced_deallocation( + graph: Graph, + stream_to_exec_trace: dict[Optional[int], IndexedDict[Node, float]], + node: Node, + last_usage: Node, +) -> None: + assert is_bwd_node(node), ( + "synced allocations should only be handled on backward nodes" + ) + assert is_bwd_node(last_usage), ( + "synced allocations should only be handled on backward nodes" + ) + allocating_stream = get_stream(node) + side_stream = get_stream(last_usage) + assert allocating_stream != side_stream, ( + "allocating and side stream should be different for synced deallocations" + ) + if not torch.cuda.is_available(): + # fallback to record_stream in this case + with graph.inserting_after(node): + graph.call_function( + torch.ops.streams.record_stream.default, + ( + node, + get_stream_or_current_stream(last_usage), + ), + {}, + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + allocating_stream_trace = populate_stream_timeline( + stream_to_exec_trace, graph, allocating_stream + ) + side_stream_trace = populate_stream_timeline( + stream_to_exec_trace, graph, side_stream + ) + + alloc_ptr = node + target_side_stream_time = side_stream_trace[last_usage] + # linear search from first usage of tensor to a point in time after the side stream has finished + while alloc_ptr is not None: + alloc_time = allocating_stream_trace[alloc_ptr] + + if alloc_time >= target_side_stream_time: + break + elif alloc_time < target_side_stream_time: + next_ptr = allocating_stream_trace.next_key(alloc_ptr) + if next_ptr is not None: + alloc_ptr = next_ptr + else: + break + + wait_event = new_event() + record_node = insert_record_event_after_node(graph, last_usage, wait_event) + with graph.inserting_after(max(alloc_ptr, record_node)): + graph.call_function( + torch.ops.streams.sync_dealloc.default, + (wait_event, get_stream_or_current_stream(alloc_ptr), node), + {}, + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + def insert_sync( graph: Graph, @@ -111,7 +242,7 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None: def insert_backward_syncs(gm: torch.fx.GraphModule) -> None: """Inserts stream syncs for backward nodes if consumer and producer are on different streams""" - node_to_wait_event_ind = {} + node_to_wait_event_ind: dict[Node, int] = {} for node in gm.graph.nodes: if is_bwd_node(node): flat_args = _get_flat_args(node, {}) @@ -122,3 +253,29 @@ def insert_backward_syncs(gm: torch.fx.GraphModule) -> None: arg_stream = get_stream(arg) if arg_stream != cur_node_stream and get_device(arg).type != "cpu": insert_sync(gm.graph, node, arg, node_to_wait_event_ind) + + +def sync_deallocations(gm: torch.fx.GraphModule) -> None: + """Handles https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream""" + # Note: this is only needed if the last usage of a tensor is on a stream other than + # the stream the tensor was allocated on + + # an estimated timestamp from the beginning of graph execution (assuming 0 CPU overhead) + # I think this is fine because you should have large tensors if you're using streams + # although perhaps I could add a constant 10us per op ahead of the first stream op? + # a trace of all the nodes running in a given stream + stream_to_exec_trace: dict[Optional[int], IndexedDict[Node, float]] = {} + for node in gm.graph.nodes: + if is_bwd_node(node): + allocating_stream = get_stream(node) + users = list(node.users.keys()) + if not users: + continue + last_user = max(user for user in users) + if last_user.op == "output": + continue + side_stream = get_stream(last_user) + if allocating_stream != side_stream: + handle_synced_deallocation( + gm.graph, stream_to_exec_trace, node, last_user + ) From 55c4ab554845481d0a69a3811937575fe8bb1a66 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 1 Dec 2025 07:13:56 +0000 Subject: [PATCH 075/338] Revert "[dynamo][dicts] Decentralize and Improve key hash implementation for Dict variable tracker (#169204)" This reverts commit 84149583d483e9c973c9a0feda70e4f3964947b0. Reverted https://github.com/pytorch/pytorch/pull/169204 on behalf of https://github.com/wdvr due to failing test/test_fx.py::TestFXAPIBackwardCompatibility::test_public_api_surface [GH job link](https://github.com/pytorch/pytorch/actions/runs/19803784913/job/56735267443) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/84149583d483e9c973c9a0feda70e4f3964947b0) consistently ([comment](https://github.com/pytorch/pytorch/pull/169204#issuecomment-3594928195)) --- test/dynamo/test_dicts.py | 210 +----------------- .../TestCustomOp.test_impl_device_cpu | 0 torch/_dynamo/graph_break_registry.json | 44 ---- torch/_dynamo/utils.py | 18 -- torch/_dynamo/variables/base.py | 51 ----- torch/_dynamo/variables/builtin.py | 9 - torch/_dynamo/variables/constant.py | 25 --- torch/_dynamo/variables/dicts.py | 206 +++++++++++------ torch/_dynamo/variables/functions.py | 46 ---- torch/_dynamo/variables/higher_order_ops.py | 9 - torch/_dynamo/variables/lists.py | 34 --- torch/_dynamo/variables/misc.py | 37 --- torch/_dynamo/variables/tensor.py | 28 --- torch/_dynamo/variables/torch.py | 9 - torch/_dynamo/variables/user_defined.py | 55 ++--- 15 files changed, 154 insertions(+), 627 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 4c233ea9458f3..cdaeb2d91fbfb 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -19,7 +19,6 @@ import torch._functorch.config import torch.nn import torch.utils.checkpoint -from torch._dynamo.exc import Unsupported from torch._dynamo.testing import same from torch._dynamo.utils import dict_items from torch.testing._internal.common_utils import ( @@ -90,7 +89,7 @@ def forward(self, x): inp = torch.randn(4, 4) mod = Foo() - opt_f = torch.compile(mod, backend="eager", fullgraph=True) + opt_f = torch.compile(mod) self.assertEqual(mod(inp), opt_f(inp)) def test_dict_subclass_local_with_non_dict_method(self): @@ -519,7 +518,7 @@ def fn(d): args1 = {namedtuple: None, 3: torch.randn(3)} cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + opt_fn = torch.compile(fn, backend=cnts) self.assertEqual(fn(args1), opt_fn(args1)) self.assertEqual(cnts.frame_count, 1) # Test a failing namedtuple guard @@ -539,7 +538,7 @@ def fn(d, x): args1[3] = z cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + opt_fn = torch.compile(fn, backend=cnts) self.assertEqual(fn(args1, x), opt_fn(args1, x)) self.assertEqual(cnts.frame_count, 1) @@ -1063,6 +1062,8 @@ def fn(b: Any): a = {"one": torch.ones(1)} return a | b + from torch._dynamo.exc import Unsupported + for arg in args: with self.assertRaisesRegex(Unsupported, "Observed exception"): _ = fn(arg) @@ -1203,156 +1204,6 @@ def f(): opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) - def test_range_as_dict_key(self): - def fn(x): - d = {range(5): x * 2, range(10, 15): x * 3} - return d[range(0, 5, 1)] + d[range(10, 15)] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_tuple_as_dict_key(self): - def fn(x): - d = {(1, 2): x * 2, (3, 4, 5): x * 3} - return d[(1, 2)] + d[(3, 4, 5)] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_enum_as_dict_key(self): - class Color(enum.Enum): - RED = 1 - GREEN = 2 - BLUE = 3 - - def fn(x): - d = {Color.RED: x * 2, Color.GREEN: x * 3, Color.BLUE: x * 4} - return d[Color.RED] + d[Color.GREEN] + d[Color.BLUE] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_intenum_as_dict_key(self): - class Priority(enum.IntEnum): - LOW = 1 - MEDIUM = 2 - HIGH = 3 - - def fn(x): - d = {Priority.LOW: x * 2, Priority.MEDIUM: x * 3, Priority.HIGH: x * 4} - return d[Priority.LOW] + d[Priority.MEDIUM] + d[Priority.HIGH] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_frozenset_as_dict_key(self): - def fn(x): - d = {frozenset([1, 2]): x * 2, frozenset([3, 4, 5]): x * 3} - return d[frozenset([1, 2])] + d[frozenset([3, 4, 5])] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_typing_union_as_dict_key(self): - from typing import Union - - def fn(x): - d = {Union[int, str]: x * 2, Union[float, bool]: x * 3} - return d[Union[int, str]] + d[Union[float, bool]] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_numpy_dtype_as_dict_key(self): - import numpy as np - - def fn(x): - d = {np.float32: x * 2, np.int64: x * 3, np.bool_: x * 4} - return d[np.float32] + d[np.int64] + d[np.bool_] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_method_wrapper_as_dict_key(self): - add_method = list.__add__ - mul_method = list.__mul__ - - def fn(x): - # Method wrappers are the type of bound methods on built-in types - d = {add_method: x * 2, mul_method: x * 3} - return d[add_method] + d[mul_method] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_torch_builtin_function_as_dict_key(self): - def fn(x, y): - # Using torch built-in functions as dictionary keys - d = {torch.add: x * 2, torch.mul: y * 3, torch.sub: x + y} - return d[torch.add] + d[torch.mul] + d[torch.sub] - - x = torch.randn(4) - y = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x, y), opt_fn(x, y)) - - def test_frozen_dataclass_as_dict_key(self): - from dataclasses import dataclass - - @dataclass(frozen=True) - class Point: - x: int - y: int - - def fn(tensor): - p1 = Point(1, 2) - p2 = Point(3, 4) - d = {p1: tensor * 2, p2: tensor * 3} - return d[Point(1, 2)] + d[Point(3, 4)] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_list_as_dict_key_raises_typeerror(self): - def fn(x): - d = {[1, 2, 3]: x * 2} - return d[[1, 2, 3]] - - x = torch.randn(4) - - # First check that eager execution raises TypeError - with self.assertRaises(TypeError): - fn(x) - - # Also check that compiled version raises TypeError - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - with self.assertRaisesRegex(Unsupported, "Observed exception"): - opt_fn(x) - - def test_get_default_nowrap_functions_as_dict_key(self): - def fn(x): - # Get the set of default nowrap functions - nowrap_funcs = torch.overrides.get_default_nowrap_functions() - # Use the set as a dict key and search for Tensor.grad.__get__ in it - d = {frozenset(nowrap_funcs): x * 2} - # Check if Tensor.grad.__get__ is in the set - if torch.Tensor.grad.__get__ in nowrap_funcs: - return d[frozenset(nowrap_funcs)] + x - return x - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - instantiate_parametrized_tests(DictTests) @@ -1887,9 +1738,7 @@ def fn(x): new_gn = partial(gn, x=1) key = Container(new_gn, 4) new_dict[key] = 5 - # Make another key that should hash to the same value - key1 = Container(new_gn, 4) - return x * new_dict[key1] + return x * new_dict[key] x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -1898,53 +1747,6 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) - def test_custom_object_as_dict_key(self): - """Test that custom objects with __hash__ as dict keys are properly handled. - - This test verifies that when using custom objects with overridden __hash__ - and __eq__ as dictionary keys, two instances with the same hash and equality - should be recognized as the same key. - """ - - class CustomKey: - def __init__(self, value, name): - self.value = value - self.name = name - - def fn(x): - d = {} - # Create first instance - key1 = CustomKey(42, "test") - d[key1] = x * 2 - - # Create second instance with same values - should hash to same value - key2 = CustomKey(42, "test") - d[key2] = x * 3 # This should overwrite the first value - - return d[key1] * d[key2] - - x = torch.randn(4) - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertTrue(same(opt_fn(x), fn(x))) - - def test_user_defined_object(self): - class A: - def __init__(self): - self.x = {} - REF[self] = {} - - REF = {} - - def f(a, x): - REF[a]["foo"] = x - return x + 1 - - opt_f = torch.compile(f, backend="eager", fullgraph=True) - - x = torch.randn(4) - self.assertTrue(same(f(A(), x), opt_f(A(), x))) - class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict diff --git a/test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu b/test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 9bfe593417699..5f967971005f6 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3667,49 +3667,5 @@ "Use custom operators instead of direct attribute/method access." ] } - ], - "GB0363": [ - { - "Gb_type": "User-defined object with overridden __hash__", - "Context": "hashing object of type={type(obj)} and variable tracker {vt}", - "Explanation": "Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", - "Hints": [ - "Dynamo does not support hashing user-defined objects with overridden __hash__", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } - ], - "GB0364": [ - { - "Gb_type": "Dynamo cannot determine whether the underlying object is hashable", - "Context": "is_python_hashable {self}", - "Explanation": "Dynamo does not know whether the underlying python object for {self} is hashable", - "Hints": [ - "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } - ], - "GB0365": [ - { - "Gb_type": "Dynamo cannot determine the hash of an object", - "Context": "get_python_hash {self}", - "Explanation": "Dynamo does not know the hash of the underlying python object for {self}", - "Hints": [ - "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } - ], - "GB0366": [ - { - "Gb_type": "Dynamo cannot determine the equality comparison of an object", - "Context": "is_python_equal {self}", - "Explanation": "Dynamo does not know the equality comparison of the underlying python object for {self}", - "Hints": [ - "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } ] } diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 5b1070aad5ad6..c6825737ec994 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4956,21 +4956,3 @@ def get_traced_code() -> Optional[list[CodeType]]: from torch._guards import TracingContext return TracingContext.get_traced_code() - - -def raise_on_overridden_hash(obj: Any, vt: VariableTracker) -> None: - from . import graph_break_hints - from .exc import unimplemented - - is_overridden = type(obj).__dict__.get("__hash__", False) - - if is_overridden: - unimplemented( - gb_type="User-defined object with overridden __hash__", - context=f"hashing object of type={type(obj)} and variable tracker {vt}", - explanation=f"Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", - hints=[ - "Dynamo does not support hashing user-defined objects with overridden __hash__", - *graph_break_hints.SUPPORTABLE, - ], - ) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 0dcf75d344060..617f787e43d8a 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -683,57 +683,6 @@ def build( else: return variables.LazyVariableTracker.create(value, source) - def is_python_hashable(self): - """ - Unlike the variable tracker's own __hash__, this method checks whether - the underlying Python object referenced by this variable tracker is hashable. - """ - unimplemented( - gb_type="Dynamo cannot determine whether the underlying object is hashable", - context=f"is_python_hashable {self}", - explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable", - hints=[ - ( - f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." - ), - *graph_break_hints.SUPPORTABLE, - ], - ) - - def get_python_hash(self): - """ - Unlike the variable tracker’s own __hash__, this method is used by - ConstDictVariableTracker to compute the hash of the underlying key object. - """ - unimplemented( - gb_type="Dynamo cannot determine the hash of an object", - context=f"get_python_hash {self}", - explanation=f"Dynamo does not know the hash of the underlying python object for {self}", - hints=[ - ( - f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." - ), - *graph_break_hints.SUPPORTABLE, - ], - ) - - def is_python_equal(self, other): - """ - NB - Deliberately not overriding the __eq__ method because that can - disable the __hash__ for the vt itself. - """ - unimplemented( - gb_type="Dynamo cannot determine the equality comparison of an object", - context=f"is_python_equal {self}", - explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}", - hints=[ - ( - f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." - ), - *graph_break_hints.SUPPORTABLE, - ], - ) - def __init__( self, *, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 8fdaefea56f89..ae6678628634a 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -3243,15 +3243,6 @@ def call_contains( ) -> VariableTracker: return a.call_method(tx, "__contains__", [b], {}) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.fn) - - def is_python_equal(self, other): - return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn - @contextlib.contextmanager def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]: diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 0b2eaaea80826..672fa1d804383 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -23,7 +23,6 @@ istype, np, raise_args_mismatch, - raise_on_overridden_hash, ) from .base import ValueMutationNew, VariableTracker @@ -341,20 +340,6 @@ def call_obj_hasattr( result = hasattr(self.value, name) return variables.ConstantVariable.create(result) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.value) - - def is_python_equal(self, other): - # Could be an EnumVariable as well - from .tensor import SymNodeVariable - - if isinstance(other, SymNodeVariable): - return self.as_python_constant() == other.evaluate_expr() - return self.as_python_constant() == other.as_python_constant() - class EnumVariable(VariableTracker): """VariableTracker for enum.Enum and enum.IntEnum instances @@ -403,13 +388,3 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker member = getattr(self.value, name) source = self.source and AttrSource(self.source, name) return VariableTracker.build(tx, member, source=source) - - def is_python_hashable(self): - raise_on_overridden_hash(self.value, self) - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 9b98c91723063..422cae7c4d3f1 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -20,11 +20,14 @@ import collections import functools +import inspect import operator import types -from collections.abc import Sequence +from collections.abc import Hashable as py_Hashable, Sequence from typing import Any, Optional, TYPE_CHECKING, Union +from torch._subclasses.fake_tensor import is_fake + from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import raise_observed_exception, unimplemented @@ -52,8 +55,8 @@ # [Adding a new supported class within the keys of ConstDictVariable] -# - Implement is_python_hashable() method in the VariableTracker subclass -# - Implement get_python_hash() and is_python_equal() methods for hashable types +# - Add its tracker type to is_hashable +# - (perhaps) Define how it is compared in _HashableTracker._eq_impl def was_instancecheck_override(obj: Any) -> bool: @@ -70,7 +73,7 @@ def raise_unhashable( raise_observed_exception( TypeError, tx, - msg=f"Unhashable type: {arg.python_type()!r} and variable tracker = {type(arg.realize())}", + args=[ConstantVariable(f"unhashable type: {type(arg.realize())}")], ) @@ -85,7 +88,52 @@ def is_hashable(x: VariableTracker) -> bool: and x.is_hashable() ): return True - return x.is_python_hashable() + + if isinstance(x, variables.TensorVariable): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return x.as_proxy().node.meta.get("example_value") is not None + elif isinstance(x, variables.TupleVariable): + return all(is_hashable(e) for e in x.items) + elif isinstance(x, variables.FrozenDataClassVariable): + return all(is_hashable(e) for e in x.fields.values()) + elif ( + isinstance(x, variables.UserDefinedObjectVariable) + and not was_instancecheck_override(x.value) + and inspect.getattr_static(x.value, "__hash__") is int.__hash__ + and isinstance(x.value, int) + ): + return isinstance(x.value, py_Hashable) + elif isinstance(x, variables.FunctoolsPartialVariable): + return ( + is_hashable(x.func) + and all(is_hashable(arg) for arg in x.args) + and all(is_hashable(value) for value in x.keywords.values()) + ) + else: + return isinstance( + x, + ( + variables.BuiltinVariable, + variables.SymNodeVariable, + variables.ConstantVariable, + variables.EnumVariable, + variables.FrozensetVariable, + variables.UserDefinedClassVariable, + variables.UserFunctionVariable, + variables.SkipFunctionVariable, + variables.misc.NumpyVariable, + variables.NNModuleVariable, + variables.UnspecializedNNModuleVariable, + variables.MethodWrapperVariable, + variables.TorchInGraphFunctionVariable, + variables.TypingVariable, + variables.FunctoolsPartialVariable, + variables.WeakRefVariable, + variables.TorchHigherOrderOperatorVariable, + ), + ) class ConstDictVariable(VariableTracker): @@ -106,47 +154,88 @@ class _HashableTracker: def __init__(self, vt: VariableTracker) -> None: # We specialize SymNodes vt = specialize_symnode(vt) - - # If Dynamo does not know the hashability of the vt, it will raise unsupported here + # TODO Temporarily remove to figure out what keys are we breaking on + # and add proper support for them if not is_hashable(vt): raise_unhashable(vt) self.vt = vt - def __hash__(self) -> int: - """ - Computes the hash value for the wrapped VariableTracker. - - For unrealized LazyVariableTrackers, uses the hash of the original value - to avoid realizing the tracker and inserting unnecessary guards. - For all other cases, delegates to the VariableTracker's get_python_hash method. - - Returns: - The hash value of the underlying variable tracker - """ + @property + def underlying_value(self) -> Any: if ( isinstance(self.vt, variables.LazyVariableTracker) and not self.vt.is_realized() and self.vt.is_hashable() ): - return hash(self.vt.original_value()) - return self.vt.get_python_hash() - - def __eq__(self, other) -> bool: - """ - Checks equality between two _HashableTracker instances. + return self.vt.original_value() + if isinstance(self.vt, variables.TensorVariable): + x = self.vt.as_proxy().node.meta["example_value"] + elif isinstance(self.vt, variables.TupleVariable): + Hashable = ConstDictVariable._HashableTracker + x = tuple(Hashable(e).underlying_value for e in self.vt.items) + elif isinstance(self.vt, variables.NNModuleVariable): + return self.vt.value + elif isinstance(self.vt, variables.UnspecializedNNModuleVariable): + return self.vt.value + elif isinstance(self.vt, variables.UserFunctionVariable): + return self.vt.get_function() + elif isinstance(self.vt, variables.WeakRefVariable): + # Access the underlying value inside the referent_vt for the key representation + Hashable = ConstDictVariable._HashableTracker + return Hashable(self.vt.referent_vt).underlying_value + elif isinstance(self.vt, variables.FrozenDataClassVariable): + Hashable = ConstDictVariable._HashableTracker + fields_values = { + k: Hashable(v).underlying_value + for k, v in self.vt.fields.items() # type: ignore[attr-defined] + } + return variables.FrozenDataClassVariable.HashWrapper( + self.vt.python_type(), fields_values + ) + elif isinstance(self.vt, variables.UserDefinedObjectVariable): + # The re module in Python 3.13+ has a dictionary (_cache2) with + # an object as key (`class _ZeroSentinel(int): ...`): + # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual + return self.vt.value # type: ignore[attr-defined,union-attr] + elif isinstance(self.vt, variables.FunctoolsPartialVariable): + Hashable = ConstDictVariable._HashableTracker + items = (self.vt.func, *self.vt.args, *self.vt.keywords.values()) + x = tuple(Hashable(e).underlying_value for e in items) + return x + else: + x = self.vt.as_python_constant() + return x - Delegates to the VariableTracker's is_python_equal method to compare - the underlying variable trackers for Python-level equality. + def __hash__(self) -> int: + return hash(self.underlying_value) + + @staticmethod + def _eq_impl(a: Any, b: Any) -> bool: + # TODO: Put this in utils and share it between variables/builtin.py and here + type_a, type_b = type(a), type(b) + if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)): + return False + + if isinstance(a, tuple): + Hashable = ConstDictVariable._HashableTracker + return len(a) == len(b) and all( + Hashable._eq_impl(u, v) for u, v in zip(a, b) + ) + elif is_fake(a): + return a is b + else: + return a == b - Args: - other: Another _HashableTracker instance to compare with + def __eq__(self, other: object) -> bool: + Hashable = ConstDictVariable._HashableTracker + assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), ( + type(other) + ) + if isinstance(other, Hashable): + return Hashable._eq_impl(self.underlying_value, other.underlying_value) - Returns: - True if the underlying variable trackers are Python-equal, False otherwise - """ - if self.vt is other.vt: - return True - return self.vt.is_python_equal(other.vt) + # constant + return Hashable._eq_impl(self.underlying_value, other) def __init__( self, @@ -235,7 +324,7 @@ def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return ( - vt.is_python_hashable() + is_hashable(vt) and Hashable(vt) in self.items and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) ) @@ -447,6 +536,8 @@ def call_method( Hashable = ConstDictVariable._HashableTracker + arg_hashable = args and is_hashable(args[0]) + if name == "__init__": temp_dict_vt = variables.BuiltinVariable(dict).call_dict( tx, *args, **kwargs @@ -515,7 +606,6 @@ def call_method( self.install_dict_keys_match_guard() return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -530,21 +620,16 @@ def call_method( tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = args[1] return ConstantVariable.create(None) - elif name == "__delitem__" and self.is_mutable(): - arg_hashable = args and is_hashable(args[0]) - if arg_hashable: - self.install_dict_keys_match_guard() - self.should_reconstruct_all = True - tx.output.side_effects.mutation(self) - self.items.__delitem__(Hashable(args[0])) - return ConstantVariable.create(None) - else: - return super().call_method(tx, name, args, kwargs) + elif name == "__delitem__" and arg_hashable and self.is_mutable(): + self.install_dict_keys_match_guard() + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.__delitem__(Hashable(args[0])) + return ConstantVariable.create(None) elif name == "get": if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -560,7 +645,6 @@ def call_method( if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -652,7 +736,6 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -668,7 +751,6 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -821,12 +903,6 @@ def clone(self, **kwargs: Any) -> VariableTracker: self.install_dict_keys_match_guard() return super().clone(**kwargs) - def is_python_hashable(self): - """ - Dictionaries are mutable and therefore not hashable in Python. - """ - return False - class MappingProxyVariable(VariableTracker): # proxies to the original dict_vt @@ -1340,18 +1416,6 @@ def call_method( return FrozensetVariable(r.items) # type: ignore[attr-defined] return super().call_method(tx, name, args, kwargs) - def is_python_hashable(self): - """ - Frozensets are immutable and hashable in Python. - """ - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class DictKeySetVariable(SetVariable): def debug_repr(self) -> str: @@ -1541,9 +1605,3 @@ def call_method( return self.dv_dict.call_method(tx, "__eq__", [args[0].dv_dict], {}) return ConstantVariable.create(False) return super().call_method(tx, name, args, kwargs) - - def is_python_hashable(self): - """ - Dictionary item views are not hashable in Python. - """ - return False diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 360c0fdd94488..deee9bcec42de 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -807,15 +807,6 @@ def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: return collected return None - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.fn) - - def is_python_equal(self, other): - return isinstance(other, variables.UserFunctionVariable) and self.fn is other.fn - class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): _nonvar_fields = { @@ -1972,15 +1963,6 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return fn_var_getattr(tx, self.value, self.source, name) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.value) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class WrappedSkipFunctionVariable(SkipFunctionVariable): def __init__( @@ -2367,34 +2349,6 @@ def guard_as_python_constant(self) -> Any: **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, ) - def is_python_hashable(self) -> bool: - return ( - self.func.is_python_hashable() - and all(arg.is_python_hashable() for arg in self.args) - and all(value.is_python_hashable() for value in self.keywords.values()) - ) - - def get_python_hash(self): - func_hash = self.func.get_python_hash() - args_hash = (arg.get_python_hash() for arg in self.args) - values_hash = (value.get_python_hash() for value in self.keywords.values()) - return hash((func_hash, *args_hash, *values_hash)) - - def is_python_equal(self, other): - return ( - self.func.is_python_equal(other.func) - and all( - arg_a.is_python_equal(arg_b) - for (arg_a, arg_b) in zip(self.args, other.args) - ) - and all( - value_a.is_python_equal(value_b) - for (value_a, value_b) in zip( - self.keywords.values(), other.keywords.values() - ) - ) - ) - class PolyfilledFunctionVariable(VariableTracker): _nonvar_fields = { diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 8b178b3be1ac3..afb6522ac0e5c 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1738,15 +1738,6 @@ def _call_function( def as_python_constant(self): return self.value - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index a97c284f9516c..4f21e35479fb8 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -620,25 +620,6 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return self.items[fields.index(name)] return super().var_getattr(tx, name) - def is_python_hashable(self): - return True - - def get_python_hash(self): - l = self.range_length() - start = self.start() - step = self.step() - return hash((l, start, step)) - - def is_python_equal(self, other): - if not isinstance(other, variables.RangeVariable): - return False - - return ( - self.start() == other.start() - and self.step() == other.step() - and self.stop() == other.stop() - ) - class CommonListMethodsVariable(BaseListVariable): """ @@ -1000,9 +981,6 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) - def is_python_hashable(self): - return False - class DequeVariable(CommonListMethodsVariable): def __init__( @@ -1192,18 +1170,6 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) - def is_python_hashable(self): - return all(item.is_python_hashable() for item in self.items) - - def get_python_hash(self): - items = tuple(x.get_python_hash() for x in self.items) - return hash(items) - - def is_python_equal(self, other): - return isinstance(other, variables.TupleVariable) and all( - a.is_python_equal(b) for (a, b) in zip(self.items, other.items) - ) - class SizeVariable(TupleVariable): """torch.Size(...)""" diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 5bd8ad5d075e6..8d074f913dbf5 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1306,15 +1306,6 @@ def is_python_constant(self): def as_python_constant(self): return self.method_wrapper - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class GetSetDescriptorVariable(VariableTracker): def __init__(self, desc, **kwargs) -> None: @@ -1449,15 +1440,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # codegen.append_output(codegen.create_load_const(self.value)) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): @@ -1636,15 +1618,6 @@ def as_proxy(self): return super().as_proxy() - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls class NullVariable(VariableTracker): @@ -2124,13 +2097,3 @@ def reconstruct(self, codegen: "PyCodegen"): codegen(self.referent_vt) codegen(self.callback_vt) codegen.extend_output(create_call_function(2, False)) - - def is_python_hashable(self): - return self.referent_vt.is_python_hashable() - - def get_python_hash(self): - # weakref relies on the referent's hash - return self.referent_vt.get_python_hash() - - def is_python_equal(self, other): - return self.referent_vt.is_python_equal(other.referent_vt) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 548e69ef0262d..0787ef7c49b57 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1428,20 +1428,6 @@ def set_name_hint(self, name: str): self.proxy.node._rename(name) self._is_name_set = True - def is_python_hashable(self): - # Tensors are hashable if they have an example_value (a fake tensor) - # Most VT's should have one. - # It'd be nice if at some point we could assert that they all have one - return self.as_proxy().node.meta["example_value"] is not None - - def get_python_hash(self): - return hash(self.as_proxy().node.meta["example_value"]) - - def is_python_equal(self, other): - a = self.as_proxy().node.meta["example_value"] - b = other.as_proxy().node.meta["example_value"] - return a is b - class SymNodeVariable(VariableTracker): """ @@ -1530,20 +1516,6 @@ def call_method( ), ) - def is_python_hashable(self): - return True - - def get_python_hash(self): - # Essentially convert the SymNode to a constant variable whenever its - # searched for a dict key. - return hash(self.evaluate_expr()) - - def is_python_equal(self, other): - if isinstance(other, SymNodeVariable): - return self.evaluate_expr() == other.evaluate_expr() - # could be constant variable as well - return self.evaluate_expr() == other.as_python_constant() - class NumpyNdarrayVariable(TensorVariable): """ diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 78d87a09713ab..76da71f6fb323 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -2075,15 +2075,6 @@ def torch_function_override_enabled(self, tx, args, kwargs): ) ) and can_dispatch_torch_function(tx, args, kwargs) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.value) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class DispatchKeySetVariable(BaseTorchVariable): """represents torch.DispatchKeySet""" diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 9de51061cbe31..e87af5b87a75a 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -89,7 +89,6 @@ object_has_getattribute, proxy_args_kwargs, raise_args_mismatch, - raise_on_overridden_hash, set_methods, tensortype_to_dtype, tuple_methods, @@ -928,18 +927,6 @@ def const_getattr(self, tx: "InstructionTranslator", name): return self.value.__name__ return super().const_getattr(tx, name) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.value) - - def is_python_equal(self, other): - return ( - isinstance(other, variables.UserDefinedClassVariable) - and self.value is other.value - ) - class UserDefinedExceptionClassVariable(UserDefinedClassVariable): @property @@ -1756,20 +1743,26 @@ def call_obj_hasattr( handle_observed_exception(tx) return variables.ConstantVariable.create(False) - def is_python_hashable(self): - raise_on_overridden_hash(self.value, self) - return True - def get_python_hash(self): - # default hash - return hash(self.value) +class FrozenDataClassVariable(UserDefinedObjectVariable): + class HashWrapper: + """This class is hashed if a dataclass is used as a key in a dict. + It's necessary to avoid side effects from calling the __init__ of the dataclass class when hashing""" - def is_python_equal(self, other): - # id check - return self.value is other.value + def __init__(self, c, fields): + self.cls = c + self.fields = tuple(fields.items()) + def __eq__(self, other): + return ( + type(self) is type(other) + and self.cls == other.cls + and self.fields == other.fields + ) + + def __hash__(self): + return hash((self.cls, self.fields)) -class FrozenDataClassVariable(UserDefinedObjectVariable): @staticmethod def create(tx, value, source): from dataclasses import fields @@ -1867,22 +1860,6 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value_type.__name__})" - def is_python_hashable(self): - # TODO - Check corner cases like eq=False, hash=False etc - return True - - def get_python_hash(self): - return hash(tuple(arg.get_python_hash() for arg in self.fields.values())) - - def is_python_equal(self, other): - is_class_same = self.python_type() is other.python_type() - is_field_name_same = self.fields.keys() == other.fields.keys() - is_field_value_same = all( - value_a.is_python_equal(value_b) - for value_a, value_b in zip(self.fields.values(), other.fields.values()) - ) - return is_class_same and is_field_name_same and is_field_value_same - class SourcelessGraphModuleVariable(UserDefinedObjectVariable): def __init__( From 7d2a33e4ebf60b217a3cd77feae19231eb996fc8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 1 Dec 2025 09:35:13 +0000 Subject: [PATCH 076/338] Revert "[Accelerator] Add Accelerator Capabilities API (#165631)" This reverts commit c8210e7d94bad5ae21ac389fa4ba8a463c76c4d0. Reverted https://github.com/pytorch/pytorch/pull/165631 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165631#issuecomment-3595505779)) --- aten/src/ATen/DeviceAccelerator.cpp | 6 -- aten/src/ATen/DeviceAccelerator.h | 5 -- c10/core/DeviceCapability.h | 74 ------------------- c10/core/impl/DeviceGuardImplInterface.h | 26 ------- c10/core/impl/VirtualGuardImpl.h | 4 - .../torch_openreg/csrc/runtime/OpenRegGuard.h | 9 --- .../torch_openreg/tests/test_device.py | 9 +-- torch/_C/__init__.pyi.in | 1 - torch/accelerator/__init__.py | 26 +------ torch/csrc/DeviceAccelerator.cpp | 19 ----- 10 files changed, 2 insertions(+), 177 deletions(-) delete mode 100644 c10/core/DeviceCapability.h diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index efab9ec9c5927..aa9d6e6b1ce9b 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -130,12 +130,6 @@ c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) { impl.uncheckedSetDevice({device_type, device_index}); return impl.getDevice().index(); } - -c10::DeviceCapability getDeviceCapability(c10::DeviceIndex device_index) { - const auto device_type = getAccelerator(true).value(); - c10::impl::VirtualGuardImpl impl(device_type); - return impl.getDeviceCapability({device_type, device_index}); -} // NOLINTEND(bugprone-unchecked-optional-access) } // namespace at::accelerator diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index d24b42ca459e7..2cc4cff7cd1f2 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -74,10 +73,6 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); -// Get the device capability of the given device index. -TORCH_API c10::DeviceCapability getDeviceCapability( - c10::DeviceIndex device_index); - TORCH_API inline void emptyCache() { const auto device_type = getAccelerator(true).value(); at::getDeviceAllocator(device_type)->emptyCache(); diff --git a/c10/core/DeviceCapability.h b/c10/core/DeviceCapability.h deleted file mode 100644 index e24f12614978a..0000000000000 --- a/c10/core/DeviceCapability.h +++ /dev/null @@ -1,74 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace c10 { - -constexpr size_t NUMBER_OF_DEVICE_CAPABILITIES = NumScalarTypes; - -// Generate bitfields for each scalar type -#define DEFINE_SCALAR_TYPE(_1, n) unsigned int has_##n : 1; - -// Generate enum indices for each scalar type -#define DEFINE_SCALAR_ENUM(_1, name) kIndex_##name, - -enum ScalarTypeIndex { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_ENUM) -}; - -/** - * @brief DeviceCapability represents the the common capabilities that all - * devices should support. - * - * This struct provides a compact way to represent the common capabilities that - * all devices should support. Includes the following capabilities: - * - Supported data types - * - * Purpose - * - Enable device-specific optimizations based on supported capabilities - * - * Contract - * - * Supported data types: - * - Each bitfield represents support for one device capability - * - Bit value 1 means the capability is supported, 0 means not supported - * - The struct is initialized with all capabilities enabled by default - * - * @note Adding New Capabilities - * - * 1. Define the new capability in the `DeviceCapability` struct - * 2. Update the support of the new capability in each accelerator - * implementation - * 3. Add the new capability to the returned PyObject Dictionary - */ -struct C10_API DeviceCapability { - union { - struct { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE) - }; - uint64_t capability_bits; // Allow direct bit manipulation - }; - - // Default constructor with all capabilities enabled. - DeviceCapability() - : capability_bits((1ULL << NUMBER_OF_DEVICE_CAPABILITIES) - 1) {} - - // Iterate supported ScalarTypes without allocating a vector - template - void forEachSupportedScalarType(F&& visitor) const { -#define VISIT_SCALAR_TYPE(_1, n) \ - if (has_##n) { \ - visitor(ScalarType::n); \ - } - - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(VISIT_SCALAR_TYPE) - -#undef VISIT_SCALAR_TYPE - } -}; - -#undef DEFINE_SCALAR_ENUM -#undef DEFINE_SCALAR_TYPE -} // namespace c10 diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index 00096584b9229..f9f67497c6315 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -192,15 +191,6 @@ struct C10_API DeviceGuardImplInterface { */ virtual DeviceIndex deviceCount() const noexcept = 0; - /** - * Get the following capabilities of the current device: - * (1) Data type support - * Returns DeviceCapability object. - */ - virtual DeviceCapability getDeviceCapability(Device /*unused*/) const { - TORCH_CHECK(false, "Backend doesn't support getting device capabilities."); - } - /** * Return true if all the work previously enqueued on the stream for * asynchronous execution has completed running on the device. @@ -301,22 +291,6 @@ struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface { return 1; } - DeviceCapability getDeviceCapability(Device /*unused*/) const override { - DeviceCapability cap; - if constexpr (D == DeviceType::Meta) { - cap.capability_bits = 0; - // Meta only supports basic types for shape inference - // Byte, Char, Short, Int, Long, Float, Double, - // Bool, ComplexFloat, ComplexDouble - cap.capability_bits = (1ULL << kIndex_Byte) | (1ULL << kIndex_Char) | - (1ULL << kIndex_Short) | (1ULL << kIndex_Int) | - (1ULL << kIndex_Long) | (1ULL << kIndex_Float) | - (1ULL << kIndex_Double) | (1ULL << kIndex_ComplexFloat) | - (1ULL << kIndex_ComplexDouble) | (1ULL << kIndex_Bool); - } - return cap; - } - // Event-related functions void record( void** /*event*/, diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 0254c69baba00..3d259f5e390e3 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -57,10 +57,6 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { return impl_->deviceCount(); } - DeviceCapability getDeviceCapability(Device d) const override { - return impl_->getDeviceCapability(d); - } - // Event functions void record( void** event, diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h index 3c1c1193d3cdb..59bc2d5cdbff5 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -51,14 +50,6 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { return c10::Device(static_type, device_index); } - /** - * Get the device capability for a given device. - * By default, OpenReg has 2 same devices with the same capability. - */ - c10::DeviceCapability getDeviceCapability(c10::Device /*unused*/) const override { - return c10::DeviceCapability(); - } - /** * Set the current device to c10::Device. */ diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py index 9cb4a785d36e7..f925f15600ce7 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py @@ -1,7 +1,7 @@ # Owner(s): ["module: PrivateUse1"] import torch -from torch.testing._internal.common_dtype import get_all_dtypes +import torch_openreg # noqa: F401 from torch.testing._internal.common_utils import run_tests, TestCase @@ -31,13 +31,6 @@ def test_invalid_device_index(self): with self.assertRaisesRegex(RuntimeError, "The device index is out of range"): torch.accelerator.set_device_index(2) - def test_device_capability(self): - capability = torch.accelerator.get_device_capability("openreg:0") - supported_dtypes = capability["supported_dtypes"] - expected_dtypes = get_all_dtypes(include_complex32=True, include_qint=True) - - self.assertTrue(all(dtype in supported_dtypes for dtype in expected_dtypes)) - if __name__ == "__main__": run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 01c4abd6fab76..e9b58b9ce71eb 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2493,7 +2493,6 @@ def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails def _accelerator_getAccelerator() -> _device: ... def _accelerator_setDeviceIndex(device_index: _int) -> None: ... def _accelerator_getDeviceIndex() -> _int: ... -def _accelerator_getDeviceCapability(device_index: _int) -> dict[str, Any]: ... def _accelerator_setStream(Stream) -> None: ... def _accelerator_getStream(device_index: _int) -> Stream: ... def _accelerator_synchronizeDevice(device_index: _int) -> None: ... diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index a1335d2ad03bd..e1a82aa63ce22 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -2,8 +2,7 @@ This package introduces support for the current :ref:`accelerator` in python. """ -from functools import cache -from typing import Any +from typing import Optional from typing_extensions import deprecated import torch @@ -26,7 +25,6 @@ "current_accelerator", "current_device_idx", # deprecated "current_device_index", - "get_device_capability", "current_stream", "device_count", "device_index", @@ -154,28 +152,6 @@ def current_device_index() -> int: """ -@cache -def get_device_capability(device: _device_t = None, /) -> dict[str, Any]: - r"""Return the capability of the currently selected device. - - Args: - device (:class:`torch.device`, str, int, optional): The device to query capabilities for - :ref:`accelerator` device type. If not given, - use :func:`torch.accelerator.current_device_index` by default. - - Returns: - dict[str, Any]: A dictionary containing device capability information. The dictionary includes: - - ``supported_dtypes`` (set(torch.dtype)): Set of PyTorch data types supported by the device - - Examples: - >>> # Query capabilities for current device - >>> capabilities = torch.accelerator.get_device_capability("cuda:0") - >>> print("Supported dtypes:", capabilities["supported_dtypes"]) - """ - device_index = _get_device_index(device, optional=True) - return torch._C._accelerator_getDeviceCapability(device_index) - - def set_device_index(device: _device_t, /) -> None: r"""Set the current device index to a given device. diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index c6ffa893d95ae..14e54851178f5 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -33,25 +33,6 @@ void initModule(PyObject* module) { return at::accelerator::getDeviceIndex(); }); - m.def("_accelerator_getDeviceCapability", [](c10::DeviceIndex device_index) { - const auto device_type = at::accelerator::getAccelerator(true).value(); - torch::utils::maybe_initialize_device(device_type); - auto caps = at::accelerator::getDeviceCapability(device_index); - - py::dict dict; - - py::set dtype_set; - caps.forEachSupportedScalarType([&](c10::ScalarType dtype) { - THPDtype* thp_dtype = torch::getTHPDtype(dtype); - py::object dtype_obj = - py::reinterpret_borrow((PyObject*)thp_dtype); - dtype_set.add(dtype_obj); - }); - - dict["supported_dtypes"] = dtype_set; - return dict; - }); - m.def("_accelerator_setStream", [](c10::Stream stream) { const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); From 29e5455a4740c326ab187c7aa7b5ef98034ea563 Mon Sep 17 00:00:00 2001 From: kundaMwiza Date: Mon, 1 Dec 2025 09:43:13 +0000 Subject: [PATCH 077/338] [inductor] Add option to transpose tensor descriptors on load / store (#165541) Currently block descriptor index matches fail being represented by tensor descriptors if: - the innermost stride is not 1 - Outer strides are 16 byte aligned This PR adds a transpose option to increase the possibility for more tensor descriptor matches by first reordering block parameters so that dimensions are in descending stride order, and then inserting transpose operations after load / before store operations. This can be considered a general follow up for the ND case to https://github.com/pytorch/pytorch/pull/160480#discussion_r2342953799 which added transpose support for 2D tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165541 Approved by: https://github.com/njriasan, https://github.com/jansel --- test/inductor/test_max_autotune.py | 9 +- .../test_torchinductor_strided_blocks.py | 47 +- torch/_inductor/codegen/triton.py | 414 +++++++++++++----- torch/_inductor/config.py | 5 + torch/_inductor/runtime/triton_heuristics.py | 2 +- torch/_inductor/select_algorithm.py | 48 +- torch/_inductor/template_heuristics/triton.py | 1 + torch/_inductor/utils.py | 22 +- 8 files changed, 419 insertions(+), 129 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index db34336aeda99..b1c4d1b61659a 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1947,7 +1947,7 @@ def test_triton_template_generated_code_cache_key(self): # Make sure all args of generate_and_load_args are passed to make_key_args (Except generate_with_caching) # update this function each time new arg added to generate_and_load and make sure arg is added to make_key self.assertEqual(generate_and_load_args - 1, make_key_args) - self.assertEqual(generate_and_load_args, 18) + self.assertEqual(generate_and_load_args, 19) @fresh_cache() @config.patch( @@ -2036,6 +2036,7 @@ def func_test1(x, y, z, m): 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[10,30], 'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]", 'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','tma_store':False, + 'transpose_discontiguous_tensor_descriptors_override':None, 'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32', 'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True},'hint_override':None}""" @@ -2075,8 +2076,10 @@ def func_test1(x, y, z, m): "[[s27,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"], 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94], 'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0, - 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','tma_store':False,'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False, - 'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True},'hint_override':None}""" + 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','tma_store':False, + 'transpose_discontiguous_tensor_descriptors_override':None, + 'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32, + 'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True},'hint_override':None}""" expected = expected.replace("cuda", GPU_TYPE) self.assertExpectedInline( remove_white_space(cache_key), diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index d70375ebc3345..bea7b667ccd78 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -81,7 +81,6 @@ def xfail_if_use_tensor_descriptor(fn): "test_broadcast_prefer_nd_tiling_False_x_size2_y_size2", "test_broadcast_prefer_nd_tiling_True_x_size0_y_size0", "test_broadcast_prefer_nd_tiling_True_x_size2_y_size2", - "test_broadcast_with_singleton_dims", ), TMA_XFAIL, ) @@ -168,8 +167,6 @@ def count_code(substr: str, expected: Optional[int]): self.assertEqual(len(code), expected_num_programs) count_code("@triton.jit", expected_num_triton_kernels) count_code(self.block_descriptor_constructor_str, expected_num_block_pointers) - # Verify that 1D shapes aren't being transposed for the TMA store. - count_code("tl.trans", 0) return result, code @@ -912,7 +909,6 @@ def test_reduction_multiple_discontiguous_dims(self): msg="AssertionError: Scalars are not equal!, " "https://github.com/intel/torch-xpu-ops/issues/2332" ) - @xfail_if_use_tensor_descriptor # Cannot use TMA API for store with no x dimension. @test_torchinductor.skip_if_triton_cpu # Illegal instruction File; cannot xfail because it crashes process def test_2d_reduction_multi_kernel(self): """ @@ -1023,7 +1019,6 @@ def test_enable_tiled_reductions(self, tile_reductions: bool): # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2 if tile_reductions else 1) - @xfail_if_use_tensor_descriptor def test_complex_reshape_block_ptr(self): def func(x, y): add_ = x + y @@ -1242,7 +1237,6 @@ def foo(x, y, z): # dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0 # } # This is now fixed by ensuring that that wild symbols only match integers - @xfail_if_use_tensor_descriptor @skipIfXpu( msg="Triton issue exposed by new driver, will be resolved after next triton update." ) @@ -1412,10 +1406,51 @@ class TritonBlockPointerTestGPU(BlockDescriptorTestBase): "Requires Triton CUDA backend and CUDA compute capability >= 9.0", ) @config.patch({"triton.use_tensor_descriptor": True, "assume_aligned_inputs": True}) +@instantiate_parametrized_tests class TritonTensorDescriptorTestCUDA(BlockDescriptorTestBase): block_descriptor_constructor_str = "tl.make_tensor_descriptor" device = GPU_TYPE + @config.patch({"triton.transpose_discontiguous_tensor_descriptor": True}) + @parametrize( + "view_size,permute_order,num_tensor_descriptors,expect_transpose", + [ + ((128,), (0,), 3, False), + ((128, 128), (0, 1), 3, False), + ((128, 64), (1, 0), 3, True), + ((256, 32, 16), (2, 0, 1), 3, True), + ((16, 32, 256), (2, 0, 1), 3, True), + ], + ) + def test_match_with_transpose( + self, + view_size: tuple[int], + permute_order: tuple[int], + num_tensor_descriptors: int, + expect_transpose: bool, + ): + a = self._discontiguous_tensor(view_size, self.device) + pre_permute_size = [1] * len(view_size) + for i, value in zip(permute_order, view_size): + pre_permute_size[i] = value + b = self._discontiguous_tensor(pre_permute_size, self.device) + b = b.permute(permute_order) + + def fn(a, b): + return a * b + + result, (code,) = self._run_and_compare( + fn, + a, + b, + expected_num_block_pointers=num_tensor_descriptors, + expected_num_triton_kernels=1, + config_patches=tiled_reduction_config, + ) + + transpose_count = code.count("tl.trans") + self.assertEqual(transpose_count, 1 if expect_transpose else 0) + test_torchinductor.copy_tests( CommonTemplate, diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 9b718f0c780c1..cba36a25aad8d 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -11,9 +11,10 @@ import operator import os import textwrap +from abc import abstractmethod from collections.abc import Callable, Iterable, Sequence from functools import lru_cache -from typing import Any, cast, Optional, TYPE_CHECKING, Union +from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union import sympy from sympy.printing.precedence import PRECEDENCE @@ -30,7 +31,7 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges -from .. import config, ir, metrics +from .. import config, ir, metrics, utils from ..async_compile import AsyncCompile from ..codecache import code_hash, get_path, PyCodeCache, write_atomic from ..debug import set_kernel_post_grad_provenance_tracing @@ -105,9 +106,9 @@ if TYPE_CHECKING: from types import ModuleType - from typing import TypeVar from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + from torch.fx.experimental.symbolic_shapes import ShapeEnv from ..ir import IRNode from .common import BlockShapeType @@ -273,14 +274,6 @@ def get_block_shape(cls, expr: sympy.Expr) -> BlockShapeType: assert expr_shape is not None - # Below logic handles when index symbols does not match with convention range tree order. - # Mainly, it is for TMA template where TMA indices are expected to be in (x,y), not (y,x). - # so in such case, the get_block_shape(yindex) should be (1,YBLOCK), not (YBLOCK,1). - if isinstance(V.kernel, torch._inductor.select_algorithm.TritonTemplateKernel): - out_shape = V.kernel.template_out_shape - if out_shape == ("XBLOCK", "YBLOCK") and V.kernel.tma_store: - expr_shape = (expr_shape[1], expr_shape[0], *expr_shape[2:]) - return expr_shape @classmethod @@ -341,6 +334,10 @@ class BlockDescriptorOptions: broadcast_shape: Sequence[sympy.Expr] broadcasting_dims: list[bool] final_shape: Sequence[sympy.Expr] + # If the BlockParameters have been sorted using a particular stride order + # transpose load / store blocks at runtime using the information in + # stride_sorter. + stride_sorter: BlockParameters.StrideSorter _boundary_check: Optional[list[int]] = None # Can we safely lift the constructor # to the top of the kernel? @@ -371,8 +368,8 @@ def create( range_trees: list[IterationRangesRoot], mask_vars: OrderedSet[str], get_max_block: Callable[[str], int], - can_lift=False, - transpose_contiguous=False, + stride_sorter_cls: type[BlockParameters.StrideSorter], + can_lift: bool = False, ) -> BlockDescriptorOptions: """Helper to create a BlockDescriptorOptions instance""" @@ -385,14 +382,10 @@ def lookup_size(exprs: Iterable[sympy.Expr]) -> list[sympy.Expr]: params.shape = lookup_size(params.shape) params.strides = lookup_size(params.strides) - # Strip out dimensions of stride 0. - # These will be restored with tl.broadcast_to. - broadcasting_dims = [ - sizevars.statically_known_equals(stride, 0) for stride in params.strides - ] - # Strip out dimensions of size 1. - # These will be restored by tl.reshape. + # Size 1 dimensions are redundant since the triton kernel shape + # will be e.g. [YBLOCK, XBLOCK], so tl.reshape would just remove these + # dimensions anyway singleton_dims = [ sizevars.statically_known_equals(dim, 1) for dim in params.block_shape ] @@ -400,44 +393,28 @@ def lookup_size(exprs: Iterable[sympy.Expr]) -> list[sympy.Expr]: # Handle a pure singletons, e.g. [1, 1] singleton_dims[-1] = False - # Record the post-broadcast shape before broadcasting dims are removed. - # The pre-broadcast shape is identical to this, except broadcasting dims are - # replaced with 1. - broadcast_shape = [ - dim - for dim, is_singleton in zip(params.block_shape, singleton_dims) - if not is_singleton - ] + # Drop singleton dimensions from the block descriptor. + params = params.remove_dims(singleton_dims) - # Combine all removable dims. - removable_dims = [any(dims) for dims in zip(singleton_dims, broadcasting_dims)] + # Maybe reorder dimensions based on strides + # with tl.trans applied at load / store time + params, stride_sorter = params.maybe_sort_with_stride_order( + stride_sorter_cls=stride_sorter_cls, shape_env=V.graph._shape_env + ) - # Remove singleton_dims from broadcasting_dims so that - # broadcast_shape and broadcasting_dims have the same length + # Strip out dimensions of stride 0. + # These will be restored with tl.broadcast_to. broadcasting_dims = [ - dim - for dim, is_singleton in zip(broadcasting_dims, singleton_dims) - if not is_singleton + sizevars.statically_known_equals(stride, 0) for stride in params.strides ] - def remove_dims(it): - """Removes any broadcasting or singleton dims from a given sequence""" - return [ - item - for item, is_removable in zip(it, removable_dims) - if not is_removable - ] + # Record the post-broadcast shape before broadcasting dims are removed. + # The pre-broadcast shape is identical to this, except broadcasting dims are + # replaced with 1. + broadcast_shape = params.block_shape - # Drop removable dimensions from the input. - params = BlockParameters( - **{ - key: remove_dims(val) for key, val in dataclasses.asdict(params).items() - }, - ) - # TODO: Generalize to ND tensors. - transpose = transpose_contiguous and params.strides[-1] != 1 - if transpose: - params = params.transpose() + # Drop broadcasting dims from the block descriptor. + params = params.remove_dims(broadcasting_dims) # Compute the final shape, adjusting for special kernel types. final_shape = [TritonSymbols.get_block_size(tree) for tree in range_trees] @@ -445,12 +422,6 @@ def remove_dims(it): assert range_trees[0].prefix == "x" final_shape.pop(0) - # Check for when BlockParams have been transposed. - order = list(reversed(range(len(params.shape)))) - if transpose: - final_shape.reverse() - order.reverse() - reduction_ndim = V.kernel.num_reduction_dims if ( not V.kernel.inside_reduction @@ -460,6 +431,14 @@ def remove_dims(it): # Need to expand rank to match the rank used inside the reduction loop final_shape += [sympy.S.One] * reduction_ndim + try: + # Get permutation to sort strides in ascending order. + # This is used as the order argument in tl.make_block_ptr + order = utils.argsort_sym(V.graph._shape_env, params.strides) + except AssertionError: + # Symbolic shapes, failed to evaluate comparison expression + order = list(reversed(range(len(params.strides)))) + result = cls( params=params, constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), @@ -468,6 +447,7 @@ def remove_dims(it): final_shape=final_shape, broadcast_shape=broadcast_shape, broadcasting_dims=broadcasting_dims, + stride_sorter=stride_sorter, can_lift=can_lift, ) result.compute_boundary_check(get_max_block, range_trees) @@ -567,21 +547,55 @@ def codegen_broadcast_and_reshape( initial_shape: Sequence[sympy.Expr], final_shape: Sequence[sympy.Expr], allow_implicit: bool, + for_store: bool, ) -> str: """ Generate a broadcast and a reshape for the block descriptor. This restores stride-0 dimensions which were removed from the block descriptor. + + Transposes are also applied to the input using self.stride_sorter: + if for_store is True: + - First Broadcast the value. Since self.broadcast_shape is stored in + descending stride order, it must be reverted to the original order + since the input value does not have dims with descending strides + - After, transpose the broadcasted value so that dimensions are in + descending stride order + - Finally reshape to the block shape + else (for load): + - First broadcast the value to self.broadcast_shape (strides are descending) + - Then transpose the value so that dimensions no longer have descending strides + - Finally reshape the block to the final kernel tile shape """ + broadcast_shape = self.broadcast_shape + broadcasting_dims = self.broadcasting_dims + + # If the block parameters have been sorted by descending strides, + # permute the broadcasting parameters so that they are compatible + # with the value being stored. This is because the dimensions + # of the value being stored are not sorted in descending stride order, + # but the broadcasting parameters are based on the dims in sorted order + if for_store: + broadcast_shape = self.stride_sorter.revert(self.broadcast_shape) + broadcasting_dims = self.stride_sorter.revert(self.broadcasting_dims) # Reshape to add singletons. pre_broadcast_shape = [ sympy.S.One if is_broadcasting else dim - for dim, is_broadcasting in zip( - self.broadcast_shape, self.broadcasting_dims - ) + for dim, is_broadcasting in zip(broadcast_shape, broadcasting_dims) ] value = triton_reshape(value, initial_shape, pre_broadcast_shape) + if ( + not self.stride_sorter.is_identity + and not for_store + and len(pre_broadcast_shape) == len(final_shape) + ): + # If all we need to do is transpose to match the final shape + # with implicit broadcasting then we don't need an explicit broadcast + # unless the caller requests it. So just test implicit broadcast support + # with the transposed pre broadcast shape + pre_broadcast_shape = self.stride_sorter.revert(pre_broadcast_shape) + # Broadcast singletons. # For loads, we can often implicitly broadcast singleton dimensions. # We need an explicit broadcast for stores, or if the final reshape does more @@ -597,10 +611,32 @@ def codegen_broadcast_and_reshape( ) if any(self.broadcasting_dims) and not supports_implicit_broadcast: - value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(self.broadcast_shape)})" + value = ( + f"tl.broadcast_to({value}, {V.kernel.index_to_str(broadcast_shape)})" + ) + + old_shape = self.broadcast_shape + if not self.stride_sorter.is_identity: + # if for_store the transform is + # (non-descending strides) broadcasted kernel tile shape + # -> (descending strides) block descriptor shape + # o/w if loading the transform is + # (descending strides) ((maybe implicitly) broadcasted block shape + # -> (non-descending) (maybe implicitly) broadcasted kernel tile shape + permute_dims = ( + self.stride_sorter.sort_idx + if for_store + else self.stride_sorter.revert_sort_idx + ) + value = f"tl.trans({value}, {permute_dims})" + old_shape = ( + self.broadcast_shape + if for_store + else self.stride_sorter.revert(self.broadcast_shape) + ) # Reshape to the final shape. - value = triton_reshape(value, self.broadcast_shape, final_shape) + value = triton_reshape(value, old_shape, final_shape) return value @@ -1984,6 +2020,99 @@ class BlockParameters: strides: list[sympy.Expr] = dataclasses.field(default_factory=list) offsets: list[sympy.Expr] = dataclasses.field(default_factory=list) + @dataclasses.dataclass + class StrideSorter: + original_strides: list[int] + sort_idx: list[int] + revert_sort_idx: list[int] = dataclasses.field(init=False) + + def __post_init__(self): + assert len(self.original_strides) > 0 + assert len(self.sort_idx) == len(self.original_strides) + + identity_sort_idx = list(range(len(self.original_strides))) + self._is_identity = self.sort_idx == identity_sort_idx + + # Set revert_sort_idx + sorted_dims_by_strides_map = {k: i for i, k in enumerate(self.sort_idx)} + self.revert_sort_idx = [ + sorted_dims_by_strides_map[i] + for i in range(len(sorted_dims_by_strides_map)) + ] + + @property + def is_identity(self): + return self._is_identity + + @classmethod + @abstractmethod + def create( + cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv + ) -> BlockParameters.StrideSorter: + """Create a `StrideSorter` that can be used to sort block parameters.""" + + def sort(self, attr): + if not self.is_identity: + return [attr[i] for i in self.sort_idx] + return attr + + def revert(self, attr): + if not self.is_identity: + return [attr[i] for i in self.sort_idx] + return attr + + @dataclasses.dataclass + class IdentityStrideSorter(StrideSorter): + def __post_init__(self): + super().__post_init__() + + @classmethod + def create( + cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv + ) -> BlockParameters.StrideSorter: + return cls( + original_strides=original_strides, + sort_idx=list(range(len(original_strides))), + ) + + @dataclasses.dataclass + class TensorDecriptorStrideSorter(StrideSorter): + """ + Sorts BlockParameters dimensions with strides in descending order. + """ + + def __post_init__(self): + super().__post_init__() + + @classmethod + def create( + cls, original_strides: list[Union[int, sympy.Expr]], shape_env: ShapeEnv + ) -> BlockParameters.StrideSorter: + """ + If the strides are not all known constants or if the strides are already + sorted in descending order, return identity sort. + + For example if block_shape @ strides is [ZBLOCK, XBLOCK, YBLOCK] @ [8, 1, 16] + The indices to sort the strides in descending order will be [2, 0, 1]. + The indices to revert back to the original order will be [1, 2, 0]. + """ + identity_sort = list(range(len(original_strides))) + try: + # TODO: even if the strides are not in descending order the strides + # may be tensor descriptor compliant + # i.e. innermost stride == 1 and outer strides 16 byte aligned + # We should benchmark the effect of applying a transpose to these + # cases vs leaving them unsorted. + sort_idx = utils.argsort_sym(shape_env, original_strides, reverse=True) + except AssertionError: + # Symbolic shapes, failed to evaluate comparison expression + sort_idx = identity_sort + + return cls( + original_strides=original_strides, + sort_idx=sort_idx, + ) + def __add__(self, other: BlockParameters) -> BlockParameters: """ Concatenates block parameters. @@ -1992,12 +2121,37 @@ def __add__(self, other: BlockParameters) -> BlockParameters: a, b = tuple(dataclasses.asdict(x) for x in (self, other)) return cls(**{key: a[key] + b[key] for key in a}) - def transpose(self) -> BlockParameters: + def maybe_sort_with_stride_order( + self, stride_sorter_cls: type[StrideSorter], shape_env: ShapeEnv + ) -> tuple[BlockParameters, BlockParameters.StrideSorter]: + """ + Sort `BlockParameter` with stride_sorter_cls. Returns block parameters + as well as a `StrideSorter` which contains information on how the sort + can be reverted. + """ + stride_sorter = stride_sorter_cls.create(self.strides, shape_env=shape_env) + params = BlockParameters( + **{ + key: stride_sorter.sort(val) + for key, val in dataclasses.asdict(self).items() + } + ) + return params, stride_sorter + + def remove_dims(self, removable_dims: list[bool]) -> BlockParameters: + """ + Remove dimensions where removable_dims is True. + """ + + def filter_dims(it): + return [ + item + for item, is_removable in zip(it, removable_dims) + if not is_removable + ] + return BlockParameters( - self.shape[::-1], - self.block_shape[::-1], - self.strides[::-1], - self.offsets[::-1], + **{key: filter_dims(val) for key, val in dataclasses.asdict(self).items()}, ) @@ -2131,8 +2285,9 @@ def are_block_parameters_compatible( # and that the outer strides are 16 byte aligned if not V.graph.sizevars.statically_known_equals(strides[-1], sympy.Integer(1)): log.debug( - "%s TMA API requires innermost stride to be 1.", + "%s TMA API requires innermost stride to be 1. Strides are: %s", self.failed_debug_prefix, + strides, ) return False @@ -2143,8 +2298,10 @@ def are_block_parameters_compatible( sympy.Integer(0), ): log.debug( - "%s TMA API requires outer strides to be 16 byte aligned.", + "%s TMA API requires outer strides to be 16 byte aligned. Dtype bytes: %d, strides: %s", self.failed_debug_prefix, + element_size, + strides, ) return False @@ -2153,6 +2310,18 @@ def are_block_parameters_compatible( # can be loaded / stored. # Start with finding the innermost block type innermost_block_shape = block_params.block_shape[-1] + + # Pure singleton case + if V.graph.sizevars.statically_known_equals( + innermost_block_shape, sympy.Integer(1) + ): + log.debug( + "%s innermost block shape cannot load 16 bytes. Block shape: %s", + self.failed_debug_prefix, + block_params.block_shape, + ) + return False + innermost_block_type = None innermost_block_symt = None for block_type_str in innermost_block_shape.free_symbols: @@ -2161,6 +2330,7 @@ def are_block_parameters_compatible( innermost_block_type = block_type_str innermost_block_symt = block_symt break + assert innermost_block_type and innermost_block_symt, ( f"{innermost_block_shape} expr must contain a single block type from {TritonSymbols.block_types}" ) @@ -2189,8 +2359,10 @@ def are_block_parameters_compatible( innermost_block_bytes, sympy.Integer(16) ): log.debug( - "%s persistent reduction innermost block shape cannot load 16 bytes.", + "%s persistent reduction innermost block shape cannot load 16 bytes. Block shape: %s, persistent RBLOCK: %d", self.failed_debug_prefix, + block_params.block_shape, + persistent_rblock, ) return False @@ -2199,17 +2371,45 @@ def are_block_parameters_compatible( # then the TMA API can only be used if the dtype has an 8 byte element # size so that 16 bytes of data can be loaded in the innermost dimension try: + + def indexing_div_rep( + x: sympy.Expr, + y: sympy.Expr, + z: Optional[sympy.Expr] = None, + ) -> sympy.Expr: + div = x / y + if z: + div = div % z + return div + + solve_expr = innermost_block_shape * element_size - 16 + # Sympy cannot handle FloorDiv and ModularIndexing well, so simplify + solve_expr_simplified = solve_expr.replace( + FloorDiv, indexing_div_rep + ).replace(ModularIndexing, indexing_div_rep) min_block_size = next_power_of_2( int( sympy.nsolve( - innermost_block_shape * element_size - 16, + solve_expr_simplified, innermost_block_type, 1, ) ) ) - block_type_str = V.kernel.index_to_str(innermost_block_type) + # TODO: min block size may be too large / introduce redundancy + if min_block_size > self.kernel.max_block( + prefix_str[innermost_block_symt] + ): + log.debug( + "%s the minimum block size to satisfy expression %s is too large: %d", + self.failed_debug_prefix, + solve_expr_simplified, + min_block_size, + ) + return False + + block_type_str = self.kernel.index_to_str(innermost_block_type) # Check block sizes if the user has provided a fixed triton config if self.kernel.fixed_config: if min_block_size > self.kernel.fixed_config[block_type_str]: @@ -2232,8 +2432,9 @@ def are_block_parameters_compatible( except ValueError: log.debug( - "%s innermost block shape cannot load 16 bytes.", + "%s innermost block shape cannot load 16 bytes. Block params: %s", self.failed_debug_prefix, + block_params.block_shape, ) return False @@ -2262,6 +2463,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): kexpr: Callable[[sympy.Expr], str] = texpr allow_block_ptr = True tma_compatibility_checker_cls = TMACompatibilityChecker + transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None def __init__( self, @@ -2732,17 +2934,39 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: else TensorDescriptorOptions ) nonlocal tma_compatibility_checker + stride_sorter_cls: type[BlockParameters.StrideSorter] if config.triton.use_block_ptr: can_lift = False - transpose_contiguous = False + stride_sorter_cls = BlockParameters.IdentityStrideSorter else: tma_compatibility_checker = cast( TMACompatibilityChecker, tma_compatibility_checker ) can_lift = tma_compatibility_checker.can_lift() + + if ( + self.transpose_discontiguous_tensor_descriptors_override + is not None + ): + transpose_contiguous = ( + self.transpose_discontiguous_tensor_descriptors_override + ) + else: + transpose_contiguous = ( + config.triton.transpose_discontiguous_tensor_descriptor + ) + + # For templates: # Only try transpose if we know the output shape # in case we need to transpose the data. - transpose_contiguous = copy_shape is not None + if hasattr(self, "template_out_shape"): + transpose_contiguous &= copy_shape is not None + + stride_sorter_cls = ( + BlockParameters.TensorDecriptorStrideSorter + if transpose_contiguous + else BlockParameters.IdentityStrideSorter + ) options = options_class.create( params=block_params, @@ -2751,7 +2975,7 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: mask_vars=mask_vars, get_max_block=self.max_block, can_lift=can_lift, - transpose_contiguous=transpose_contiguous, + stride_sorter_cls=stride_sorter_cls, ) if options_class == TensorDescriptorOptions: tma_compatibility_checker = cast( @@ -3001,30 +3225,6 @@ def codegen_block_ptr( return block_descriptor, other def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""): - def stringify_shape(shape): - return tuple( - symt.name if isinstance(symt, sympy.Symbol) else str(symt) - for symt in shape - ) - - if value.shape: - value_forward_shape = stringify_shape(value.shape) - value_reverse_shape = stringify_shape(value.shape[::-1]) - else: - value_forward_shape = None - value_reverse_shape = None - final_shape = stringify_shape(indexing.final_shape) - # TODO: Generalize to N Dimensions - if ( - value_forward_shape != final_shape - and value_reverse_shape == final_shape - and len(final_shape) == 2 - ): - # TMA stores may require transposing the data to ensure we are contiguous along - # the final dimension. This applies to Block-pointers generally, but should only practically - # be reached with TMA. - value = f"tl.trans({value})" - # Stores require an explicit broadcast. We do this in two phases: # 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK, # YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads. @@ -3040,7 +3240,11 @@ def stringify_shape(shape): indexing.broadcasting_dims[idx] = False value = indexing.codegen_broadcast_and_reshape( - value, indexing.final_shape, indexing.block_shape, False + value, + indexing.final_shape, + indexing.block_shape, + allow_implicit=False, + for_store=True, ) # workaround https://github.com/triton-lang/triton/issues/2814 @@ -3232,7 +3436,11 @@ def decide_later(): else: line = f"{block_descriptor}.load({V.kernel.index_to_str(indexing.offsets)})" line = indexing.codegen_broadcast_and_reshape( - line, indexing.block_shape, indexing.final_shape, True + line, + indexing.block_shape, + indexing.final_shape, + allow_implicit=True, + for_store=False, ) shape = indexing.final_shape elif is_sympy_integer_like(original_index): diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 7048990692da0..7ba93575ce8bf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1592,6 +1592,11 @@ class triton: # can be satisfied, along with any existing requirements for index expressions use_tensor_descriptor = False + # (Experimental) + # Whether to allow reordering tensor descriptor matches with descending + # strides, at the expense of transposing values after load / before store. + transpose_discontiguous_tensor_descriptor = True + # Inject a bug into our relu implementation; useful for testing our repro # extraction and minification functionality. # Valid values: "compile_error", "runtime_error", "accuracy" diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 175bf76bfc740..ce3cd317934fe 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2571,7 +2571,7 @@ def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Conf if inductor_meta.get("persistent_reduction"): tma_min_block_sizes = { block_type: block_size - for block_type, block_size in tma_min_block_sizes + for block_type, block_size in tma_min_block_sizes.items() if not prefix_is_reduction(block_type.lower()) } diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 493ca1179fad8..eb1bbf42f8c37 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -390,6 +390,7 @@ def __init__( num_buffers_warp_spec=0, use_jit=False, tma_store=False, + transpose_discontiguous_tensor_descriptors_override=None, prefix_args=0, suffix_args=0, epilogue_fn=identity, @@ -420,6 +421,29 @@ def __init__( features=SIMDKernelFeatures([], numel), hint_override=hint_override, ) + if tma_store: + # By default `construct_range_trees` will return the range_trees in the order + # ["z", "y", "x", "r0_", "r1_"] (see simd.py:all_prefixes) + # and this order defines what the kernel block shape will be. So if the template + # input / output has requested e.g. ["x", "y"], `construct_range_trees` will still return the + # trees in the order ["y", "x"]. This would mean that the template would need to transpose + # the loaded value. + # The below sorts the range trees according to that required by the caller + prefix_to_range_tree = {rt.prefix: rt for rt in self.range_trees} + pw_sorted_range_trees = [] + reduction_idx = None + for i, prefix in enumerate(tiling): + rt = prefix_to_range_tree[prefix] + # pyrefly: ignore # missing-argument + if rt.is_reduction: + reduction_idx = i + break + rt.index = i + rt.grid_dim = i + rt.tensor_dim = i + pw_sorted_range_trees.append(rt) + self.range_trees = pw_sorted_range_trees + self.range_trees[reduction_idx:] + self.input_nodes = input_nodes self.output_node = output_node self.named_input_nodes = {} # type: ignore[var-annotated] @@ -427,6 +451,9 @@ def __init__( self.kernel_name = kernel_name self.use_jit = use_jit self.tma_store = tma_store + self.transpose_discontiguous_tensor_descriptors_override = ( + transpose_discontiguous_tensor_descriptors_override + ) self.num_stages = num_stages self.num_warps = num_warps self.num_consumer_groups = num_consumer_groups @@ -1169,13 +1196,8 @@ def store_output( intermediate_lines: list[str] = [] epilogue_index_symbols: list[sympy.Symbol] = [] if self.tma_store: - # Generate the expected indexing symbols. - # Note: TMA indices are expected to be in the - # format (x, y), but the range_tree is always - # (yindex, xindex). - index_order = [1, 0] val_shape_copy = list(val_shape) - for i, range_tree in zip(index_order, self.range_trees[:-1]): + for i, range_tree in enumerate(self.range_trees[:-1]): name = range_tree.name symbol = range_tree.symbol() epilogue_index_symbols.append(symbol) @@ -1196,7 +1218,7 @@ def store_output( index_symbols[i], val_shape[i], i, - len(index_order), + len(val_shape), # pyrefly: ignore [missing-argument] block_name=range_tree.symt.name, ) @@ -1213,10 +1235,6 @@ def store_output( # after the remapping. # pyrefly: ignore [missing-argument] val_shape_copy[i] = range_tree.symt.name - # Reverse the index symbols because TMA is indexed - # as (x, y) whereas the variables will naturally be indexed - # as (y, x) - epilogue_index_symbols.reverse() val_shape = tuple(val_shape_copy) else: mask_vars: list[str] = [] @@ -1564,6 +1582,7 @@ def make_key( epilogue_fn: Optional[Callable[..., Any]], epilogue_fn_hash: Optional[str], tma_store: bool, + transpose_discontiguous_tensor_descriptors_override: Optional[bool], subgraphs: Optional[list[ir.Buffer]], # has to be none to cache workspace_arg: Optional[WorkspaceArg], # has to be none to cache layout: ir.Layout, @@ -1621,6 +1640,7 @@ def has_flexible_layout() -> bool: "num_buffers_warp_spec": num_buffers_warp_spec, "epilogue_fn_hash": epilogue_fn_hash, "tma_store": tma_store, + "transpose_discontiguous_tensor_descriptors_override": transpose_discontiguous_tensor_descriptors_override, "kwargs": kwargs, "hint_override": hint_override, } @@ -1736,6 +1756,7 @@ def generate_and_load( generate_with_caching, hint_override: Optional[int] = None, tma_store: bool = False, + transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None, ) -> Optional[GenerateAndLoadResult]: """Generate the python code and load it into the current process""" caching_enabled = ( @@ -1755,6 +1776,7 @@ def generate_and_load( epilogue_fn, epilogue_fn_hash, tma_store, + transpose_discontiguous_tensor_descriptors_override, subgraphs, workspace_arg, layout, @@ -1815,6 +1837,7 @@ def make_kernel(): use_jit=False, hint_override=hint_override, tma_store=tma_store, + transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override, **kernel_options, ) @@ -1936,6 +1959,7 @@ def generate( # type: ignore[override] generate_with_caching=False, hint_override: Optional[int] = None, tma_store: bool = False, + transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None, **kwargs, ): """This function generates a TritonTemplateCaller @@ -1982,6 +2006,7 @@ def generate( # type: ignore[override] generate_with_caching and self._cache_codegen_enabled_for_template, hint_override=hint_override, tma_store=tma_store, + transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override, ) # May happen as result of dev by 0. @@ -2045,6 +2070,7 @@ def make_kernel_render(out_node, hint_override: Optional[int] = None): use_jit=False, hint_override=hint_override, tma_store=tma_store, + transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override, **options, ) render = functools.partial( diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 9df8d114ef67b..68a34f5d1d2f1 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -1777,6 +1777,7 @@ def _get_template_configs_impl( "TMA_SIZE": TMA_DESCRIPTOR_SIZE, "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), "tma_store": config.triton.enable_template_tma_store, + "transpose_discontiguous_tensor_descriptors_override": True, } # Get base template configs from superclass for template_kwargs in super()._get_template_configs_impl( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f029a2e73f038..a45d9c0275b73 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1370,15 +1370,27 @@ def fresh_cache( fresh_inductor_cache = fresh_cache -def argsort(seq: Sequence[Any]) -> list[int]: - # preserve original order for equal strides +def argsort(seq: Sequence[Any], *, reverse: bool = False) -> list[int]: getter = seq.__getitem__ a_r = range(len(seq)) - return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 + # preserve original order for equal strides + # e.g. if strides are [32, 8, 8, 1] + # argsort -> [3, 2, 1, 0], rather than + # [3, 1, 2, 0] + # i.e. for equal strides in ascending order (reverse=False) an + # inner dimension should come before an outer dimension, and vice versa + # for descending + sort_idx = list(sorted(a_r, key=getter, reverse=True)) # noqa: C413 + if not reverse: + return list(reversed(sort_idx)) + return sort_idx def argsort_sym( - shape_env: ShapeEnv, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]] + shape_env: ShapeEnv, + seq: Sequence[Union[int, torch.SymInt, sympy.Expr]], + *, + reverse: bool = False, ) -> list[int]: def cmp(a: tuple[int, sympy.Expr], b: tuple[int, sympy.Expr]) -> int: a_idx, a_val = a @@ -1408,7 +1420,7 @@ def evaluate(expr: Union[bool, torch.SymInt, sympy.Expr]) -> bool: (idx, s.node.expr if isinstance(s, torch.SymInt) else s) for idx, s in enumerate(seq) ] - exprs = sorted(exprs, key=functools.cmp_to_key(cmp)) + exprs = sorted(exprs, key=functools.cmp_to_key(cmp), reverse=reverse) result = [idx for idx, _ in exprs] return result From ded9bcd61a059bf723e6e84689552962b480ea77 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 1 Dec 2025 09:49:57 +0000 Subject: [PATCH 078/338] Remove the CUPTI CMake check for kineto (#161370) This PR removes the CUPTI check because kineto has always linked to `CUDA::cupti`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161370 Approved by: https://github.com/Skylion007 --- cmake/Dependencies.cmake | 70 ---------------------------------------- 1 file changed, 70 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 733183ef50bd5..4df8ba4a784b4 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1637,76 +1637,6 @@ if(USE_KINETO) message(STATUS " KINETO_BUILD_TESTS = ${KINETO_BUILD_TESTS}") message(STATUS " KINETO_LIBRARY_TYPE = ${KINETO_LIBRARY_TYPE}") - if(NOT LIBKINETO_NOCUPTI) - set(CUDA_SOURCE_DIR "${CUDA_TOOLKIT_ROOT_DIR}" CACHE STRING "") - message(STATUS " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}") - message(STATUS " CUDA_INCLUDE_DIRS = ${CUDA_INCLUDE_DIRS}") - - if(NOT MSVC) - if(USE_CUPTI_SO) - set(CUPTI_LIB_NAME "libcupti.so") - else() - set(CUPTI_LIB_NAME "libcupti_static.a") - endif() - else() - set(CUPTI_LIB_NAME "cupti.lib") - endif() - - find_library(CUPTI_LIBRARY_PATH ${CUPTI_LIB_NAME} PATHS - ${CUDA_SOURCE_DIR} - ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64 - ${CUDA_SOURCE_DIR}/lib - ${CUDA_SOURCE_DIR}/lib64 - NO_DEFAULT_PATH) - - find_path(CUPTI_INCLUDE_DIR cupti.h PATHS - ${CUDA_SOURCE_DIR}/extras/CUPTI/include - ${CUDA_INCLUDE_DIRS} - ${CUDA_SOURCE_DIR} - ${CUDA_SOURCE_DIR}/include - NO_DEFAULT_PATH) - - if(CUPTI_LIBRARY_PATH AND CUPTI_INCLUDE_DIR) - message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}") - set(CUDA_cupti_LIBRARY ${CUPTI_LIBRARY_PATH}) - message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}") - message(STATUS "Found CUPTI") - set(LIBKINETO_NOCUPTI OFF CACHE STRING "" FORCE) - - # I've only tested this sanity check on Linux; if someone - # runs into this bug on another platform feel free to - # generalize it accordingly - if(NOT USE_CUPTI_SO AND UNIX) - include(CheckCXXSourceRuns) - # rt is handled by the CMAKE_REQUIRED_LIBRARIES set above - if(NOT APPLE) - set(CMAKE_REQUIRED_LIBRARIES ${CMAKE_REQUIRED_LIBRARIES} "dl" "pthread") - endif() - set(CMAKE_REQUIRED_LINK_OPTIONS "-Wl,--whole-archive,${CUPTI_LIBRARY_PATH},--no-whole-archive") - check_cxx_source_runs("#include - int main() { - try { - throw std::runtime_error(\"error\"); - } catch (...) { - return 0; - } - return 1; - }" EXCEPTIONS_WORK) - set(CMAKE_REQUIRED_LINK_OPTIONS "") - if(NOT EXCEPTIONS_WORK) - message(FATAL_ERROR - "Detected that statically linking against CUPTI causes exceptions to stop working. " - "See https://github.com/pytorch/pytorch/issues/57744 for more details. " - "Perhaps try: USE_CUPTI_SO=1 CMAKE_FRESH=1 python -m pip install -e . -v --no-build-isolation") - endif() - endif() - - else() - message(STATUS "Could not find CUPTI library, using CPU-only Kineto build") - set(LIBKINETO_NOCUPTI ON CACHE STRING "" FORCE) - endif() - endif() - if(NOT LIBKINETO_NOROCTRACER) if("$ENV{ROCM_SOURCE_DIR}" STREQUAL "") set(ENV{ROCM_SOURCE_DIR} "/opt/rocm") From 5f0030ba63d334d7e8c93a09e41403b89e4c573c Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 28 Nov 2025 09:16:16 +0000 Subject: [PATCH 079/338] [xpu][fix] Support xpu custom raw_alloc/delete in caching allocator (#168957) # Motivation Memory Pool needs to support the custom `raw_alloc` and `raw_delete` from a custom allocator. # Solution When the custom allocator is provided in the memory pool, use its `raw_alloc` and `raw_delete`. Otherwise, use the `sycl::aligned_alloc_device` and `sycl::free` from SYCL runtime. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168957 Approved by: https://github.com/EikanWang, https://github.com/gujinghui ghstack dependencies: #168956 --- c10/xpu/XPUCachingAllocator.cpp | 59 ++++++++++++++++++++++----------- c10/xpu/XPUCachingAllocator.h | 8 +++-- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index d97388c8703be..dfcccc94c9e32 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -15,8 +15,6 @@ using namespace c10::CachingDeviceAllocator; // newly allocated memory with 512-byte alignment. constexpr size_t kDeviceAlignment = 512; -class XPUAllocator; - namespace { using stream_set = ska::flat_hash_set; @@ -393,6 +391,26 @@ struct MempoolIdHash { } }; +void allocPrimitive(void** ptr, size_t size, AllocParams& p) { + if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) { + *ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size); + } else { + *ptr = sycl::aligned_alloc_device( + kDeviceAlignment, + size, + xpu::get_raw_device(p.device()), + xpu::get_device_context()); + } +} + +void deletePrimitive(void* ptr, BlockPool* pool) { + if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) { + pool->owner_PrivatePool->allocator()->raw_delete(ptr); + } else { + sycl::free(ptr, xpu::get_device_context()); + } +} + } // anonymous namespace class DeviceCachingAllocator { @@ -713,7 +731,8 @@ class DeviceCachingAllocator { bool alloc_block(AllocParams& p, bool isRetry) { auto size = p.alloc_size; - auto device = p.device(); + void* ptr = nullptr; + if (isRetry) { stats.num_alloc_retries += 1; } @@ -728,27 +747,24 @@ class DeviceCachingAllocator { TORCH_CHECK( !active_pool, "torch.xpu.MemPool doesn't currently support expandable_segments."); - p.block = - try_allocate_expandable_block(device, p.queue(), p.pool, p.size()); + p.block = try_allocate_expandable_block( + p.device(), p.queue(), p.pool, p.size()); if (p.block && p.pool->owner_PrivatePool) { // The block is used only for XPU graph's PrivatePool. p.pool->owner_PrivatePool->allocation_count++; } return bool(p.block); - } - void* ptr = sycl::aligned_alloc_device( - kDeviceAlignment, - size, - xpu::get_raw_device(device), - xpu::get_device_context()); - if (!ptr) { - return false; + } else { + allocPrimitive(&ptr, size, p); + if (!ptr) { + return false; + } } if (p.pool->owner_PrivatePool) { p.pool->owner_PrivatePool->allocation_count++; } - p.block = new Block(device, p.queue(), size, p.pool, ptr); + p.block = new Block(p.device(), p.queue(), size, p.pool, ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { stats.reserved_bytes[stat_type].increase(size); }); @@ -806,8 +822,13 @@ class DeviceCachingAllocator { * guarantee that all kernels can access to the blocks have finished. */ TORCH_INTERNAL_ASSERT(!block->expandable_segment); - sycl::free(block->ptr, xpu::get_device_context()); auto* pool = block->pool; + deletePrimitive(block->ptr, pool); + + if (pool->owner_PrivatePool) { + TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->allocation_count > 0); + pool->owner_PrivatePool->allocation_count--; + } pool->blocks.erase(block); StatTypes stat_types = get_stat_types_for_pool(*pool); @@ -1297,7 +1318,7 @@ class DeviceCachingAllocator { static void local_raw_delete(void* ptr); -class XPUAllocator : public DeviceAllocator { +class NativeCachingAllocator : public XPUAllocator { private: alignas(hardware_destructive_interference_size) std::mutex mutex; ska::flat_hash_map allocated_blocks; @@ -1413,7 +1434,7 @@ class XPUAllocator : public DeviceAllocator { return &local_raw_delete; } - void* raw_alloc(size_t size) { + void* raw_alloc(size_t size) override { if (size == 0) { return nullptr; } @@ -1433,7 +1454,7 @@ class XPUAllocator : public DeviceAllocator { return r; } - void raw_delete(void* ptr) { + void raw_delete(void* ptr) override { this->free(ptr); } @@ -1517,7 +1538,7 @@ class XPUAllocator : public DeviceAllocator { } }; -static XPUAllocator allocator; +static NativeCachingAllocator allocator; void local_raw_delete(void* ptr) { allocator.free(ptr); diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index c55de309032e0..0054e359e77fe 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -6,6 +6,12 @@ namespace c10::xpu::XPUCachingAllocator { +class XPUAllocator : public DeviceAllocator { + public: + virtual void* raw_alloc(size_t nbytes) = 0; + virtual void raw_delete(void* ptr) = 0; +}; + C10_XPU_API Allocator* get(); C10_XPU_API void init(DeviceIndex device_count); @@ -33,8 +39,6 @@ C10_XPU_API double getMemoryFraction(DeviceIndex device); C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device); -class XPUAllocator; - C10_XPU_API void createOrIncrefPool( c10::DeviceIndex device, c10::MempoolId_t mempool_id, From 481e5ab336275bd3acd5fa8a611b05b4469012af Mon Sep 17 00:00:00 2001 From: Aaron Pollack Date: Mon, 1 Dec 2025 17:40:55 +0000 Subject: [PATCH 080/338] Replace vscode recommendation for type checker (#169021) Now that Pyrefly type checks instead of mypy we should recommend the correct vscode extension image Pull Request resolved: https://github.com/pytorch/pytorch/pull/169021 Approved by: https://github.com/maggiemoss, https://github.com/albanD, https://github.com/malfet --- .vscode/extensions.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/extensions.json b/.vscode/extensions.json index e6d0ebc6afc1e..b52a56ab6833a 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -3,7 +3,7 @@ "ms-python.python", "charliermarsh.ruff", "ms-python.flake8", - "ms-python.mypy-type-checker", + "meta.pyrefly", "ms-vscode.cmake-tools", "EditorConfig.EditorConfig", "streetsidesoftware.code-spell-checker", From 1ee32a8b1f554a312d79bad01ded24f38cd95543 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Mon, 1 Dec 2025 18:42:00 +0000 Subject: [PATCH 081/338] [reland][inductor] fix the decision of inner reduction (#168391) Summary: reland https://github.com/pytorch/pytorch/pull/167697 from fbcode. Test Plan: CI and new tests. Differential Revision: D87678489 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168391 Approved by: https://github.com/desertfire --- test/inductor/test_mix_order_reduction.py | 19 +++++++++++++++++-- test/inductor/test_torchinductor.py | 15 +++++++++++++++ torch/_inductor/ir.py | 4 +++- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 7eee686ea99b0..d1486715ba526 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -270,11 +270,20 @@ def f(x, y): ], ) @parametrize("split_reductions", (False, True)) - @parametrize("shape", ((32768, 2048), (32768, 768), (32768 + 1023, 768))) + @parametrize( + "shape", ((1000000, 256), (32768, 2048), (32768, 768), (32768 + 1023, 768)) + ) @parametrize("max_autotune", (False, True)) @parametrize("initial_xblock", (1, 2)) + @parametrize("add_1dim", (False, True)) def test_rms_norm_bwd( - self, wdtype, split_reductions, shape, max_autotune, initial_xblock + self, + wdtype, + split_reductions, + shape, + max_autotune, + initial_xblock, + add_1dim, ): # max_autotune can be slow and cost resource, trim down the tests # for max autotune @@ -287,6 +296,9 @@ def test_rms_norm_bwd( ): self.skipTest("Skip non-critical tests to save resources.") + if shape != (1000000, 256) and add_1dim: + self.skipTest("Skip non-critical tests to save resources.") + def f(x, w, eps): orig_dtype = x.dtype @@ -307,6 +319,9 @@ def fwd_bwd(f): # M, N = 1152 * 500, 384 M, N = shape x = torch.randn(M, N, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True) + if add_1dim: + x = x[:, None, :] + w = torch.randn(N, dtype=wdtype, device=GPU_TYPE, requires_grad=True) dy = torch.randn_like(x) eps = 1e-5 diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b1cea5eac77d7..d3585bdb1d317 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14733,6 +14733,21 @@ def test_weight_norm_conv2d(self): self.assertTrue(same((ref, ref_grad), (act, act_grad), tol=1e-3)) + @skipIfMPS + def test_inner_reduction_detection(self): + if self.device == "cpu": + self.skipTest("Skip for CPU device") + + x = torch.randn(100000, 1, 256, device=self.device) + + @torch.compile + def f(x): + return x.sum(dim=(0, 1)) + + code = run_and_get_triton_code(f, x) + self.assertTrue("ReductionHint.OUTER" in code) + self.assertFalse("ReductionHint.INNER" in code) + @skip_if_halide @requires_cuda_and_triton @skip_if_cpp_wrapper("skip cpp wrapper") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 6a2183f42886a..b4bc3bbf19e88 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1442,7 +1442,9 @@ def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]: strides = V.graph.sizevars.stride_hints( j, reduction_vars, list(ranges1.keys()) ) - outer = all(s > 1 for s in strides) + # A 0 stride does not make a reduction contiguous. + # This can happen when the reduction ranges contains a 1. + outer = all(s == 0 or s > 1 for s in strides) if outer: num_outer += 1 else: From a7dc6dab9ad911259d4801c502907e531594db45 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 1 Dec 2025 18:58:14 +0000 Subject: [PATCH 082/338] bump transformer pin to 4.57.3 (#169226) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169226 Approved by: https://github.com/anijain2305 --- .ci/docker/ci_commit_pins/huggingface-requirements.txt | 2 +- .../ci_expected_accuracy/inductor_huggingface_inference.csv | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/huggingface-requirements.txt b/.ci/docker/ci_commit_pins/huggingface-requirements.txt index f4f3830136eb6..e542372178a16 100644 --- a/.ci/docker/ci_commit_pins/huggingface-requirements.txt +++ b/.ci/docker/ci_commit_pins/huggingface-requirements.txt @@ -1,2 +1,2 @@ -transformers==4.56.0 +transformers==4.57.3 soxr==0.5.0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv index 54914c1395e17..bf3b3c0633a03 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv @@ -122,7 +122,7 @@ google/gemma-3-4b-it,pass_due_to_skip,0 -openai/whisper-tiny,pass,0 +openai/whisper-tiny,pass,5 From cddec6562eabfa390d014fa3741a5659cf9c94c9 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Mon, 1 Dec 2025 19:56:43 +0000 Subject: [PATCH 083/338] [mergebot] Set header for workflow when calling Dr.CI (#169037) Same spirit as https://github.com/pytorch/test-infra/pull/7513 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169037 Approved by: https://github.com/clee2000, https://github.com/yangw-dev --- .github/scripts/trymerge.py | 1 + .github/workflows/trymerge.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 2100f96427574..4f910ccfe0f68 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -1789,6 +1789,7 @@ def get_drci_classifications(pr_num: int, project: str = "pytorch") -> Any: headers={ "Authorization": os.getenv("DRCI_BOT_KEY", ""), "Accept": "application/vnd.github.v3+json", + "x-hud-internal-bot": os.getenv("HUD_API_TOKEN", ""), }, method="POST", reader=json.load, diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml index 5c456c607c887..f625ce8b715a3 100644 --- a/.github/workflows/trymerge.yml +++ b/.github/workflows/trymerge.yml @@ -45,6 +45,7 @@ jobs: IGNORE_CURRENT: ${{ github.event.client_payload.ignore_current }} DRCI_BOT_KEY: ${{ secrets.DRCI_BOT_KEY }} GITHUB_RUN_ID: ${{ github.run_id }} + HUD_API_TOKEN: ${{ secrets.HUD_API_TOKEN }} run: | set -x if [ -n "${REBASE}" ]; then From f49d32dfa4730dcfb1b60eeeb369b5889da983c8 Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 1 Dec 2025 07:09:39 -0800 Subject: [PATCH 084/338] bucketing compile time improve (#168122) Strict compile time improvement. We always maintain that start -> hiding nodes -> wait. Add start, to hiding nodes ancestors, and hiding nodes to wait ancestors, to minimize repeated graph searches by precomputing the dependencies. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168122 Approved by: https://github.com/IvanKobzarev --- .../test_overlap_bucketing_unit.py | 47 ++------------ .../fx_passes/overlap_manual_scheduling.py | 1 - .../fx_passes/overlap_preserving_bucketer.py | 65 +++++++++++++++---- .../_inductor/fx_passes/overlap_scheduling.py | 1 - 4 files changed, 58 insertions(+), 56 deletions(-) diff --git a/test/distributed/test_overlap_bucketing_unit.py b/test/distributed/test_overlap_bucketing_unit.py index c0c4c31cc1a81..48103c7ada713 100644 --- a/test/distributed/test_overlap_bucketing_unit.py +++ b/test/distributed/test_overlap_bucketing_unit.py @@ -93,28 +93,6 @@ def build_collective_info(graph, hiding_annotations): return collective_info -def compute_ancestors(graph): - """Compute ancestor sets for all nodes in the graph.""" - node_ancestors = {} - - for node in graph.nodes: - ancestors = OrderedSet() - stack = list(node.all_input_nodes) - visited = set() - - while stack: - current = stack.pop() - if current in visited: - continue - visited.add(current) - ancestors.add(current) - stack.extend(current.all_input_nodes) - - node_ancestors[node] = ancestors - - return node_ancestors - - @requires_accelerator_dist_backend() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @instantiate_parametrized_tests @@ -190,9 +168,7 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and ancestors collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -203,7 +179,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -278,9 +253,8 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -291,7 +265,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -381,9 +354,8 @@ def func(a, b, c): if final_mm_hidden: hiding_annotations[rs] = mm2 - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing logic to find buckets (without applying them, which would require process groups) @@ -394,7 +366,6 @@ def func(a, b, c): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) @@ -467,7 +438,6 @@ def func(a, b): # Build collective info collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -478,7 +448,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -550,9 +519,8 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing with multidtype mode @@ -563,7 +531,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, bucket_mode="custom_ops_multidtype", ) @@ -635,9 +602,8 @@ def func(a, b): ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3 } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Verify hiding_nodes are correctly set @@ -656,7 +622,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -729,9 +694,8 @@ def func(a, b, c): ag3: mm, } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -742,7 +706,6 @@ def func(a, b, c): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() diff --git a/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/torch/_inductor/fx_passes/overlap_manual_scheduling.py index c8af70dc598f4..d2c8b588d2011 100644 --- a/torch/_inductor/fx_passes/overlap_manual_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_manual_scheduling.py @@ -182,7 +182,6 @@ def __init__( self.bucketer = ManualOverlapPreservingBucketer( graph=self.graph, collective_info=self.collective_info, - node_ancestors=self.node_ancestors, node_users=self.node_users, scheduled=OrderedSet(self.graph.nodes), ) diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index b5ef930b8fa8f..eb239a3a219a6 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -1,3 +1,4 @@ +import itertools import logging from collections import defaultdict from dataclasses import dataclass @@ -130,7 +131,6 @@ def __init__( self, graph: fx.Graph, collective_info: dict[fx.Node, CollectiveInfo], - node_ancestors: dict[fx.Node, OrderedSet[fx.Node]], scheduled: OrderedSet[fx.Node], max_bucket_memory_gb: float = 1.0, max_coll_distance: int = 1000, @@ -139,19 +139,46 @@ def __init__( ): self.graph = graph self.collective_info = collective_info - self.node_ancestors = node_ancestors self.scheduled = scheduled self.max_bucket_memory_gb = max_bucket_memory_gb self.node_idx = {n: i for i, n in enumerate(scheduled)} - self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) self.max_coll_distance = max_coll_distance self.insert_overlap_deps = insert_overlap_deps self.bucket_mode = bucket_mode self.node_to_event: dict[fx.Node, PGEvent] = {} - self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines() + # Compute ancestors including original graph edges and hiding interval dependencies + self.node_ancestors = self._compute_node_ancestors() + self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) + + # Build timelines and add constraints to aug_graph + self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines() self._add_hiding_interval_constraints() + def _compute_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]: + """ + Compute ancestor sets for all nodes including: + 1. Original graph edges + 2. Hiding interval deps: collective_start -> hiding_node -> wait + """ + augmented_inputs: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for start, info in self.collective_info.items(): + if info.is_exposed: + continue + for hiding_node in info.hiding_nodes: + augmented_inputs[hiding_node].add(start) + augmented_inputs[info.wait_node].add(hiding_node) + + node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for node in self.scheduled: + for input_node in itertools.chain( + augmented_inputs[node], node.all_input_nodes + ): + node_ancestors[node].add(input_node) + node_ancestors[node] |= node_ancestors[input_node] + + return node_ancestors + def build_timelines(self) -> dict[str, Optional[PGEvent]]: "Construct each process groups ordered series of event" all_pgs: OrderedSet[str] = OrderedSet() @@ -337,21 +364,30 @@ def _find_buckets( ) processed.add(start_node) + # Greedy optimization: stop after consecutive failures + consecutive_failures = 0 + max_consecutive_failures = 20 + # Check candidates in sorted order, break when beyond max distance for candidate in sorted_collectives[i + 1 : i + 1 + self.max_coll_distance]: - if candidate in processed: - continue - candidate_bytes = self.collective_info[candidate].size_bytes # proxy on memory use, if we see a too large bucket, # dont look for another, later bucket if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes: break + if candidate in processed: + continue + if self._can_add_to_bucket(bucket_info, candidate): bucket_info.collectives.append(candidate) bucket_info.total_bytes += candidate_bytes processed.add(candidate) + consecutive_failures = 0 # Reset on success + else: + consecutive_failures += 1 + if consecutive_failures >= max_consecutive_failures: + break if len(bucket_info.collectives) > 1: buckets.append(bucket_info) @@ -656,23 +692,28 @@ def _has_ancestor_conflicts( candidate_wait = candidate_info.wait_node for coll in bucket_info.collectives: - # Check if collectives are ancestors of each other - if self._ancestor_dep(coll, candidate): + if ( + coll in self.node_ancestors[candidate] + or candidate in self.node_ancestors[coll] + ): return True # Check if waits are ancestors of each other coll_wait = self.collective_info[coll].wait_node - if self._ancestor_dep(candidate_wait, coll_wait): + if ( + coll_wait in self.node_ancestors[candidate_wait] + or candidate_wait in self.node_ancestors[coll_wait] + ): return True # Check if existing hiding node conflicts with candidate wait for old_hiding_node in self.collective_info[coll].hiding_nodes: - if self._ancestor_dep(old_hiding_node, candidate_wait): + if candidate_wait in self.node_ancestors[old_hiding_node]: return True # Check if candidate hiding node conflicts with existing wait for new_hiding_node in candidate_info.hiding_nodes: - if self._ancestor_dep(new_hiding_node, coll_wait): + if coll_wait in self.node_ancestors[new_hiding_node]: return True return False diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 436a3ab0db81b..14555c84b43ce 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -1125,7 +1125,6 @@ def _bucket_collectives(self) -> None: bucketer = OverlapPreservingBucketer( graph=self.graph, collective_info=self.collective_info, - node_ancestors=self.node_ancestors, scheduled=self.scheduled, max_bucket_memory_gb=2.0, # Could make this configurable max_coll_distance=self.max_node_distance, From 90f0139e64b2951815d524b6a373bed20c4fbf90 Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 1 Dec 2025 07:09:39 -0800 Subject: [PATCH 085/338] Track overlap per-pg (#169019) Track process groups separately for overlap, and if we were to schedule an exposed wait, schedule collectives on other pgs for overlap. Also, schedule available reduce scatters before pre fetching to reduce memory. The waits of these nodes will still be sunk. Since we limit pre fetching based on original schedule memory, we'd need to be careful about not scheduling reduce scatters bc that can increase memory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169019 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #168122 --- .../test_aten_comm_compute_reordering.py | 2 +- .../test_overlap_bucketing_unit.py | 103 ++++++++++ .../fx_passes/overlap_preserving_bucketer.py | 16 ++ .../_inductor/fx_passes/overlap_scheduling.py | 192 ++++++++++++------ 4 files changed, 252 insertions(+), 61 deletions(-) diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 966f84ff0ee56..60488496d0ffb 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -1079,7 +1079,7 @@ def func(a): out, aten_graph_str = run_and_get_aten_graph(compiled, inputs) # Verify all three collective types are present - FileCheck().check("all_reduce").check("all_gather").check( + FileCheck().check_dag("all_reduce").check_dag("all_gather").check_dag( "reduce_scatter" ).run(aten_graph_str) diff --git a/test/distributed/test_overlap_bucketing_unit.py b/test/distributed/test_overlap_bucketing_unit.py index 48103c7ada713..2fe705e0c23b6 100644 --- a/test/distributed/test_overlap_bucketing_unit.py +++ b/test/distributed/test_overlap_bucketing_unit.py @@ -10,6 +10,7 @@ # for some reason importing functional collectives after dynamo breaks collectives handling! from torch._C import FileCheck +from torch._dynamo.utils import counters from torch._inductor.test_case import TestCase as InductorTestCase from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import make_fx @@ -168,6 +169,7 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) scheduled = OrderedSet(traced.graph.nodes) @@ -719,5 +721,106 @@ def func(a, b, c): ).run(graph_str) +@requires_accelerator_dist_backend(["nccl", "xccl"]) +@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") +class TestCrossPGOverlap(InductorTestCase): + """ + Tests for cross-PG overlap scheduling. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + from torch.testing._internal.distributed.fake_pg import FakeStore + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + cls.device = "cuda" + + # Create two separate process groups for cross-PG testing + cls.pg1 = dist.new_group(ranks=[0, 1]) + cls.pg2 = dist.new_group(ranks=[0, 1]) + cls.pg1_name = cls.pg1.group_name + cls.pg2_name = cls.pg2.group_name + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + dist.destroy_process_group(cls.pg1) + dist.destroy_process_group(cls.pg2) + dist.destroy_process_group() + + def test_cross_pg_prefetch_during_exposed_wait(self): + """ + Test that ag2 on PG2 gets prefetched during exposed wait of ag1 on PG1. + """ + pg1_name = self.pg1_name + pg2_name = self.pg2_name + + def func(a, b): + group_size = 1 + + # First collective on PG1 + ag1 = torch.ops._c10d_functional.all_gather_into_tensor( + a, group_size, pg1_name + ) + ag1_out = torch.ops._c10d_functional.wait_tensor(ag1) + mm1 = torch.mm(ag1_out[:4, :4], ag1_out[:4, :4]) + + # Second collective on PG2 + ag2 = torch.ops._c10d_functional.all_gather_into_tensor( + b, group_size, pg2_name + ) + ag2_out = torch.ops._c10d_functional.wait_tensor(ag2) + mm2 = torch.mm(ag2_out[:4, :4], ag2_out[:4, :4]) + + return mm1 + mm2 + + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device) + b = torch.ones(4, 4, device=self.device) * 2 + + traced = make_fx(func)(a, b) + + # Find nodes + ag1, ag2 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + wait1, wait2 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.wait_tensor.default, + ) + mm1, mm2 = traced.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + + def custom_runtime(node: fx.Node, override_size: int | None) -> float | None: + if "all_gather" in str(node.target): + return 10.0 + return 0.0 + + # Run overlap scheduler + from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler + + scheduler = OverlapScheduler( + traced, + max_in_flight_gb=5.0, + max_compute_pre_fetch=200, + collective_bucketing=False, + insert_overlap_deps=False, + compute_overlap_multipler=1.0, + max_coll_distance=200, + custom_runtime_estimation=custom_runtime, + collective_estimator="analytical", + ) + out = scheduler.run() + FileCheck().check("%all_gather_into_tensor").check( + "%all_gather_into_tensor" + ).check("%wait_tensor").run(str(out.graph)) + + self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 1) + + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index eb239a3a219a6..7fc456f388deb 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -146,6 +146,7 @@ def __init__( self.insert_overlap_deps = insert_overlap_deps self.bucket_mode = bucket_mode self.node_to_event: dict[fx.Node, PGEvent] = {} + self.all_hiding_nodes: OrderedSet[fx.Node] = OrderedSet() # Compute ancestors including original graph edges and hiding interval dependencies self.node_ancestors = self._compute_node_ancestors() @@ -260,6 +261,8 @@ def _add_hiding_interval_constraints(self) -> None: self.aug_graph.add_extra_dep(n=hn, dep=start) self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn) + self.all_hiding_nodes |= info.hiding_nodes + def bucket_collectives(self) -> None: # Group collectives by PG first pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet) @@ -357,6 +360,12 @@ def _find_buckets( if start_node in processed: continue + if ( + start_node in self.all_hiding_nodes + or self.collective_info[start_node].wait_node in self.all_hiding_nodes + ): + continue + # Initialize bucket with first collective bucket_info = CollBucket( collectives=[start_node], @@ -740,6 +749,13 @@ def _can_add_to_bucket( candidate_info = self.collective_info[candidate] + if ( + candidate in self.all_hiding_nodes + or candidate_info.wait_node in self.all_hiding_nodes + ): + why("nyi: bucketing collective used for overlap") + return False + # Step 1: Quick check using precomputed ancestors # These ancestors are computed prior to adding augmented dependencies and not updated, # so if any of these checks fail then the merge will not be topologically valid diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 14555c84b43ce..6e5971b68e4fb 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -101,6 +101,11 @@ def is_compute_node(n: fx.Node) -> bool: ) +def is_reduce_scatter(n: fx.Node) -> bool: + """Check if node is a reduce_scatter collective.""" + return "reduce_scatter" in str(n.target).lower() + + def get_hint(x: int | torch.SymInt) -> int | None: if isinstance(x, int): return x @@ -446,7 +451,9 @@ def off_compute_path(self, n: fx.Node) -> bool: return self.compute_index_domination[n] == sys.maxsize def _identify_collectives(self) -> None: - """Identify all collective operations.""" + """Identify all collective operations and process groups.""" + self.all_pgs: OrderedSet[str] = OrderedSet() + for node in self.nodes: if _schedulable_wait_node(node): start = node.args[0] @@ -464,6 +471,7 @@ def _identify_collectives(self) -> None: self.collective_info[start] = info self.wait_to_start[node] = start self.unscheduled_collectives.add(start) + self.all_pgs.add(get_group_name(start)) def _calculate_compute_node_domination_index(self) -> dict[fx.Node, int]: """ @@ -719,21 +727,40 @@ def get_non_collective_runtime_estimate(self, node: fx.Node) -> float | None: return self.custom_runtime_estimation(node, None) def _reduce_exposed_time_of_in_flight_collectives( - self, node: fx.Node, available_compute: float - ) -> float: - """Reduce exposed time of in-flight collectives using available compute time and return available time""" + self, + node: fx.Node, + available_compute: float, + exclude_pg: str | None = None, + ) -> dict[str, float]: + """ + Reduce exposed time of in-flight collectives using available compute time. + + Collectives on different process groups can overlap simultaneously with the same + compute, so we track remaining time separately per PG. + """ + # Initialize all PGs with full available compute (except excluded) + remaining_time_per_pg: dict[str, float] = { + pg: available_compute for pg in self.all_pgs if pg != exclude_pg + } - # TODO: separate overlap time per process group - for info in self.in_flight.values(): + for start_node, info in self.in_flight.items(): if info.exposed_time_ms == 0: continue - overlap_amount = min(info.exposed_time_ms, available_compute) + + pg_name = get_group_name(start_node) + if pg_name == exclude_pg: + continue + + pg_remaining = remaining_time_per_pg[pg_name] + if pg_remaining <= 0: + continue + + overlap_amount = min(info.exposed_time_ms, pg_remaining) info.exposed_time_ms -= overlap_amount - available_compute -= overlap_amount + remaining_time_per_pg[pg_name] -= overlap_amount info.hiding_nodes.add(node) - if available_compute == 0: - break - return available_compute + + return remaining_time_per_pg def _handle_compute_or_other(self, node: fx.Node) -> None: """Handle scheduling compute or other nodes and attempt to overlap with collectives.""" @@ -747,12 +774,13 @@ def _handle_compute_or_other(self, node: fx.Node) -> None: return available_compute = runtime_estimate * self.compute_overlap_multipler - initial_compute = available_compute # Track initial compute time for wasted compute/path calculations - available_compute = self._reduce_exposed_time_of_in_flight_collectives( + # First, reduce exposed time of in-flight collectives (per PG) + remaining_time_per_pg = self._reduce_exposed_time_of_in_flight_collectives( node, available_compute ) - self._schedule_collectives_for_overlap(node, available_compute, initial_compute) + # Then, schedule new collectives for overlap + self._schedule_collectives_for_overlap(node, remaining_time_per_pg) self._schedule(node) if is_compute_node(node): @@ -871,26 +899,48 @@ def _handle_wait(self, node: fx.Node) -> None: for coll_to_schedule in to_schedule: self._handle_wait(self.collective_info[coll_to_schedule].wait_node) + # If we are waiting on an exposed collective, use this time to + # overlap on other PGs. + info = self.collective_info[coll_start] + if info.exposed_time_ms > 0: + exposed_time = info.exposed_time_ms + exclude_pg = group_name + + remaining_time_per_pg = self._reduce_exposed_time_of_in_flight_collectives( + node, exposed_time, exclude_pg=exclude_pg + ) + self._schedule_collectives_for_overlap( + node, remaining_time_per_pg, exclude_pg=exclude_pg + ) + self.in_flight_bytes -= self.in_flight[coll_start].size_bytes del self.in_flight[coll_start] self._schedule(node) def _schedule_collectives_for_overlap( - self, compute_node: fx.Node, available_compute_time: float, initial_time: float + self, + overlap_node: fx.Node, + remaining_time_per_pg: dict[str, float], + exclude_pg: str | None = None, ) -> None: - """Opportunistically schedule collectives that can be hidden by compute.""" - if available_compute_time == 0: + """Opportunistically schedule collectives that can be hidden by available overlap time.""" + if not remaining_time_per_pg or all( + t <= 0 for t in remaining_time_per_pg.values() + ): return - reduced_time = initial_time - available_compute_time - compute_ancestors = self.node_ancestors[compute_node] + overlap_node_ancestors = self.node_ancestors[overlap_node] - # Compile-time filtering: limit candidates by distance to bound O(compute * collectives) cost + # Compile candidates - limit by distance to bound compile time candidates = [] for i, collective in enumerate(self.unscheduled_collectives): if i > self.max_node_distance: break + pg_name = get_group_name(collective) + if pg_name == exclude_pg: + continue + if ( not self.off_compute_path(collective) and self.compute_index_domination[collective] @@ -901,21 +951,31 @@ def _schedule_collectives_for_overlap( candidates.append(collective) - candidates = sorted( - candidates, - key=lambda n: (self.compute_index_domination[n], self.node_idx[n]), + # Sort candidates prioritizing: + # 1. reduce_scatter operations (reduce memory pressure) + # 2. Earlier domination index + # 3. Original order for stability + candidates.sort( + key=lambda n: ( + not is_reduce_scatter(n), # reduce_scatter first + self.compute_index_domination[n], + self.node_idx[n], + ), ) for collective in candidates: - if available_compute_time == 0: - break + pg_name = get_group_name(collective) + pg_available_time = remaining_time_per_pg[pg_name] + + if pg_available_time <= 0: + continue - why = WhyNoOverlap(compute_node, collective) + why = WhyNoOverlap(overlap_node, collective) info = self.collective_info[collective] if ( - collective in compute_ancestors - or compute_node in self.node_ancestors[collective] + collective in overlap_node_ancestors + or overlap_node in self.node_ancestors[collective] ): why("dependency conflict") continue @@ -925,10 +985,11 @@ def _schedule_collectives_for_overlap( why("prefetch would exceed memory budget") continue + # Try to free memory by forcing hidden waits while ( self.in_flight and (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes - and self._wait_is_hidden(self._get_oldest_wait(), compute_node) + and self._wait_is_hidden(self._get_oldest_wait(), overlap_node) ): self._force_oldest_wait() @@ -937,40 +998,44 @@ def _schedule_collectives_for_overlap( continue # Check if we can reach this collective without scheduling compute, other collectives, or waits - path = self._find_schedulable_path(collective, compute_node, why) + path = self._find_schedulable_path(collective, overlap_node, why) if path is None: continue log.debug( - "Overlapping collective %s with compute %s: coll_domination=%d, current_depth=%d", + "Overlapping collective %s with node %s: coll_domination=%d, current_depth=%d", collective.name, - compute_node.name, + overlap_node.name, self.compute_index_domination[collective], self.current_compute_index, ) - # Track compute runtime of nodes we must schedule to reach collective and - # add back available overlap time corresponding to prior in-flight collectives - path_estimates = [self.get_non_collective_runtime_estimate(p) for p in path] - path_time = sum(p for p in path_estimates if p is not None) - additional_time = min(path_time, reduced_time) - reduced_time -= additional_time - available_compute_time += additional_time + # TODO: We previously tracked path compute time and added it back to available + # overlap time. With per-PG tracking this is complex: if there were in-flight + # collectives on one PG but not another, we can't add path time back to the PG + # that wasn't in-flight - self._schedule_path_to_collective(path, compute_node) + # Schedule path and collective + self._schedule_path_to_collective(path, overlap_node) self._handle_collective_start(collective) self._update_cumulative_prefetch_memory(collective, info) - # Update exposed time - overlap_amount = min(available_compute_time, info.exposed_time_ms) + # Update exposed time for this collective + overlap_amount = min(pg_available_time, info.exposed_time_ms) info.exposed_time_ms -= overlap_amount - info.hiding_nodes.add(compute_node) - available_compute_time -= overlap_amount + info.hiding_nodes.add(overlap_node) + + # Update available time for this PG + remaining_time_per_pg[pg_name] -= overlap_amount + + if sum(remaining_time_per_pg.values()) == 0: + break - self.wasted_compute += available_compute_time + if remaining_time_per_pg: + self.wasted_compute += min(remaining_time_per_pg.values()) def _find_schedulable_path( - self, target: fx.Node, curr_compute_node: fx.Node | None, why: WhyNoOverlap + self, target: fx.Node, curr_overlap_node: fx.Node | None, why: WhyNoOverlap ) -> OrderedSet[fx.Node] | None: """Find path to target by collecting unscheduled dependencies.""" # Get unscheduled ancestors @@ -990,20 +1055,27 @@ def _find_schedulable_path( # current compute node we are scheduling, then we are effectively exposing it. # similarly, dont schedule a wait of a collective that could be otherwise hidden, # thus forcing it to be exposed. - # however, if it is already hidden or it cannot be possible hidden, - # it's fine to schedule it + # however, if it is already hidden it's fine to schedule it if _schedulable_wait_node(node): info = self.collective_info[self.wait_to_start[node]] - if info.hiding_nodes and curr_compute_node not in info.hiding_nodes: - why( - "path blocked by wait node %s with different hiding compute", - node.name, - ) - continue - elif node not in self.potentially_hidden_waits: - why("path blocked by wait node %s that could be hidden", node.name) + # Allow if fully hidden by other nodes + if not info.is_exposed and curr_overlap_node not in info.hiding_nodes: continue + why( + "path blocked by wait node %s (exposed=%s, hiding_nodes=%s)", + node.name, + info.is_exposed, + curr_overlap_node in info.hiding_nodes, + ) + + # Skip c10 ops and dtensor shard ops - they should be scheduled via main loop + target_str = str(node.target) + if "c10" in target_str or "_dtensor" in target_str: + log.debug( + "Skipping c10/dtensor op %s in path to collective", + node.name, + ) return None return unscheduled_ancestors @@ -1031,14 +1103,14 @@ def _get_oldest_wait(self) -> fx.Node: return self.collective_info[oldest_start].wait_node def _wait_is_hidden( - self, wait_node: fx.Node, compute_node: fx.Node | None = None + self, wait_node: fx.Node, overlap_node: fx.Node | None = None ) -> bool: assert is_wait_tensor(wait_node) info = self.collective_info[self.wait_to_start[wait_node]] - return not info.is_exposed and compute_node not in info.hiding_nodes + return not info.is_exposed and overlap_node not in info.hiding_nodes def _schedule_path_to_collective( - self, path: OrderedSet[fx.Node], curr_compute_node: fx.Node + self, path: OrderedSet[fx.Node], curr_overlap_node: fx.Node ) -> None: """Schedule all nodes needed to reach a collective.""" @@ -1054,7 +1126,7 @@ def _schedule_path_to_collective( continue info = self.collective_info[self.wait_to_start[node]] - assert curr_compute_node not in info.hiding_nodes + assert curr_overlap_node not in info.hiding_nodes self._handle_wait(node) continue From 9f7fceb887d0cfa0326a59b887821c63ff11340a Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Mon, 1 Dec 2025 22:14:57 +0000 Subject: [PATCH 086/338] [BE] Add missing method docstrings for pytorch quantization classes (#165199) Summary: This PR documents some apis in torch.ao and reorganizes quantization API documentation with aliases pages. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165199 Approved by: https://github.com/jerryzh168 --- docs/source/conf.py | 65 ----- docs/source/quantization-support.aliases.md | 267 ++++++++++++++++++ docs/source/quantization-support.md | 7 + docs/source/quantization.rst | 2 - torch/ao/nn/intrinsic/modules/fused.py | 2 + .../ao/nn/intrinsic/qat/modules/conv_fused.py | 12 + .../nn/intrinsic/qat/modules/linear_fused.py | 5 +- .../nn/intrinsic/quantized/modules/bn_relu.py | 6 + .../intrinsic/quantized/modules/conv_add.py | 6 + .../intrinsic/quantized/modules/conv_relu.py | 9 + 10 files changed, 312 insertions(+), 69 deletions(-) create mode 100644 docs/source/quantization-support.aliases.md diff --git a/docs/source/conf.py b/docs/source/conf.py index 99ce1e0b8db5d..7a3663ca062df 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2066,71 +2066,6 @@ "Quantize", # torch.utils.backcompat "Warning", - # torch.ao.nn.intrinsic.modules.fused - "ConvAdd2d", - "ConvAddReLU2d", - "LinearBn1d", - "LinearLeakyReLU", - "LinearTanh", - # torch.ao.nn.intrinsic.qat.modules.conv_fused - "ConvBnReLU1d", - "ConvBnReLU2d", - "ConvBnReLU3d", - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - # torch.ao.nn.intrinsic.qat.modules.linear_fused - "LinearBn1d", - # torch.ao.nn.intrinsic.qat.modules.linear_relu - "LinearReLU", - # torch.ao.nn.intrinsic.quantized.dynamic.modules.linear_relu - "LinearReLU", - # torch.ao.nn.intrinsic.quantized.modules.bn_relu - "BNReLU2d", - "BNReLU3d", - # torch.ao.nn.intrinsic.quantized.modules.conv_add - "ConvAdd2d", - "ConvAddReLU2d", - # torch.ao.nn.intrinsic.quantized.modules.conv_relu - "ConvReLU1d", - "ConvReLU2d", - "ConvReLU3d", - # torch.ao.nn.intrinsic.quantized.modules.linear_relu - "LinearLeakyReLU", - "LinearReLU", - "LinearTanh", - # torch.ao.nn.qat.modules.conv - "Conv1d", - "Conv2d", - "Conv3d", - # torch.ao.nn.qat.modules.embedding_ops - "Embedding", - "EmbeddingBag", - # torch.ao.nn.qat.modules.linear - "Linear", - # torch.ao.nn.quantizable.modules.activation - "MultiheadAttention", - # torch.ao.nn.quantizable.modules.rnn - "LSTM", - "LSTMCell", - # torch.ao.nn.quantized.dynamic.modules.conv - "Conv1d", - "Conv2d", - "Conv3d", - "ConvTranspose1d", - "ConvTranspose2d", - "ConvTranspose3d", - # torch.ao.nn.quantized.dynamic.modules.linear - "Linear", - # torch.ao.nn.quantized.dynamic.modules.rnn - "GRU", - "GRUCell", - "LSTM", - "LSTMCell", - "PackedParameter", - "RNNBase", - "RNNCell", - "RNNCellBase", # torch.ao.nn.quantized.modules.activation "ELU", "Hardswish", diff --git a/docs/source/quantization-support.aliases.md b/docs/source/quantization-support.aliases.md new file mode 100644 index 0000000000000..6d9e98c6135cc --- /dev/null +++ b/docs/source/quantization-support.aliases.md @@ -0,0 +1,267 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Aliases in torch.ao +The following are aliases to their counterparts in ``torch.ao`` in nested namespaces. + +## torch.ao.nn.intrinsic.qat.modules +The following are aliases to their counterparts in ``torch.ao.nn.intrinsic.qat`` in the ``torch.ao.nn.intrinsic.qat.module`` namespace. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.intrinsic.qat.modules +``` + +### torch.ao.nn.intrinsic.qat.modules.conv_fused (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + conv_fused.ConvReLU1d + conv_fused.ConvReLU2d + conv_fused.ConvReLU3d + conv_fused.ConvBnReLU1d + conv_fused.ConvBnReLU2d + conv_fused.ConvBnReLU3d +``` + +### torch.ao.nn.intrinsic.qat.modules.linear_fused (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + linear_fused.LinearBn1d +``` + +### torch.ao.nn.intrinsic.qat.modules.linear_relu (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + linear_relu.LinearReLU +``` + +## torch.ao.nn.intrinsic.quantized.modules +```{eval-rst} +.. currentmodule:: torch.ao.nn.intrinsic.quantized.modules +``` + +The following are aliases to their counterparts in ``torch.ao.nn.intrinsic.quantized`` in the ``torch.ao.nn.intrinsic.quantized.modules`` namespace. + +### torch.ao.nn.intrinsic.quantized.modules.conv_relu (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + conv_relu.ConvReLU1d + conv_relu.ConvReLU2d + conv_relu.ConvReLU3d +``` + +### torch.ao.nn.intrinsic.quantized.modules.bn_relu (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + bn_relu.BNReLU2d + bn_relu.BNReLU3d +``` + +### torch.ao.nn.intrinsic.quantized.modules.conv_add (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + conv_add.ConvAdd2d + conv_add.ConvAddReLU2d +``` + +### torch.ao.nn.intrinsic.quantized.modules.linear_relu (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + linear_relu.LinearLeakyReLU + linear_relu.LinearReLU + linear_relu.LinearTanh +``` + +## torch.ao.nn.intrinsic.quantized.dynamic.modules +```{eval-rst} +.. currentmodule:: torch.ao.nn.intrinsic.quantized.dynamic.modules +``` + +The following are aliases to their counterparts in the ``torch.ao.nn.intrinsic.quantized.dynamic`` namespace. + +### torch.ao.nn.intrinsic.quantized.dynamic.modules.linear_relu (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + linear_relu.LinearReLU +``` + +## torch.ao.nn.intrinsic.modules +```{eval-rst} +.. currentmodule:: torch.ao.nn.intrinsic.modules +``` +The following are aliases to their counterparts in the ``torch.ao.nn.intrinsic`` namespace. + +### torch.ao.nn.intrinsic.modules.fused (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + fused.ConvAdd2d + fused.ConvAddReLU2d + fused.LinearBn1d + fused.LinearLeakyReLU + fused.LinearTanh +``` + +## torch.ao.nn.intrinsic.modules.torch.ao.nn.qat.modules +```{eval-rst} +.. currentmodule:: torch.ao.nn.qat.modules +``` +The following are aliases to their counterparts in the ``torch.ao.nn.qat`` namespace. + +### torch.ao.nn.intrinsic.modules.conv (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + conv.Conv1d + conv.Conv2d + conv.Conv3d +``` + +### torch.ao.nn.intrinsic.modules.embedding_ops (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + embedding_ops.Embedding + embedding_ops.EmbeddingBag +``` + +### torch.ao.nn.intrinsic.modules.linear (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + linear.Linear +``` + +## torch.ao.nn.quantizable.modules +```{eval-rst} +.. currentmodule:: torch.ao.nn.quantizable.modules +``` + +The following are aliases to their counterparts in the ``torch.ao.nn.quantizable`` namespace. + +### torch.ao.nn.quantizable.modules.activation (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + activation.MultiheadAttention +``` + +### torch.ao.nn.quantizable.modules.rnn (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + rnn.LSTM + rnn.LSTMCell +``` + +## torch.ao.nn.quantized.dynamic.modules +```{eval-rst} +.. currentmodule:: torch.ao.nn.quantized.dynamic.modules +``` + +The following are aliases to their counterparts in the ``torch.ao.nn.quantized.dynamic`` namespace. + +### torch.ao.nn.quantized.dynamic.modules.conv (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + + conv.Conv1d + conv.Conv2d + conv.Conv3d + conv.ConvTranspose1d + conv.ConvTranspose2d + conv.ConvTranspose3d +``` + +### torch.ao.nn.quantized.dynamic.modules.linear (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + linear.Linear +``` + +### torch.ao.nn.quantized.dynamic.modules.rnn (Aliases) +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + rnn.GRU + rnn.GRUCell + rnn.LSTM + rnn.LSTMCell + rnn.PackedParameter + rnn.RNNBase + rnn.RNNCell + rnn.RNNCellBase +``` diff --git a/docs/source/quantization-support.md b/docs/source/quantization-support.md index 0b5d338d6f2bb..90721da45860d 100644 --- a/docs/source/quantization-support.md +++ b/docs/source/quantization-support.md @@ -843,3 +843,10 @@ the `custom operator mechanism Date: Sat, 29 Nov 2025 19:27:36 -0800 Subject: [PATCH 087/338] [BE][MPS] Add out-of-bounds checks for embedding_bag (#168930) Using new asynchronous reporting mechanism Fixes https://github.com/pytorch/pytorch/issues/163630 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168930 Approved by: https://github.com/dcci --- .../ATen/native/mps/kernels/EmbeddingBag.h | 1 + .../native/mps/kernels/EmbeddingBag.metal | 23 ++++++++++++++++--- .../native/mps/operations/EmbeddingBag.mm | 4 +++- test/test_mps.py | 9 ++++++++ 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/EmbeddingBag.h b/aten/src/ATen/native/mps/kernels/EmbeddingBag.h index 60485815bea47..b11b89f21471a 100644 --- a/aten/src/ATen/native/mps/kernels/EmbeddingBag.h +++ b/aten/src/ATen/native/mps/kernels/EmbeddingBag.h @@ -20,6 +20,7 @@ struct EmbeddingBagParams { idx_type_t num_indices; idx_type_t num_bags; idx_type_t feature_size; + idx_type_t num_weights; EmbeddingBagMode mode; int64_t padding_idx; diff --git a/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal b/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal index c97650b7f5070..5002b47ccd068 100644 --- a/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal +++ b/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -152,6 +153,7 @@ void embedding_bag_impl( device I* bag_size, device I* max_indices, constant EmbeddingBagParams& params, + device ErrorMessages* error_buf, uint tid) { auto num_indices = params.num_indices; auto num_bags = params.num_bags; @@ -159,6 +161,7 @@ void embedding_bag_impl( auto padding_idx = params.padding_idx; auto use_per_sample_weights = params.use_per_sample_weights; auto per_sample_weights_stride = params.per_sample_weights_stride; + const auto num_weights = params.num_weights; constant auto& output_strides = params.output_strides; constant auto& weight_strides = params.weight_strides; constant auto& max_indices_strides = params.max_indices_strides; @@ -167,10 +170,10 @@ void embedding_bag_impl( auto feature_idx = tid % feature_size; uint32_t offsets_end = min(bag_idx + 1, num_bags - 1); - bool is_last_bag = bag_idx + 1 == num_bags; + const bool is_last_bag = bag_idx + 1 == num_bags; uint32_t indices_start = static_cast(offsets[bag_idx]); - uint32_t indices_end = is_last_bag * (num_indices) + - (!is_last_bag) * (static_cast(offsets[offsets_end])); + uint32_t indices_end = + is_last_bag ? num_indices : static_cast(offsets[offsets_end]); auto out_val = ReductionOpInit()(); @@ -180,6 +183,17 @@ void embedding_bag_impl( for (uint32_t indices_idx = indices_start; indices_idx < indices_end; indices_idx++) { I weight_idx = indices[indices_idx]; + if (weight_idx < 0 || static_cast(weight_idx) > num_weights) { + TORCH_REPORT_ERROR( + error_buf, + "Index ", + indices_idx, + " is out of bounds: ", + weight_idx, + ", range 0 to ", + num_weights); + return; + } bool pad = (weight_idx == padding_idx); auto weight_val = static_cast>( weight @@ -223,6 +237,7 @@ void embedding_bag_impl( bag_size, \ max_indices, \ params, \ + error_buf, \ tid) template @@ -236,6 +251,7 @@ kernel void embedding_bag( device I* bag_size [[buffer(6)]], device I* max_indices [[buffer(7)]], constant EmbeddingBagParams& params [[buffer(8)]], + device ErrorMessages* error_buf [[buffer(9)]], uint tid [[thread_position_in_grid]]) { switch (params.mode) { case EmbeddingBagMode::SUM: @@ -424,6 +440,7 @@ kernel void embedding_bag_per_sample_weights_backward( device I * bag_size [[buffer(6)]], \ device I * max_indices [[buffer(7)]], \ constant EmbeddingBagParams & params [[buffer(8)]], \ + device ErrorMessages * error_buf [[buffer(9)]], \ uint tid [[thread_position_in_grid]]); \ \ template [[host_name("embedding_bag_backward_" #T "_" #I)]] \ diff --git a/aten/src/ATen/native/mps/operations/EmbeddingBag.mm b/aten/src/ATen/native/mps/operations/EmbeddingBag.mm index d7916ccdf875d..2225b93a6aecd 100644 --- a/aten/src/ATen/native/mps/operations/EmbeddingBag.mm +++ b/aten/src/ATen/native/mps/operations/EmbeddingBag.mm @@ -105,6 +105,7 @@ params.feature_size = feature_size; params.mode = static_cast(mode); params.padding_idx = padding_idx; + params.num_weights = weight.size(0); auto num_threads = output.numel(); MPSStream* stream = getCurrentMPSStream(); @@ -126,7 +127,8 @@ offset2bag, bag_size, max_indices, - params); + params, + stream->getErrorBuffer()); mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads); getMPSProfiler().endProfileKernel(pipeline_state); diff --git a/test/test_mps.py b/test/test_mps.py index bf837e788e74c..9030348f11d3a 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -12762,6 +12762,15 @@ def test_index_put_out_of_bounds(self, device): y = x[:, [1]] torch.mps.synchronize() + def test_embedding_bag_out_of_bounds(self, device): + inputs = torch.tensor([0, 1, 6], device=device) # Note: 6 is out of bounds for weight with size 4 + weight = torch.randn(4, 2, device=device) + offsets = torch.tensor([0, 3], device=device) + with self.assertRaisesRegex(torch.AcceleratorError, "Index 2 is out of bounds: 6, range 0 to 4"): + torch.nn.functional.embedding_bag(inputs, weight, offsets) + torch.mps.synchronize() + + class TestComplex(TestCase): def test_tensor_scalar_binops(self): # Regression test for https://github.com/pytorch/pytorch/issues/119088 From 641cdb68ae27668eb441d0e49c87a0602c120c2b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 1 Dec 2025 23:58:06 +0000 Subject: [PATCH 088/338] Revert "Avoid std::tie and returning value constructions in qconv_unpack.cpp (#169207)" This reverts commit d5038950bacfe36bbf24a47a455fe76901deb8e8. Reverted https://github.com/pytorch/pytorch/pull/169207 on behalf of https://github.com/huydhn due to I think this breaks some quantization tests, maybe they were wrongly skipped on CI ([comment](https://github.com/pytorch/pytorch/pull/169207#issuecomment-3599494109)) --- .../ATen/native/quantized/qconv_unpack.cpp | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/quantized/qconv_unpack.cpp b/aten/src/ATen/native/quantized/qconv_unpack.cpp index dcbfa7fdcf3f1..4c2352a396177 100644 --- a/aten/src/ATen/native/quantized/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/qconv_unpack.cpp @@ -82,28 +82,32 @@ class QConv1dUnpackWeightsInt8 final { static std::tuple> run( const c10::intrusive_ptr>& packed_weight) { auto& ctx = at::globalContext(); + at::Tensor weight; + std::optional bias; #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM || ctx.qEngine() == at::QEngine::X86) { - auto result = packed_weight->unpack(); - std::get<0>(result).squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return result; + std::tie(weight, bias) = packed_weight->unpack(); + weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return std::tuple>(weight, bias); } #endif #ifdef USE_PYTORCH_QNNPACK if (ctx.qEngine() == at::QEngine::QNNPACK) { - auto result = packed_weight->unpack(); - std::get<0>(result).squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return result; + std::tie(weight, bias) = packed_weight->unpack(); + at::Tensor new_weight = weight.clone(); + new_weight = new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return std::tuple>(new_weight, bias); } #endif #if AT_MKLDNN_ENABLED() if (ctx.qEngine() == at::QEngine::ONEDNN) { - auto result = packed_weight->unpack(); - std::get<0>(result).squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return result; + std::tie(weight, bias) = packed_weight->unpack(); + at::Tensor new_weight = weight.clone(); + new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return std::tuple>(new_weight, bias); } #endif From b8c4ba3593761e7b2a3ebd86f040fb07b47c02cf Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Tue, 2 Dec 2025 00:19:45 +0000 Subject: [PATCH 089/338] [xpu][feature][1/N] Enable SDPA XPU FlashAttention backend with SYCL-TLA implementation (#169101) This is a PR to utilize [SYCL-TLA](https://github.com/intel/sycl-tla)-based FlashAttention to accelerate `scaled_dot_product_attention` for Pytorch XPU. PR stacks: - https://github.com/pytorch/pytorch/pull/169101 - https://github.com/pytorch/pytorch/pull/167057 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169101 Approved by: https://github.com/EikanWang, https://github.com/atalman --- aten/src/ATen/CMakeLists.txt | 2 + .../native/transformers/xpu/attention.cpp | 56 ++++++ .../transformers/xpu/attention_backward.cpp | 51 ++++++ .../native/transformers/xpu/sdp_utils.cpp | 172 ++++++++++++++++++ .../ATen/native/transformers/xpu/sdp_utils.h | 17 ++ third_party/xpu.txt | 2 +- 6 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 aten/src/ATen/native/transformers/xpu/attention.cpp create mode 100644 aten/src/ATen/native/transformers/xpu/attention_backward.cpp create mode 100644 aten/src/ATen/native/transformers/xpu/sdp_utils.cpp create mode 100644 aten/src/ATen/native/transformers/xpu/sdp_utils.h diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index ae762e1def3ec..84dafb8e88cd5 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -171,6 +171,7 @@ file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu") file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp") file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip") file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp") +file(GLOB native_transformers_xpu_cpp "native/transformers/xpu/*.cpp") file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp") file(GLOB native_utils_cpp "native/utils/*.cpp") file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu) @@ -414,6 +415,7 @@ endif() if(USE_XPU) list(APPEND ATen_XPU_SRCS ${mkldnn_xpu_cpp}) + list(APPEND ATen_XPU_SRCS ${native_transformers_xpu_cpp}) list(APPEND ATen_XPU_DEPENDENCY_LIBS xpu_mkldnn) list(APPEND ATen_XPU_DEPENDENCY_LIBS ${OCL_LIBRARY}) diff --git a/aten/src/ATen/native/transformers/xpu/attention.cpp b/aten/src/ATen/native/transformers/xpu/attention.cpp new file mode 100644 index 0000000000000..8a953ef8be7c9 --- /dev/null +++ b/aten/src/ATen/native/transformers/xpu/attention.cpp @@ -0,0 +1,56 @@ +#include +#include +#include + +namespace at { +namespace native { + +std::tuple< + Tensor, + Tensor, + Tensor, + Tensor, + c10::SymInt, + c10::SymInt, + Tensor, + Tensor, + Tensor> +_scaled_dot_product_flash_attention_xpu( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + auto + [attention, + logsumexp, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + philox_seed, + philox_offset] = + sycltla::flash_attention_forward( + query, + key, + value, + dropout_p, + is_causal, + scale.has_value() ? scale.value() + : (1.0 / std::sqrt(query.size(3)))); + return std::make_tuple( + attention, + logsumexp, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + philox_seed, + philox_offset, + /* debug_attn_mask */ at::Tensor()); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/xpu/attention_backward.cpp b/aten/src/ATen/native/transformers/xpu/attention_backward.cpp new file mode 100644 index 0000000000000..4128d0f5c7e25 --- /dev/null +++ b/aten/src/ATen/native/transformers/xpu/attention_backward.cpp @@ -0,0 +1,51 @@ +#include +#include +#include + +namespace at { +namespace native { + +std::tuple +_scaled_dot_product_flash_attention_backward_xpu( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cumulative_sequence_length_q, + const at::Tensor& cumulative_sequence_length_k, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + if (!grad_out.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + + auto [grad_q, grad_k, grad_v] = sycltla::flash_attention_backward( + grad_out, + query, + key, + value, + out, + logsumexp, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout_p, + is_causal, + philox_seed, + philox_offset, + scale.has_value() ? scale.value() : (1.0 / std::sqrt(query.size(3)))); + + return std::make_tuple( + std::move(grad_q), std::move(grad_k), std::move(grad_v)); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/xpu/sdp_utils.cpp b/aten/src/ATen/native/transformers/xpu/sdp_utils.cpp new file mode 100644 index 0000000000000..ee6b47b0e2e69 --- /dev/null +++ b/aten/src/ATen/native/transformers/xpu/sdp_utils.cpp @@ -0,0 +1,172 @@ +#include +#include +#include + +namespace sdp { + +bool is_flash_attention_available() { + return sycltla::is_flash_attention_available(); +} + +inline bool is_flash_attention_available(sdp_params const& params, bool debug) { + if (!is_flash_attention_available()) { + if (debug) { + TORCH_WARN("Torch XPU was not compiled with flash attention."); + } + return false; + } + return true; +} + +bool check_flash_attention_hardware_support( + sdp_params const& params, + bool debug) { + if (!at::xpu::is_available()) { + TORCH_CHECK(false, "FlashAttentionXPU: XPU device is not available."); + } + + constexpr auto supported_architectures = + c10::array_of( + sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, + sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21); + auto* device_prop = at::xpu::getCurrentDeviceProperties(); + auto device_architecture = device_prop->architecture; + + if (std::find( + supported_architectures.begin(), + supported_architectures.end(), + device_architecture) == supported_architectures.end()) { + if (debug) { + TORCH_WARN( + "XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21."); + } + return false; + } + + return true; +} + +inline bool check_flash_attention_datatype( + sdp_params const& params, + bool debug) { + constexpr auto supported_dtypes = + c10::array_of(at::kBFloat16, at::kHalf); + + auto query_dtype = params.query.dtype(); + if (!(query_dtype == params.key.dtype() && + query_dtype == params.value.dtype() && + (std::find( + supported_dtypes.begin(), supported_dtypes.end(), query_dtype) != + supported_dtypes.end()))) { + if (debug) { + TORCH_WARN( + "FlashAttentionXPU expected query, key, and value to all be of dtype: {", + "bfloat16, half", + "}. Got ", + "Query dtype: ", + params.query.dtype(), + ", Key dtype: ", + params.key.dtype(), + ", and Value dtype: ", + params.value.dtype(), + " instead."); + } + return false; + } + return true; +} + +inline bool check_flash_attention_head_dim_size( + sdp_params const& params, + bool debug) { + const int query_size_last = params.query.size(3); + const int key_size_last = params.key.size(3); + const int value_size_last = params.value.size(3); + + const bool head_dims_equal = (query_size_last == key_size_last) && + (query_size_last == value_size_last); + if (!head_dims_equal) { + if (debug) { + TORCH_WARN( + "FlashAttentionXPU requires q,k,v to have the same last dimension.", + " Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + key_size_last, + ", Value.size(-1): ", + value_size_last, + " instead."); + } + return false; + } + + constexpr auto max_supported_headdim = 192; + if (query_size_last > max_supported_headdim) { + if (debug) { + TORCH_WARN( + "FlashAttentionXPU supports head dimension up to ", + max_supported_headdim, + ". ", + "Got head dimension: ", + query_size_last, + " instead."); + } + return false; + } + return true; +} + +inline bool check_flash_attention_layout(sdp_params const& params, bool debug) { + return sycltla::check_flash_attention_layout(params, debug); +} + +inline bool check_flash_causal_non_square_seqlens( + sdp_params const& params, + bool debug) { + // FlashAttention 2 updated the default mask meaning for causal in this PR: + // 9e5e8bc91e it is now aligned to lower_right which would be a BC break + // for non-square masks. We will not support non-square masks for causal w/ + // FAV2 + if (params.is_causal && !params.query.is_nested() && + !params.key.is_nested() && + params.query.sym_size(-2) != params.key.sym_size(-2)) { + if (debug) { + TORCH_WARN( + "Flash attention XPU does not support the is_causal flag when seqlen_q != seqlen_k. ", + "Got seqlen_q: ", + params.query.sym_size(-2), + " seqlen_k: ", + params.key.sym_size(-2), + ". If you would like to use causal attention with non-square masks, please see CausalAttnMask."); + } + return false; + } + return true; +} + +bool can_use_flash_attention(sdp_params const& params, bool debug) { + constexpr auto constraints = + std::array{ + is_flash_attention_available, + check_flash_attention_hardware_support, + check_for_attn_mask, + check_for_dropout, + check_nested_tensor, + check_tensor_shapes, + check_batch_size_and_num_heads_dense, + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense, + check_flash_causal_non_square_seqlens, + check_flash_attention_datatype, + check_flash_attention_head_dim_size, + check_flash_attention_layout}; + for (auto& constraint : constraints) { + if (!constraint(params, debug)) { + return false; + } + } + return true; +} + +} // namespace sdp diff --git a/aten/src/ATen/native/transformers/xpu/sdp_utils.h b/aten/src/ATen/native/transformers/xpu/sdp_utils.h new file mode 100644 index 0000000000000..14153741298d3 --- /dev/null +++ b/aten/src/ATen/native/transformers/xpu/sdp_utils.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace sdp { + +C10_EXPORT bool is_flash_attention_available(); +C10_EXPORT bool can_use_flash_attention(sdp_params const& params, bool debug); +C10_EXPORT bool check_flash_attention_hardware_support( + sdp_params const& params, + bool debug); + +} // namespace sdp diff --git a/third_party/xpu.txt b/third_party/xpu.txt index f05ce60393d66..423b13180d087 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -1e69f40b3c03492eb3dd7e03462a5566f29674d3 +549347d24e9b509b653a350053d56992fc8436ad From 518c2b1b3dab9a2ef2849e04b3bc2f20c1c41db9 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Tue, 2 Dec 2025 00:36:57 +0000 Subject: [PATCH 090/338] [10/N] Use Python 3.10 typing (#169229) This PR applies Python 3.10 typing syntax to some files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169229 Approved by: https://github.com/Lucaskabela --- torch/__init__.py | 60 ++++----- torch/_compile.py | 6 +- torch/_guards.py | 88 ++++++------- torch/_jit_internal.py | 6 +- torch/_linalg_utils.py | 10 +- torch/_lobpcg.py | 80 ++++++------ torch/_lowrank.py | 19 ++- torch/_meta_registrations.py | 204 +++++++++++++++--------------- torch/_ops.py | 35 +---- torch/_sources.py | 8 +- torch/_tensor.py | 52 ++++---- torch/_tensor_str.py | 4 +- torch/_utils.py | 6 +- torch/_utils_internal.py | 10 +- torch/_vmap_internals.py | 12 +- torch/_weights_only_unpickler.py | 12 +- torch/functional.py | 42 +++--- torch/hub.py | 12 +- torch/library.py | 66 +++++----- torch/masked/_ops.py | 174 +++++++++++++------------ torch/nn/_reduction.py | 9 +- torch/nn/common_types.py | 8 +- torch/nn/init.py | 28 ++-- torch/nn/modules/activation.py | 41 +++--- torch/nn/modules/batchnorm.py | 18 +-- torch/nn/modules/container.py | 14 +- torch/nn/modules/conv.py | 20 +-- torch/nn/modules/lazy.py | 4 +- torch/nn/modules/loss.py | 25 ++-- torch/nn/modules/module.py | 40 +++--- torch/nn/modules/normalization.py | 6 +- torch/nn/modules/pooling.py | 34 +++-- torch/nn/modules/rnn.py | 36 +++--- torch/nn/modules/transformer.py | 82 ++++++------ torch/nn/modules/upsampling.py | 25 ++-- torch/overrides.py | 4 +- torch/quasirandom.py | 11 +- torch/serialization.py | 48 +++---- torch/storage.py | 52 ++++---- torch/types.py | 22 ++-- torch/xpu/__init__.py | 18 ++- torch/xpu/random.py | 7 +- 42 files changed, 703 insertions(+), 755 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index e39e50a1f8409..165ade4f04dcf 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -320,7 +320,7 @@ def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> ctypes.CDLL(lib_path) -def _preload_cuda_deps(err: _Optional[OSError] = None) -> None: +def _preload_cuda_deps(err: OSError | None = None) -> None: cuda_libs: list[tuple[str, str]] = [ ("cublas", "libcublas.so.*[0-9]"), ("cudnn", "libcudnn.so.*[0-9]"), @@ -1277,7 +1277,7 @@ def set_default_device(device: "Device") -> None: _GLOBAL_DEVICE_CONTEXT.device_context = device_context -def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None: +def set_default_tensor_type(t: type["torch.Tensor"] | str, /) -> None: r""" .. warning:: @@ -1525,7 +1525,7 @@ def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: return _C._get_deterministic_algorithms_warn_only() -def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None: +def set_deterministic_debug_mode(debug_mode: builtins.int | str) -> None: r"""Sets the debug mode for deterministic operations. .. note:: This is an alternative interface for @@ -1687,7 +1687,7 @@ def is_warn_always_enabled() -> builtins.bool: def _check_with( error_type, - cond: _Union[builtins.bool, SymBool], + cond: builtins.bool | SymBool, message: _Callable[[], str], ): # noqa: F811 if not isinstance(cond, (builtins.bool, SymBool)): @@ -2093,7 +2093,7 @@ def _dtype(self): return torch.quint2x4 -_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = { +_storage_classes: set[type[TypedStorage | UntypedStorage]] = { UntypedStorage, DoubleStorage, FloatStorage, @@ -2399,13 +2399,13 @@ def __eq__(self, other): and self.dynamic == other.dynamic ) - def apply_mode(self, mode: _Optional[str]): + def apply_mode(self, mode: str | None): if mode and mode != "default": from torch._inductor import list_mode_options self.apply_options(list_mode_options(mode, self.dynamic)) - def apply_options(self, options: _Optional[dict[str, _Any]]): + def apply_options(self, options: dict[str, _Any] | None): if not options: return @@ -2525,12 +2525,10 @@ def compile( model: _Callable[_InputT, _RetT], *, fullgraph: builtins.bool = False, - dynamic: _Optional[builtins.bool] = None, - backend: _Union[str, _Callable] = "inductor", - mode: _Union[str, None] = None, - options: _Optional[ - dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] - ] = None, + dynamic: builtins.bool | None = None, + backend: str | _Callable = "inductor", + mode: str | None = None, + options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, disable: builtins.bool = False, ) -> _Callable[_InputT, _RetT]: ... @@ -2540,31 +2538,27 @@ def compile( model: None = None, *, fullgraph: builtins.bool = False, - dynamic: _Optional[builtins.bool] = None, - backend: _Union[str, _Callable] = "inductor", - mode: _Union[str, None] = None, - options: _Optional[ - dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] - ] = None, + dynamic: builtins.bool | None = None, + backend: str | _Callable = "inductor", + mode: str | None = None, + options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, disable: builtins.bool = False, ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ... def compile( - model: _Optional[_Callable[_InputT, _RetT]] = None, + model: _Callable[_InputT, _RetT] | None = None, *, fullgraph: builtins.bool = False, - dynamic: _Optional[builtins.bool] = None, - backend: _Union[str, _Callable] = "inductor", - mode: _Union[str, None] = None, - options: _Optional[ - dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] - ] = None, + dynamic: builtins.bool | None = None, + backend: str | _Callable = "inductor", + mode: str | None = None, + options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, disable: builtins.bool = False, -) -> _Union[ - _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]], - _Callable[_InputT, _RetT], -]: +) -> ( + _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]] + | _Callable[_InputT, _RetT] +): """ Optimizes given model/function using TorchDynamo and specified backend. If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` @@ -2872,7 +2866,7 @@ def __getattr__(name): @functools.cache -def get_device_module(device: _Optional[_Union[torch.device, str]] = None): +def get_device_module(device: torch.device | str | None = None): """ Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). If no device is given, return the module for the current accelerator or CPU if none is present. @@ -2898,8 +2892,8 @@ def get_device_module(device: _Optional[_Union[torch.device, str]] = None): def _constrain_as_size( symbol, - min: _Optional[builtins.int] = None, - max: _Optional[builtins.int] = None, + min: builtins.int | None = None, + max: builtins.int | None = None, ): """ This indicates that a given int is size-like, and can be used in any context where a size is expected. diff --git a/torch/_compile.py b/torch/_compile.py index 76ddd3ccb05b4..bf7d715883d58 100644 --- a/torch/_compile.py +++ b/torch/_compile.py @@ -5,7 +5,7 @@ import functools from collections.abc import Callable -from typing import Optional, overload, TypeVar, Union +from typing import overload, TypeVar from typing_extensions import ParamSpec @@ -26,8 +26,8 @@ def _disable_dynamo( def _disable_dynamo( - fn: Optional[Callable[_P, _T]] = None, recursive: bool = True -) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: + fn: Callable[_P, _T] | None = None, recursive: bool = True +) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]: """ This API should be only used inside torch, external users should still use torch._dynamo.disable. The main goal of this API is to avoid circular diff --git a/torch/_guards.py b/torch/_guards.py index 1bd32fc7f08ec..c9daab1e69e81 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -14,7 +14,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Generic, NamedTuple, TYPE_CHECKING, TypeVar import torch from torch.utils import _pytree as pytree @@ -92,7 +92,7 @@ def __str__(self) -> str: return f"{self.frame_id}/{self.frame_compile_id}" @classmethod - def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]: + def from_string(cls, compile_id: str | None) -> CompileId | None: """ Factory method that creates a CompileId from its string representation. Keep this in sync with the __str__ method. @@ -255,14 +255,14 @@ class Guard: create_fn: Callable[[GuardBuilderBase, Guard], None] # Export only. These values are written to at time of guard check_fn creation. - guard_types: Optional[list[str]] = None - code_list: Optional[list[str]] = None - obj_weakref: Optional[object] = None - guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None - - stack: Optional[CapturedTraceback] = None - user_stack: Optional[traceback.StackSummary] = None - _hash: Optional[int] = None + guard_types: list[str] | None = None + code_list: list[str] | None = None + obj_weakref: object | None = None + guarded_class_weakref: weakref.ReferenceType[Any] | None = None + + stack: CapturedTraceback | None = None + user_stack: traceback.StackSummary | None = None + _hash: int | None = None _unserializable: bool = False def __hash__(self) -> int: @@ -379,7 +379,7 @@ def create_fn_name(self) -> str: def set_export_info( self, guard_type: str, - guarded_class: Optional[weakref.ReferenceType[Any]], + guarded_class: weakref.ReferenceType[Any] | None, code_list: list[str], obj_weakref: object, ) -> None: @@ -492,7 +492,7 @@ class GuardsCheckpointState: def __init__(self, dynamo_guards: set[Guard]) -> None: self.dynamo_guards = dynamo_guards - def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]: + def diff(self, other: GuardsCheckpointState) -> set[Guard] | None: """ Produces a delta against another GuardsCheckpointState. @@ -516,7 +516,7 @@ class ModuleContextCheckpointState: def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None: self.nn_modules = nn_modules - def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]: + def diff(self, other: ModuleContextCheckpointState) -> set[str] | None: """ Produces a delta against another ModuleContextCheckpointState. @@ -552,7 +552,7 @@ class GlobalContextCheckpointState: def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None: self.global_state = global_states - def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]: + def diff(self, other: GlobalContextCheckpointState) -> set[str] | None: """ Produces a delta against another GlobalContextCheckpointState. @@ -605,7 +605,7 @@ def restore_graphstate(self, state: GlobalContextCheckpointState) -> None: # Like a Set[Guard] but will record the user stack on all guards at the # time they were installed at their destination class GuardsSet: - def __init__(self, inner: Optional[set[Guard]] = None) -> None: + def __init__(self, inner: set[Guard] | None = None) -> None: if inner is None: inner = set() self.inner = inner @@ -683,13 +683,13 @@ def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ... def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ... @abstractmethod - def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: ... + def get_autograd_key_entry(self, identifier: str) -> Callable | None: ... @abstractmethod def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ... @abstractmethod - def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ... + def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: ... @abstractmethod def add_lazy_bwd_entry( @@ -702,7 +702,7 @@ def add_lazy_bwd_entry( @abstractmethod def get_lazy_bwd_entry( self, identifier: str, tangent_metadata: tuple[object] - ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ... + ) -> tuple[torch.fx.GraphModule | None, int | None]: ... class InvokeSubgraphCache(HopSubgraphCache): @@ -726,13 +726,13 @@ def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: self.autograd_cache[identifier] = key - def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: + def get_autograd_key_entry(self, identifier: str) -> Callable | None: return self.autograd_cache.get(identifier, None) def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: self.proxy_dispatch_cache[identifier] = key - def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: + def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: return self.proxy_dispatch_cache.get(identifier, None) def add_lazy_bwd_entry( @@ -748,7 +748,7 @@ def add_lazy_bwd_entry( def get_lazy_bwd_entry( self, identifier: str, tangent_metadata: tuple[object] - ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: + ) -> tuple[torch.fx.GraphModule | None, int | None]: if identifier not in self.lazy_bwd_cache: return (None, None) @@ -765,7 +765,7 @@ def add_effects(self, identifier: str, effects: set) -> None: ) self.effects_cache[identifier] = effects - def get_effects(self, identifier: str) -> Optional[set]: + def get_effects(self, identifier: str) -> set | None: """Retrieve the effect types for a given invoke_subgraph identifier.""" return self.effects_cache.get(identifier, None) @@ -814,7 +814,7 @@ def get() -> CompileContext: def try_get() -> CompileContext | None: return getattr(_TLS, "compile_context", None) - def __init__(self, compile_id: Optional[CompileId]) -> None: + def __init__(self, compile_id: CompileId | None) -> None: assert compile_id is None or isinstance(compile_id, CompileId) self.compile_id: CompileId | None = compile_id self.attempt = 0 @@ -822,14 +822,14 @@ def __init__(self, compile_id: Optional[CompileId]) -> None: self.shape_env_guards: list[str] = [] @staticmethod - def current_compile_id() -> Optional[CompileId]: + def current_compile_id() -> CompileId | None: self = CompileContext.try_get() if self is None: return None return self.compile_id @staticmethod - def current_trace_id() -> Optional[TraceId]: + def current_trace_id() -> TraceId | None: self = CompileContext.try_get() if self is None: return None @@ -858,13 +858,13 @@ def get() -> TracingContext: "TracingContext.get() must be called within an ongoing trace." ) - def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: + def __init__(self, fake_mode: FakeTensorMode | None) -> None: self.guards_context = GuardsContext() self.module_context = ModuleContext() self.global_context = GlobalContext() self.previously_inlined_functions: dict[Any, Any] = dict() self.previously_cleaned_instructions: dict[Any, Any] = dict() - self.fake_mode: Optional[FakeTensorMode] = fake_mode + self.fake_mode: FakeTensorMode | None = fake_mode self.frame_summary_stack: list[traceback.FrameSummary] = [] # This is morally part of frame_summary_stack, but it is kept separate # for clarity. As we process a frame, this variable gets updated @@ -872,16 +872,16 @@ def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: # function call, this gets cleared and the frame location is pushed # to frame_summary_stack (prepping this variable for the inner frame's # progress) - self.loc_in_frame: Optional[tuple[str, int, str]] = None + self.loc_in_frame: tuple[str, int, str] | None = None # this is only set after aot_autograd - self.fw_metadata: Optional[ViewAndMutationMeta] = None + self.fw_metadata: ViewAndMutationMeta | None = None # this is only set when the DDPOptimizer is used - self.ddp_optimizer_ctx: Optional[DDPOptimizerContext] = None + self.ddp_optimizer_ctx: DDPOptimizerContext | None = None # this is only set after aot_autograd - self.aot_graph_name: Optional[list[str]] = None - self.params_flat: Optional[list[Any]] = None - self.params_flat_unwrap_subclasses: Optional[list[Any]] = None - self.params_unwrapped_to_flat_index: Optional[list[Any]] = None + self.aot_graph_name: list[str] | None = None + self.params_flat: list[Any] | None = None + self.params_flat_unwrap_subclasses: list[Any] | None = None + self.params_unwrapped_to_flat_index: list[Any] | None = None # this is for extended return calling convention from backend # compiler to aot_autograd # Per output, what the compiler specified stride of the output is, @@ -985,7 +985,7 @@ def clear_frame() -> Generator[None, None, None]: @staticmethod @contextlib.contextmanager def current_frame( - frame_summary: Optional[traceback.FrameSummary], + frame_summary: traceback.FrameSummary | None, ) -> Generator[None, None, None]: # frame_summary can be None to solely take advantage of real_stack # attachment to thrown exceptions @@ -1008,7 +1008,7 @@ def current_frame( @staticmethod @contextlib.contextmanager def report_output_strides() -> Generator[ - Optional[list[Optional[tuple[int, ...]]]], None, None + list[tuple[int, ...] | None] | None, None, None ]: tc = TracingContext.try_get() if tc is None: @@ -1028,7 +1028,7 @@ def set_current_loc(filename: str, lineno: int, frame_name: str) -> None: TracingContext.get().loc_in_frame = (filename, lineno, frame_name) @staticmethod - def get_traced_code() -> Optional[list[CodeType]]: + def get_traced_code() -> list[CodeType] | None: tc = TracingContext.try_get() if tc is None: return None @@ -1037,8 +1037,8 @@ def get_traced_code() -> Optional[list[CodeType]]: @contextmanager def compile_context( - context: Optional[CompileContext], -) -> Generator[Optional[CompileContext], None, None]: + context: CompileContext | None, +) -> Generator[CompileContext | None, None, None]: old_context = getattr(_TLS, "compile_context", None) _TLS.compile_context = context try: @@ -1049,8 +1049,8 @@ def compile_context( @contextmanager def tracing( - context: Optional[TracingContext], -) -> Generator[Optional[TracingContext], None, None]: + context: TracingContext | None, +) -> Generator[TracingContext | None, None, None]: """ This function installs the passed in tracing context as a dynamic scoped global variable. @@ -1127,7 +1127,7 @@ def get_base(self) -> Source: return current -def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: +def detect_fake_mode(inputs: Any = None) -> FakeTensorMode | None: """ Attempts to "detect" what the current fake mode is. If there is one ambiently available from TracingContext, we preferentially use that. Otherwise, we @@ -1164,7 +1164,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: # pyrefly: ignore [bad-argument-type] fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) if is_traceable_wrapper_subclass(flat_input): - out: list[Union[torch.Tensor, int, torch.SymInt]] = [] + out: list[torch.Tensor | int | torch.SymInt] = [] get_plain_tensors(flat_input, out=out) # type: ignore[arg-type] fake_tensors: list[FakeTensor] = [ x for x in out if isinstance(x, FakeTensor) @@ -1193,7 +1193,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: return None -def active_fake_mode() -> Optional[FakeTensorMode]: +def active_fake_mode() -> FakeTensorMode | None: """ Inspects the dispatch mode stack for an active fake mode and returns it. Returns None if no fake mode is active. diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 9efa0583cdea7..27c5768477dab 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -52,7 +52,7 @@ _P = ParamSpec("_P") _R = TypeVar("_R") -BuiltinUnionType: Union[type, tuple[type, ...]] = types.UnionType +BuiltinUnionType: type | tuple[type, ...] = types.UnionType LockType: type try: @@ -1236,7 +1236,7 @@ def _try_get_dispatched_fn(fn): def _get_named_tuple_properties( obj, - loc: Optional[torch._C._jit_tree_views.SourceRange] = None, + loc: torch._C._jit_tree_views.SourceRange | None = None, rcb=None, ): if loc is None: @@ -1531,7 +1531,7 @@ def _extract_tensors(obj): return tensors -def _get_model_id(obj) -> Optional[str]: +def _get_model_id(obj) -> str | None: if isinstance(obj, torch.jit.ScriptModule): return str(obj._c._type()) elif isinstance(obj, torch.jit.ScriptFunction): diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index 43c8b65767e00..213393da9aa99 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs """Various linear algebra utility methods for internal use.""" -from typing import Optional - import torch from torch import Tensor @@ -29,7 +27,7 @@ def get_floating_dtype(A): return torch.float32 -def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: +def matmul(A: Tensor | None, B: Tensor) -> Tensor: """Multiply two matrices. If A is None, return B. A can be sparse or dense. B is always @@ -42,12 +40,12 @@ def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: return torch.matmul(A, B) -def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: +def bform(X: Tensor, A: Tensor | None, Y: Tensor) -> Tensor: """Return bilinear form of matrices: :math:`X^T A Y`.""" return matmul(X.mT, matmul(A, Y)) -def qform(A: Optional[Tensor], S: Tensor): +def qform(A: Tensor | None, S: Tensor): """Return quadratic form :math:`S^T A S`.""" return bform(S, A, S) @@ -57,7 +55,7 @@ def basis(A): return torch.linalg.qr(A).Q -def symeig(A: Tensor, largest: Optional[bool] = False) -> tuple[Tensor, Tensor]: +def symeig(A: Tensor, largest: bool | None = False) -> tuple[Tensor, Tensor]: """Return eigenpairs of A with specified ordering.""" if largest is None: largest = False diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 1137efdc5f63a..cdc426047c33f 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -3,8 +3,6 @@ # Author: Pearu Peterson # Created: February 2020 -from typing import Optional - import torch from torch import _linalg_utils as _utils, Tensor from torch.overrides import handle_torch_function, has_torch_function @@ -258,19 +256,19 @@ class LOBPCGAutogradFunction(torch.autograd.Function): def forward( # type: ignore[override] ctx, A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, + k: int | None = None, + B: Tensor | None = None, + X: Tensor | None = None, + n: int | None = None, + iK: Tensor | None = None, + niter: int | None = None, + tol: float | None = None, + largest: bool | None = None, + method: str | None = None, tracker: None = None, - ortho_iparams: Optional[dict[str, int]] = None, - ortho_fparams: Optional[dict[str, float]] = None, - ortho_bparams: Optional[dict[str, bool]] = None, + ortho_iparams: dict[str, int] | None = None, + ortho_fparams: dict[str, float] | None = None, + ortho_bparams: dict[str, bool] | None = None, ) -> tuple[Tensor, Tensor]: # makes sure that input is contiguous for efficiency. # Note: autograd does not support dense gradients for sparse input yet. @@ -344,19 +342,19 @@ def backward(ctx, D_grad, U_grad): # pyrefly: ignore # bad-override def lobpcg( A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, + k: int | None = None, + B: Tensor | None = None, + X: Tensor | None = None, + n: int | None = None, + iK: Tensor | None = None, + niter: int | None = None, + tol: float | None = None, + largest: bool | None = None, + method: str | None = None, tracker: None = None, - ortho_iparams: Optional[dict[str, int]] = None, - ortho_fparams: Optional[dict[str, float]] = None, - ortho_bparams: Optional[dict[str, bool]] = None, + ortho_iparams: dict[str, int] | None = None, + ortho_fparams: dict[str, float] | None = None, + ortho_bparams: dict[str, bool] | None = None, ) -> tuple[Tensor, Tensor]: """Find the k largest (or smallest) eigenvalues and the corresponding eigenvectors of a symmetric positive definite generalized @@ -584,19 +582,19 @@ def lobpcg( def _lobpcg( A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, + k: int | None = None, + B: Tensor | None = None, + X: Tensor | None = None, + n: int | None = None, + iK: Tensor | None = None, + niter: int | None = None, + tol: float | None = None, + largest: bool | None = None, + method: str | None = None, tracker: None = None, - ortho_iparams: Optional[dict[str, int]] = None, - ortho_fparams: Optional[dict[str, float]] = None, - ortho_bparams: Optional[dict[str, bool]] = None, + ortho_iparams: dict[str, int] | None = None, + ortho_fparams: dict[str, float] | None = None, + ortho_bparams: dict[str, bool] | None = None, ) -> tuple[Tensor, Tensor]: # A must be square: assert A.shape[-2] == A.shape[-1], A.shape @@ -696,10 +694,10 @@ class LOBPCG: def __init__( self, - A: Optional[Tensor], - B: Optional[Tensor], + A: Tensor | None, + B: Tensor | None, X: Tensor, - iK: Optional[Tensor], + iK: Tensor | None, iparams: dict[str, int], fparams: dict[str, float], bparams: dict[str, bool], diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 182883cfc5e59..25089d66d35ea 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -2,7 +2,6 @@ __all__ = ["svd_lowrank", "pca_lowrank"] -from typing import Optional import torch from torch import _linalg_utils as _utils, Tensor @@ -12,8 +11,8 @@ def get_approximate_basis( A: Tensor, q: int, - niter: Optional[int] = 2, - M: Optional[Tensor] = None, + niter: int | None = 2, + M: Tensor | None = None, ) -> Tensor: """Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is @@ -85,9 +84,9 @@ def get_approximate_basis( def svd_lowrank( A: Tensor, - q: Optional[int] = 6, - niter: Optional[int] = 2, - M: Optional[Tensor] = None, + q: int | None = 6, + niter: int | None = 2, + M: Tensor | None = None, ) -> tuple[Tensor, Tensor, Tensor]: r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, batches of matrices, or a sparse matrix :math:`A` such that @@ -149,9 +148,9 @@ def svd_lowrank( def _svd_lowrank( A: Tensor, - q: Optional[int] = 6, - niter: Optional[int] = 2, - M: Optional[Tensor] = None, + q: int | None = 6, + niter: int | None = 2, + M: Tensor | None = None, ) -> tuple[Tensor, Tensor, Tensor]: # Algorithm 5.1 in Halko et al., 2009 @@ -183,7 +182,7 @@ def _svd_lowrank( def pca_lowrank( A: Tensor, - q: Optional[int] = None, + q: int | None = None, center: bool = True, niter: int = 2, ) -> tuple[Tensor, Tensor, Tensor]: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index cd397a0bc29c9..0055bdd77f315 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3,7 +3,7 @@ from collections.abc import Callable, Sequence from enum import Enum from functools import wraps -from typing import Optional, TypeVar, Union +from typing import TypeVar from typing_extensions import ParamSpec import torch @@ -547,9 +547,9 @@ def meta_sparse_structured_linear( input: Tensor, weight: Tensor, _meta: Tensor, - bias: Optional[Tensor] = None, - _activation_opt: Optional[str] = None, - out_dtype: Optional[torch.dtype] = None, + bias: Tensor | None = None, + _activation_opt: str | None = None, + out_dtype: torch.dtype | None = None, ): output_sizes = list(input.shape) if bias is not None: @@ -581,7 +581,7 @@ def meta_sparse_structured_mm( mat1: Tensor, mat1_meta: Tensor, mat2: Tensor, - out_dtype: Optional[torch.dtype] = None, + out_dtype: torch.dtype | None = None, ): assert len(mat1.shape) == 2 assert len(mat1_meta.shape) == 2 @@ -610,7 +610,7 @@ def meta_sparse_structured_addmm( *, alpha=1, beta=1, - out_dtype: Optional[torch.dtype] = None, + out_dtype: torch.dtype | None = None, ): assert len(input.shape) == 1, ( "only input broadcasted to columns of mat1 * mat2 product is supported" @@ -640,9 +640,9 @@ def meta_sparse_structured_addmm( def meta__cslt_sparse_mm( compressed_A: torch.Tensor, dense_B: torch.Tensor, - bias: Optional[Tensor] = None, - alpha: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + bias: Tensor | None = None, + alpha: Tensor | None = None, + out_dtype: torch.dtype | None = None, transpose_result: bool = False, alg_id: int = 0, split_k: int = 1, @@ -724,9 +724,9 @@ def meta_segment_reduce( data: Tensor, reduce: str, *, - lengths: Optional[Tensor] = None, - indices: Optional[Tensor] = None, - offsets: Optional[Tensor] = None, + lengths: Tensor | None = None, + indices: Tensor | None = None, + offsets: Tensor | None = None, axis: int = 0, unsafe: bool = False, initial=None, @@ -1468,7 +1468,7 @@ def _linalg_svd_meta( A: Tensor, full_matrices: bool = False, compute_uv: bool = True, - driver: Optional[str] = None, + driver: str | None = None, ): checkIsMatrix(A, "linalg.svd") checkFloatingOrComplex(A, "linalg.svd") @@ -1521,7 +1521,7 @@ def _linalg_broadcast_batch_dims( def _linalg_broadcast_batch_dims_name( arg1: Tensor, arg2: Tensor, - name: Optional[str], + name: str | None, ) -> tuple[Tensor, Tensor]: # If there's no name we assume we don't want to check the errors if name: @@ -1553,10 +1553,10 @@ def _linalg_solve_ex( *, left: bool = True, check_errors: bool = False, - result: Optional[Tensor] = None, - LU: Optional[Tensor] = None, - pivots: Optional[Tensor] = None, - info: Optional[Tensor] = None, + result: Tensor | None = None, + LU: Tensor | None = None, + pivots: Tensor | None = None, + info: Tensor | None = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: checkFloatingOrComplex(A, "linalg.solve") torch._check( @@ -1613,7 +1613,7 @@ def linalg_solve_triangular_meta( upper: bool, left: bool = True, unitriangular: bool = False, - out: Optional[Tensor] = None, + out: Tensor | None = None, ) -> Tensor: if out is None: out = A.new_empty([0]) @@ -2264,7 +2264,7 @@ def meta__fused_moving_avg_obs_fq_helper( @register_meta(aten.mm) @out_wrapper(exact_dtype=True) -def meta_mm(a, b, out_dtype: Optional[torch.dtype] = None): +def meta_mm(a, b, out_dtype: torch.dtype | None = None): torch._check(a.dim() == 2, lambda: "a must be 2D") torch._check(b.dim() == 2, lambda: "b must be 2D") N, M1 = a.shape @@ -2313,12 +2313,12 @@ def device_hint(tensor) -> "str": def calc_conv_nd_return_shape( input_tensor: torch.Tensor, weight: torch.Tensor, - stride: Union[list[int], int], - padding: Union[list[int], int], - dilation: Union[list[int], int], + stride: list[int] | int, + padding: list[int] | int, + dilation: list[int] | int, is_transposed: bool, groups: int, - output_padding: Optional[Union[list[int], int]] = None, + output_padding: list[int] | int | None = None, ): def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ @@ -2384,7 +2384,7 @@ def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int elif len(dilation) == 1: dilation = [dilation[0]] * len(dims) - output_padding_list: Optional[list[int]] = None + output_padding_list: list[int] | None = None if output_padding: if isinstance(output_padding, IntLike): # pyrefly: ignore [bad-assignment] @@ -2435,9 +2435,9 @@ def is_channels_last(ten): def meta_miopen_batch_norm( input_tensor: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor], - running_mean: Optional[torch.Tensor], - running_var: Optional[torch.Tensor], + bias: torch.Tensor | None, + running_mean: torch.Tensor | None, + running_var: torch.Tensor | None, training: bool, exponential_average_factor: float, epsilon: float, @@ -3383,7 +3383,7 @@ def meta_index_Tensor(self, indices): torch._check(bool(indices), lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors - result: list[Optional[Tensor]] = [] + result: list[Tensor | None] = [] for i, index in enumerate(indices): if index is not None: torch._check( @@ -3853,7 +3853,7 @@ def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs): @register_meta([aten._dyn_quant_pack_4bit_weight]) def meta__dyn_quant_pack_4bit_weight( - weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features + weights, scales_zeros, bias: Tensor | None, block_size, in_features, out_features ): torch._check( weights.dtype is torch.uint8, @@ -5655,7 +5655,7 @@ def meta__scaled_dot_product_flash_attention( dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -5737,12 +5737,12 @@ def meta__scaled_dot_product_cudnn_attention( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Optional[Tensor], + attn_bias: Tensor | None, compute_log_sumexp: bool, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): B = query.size(0) H = query.size(1) @@ -5781,11 +5781,11 @@ def meta__scaled_dot_product_fused_attention_overrideable( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Optional[Tensor] = None, + attn_bias: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): B = query.size(0) H_Q = query.size(1) @@ -5839,7 +5839,7 @@ def meta__scaled_dot_product_flash_backward( is_causal: bool, philox_seed: Tensor, philox_offset: Tensor, - scale: Optional[float] = None, + scale: float | None = None, ): grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2) grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2) @@ -5858,8 +5858,8 @@ def meta__scaled_dot_product_flash_attention_for_cpu( value: Tensor, dropout_p: float = 0.0, is_causal: bool = False, - attn_mask: Optional[Tensor] = None, - scale: Optional[float] = None, + attn_mask: Tensor | None = None, + scale: float | None = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -5895,8 +5895,8 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( logsumexp: Tensor, dropout_p: float, is_causal: bool, - attn_mask: Optional[Tensor] = None, - scale: Optional[float] = None, + attn_mask: Tensor | None = None, + scale: float | None = None, ): # cpus's grad layout is different from cuda's, # i.e. (batch_size, seq_len, num_heads, head_dim) @@ -5927,11 +5927,11 @@ def meta__scaled_dot_product_attention_math_for_mps( query: Tensor, key: Tensor, value: Tensor, - attn_mask: Optional[Tensor] = None, + attn_mask: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, - dropout_mask: Optional[Tensor] = None, - scale: Optional[float] = None, + dropout_mask: Tensor | None = None, + scale: float | None = None, ) -> tuple[Tensor, Tensor]: def ensure_4d(x): if x.dim() == 3: @@ -5982,11 +5982,11 @@ def meta__scaled_dot_product_efficient_attention( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Optional[Tensor], + attn_bias: Tensor | None, compute_log_sumexp: bool, dropout_p=0.0, is_causal: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -6032,7 +6032,7 @@ def meta__scaled_dot_product_efficient_backward( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Optional[Tensor], + attn_bias: Tensor | None, out: Tensor, logsumexp: Tensor, philox_seed: Tensor, @@ -6040,7 +6040,7 @@ def meta__scaled_dot_product_efficient_backward( dropout_p: float, grad_input_mask: list[bool], is_causal: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -6103,7 +6103,7 @@ def meta__scaled_dot_product_cudnn_backward( max_k: int, dropout_p: float, is_causal: bool, - scale: Optional[float] = None, + scale: float | None = None, ): grad_q = torch.empty_like(query) grad_k = torch.empty_like(key) @@ -6120,18 +6120,18 @@ def meta__flash_attention_forward( query: Tensor, key: Tensor, value: Tensor, - cum_seq_q: Optional[Tensor], - cum_seq_k: Optional[Tensor], + cum_seq_q: Tensor | None, + cum_seq_k: Tensor | None, max_q: int, max_k: int, dropout_p: float, is_causal: bool, return_debug_mask: bool, - scale: Optional[float] = None, - window_size_left: Optional[int] = None, - window_size_right: Optional[int] = None, - seqused_k: Optional[Tensor] = None, - alibi_slopes: Optional[Tensor] = None, + scale: float | None = None, + window_size_left: int | None = None, + window_size_right: int | None = None, + seqused_k: Tensor | None = None, + alibi_slopes: Tensor | None = None, ): # NB: there are two underlying paths: # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) @@ -6211,9 +6211,9 @@ def meta__flash_attention_backward( is_causal: bool, philox_seed: Tensor, philox_offset: Tensor, - scale: Optional[float] = None, - window_size_left: Optional[int] = None, - window_size_right: Optional[int] = None, + scale: float | None = None, + window_size_left: int | None = None, + window_size_right: int | None = None, ): grad_query = torch.empty_like(query) grad_key = torch.empty_like(key) @@ -6231,18 +6231,18 @@ def meta__efficient_attention_forward( query: Tensor, key: Tensor, value: Tensor, - bias: Optional[Tensor], - cu_seqlens_q: Optional[Tensor], - cu_seqlens_k: Optional[Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], + bias: Tensor | None, + cu_seqlens_q: Tensor | None, + cu_seqlens_k: Tensor | None, + max_seqlen_q: int | None, + max_seqlen_k: int | None, dropout_p: float, custom_mask_type: int, compute_log_sumexp: bool = False, - scale: Optional[float] = None, - causal_diagonal: Optional[Tensor] = None, - seqlen_k: Optional[Tensor] = None, - window_size: Optional[int] = None, + scale: float | None = None, + causal_diagonal: Tensor | None = None, + seqlen_k: Tensor | None = None, + window_size: int | None = None, ): B = query.size(0) M = query.size(1) @@ -6284,9 +6284,9 @@ def meta__efficient_attention_backward( query: Tensor, key: Tensor, value: Tensor, - bias: Optional[Tensor], - cu_seqlens_q: Optional[Tensor], - cu_seqlens_k: Optional[Tensor], + bias: Tensor | None, + cu_seqlens_q: Tensor | None, + cu_seqlens_k: Tensor | None, max_seqlen_q: torch.SymInt, max_seqlen_k: torch.SymInt, logsumexp: Tensor, @@ -6295,8 +6295,8 @@ def meta__efficient_attention_backward( philox_offset: Tensor, custom_mask_type: int, bias_requires_grad: bool, - scale: Optional[float] = None, - num_splits_key: Optional[int] = None, + scale: float | None = None, + num_splits_key: int | None = None, shared_storage_dqdkdv: bool = False, ): if shared_storage_dqdkdv: @@ -6339,9 +6339,9 @@ def _check_scaled_mm_sizes( mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - bias: Optional[torch.Tensor] = None, - scale_result: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + bias: torch.Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, use_fast_accum: bool = False, ): def is_fp8_or_fp4_type(dtype): @@ -6520,9 +6520,9 @@ def meta_scaled_mm( mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - bias: Optional[torch.Tensor] = None, - scale_result: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + bias: torch.Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, use_fast_accum: bool = False, ): return _check_scaled_mm_sizes( @@ -6537,10 +6537,10 @@ def _check_scaled_mm_sizes_v2( scale_recipe_a: list[ScalingType], scale_b: list[torch.Tensor], scale_recipe_b: list[ScalingType], - bias: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - swizzle_a: Optional[list[SwizzleType]] = None, - swizzle_b: Optional[list[SwizzleType]] = None, + bias: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + swizzle_a: list[SwizzleType] | None = None, + swizzle_b: list[SwizzleType] | None = None, use_fast_accum: bool = False, ): def is_fp8_or_fp4_type(dtype): @@ -6872,9 +6872,9 @@ def meta_scaled_mm_v2( scale_b: list[torch.Tensor], scale_recipe_b: list[ScalingType], swizzle_b: list[SwizzleType], - bias: Optional[torch.Tensor] = None, - output_dtype: Optional[torch.dtype] = None, - contraction_dims: Optional[list[int]] = None, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, + contraction_dims: list[int] | None = None, use_fast_accum: bool = False, ): return _check_scaled_mm_sizes_v2( @@ -6997,10 +6997,10 @@ def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None): ) def upsample_nearest2d_backward( grad_output: Tensor, - output_size: Sequence[Union[int, torch.SymInt]], - input_size: Sequence[Union[int, torch.SymInt]], - scales_h: Optional[float] = None, - scales_w: Optional[float] = None, + output_size: Sequence[int | torch.SymInt], + input_size: Sequence[int | torch.SymInt], + scales_h: float | None = None, + scales_w: float | None = None, ): full_output_size = upsample_common_check( input_size, output_size, num_spatial_dims=2 @@ -7842,12 +7842,12 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): def _meta_grouped_mm_common( mat_a: Tensor, mat_b: Tensor, - scale_a: Optional[torch.Tensor], - scale_b: Optional[torch.Tensor], - offs: Optional[Tensor] = None, - bias: Optional[Tensor] = None, - scale_result: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + scale_a: torch.Tensor | None, + scale_b: torch.Tensor | None, + offs: Tensor | None = None, + bias: Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, use_fast_accum: bool = False, ): torch._check( @@ -8055,9 +8055,9 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): def meta_grouped_mm( mat_a: Tensor, mat_b: Tensor, - offs: Optional[Tensor] = None, - bias: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + offs: Tensor | None = None, + bias: Tensor | None = None, + out_dtype: torch.dtype | None = None, ) -> Tensor: return _meta_grouped_mm_common( mat_a, @@ -8077,10 +8077,10 @@ def meta_scaled_grouped_mm( mat_b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - offs: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - scale_result: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + offs: torch.Tensor | None = None, + bias: torch.Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, use_fast_accum: bool = False, ): # matching _scaled_grouped_mm_cuda Blas.cpp implementation diff --git a/torch/_ops.py b/torch/_ops.py index 8f8a7328429fa..23108117a9870 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -8,16 +8,7 @@ import types from collections.abc import Callable, Iterator from functools import cached_property -from typing import ( - Any, - ClassVar, - Concatenate, - final, - Generic, - Optional, - TYPE_CHECKING, - Union, -) +from typing import Any, ClassVar, Concatenate, final, Generic, TYPE_CHECKING from typing_extensions import ParamSpec, TypeVar import torch @@ -79,9 +70,7 @@ def __init__(self): # for use with OpOverload; cache lookup is done entirely from C++ # for speed. # TODO: The cache is NOT currently used by HigherOrderOperator, but it should! - self._dispatch_cache: dict[ - DispatchKey, Union[DispatchKey, Callable[..., Any]] - ] = {} + self._dispatch_cache: dict[DispatchKey, DispatchKey | Callable[..., Any]] = {} # This table allows you to override the behavior of a particular # dispatch key to call a custom Python function, rather than the @@ -99,7 +88,7 @@ def __init__(self): # makes sense that you should be able to register them, the same # way you can register dispatch keys. self.python_key_table: dict[ - type[Union[TorchDispatchMode, torch.Tensor]], Callable[..., Any] + type[TorchDispatchMode | torch.Tensor], Callable[..., Any] ] = {} # This table allows you to override the behavior of functorch @@ -121,12 +110,7 @@ def has_kernel_for_any_dispatch_key(self, ks): def py_impl( self, - k: Union[ - type[TorchDispatchMode], - type[torch.Tensor], - TransformType, - DispatchKey, - ], + k: type[TorchDispatchMode] | type[torch.Tensor] | TransformType | DispatchKey, ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]: if inspect.isclass(k) and ( @@ -185,7 +169,7 @@ def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: return fn(CppFunctionalizeAPI(), *args, **kwargs) def functionalize_dispatch_mode_fn( - mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs + mode: FunctionalTensorMode | None, *args: _P.args, **kwargs: _P.kwargs ) -> _T: return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) @@ -307,12 +291,7 @@ def __init__(self, name, *, cacheable=False): def py_impl( self, - k: Union[ - type[TorchDispatchMode], - type[torch.Tensor], - TransformType, - DispatchKey, - ], + k: type[TorchDispatchMode] | type[torch.Tensor] | TransformType | DispatchKey, ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) @@ -894,7 +873,7 @@ def _uncache_dispatch(self, key: DispatchKey) -> None: self._dispatch_cache.pop(key, None) # This implements the pre-computation logic for the Python dispatcher. - def _get_dispatch(self, key: DispatchKey) -> Union[DispatchKey, Callable[_P, _T]]: + def _get_dispatch(self, key: DispatchKey) -> DispatchKey | Callable[_P, _T]: # This is only called upon a cache miss assert key not in self._dispatch_cache, f"{self} {key}" diff --git a/torch/_sources.py b/torch/_sources.py index 1327729a717b1..e0ab883a8b46c 100644 --- a/torch/_sources.py +++ b/torch/_sources.py @@ -3,7 +3,7 @@ import functools import inspect from textwrap import dedent -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple from torch._C import ErrorReport from torch._C._jit_tree_views import SourceRangeFactory @@ -11,8 +11,8 @@ def get_source_lines_and_file( obj: Any, - error_msg: Optional[str] = None, -) -> tuple[list[str], int, Optional[str]]: + error_msg: str | None = None, +) -> tuple[list[str], int, str | None]: """ Wrapper around inspect.getsourcelines and inspect.getsourcefile. @@ -113,7 +113,7 @@ class ParsedDef(NamedTuple): ast: ast.Module ctx: SourceContext source: str - filename: Optional[str] + filename: str | None file_lineno: int diff --git a/torch/_tensor.py b/torch/_tensor.py index c6351ed75ffcb..c1093f35aa984 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -8,7 +8,7 @@ from collections.abc import Callable from copy import deepcopy from numbers import Number -from typing import Any, cast, Concatenate, Optional, TypeVar, Union +from typing import Any, cast, Concatenate, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -180,10 +180,10 @@ def __deepcopy__(self, memo): new_storage = self._typed_storage()._deepcopy(memo) if self.is_quantized: # quantizer_params can be different type based on torch attribute - quantizer_params: Union[ - tuple[torch.qscheme, float, int], - tuple[torch.qscheme, Tensor, Tensor, int], - ] + quantizer_params: ( + tuple[torch.qscheme, float, int] + | tuple[torch.qscheme, Tensor, Tensor, int] + ) if self.qscheme() == torch.per_tensor_affine: quantizer_params = ( self.qscheme(), @@ -366,9 +366,9 @@ def _reduce_ex_internal(self, proto): "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" ) # quantizer_params can be different type based on torch attribute - quantizer_params: Union[ - tuple[torch.qscheme, float, int], tuple[Any, Tensor, Tensor, int] - ] + quantizer_params: ( + tuple[torch.qscheme, float, int] | tuple[Any, Tensor, Tensor, int] + ) if self.qscheme() == torch.per_tensor_affine: quantizer_params = ( torch.per_tensor_affine, @@ -893,7 +893,7 @@ def __reversed__(self): def norm( self, - p: Optional[Union[float, str]] = "fro", + p: float | str | None = "fro", dim=None, keepdim=False, dtype=None, @@ -944,15 +944,15 @@ def lu(self, pivot=True, get_infos=False): def stft( self, n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: "Optional[Tensor]" = None, + hop_length: int | None = None, + win_length: int | None = None, + window: "Tensor | None" = None, center: bool = True, pad_mode: str = "reflect", normalized: bool = False, - onesided: Optional[bool] = None, - return_complex: Optional[bool] = None, - align_to_window: Optional[bool] = None, + onesided: bool | None = None, + return_complex: bool | None = None, + align_to_window: bool | None = None, ): r"""See :func:`torch.stft` @@ -993,13 +993,13 @@ def stft( def istft( self, n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: "Optional[Tensor]" = None, + hop_length: int | None = None, + win_length: int | None = None, + window: "Tensor | None" = None, center: bool = True, normalized: bool = False, - onesided: Optional[bool] = None, - length: Optional[int] = None, + onesided: bool | None = None, + length: int | None = None, return_complex: bool = False, ): r"""See :func:`torch.istft`""" @@ -1528,9 +1528,7 @@ def to_sparse_coo(self): """ return self.to_sparse() - def dim_order( - self, *, ambiguity_check: Union[bool, list[torch.memory_format]] = False - ): + def dim_order(self, *, ambiguity_check: bool | list[torch.memory_format] = False): """ dim_order(ambiguity_check=False) -> tuple @@ -1712,10 +1710,10 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def __dlpack__( self, *, - stream: Optional[Any] = -1, - max_version: Optional[tuple[int, int]] = None, - dl_device: Optional[tuple[enum.IntEnum, int]] = None, - copy: Optional[bool] = None, + stream: Any | None = -1, + max_version: tuple[int, int] | None = None, + dl_device: tuple[enum.IntEnum, int] | None = None, + copy: bool | None = None, ): """ Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 613fa9ad6ff95..46af738829312 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -3,7 +3,7 @@ import dataclasses import math import textwrap -from typing import Any, Optional +from typing import Any import torch from torch import inf @@ -15,7 +15,7 @@ class __PrinterOptions: threshold: float = 1000 edgeitems: int = 3 linewidth: int = 80 - sci_mode: Optional[bool] = None + sci_mode: bool | None = None PRINT_OPTS = __PrinterOptions() diff --git a/torch/_utils.py b/torch/_utils.py index 01cf9d393188b..70641a7c534d7 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Callable from types import ModuleType -from typing import Any, Generic, Optional, TYPE_CHECKING +from typing import Any, Generic, TYPE_CHECKING from typing_extensions import deprecated, ParamSpec import torch @@ -856,7 +856,7 @@ def _get_device_index( """ if isinstance(device, str): device = torch.device(device) - device_idx: Optional[int] = None + device_idx: int | None = None if isinstance(device, torch.device): if not allow_cpu and device.type == "cpu": raise ValueError(f"Expected a non cpu device, but got: {device}") @@ -1054,7 +1054,7 @@ def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None: ) -def try_import(module_name: str) -> Optional[ModuleType]: +def try_import(module_name: str) -> ModuleType | None: # Implementation based on # https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported if (module := sys.modules.get(module_name, None)) is not None: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 3a172a814e2e5..6f95511b5ce80 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -6,7 +6,7 @@ import tempfile import typing_extensions from collections.abc import Callable -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch @@ -255,7 +255,7 @@ def max_clock_rate(): return 1100 -def get_mast_job_name_version() -> Optional[tuple[str, int]]: +def get_mast_job_name_version() -> tuple[str, int] | None: return None @@ -274,7 +274,7 @@ def get_mast_job_name_version() -> Optional[tuple[str, int]]: REQUIRES_SET_PYTHON_MODULE = False -def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]: +def maybe_upload_prof_stats_to_manifold(profile_path: str) -> str | None: print("Uploading profile stats (fb-only otherwise no-op)") return None @@ -367,11 +367,11 @@ def get_default_numa_options(): return None -def log_triton_builds(fail: Optional[str]): +def log_triton_builds(fail: str | None): pass -def find_compile_subproc_binary() -> Optional[str]: +def find_compile_subproc_binary() -> str | None: """ Allows overriding the binary used for subprocesses """ diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 3f303f78a4713..861d4fd4b4153 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools from collections.abc import Callable -from typing import Any, Optional, Union +from typing import Any from typing_extensions import deprecated import torch @@ -9,13 +9,13 @@ from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten -in_dims_t = Union[int, tuple] -out_dims_t = Union[int, tuple[int, ...]] +in_dims_t = int | tuple +out_dims_t = int | tuple[int, ...] # Checks that all args-to-be-batched have the same batch dim size def _validate_and_get_batch_size( - flat_in_dims: list[Optional[int]], + flat_in_dims: list[int | None], flat_args: list, ) -> int: batch_sizes = [ @@ -31,7 +31,7 @@ def _validate_and_get_batch_size( return batch_sizes[0] -def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int: +def _num_outputs(batched_outputs: Tensor | tuple[Tensor, ...]) -> int: if isinstance(batched_outputs, tuple): return len(batched_outputs) return 1 @@ -115,7 +115,7 @@ def _create_batched_inputs( # Undos the batching (and any batch dimensions) associated with the `vmap_level`. def _unwrap_batched( - batched_outputs: Union[Tensor, tuple[Tensor, ...]], + batched_outputs: Tensor | tuple[Tensor, ...], out_dims: out_dims_t, vmap_level: int, batch_size: int, diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 5aaa77b25697a..a4c8aaafa351b 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -69,7 +69,7 @@ ) from struct import unpack from sys import maxsize -from typing import Any, Union +from typing import Any import torch from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING @@ -84,15 +84,15 @@ "nt", ] -_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set() +_marked_safe_globals_set: set[Callable | tuple[Callable, str]] = set() -def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]): +def _add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]): global _marked_safe_globals_set _marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals)) -def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: +def _get_safe_globals() -> list[Callable | tuple[Callable, str]]: global _marked_safe_globals_set return list(_marked_safe_globals_set) @@ -103,14 +103,14 @@ def _clear_safe_globals(): def _remove_safe_globals( - globals_to_remove: list[Union[Callable, tuple[Callable, str]]], + globals_to_remove: list[Callable | tuple[Callable, str]], ): global _marked_safe_globals_set _marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove) class _safe_globals: - def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]): + def __init__(self, safe_globals: list[Callable | tuple[Callable, str]]): self.safe_globals = safe_globals def __enter__(self): diff --git a/torch/functional.py b/torch/functional.py index 013832d59cfb3..33b0ada75324c 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -2,7 +2,7 @@ import itertools import operator from collections.abc import Sequence -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING import torch import torch.nn.functional as F @@ -120,7 +120,7 @@ def broadcast_shapes(*shapes): def split( tensor: Tensor, - split_size_or_sections: Union[int, list[int]], + split_size_or_sections: int | list[int], dim: int = 0, ) -> tuple[Tensor, ...]: r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. @@ -387,13 +387,13 @@ def parse_subscript(n: int) -> str: if TYPE_CHECKING: # The JIT doesn't understand Union, so only add type annotation for mypy def meshgrid( - *tensors: Union[Tensor, list[Tensor]], indexing: Optional[str] = None + *tensors: Tensor | list[Tensor], indexing: str | None = None ) -> tuple[Tensor, ...]: return _meshgrid(*tensors, indexing=indexing) else: - def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]: + def meshgrid(*tensors, indexing: str | None = None) -> tuple[Tensor, ...]: r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors. This is helpful when you want to visualize data over some @@ -490,7 +490,7 @@ def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]: return _meshgrid(*tensors, indexing=indexing) -def _meshgrid(*tensors, indexing: Optional[str]): +def _meshgrid(*tensors, indexing: str | None): if has_torch_function(tensors): return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing) if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): @@ -508,15 +508,15 @@ def _meshgrid(*tensors, indexing: Optional[str]): def stft( input: Tensor, n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: Optional[Tensor] = None, + hop_length: int | None = None, + win_length: int | None = None, + window: Tensor | None = None, center: bool = True, pad_mode: str = "reflect", normalized: bool = False, - onesided: Optional[bool] = None, - return_complex: Optional[bool] = None, - align_to_window: Optional[bool] = None, + onesided: bool | None = None, + return_complex: bool | None = None, + align_to_window: bool | None = None, ) -> Tensor: r"""Short-time Fourier transform (STFT). @@ -788,7 +788,7 @@ def _unique_impl( sorted: bool = True, return_inverse: bool = False, return_counts: bool = False, - dim: Optional[int] = None, + dim: int | None = None, ) -> _unique_impl_out: r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> tuple[Tensor, Tensor, Tensor] @@ -956,7 +956,7 @@ def _unique_consecutive_impl( input: Tensor, return_inverse: bool = False, return_counts: bool = False, - dim: Optional[int] = None, + dim: int | None = None, ) -> _unique_impl_out: r"""Eliminates all but the first element from every consecutive group of equivalent elements. @@ -1201,7 +1201,7 @@ def tensordot( a, b, dims: int = 2, - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): pass @@ -1210,7 +1210,7 @@ def tensordot( # noqa: F811 a, b, dims: tuple[list[int], list[int]], - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): pass @@ -1219,7 +1219,7 @@ def tensordot( # noqa: F811 a, b, dims: list[list[int]], - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): pass @@ -1228,7 +1228,7 @@ def tensordot( # noqa: F811 a, b, dims: torch.Tensor, - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): pass @@ -1237,7 +1237,7 @@ def tensordot( # noqa: F811 a, b, dims=2, - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): r"""Returns a contraction of a and b over multiple dimensions. @@ -1659,7 +1659,7 @@ def norm( # noqa: F811 def norm( # noqa: F811 input, - p: Optional[Union[float, str]] = "fro", + p: float | str | None = "fro", dim=None, keepdim=False, out=None, @@ -1882,7 +1882,7 @@ def norm( # noqa: F811 def unravel_index( indices: Tensor, - shape: Union[int, Sequence[int], torch.Size], + shape: int | Sequence[int] | torch.Size, ) -> tuple[Tensor, ...]: r"""Converts a tensor of flat indices into a tuple of coordinate tensors that index into an arbitrary tensor of the specified shape. @@ -1938,7 +1938,7 @@ def unravel_index( return res_tensor.unbind(-1) -def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: +def _unravel_index(indices: Tensor, shape: int | Sequence[int]) -> Tensor: torch._check_type( not indices.is_complex() and not indices.is_floating_point() diff --git a/torch/hub.py b/torch/hub.py index bf138f7784347..3ec285fcb3a9e 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -12,7 +12,7 @@ import warnings import zipfile from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from typing_extensions import deprecated from urllib.error import HTTPError, URLError from urllib.parse import urlparse # noqa: F401 @@ -91,7 +91,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): VAR_DEPENDENCY = "dependencies" MODULE_HUBCONF = "hubconf.py" READ_DATA_CHUNK = 128 * 1024 -_hub_dir: Optional[str] = None +_hub_dir: str | None = None @contextlib.contextmanager @@ -417,7 +417,7 @@ def get_dir() -> str: return os.path.join(_get_torch_home(), "hub") -def set_dir(d: Union[str, os.PathLike]) -> None: +def set_dir(d: str | os.PathLike) -> None: r""" Optionally set the Torch Hub directory used to save downloaded models & weights. @@ -694,7 +694,7 @@ def _load_local(hubconf_dir, model, *args, **kwargs): def download_url_to_file( url: str, dst: str, - hash_prefix: Optional[str] = None, + hash_prefix: str | None = None, progress: bool = True, ) -> None: r"""Download object at the given URL to a local path. @@ -816,11 +816,11 @@ def _legacy_zip_load( def load_state_dict_from_url( url: str, - model_dir: Optional[str] = None, + model_dir: str | None = None, map_location: MAP_LOCATION = None, progress: bool = True, check_hash: bool = False, - file_name: Optional[str] = None, + file_name: str | None = None, weights_only: bool = False, ) -> dict[str, Any]: r"""Loads the Torch serialized object at the given URL. diff --git a/torch/library.py b/torch/library.py index 76e5d27aae434..5305d647bc613 100644 --- a/torch/library.py +++ b/torch/library.py @@ -7,7 +7,7 @@ import traceback import weakref from collections.abc import Callable, Sequence -from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union +from typing import Any, overload, TYPE_CHECKING, TypeVar, Union from typing_extensions import deprecated, ParamSpec import torch @@ -98,7 +98,7 @@ def __init__(self, ns, kind, dispatch_key=""): frame = traceback.extract_stack(limit=2)[0] filename, lineno = frame.filename, frame.lineno - self.m: Optional[Any] = torch._C._dispatch_library( + self.m: Any | None = torch._C._dispatch_library( kind, ns, dispatch_key, filename, lineno ) self.ns = ns @@ -399,7 +399,7 @@ def fallback(self, fn, dispatch_key="", *, with_keyset=False): self.m.fallback(dispatch_key, fn, with_keyset) - def _register_effectful_op(self, op_name: str, effect: Optional[EffectType]): + def _register_effectful_op(self, op_name: str, effect: EffectType | None): """ Registers an effect to an operator. This is used to register an op that has side effects that is not capturable by the schema. @@ -570,20 +570,20 @@ def wrap(f): @overload def impl( qualname: str, - types: Union[str, Sequence[str]], + types: str | Sequence[str], func: None = None, *, - lib: Optional[Library] = None, + lib: Library | None = None, ) -> Callable[[Callable[..., object]], None]: ... @overload def impl( qualname: str, - types: Union[str, Sequence[str]], + types: str | Sequence[str], func: Callable[..., object], *, - lib: Optional[Library] = None, + lib: Library | None = None, ) -> None: ... @@ -599,10 +599,10 @@ def impl( @functools.singledispatch def impl( qualname: str, - types: Union[str, Sequence[str]], - func: Optional[Callable[_P, _T]] = None, + types: str | Sequence[str], + func: Callable[_P, _T] | None = None, *, - lib: Optional[Library] = None, + lib: Library | None = None, ) -> object: """Register an implementation for a device type for this operator. @@ -683,10 +683,10 @@ def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]: @overload def _impl( qualname: str, - types: Union[str, Sequence[str]], + types: str | Sequence[str], func: None = None, *, - lib: Optional[Library] = None, + lib: Library | None = None, disable_dynamo: bool = False, ) -> Callable[[Callable[..., object]], None]: ... @@ -694,22 +694,22 @@ def _impl( @overload def _impl( qualname: str, - types: Union[str, Sequence[str]], + types: str | Sequence[str], func: Callable[..., object], *, - lib: Optional[Library] = None, + lib: Library | None = None, disable_dynamo: bool = False, ) -> None: ... def _impl( qualname: str, - types: Union[str, Sequence[str]], - func: Optional[Callable[..., object]] = None, + types: str | Sequence[str], + func: Callable[..., object] | None = None, *, - lib: Optional[Library] = None, + lib: Library | None = None, disable_dynamo: bool = False, -) -> Optional[Callable[[Callable[..., object]], None]]: +) -> Callable[[Callable[..., object]], None] | None: # See impl() if isinstance(types, str): types = (types,) @@ -786,10 +786,10 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): def register_kernel( op: _op_identifier, device_types: device_types_t, - func: Optional[Callable] = None, + func: Callable | None = None, /, *, - lib: Optional[Library] = None, + lib: Library | None = None, ): """Register an implementation for a device type for this operator. @@ -857,7 +857,7 @@ def register_autocast( cast_inputs: _dtype, /, *, - lib: Optional[Library] = None, + lib: Library | None = None, ): r"""Register an autocast dispatch rule for this custom op. @@ -948,10 +948,10 @@ def kernel(_, *args, **kwargs): def register_fake( op: _op_identifier, - func: Optional[Callable] = None, + func: Callable | None = None, /, *, - lib: Optional[Library] = None, + lib: Library | None = None, _stacklevel: int = 1, allow_override: bool = False, ): @@ -1084,9 +1084,9 @@ def register(func): def _register_effectful_op( op: _op_identifier, - effect: Optional[EffectType], + effect: EffectType | None, *, - lib: Optional[Library] = None, + lib: Library | None = None, ) -> None: r""" To specify that an operator has side-effects, we must register an effect @@ -1125,7 +1125,7 @@ def register_autograd( backward: Callable, /, *, - setup_context: Optional[Callable] = None, + setup_context: Callable | None = None, lib=None, ) -> None: r"""Register a backward formula for this custom op. @@ -1253,10 +1253,10 @@ def register_autograd( def register_torch_dispatch( op: _op_identifier, torch_dispatch_class: Any, - func: Optional[Callable] = None, + func: Callable | None = None, /, *, - lib: Optional[Library] = None, + lib: Library | None = None, ): r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. @@ -1333,7 +1333,7 @@ def register(func): def register_vmap( op: _op_identifier, - func: Optional[Callable] = None, + func: Callable | None = None, /, *, lib=None, @@ -1525,7 +1525,7 @@ def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": def get_kernel( - op: _op_identifier, dispatch_key: Union[str, torch.DispatchKey] + op: _op_identifier, dispatch_key: str | torch.DispatchKey ) -> torch._C._SafeKernelFunction: """Returns the computed kernel for a given operator and dispatch key. @@ -1607,11 +1607,11 @@ def get_kernel( def opcheck( - op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], + op: torch._ops.OpOverload | torch._ops.OpOverloadPacket | CustomOpDef, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, *, - test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS, + test_utils: str | Sequence[str] = _OPCHECK_DEFAULT_UTILS, raise_exception: bool = True, atol=None, rtol=None, diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 4bae914f0292b..dd3ff69fd6af8 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import warnings from collections.abc import Callable -from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar, Union +from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar from typing_extensions import ParamSpec import torch @@ -16,7 +16,7 @@ from torch._prims_common import DimsType from torch.types import _dtype as DType - DimOrDims: TypeAlias = Optional[DimsType] + DimOrDims: TypeAlias = DimsType | None else: # The JIT doesn't understand Union, nor torch.dtype here DType = int @@ -624,7 +624,7 @@ def _sparse_coo_scatter_reduction_helper( mask_input: Tensor, dims: tuple[int, ...], keepdim: bool, - dtype: Optional[DType] = None, + dtype: DType | None = None, ) -> Tensor: reduce = op.__name__ valid_reductions = ["sum", "prod", "amax", "amin"] @@ -744,7 +744,7 @@ def _sparse_csr_segment_reduction_helper( mask_input: Tensor, dims: tuple[int, ...], keepdim: bool, - dtype: Optional[DType] = None, + dtype: DType | None = None, ) -> Tensor: # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True # FIXME: when dense dimensions are implemented for CSR tensors @@ -869,7 +869,7 @@ def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: ) -def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor: +def _input_mask(input: Tensor | MaskedTensor, *args, **kwargs) -> Tensor: """Return canonical input mask. A canonical input mask is defined as a boolean mask tensor that @@ -1000,9 +1000,7 @@ def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: ) -def _combine_input_and_mask( - op, input: Union[MaskedTensor, Tensor], mask, *args -) -> Tensor: +def _combine_input_and_mask(op, input: MaskedTensor | Tensor, mask, *args) -> Tensor: def helper(input, mask): if mask is None: return input @@ -1046,12 +1044,12 @@ def backward(ctx, grad_output): @_apply_docstring_templates def sum( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: # __doc__ is generated by _apply_docstring_templates decorator if dtype is None: @@ -1099,12 +1097,12 @@ def sum( @_apply_docstring_templates def prod( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: # __doc__ is generated by _apply_docstring_templates decorator if dtype is None: @@ -1179,8 +1177,8 @@ def cumsum( input: Tensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1199,8 +1197,8 @@ def cumprod( input: Tensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1216,12 +1214,12 @@ def cumprod( @_apply_docstring_templates def amax( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1266,12 +1264,12 @@ def amax( @_apply_docstring_templates def amin( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1316,12 +1314,12 @@ def amin( @_apply_docstring_templates def argmax( - input: Union[Tensor, MaskedTensor], - dim: Optional[int] = None, + input: Tensor | MaskedTensor, + dim: int | None = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1342,12 +1340,12 @@ def argmax( @_apply_docstring_templates def argmin( - input: Union[Tensor, MaskedTensor], - dim: Optional[int] = None, + input: Tensor | MaskedTensor, + dim: int | None = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1368,12 +1366,12 @@ def argmin( @_apply_docstring_templates def mean( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1435,12 +1433,12 @@ def mean( @_apply_docstring_templates def median( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: int = -1, *, keepdim: bool = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1482,8 +1480,8 @@ def logsumexp( dim: DimOrDims = None, *, keepdim: bool = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1499,12 +1497,12 @@ def logsumexp( # Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations def logaddexp( - input: Union[Tensor, MaskedTensor], - other: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, + other: Tensor | MaskedTensor, *, - dtype: Optional[DType] = None, - input_mask: Optional[Tensor] = None, - other_mask: Optional[Tensor] = None, + dtype: DType | None = None, + input_mask: Tensor | None = None, + other_mask: Tensor | None = None, ) -> Tensor: """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor @@ -1561,13 +1559,13 @@ def logaddexp( @_apply_docstring_templates def norm( - input: Union[Tensor, MaskedTensor], - ord: Optional[float] = 2.0, + input: Tensor | MaskedTensor, + ord: float | None = 2.0, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1596,15 +1594,15 @@ def norm( def _std_var( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims, - unbiased: Optional[bool], + unbiased: bool | None, *, - correction_opt: Optional[Union[int, float]], - keepdim: Optional[bool], - dtype: Optional[DType], - mask: Optional[Tensor], - take_sqrt: Optional[bool], + correction_opt: int | float | None, + keepdim: bool | None, + dtype: DType | None, + mask: Tensor | None, + take_sqrt: bool | None, ) -> Tensor: assert unbiased is None or correction_opt is None, ( "Only one of unbiased and correction may be given" @@ -1677,14 +1675,14 @@ def _std_var( @_apply_docstring_templates def var( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, - unbiased: Optional[bool] = None, + unbiased: bool | None = None, *, - correction: Optional[Union[int, float]] = None, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + correction: int | float | None = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1708,14 +1706,14 @@ def var( @_apply_docstring_templates def std( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, - unbiased: Optional[bool] = None, + unbiased: bool | None = None, *, - correction: Optional[int] = None, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + correction: int | None = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1739,11 +1737,11 @@ def std( @_apply_docstring_templates def softmax( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1759,11 +1757,11 @@ def softmax( @_apply_docstring_templates def log_softmax( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1779,11 +1777,11 @@ def log_softmax( @_apply_docstring_templates def softmin( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1799,13 +1797,13 @@ def softmin( @_apply_docstring_templates def normalize( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, ord: float, dim: int, *, eps: float = 1e-12, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype diff --git a/torch/nn/_reduction.py b/torch/nn/_reduction.py index 9764f935b7c3d..a3ca62929a3b5 100644 --- a/torch/nn/_reduction.py +++ b/torch/nn/_reduction.py @@ -1,5 +1,4 @@ import warnings -from typing import Optional # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h @@ -31,8 +30,8 @@ def get_enum(reduction: str) -> int: # We use these functions in torch/legacy as well, in which case we'll silence the warning def legacy_get_string( - size_average: Optional[bool], - reduce: Optional[bool], + size_average: bool | None, + reduce: bool | None, emit_warning: bool = True, ) -> str: warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." @@ -54,8 +53,8 @@ def legacy_get_string( def legacy_get_enum( - size_average: Optional[bool], - reduce: Optional[bool], + size_average: bool | None, + reduce: bool | None, emit_warning: bool = True, ) -> int: return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/torch/nn/common_types.py b/torch/nn/common_types.py index 9262c45472271..e1928414a396e 100644 --- a/torch/nn/common_types.py +++ b/torch/nn/common_types.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeAlias as _TypeAlias, TypeVar +from typing import TypeAlias as _TypeAlias, TypeVar from torch import Tensor @@ -29,9 +29,9 @@ _size_6_t: _TypeAlias = _scalar_or_tuple_6_t[int] # For arguments which represent optional size parameters (eg, adaptive pool parameters) -_size_any_opt_t: _TypeAlias = _scalar_or_tuple_any_t[Optional[int]] -_size_2_opt_t: _TypeAlias = _scalar_or_tuple_2_t[Optional[int]] -_size_3_opt_t: _TypeAlias = _scalar_or_tuple_3_t[Optional[int]] +_size_any_opt_t: _TypeAlias = _scalar_or_tuple_any_t[int | None] +_size_2_opt_t: _TypeAlias = _scalar_or_tuple_2_t[int | None] +_size_3_opt_t: _TypeAlias = _scalar_or_tuple_3_t[int | None] # For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) _ratio_2_t: _TypeAlias = _scalar_or_tuple_2_t[float] diff --git a/torch/nn/init.py b/torch/nn/init.py index 3956d9399876e..900b2d34bc08f 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -3,7 +3,7 @@ import math import warnings from collections.abc import Callable -from typing import Literal, Optional as _Optional, TypeVar +from typing import Literal, TypeVar from typing_extensions import ParamSpec import torch @@ -67,7 +67,7 @@ # managers, so these need to be implemented as builtins. Using these wrappers # lets us keep those builtins small and reusable. def _no_grad_uniform_( - tensor: Tensor, a: float, b: float, generator: _Optional[torch.Generator] = None + tensor: Tensor, a: float, b: float, generator: torch.Generator | None = None ) -> Tensor: with torch.no_grad(): return tensor.uniform_(a, b, generator=generator) @@ -77,7 +77,7 @@ def _no_grad_normal_( tensor: Tensor, mean: float, std: float, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: with torch.no_grad(): return tensor.normal_(mean, std, generator=generator) @@ -89,7 +89,7 @@ def _no_grad_trunc_normal_( std: float, a: float, b: float, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x: float) -> float: @@ -138,7 +138,7 @@ def _no_grad_zero_(tensor: Tensor) -> Tensor: def calculate_gain( - nonlinearity: _NonlinearityType, param: _Optional[int | float] = None + nonlinearity: _NonlinearityType, param: int | float | None = None ) -> float: r"""Return the recommended gain value for the given nonlinearity function. @@ -215,7 +215,7 @@ def uniform_( tensor: Tensor, a: float = 0.0, b: float = 1.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from the uniform distribution. @@ -242,7 +242,7 @@ def normal_( tensor: Tensor, mean: float = 0.0, std: float = 1.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from the normal distribution. @@ -271,7 +271,7 @@ def trunc_normal_( std: float = 1.0, a: float = -2.0, b: float = 2.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from a truncated normal distribution. @@ -438,7 +438,7 @@ def _calculate_fan_in_and_fan_out(tensor: Tensor) -> tuple[int, int]: def xavier_uniform_( tensor: Tensor, gain: float = 1.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Xavier uniform distribution. @@ -471,7 +471,7 @@ def xavier_uniform_( def xavier_normal_( tensor: Tensor, gain: float = 1.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Xavier normal distribution. @@ -515,7 +515,7 @@ def kaiming_uniform_( a: float = 0, mode: _FanMode = "fan_in", nonlinearity: _NonlinearityType = "leaky_relu", - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. @@ -580,7 +580,7 @@ def kaiming_normal_( a: float = 0, mode: _FanMode = "fan_in", nonlinearity: _NonlinearityType = "leaky_relu", - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Kaiming normal distribution. @@ -631,7 +631,7 @@ def kaiming_normal_( def orthogonal_( tensor: Tensor, gain: float = 1, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with a (semi) orthogonal matrix. @@ -683,7 +683,7 @@ def sparse_( tensor: Tensor, sparsity: float, std: float = 0.01, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the 2D input `Tensor` as a sparse matrix. diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index edd65601db985..dac27cdb0d246 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import warnings -from typing import Optional import torch import torch.nn.functional as F @@ -261,8 +260,8 @@ def __init__( min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False, - min_value: Optional[float] = None, - max_value: Optional[float] = None, + min_value: float | None = None, + max_value: float | None = None, ) -> None: super().__init__() if min_value is not None: @@ -1053,7 +1052,7 @@ def extra_repr(self) -> str: return str(self.lambd) -def _check_arg_device(x: Optional[torch.Tensor]) -> bool: +def _check_arg_device(x: torch.Tensor | None) -> bool: if x is not None: return x.device.type in [ "cpu", @@ -1063,7 +1062,7 @@ def _check_arg_device(x: Optional[torch.Tensor]) -> bool: return True -def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool: +def _arg_requires_grad(x: torch.Tensor | None) -> bool: if x is not None: return x.requires_grad return False @@ -1156,8 +1155,8 @@ class MultiheadAttention(Module): """ __constants__ = ["batch_first"] - bias_k: Optional[torch.Tensor] - bias_v: Optional[torch.Tensor] + bias_k: torch.Tensor | None + bias_v: torch.Tensor | None def __init__( self, @@ -1258,12 +1257,12 @@ def forward( query: Tensor, key: Tensor, value: Tensor, - key_padding_mask: Optional[Tensor] = None, + key_padding_mask: Tensor | None = None, need_weights: bool = True, - attn_mask: Optional[Tensor] = None, + attn_mask: Tensor | None = None, average_attn_weights: bool = True, is_causal: bool = False, - ) -> tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: r"""Compute attention outputs using query, key, and value embeddings. Supports optional parameters for padding, masks and attention weights. @@ -1517,10 +1516,10 @@ def forward( def merge_masks( self, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, query: Tensor, - ) -> tuple[Optional[Tensor], Optional[int]]: + ) -> tuple[Tensor | None, int | None]: r"""Determine mask type and combine masks if necessary. If only one mask is provided, that mask @@ -1535,8 +1534,8 @@ def merge_masks( merged_mask: merged mask mask_type: merged mask type (0, 1, or 2) """ - mask_type: Optional[int] = None - merged_mask: Optional[Tensor] = None + mask_type: int | None = None + merged_mask: Tensor | None = None if key_padding_mask is not None: mask_type = 1 @@ -1732,9 +1731,9 @@ class Softmin(Module): """ __constants__ = ["dim"] - dim: Optional[int] + dim: int | None - def __init__(self, dim: Optional[int] = None) -> None: + def __init__(self, dim: int | None = None) -> None: super().__init__() self.dim = dim @@ -1797,9 +1796,9 @@ class Softmax(Module): """ __constants__ = ["dim"] - dim: Optional[int] + dim: int | None - def __init__(self, dim: Optional[int] = None) -> None: + def __init__(self, dim: int | None = None) -> None: super().__init__() self.dim = dim @@ -1882,9 +1881,9 @@ class LogSoftmax(Module): """ __constants__ = ["dim"] - dim: Optional[int] + dim: int | None - def __init__(self, dim: Optional[int] = None) -> None: + def __init__(self, dim: int | None = None) -> None: super().__init__() self.dim = dim diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 2ac05f2e8f933..40a912b4f0568 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Optional +from typing import Any import torch from torch import Tensor @@ -29,7 +29,7 @@ class _NormBase(Module): __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] num_features: int eps: float - momentum: Optional[float] + momentum: float | None affine: bool track_running_stats: bool # WARNING: weight and bias purposely not defined here. @@ -39,7 +39,7 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: Optional[float] = 0.1, + momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, @@ -65,8 +65,8 @@ def __init__( self.register_buffer( "running_var", torch.ones(num_features, **factory_kwargs) ) - self.running_mean: Optional[Tensor] - self.running_var: Optional[Tensor] + self.running_mean: Tensor | None + self.running_var: Tensor | None self.register_buffer( "num_batches_tracked", torch.tensor( @@ -76,7 +76,7 @@ def __init__( **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, ), ) - self.num_batches_tracked: Optional[Tensor] + self.num_batches_tracked: Tensor | None else: self.register_buffer("running_mean", None) self.register_buffer("running_var", None) @@ -146,7 +146,7 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: Optional[float] = 0.1, + momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, @@ -718,10 +718,10 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: Optional[float] = 0.1, + momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True, - process_group: Optional[Any] = None, + process_group: Any | None = None, device=None, dtype=None, ) -> None: diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index f062c4bcbd12b..d99151369e18e 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -4,7 +4,7 @@ import operator from collections import abc as container_abcs, OrderedDict from itertools import chain, islice -from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar +from typing import Any, overload, TYPE_CHECKING, TypeVar from typing_extensions import deprecated, Self import torch @@ -358,7 +358,7 @@ def forward(self, x): _modules: dict[str, Module] # type: ignore[assignment] - def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + def __init__(self, modules: Iterable[Module] | None = None) -> None: super().__init__() if modules is not None: self += modules @@ -545,7 +545,7 @@ def forward(self, x, choice, act): _modules: dict[str, Module] # type: ignore[assignment] - def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + def __init__(self, modules: Mapping[str, Module] | None = None) -> None: super().__init__() if modules is not None: self.update(modules) @@ -673,7 +673,7 @@ def forward(self, x): return x """ - def __init__(self, values: Optional[Iterable[Any]] = None) -> None: + def __init__(self, values: Iterable[Any] | None = None) -> None: super().__init__() self._size = 0 if values is not None: @@ -888,7 +888,7 @@ def copy(self) -> ParameterDict: def __contains__(self, key: str) -> bool: return key in self._keys - def setdefault(self, key: str, default: Optional[Any] = None) -> Any: + def setdefault(self, key: str, default: Any | None = None) -> Any: """Set the default for a key in the Parameterdict. If key is in the ParameterDict, return its value. @@ -927,7 +927,7 @@ def popitem(self) -> tuple[str, Any]: del self[k] return k, val - def get(self, key: str, default: Optional[Any] = None) -> Any: + def get(self, key: str, default: Any | None = None) -> Any: r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. Args: @@ -937,7 +937,7 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: return self[key] if key in self else default # noqa: SIM401 def fromkeys( - self, keys: Iterable[str], default: Optional[Any] = None + self, keys: Iterable[str], default: Any | None = None ) -> ParameterDict: r"""Return a new ParameterDict with the keys provided. diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index b539203f6fedd..8b74b6a5a39e8 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -67,7 +67,7 @@ class _ConvNd(Module): __annotations__ = {"bias": Optional[torch.Tensor]} def _conv_forward( # type: ignore[empty-body] - self, input: Tensor, weight: Tensor, bias: Optional[Tensor] + self, input: Tensor, weight: Tensor, bias: Tensor | None ) -> Tensor: ... in_channels: int @@ -82,7 +82,7 @@ def _conv_forward( # type: ignore[empty-body] groups: int padding_mode: Literal["zeros", "reflect", "replicate", "circular"] weight: Tensor - bias: Optional[Tensor] + bias: Tensor | None def __init__( self, @@ -353,7 +353,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): if self.padding_mode != "zeros": return F.conv1d( F.pad( @@ -531,7 +531,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): if self.padding_mode != "zeros": return F.conv2d( F.pad( @@ -701,7 +701,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): if self.padding_mode != "zeros": return F.conv3d( F.pad( @@ -766,12 +766,12 @@ def __init__( def _output_padding( self, input: Tensor, - output_size: Optional[list[int]], + output_size: list[int] | None, stride: list[int], padding: list[int], kernel_size: list[int], num_spatial_dims: int, - dilation: Optional[list[int]] = None, + dilation: list[int] | None = None, ) -> list[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already @@ -965,7 +965,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: if self.padding_mode != "zeros": raise ValueError( "Only `zeros` padding mode is supported for ConvTranspose1d" @@ -1153,7 +1153,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: """ Performs the forward pass. @@ -1344,7 +1344,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: if self.padding_mode != "zeros": raise ValueError( "Only `zeros` padding mode is supported for ConvTranspose3d" diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index d4c192ee8ce4a..72d90d1c10364 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import itertools -from typing import Any, Optional, Protocol +from typing import Any, Protocol import torch from torch.nn.parameter import is_lazy @@ -167,7 +167,7 @@ class LazyModuleMixin: # modules inheriting from this will change their __class__ to the specified # one after they are fully initialized - cls_to_become: Optional[type[Any]] = None + cls_to_become: type[Any] | None = None def __init__(self: _LazyProtocol, *args, **kwargs): # Mypy doesn't like this super call in a mixin diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 05b39ba762f47..00ada62febded 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs from collections.abc import Callable -from typing import Optional from typing_extensions import deprecated from torch import Tensor @@ -50,14 +49,14 @@ def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> N class _WeightedLoss(_Loss): def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, reduce=None, reduction: str = "mean", ) -> None: super().__init__(size_average, reduce, reduction) self.register_buffer("weight", weight) - self.weight: Optional[Tensor] + self.weight: Tensor | None class L1Loss(_Loss): @@ -241,7 +240,7 @@ class NLLLoss(_WeightedLoss): def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -272,7 +271,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: class NLLLoss2d(NLLLoss): def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -817,17 +816,17 @@ class BCEWithLogitsLoss(_Loss): def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, reduce=None, reduction: str = "mean", - pos_weight: Optional[Tensor] = None, + pos_weight: Tensor | None = None, ) -> None: super().__init__(size_average, reduce, reduction) self.register_buffer("weight", weight) self.register_buffer("pos_weight", pos_weight) - self.weight: Optional[Tensor] - self.pos_weight: Optional[Tensor] + self.weight: Tensor | None + self.pos_weight: Tensor | None def forward(self, input: Tensor, target: Tensor) -> Tensor: """Runs the forward pass.""" @@ -1347,7 +1346,7 @@ class probabilities only when a single class label per minibatch item is too res def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -1626,7 +1625,7 @@ def __init__( self, p: int = 1, margin: float = 1.0, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, reduce=None, reduction: str = "mean", @@ -1869,7 +1868,7 @@ class TripletMarginWithDistanceLoss(_Loss): def __init__( self, *, - distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + distance_function: Callable[[Tensor, Tensor], Tensor] | None = None, margin: float = 1.0, swap: bool = False, reduction: str = "mean", @@ -1879,7 +1878,7 @@ def __init__( raise ValueError( f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead" ) - self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ( + self.distance_function: Callable[[Tensor, Tensor], Tensor] | None = ( distance_function if distance_function is not None else PairwiseDistance() ) self.margin = margin diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 6557f60389964..f9795cc1c74aa 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -115,7 +115,7 @@ def __setstate__(self, state: dict): purposes""" _global_backward_pre_hooks: dict[int, Callable] = OrderedDict() _global_backward_hooks: dict[int, Callable] = OrderedDict() -_global_is_full_backward_hook: Optional[bool] = None +_global_is_full_backward_hook: bool | None = None _global_forward_pre_hooks: dict[int, Callable] = OrderedDict() _global_forward_hooks: dict[int, Callable] = OrderedDict() _global_forward_hooks_always_called: dict[int, bool] = OrderedDict() @@ -453,12 +453,12 @@ def forward(self, x): the change.""" training: bool - _parameters: dict[str, Optional[Parameter]] - _buffers: dict[str, Optional[Tensor]] + _parameters: dict[str, Parameter | None] + _buffers: dict[str, Tensor | None] _non_persistent_buffers_set: set[str] _backward_pre_hooks: dict[int, Callable] _backward_hooks: dict[int, Callable] - _is_full_backward_hook: Optional[bool] + _is_full_backward_hook: bool | None _forward_hooks: dict[int, Callable] # Marks whether the corresponding _forward_hooks accept kwargs or not. # As JIT does not support set[int], this dict is used as a set, where all @@ -477,7 +477,7 @@ def forward(self, x): _load_state_dict_post_hooks: dict[int, Callable] _modules: dict[str, Optional["Module"]] call_super_init: bool = False - _compiled_call_impl: Optional[Callable] = None + _compiled_call_impl: Callable | None = None def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" @@ -526,7 +526,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: forward: Callable[..., Any] = _forward_unimplemented def register_buffer( - self, name: str, tensor: Optional[Tensor], persistent: bool = True + self, name: str, tensor: Tensor | None, persistent: bool = True ) -> None: r"""Add a buffer to the module. @@ -589,7 +589,7 @@ def register_buffer( else: self._non_persistent_buffers_set.add(name) - def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + def register_parameter(self, name: str, param: Parameter | None) -> None: r"""Add a parameter to the module. The parameter can be accessed as an attribute using given name. @@ -1073,7 +1073,7 @@ def apply(self, fn: Callable[["Module"], None]) -> Self: fn(self) return self - def cuda(self, device: Optional[int | device] = None) -> Self: + def cuda(self, device: int | device | None = None) -> Self: r"""Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So @@ -1092,7 +1092,7 @@ def cuda(self, device: Optional[int | device] = None) -> Self: """ return self._apply(lambda t: t.cuda(device)) - def ipu(self, device: Optional[int | device] = None) -> Self: + def ipu(self, device: int | device | None = None) -> Self: r"""Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So @@ -1111,7 +1111,7 @@ def ipu(self, device: Optional[int | device] = None) -> Self: """ return self._apply(lambda t: t.ipu(device)) - def xpu(self, device: Optional[int | device] = None) -> Self: + def xpu(self, device: int | device | None = None) -> Self: r"""Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So @@ -1130,7 +1130,7 @@ def xpu(self, device: Optional[int | device] = None) -> Self: """ return self._apply(lambda t: t.xpu(device)) - def mtia(self, device: Optional[int | device] = None) -> Self: + def mtia(self, device: int | device | None = None) -> Self: r"""Move all model parameters and buffers to the MTIA. This also makes associated parameters and buffers different objects. So @@ -1218,9 +1218,7 @@ def bfloat16(self) -> Self: """ return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) - def to_empty( - self, *, device: Optional[DeviceLikeType], recurse: bool = True - ) -> Self: + def to_empty(self, *, device: DeviceLikeType | None, recurse: bool = True) -> Self: r"""Move the parameters and buffers to the specified device without copying storage. Args: @@ -1239,8 +1237,8 @@ def to_empty( @overload def to( self, - device: Optional[DeviceLikeType] = ..., - dtype: Optional[dtype] = ..., + device: DeviceLikeType | None = ..., + dtype: dtype | None = ..., non_blocking: bool = ..., ) -> Self: ... @@ -1623,9 +1621,9 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn) -> None: def register_forward_pre_hook( self, - hook: Callable[[T, tuple[Any, ...]], Optional[Any]] + hook: Callable[[T, tuple[Any, ...]], Any | None] | Callable[ - [T, tuple[Any, ...], dict[str, Any]], Optional[tuple[Any, dict[str, Any]]] + [T, tuple[Any, ...], dict[str, Any]], tuple[Any, dict[str, Any]] | None ], *, prepend: bool = False, @@ -1686,8 +1684,8 @@ def register_forward_pre_hook( def register_forward_hook( self, - hook: Callable[[T, tuple[Any, ...], Any], Optional[Any]] - | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]], + hook: Callable[[T, tuple[Any, ...], Any], Any | None] + | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Any | None], *, prepend: bool = False, with_kwargs: bool = False, @@ -2830,7 +2828,7 @@ def modules(self) -> Iterator["Module"]: def named_modules( self, - memo: Optional[set["Module"]] = None, + memo: set["Module"] | None = None, prefix: str = "", remove_duplicate: bool = True, ): diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 4a7302d5cae33..d492cdb3cf5a0 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import numbers -from typing import Optional, Union +from typing import Union import torch from torch import Size, Tensor @@ -375,13 +375,13 @@ class RMSNorm(Module): __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: tuple[int, ...] - eps: Optional[float] + eps: float | None elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, - eps: Optional[float] = None, + eps: float | None = None, elementwise_affine: bool = True, device=None, dtype=None, diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 777e6b0abd8c4..1dc57c25b1683 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch.nn.functional as F from torch import Tensor from torch.nn.common_types import ( @@ -57,7 +55,7 @@ class _MaxPoolNd(Module): def __init__( self, kernel_size: _size_any_t, - stride: Optional[_size_any_t] = None, + stride: _size_any_t | None = None, padding: _size_any_t = 0, dilation: _size_any_t = 1, return_indices: bool = False, @@ -389,7 +387,7 @@ class MaxUnpool1d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_1_t, - stride: Optional[_size_1_t] = None, + stride: _size_1_t | None = None, padding: _size_1_t = 0, ) -> None: super().__init__() @@ -398,7 +396,7 @@ def __init__( self.padding = _single(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool1d( @@ -485,7 +483,7 @@ class MaxUnpool2d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_2_t, - stride: Optional[_size_2_t] = None, + stride: _size_2_t | None = None, padding: _size_2_t = 0, ) -> None: super().__init__() @@ -494,7 +492,7 @@ def __init__( self.padding = _pair(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool2d( @@ -564,7 +562,7 @@ class MaxUnpool3d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_3_t, - stride: Optional[_size_3_t] = None, + stride: _size_3_t | None = None, padding: _size_3_t = 0, ) -> None: super().__init__() @@ -573,7 +571,7 @@ def __init__( self.padding = _triple(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool3d( @@ -762,11 +760,11 @@ class AvgPool2d(_AvgPoolNd): def __init__( self, kernel_size: _size_2_t, - stride: Optional[_size_2_t] = None, + stride: _size_2_t | None = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, + divisor_override: int | None = None, ) -> None: super().__init__() self.kernel_size = kernel_size @@ -879,11 +877,11 @@ class AvgPool3d(_AvgPoolNd): def __init__( self, kernel_size: _size_3_t, - stride: Optional[_size_3_t] = None, + stride: _size_3_t | None = None, padding: _size_3_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, + divisor_override: int | None = None, ) -> None: super().__init__() self.kernel_size = kernel_size @@ -964,8 +962,8 @@ class FractionalMaxPool2d(Module): def __init__( self, kernel_size: _size_2_t, - output_size: Optional[_size_2_t] = None, - output_ratio: Optional[_ratio_2_t] = None, + output_size: _size_2_t | None = None, + output_ratio: _ratio_2_t | None = None, return_indices: bool = False, _random_samples=None, ) -> None: @@ -1050,8 +1048,8 @@ class FractionalMaxPool3d(Module): def __init__( self, kernel_size: _size_3_t, - output_size: Optional[_size_3_t] = None, - output_ratio: Optional[_ratio_3_t] = None, + output_size: _size_3_t | None = None, + output_ratio: _ratio_3_t | None = None, return_indices: bool = False, _random_samples=None, ) -> None: @@ -1106,7 +1104,7 @@ def __init__( self, norm_type: float, kernel_size: _size_any_t, - stride: Optional[_size_any_t] = None, + stride: _size_any_t | None = None, ceil_mode: bool = False, ) -> None: super().__init__() diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 13cd9ec08cb55..68e8292870fc8 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -4,7 +4,7 @@ import numbers import warnings import weakref -from typing import Optional, overload +from typing import overload from typing_extensions import deprecated import torch @@ -106,7 +106,7 @@ def __init__( self.dropout = float(dropout) self.bidirectional = bidirectional self.proj_size = proj_size - self._flat_weight_refs: list[Optional[weakref.ReferenceType[Parameter]]] = [] + self._flat_weight_refs: list[weakref.ReferenceType[Parameter] | None] = [] num_directions = 2 if bidirectional else 1 if ( @@ -298,7 +298,7 @@ def reset_parameters(self) -> None: for weight in self.parameters(): init.uniform_(weight, -stdv, stdv) - def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + def check_input(self, input: Tensor, batch_sizes: Tensor | None) -> None: if not torch.jit.is_scripting(): if ( input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr] @@ -318,7 +318,7 @@ def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: ) def get_expected_hidden_size( - self, input: Tensor, batch_sizes: Optional[Tensor] + self, input: Tensor, batch_sizes: Tensor | None ) -> tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) @@ -362,14 +362,14 @@ def _weights_have_changed(self): return weights_changed def check_forward_args( - self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + self, input: Tensor, hidden: Tensor, batch_sizes: Tensor | None ) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) self.check_hidden_size(hidden, expected_hidden_size) - def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): + def permute_hidden(self, hx: Tensor, permutation: Tensor | None): if permutation is None: return hx return _apply_permutation(hx, permutation) @@ -645,7 +645,7 @@ def __init__(self, *args, **kwargs): def forward( self, input: Tensor, - hx: Optional[Tensor] = None, + hx: Tensor | None = None, ) -> tuple[Tensor, Tensor]: pass @@ -654,7 +654,7 @@ def forward( def forward( self, input: PackedSequence, - hx: Optional[Tensor] = None, + hx: Tensor | None = None, ) -> tuple[PackedSequence, Tensor]: pass @@ -990,7 +990,7 @@ def __init__(self, *args, **kwargs): super().__init__("LSTM", *args, **kwargs) def get_expected_cell_size( - self, input: Tensor, batch_sizes: Optional[Tensor] + self, input: Tensor, batch_sizes: Tensor | None ) -> tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) @@ -1010,7 +1010,7 @@ def check_forward_args( self, input: Tensor, hidden: tuple[Tensor, Tensor], # type: ignore[override] - batch_sizes: Optional[Tensor], + batch_sizes: Tensor | None, ) -> None: self.check_input(input, batch_sizes) self.check_hidden_size( @@ -1028,7 +1028,7 @@ def check_forward_args( def permute_hidden( # type: ignore[override] self, hx: tuple[Tensor, Tensor], - permutation: Optional[Tensor], + permutation: Tensor | None, ) -> tuple[Tensor, Tensor]: if permutation is None: return hx @@ -1042,7 +1042,7 @@ def permute_hidden( # type: ignore[override] def forward( self, input: Tensor, - hx: Optional[tuple[Tensor, Tensor]] = None, + hx: tuple[Tensor, Tensor] | None = None, ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811 pass @@ -1052,7 +1052,7 @@ def forward( def forward( self, input: PackedSequence, - hx: Optional[tuple[Tensor, Tensor]] = None, + hx: tuple[Tensor, Tensor] | None = None, ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811 pass @@ -1338,7 +1338,7 @@ def __init__(self, *args, **kwargs): def forward( self, input: Tensor, - hx: Optional[Tensor] = None, + hx: Tensor | None = None, ) -> tuple[Tensor, Tensor]: # noqa: F811 pass @@ -1347,7 +1347,7 @@ def forward( def forward( self, input: PackedSequence, - hx: Optional[Tensor] = None, + hx: Tensor | None = None, ) -> tuple[PackedSequence, Tensor]: # noqa: F811 pass @@ -1584,7 +1584,7 @@ def __init__( super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) self.nonlinearity = nonlinearity - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: if input.dim() not in (1, 2): raise ValueError( f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" @@ -1704,7 +1704,7 @@ def __init__( super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) def forward( - self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None ) -> tuple[Tensor, Tensor]: if input.dim() not in (1, 2): raise ValueError( @@ -1815,7 +1815,7 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: if input.dim() not in (1, 2): raise ValueError( f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index abcd7240a742c..f5775f63ff4ad 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -2,7 +2,7 @@ import copy import warnings from collections.abc import Callable -from typing import Any, Optional +from typing import Any import torch import torch.nn.functional as F @@ -28,8 +28,8 @@ def _generate_square_subsequent_mask( sz: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. @@ -41,7 +41,7 @@ def _generate_square_subsequent_mask( ) -def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: +def _get_seq_len(src: Tensor, batch_first: bool) -> int | None: if src.is_nested: return None else: @@ -106,8 +106,8 @@ def __init__( dim_feedforward: int = 2048, dropout: float = 0.1, activation: str | Callable[[Tensor], Tensor] = F.relu, - custom_encoder: Optional[Any] = None, - custom_decoder: Optional[Any] = None, + custom_encoder: Any | None = None, + custom_decoder: Any | None = None, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, @@ -182,14 +182,14 @@ def forward( self, src: Tensor, tgt: Tensor, - src_mask: Optional[Tensor] = None, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - src_is_causal: Optional[bool] = None, - tgt_is_causal: Optional[bool] = None, + src_mask: Tensor | None = None, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, + src_is_causal: bool | None = None, + tgt_is_causal: bool | None = None, memory_is_causal: bool = False, ) -> Tensor: r"""Take in and process masked source/target sequences. @@ -301,8 +301,8 @@ def forward( @staticmethod def generate_square_subsequent_mask( sz: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. @@ -354,7 +354,7 @@ def __init__( self, encoder_layer: "TransformerEncoderLayer", num_layers: int, - norm: Optional[Module] = None, + norm: Module | None = None, enable_nested_tensor: bool = True, mask_check: bool = True, ) -> None: @@ -407,9 +407,9 @@ def __init__( def forward( self, src: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - is_causal: Optional[bool] = None, + mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + is_causal: bool | None = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -587,7 +587,7 @@ def __init__( self, decoder_layer: "TransformerDecoderLayer", num_layers: int, - norm: Optional[Module] = None, + norm: Module | None = None, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") @@ -599,11 +599,11 @@ def forward( self, tgt: Tensor, memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - tgt_is_causal: Optional[bool] = None, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, + tgt_is_causal: bool | None = None, memory_is_causal: bool = False, ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer in turn. @@ -798,8 +798,8 @@ def __setstate__(self, state): def forward( self, src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + src_mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, is_causal: bool = False, ) -> Tensor: r"""Pass the input through the encoder layer. @@ -959,8 +959,8 @@ def forward( def _sa_block( self, x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, is_causal: bool = False, ) -> Tensor: x = self.self_attn( @@ -1088,10 +1088,10 @@ def forward( self, tgt: Tensor, memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, tgt_is_causal: bool = False, memory_is_causal: bool = False, ) -> Tensor: @@ -1156,8 +1156,8 @@ def forward( def _sa_block( self, x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, is_causal: bool = False, ) -> Tensor: x = self.self_attn( @@ -1176,8 +1176,8 @@ def _mha_block( self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, is_causal: bool = False, ) -> Tensor: x = self.multihead_attn( @@ -1212,9 +1212,9 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: def _detect_is_causal_mask( - mask: Optional[Tensor], - is_causal: Optional[bool] = None, - size: Optional[int] = None, + mask: Tensor | None, + is_causal: bool | None = None, + size: int | None = None, ) -> bool: """Return whether the given attention mask is causal. diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index 7fd102a768225..29e58bc6a9f37 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch.nn.functional as F from torch import Tensor @@ -143,19 +142,19 @@ class Upsample(Module): "recompute_scale_factor", ] name: str - size: Optional[_size_any_t] - scale_factor: Optional[_ratio_any_t] + size: _size_any_t | None + scale_factor: _ratio_any_t | None mode: str - align_corners: Optional[bool] - recompute_scale_factor: Optional[bool] + align_corners: bool | None + recompute_scale_factor: bool | None def __init__( self, - size: Optional[_size_any_t] = None, - scale_factor: Optional[_ratio_any_t] = None, + size: _size_any_t | None = None, + scale_factor: _ratio_any_t | None = None, mode: str = "nearest", - align_corners: Optional[bool] = None, - recompute_scale_factor: Optional[bool] = None, + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, ) -> None: super().__init__() self.name = type(self).__name__ @@ -242,8 +241,8 @@ class UpsamplingNearest2d(Upsample): def __init__( self, - size: Optional[_size_2_t] = None, - scale_factor: Optional[_ratio_2_t] = None, + size: _size_2_t | None = None, + scale_factor: _ratio_2_t | None = None, ) -> None: super().__init__(size, scale_factor, mode="nearest") @@ -293,7 +292,7 @@ class UpsamplingBilinear2d(Upsample): def __init__( self, - size: Optional[_size_2_t] = None, - scale_factor: Optional[_ratio_2_t] = None, + size: _size_2_t | None = None, + scale_factor: _ratio_2_t | None = None, ) -> None: super().__init__(size, scale_factor, mode="bilinear", align_corners=True) diff --git a/torch/overrides.py b/torch/overrides.py index e0597eafd8107..b1193bab3d6dc 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -30,7 +30,7 @@ import warnings from collections.abc import Callable, Iterable from functools import wraps -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch @@ -1609,7 +1609,7 @@ def wrapped(*args, **kwargs): def _get_overloaded_args( relevant_args: Iterable[Any], - get_type_fn: Optional[Callable[[Any], type]] = None, + get_type_fn: Callable[[Any], type] | None = None, ) -> list[Any]: """Returns a list of arguments on which to call __torch_function__. diff --git a/torch/quasirandom.py b/torch/quasirandom.py index b5d4540e592f1..f9e6619cab180 100644 --- a/torch/quasirandom.py +++ b/torch/quasirandom.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch @@ -78,8 +77,8 @@ def __init__(self, dimension, scramble=False, seed=None): def draw( self, n: int = 1, - out: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, + out: torch.Tensor | None = None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: r""" Function to draw a sequence of :attr:`n` points from a Sobol sequence. @@ -131,8 +130,8 @@ def draw( def draw_base2( self, m: int, - out: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, + out: torch.Tensor | None = None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: r""" Function to draw a sequence of :attr:`2**m` points from a Sobol sequence. @@ -187,7 +186,7 @@ def fast_forward(self, n): return self def _scramble(self): - g: Optional[torch.Generator] = None + g: torch.Generator | None = None if self.seed is not None: g = torch.Generator() g.manual_seed(self.seed) diff --git a/torch/serialization.py b/torch/serialization.py index 398d011f324b5..1a6acc8010634 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -16,7 +16,7 @@ from collections.abc import Callable from contextlib import closing, contextmanager from enum import Enum -from typing import Any, cast, Generic, IO, Optional, TypeAlias, TypeVar, Union +from typing import Any, cast, Generic, IO, TypeAlias, TypeVar from typing_extensions import TypeIs import torch @@ -66,10 +66,10 @@ PROTOCOL_VERSION = 1001 STORAGE_KEY_SEPARATOR = "," -MAP_LOCATION: TypeAlias = Optional[ - Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]] -] -STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] +MAP_LOCATION: TypeAlias = ( + Callable[[Storage, str], Storage] | torch.device | str | dict[str, str] | None +) +STORAGE: TypeAlias = Storage | torch.storage.TypedStorage | torch.UntypedStorage IS_WINDOWS = sys.platform == "win32" @@ -99,7 +99,7 @@ def _default_to_weights_only(pickle_module): class _SerializationLocal(threading.local): def __init__(self): super().__init__() - self.map_location: Optional[MAP_LOCATION] = None + self.map_location: MAP_LOCATION | None = None self.skip_data: bool = False self.materialize_fake_tensors: bool = False @@ -123,8 +123,8 @@ def mkdtemp(): _package_registry: list[ tuple[ int, - Callable[[STORAGE], Optional[str]], - Callable[[STORAGE, str], Optional[STORAGE]], + Callable[[STORAGE], str | None], + Callable[[STORAGE, str], STORAGE | None], ] ] = [] @@ -135,7 +135,7 @@ class LoadEndianness(Enum): BIG = 3 -def get_default_load_endianness() -> Optional[LoadEndianness]: +def get_default_load_endianness() -> LoadEndianness | None: """ Get fallback byte order for loading files @@ -197,7 +197,7 @@ def set_crc32_options(compute_crc32: bool): config.save.compute_crc32 = compute_crc32 -def get_default_mmap_options() -> Optional[int]: +def get_default_mmap_options() -> int | None: """ Get default mmap options for :func:`torch.load` with ``mmap=True``. @@ -272,14 +272,14 @@ def clear_safe_globals() -> None: _weights_only_unpickler._clear_safe_globals() -def get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: +def get_safe_globals() -> list[Callable | tuple[Callable, str]]: """ Returns the list of user-added globals that are safe for ``weights_only`` load. """ return _weights_only_unpickler._get_safe_globals() -def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]) -> None: +def add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]) -> None: """ Marks the given globals as safe for ``weights_only`` load. For example, functions added to this list can be called during unpickling, classes could be instantiated @@ -443,8 +443,8 @@ def _is_zipfile(f) -> bool: def register_package( priority: int, - tagger: Callable[[STORAGE], Optional[str]], - deserializer: Callable[[STORAGE, str], Optional[STORAGE]], + tagger: Callable[[STORAGE], str | None], + deserializer: Callable[[STORAGE, str], STORAGE | None], ): """ Registers callables for tagging and deserializing storage objects with an associated priority. @@ -672,7 +672,7 @@ def _deserialize(backend_name, obj, location): def location_tag( - storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage], + storage: Storage | torch.storage.TypedStorage | torch.UntypedStorage, ): for _, tagger, _ in _package_registry: location = tagger(storage) @@ -726,7 +726,7 @@ def storage_to_tensor_type(storage): return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) -def _is_path(name_or_buffer: object) -> TypeIs[Union[str, os.PathLike]]: +def _is_path(name_or_buffer: object) -> TypeIs[str | os.PathLike]: return isinstance(name_or_buffer, (str, os.PathLike)) @@ -745,7 +745,7 @@ def __exit__(self, *args): class _open_file(_opener[IO[bytes]]): - def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None: + def __init__(self, name: str | os.PathLike[str], mode: str) -> None: super().__init__(open(name, mode)) # noqa: SIM115 def __exit__(self, *args): @@ -776,7 +776,7 @@ def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]: class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]): - def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None: + def __init__(self, name_or_buffer: str | IO[bytes]) -> None: super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) @@ -829,7 +829,7 @@ def __exit__(self, *args) -> None: self.buffer.flush() -def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener: +def _open_zipfile_writer(name_or_buffer: str | IO[bytes]) -> _opener: container: type[_opener] if _is_path(name_or_buffer): container = _open_zipfile_writer_file @@ -1004,7 +1004,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: # TODO: This feature could be added in the future storage_dtypes: dict[int, torch.dtype] = {} - def persistent_id(obj: Any) -> Optional[tuple]: + def persistent_id(obj: Any) -> tuple | None: # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see @@ -1064,7 +1064,7 @@ def persistent_id(obj: Any) -> Optional[tuple]: else: storage_dtypes[storage.data_ptr()] = storage_dtype - view_metadata: Optional[tuple[str, int, int]] + view_metadata: tuple[str, int, int] | None # Offset is always 0, but we keep it for backwards compatibility # with the old serialization format (which supported storage views) @@ -1291,8 +1291,8 @@ def load( map_location: MAP_LOCATION = None, pickle_module: Any = None, *, - weights_only: Optional[bool] = None, - mmap: Optional[bool] = None, + weights_only: bool | None = None, + mmap: bool | None = None, **pickle_load_args: Any, ) -> Any: # Reference: https://github.com/pytorch/pytorch/issues/54354 @@ -1852,7 +1852,7 @@ def persistent_load(saved_id): return result -def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: +def _maybe_decode_ascii(bytes_str: bytes | str) -> str: # When using encoding='bytes' in Py3, some **internal** keys stored as # strings in Py2 are loaded as bytes. This function decodes them with # ascii encoding, one that Py3 uses by default. diff --git a/torch/storage.py b/torch/storage.py index 1b9023121ddfb..29847d958523d 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -8,7 +8,7 @@ import io import threading import warnings -from typing import Any, cast, Optional as _Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, cast, TYPE_CHECKING, TypeVar from typing_extensions import Self import torch @@ -35,7 +35,7 @@ _share_memory_lock = threading.Lock() _share_memory_map: dict[int, threading.RLock] = {} -T = TypeVar("T", bound="Union[_StorageBase, TypedStorage]") +T = TypeVar("T", bound="_StorageBase | TypedStorage") class _StorageBase: @@ -46,9 +46,9 @@ class _StorageBase: # Used when # (1) stashing FakeTensor device onto storage in torch.serialization.skip_data # (2) stashing device onto storage to propagate to FakeTensor when torch.load under FakeTensorMode - _fake_device: _Optional[torch.device] = None + _fake_device: torch.device | None = None # Used when loading with FakeTensorMode to give information about offset of storage in torch.saved-file - _checkpoint_offset: _Optional[int] = None + _checkpoint_offset: int | None = None def __init__(self, *args, **kwargs): pass @@ -62,10 +62,10 @@ def __getitem__(self, idx): def __setitem__(self, *args, **kwargs): raise NotImplementedError - def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T: + def copy_(self, source: T, non_blocking: _bool | None = None) -> T: raise NotImplementedError - def new(self) -> Union[_StorageBase, TypedStorage]: + def new(self) -> _StorageBase | TypedStorage: raise NotImplementedError def nbytes(self) -> _int: @@ -75,13 +75,11 @@ def size(self) -> _int: return self.nbytes() def type( - self, dtype: _Optional[str] = None, non_blocking: _bool = False - ) -> Union[_StorageBase, TypedStorage]: + self, dtype: str | None = None, non_blocking: _bool = False + ) -> _StorageBase | TypedStorage: return _type(self, dtype, non_blocking) - def cuda( - self, device=None, non_blocking=False - ) -> Union[_StorageBase, TypedStorage]: + def cuda(self, device=None, non_blocking=False) -> _StorageBase | TypedStorage: """Returns a copy of this object in CUDA memory. If this object is already in CUDA memory and on the correct device, then @@ -96,7 +94,7 @@ def cuda( device2 = torch.device("cuda", device) if device else torch.device("cuda") return self.to(device=device2, non_blocking=non_blocking) - def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]: + def hpu(self, device=None, non_blocking=False) -> _StorageBase | TypedStorage: """Returns a copy of this object in HPU memory. If this object is already in HPU memory and on the correct device, then @@ -166,7 +164,7 @@ def _release_ipc_counter_cuda(cls, *args, **kwargs) -> Self: def _new_with_weak_ptr(cls, *args, **kwargs) -> Self: raise NotImplementedError - def _shared_decref(self) -> Union[_StorageBase, TypedStorage]: + def _shared_decref(self) -> _StorageBase | TypedStorage: raise NotImplementedError def _write_file(self, *args, **kwargs): @@ -175,7 +173,7 @@ def _write_file(self, *args, **kwargs): def resize_(self, size: _int): raise NotImplementedError - def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: + def _weak_ref(self, *args, **kwargs) -> _StorageBase | TypedStorage: raise NotImplementedError def _set_from_file(self, *args, **kwargs): @@ -210,17 +208,17 @@ def is_hpu(self): raise NotImplementedError @classmethod - def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]: + def from_file(cls, filename, shared, nbytes) -> _StorageBase | TypedStorage: raise NotImplementedError @classmethod - def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: + def _expired(cls, *args, **kwargs) -> _StorageBase | TypedStorage: raise NotImplementedError def _byteswap(self, *args, **kwargs): raise NotImplementedError - def _get_filename(self, *args, **kwargs) -> _Optional[str]: + def _get_filename(self, *args, **kwargs) -> str | None: raise NotImplementedError def __repr__(self): @@ -354,7 +352,7 @@ def float8_e4m3fnuz(self): """Casts this storage to float8_e4m3fnuz type""" return self._to(torch.float8_e4m3fnuz) - def is_pinned(self, device: Union[str, torch.device] = "cuda"): + def is_pinned(self, device: str | torch.device = "cuda"): r"""Determine whether the CPU storage is already pinned on device. Args: @@ -370,7 +368,7 @@ def is_pinned(self, device: Union[str, torch.device] = "cuda"): .is_pinned(device) ) - def pin_memory(self, device: Union[str, torch.device] = "cuda"): + def pin_memory(self, device: str | torch.device = "cuda"): r"""Copy the CPU storage to pinned memory, if it's not already pinned. Args: @@ -478,7 +476,7 @@ def is_hpu(self): return self.device.type == "hpu" @property - def filename(self) -> _Optional[str]: + def filename(self) -> str | None: """Returns the file name associated with this storage. The file name will be a string if the storage is on CPU and was created via @@ -671,7 +669,7 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: _Optional[torch.device] = None + _fake_device: torch.device | None = None dtype: torch.dtype @@ -680,7 +678,7 @@ def _dtype(self): return self.dtype @property - def filename(self) -> _Optional[str]: + def filename(self) -> str | None: """Returns the file name associated with this storage if the storage was memory mapped from a file. or ``None`` if the storage was not created by memory mapping a file.""" return self._untyped_storage.filename @@ -1018,7 +1016,7 @@ def _getitem(self, idx): ).set_(self) return tmp_tensor[idx_wrapped].item() - def copy_(self, source: T, non_blocking: _Optional[bool] = None): + def copy_(self, source: T, non_blocking: bool | None = None): _warn_typed_storage_removal() if isinstance(source, TypedStorage): self._untyped_storage.copy_(source._untyped_storage, non_blocking) @@ -1036,9 +1034,9 @@ def _nbytes(self): def type( self, - dtype: _Optional[str] = None, + dtype: str | None = None, non_blocking: bool = False, - ) -> Union[_StorageBase, TypedStorage, str]: + ) -> _StorageBase | TypedStorage | str: _warn_typed_storage_removal() if dtype is None: legacy_class = self._get_legacy_storage_class() @@ -1157,7 +1155,7 @@ def cpu(self): _warn_typed_storage_removal() return self._new_wrapped_storage(self._untyped_storage.cpu()) - def is_pinned(self, device: Union[str, torch.device] = "cuda"): + def is_pinned(self, device: str | torch.device = "cuda"): r"""Determine whether the CPU TypedStorage is already pinned on device. Args: @@ -1170,7 +1168,7 @@ def is_pinned(self, device: Union[str, torch.device] = "cuda"): _warn_typed_storage_removal() return self._untyped_storage.is_pinned(device) - def pin_memory(self, device: Union[str, torch.device] = "cuda"): + def pin_memory(self, device: str | torch.device = "cuda"): r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned. Args: diff --git a/torch/types.py b/torch/types.py index 0388c9c66aefe..9ed69a859b1ee 100644 --- a/torch/types.py +++ b/torch/types.py @@ -38,7 +38,7 @@ # Convenience aliases for common composite types that we need # to talk about in PyTorch -_TensorOrTensors: TypeAlias = Union[Tensor, Sequence[Tensor]] # noqa: PYI047 +_TensorOrTensors: TypeAlias = Tensor | Sequence[Tensor] # noqa: PYI047 _TensorOrTensorsOrGradEdge: TypeAlias = Union[ # noqa: PYI047 Tensor, Sequence[Tensor], @@ -46,32 +46,32 @@ Sequence["GradientEdge"], ] -_size: TypeAlias = Union[Size, list[int], tuple[int, ...]] # noqa: PYI042,PYI047 -_symsize: TypeAlias = Union[Size, Sequence[Union[int, SymInt]]] # noqa: PYI042,PYI047 -_dispatchkey: TypeAlias = Union[str, DispatchKey] # noqa: PYI042,PYI047 +_size: TypeAlias = Size | list[int] | tuple[int, ...] # noqa: PYI042,PYI047 +_symsize: TypeAlias = Size | Sequence[int | SymInt] # noqa: PYI042,PYI047 +_dispatchkey: TypeAlias = str | DispatchKey # noqa: PYI042,PYI047 # int or SymInt -IntLikeType: TypeAlias = Union[int, SymInt] +IntLikeType: TypeAlias = int | SymInt # float or SymFloat -FloatLikeType: TypeAlias = Union[float, SymFloat] +FloatLikeType: TypeAlias = float | SymFloat # bool or SymBool -BoolLikeType: TypeAlias = Union[bool, SymBool] +BoolLikeType: TypeAlias = bool | SymBool py_sym_types = (SymInt, SymFloat, SymBool) # left un-annotated intentionally -PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool] +PySymType: TypeAlias = SymInt | SymFloat | SymBool # Meta-type for "numeric" things; matches our docs -Number: TypeAlias = Union[int, float, bool] +Number: TypeAlias = int | float | bool # tuple for isinstance(x, Number) checks. # FIXME: refactor once python 3.9 support is dropped. _Number = (int, float, bool) -FileLike: TypeAlias = Union[str, os.PathLike[str], IO[bytes]] +FileLike: TypeAlias = str | os.PathLike[str] | IO[bytes] # Meta-type for "device-like" things. Not to be confused with 'device' (a # literal device object). This nomenclature is consistent with PythonArgParser. # None means use the default device (typically CPU) -Device: TypeAlias = Union[_device, str, int, None] +Device: TypeAlias = _device | str | int | None # Storage protocol implemented by ${Type}StorageBase classes diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 194684e3388e4..6cb4f9b9c012b 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -218,7 +218,7 @@ def set_device(device: _device_t) -> None: torch._C._xpu_setDevice(device) -def get_device_name(device: Optional[_device_t] = None) -> str: +def get_device_name(device: _device_t | None = None) -> str: r"""Get the name of a device. Args: @@ -234,7 +234,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str: @lru_cache(None) -def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: +def get_device_capability(device: _device_t | None = None) -> dict[str, Any]: r"""Get the xpu capability of a device. Args: @@ -259,7 +259,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: def get_device_properties( - device: Optional[_device_t] = None, + device: _device_t | None = None, ) -> _XpuDeviceProperties: # pyrefly: ignore # not-a-type r"""Get the properties of a device. @@ -281,7 +281,7 @@ def current_device() -> int: return torch._C._xpu_getDevice() -def _get_device(device: Union[int, str, torch.device]) -> torch.device: +def _get_device(device: int | str | torch.device) -> torch.device: r"""Return the torch.device type object from the passed in device. Args: @@ -395,7 +395,7 @@ def set_stream(stream: Stream) -> None: ) -def current_stream(device: Optional[_device_t] = None) -> Stream: +def current_stream(device: _device_t | None = None) -> Stream: r"""Return the currently selected :class:`Stream` for a given device. Args: @@ -413,9 +413,7 @@ def current_stream(device: Optional[_device_t] = None) -> Stream: ) -def get_stream_from_external( - data_ptr: int, device: Optional[_device_t] = None -) -> Stream: +def get_stream_from_external(data_ptr: int, device: _device_t | None = None) -> Stream: r"""Return a :class:`Stream` from an external SYCL queue. This function is used to wrap SYCL queue created in other libraries in order @@ -484,7 +482,7 @@ def _get_generator(device: torch.device) -> torch._C.Generator: def _set_rng_state_offset( - offset: int, device: Union[int, str, torch.device] = "xpu" + offset: int, device: int | str | torch.device = "xpu" ) -> None: r"""Set the random number generator state offset of the specified GPU. @@ -502,7 +500,7 @@ def cb() -> None: _lazy_call(cb) -def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: +def _get_rng_state_offset(device: int | str | torch.device = "xpu") -> int: r"""Return the random number generator state offset of the specified GPU. Args: diff --git a/torch/xpu/random.py b/torch/xpu/random.py index ec770225aef39..8b489e871f7c5 100644 --- a/torch/xpu/random.py +++ b/torch/xpu/random.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs from collections.abc import Iterable -from typing import Union import torch from torch import Tensor @@ -8,7 +7,7 @@ from . import _lazy_call, _lazy_init, current_device, device_count -def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor: +def get_rng_state(device: int | str | torch.device = "xpu") -> Tensor: r"""Return the random number generator state of the specified GPU as a ByteTensor. Args: @@ -36,9 +35,7 @@ def get_rng_state_all() -> list[Tensor]: return results -def set_rng_state( - new_state: Tensor, device: Union[int, str, torch.device] = "xpu" -) -> None: +def set_rng_state(new_state: Tensor, device: int | str | torch.device = "xpu") -> None: r"""Set the random number generator state of the specified GPU. Args: From 1902eddfe655a15ebcf2c72bd81ade110fdeef63 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Mon, 1 Dec 2025 09:26:13 -0800 Subject: [PATCH 091/338] [precompile] support serde for torch.nn.attention.SDPBackend (#168988) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously would fail deserialization since ``` SDPBackend.__name__ → "SDPBackend" SDPBackend.__qualname__ → "_SDPBackend" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168988 Approved by: https://github.com/zhxchen17 ghstack dependencies: #168989, #169008 --- test/dynamo/test_guard_serialization.py | 42 +++++++++++++++++++++++++ torch/_dynamo/guards.py | 8 +++++ 2 files changed, 50 insertions(+) diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 9e3a62477db97..ec333ed5b0dc7 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -1725,6 +1725,48 @@ def foo(x): with torch.compiler.set_stance("fail_on_recompile"): self.assertEqual(compiled_fn(x), foo(x)) + def test_sdp_backend_serialization(self): + def fn(x, backend): + # Use the backend enum in a guard-producing way + if backend == torch.nn.attention.SDPBackend.MATH: + return x + 1 + elif backend == torch.nn.attention.SDPBackend.FLASH_ATTENTION: + return x + 2 + elif backend == torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION: + return x + 3 + else: + return x + 4 + + x = torch.randn(3, 2) + backend = torch.nn.attention.SDPBackend.MATH + + ref, loaded = self._test_serialization("EQUALS_MATCH", fn, x, backend) + + # Test with the same backend + self._test_check_fn( + ref, loaded, {"x": x, "backend": torch.nn.attention.SDPBackend.MATH}, True + ) + + # Test with different backends + self._test_check_fn( + ref, + loaded, + {"x": x, "backend": torch.nn.attention.SDPBackend.FLASH_ATTENTION}, + False, + ) + self._test_check_fn( + ref, + loaded, + {"x": x, "backend": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION}, + False, + ) + self._test_check_fn( + ref, + loaded, + {"x": x, "backend": torch.nn.attention.SDPBackend.CUDNN_ATTENTION}, + False, + ) + class SimpleModule(torch.nn.Module): def __init__(self, c): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 335323e638769..71ddfed60df02 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -3326,6 +3326,11 @@ def _unpickle_c_op(cls, name: str) -> Any: def _unpickle_bound_method(cls, func: Any, base: Any) -> Any: return types.MethodType(func, base) + @staticmethod + def _unpickle_sdp_backend(name: str): + # Reconstruct from the Python-facing enum namespace + return getattr(torch.nn.attention.SDPBackend, name) + @classmethod def _unpickle_cell(cls, val: Any) -> Any: def _() -> Any: @@ -3466,6 +3471,9 @@ def reducer_override( if id(obj) not in self.guard_tree_values: return _Missing, ("distributed_c10d.Work",) + if isinstance(obj, torch.nn.attention.SDPBackend): + return type(self)._unpickle_sdp_backend, (obj.name,) + if type(obj).__qualname__ != type(obj).__name__: raise torch._dynamo.exc.PackageError( f"Type {type(obj)} for object {obj} cannot be saved " From 587d63a3e07de5dc91065f9ef70bcacda9989068 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Mon, 1 Dec 2025 09:26:14 -0800 Subject: [PATCH 092/338] [precompile] generate nonce key if enable_aot_compile is enabled (#169244) As discussed offline with @jamesjwu and @aorenste in a precompile world, it's actually not necessary to ensure we can safetely generate a unique cache key since there is no implicit sharing. This PR adds a fallback so in the case where we can't safetly generate a key for caching (eg. certain HOPs), we still generate a random nonce key for precompile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169244 Approved by: https://github.com/Lucaskabela ghstack dependencies: #168989, #169008, #168988 --- test/dynamo/test_aot_compile.py | 24 +++++++ .../_aot_autograd/autograd_cache.py | 64 ++++++++++++------- 2 files changed, 64 insertions(+), 24 deletions(-) diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index 7fcfbc68599fa..8ea9ca2bb72c0 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -776,6 +776,30 @@ def make_inputs(): self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor") self.assertEqual(expected, actual) + def test_aot_compile_with_checkpoint(self): + from torch.utils.checkpoint import checkpoint + + def fn(x, y): + def compute(x, y): + return x * 2 + y * 3 + + return checkpoint(compute, x, y, use_reentrant=False) + + compiled_fn = torch.compile(fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4), torch.randn(3, 4)), {}) + ) + inputs = (torch.randn(3, 4), torch.randn(3, 4)) + expected = fn(*inputs) + actual = compiled_fn(*inputs) + self.assertEqual(expected, actual) + compiled_fn.save_compiled_function(self.path()) + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with open(self.path(), "rb") as f: + compiled_fn = torch.compiler.load_compiled_function(f) + actual = compiled_fn(*inputs) + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index e411b4c7f6d86..1a7b4c8973c5d 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -12,6 +12,7 @@ import logging import os import pickle +import random import shutil import time import traceback @@ -474,30 +475,45 @@ def autograd_cache_key( """ Generate a unique hash of the FX graph for caching. """ - check_cacheable(gm) - if has_triton_package(): - # Due to https://github.com/triton-lang/triton/issues/3729, - # if triton is < 3.2.0, AOTAutogradCache may cause us to - # attempt to load a cache entry without initializing - # the CUDA context on the autograd thread. - - # Without caching, we naturally do this initialization when - # tracing through the graph with the autograd engine. - import triton - - if triton.__version__ < "3.2.0": - raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") - details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) - pickler = AOTAutogradCachePickler(gm) - # The prefix distinguishes among the other kinds of objects we cache - key = "a" + pickler.get_hash(details) - debug_lines = pickler.debug_lines(details) - log.debug( - "Autograd graph cache hash details for key %s:\n%s", - key, - LazyString(lambda: "\n".join(debug_lines)), - ) - return key, debug_lines + + try: + check_cacheable(gm) + if has_triton_package(): + # Due to https://github.com/triton-lang/triton/issues/3729, + # if triton is < 3.2.0, AOTAutogradCache may cause us to + # attempt to load a cache entry without initializing + # the CUDA context on the autograd thread. + + # Without caching, we naturally do this initialization when + # tracing through the graph with the autograd engine. + import triton + + if triton.__version__ < "3.2.0": + raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") + details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) + pickler = AOTAutogradCachePickler(gm) + # The prefix distinguishes among the other kinds of objects we cache + key = "a" + pickler.get_hash(details) + debug_lines = pickler.debug_lines(details) + log.debug( + "Autograd graph cache hash details for key %s:\n%s", + key, + LazyString(lambda: "\n".join(debug_lines)), + ) + return key, debug_lines + except Exception: + # If enable_aot_compile is set, we're in AOT precompile mode where we always + # want to use fallback nonce keys. Unlike caching, it's fine if we can't generate + # a proper key because we are guaranteed in an AOT precompile world users are in + # complete control of distributing and loading artifacts. + if torch._dynamo.config.enable_aot_compile: + log.info( + "Failed to generate AOTAutograd cache key; falling back to nonce due to enable_aot_compile", + exc_info=True, + ) + return str(random.random()), [] + else: + raise @contextlib.contextmanager From ce5e7e3bf1f4b69a4f4f93d288ba75b906df492a Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Tue, 2 Dec 2025 01:09:40 +0000 Subject: [PATCH 093/338] [CI][CUDA][Distributed] Update NCCL to 2.28.9 for CUDA13 (#168091) This PR updates the NCCL version for CUDA13 from 2.27.7 to 2.28.9. 2.28.9 release notes: https://github.com/NVIDIA/nccl/releases/tag/v2.28.9-1 2.28.7 release notes: https://github.com/NVIDIA/nccl/releases/tag/v2.28.7-1 2.28.3 release notes: https://github.com/NVIDIA/nccl/releases/tag/v2.28.3-1 CUDA 12 remains at 2.27.5 and is untouched by this PR. Reference PR: https://github.com/pytorch/pytorch/pull/166174 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168091 Approved by: https://github.com/atalman --- .ci/docker/ci_commit_pins/nccl-cu13.txt | 2 +- .github/scripts/generate_binary_build_matrix.py | 2 +- ...ated-linux-aarch64-binary-manywheel-nightly.yml | 14 +++++++------- .../generated-linux-binary-manywheel-nightly.yml | 14 +++++++------- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/.ci/docker/ci_commit_pins/nccl-cu13.txt b/.ci/docker/ci_commit_pins/nccl-cu13.txt index 77202c1566019..7c451d9fad29a 100644 --- a/.ci/docker/ci_commit_pins/nccl-cu13.txt +++ b/.ci/docker/ci_commit_pins/nccl-cu13.txt @@ -1 +1 @@ -v2.27.7-1 +v2.28.9-1 diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index d69db191b9464..7fb1ba1f238f4 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -115,7 +115,7 @@ "nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | " "nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | " "nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | " - "nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | " + "nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | " "nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | " "nvidia-nvtx==13.0.85; platform_system == 'Linux' | " "nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | " diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index ff5ad7e89f99b..dd35e29c2c145 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -346,7 +346,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -696,7 +696,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1046,7 +1046,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1396,7 +1396,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1746,7 +1746,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2096,7 +2096,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2446,7 +2446,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index ac04187e24d8c..754432bf461bf 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -329,7 +329,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1003,7 +1003,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1677,7 +1677,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2351,7 +2351,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3025,7 +3025,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3699,7 +3699,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -4373,7 +4373,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.28.9; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} From 2d1f78fe3ec13820f136a2e0336da12a25f41708 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 2 Dec 2025 01:37:38 +0000 Subject: [PATCH 094/338] [DebugMode] Fix hash for 0 ele tensor; Add more tests (#169027) - When tensor numel is 0, we let the hash be 0 instead of hashing, because torch.hash_tensor doesn't work for 0 numel tensors - Add some tests for distributed Pull Request resolved: https://github.com/pytorch/pytorch/pull/169027 Approved by: https://github.com/xmfan, https://github.com/ngimel --- .../tensor/debug/test_debug_mode.py | 165 +++++++++++++++++- torch/utils/_debug_mode.py | 6 +- 2 files changed, 169 insertions(+), 2 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 801cb0ab64219..dcc50bd268faa 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -1,10 +1,12 @@ # Owner(s): ["oncall: distributed"] import contextlib +import os import unittest import torch import torch.distributed as dist +import torch.distributed._functional_collectives as _functional_collectives from torch._dynamo.testing import CompileCounterWithBackend from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed.tensor import ( @@ -17,6 +19,11 @@ ) from torch.distributed.tensor._dtensor_spec import ShardOrderEntry from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + requires_nccl, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -190,8 +197,8 @@ def test_debug_mode_backward(self): aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu) redistribute_input(t: f32[8, 8], trace: R->S(0)) aten::split.Tensor(t: f32[8, 8], 1) - aten::clone(t: f32[1, 8]) aten::detach(t: f32[8, 1]) + aten::clone(t: f32[1, 8]) aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu) aten::detach(t: f32[1, 8])""", ) @@ -527,6 +534,32 @@ def test_check_hash_mismatches(self): [call["call"] for call in mismatches], ["aten::sin", "aten::sum"] ) + @unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(0).total_memory < 2**26, + "Being conservative, test peak memory is 25MB?", + ) + def test_tensor_hash_redistribute(self): + # test that hashing collectives gives correct results + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + local_tensor = torch.ones(2**18, device=self.device_type) + dt = DTensor.from_local(local_tensor, mesh, [Shard(0)], run_check=False) + + with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): + dt.redistribute(mesh, [Replicate()]) + + # Find all_gather hash + all_gather_logs = [ + op + for op in debug_mode.logs + if isinstance(op, _OpCall) + and op.op == torch.ops._c10d_functional.all_gather_into_tensor.default + ] + self.assertEqual(len(all_gather_logs), 1) + actual_hash = all_gather_logs[0].log["hash"] + self.assertEqual(actual_hash, float(local_tensor.numel() * self.world_size)) + @unittest.skipIf(not HAS_GPU, "requires GPU") @unittest.skipIf(not has_triton_package(), "requires triton") def test_check_triton_hash_mismatches(self): @@ -608,6 +641,136 @@ def f(dA, dB): self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str) +class TestDebugModeUtils(TestCase): + """Test DebugMode with NCCL backend without using DTensor.""" + + def test_hash_empty_tenor(self): + t = torch.tensor([]) + # hash tensor fn should not error out with empty tensor + out = torch.utils._debug_mode.hash_tensor_fn(t) + self.assertTrue(isinstance(out, torch.Tensor)) + out = torch.utils._debug_mode.hash_tensor_fn(t, use_scalar=True) + self.assertTrue(isinstance(out, int)) + + +class TestDTensorDebugModeNCCLBackend(MultiProcessTestCase): + @property + def world_size(self): + return 2 # Need at least 2 ranks for collectives + + def setUp(self): + super().setUp() + self._spawn_processes() + + def _init_process_group(self): + """Initialize NCCL process group for each spawned process.""" + torch.cuda.set_device(self.rank) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + self.device = f"cuda:{self.rank}" + + def _destroy_process_group(self): + """Destroy the process group.""" + dist.destroy_process_group() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_allgather_base(self): + self._init_process_group() + tensor = torch.ones(10, 10, device=torch.device(self.device)) * (self.rank + 1) + # Output size must be world_size * input_size + output_tensor = torch.zeros( + 10 * self.world_size, 10, device=torch.device(self.device) + ) + + with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(hash_inputs=True): + dist.all_gather_into_tensor(output_tensor, tensor) + + self.assertTrue("c10d::_allgather_base_" in debug_mode.debug_string()) + + hash_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731 + + self.assertEqual(debug_mode.operators[-1].log["hash"][0], hash_(output_tensor)) + + # Verify each rank's contribution + for i in range(self.world_size): + expected_slice = torch.ones(10, 10, device=self.device) * (i + 1) + self.assertEqual(output_tensor[i * 10 : (i + 1) * 10], expected_slice) + + self._destroy_process_group() + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_allgather_base_async_op(self): + """Test all_gather_into_tensor with async_op=True.""" + self._init_process_group() + tensor = torch.ones(10, 10, device=torch.device(self.device)) * (self.rank + 1) + # Output size must be world_size * input_size + output_tensor = torch.zeros( + 10 * self.world_size, 10, device=torch.device(self.device) + ) + + with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(hash_inputs=True): + # Call with async_op=True returns a work handle + work = dist.all_gather_into_tensor(output_tensor, tensor, async_op=True) + # Wait for the async operation to complete + work.wait() + + self.assertTrue("c10d::_allgather_base_" in debug_mode.debug_string()) + hash_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731 + + self.assertEqual(debug_mode.operators[-1].log["hash"][0], hash_(output_tensor)) + + # Verify each rank's contribution + for i in range(self.world_size): + expected_slice = torch.ones(10, 10, device=self.device) * (i + 1) + self.assertEqual(output_tensor[i * 10 : (i + 1) * 10], expected_slice) + + self._destroy_process_group() + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_allgather_functional_with_async_collective_tensor(self): + self._init_process_group() + tensor = torch.ones(10, 10, device=torch.device(self.device)) * (self.rank + 1) + + # Use functional collectives which return AsyncCollectiveTensor + with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): + result = _functional_collectives.all_gather_tensor( + tensor, gather_dim=0, group=dist.group.WORLD + ) + + result = result.wait() + hash_ = lambda x: norm_hash_fn(x, use_scalar=True) # noqa: E731 + + self.assertEqual(debug_mode.operators[-1].log["hash"], hash_(result)) + + self.assertTrue( + "_c10d_functional::all_gather_into_tensor" in debug_mode.debug_string() + ) + + # Verify the result shape - should be world_size times bigger + self.assertEqual(result.shape[0], tensor.shape[0] * self.world_size) + # Verify each rank's contribution + for i in range(self.world_size): + expected_slice = torch.ones(10, 10, device=self.device) * (i + 1) + self.assertEqual(result[i * 10 : (i + 1) * 10], expected_slice) + + self._destroy_process_group() + + instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 14c1607383e1c..3303f2470e4da 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -204,7 +204,11 @@ def hash_tensor_fn( else: t_clean = t.to(dtype=torch.int64) - out = torch.hash_tensor(t_clean) + if t.numel() > 0: + out = torch.hash_tensor(t_clean) + else: + out = torch.zeros((), device=t_clean.device, dtype=torch.uint64) + if use_scalar: return out.item() # type: ignore[attribute] return out From 503b2640023521f5a35cd9a52fc8033d73a95d0d Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 1 Dec 2025 14:28:22 +0000 Subject: [PATCH 095/338] [AOTI] Fix a GPU memory leak caused by reference circle (#168063) Summary: Fix https://github.com/pytorch/pytorch/issues/167630. There was a reference circle between GraphLowering and CppWrapperCpu due to caching, which makes GraphLowering unnecessarily hold some contant tensors causing GPU memory leaks. This PR fixes that by changing the cache to use the object id of GraphLowering as a part of the key. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168063 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 48 ++++++++++++++++++++++ torch/_inductor/codegen/cpp_wrapper_cpu.py | 24 ++++++++++- 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 8cac7b8f929d1..4524332cd28a9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -7472,6 +7472,54 @@ def forward(self, x): "RAIIAtenTensorHandle buf0(buf0_handle_restrided);" ).run(code) + @unittest.skipIf( + IS_FBCODE, + "different behavior in fbcode", + ) + def test_codegen_int_array_var_fix_memory_leak(self): + """ + Fix https://github.com/pytorch/pytorch/issues/167630 + """ + if self.device != "cuda": + raise unittest.SkipTest("test is only for cuda") + + def make_mlp(in_dim=128, hidden=256, out_dim=64, depth=3): + layers = [] + d = in_dim + for _ in range(depth): + layers += [nn.Linear(d, hidden), nn.ReLU()] + d = hidden + layers += [nn.Linear(d, out_dim)] + return nn.Sequential(*layers) + + batch = 32 + in_dim = 2048 + hidden = 512 + out_dim = 10 + depth = 6 + + import gc + + allocated_memory = [] + for _ in range(3): + torch.cuda.reset_peak_memory_stats() + + model = make_mlp(in_dim, hidden, out_dim, depth).to(self.device) + example_inputs = (torch.randn(batch, in_dim, device=self.device),) + ep = torch.export.export( + model, + example_inputs, + ) + torch._inductor.aoti_compile_and_package(ep) + + del model, example_inputs, ep + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + allocated_memory.append(torch.cuda.memory_allocated()) + + self.assertTrue(allocated_memory[1] == allocated_memory[2]) + @unittest.skipIf(IS_MACOS, "might have no readelf on Mac") def test_libtorch_free_so(self): class Model(torch.nn.Module): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 3a65d1c895d1c..0bb1b40cfad96 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -96,6 +96,7 @@ def __init__(self): self.include_extra_header = functools.lru_cache(None)( # type: ignore[method-assign] self._include_extra_header ) + self.codegen_int_array_var_cache = {} @staticmethod def create( @@ -1637,14 +1638,33 @@ def codegen_memory_format(self, memory_format): self.used_cached_memory_formats.add(memory_format_str) return f"cached_torch_memory_format_{memory_format_str}" - @functools.cache # noqa: B019 def codegen_int_array_var( self, int_array: str, writeline: Callable[..., None], known_statically=False, graph=None, # for per-graph caching - ): + ) -> str: + # Use id(graph) for caching to avoid circular references + cache_key = ( + int_array, + id(writeline), + known_statically, + id(graph) if graph else None, + ) + if cache_key not in self.codegen_int_array_var_cache: + self.codegen_int_array_var_cache[cache_key] = ( + self._codegen_int_array_var_impl(int_array, writeline, known_statically) + ) + + return self.codegen_int_array_var_cache[cache_key] + + def _codegen_int_array_var_impl( + self, + int_array: str, + writeline: Callable[..., None], + known_statically: bool, + ) -> str: # Used for size/stride declaration # # Because the memory planning is done in two passes (see the implementation From 3cd98b4205ada151042cc7ff097a82d4a4b18725 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Mon, 1 Dec 2025 15:20:18 +0000 Subject: [PATCH 096/338] [Inductor] support masked vectorization for the tail_loop for integer and bool datatypes (#165885) **Summary:** Support masked vectorization for the tail_loop for int32, int64 and bool datatypes Pull Request resolved: https://github.com/pytorch/pytorch/pull/165885 Approved by: https://github.com/mingfeima, https://github.com/jansel --- aten/src/ATen/cpu/vec/vec256/vec256_int.h | 16 ++--- aten/src/ATen/cpu/vec/vec512/vec512_int.h | 12 ++-- aten/src/ATen/cpu/vec/vec_mask.h | 13 ++++ aten/src/ATen/cpu/vec/vec_n.h | 13 ++-- test/inductor/test_cpu_repro.py | 87 ++++++++++++++++++++++- torch/_inductor/codegen/cpp.py | 46 +++++------- torch/csrc/inductor/cpp_prefix.h | 45 ++++++++++-- 7 files changed, 178 insertions(+), 54 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 998177758be8d..eac5c710c9002 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -116,10 +116,10 @@ class Vectorized : public Vectorizedi { __at_align__ int64_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See // https://github.com/pytorch/pytorch/issues/32502 for more details. We do - // not initialize arrays to zero using "={0}" because gcc would compile it + // not initialize arrays to one using "={1}" because gcc would compile it // to two instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(size())) { - tmp_values[i] = 0; + tmp_values[i] = 1; } std::memcpy(tmp_values, ptr, count * sizeof(int64_t)); return loadu(tmp_values); @@ -266,10 +266,10 @@ class Vectorized : public Vectorizedi { __at_align__ int32_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See // https://github.com/pytorch/pytorch/issues/32502 for more details. We do - // not initialize arrays to zero using "={0}" because gcc would compile it + // not initialize arrays to one using "={1}" because gcc would compile it // to two instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(size())) { - tmp_values[i] = 0; + tmp_values[i] = 1; } std::memcpy(tmp_values, ptr, count * sizeof(int32_t)); return loadu(tmp_values); @@ -566,10 +566,10 @@ class Vectorized : public Vectorizedi { __at_align__ int16_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See // https://github.com/pytorch/pytorch/issues/32502 for more details. We do - // not initialize arrays to zero using "={0}" because gcc would compile it + // not initialize arrays to one using "={1}" because gcc would compile it // to two instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(size())) { - tmp_values[i] = 0; + tmp_values[i] = 1; } std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); return loadu(tmp_values); @@ -914,10 +914,10 @@ class Vectorized8 : public Vectorizedi { __at_align__ T tmp_values[size()]; // Ensure uninitialized memory does not change the output value See // https://github.com/pytorch/pytorch/issues/32502 for more details. We do - // not initialize arrays to zero using "={0}" because gcc would compile it + // not initialize arrays to one using "={1}" because gcc would compile it // to two instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(size())) { - tmp_values[i] = 0; + tmp_values[i] = 1; } std::memcpy(tmp_values, ptr, count * sizeof(T)); return loadu(tmp_values); diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h index 0a2f2c5f94823..236c31e24244d 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -130,7 +130,8 @@ class Vectorized : public Vectorizedi { return _mm512_loadu_si512(reinterpret_cast(ptr)); } else { __mmask8 mask = (1ULL << count) - 1; - return _mm512_maskz_loadu_epi64(mask, ptr); + auto ones = _mm512_set1_epi64(1); + return _mm512_mask_loadu_epi64(ones, mask, ptr); } } void store(void* ptr, int count = size()) const { @@ -332,7 +333,8 @@ class Vectorized : public Vectorizedi { return _mm512_loadu_si512(reinterpret_cast(ptr)); } else { __mmask16 mask = (1ULL << count) - 1; - return _mm512_maskz_loadu_epi32(mask, ptr); + auto ones = _mm512_set1_epi32(1); + return _mm512_mask_loadu_epi32(ones, mask, ptr); } } void store(void* ptr, int count = size()) const { @@ -660,7 +662,8 @@ class Vectorized : public Vectorizedi { return _mm512_loadu_si512(reinterpret_cast(ptr)); } else { __mmask32 mask = (1ULL << count) - 1; - return _mm512_maskz_loadu_epi16(mask, ptr); + auto ones = _mm512_set1_epi16(1); + return _mm512_mask_loadu_epi16(ones, mask, ptr); } } void store(void* ptr, int count = size()) const { @@ -1101,7 +1104,8 @@ class Vectorized8 : public Vectorizedi { return loadu_one_fourth(ptr); } else { __mmask64 mask = (1ULL << count) - 1; - return _mm512_maskz_loadu_epi8(mask, ptr); + auto ones = _mm512_set1_epi8(1); + return _mm512_mask_loadu_epi8(ones, mask, ptr); } } void store(void* ptr, int count = size()) const { diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index e19d7f75388af..2bc20980f496d 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -165,6 +165,19 @@ class VecMask { return VectorizedN(VectorizedN::loadu(mask)); } + template + static VecMask from(U* b, int count) { + using int_t = int_same_size_t; + __at_align__ T mask[size()]; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (int i = 0; i < count; i++) { + *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0; + } + return VectorizedN(VectorizedN::loadu(mask, count)); + } + static VecMask blendv( const VecMask& c, const VecMask& b, diff --git a/aten/src/ATen/cpu/vec/vec_n.h b/aten/src/ATen/cpu/vec/vec_n.h index 3de55de6f1b85..9bebd724399ff 100644 --- a/aten/src/ATen/cpu/vec/vec_n.h +++ b/aten/src/ATen/cpu/vec/vec_n.h @@ -187,12 +187,13 @@ class VectorizedN { static VectorizedN loadu(const void* ptr, int64_t count) { VectorizedN result; for (int i = 0; i < N; ++i) { - result.values[i] = Vectorized::loadu( - ptr, std::min(count, (int64_t)Vectorized::size())); - ptr = static_cast(ptr) + Vectorized::size(); - count -= Vectorized::size(); - if (count <= 0) { - break; + if (count > 0) { + result.values[i] = Vectorized::loadu( + ptr, std::min(count, (int64_t)Vectorized::size())); + ptr = static_cast(ptr) + Vectorized::size(); + count -= Vectorized::size(); + } else { + result.values[i] = Vectorized((T)1); } } return result; diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index ba9dc93c651cf..79ae62d4e10ea 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4701,6 +4701,23 @@ def fn(x): self.common(fn, (x,)) check_metrics_vec_kernel_count(1) + # Tail vectorization case + x = torch.rand(37) + torch._dynamo.reset() + metrics.reset() + with torch.no_grad(): + expected = fn(x) + compiled_fn = torch.compile(fn) + actual, code = run_and_get_cpp_code(compiled_fn, x) + self.assertEqual(expected, actual) + # 1 generated vec kernel + check_metrics_vec_kernel_count(1) + # Check that both main and tail loops are vectorized + if _can_check_vec_metrics(): + FileCheck().check_count( + "at::vec::VecMask::from", 2, exactly=True + ).run(code) + @torch._dynamo.config.patch(dynamic_shapes=True) @torch._dynamo.config.patch(assume_static_by_default=False) def test_symbolic_shape_scalar_value_reduction(self): @@ -4722,6 +4739,23 @@ def fn(x): self.common(fn, (x,)) check_metrics_vec_kernel_count(1) + # Tail vectorization case + x = torch.randint(0, 100, (37, 37), dtype=torch.int32) + torch._dynamo.reset() + metrics.reset() + with torch.no_grad(): + expected = fn(x) + compiled_fn = torch.compile(fn) + actual, code = run_and_get_cpp_code(compiled_fn, x) + self.assertEqual(expected, actual) + # 1 generated vec kernel + check_metrics_vec_kernel_count(1) + # Check that both main and tail loops are vectorized + if _can_check_vec_metrics(): + FileCheck().check_count( + "at::vec::Vectorized::loadu", 2, exactly=True + ).run(code) + def test_int32_reduction_vec(self): def fn(x): return x.sum(dim=1) @@ -4731,6 +4765,23 @@ def fn(x): self.common(fn, (x,)) check_metrics_vec_kernel_count(1) + # Tail vectorization case + x = torch.randint(0, 100, (37, 37), dtype=torch.int32) + torch._dynamo.reset() + metrics.reset() + with torch.no_grad(): + expected = fn(x) + compiled_fn = torch.compile(fn) + actual, code = run_and_get_cpp_code(compiled_fn, x) + self.assertEqual(expected, actual) + # 1 generated vec kernel + check_metrics_vec_kernel_count(1) + # Check that both main and tail loops are vectorized + if _can_check_vec_metrics(): + FileCheck().check_count( + "at::vec::Vectorized::loadu", 2, exactly=True + ).run(code) + def test_uint32_pointwise_vec(self): def fn(x): return x * x @@ -4760,6 +4811,23 @@ def fn(x): self.common(fn, (x,)) check_metrics_vec_kernel_count(1) + # Tail vectorization case + x = torch.randint(0, 100, (37, 37), dtype=torch.int64) + torch._dynamo.reset() + metrics.reset() + with torch.no_grad(): + expected = fn(x) + compiled_fn = torch.compile(fn) + actual, code = run_and_get_cpp_code(compiled_fn, x) + self.assertEqual(expected, actual) + # 1 generated vec kernel + check_metrics_vec_kernel_count(1) + # Check that both main and tail loops are vectorized + if _can_check_vec_metrics(): + FileCheck().check_count( + "at::vec::VectorizedN::loadu", 2, exactly=True + ).run(code) + def test_int64_reduction_vec(self): def fn(x): return x.sum(dim=1) @@ -4769,6 +4837,23 @@ def fn(x): self.common(fn, (x,)) check_metrics_vec_kernel_count(1) + # Tail vectorization case + x = torch.randint(0, 100, (37, 37), dtype=torch.int64) + torch._dynamo.reset() + metrics.reset() + with torch.no_grad(): + expected = fn(x) + compiled_fn = torch.compile(fn) + actual, code = run_and_get_cpp_code(compiled_fn, x) + self.assertEqual(expected, actual) + # 1 generated vec kernel + check_metrics_vec_kernel_count(1) + # Check that both main and tail loops are vectorized + if _can_check_vec_metrics(): + FileCheck().check_count( + "at::vec::VectorizedN::loadu", 2, exactly=True + ).run(code) + def test_uint64_pointwise_vec(self): def fn(x): return x * x @@ -5379,7 +5464,7 @@ def fn(arg0_1): _, code = run_and_get_cpp_code(opt_fn, x) FileCheck().check_count( "return at::vec::VectorizedN::loadu(tmpbuf.data(),", - 4, + 8, exactly=True, ).run(code) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 18b209de94cb3..a9c45cd329814 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -158,17 +158,6 @@ def get_export_declaration(): torch.float8_e5m2, ] -MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [ - torch.float64, - torch.float, - torch.bfloat16, - torch.float16, - torch.uint8, - torch.int8, - torch.float8_e4m3fn, - torch.float8_e5m2, -] - def reduction_init(reduction_type, dtype): if dtype in DTYPE_LOWP_FP: @@ -1743,6 +1732,7 @@ def maskify_or_vecify(code): V.kernel.compute, code, ) + result.is_vec = True elif result.is_vec: csevar = V.kernel.cse.generate( V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}" @@ -1882,16 +1872,13 @@ def inner(*args, **kwargs): code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") with code.indent(): code.writeline(f"tmpbuf_out[i] = {res};") + load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" if output_mask: - assert not kernel.tail_size - load_args = "tmpbuf_out.data()" load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" + elif n_vec == 1: + load_fn = f"at::vec::Vectorized<{octype}>::loadu" else: - load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" - if n_vec == 1: - load_fn = f"at::vec::Vectorized<{octype}>::loadu" - else: - load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" + load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" code.writeline(f"return {load_fn}({load_args});") code.writeline("()") return code @@ -2744,7 +2731,7 @@ def _get_vec_load_line( loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var if dtype == torch.bool: # TODO: should we consider load mask here? - line = f"{self._get_mask_type()}::from({loadbuf})" + line = f"{self._get_mask_type()}::from({loadbuf}, {cexpr_index(self.num_elems)})" else: line = ( f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" @@ -2987,7 +2974,10 @@ def store(self, name, index, value, mode=None): cdtype = DTYPE_TO_CPP[dtype] index = ops.index_expr(index, torch.int64).value assert isinstance(index, CppCSEVariable) and index.is_vec - line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" + if self.tail_size: + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value}, {cexpr_index(self.tail_size)});" + else: + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" self.stores.writeline(DeferredLine(name, line)) else: raise NotImplementedError(f"store mode={mode}") @@ -3452,7 +3442,10 @@ def reduction_combine_vec( if isinstance(next_value, CppCSEVariable): assert next_value.dtype == torch.bool (next_value,) = unify_mask_base_type(V.kernel.compute, (next_value,)) - return f"{var} | {next_value}" + if self.tail_size: + return f"any_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} | {next_value}" else: raise NotImplementedError @@ -4358,13 +4351,6 @@ def run(kernel): fn_list, var_sizes_list ) assert len(tiling_factors) == len(tiling_indices) - # This should be removed after full support for vectorization is implemented. - could_masked_vec = True - all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list)) - if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): - # can be removed after masked vectorizable dtype are same with vectorizable dtype - could_masked_vec = False - _inner_loop_reduction_outer_not = False _outer_loop = None if tiling_indices: @@ -4391,7 +4377,7 @@ def run(kernel): ) tail_size = loop.size - loop.tiled_size vec_kernel.active_ranges = {loop.var: (0, loop.tiled_size)} - if config.cpp.enable_loop_tail_vec and could_masked_vec: + if config.cpp.enable_loop_tail_vec: tail_kernel = codegen_kernel( self.vec_kernel_cls, tiling_factors[0], @@ -4438,7 +4424,7 @@ def run(kernel): inner_loop.var: inner_ranges["main"], } tail_kernel = [] - if config.cpp.enable_loop_tail_vec and could_masked_vec: + if config.cpp.enable_loop_tail_vec: for outer_r, inner_r in ( ("main", "tail"), ("tail", "main"), diff --git a/torch/csrc/inductor/cpp_prefix.h b/torch/csrc/inductor/cpp_prefix.h index 7dc161d13fd52..a51bd74496fe8 100644 --- a/torch/csrc/inductor/cpp_prefix.h +++ b/torch/csrc/inductor/cpp_prefix.h @@ -306,23 +306,50 @@ inline T cascade_sum_combine( } template -T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) { +inline T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) { auto out = at::vec::maximum(a, b); return T::set(a, out, tail_size); } +template <> +inline at::vec::VecMask max_masked_reduce( + const at::vec::VecMask& a, + const at::vec::VecMask& b, + const int64_t tail_size) { + auto out = a | b; + return at::vec::VecMask::set(a, out, tail_size); +} + template -T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) { +inline T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) { auto out = at::vec::minimum(a, b); return T::set(a, out, tail_size); } +template <> +inline at::vec::VecMask min_masked_reduce( + const at::vec::VecMask& a, + const at::vec::VecMask& b, + const int64_t tail_size) { + auto out = a & b; + return at::vec::VecMask::set(a, out, tail_size); +} + template -T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) { +inline T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) { auto out = a + b; return T::set(a, out, tail_size); } +template <> +inline at::vec::VecMask sum_masked_reduce( + const at::vec::VecMask& a, + const at::vec::VecMask& b, + const int64_t tail_size) { + auto out = a | b; + return at::vec::VecMask::set(a, out, tail_size); +} + template T prod_masked_reduce(const T& a, const T& b, const int64_t tail_size) { auto out = a * b; @@ -334,6 +361,12 @@ T xor_sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) { auto out = a ^ b; return T::set(a, out, tail_size); } + +template +T any_masked_reduce(const T& a, const T& b, const int64_t tail_size) { + auto out = a | b; + return T::set(a, out, tail_size); +} #endif // Refer to @@ -869,14 +902,16 @@ template void atomic_add_vec( T* addr, at::vec::VectorizedN index, - at::vec::VectorizedN offset) { + at::vec::VectorizedN offset, + std::optional tail_size = std::nullopt) { constexpr int len = at::vec::VectorizedN::size(); static_assert(len <= at::vec::VectorizedN::size()); __at_align__ std::array tmpbuf; __at_align__ std::array tmpidx; offset.store(tmpbuf.data(), len); index.store(tmpidx.data(), len); - for (int i = 0; i < len; i++) { + int size = tail_size.has_value() ? tail_size.value() : len; + for (int i = 0; i < size; i++) { atomic_add(addr + tmpidx[i], tmpbuf[i]); } } From d9cb8a70833101dbbe16b99520cfbdd70d0a87bf Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 1 Dec 2025 13:06:52 -0800 Subject: [PATCH 097/338] [AOTI] Set device info for subgraphs (#169001) Summary: Fix https://github.com/pytorch/pytorch/issues/168398. When Inductor creates subgraphs, they should inherit device information from the parent graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169001 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 34 ++++++++++++++++++++++++++++++ torch/_inductor/graph.py | 3 +++ 2 files changed, 37 insertions(+) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 4524332cd28a9..1e71936d5653d 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -21,6 +21,7 @@ from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters +from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass from torch._inductor import config from torch._inductor.codecache import WritableTempFile from torch._inductor.cpp_builder import normalize_path_separator @@ -2229,6 +2230,39 @@ def test_cond_with_reinterpret_view_inputs_outputs(self): dynamic_shapes=dynamic_shapes, ) + @requires_gpu + def test_cond_with_replace_view_ops(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class CondModelWithViewAndLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, cache, x): + def true_fn(cache, x): + return cache + 1.0 + + def false_fn(cache, x): + return self.linear(x).view(1, 2, 4, 4) + + cache_is_initialized = (cache != 0).any() + return torch.cond(cache_is_initialized, false_fn, false_fn, [cache, x]) + + example_input = ( + torch.zeros(1, 2, 4, 4, dtype=torch.float32, device=self.device), + torch.randn(8, 4, dtype=torch.float32, device=self.device), + ) + model = CondModelWithViewAndLinear().to(device=self.device) + exported_program = torch.export.export(model, example_input) + program = exported_program.run_decompositions() + gm = ReplaceViewOpsWithViewCopyOpsPass()(program.graph_module).graph_module + with config.patch( + {"max_autotune": True, "max_autotune_gemm_backends": "TRITON,ATEN"} + ): + _ = torch._inductor.aot_compile(gm, example_input) + def test_cond_with_multiple_outputs(self): inputs = ( torch.randn((10, 20), device=self.device), diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 517d6c3e39d1b..a16e09f3ca5cf 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -2369,6 +2369,9 @@ def codegen_subgraph(self, parent_graph: GraphLowering) -> None: self.wrapper_code = parent_graph.wrapper_code self.device_ops = parent_graph.device_ops self.cpp_wrapper = parent_graph.cpp_wrapper + self.device_types = parent_graph.device_types + self.device_idxs = parent_graph.device_idxs + self.device_type = parent_graph.device_type self._update_scheduler() self.scheduler.codegen() From 1e526fb5b1d93bfc70691c5c3955fdffc1b7b7de Mon Sep 17 00:00:00 2001 From: Dzmitry Huba Date: Mon, 1 Dec 2025 10:11:00 -0800 Subject: [PATCH 098/338] Parity of rng offset compute and ranks subset support for Local Tensor (#169088) Debugging numeric differences for AutoParallel PP between Local Tensor and multi-process setup revealed differences in how rng offsets are computed. This change refactors DTensor implementation so that it can be shared with Local Tensor. The existing Local Tensor implementation was incorrectly computing shard linear index based on number of elements in the tensor instead of shard coordinates. AutoParallel PP slices world mesh into "pp" submeshes for MPMD execution and "dp_mod_ep, ep" submeshes for SPMD execution. Local Tensor uses default process group (corresponding to the world mesh) to compute collective groups and assumes input local tensors have ranks from the world mesh. Local Tensor mode can be created with subset of ranks. This feature is used in AutoParallel PP integration. Therefore this change modifies Local Tensor collectives to execute only if all ranks from the deduced rank groups are present on local tensor inputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169088 Approved by: https://github.com/dolpm --- torch/distributed/_local_tensor/__init__.py | 67 ++++++---- torch/distributed/_local_tensor/_c10d.py | 54 ++++++++ torch/distributed/tensor/_random.py | 129 +++++++++++--------- 3 files changed, 172 insertions(+), 78 deletions(-) diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index cc4a47f299444..4c8f12c11687b 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -680,28 +680,33 @@ def _set_pre_op_offset(self, state, spec) -> None: coord = (rank // num_chunks_after) % mesh_dim_size mesh_coords.append(coord) - # compute local shape and global offset for this rank - local_shape, global_offset = _compute_local_shape_and_global_offset( - spec.shape, spec.mesh.shape, mesh_coords, spec.placements + # compute shard offset based on placements + from torch.distributed.tensor._random import ( + _calc_first_shard_size, + _calc_shard_info, + _calc_shard_linear_idx, ) - # compute shard offset based on placements - shard_offset = 1 - for idx, placement in enumerate(spec.placements): - if isinstance(placement, Shard): - shard_dim = placement.dim - shard_offset *= global_offset[shard_dim] + 1 + # Compute shard index and total number of shards on each tensor dim + shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( + mesh_coords, spec + ) + + # compute shard linear index + shard_linear_idx = _calc_shard_linear_idx( + shard_idx_by_dim, total_num_shards_by_dim + ) # get current offset for this rank current_offset = int( state._per_rank_states[rank][8:].view(dtype=torch.int64).item() ) + local_shape = _calc_first_shard_size(spec) # compute local size local_size = prod(local_shape) # compute new offset (must be multiple of 4) - shard_linear_idx = shard_offset - 1 offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 state._per_rank_offsets[rank] = current_offset + offset_incr @@ -753,20 +758,20 @@ def _distribute_region(self, spec, generator=None): if self.distribute_region_enabled: # sync to rank 0's state if no explicit generator if generator is None: - rank_0_state = lm._per_rank_rng_states[0] - rank_0_cpu, rank_0_cuda = rank_0_state + any_rank_state = lm._any_local_rng_state() + any_rank_cpu, any_rank_cuda = any_rank_state if self._device.type == "cuda": - assert self._device.index in rank_0_cuda - rank_0_device_state = rank_0_cuda[self._device.index] + assert self._device.index in any_rank_cuda + any_rank_device_state = any_rank_cuda[self._device.index] else: - rank_0_device_state = rank_0_cpu + any_rank_device_state = any_rank_cpu from torch.distributed.tensor._random import _PhiloxState - rank_0_philox = _PhiloxState(rank_0_device_state) - state.seed = rank_0_philox.seed - state.offset = rank_0_philox.offset + any_rank_philox = _PhiloxState(any_rank_device_state) + state.seed = any_rank_philox.seed + state.offset = any_rank_philox.offset old_offset = state.offset self._set_pre_op_offset(state, spec) @@ -1113,18 +1118,24 @@ def _sync_meta(self) -> None: self._size = shape -_GLOBAL_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] +# If set to `True` the LocalTensorMode stack will be created for the whole process, +# otherwise it will be created for each thread. +_PROCESS_MODE: bool = True +_PROCESS_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] # When running under local runner each thread must create its own local tensor mode # so that they do not interfere with each other. _THREAD_LOCAL_TENSOR_MODE: threading.local = threading.local() def get_local_tensor_mode_list() -> list["LocalTensorMode"]: + global _PROCESS_MODE + if _PROCESS_MODE: + global _PROCESS_LOCAL_TENSOR_MODE + return _PROCESS_LOCAL_TENSOR_MODE + global _THREAD_LOCAL_TENSOR_MODE if not hasattr(_THREAD_LOCAL_TENSOR_MODE, "value"): _THREAD_LOCAL_TENSOR_MODE.value = [] - if len(_THREAD_LOCAL_TENSOR_MODE.value) > 0: - return _THREAD_LOCAL_TENSOR_MODE.value - return _GLOBAL_LOCAL_TENSOR_MODE + return _THREAD_LOCAL_TENSOR_MODE.value class LocalTensorMode(TorchDispatchMode): @@ -1230,7 +1241,7 @@ def __torch_dispatch__( for a in flat_args: if isinstance(a, LocalTensor): assert a._ranks <= self.ranks, ( - f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks" + f"Input LocalTensor {a} must be configured for a subset of the LocalTensorMode ranks {self.ranks}" ) if func.overloadpacket == torch.ops.aten.dim: @@ -1345,6 +1356,9 @@ def tensor_map( # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(results) + def _any_local_rng_state(self) -> tuple[torch.Tensor, dict[int, torch.Tensor]]: + return self._per_rank_rng_states[next(iter(self.ranks))] + def _patch_device_mesh(self) -> None: assert self._old_get_coordinate is None self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment] @@ -1674,12 +1688,16 @@ def __init__( threading.Thread(target=self._run, args=(i,), name="LocalRunnerMode") for i in range(concurrency) ] + self._process_mode = True def __enter__(self) -> "LocalRunnerMode": global _LOCAL_RUNNER_MODE assert _LOCAL_RUNNER_MODE is None, "LocalRunnerMode is already running" _LOCAL_RUNNER_MODE = self + global _PROCESS_MODE + self._process_mode = _PROCESS_MODE + _PROCESS_MODE = False for r in self._runners: r.start() return self @@ -1695,6 +1713,9 @@ def __exit__( global _LOCAL_RUNNER_MODE _LOCAL_RUNNER_MODE = None + global _PROCESS_MODE + _PROCESS_MODE = self._process_mode + def _run(self, id: int) -> None: LocalRunnerMode.runner_context.id = id # Only one thread can run at a time, hence must acquire the lock diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index ab2387af051dc..31f288a9bc85b 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -120,6 +120,9 @@ def _local_functional_all_gather_into_tensor( group_ranks = [group_offset + r for r in ranks] group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + for rank in group_ranks: group_tensors.append(tensor._local_tensors[rank]) @@ -151,6 +154,9 @@ def _local_functional_reduce_scatter_tensor( group_ranks = [group_offset + r for r in ranks] group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + for rank in group_ranks: group_tensors.append(tensor._local_tensors[rank]) @@ -191,6 +197,9 @@ def _local_functional_shard_dim_alltoall( group_ranks = [group_offset + r for r in ranks] group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + for rank in group_ranks: group_tensors.append(tensor._local_tensors[rank]) @@ -256,6 +265,9 @@ def _local_functional_all_to_all_single( for group_offset in group_offsets: group_ranks = [group_offset + r for r in ranks] + if not all(rank in split_local_tensors for rank in group_ranks): + continue + for i, dst in enumerate(group_ranks): splits = [] for j, src in enumerate(group_ranks): @@ -305,6 +317,9 @@ def _local_broadcast_( # For the tensors in this group [group_offset + r for r in ranks] # perform the broadcast on them group_ranks = [group_offset + r for r in ranks] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + source_rank = group_offset + relative_root_rank source_tensor = tensor._local_tensors[source_rank] @@ -375,6 +390,8 @@ def _local_all_reduce_( # For the tensors in this group [group_offset + r for r in ranks] # perform the allreduce on them group_ranks = [group_offset + r for r in ranks] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue # Collect tensors from the specified ranks in this group group_tensors = [] @@ -415,6 +432,8 @@ def _local_allreduce_coalesced_( # For each tensor, perform the reduction operation for tensor in tensors: assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue # Collect tensors from the specified ranks in this group group_tensors = [] for rank in group_ranks: @@ -463,6 +482,11 @@ def _local_reduce_scatter_tensor_coalesced_( assert isinstance(output_tensor, LocalTensor), ( "Output tensor must be a LocalTensor" ) + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + # Collect tensors from the specified ranks in this group group_inputs = [] for rank in group_ranks: @@ -503,6 +527,11 @@ def _local_allgather_base_( for group_offset in group_offsets: group_ranks = [group_offset + r for r in ranks] + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + gathered_tensors = [] for rank_i in group_ranks: gathered_tensors.append(input_tensor._local_tensors[rank_i]) @@ -539,6 +568,10 @@ def _local_reduce_scatter_base_( # type: ignore[no-untyped-def] for group_offset in group_offsets: group_ranks = [group_offset + r for r in ranks] + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue gathered_tensors = [] for rank_i in group_ranks: @@ -639,6 +672,12 @@ def _local_allgather_into_tensor_coalesced_( assert isinstance(output_tensor, LocalTensor), ( "Output tensor must be a LocalTensor" ) + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + # Gather input_tensor from all ranks into output_tensor # The output should be a concatenation of all inputs along the first dimension gathered_tensors = [] @@ -706,6 +745,8 @@ def _local_scatter_( # For the tensors in this group [group_offset + r for r in ranks] # perform the scatter on them group_ranks = [group_offset + r for r in ranks] + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue # Root rank scatters its input tensors to all ranks in this group for i, rank in enumerate(group_ranks): @@ -753,11 +794,19 @@ def _local_alltoall_( assert isinstance(output_tensor, LocalTensor), ( "Output tensor must be a LocalTensor" ) + + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + for j, rank_j in enumerate(group_ranks): input_tensor = input_tensors[j] assert isinstance(input_tensor, LocalTensor), ( "Input tensor must be a LocalTensor" ) + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + # Rank i's j-th input tensor goes to rank j's i-th output tensor source_tensor = input_tensor._local_tensors[rank_i] output_tensor._local_tensors[rank_j].copy_(source_tensor) @@ -796,6 +845,11 @@ def _local_alltoall_base_( # perform the alltoall_base on them group_ranks = [group_offset + r for r in ranks] + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + for i, rank_i in enumerate(group_ranks): # Split input tensor from rank_i according to input_split_sizes rank_tensor = input_tensor._local_tensors[rank_i] diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 40415947be9a0..4c3d51381f541 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -336,44 +336,14 @@ def _set_pre_op_offset(self, state: _PhiloxState, spec: DTensorSpec) -> None: The last value to calculate before obtaining the starting offset is the shard linear index. The starting offset for each rank will be its shard_linear_index * local_tensor_numel. """ - dtensor_shape = spec.shape mesh = spec.mesh - # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP - # case. Replace the custom logic with dim_map once we support it. - dim_map: list[int | list[int]] = [-1] * spec.ndim - for i, placement in enumerate(spec.placements): - if isinstance(placement, Shard): - shard_dim = placement.dim - if dim_map[shard_dim] == -1: - dim_map[shard_dim] = [i] - else: - mesh_dim_list = dim_map[shard_dim] - assert isinstance(mesh_dim_list, list) - mesh_dim_list.append(i) - - # Compute shard coordinate: - # The coordinate on each tensor dim is a tuple (idx, range) - # If a DTensor is partitioned on its dim i into n shards, and the current rank - # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i mesh_coordinate = mesh.get_coordinate() assert mesh_coordinate is not None - mesh_size = mesh.shape - shard_idx_by_dim = [] - total_num_shards_by_dim = [] # total number of shards on each tensor dim - for mesh_dim in dim_map: - shard_idx = 0 - total_num_shards = 1 - # the tensor dim is sharded on more than 1 mesh dim - if isinstance(mesh_dim, list): - rank_coord = [mesh_coordinate[d] for d in mesh_dim] - num_shards = [mesh_size[d] for d in mesh_dim] - # compute the shard idx and total number of shards - for idx, size in zip(rank_coord, num_shards): - shard_idx = shard_idx * size + idx - total_num_shards *= size - - shard_idx_by_dim.append(shard_idx) - total_num_shards_by_dim.append(total_num_shards) + + # Compute shard index and total number of shards on each tensor dim + shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( + mesh_coordinate, spec + ) # compute shard linear index shard_linear_idx = self._calc_shard_linear_idx( @@ -381,18 +351,7 @@ def _set_pre_op_offset(self, state: _PhiloxState, spec: DTensorSpec) -> None: ) # compute starting offset using the first shard's size - local_size_on_rank_0 = list(dtensor_shape) - for idx, placement in enumerate(spec.placements): - if isinstance(placement, Shard): - mesh_dim_size = mesh.size(idx) - shard_dim = placement.dim - local_size_on_rank_0[shard_dim], _ = ( - placement._local_shard_size_and_offset( - dtensor_shape[shard_dim], - mesh_dim_size, - 0, - ) - ) + local_size_on_rank_0 = _calc_first_shard_size(spec) from torch.distributed.tensor._ops.utils import prod @@ -435,14 +394,74 @@ def _set_post_op_offset( def _calc_shard_linear_idx( self, shard_coord: list[int], shard_size: list[int] ) -> int: - # compute shard linear index - shard_linear_idx = 0 - shard_coord_stride = 1 - for idx, size in zip(reversed(shard_coord), reversed(shard_size)): - shard_linear_idx += idx * shard_coord_stride - shard_coord_stride *= size - - return shard_linear_idx + return _calc_shard_linear_idx(shard_coord, shard_size) + + +def _calc_first_shard_size(spec: DTensorSpec) -> list[int]: + local_size_on_rank_0 = list(spec.shape) + for idx, placement in enumerate(spec.placements): + if isinstance(placement, Shard): + mesh_dim_size = spec.mesh.size(idx) + shard_dim = placement.dim + local_size_on_rank_0[shard_dim], _ = placement._local_shard_size_and_offset( + spec.shape[shard_dim], + mesh_dim_size, + 0, + ) + return local_size_on_rank_0 + + +def _calc_shard_info( + mesh_coordinate: list[int], spec: DTensorSpec +) -> tuple[list[int], list[int]]: + mesh = spec.mesh + # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP + # case. Replace the custom logic with dim_map once we support it. + dim_map: list[int | list[int]] = [-1] * spec.ndim + for i, placement in enumerate(spec.placements): + if isinstance(placement, Shard): + shard_dim = placement.dim + if dim_map[shard_dim] == -1: + dim_map[shard_dim] = [i] + else: + mesh_dim_list = dim_map[shard_dim] + assert isinstance(mesh_dim_list, list) + mesh_dim_list.append(i) + + # Compute shard coordinate: + # The coordinate on each tensor dim is a tuple (idx, range) + # If a DTensor is partitioned on its dim i into n shards, and the current rank + # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i + assert mesh_coordinate is not None + mesh_size = mesh.shape + shard_idx_by_dim = [] + total_num_shards_by_dim = [] # total number of shards on each tensor dim + for mesh_dim in dim_map: + shard_idx = 0 + total_num_shards = 1 + # the tensor dim is sharded on more than 1 mesh dim + if isinstance(mesh_dim, list): + rank_coord = [mesh_coordinate[d] for d in mesh_dim] + num_shards = [mesh_size[d] for d in mesh_dim] + # compute the shard idx and total number of shards + for idx, size in zip(rank_coord, num_shards): + shard_idx = shard_idx * size + idx + total_num_shards *= size + + shard_idx_by_dim.append(shard_idx) + total_num_shards_by_dim.append(total_num_shards) + return shard_idx_by_dim, total_num_shards_by_dim + + +def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int: + # compute shard linear index + shard_linear_idx = 0 + shard_coord_stride = 1 + for idx, size in zip(reversed(shard_coord), reversed(shard_size)): + shard_linear_idx += idx * shard_coord_stride + shard_coord_stride *= size + + return shard_linear_idx def _resolve_device(device_mesh: DeviceMesh) -> torch.device: From c04e2c656f48d82d1521b867bbbf03967b9b7564 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 1 Dec 2025 12:02:30 -0800 Subject: [PATCH 099/338] [dynamo][dicts] Decentralize and Improve key hash implementation for Dict variable tracker (#169204) Fixes https://github.com/pytorch/pytorch/issues/167956 ## Summary This PR decentralizes and improves the hash implementation for dictionary keys in Dynamo's ConstDictVariable tracker. Instead of maintaining a centralized list of hashable types and custom equality logic in _HashableTracker, we now delegate hashability checks, hash computation, and equality comparison to individual VariableTracker subclasses. ## Motivation The previous implementation had several issues: 1. Centralized logic: All hashability checks and hash computations were centralized in dicts.py, making it difficult to add support for new hashable types 2. Maintainability: Adding a new hashable type required modifying multiple locations in _HashableTracker (underlying_value, _eq_impl, and the is_hashable function) 3. Scattered knowledge: Type-specific hashing logic was separated from the type's own implementation 4. Limited extensibility: No clear protocol for VariableTracker subclasses to declare themselves as hashable ## Changes New Protocol Methods Added three new methods to the VariableTracker base class: 1. is_python_hashable(): Returns whether the underlying Python object is hashable 2. get_python_hash(): Computes the hash value for the underlying Python object 3. is_python_equal(other): Checks Python-level equality between two VariableTrackers The base implementation raises unimplemented() with helpful error messages, and subclasses override these methods as appropriate. ## Simplified _HashableTracker The _HashableTracker class in ConstDictVariable is now much simpler: - Removed underlying_value property (centralized type handling) - Removed _eq_impl static method (centralized equality logic) - Simplified __hash__() to delegate to vt.get_python_hash() - Simplified __eq__() to delegate to vt.is_python_equal() ## Decentralized Implementations Implemented the new protocol methods across relevant VariableTracker subclasses: - ConstantVariable, TensorVariable, TupleVariable, ListVariable - FrozensetVariable, FrozenDataClassVariable - BuiltinVariable, UserFunctionVariable, SkipFunctionVariable - FunctoolsPartialVariable, WeakRefVariable - NumpyVariable, NNModuleVariable, MethodWrapperVariable - TorchInGraphFunctionVariable, TorchHigherOrderOperatorVariable - TypingVariable, UserDefinedObjectVariable, UserDefinedClassVariable - SymNodeVariable, EnumVariable ## Enhanced Test Coverage Added 14 new test cases covering various hashable types as dictionary keys: - range, tuples, enums, frozensets - Typing constructs (e.g., typing.Union) - NumPy dtypes, method wrappers - Torch builtin functions, frozen dataclasses - Custom objects with __hash__ - Negative test for unhashable types (lists) ## Improved Error Messages Updated error messages to be more informative when encountering unhashable types, showing both the Python type and the VariableTracker type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169204 Approved by: https://github.com/jansel --- test/dynamo/test_dicts.py | 210 +++++++++++++++++- .../TestCustomOp.test_impl_device_cpu | 0 torch/_dynamo/graph_break_registry.json | 44 ++++ torch/_dynamo/utils.py | 18 ++ torch/_dynamo/variables/base.py | 56 +++++ torch/_dynamo/variables/builtin.py | 9 + torch/_dynamo/variables/constant.py | 25 +++ torch/_dynamo/variables/dicts.py | 206 ++++++----------- torch/_dynamo/variables/functions.py | 46 ++++ torch/_dynamo/variables/higher_order_ops.py | 9 + torch/_dynamo/variables/lists.py | 34 +++ torch/_dynamo/variables/misc.py | 37 +++ torch/_dynamo/variables/tensor.py | 28 +++ torch/_dynamo/variables/torch.py | 9 + torch/_dynamo/variables/user_defined.py | 87 ++++++-- 15 files changed, 664 insertions(+), 154 deletions(-) create mode 100644 test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index cdaeb2d91fbfb..4c233ea9458f3 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -19,6 +19,7 @@ import torch._functorch.config import torch.nn import torch.utils.checkpoint +from torch._dynamo.exc import Unsupported from torch._dynamo.testing import same from torch._dynamo.utils import dict_items from torch.testing._internal.common_utils import ( @@ -89,7 +90,7 @@ def forward(self, x): inp = torch.randn(4, 4) mod = Foo() - opt_f = torch.compile(mod) + opt_f = torch.compile(mod, backend="eager", fullgraph=True) self.assertEqual(mod(inp), opt_f(inp)) def test_dict_subclass_local_with_non_dict_method(self): @@ -518,7 +519,7 @@ def fn(d): args1 = {namedtuple: None, 3: torch.randn(3)} cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(fn, backend=cnts) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) self.assertEqual(fn(args1), opt_fn(args1)) self.assertEqual(cnts.frame_count, 1) # Test a failing namedtuple guard @@ -538,7 +539,7 @@ def fn(d, x): args1[3] = z cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(fn, backend=cnts) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) self.assertEqual(fn(args1, x), opt_fn(args1, x)) self.assertEqual(cnts.frame_count, 1) @@ -1062,8 +1063,6 @@ def fn(b: Any): a = {"one": torch.ones(1)} return a | b - from torch._dynamo.exc import Unsupported - for arg in args: with self.assertRaisesRegex(Unsupported, "Observed exception"): _ = fn(arg) @@ -1204,6 +1203,156 @@ def f(): opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) + def test_range_as_dict_key(self): + def fn(x): + d = {range(5): x * 2, range(10, 15): x * 3} + return d[range(0, 5, 1)] + d[range(10, 15)] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_tuple_as_dict_key(self): + def fn(x): + d = {(1, 2): x * 2, (3, 4, 5): x * 3} + return d[(1, 2)] + d[(3, 4, 5)] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_enum_as_dict_key(self): + class Color(enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + def fn(x): + d = {Color.RED: x * 2, Color.GREEN: x * 3, Color.BLUE: x * 4} + return d[Color.RED] + d[Color.GREEN] + d[Color.BLUE] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_intenum_as_dict_key(self): + class Priority(enum.IntEnum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + + def fn(x): + d = {Priority.LOW: x * 2, Priority.MEDIUM: x * 3, Priority.HIGH: x * 4} + return d[Priority.LOW] + d[Priority.MEDIUM] + d[Priority.HIGH] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_frozenset_as_dict_key(self): + def fn(x): + d = {frozenset([1, 2]): x * 2, frozenset([3, 4, 5]): x * 3} + return d[frozenset([1, 2])] + d[frozenset([3, 4, 5])] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_typing_union_as_dict_key(self): + from typing import Union + + def fn(x): + d = {Union[int, str]: x * 2, Union[float, bool]: x * 3} + return d[Union[int, str]] + d[Union[float, bool]] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_numpy_dtype_as_dict_key(self): + import numpy as np + + def fn(x): + d = {np.float32: x * 2, np.int64: x * 3, np.bool_: x * 4} + return d[np.float32] + d[np.int64] + d[np.bool_] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_method_wrapper_as_dict_key(self): + add_method = list.__add__ + mul_method = list.__mul__ + + def fn(x): + # Method wrappers are the type of bound methods on built-in types + d = {add_method: x * 2, mul_method: x * 3} + return d[add_method] + d[mul_method] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_torch_builtin_function_as_dict_key(self): + def fn(x, y): + # Using torch built-in functions as dictionary keys + d = {torch.add: x * 2, torch.mul: y * 3, torch.sub: x + y} + return d[torch.add] + d[torch.mul] + d[torch.sub] + + x = torch.randn(4) + y = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x, y), opt_fn(x, y)) + + def test_frozen_dataclass_as_dict_key(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class Point: + x: int + y: int + + def fn(tensor): + p1 = Point(1, 2) + p2 = Point(3, 4) + d = {p1: tensor * 2, p2: tensor * 3} + return d[Point(1, 2)] + d[Point(3, 4)] + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_list_as_dict_key_raises_typeerror(self): + def fn(x): + d = {[1, 2, 3]: x * 2} + return d[[1, 2, 3]] + + x = torch.randn(4) + + # First check that eager execution raises TypeError + with self.assertRaises(TypeError): + fn(x) + + # Also check that compiled version raises TypeError + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaisesRegex(Unsupported, "Observed exception"): + opt_fn(x) + + def test_get_default_nowrap_functions_as_dict_key(self): + def fn(x): + # Get the set of default nowrap functions + nowrap_funcs = torch.overrides.get_default_nowrap_functions() + # Use the set as a dict key and search for Tensor.grad.__get__ in it + d = {frozenset(nowrap_funcs): x * 2} + # Check if Tensor.grad.__get__ is in the set + if torch.Tensor.grad.__get__ in nowrap_funcs: + return d[frozenset(nowrap_funcs)] + x + return x + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + instantiate_parametrized_tests(DictTests) @@ -1738,7 +1887,9 @@ def fn(x): new_gn = partial(gn, x=1) key = Container(new_gn, 4) new_dict[key] = 5 - return x * new_dict[key] + # Make another key that should hash to the same value + key1 = Container(new_gn, 4) + return x * new_dict[key1] x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -1747,6 +1898,53 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) + def test_custom_object_as_dict_key(self): + """Test that custom objects with __hash__ as dict keys are properly handled. + + This test verifies that when using custom objects with overridden __hash__ + and __eq__ as dictionary keys, two instances with the same hash and equality + should be recognized as the same key. + """ + + class CustomKey: + def __init__(self, value, name): + self.value = value + self.name = name + + def fn(x): + d = {} + # Create first instance + key1 = CustomKey(42, "test") + d[key1] = x * 2 + + # Create second instance with same values - should hash to same value + key2 = CustomKey(42, "test") + d[key2] = x * 3 # This should overwrite the first value + + return d[key1] * d[key2] + + x = torch.randn(4) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertTrue(same(opt_fn(x), fn(x))) + + def test_user_defined_object(self): + class A: + def __init__(self): + self.x = {} + REF[self] = {} + + REF = {} + + def f(a, x): + REF[a]["foo"] = x + return x + 1 + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + + x = torch.randn(4) + self.assertTrue(same(f(A(), x), opt_f(A(), x))) + class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict diff --git a/test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu b/test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 5f967971005f6..7cf8e52d0197d 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3667,5 +3667,49 @@ "Use custom operators instead of direct attribute/method access." ] } + ], + "GB0363": [ + { + "Gb_type": "User-defined object with overridden __hash__", + "Context": "hashing object of type={type(obj)} and variable tracker {vt}", + "Explanation": "Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", + "Hints": [ + "Dynamo does not support hashing user-defined objects with overridden __hash__", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0364": [ + { + "Gb_type": "Dynamo cannot determine whether the underlying object is hashable", + "Context": "is_python_hashable {self}", + "Explanation": "Dynamo does not know whether the underlying python object for {self} is hashable", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {type_self}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0365": [ + { + "Gb_type": "Dynamo cannot determine the hash of an object", + "Context": "get_python_hash {self}", + "Explanation": "Dynamo does not know the hash of the underlying python object for {self}", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0366": [ + { + "Gb_type": "Dynamo cannot determine the equality comparison of an object", + "Context": "is_python_equal {self}", + "Explanation": "Dynamo does not know the equality comparison of the underlying python object for {self}", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } ] } diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index c6825737ec994..5b1070aad5ad6 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4956,3 +4956,21 @@ def get_traced_code() -> Optional[list[CodeType]]: from torch._guards import TracingContext return TracingContext.get_traced_code() + + +def raise_on_overridden_hash(obj: Any, vt: VariableTracker) -> None: + from . import graph_break_hints + from .exc import unimplemented + + is_overridden = type(obj).__dict__.get("__hash__", False) + + if is_overridden: + unimplemented( + gb_type="User-defined object with overridden __hash__", + context=f"hashing object of type={type(obj)} and variable tracker {vt}", + explanation=f"Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", + hints=[ + "Dynamo does not support hashing user-defined objects with overridden __hash__", + *graph_break_hints.SUPPORTABLE, + ], + ) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 617f787e43d8a..a794010f4083f 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -683,6 +683,62 @@ def build( else: return variables.LazyVariableTracker.create(value, source) + def is_python_hashable(self): + """ + Unlike the variable tracker's own __hash__, this method checks whether + the underlying Python object referenced by this variable tracker is hashable. + """ + try: + type_self = self.python_type() + except NotImplementedError: + type_self = type(self) + + unimplemented( + gb_type="Dynamo cannot determine whether the underlying object is hashable", + context=f"is_python_hashable {self}", + explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {type_self}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def get_python_hash(self): + """ + Unlike the variable tracker’s own __hash__, this method is used by + ConstDictVariableTracker to compute the hash of the underlying key object. + """ + unimplemented( + gb_type="Dynamo cannot determine the hash of an object", + context=f"get_python_hash {self}", + explanation=f"Dynamo does not know the hash of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def is_python_equal(self, other): + """ + NB - Deliberately not overriding the __eq__ method because that can + disable the __hash__ for the vt itself. + """ + unimplemented( + gb_type="Dynamo cannot determine the equality comparison of an object", + context=f"is_python_equal {self}", + explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + def __init__( self, *, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index ae6678628634a..8fdaefea56f89 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -3243,6 +3243,15 @@ def call_contains( ) -> VariableTracker: return a.call_method(tx, "__contains__", [b], {}) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn + @contextlib.contextmanager def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]: diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 672fa1d804383..0b2eaaea80826 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -23,6 +23,7 @@ istype, np, raise_args_mismatch, + raise_on_overridden_hash, ) from .base import ValueMutationNew, VariableTracker @@ -340,6 +341,20 @@ def call_obj_hasattr( result = hasattr(self.value, name) return variables.ConstantVariable.create(result) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + # Could be an EnumVariable as well + from .tensor import SymNodeVariable + + if isinstance(other, SymNodeVariable): + return self.as_python_constant() == other.evaluate_expr() + return self.as_python_constant() == other.as_python_constant() + class EnumVariable(VariableTracker): """VariableTracker for enum.Enum and enum.IntEnum instances @@ -388,3 +403,13 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker member = getattr(self.value, name) source = self.source and AttrSource(self.source, name) return VariableTracker.build(tx, member, source=source) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 422cae7c4d3f1..9b98c91723063 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -20,14 +20,11 @@ import collections import functools -import inspect import operator import types -from collections.abc import Hashable as py_Hashable, Sequence +from collections.abc import Sequence from typing import Any, Optional, TYPE_CHECKING, Union -from torch._subclasses.fake_tensor import is_fake - from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import raise_observed_exception, unimplemented @@ -55,8 +52,8 @@ # [Adding a new supported class within the keys of ConstDictVariable] -# - Add its tracker type to is_hashable -# - (perhaps) Define how it is compared in _HashableTracker._eq_impl +# - Implement is_python_hashable() method in the VariableTracker subclass +# - Implement get_python_hash() and is_python_equal() methods for hashable types def was_instancecheck_override(obj: Any) -> bool: @@ -73,7 +70,7 @@ def raise_unhashable( raise_observed_exception( TypeError, tx, - args=[ConstantVariable(f"unhashable type: {type(arg.realize())}")], + msg=f"Unhashable type: {arg.python_type()!r} and variable tracker = {type(arg.realize())}", ) @@ -88,52 +85,7 @@ def is_hashable(x: VariableTracker) -> bool: and x.is_hashable() ): return True - - if isinstance(x, variables.TensorVariable): - # Tensors are hashable if they have an example_value (a fake tensor) - # Most VT's should have one. - # It'd be nice if at some point we could assert that they all have one - return x.as_proxy().node.meta.get("example_value") is not None - elif isinstance(x, variables.TupleVariable): - return all(is_hashable(e) for e in x.items) - elif isinstance(x, variables.FrozenDataClassVariable): - return all(is_hashable(e) for e in x.fields.values()) - elif ( - isinstance(x, variables.UserDefinedObjectVariable) - and not was_instancecheck_override(x.value) - and inspect.getattr_static(x.value, "__hash__") is int.__hash__ - and isinstance(x.value, int) - ): - return isinstance(x.value, py_Hashable) - elif isinstance(x, variables.FunctoolsPartialVariable): - return ( - is_hashable(x.func) - and all(is_hashable(arg) for arg in x.args) - and all(is_hashable(value) for value in x.keywords.values()) - ) - else: - return isinstance( - x, - ( - variables.BuiltinVariable, - variables.SymNodeVariable, - variables.ConstantVariable, - variables.EnumVariable, - variables.FrozensetVariable, - variables.UserDefinedClassVariable, - variables.UserFunctionVariable, - variables.SkipFunctionVariable, - variables.misc.NumpyVariable, - variables.NNModuleVariable, - variables.UnspecializedNNModuleVariable, - variables.MethodWrapperVariable, - variables.TorchInGraphFunctionVariable, - variables.TypingVariable, - variables.FunctoolsPartialVariable, - variables.WeakRefVariable, - variables.TorchHigherOrderOperatorVariable, - ), - ) + return x.is_python_hashable() class ConstDictVariable(VariableTracker): @@ -154,88 +106,47 @@ class _HashableTracker: def __init__(self, vt: VariableTracker) -> None: # We specialize SymNodes vt = specialize_symnode(vt) - # TODO Temporarily remove to figure out what keys are we breaking on - # and add proper support for them + + # If Dynamo does not know the hashability of the vt, it will raise unsupported here if not is_hashable(vt): raise_unhashable(vt) self.vt = vt - @property - def underlying_value(self) -> Any: + def __hash__(self) -> int: + """ + Computes the hash value for the wrapped VariableTracker. + + For unrealized LazyVariableTrackers, uses the hash of the original value + to avoid realizing the tracker and inserting unnecessary guards. + For all other cases, delegates to the VariableTracker's get_python_hash method. + + Returns: + The hash value of the underlying variable tracker + """ if ( isinstance(self.vt, variables.LazyVariableTracker) and not self.vt.is_realized() and self.vt.is_hashable() ): - return self.vt.original_value() - if isinstance(self.vt, variables.TensorVariable): - x = self.vt.as_proxy().node.meta["example_value"] - elif isinstance(self.vt, variables.TupleVariable): - Hashable = ConstDictVariable._HashableTracker - x = tuple(Hashable(e).underlying_value for e in self.vt.items) - elif isinstance(self.vt, variables.NNModuleVariable): - return self.vt.value - elif isinstance(self.vt, variables.UnspecializedNNModuleVariable): - return self.vt.value - elif isinstance(self.vt, variables.UserFunctionVariable): - return self.vt.get_function() - elif isinstance(self.vt, variables.WeakRefVariable): - # Access the underlying value inside the referent_vt for the key representation - Hashable = ConstDictVariable._HashableTracker - return Hashable(self.vt.referent_vt).underlying_value - elif isinstance(self.vt, variables.FrozenDataClassVariable): - Hashable = ConstDictVariable._HashableTracker - fields_values = { - k: Hashable(v).underlying_value - for k, v in self.vt.fields.items() # type: ignore[attr-defined] - } - return variables.FrozenDataClassVariable.HashWrapper( - self.vt.python_type(), fields_values - ) - elif isinstance(self.vt, variables.UserDefinedObjectVariable): - # The re module in Python 3.13+ has a dictionary (_cache2) with - # an object as key (`class _ZeroSentinel(int): ...`): - # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual - return self.vt.value # type: ignore[attr-defined,union-attr] - elif isinstance(self.vt, variables.FunctoolsPartialVariable): - Hashable = ConstDictVariable._HashableTracker - items = (self.vt.func, *self.vt.args, *self.vt.keywords.values()) - x = tuple(Hashable(e).underlying_value for e in items) - return x - else: - x = self.vt.as_python_constant() - return x + return hash(self.vt.original_value()) + return self.vt.get_python_hash() - def __hash__(self) -> int: - return hash(self.underlying_value) - - @staticmethod - def _eq_impl(a: Any, b: Any) -> bool: - # TODO: Put this in utils and share it between variables/builtin.py and here - type_a, type_b = type(a), type(b) - if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)): - return False - - if isinstance(a, tuple): - Hashable = ConstDictVariable._HashableTracker - return len(a) == len(b) and all( - Hashable._eq_impl(u, v) for u, v in zip(a, b) - ) - elif is_fake(a): - return a is b - else: - return a == b + def __eq__(self, other) -> bool: + """ + Checks equality between two _HashableTracker instances. - def __eq__(self, other: object) -> bool: - Hashable = ConstDictVariable._HashableTracker - assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), ( - type(other) - ) - if isinstance(other, Hashable): - return Hashable._eq_impl(self.underlying_value, other.underlying_value) + Delegates to the VariableTracker's is_python_equal method to compare + the underlying variable trackers for Python-level equality. + + Args: + other: Another _HashableTracker instance to compare with - # constant - return Hashable._eq_impl(self.underlying_value, other) + Returns: + True if the underlying variable trackers are Python-equal, False otherwise + """ + if self.vt is other.vt: + return True + return self.vt.is_python_equal(other.vt) def __init__( self, @@ -324,7 +235,7 @@ def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return ( - is_hashable(vt) + vt.is_python_hashable() and Hashable(vt) in self.items and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) ) @@ -536,8 +447,6 @@ def call_method( Hashable = ConstDictVariable._HashableTracker - arg_hashable = args and is_hashable(args[0]) - if name == "__init__": temp_dict_vt = variables.BuiltinVariable(dict).call_dict( tx, *args, **kwargs @@ -606,6 +515,7 @@ def call_method( self.install_dict_keys_match_guard() return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -620,16 +530,21 @@ def call_method( tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = args[1] return ConstantVariable.create(None) - elif name == "__delitem__" and arg_hashable and self.is_mutable(): - self.install_dict_keys_match_guard() - self.should_reconstruct_all = True - tx.output.side_effects.mutation(self) - self.items.__delitem__(Hashable(args[0])) - return ConstantVariable.create(None) + elif name == "__delitem__" and self.is_mutable(): + arg_hashable = args and is_hashable(args[0]) + if arg_hashable: + self.install_dict_keys_match_guard() + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.__delitem__(Hashable(args[0])) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) elif name == "get": if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -645,6 +560,7 @@ def call_method( if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -736,6 +652,7 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -751,6 +668,7 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) + arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -903,6 +821,12 @@ def clone(self, **kwargs: Any) -> VariableTracker: self.install_dict_keys_match_guard() return super().clone(**kwargs) + def is_python_hashable(self): + """ + Dictionaries are mutable and therefore not hashable in Python. + """ + return False + class MappingProxyVariable(VariableTracker): # proxies to the original dict_vt @@ -1416,6 +1340,18 @@ def call_method( return FrozensetVariable(r.items) # type: ignore[attr-defined] return super().call_method(tx, name, args, kwargs) + def is_python_hashable(self): + """ + Frozensets are immutable and hashable in Python. + """ + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class DictKeySetVariable(SetVariable): def debug_repr(self) -> str: @@ -1605,3 +1541,9 @@ def call_method( return self.dv_dict.call_method(tx, "__eq__", [args[0].dv_dict], {}) return ConstantVariable.create(False) return super().call_method(tx, name, args, kwargs) + + def is_python_hashable(self): + """ + Dictionary item views are not hashable in Python. + """ + return False diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index deee9bcec42de..360c0fdd94488 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -807,6 +807,15 @@ def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: return collected return None + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.UserFunctionVariable) and self.fn is other.fn + class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): _nonvar_fields = { @@ -1963,6 +1972,15 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return fn_var_getattr(tx, self.value, self.source, name) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class WrappedSkipFunctionVariable(SkipFunctionVariable): def __init__( @@ -2349,6 +2367,34 @@ def guard_as_python_constant(self) -> Any: **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, ) + def is_python_hashable(self) -> bool: + return ( + self.func.is_python_hashable() + and all(arg.is_python_hashable() for arg in self.args) + and all(value.is_python_hashable() for value in self.keywords.values()) + ) + + def get_python_hash(self): + func_hash = self.func.get_python_hash() + args_hash = (arg.get_python_hash() for arg in self.args) + values_hash = (value.get_python_hash() for value in self.keywords.values()) + return hash((func_hash, *args_hash, *values_hash)) + + def is_python_equal(self, other): + return ( + self.func.is_python_equal(other.func) + and all( + arg_a.is_python_equal(arg_b) + for (arg_a, arg_b) in zip(self.args, other.args) + ) + and all( + value_a.is_python_equal(value_b) + for (value_a, value_b) in zip( + self.keywords.values(), other.keywords.values() + ) + ) + ) + class PolyfilledFunctionVariable(VariableTracker): _nonvar_fields = { diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index afb6522ac0e5c..8b178b3be1ac3 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1738,6 +1738,15 @@ def _call_function( def as_python_constant(self): return self.value + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 4f21e35479fb8..a97c284f9516c 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -620,6 +620,25 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return self.items[fields.index(name)] return super().var_getattr(tx, name) + def is_python_hashable(self): + return True + + def get_python_hash(self): + l = self.range_length() + start = self.start() + step = self.step() + return hash((l, start, step)) + + def is_python_equal(self, other): + if not isinstance(other, variables.RangeVariable): + return False + + return ( + self.start() == other.start() + and self.step() == other.step() + and self.stop() == other.stop() + ) + class CommonListMethodsVariable(BaseListVariable): """ @@ -981,6 +1000,9 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) + def is_python_hashable(self): + return False + class DequeVariable(CommonListMethodsVariable): def __init__( @@ -1170,6 +1192,18 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) + def is_python_hashable(self): + return all(item.is_python_hashable() for item in self.items) + + def get_python_hash(self): + items = tuple(x.get_python_hash() for x in self.items) + return hash(items) + + def is_python_equal(self, other): + return isinstance(other, variables.TupleVariable) and all( + a.is_python_equal(b) for (a, b) in zip(self.items, other.items) + ) + class SizeVariable(TupleVariable): """torch.Size(...)""" diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 8d074f913dbf5..5bd8ad5d075e6 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1306,6 +1306,15 @@ def is_python_constant(self): def as_python_constant(self): return self.method_wrapper + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class GetSetDescriptorVariable(VariableTracker): def __init__(self, desc, **kwargs) -> None: @@ -1440,6 +1449,15 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # codegen.append_output(codegen.create_load_const(self.value)) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): @@ -1618,6 +1636,15 @@ def as_proxy(self): return super().as_proxy() + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls class NullVariable(VariableTracker): @@ -2097,3 +2124,13 @@ def reconstruct(self, codegen: "PyCodegen"): codegen(self.referent_vt) codegen(self.callback_vt) codegen.extend_output(create_call_function(2, False)) + + def is_python_hashable(self): + return self.referent_vt.is_python_hashable() + + def get_python_hash(self): + # weakref relies on the referent's hash + return self.referent_vt.get_python_hash() + + def is_python_equal(self, other): + return self.referent_vt.is_python_equal(other.referent_vt) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 0787ef7c49b57..548e69ef0262d 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1428,6 +1428,20 @@ def set_name_hint(self, name: str): self.proxy.node._rename(name) self._is_name_set = True + def is_python_hashable(self): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return self.as_proxy().node.meta["example_value"] is not None + + def get_python_hash(self): + return hash(self.as_proxy().node.meta["example_value"]) + + def is_python_equal(self, other): + a = self.as_proxy().node.meta["example_value"] + b = other.as_proxy().node.meta["example_value"] + return a is b + class SymNodeVariable(VariableTracker): """ @@ -1516,6 +1530,20 @@ def call_method( ), ) + def is_python_hashable(self): + return True + + def get_python_hash(self): + # Essentially convert the SymNode to a constant variable whenever its + # searched for a dict key. + return hash(self.evaluate_expr()) + + def is_python_equal(self, other): + if isinstance(other, SymNodeVariable): + return self.evaluate_expr() == other.evaluate_expr() + # could be constant variable as well + return self.evaluate_expr() == other.as_python_constant() + class NumpyNdarrayVariable(TensorVariable): """ diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 76da71f6fb323..78d87a09713ab 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -2075,6 +2075,15 @@ def torch_function_override_enabled(self, tx, args, kwargs): ) ) and can_dispatch_torch_function(tx, args, kwargs) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class DispatchKeySetVariable(BaseTorchVariable): """represents torch.DispatchKeySet""" diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index e87af5b87a75a..012bea32620e9 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -89,6 +89,7 @@ object_has_getattribute, proxy_args_kwargs, raise_args_mismatch, + raise_on_overridden_hash, set_methods, tensortype_to_dtype, tuple_methods, @@ -927,6 +928,18 @@ def const_getattr(self, tx: "InstructionTranslator", name): return self.value.__name__ return super().const_getattr(tx, name) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return ( + isinstance(other, variables.UserDefinedClassVariable) + and self.value is other.value + ) + class UserDefinedExceptionClassVariable(UserDefinedClassVariable): @property @@ -1743,26 +1756,20 @@ def call_obj_hasattr( handle_observed_exception(tx) return variables.ConstantVariable.create(False) + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True -class FrozenDataClassVariable(UserDefinedObjectVariable): - class HashWrapper: - """This class is hashed if a dataclass is used as a key in a dict. - It's necessary to avoid side effects from calling the __init__ of the dataclass class when hashing""" + def get_python_hash(self): + # default hash + return hash(self.value) - def __init__(self, c, fields): - self.cls = c - self.fields = tuple(fields.items()) + def is_python_equal(self, other): + # id check + return self.value is other.value - def __eq__(self, other): - return ( - type(self) is type(other) - and self.cls == other.cls - and self.fields == other.fields - ) - - def __hash__(self): - return hash((self.cls, self.fields)) +class FrozenDataClassVariable(UserDefinedObjectVariable): @staticmethod def create(tx, value, source): from dataclasses import fields @@ -1860,6 +1867,22 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value_type.__name__})" + def is_python_hashable(self): + # TODO - Check corner cases like eq=False, hash=False etc + return True + + def get_python_hash(self): + return hash(tuple(arg.get_python_hash() for arg in self.fields.values())) + + def is_python_equal(self, other): + is_class_same = self.python_type() is other.python_type() + is_field_name_same = self.fields.keys() == other.fields.keys() + is_field_value_same = all( + value_a.is_python_equal(value_b) + for value_a, value_b in zip(self.fields.values(), other.fields.values()) + ) + return is_class_same and is_field_name_same and is_field_value_same + class SourcelessGraphModuleVariable(UserDefinedObjectVariable): def __init__( @@ -2080,6 +2103,10 @@ def install_dict_keys_match_guard(self): def install_dict_contains_guard(self): return self._dict_vt.install_dict_contains_guard() + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return False + class UserDefinedSetVariable(UserDefinedObjectVariable): """ @@ -2153,6 +2180,18 @@ def install_dict_keys_match_guard(self): def install_dict_contains_guard(self): return self._set_vt.install_dict_contains_guard() + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return self._set_vt.is_python_hashable() + + def get_python_hash(self): + return self._set_vt.get_python_hash() + + def is_python_equal(self, other): + return isinstance( + other, UserDefinedSetVariable + ) and self._set_vt.is_python_equal(other._set_vt) + class UserDefinedListVariable(UserDefinedObjectVariable): """ @@ -2194,6 +2233,10 @@ def unpack_var_sequence(self, tx): def is_underlying_vt_modified(self, side_effects): return side_effects.is_modified(self._list_vt) + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return False + class UserDefinedTupleVariable(UserDefinedObjectVariable): """ @@ -2242,6 +2285,18 @@ def unpack_var_sequence(self, tx): return self._tuple_vt.unpack_var_sequence(tx) raise NotImplementedError + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return self._tuple_vt.is_python_hashable() + + def get_python_hash(self): + return self._tuple_vt.get_python_hash() + + def is_python_equal(self, other): + return isinstance( + other, UserDefinedTupleVariable + ) and self._tuple_vt.is_python_equal(other._tuple_vt) + class MutableMappingVariable(UserDefinedObjectVariable): def __init__(self, value, **kwargs): From d76697633a2d2b9cced1ae21161849b33bfe7e47 Mon Sep 17 00:00:00 2001 From: Nick Riasanovsky Date: Tue, 2 Dec 2025 03:59:23 +0000 Subject: [PATCH 100/338] [Inductor] [Triton] Capture Timeout errors without crashing the job (#169064) Summary: Opts to capture timeout errors during compilation without forcing process failure. Useful to avoid hangs in MAST jobs. We may want to consider a configuration option for this to avoid wasted compute by never pruning bad config options. Test Plan: Tested with local model reproducers. Differential Revision: D87866423 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169064 Approved by: https://github.com/PaulZhang12 --- torch/_inductor/select_algorithm.py | 70 +++++++++++++++++------------ 1 file changed, 41 insertions(+), 29 deletions(-) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index eb1bbf42f8c37..f0101f01f3617 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -3241,23 +3241,17 @@ def wait_on_futures(): log.debug("Waiting on futures") counters["inductor"]["select_algorithm_precompile"] += 1 exceptions: list[tuple[ChoiceCaller, BaseException]] = [] - for future in as_completed( - futures, - timeout=precompilation_timeout_seconds, - ): - if e := future.exception(): - counters["inductor"][ - "select_algorithm_num_precompilation_exceptions" - ] += 1 - exceptions.append((futures[future], e)) - from torch._inductor.codegen.cuda.cuda_kernel import ( - CUDATemplateCaller, - ) - - if isinstance(e, CUDACompileError) and isinstance( - futures[future], CUDATemplateCaller - ): - log.debug( + try: + for future in as_completed( + futures, + timeout=precompilation_timeout_seconds, + ): + if e := future.exception(): + counters["inductor"][ + "select_algorithm_num_precompilation_exceptions" + ] += 1 + exceptions.append((futures[future], e)) + log.exception( # noqa: G202 "Exception %s for benchmark choice %s", e, futures[future], @@ -3265,20 +3259,38 @@ def wait_on_futures(): ) futures[future].mark_failed() else: - log.exception( # noqa: G202 - "Exception %s for benchmark choice %s", - e, - futures[future], - exc_info=e, + counters["inductor"]["select_algorithm_num_precompiles"] += 1 + log.info( + "Precompiling benchmark choice %s took %.02fs", + futures.get(future), + elapsed_times.get(future), ) - futures[future].mark_failed() - else: - counters["inductor"]["select_algorithm_num_precompiles"] += 1 - log.info( - "Precompiling benchmark choice %s took %.02fs", - futures.get(future), - elapsed_times.get(future), + except TimeoutError: + # Don't force the entire process to crash due to a timeout + # in compilation. Just mark those futures as failed. + completed_futures = OrderedSet([f for f in futures if f.done()]) + remaining_futures = OrderedSet(futures.keys()) - completed_futures + + log.warning( + "Precompilation timeout after %ds: %d of %d futures did not complete", + precompilation_timeout_seconds, + len(remaining_futures), + len(futures), + ) + + # Mark remaining futures as failed and log them + for future in remaining_futures: + choice = futures[future] + log.warning( + "Marking choice as failed due to timeout: %s", + choice, + ) + choice.mark_failed() + # Add timeout exception to the exceptions list + timeout_exc = TimeoutError( + f"Precompilation timed out after {precompilation_timeout_seconds}s" ) + exceptions.append((choice, timeout_exc)) if exceptions: _log_autotune_exceptions(exceptions) From bb3034198b459401fabeab254e1b99f0115046e2 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 2 Dec 2025 05:36:44 +0000 Subject: [PATCH 101/338] Avoid std::tie and returning value constructions in qconv_unpack.cpp (#169207) This PR avoids returning value construction in `qconv_unpack.cpp`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169207 Approved by: https://github.com/Skylion007 --- .../ATen/native/quantized/qconv_unpack.cpp | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/quantized/qconv_unpack.cpp b/aten/src/ATen/native/quantized/qconv_unpack.cpp index 4c2352a396177..df66a6087f738 100644 --- a/aten/src/ATen/native/quantized/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/qconv_unpack.cpp @@ -82,32 +82,31 @@ class QConv1dUnpackWeightsInt8 final { static std::tuple> run( const c10::intrusive_ptr>& packed_weight) { auto& ctx = at::globalContext(); - at::Tensor weight; - std::optional bias; #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM || ctx.qEngine() == at::QEngine::X86) { - std::tie(weight, bias) = packed_weight->unpack(); + auto result = packed_weight->unpack(); + auto& weight = std::get<0>(result); weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return std::tuple>(weight, bias); + return result; } #endif #ifdef USE_PYTORCH_QNNPACK if (ctx.qEngine() == at::QEngine::QNNPACK) { - std::tie(weight, bias) = packed_weight->unpack(); - at::Tensor new_weight = weight.clone(); - new_weight = new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return std::tuple>(new_weight, bias); + auto result = packed_weight->unpack(); + auto& weight = std::get<0>(result); + weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return result; } #endif #if AT_MKLDNN_ENABLED() if (ctx.qEngine() == at::QEngine::ONEDNN) { - std::tie(weight, bias) = packed_weight->unpack(); - at::Tensor new_weight = weight.clone(); - new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return std::tuple>(new_weight, bias); + auto result = packed_weight->unpack(); + auto& weight = std::get<0>(result); + weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return result; } #endif From fb5be221a46b51bfc9509013b0d85bc5a9d4f15b Mon Sep 17 00:00:00 2001 From: nandan2003 Date: Tue, 2 Dec 2025 05:40:11 +0000 Subject: [PATCH 102/338] Refactor: Remove unnecessary ConstantVariable wrapping in raise_observed_exception (#168337) Fixes #168291 # Summary Removes `ConstantVariable.create` wrapping in `raise_observed_exception` calls within `torch/_dynamo/variables/functions.py`. # Context The `raise_observed_exception` function handles the exception creation internally. Wrapping the error strings in `ConstantVariable` is unnecessary and can be simplified to passing raw strings. # Test Plan - [x] Verified syntax validity via `python3 -m py_compile torch/_dynamo/variables/functions.py` - [ ] CI/CD (Existing tests should pass as this is a refactor of error reporting paths) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168337 Approved by: https://github.com/williamwen42, https://github.com/guilhermeleobas, https://github.com/cyyever --- torch/_dynamo/variables/functions.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 360c0fdd94488..4f6301b1eb6c5 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -210,11 +210,7 @@ def bind_args_cached( raise_observed_exception( TypeError, tx, - args=[ - ConstantVariable.create( - f"Missing required positional argument: {name}" - ) - ], + args=[f"Missing required positional argument: {name}"], ) # 2) *args @@ -226,9 +222,7 @@ def bind_args_cached( TypeError, tx, args=[ - ConstantVariable.create( - f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" - ) + f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" ], ) @@ -245,11 +239,7 @@ def bind_args_cached( raise_observed_exception( TypeError, tx, - args=[ - ConstantVariable.create( - f"Missing required keyword-only argument: {name}" - ) - ], + args=[f"Missing required keyword-only argument: {name}"], ) # 4) **kwargs @@ -259,9 +249,7 @@ def bind_args_cached( raise_observed_exception( TypeError, tx, - args=[ - ConstantVariable.create(f"Unexpected keyword arguments: {list(rem_kw)}") - ], + args=[f"Unexpected keyword arguments: {list(rem_kw)}"], ) return ba @@ -2994,7 +2982,7 @@ def call_function( if len(args) != 1: raise_type_error_exc( tx, - f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", + f"_get_node_type() takes 1 positional argument but {len(args)} were given", ) type_source = None if args[0].source: From 166efdad2ac827f30fb02504c6017520257f88ec Mon Sep 17 00:00:00 2001 From: "Andy (An) Wang" Date: Tue, 2 Dec 2025 05:49:43 +0000 Subject: [PATCH 103/338] [MTIAGraph][Pytorch] Add the graph_pool_handle api (#169283) Summary: Add the `torch.mtia.graph_pool_handle` API as the counterpart of `torch.cuda.graph_pool_handle`, which is used in vllm, e.g. https://www.internalfb.com/code/fbsource/[f6d024bd45964d71810cbe1ed859f132f7f734cd]/fbcode/vllm/trunk/vllm/compilation/cuda_graph.py?lines=170 Test Plan: ``` buck2 run mtia/host_runtime/torch_mtia/tests:test_mtia_graph_py -- -r test_graph_pool_handle ``` Differential Revision: D88059625 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169283 Approved by: https://github.com/patrick-toulme --- aten/src/ATen/detail/MTIAHooksInterface.h | 4 ++++ docs/source/mtia.mtia_graph.md | 4 ++++ torch/_C/__init__.pyi.in | 1 + torch/csrc/mtia/Module.cpp | 4 ++++ torch/mtia/__init__.py | 1 + torch/mtia/mtia_graph.py | 8 ++++++++ 6 files changed, 22 insertions(+) diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index 58c7a0304181c..a9742a78146e1 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -183,6 +183,10 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { virtual MempoolId_t mtiagraphPool(int64_t handle) const { FAIL_MTIAHOOKS_FUNC(__func__); } + + virtual MempoolId_t graphPoolHandle() const { + FAIL_MTIAHOOKS_FUNC(__func__); + } }; struct TORCH_API MTIAHooksArgs {}; diff --git a/docs/source/mtia.mtia_graph.md b/docs/source/mtia.mtia_graph.md index 1d1560960792c..424171ea863c3 100644 --- a/docs/source/mtia.mtia_graph.md +++ b/docs/source/mtia.mtia_graph.md @@ -10,6 +10,10 @@ The MTIA backend is implemented out of the tree, only interfaces are defined her .. currentmodule:: torch.mtia.mtia_graph ``` +```{eval-rst} +.. autofunction:: graph_pool_handle +``` + ```{eval-rst} .. autoclass:: MTIAGraph :members: diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e9b58b9ce71eb..532815d535d5e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2009,6 +2009,7 @@ def _mtia_attachOutOfMemoryObserver( ) -> None: ... def _mtia_getDeviceCount() -> _int: ... def _mtia_resetPeakMemoryStats(device: _int) -> None: ... +def _mtia_graphPoolHandle() -> tuple[_int, _int]: ... # Defined in torch/csrc/mtia/Module.cpp class _MTIAGraph: diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 468e4828c4122..dd15864e332b4 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -171,6 +171,10 @@ void initModule(PyObject* module) { at::detail::getMTIAHooks().resetPeakMemoryStats(device_index); }); + m.def("_mtia_graphPoolHandle", []() { + return at::detail::getMTIAHooks().graphPoolHandle(); + }); + py::class_<_MTIAGraph>(m, "_MTIAGraph") .def(py::init(), py::arg("keep_graph") = false) .def("capture_begin", &_MTIAGraph::capture_begin) diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 35ef04a67319d..af3a333bc3d2b 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -427,4 +427,5 @@ def set_rng_state( "is_bf16_supported", "MTIAGraph", "graph", + "graph_pool_handle", ] diff --git a/torch/mtia/mtia_graph.py b/torch/mtia/mtia_graph.py index bc5a8ea49dfea..019f5604c4d95 100644 --- a/torch/mtia/mtia_graph.py +++ b/torch/mtia/mtia_graph.py @@ -9,6 +9,13 @@ _POOL_HANDLE = tuple[int, int] +def graph_pool_handle() -> _POOL_HANDLE: + """ + Return an opaque token representing the id of a graph memory pool. + """ + return torch._C._mtia_graphPoolHandle() + + class MTIAGraph(torch._C._MTIAGraph): """ Wrapper around a MTIA graph. @@ -93,4 +100,5 @@ def __exit__(self, *args: object) -> None: __all__ = [ "MTIAGraph", "graph", + "graph_pool_handle", ] From 62d3ccd71484ed6a760d909b41487101bbc65719 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Tue, 2 Dec 2025 06:04:40 +0000 Subject: [PATCH 104/338] [audio hash update] update the pinned audio hash (#169198) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169198 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index b65b6a7f117ef..a3c4cd801b60c 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -32ce8c011855adb15438ddc9bf6c139d23f8cee5 +e90a3986cbebd57a5ad08b6813e2c7ff199cdbe0 From 7d1bbaf4ba301ea3fba6f3c7bc02d58f6417aaed Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 2 Dec 2025 06:13:04 +0000 Subject: [PATCH 105/338] Add sum support for qlinear_binary templated implementation (#163249) Add sum support for qlinear_binary templated gemm, and also allow sum for the choice of `QLinearPointwiseBinaryPT2E` in the cases of x2 happening to be the output of `QLinearPointwiseBinaryPT2E`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163249 Approved by: https://github.com/Xia-Weiwen, https://github.com/leslie-fang-intel, https://github.com/jansel --- test/inductor/test_cpu_select_algorithm.py | 38 +++++- test/inductor/test_mkldnn_pattern_matcher.py | 63 ++++++++++ torch/_inductor/codegen/cpp_wrapper_cpu.py | 29 +++-- torch/_inductor/codegen/wrapper.py | 118 ++++++++++++++++--- torch/_inductor/fx_passes/mkldnn_fusion.py | 35 +++++- torch/_inductor/fx_passes/quantization.py | 19 ++- torch/_inductor/mkldnn_lowerings.py | 31 ++++- 7 files changed, 288 insertions(+), 45 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index ca520ab66bcc2..d4249e9ab4b6d 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -1954,6 +1954,8 @@ def test_quantized_linear_with_pointwise_binary( return B = (2, batch_size) if input_3d else (batch_size,) input = torch.randn(*B, in_features).to(dtype=torch.float32) + input2 = torch.randn(*B, in_features).to(dtype=torch.float32) + input3 = torch.randn(*B, out_features).to(dtype=torch.float32) other = torch.randn(*B, out_features).to(dtype=dtype) # Avoid hitting qlinear inplace sum fusion @@ -1962,6 +1964,8 @@ def test_quantized_linear_with_pointwise_binary( else: other2 = torch.randn(1, *B, out_features).to(dtype=dtype) + other_clone = other.clone() + class M(torch.nn.Module): def __init__(self, bias, input_3d): super().__init__() @@ -1981,11 +1985,29 @@ def forward(self, x, other, other2): res = self.epilogue2(self.linear2(res) + other2) return res + class M2(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias) + self.epilogue = _get_epilogue(epilogue) + self.linear2 = torch.nn.Linear(out_features, out_features, bias) + self.epilogue2 = _get_epilogue(epilogue) + + def forward(self, x0, x1, other): + # test qlinear sum -> qlinear sum + res = self.epilogue(self.linear(x0) + other) + res = self.epilogue2(self.linear2(x1) + res) + return res + counters.clear() ref_quantized_mod = _generate_qdq_quantized_model( M(bias=bias, input_3d=input_3d).eval(), (input, other, other2), ) + ref_quantized_mod2 = _generate_qdq_quantized_model( + M2(bias=bias).eval(), + (input2, input3, other_clone), + ) atol, rtol = 5e-2, 5e-2 with ( patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)), @@ -1994,6 +2016,9 @@ def forward(self, x, other, other2): ): ref_res = ref_quantized_mod(input, other, other2) cfn = torch.compile(ref_quantized_mod) + ref_res2 = ref_quantized_mod2(input2, input3, other_clone) + cfn2 = torch.compile(ref_quantized_mod2) + res = cfn(input, other, other2) self.assertEqual( res, @@ -2003,7 +2028,18 @@ def forward(self, x, other, other2): equal_nan=True, exact_dtype=True, ) - self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2) + + res2 = cfn2(input2, input3, other_clone) + self.assertEqual( + res2, + ref_res2, + atol=atol, + rtol=rtol, + equal_nan=True, + exact_dtype=True, + ) + + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 4) self.assertEqual( counters["inductor"]["cpp_epilogue_fusion_counter"], 0, diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index a793a052c059d..440ee6a52f553 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -2425,6 +2425,51 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, ) + def _qlinear_sum_test_helper( + self, + inputs, + device="cpu", + int8_mixed_bf16=False, + matcher_check_fn=None, + bias=True, + ): + class M(torch.nn.Module): + def __init__(self, use_bias): + super().__init__() + self.linear = torch.nn.Linear(4, 4, use_bias) + self.linear2 = torch.nn.Linear(4, 4, use_bias) + + def forward(self, x, other): + # test qlinear sum -> qlinear sum + res = self.linear(x) + other + res = self.linear2(x) + res + return res + + mod = M(bias).eval().to(device=device) + assert isinstance(inputs, tuple) + + def __convert_tensor_to_device(input, device): + return input.to(device=device) if isinstance(input, torch.Tensor) else input + + inputs = tuple(__convert_tensor_to_device(input, device) for input in inputs) + + def _default_matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + + self._test_common( + mod, + inputs, + matcher_check_fn=( + matcher_check_fn + if matcher_check_fn is not None + else _default_matcher_check_fn + ), + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_quantization=True, + ) + def _qlinear_test_helper( self, inputs, @@ -3140,6 +3185,24 @@ def test_qlinear_add_int8_mixed_bf16_xpu(self, use_relu, is_qat, is_dynamic): is_dynamic=is_dynamic, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qlinear_sum_cpu(self): + for bias in [True, False]: + use_bf16 = ( + [True, False] + if is_mkldnn_bf16_supported("cpu") + else [ + False, + ] + ) + for int8_mixed_bf16 in use_bf16: + self._qlinear_sum_test_helper( + (torch.randn((2, 2, 4)), torch.randn(2, 2, 4)), + bias=bias, + int8_mixed_bf16=int8_mixed_bf16, + ) + def _test_qlinear_fp8_inductor_cpu_helper(self, qlinear_op, post_op="none"): dtype = torch.float8_e4m3fn qlinear_prepack = torch.ops.onednn.qlinear_prepack diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 0bb1b40cfad96..9ec44c6c2790f 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -29,6 +29,7 @@ from .common import get_device_op_overrides, IndentedBuffer, Kernel from .cpp_utils import cexpr, DEVICE_TO_ATEN, DEVICE_TO_INT, DTYPE_TO_ATEN, DTYPE_TO_CPP from .wrapper import ( + codegen_reinterpret_view_helper, EnterSubgraphLine, ExitSubgraphLine, PythonWrapperCodegen, @@ -1823,6 +1824,11 @@ def codegen_reinterpret_view( """Returns a newly-created, temporary RAII tensor handle containing the reinterpreted tensor data. Callers of this function are responsible for saving the handle if persistent access is needed.""" + + d_size, d_stride, d_offset, d_dtype, collapsible = ( + codegen_reinterpret_view_helper(data) + ) + dim = str(len(size)) original_offset = offset offset = self.codegen_sizevar(offset) @@ -1868,13 +1874,21 @@ def create_new_tensor_handle() -> tuple[str, list[str]]: ] return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs - if ( - size == data.layout.size - and stride == data.layout.stride - and original_offset == data.layout.offset - ): + collapsed = collapsible and original_offset == d_offset + if collapsed: + same_layout = size == d_size and stride == d_stride + base_dtype = d_dtype + else: + same_layout = ( + size == data.layout.size + and stride == data.layout.stride + and original_offset == data.layout.offset + ) + base_dtype = data.dtype + + if same_layout: # pure dtypeview - if dtype is not None and dtype != data.dtype: + if dtype is not None and dtype != base_dtype: final_tensor_str, tmp_call_strs = create_dtypeview_call(data.get_name()) else: final_tensor_str, tmp_call_strs = create_new_tensor_handle() @@ -1882,8 +1896,7 @@ def create_new_tensor_handle() -> tuple[str, list[str]]: else: # firstly create reinterpretview final_tensor_str = create_reinterpret_call() - - if dtype is not None and dtype != data.dtype: + if dtype is not None and dtype != base_dtype: # wrap it with dtypeview final_tensor_str, tmp_call_strs = create_dtypeview_call( final_tensor_str diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0eab3cac9b4a7..86290dee57bd0 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -138,6 +138,40 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike): return False +def codegen_reinterpret_view_helper(data): + """ + Collapse a chain of ReinterpretView <- StorageBox + <- ReinterpretView <- StorageBox.... <- buffer wrappers if every layer + has the same offset as the innermost (base) buffer. + + Returns: + (size, stride, offset, dtype, collapsible: bool) + """ + if isinstance(data, ir.Buffer): + lay = data.get_layout() + return lay.size, lay.stride, lay.offset, lay.dtype, True + + layouts: list[Any] = [] + cur = data + while isinstance(cur, (ir.TensorBox, ir.StorageBox, ir.ReinterpretView)): + lay = cur.get_layout() + if lay is None: + return None, None, None, None, False + layouts.append(lay) + cur = cur.data # unwrap + + if not isinstance(cur, ir.Buffer): + return None, None, None, None, False + + # All wrapper offsets must match base offset to be collapsible + for lay in layouts: + if lay.offset != cur.get_layout().offset: + return None, None, None, None, False + + base_lay = cur.get_layout() + return base_lay.size, base_lay.stride, base_lay.offset, base_lay.dtype, True + + # TODO: Move to a well known place TritonMetaParams = dict[str, int] TritonGrid = Union[ @@ -2022,25 +2056,58 @@ def codegen_reinterpret_view( writeline: Callable[..., None], dtype=None, ) -> str: - if ( - size == data.layout.size - and stride == data.layout.stride - and offset == data.layout.offset + # Get the innermost buffer's layout info to help reinterpret view. + # Consider a chain of (ReinterpretView <- TensorBox| StorageBox)... <- buffer + # If we only use x.data to determine the reinterpret, we may get wrong layout. + # For example: + # x = ReinterpretView( + # Storage( + # ReinterpretView( + # storage( + # Buffer(name='buf0', layout=(size=(2, 5, 10), ...) + # ), + # layout=(10, 10), + # ), + # ), + # layout=(10, 10), + # ) + # In this case, x.data.layout == x.layout is (10, 10), the reinterpret view will return buf0, + # but buf0 need to be viewed from (2, 5, 10) to (10, 10). + # So we need to dig into the chain to find the innermost buffer's layout. + d_size, d_stride, d_offset, d_dtype, collapsible = ( + codegen_reinterpret_view_helper(data) + ) + + def apply_reinterpret( + name, tgt_size, tgt_stride, tgt_offset, cast_dtype, base_dtype ): - if dtype is not None and dtype != data.dtype: - return f"aten.view.dtype({data.get_name()}, {dtype})" - else: - return f"{data.get_name()}" + s = self.codegen_python_shape_tuple(tgt_size) + st = self.codegen_python_shape_tuple(tgt_stride) + off = self.codegen_sizevar(tgt_offset) + expr = f"reinterpret_tensor({name}, {s}, {st}, {off})" + if cast_dtype is not None and cast_dtype != base_dtype: + return f"aten.view.dtype({expr}, {cast_dtype})" + return expr + + name = data.get_name() + collapsed = collapsible and offset == d_offset + if collapsed: + same_layout = size == d_size and stride == d_stride + base_dtype = d_dtype else: - size = self.codegen_python_shape_tuple(size) - stride = self.codegen_python_shape_tuple(stride) - offset = self.codegen_sizevar(offset) - if dtype is not None and dtype != data.dtype: - return f"aten.view.dtype(reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset}), {dtype})" - else: - return ( - f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" - ) + same_layout = ( + size == data.layout.size + and stride == data.layout.stride + and offset == data.layout.offset + ) + base_dtype = data.dtype + + if same_layout: + if dtype is not None and dtype != base_dtype: + return f"aten.view.dtype({name}, {dtype})" + return f"{name}" + + return apply_reinterpret(name, size, stride, offset, dtype, base_dtype) def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): self.writeline(f"{dst}.copy_({src}, {non_blocking})") @@ -3180,7 +3247,7 @@ def codegen_allocation(self, buffer: ir.Buffer): if ( name in V.graph.removed_buffers or name in self.allocated - or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer)) + or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer, ir.InputBuffer)) ): return self.allocated.add(name) @@ -3205,7 +3272,20 @@ def codegen_allocation(self, buffer: ir.Buffer): box = layout.view.data assert isinstance(box, ir.StorageBox), type(box) input_buffer = box.data - assert isinstance(input_buffer, ir.Buffer), type(box) + assert isinstance(input_buffer, (ir.Buffer, ir.ReinterpretView)), type( + input_buffer + ) + if isinstance(input_buffer, ir.ReinterpretView): + + def unwrap_views(target) -> ir.Buffer: + if isinstance(target, ir.BaseView): + return unwrap_views(target.unwrap_view()) + if isinstance(target, ir.MutableBox): + return unwrap_views(target.data) + assert isinstance(target, ir.Buffer), type(target) + return target + + input_buffer = unwrap_views(input_buffer) self.codegen_allocation(input_buffer) self.writeline(ReinterpretLine(self, input_buffer, buffer, layout)) return diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 214d3bf02f7f4..08252e58dd566 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -9,7 +9,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_symbols from torch.utils._ordered_set import OrderedSet -from .. import ir +from .. import ir, mkldnn_ir from ..lowering import lowerings as L from ..pattern_matcher import ( Arg, @@ -765,6 +765,39 @@ def _can_be_inplace(_other): or len(_other.get_inputs_that_alias_output()) > 0 ) + def _qlinear_binary_can_be_inplace(_other): + if isinstance(_other.data, ir.BaseView): + + def unwrap_buffer(data): + if isinstance(data, ir.StorageBox): + return data.data + return data + + data = _other.data.unwrap_view() + if isinstance(unwrap_buffer(data), ir.CppTemplateBuffer): + # It can be inplaced when _other is the 2D to 3D view of + # a CppTemplateBuffer because if there is a view of CppTemplateBuffer, + # CppTemplateBuffer will not be used directly but the view. + return True + else: + # The case of QLinearPointwiseBinaryPT2E(sum) -> QLinearPointwiseBinaryPT2E(sum) + # is similar to CppTemplateBuffer above. + # The output of previous QLinearPointwiseBinaryPT2E is + # the input x2 of current QLinearPointwiseBinaryPT2E. + # Use V.graph.operations to check if _other is a view of the output + # of previous QLinearPointwiseBinaryPT2E (the inputs[6]). + for op in V.graph.operations: + if ( + isinstance(op, mkldnn_ir.QLinearPointwiseBinaryPT2E) + and unwrap_buffer(data) == op.inputs[6] # type: ignore[attr-defined] + ): + return True + return False + elif len(_other.get_inputs_that_alias_output()) > 0: + return False + else: + return True + def _register_binary_unary_maybe_inplace_fusion_lowering( pattern, computation_op, diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index ceb0ce3a2f6e6..951a62acf2276 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -615,22 +615,21 @@ def qlinear_binary(match: Match, *args, **kwargs): o_zero_point = kwargs["output_zero_point"] x2.realize() - from .mkldnn_fusion import _can_be_inplace + from .mkldnn_fusion import _qlinear_binary_can_be_inplace binary_op_name = kwargs["binary_op_name"] alpha = kwargs["alpha"] unary_op_name = kwargs["unary_op_name"] unary_op_args = kwargs["unary_op_args"] unary_op_algorithm = kwargs["unary_op_algorithm"] - - if binary_op_name == "sum" and not _can_be_inplace(x2): - # When we enable the GEMM Template, the output of QLinear - # will be reshaped from 2D back to 3D if the input is 3D. - # This causes _can_be_inplace(x2) to return False if x2 happens - # to be the output of QLinear in this scenario. - # Change the post op from sum to binary add for this case. - # Refer to test case: - # test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2 + if ( + # TODO Ensure sum is safe and remove such check, i.e., + # x2 is not used by other operations + # or current qlinear sum is the last user of x2. + # This needs to be ensured when registering + # the lowering pattern of quantized_linear_binary. + binary_op_name == "sum" and (not _qlinear_binary_can_be_inplace(x2)) + ): binary_op_name = "add" computation_args = ( diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 14b492aff35ad..823e2baf12dda 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1018,7 +1018,7 @@ def qlinear_binary( x_size = x.get_size() x2_size = x2.get_size() assert len(x_size) == len(x2_size) - if len(x_size) > 2 and binary_attr == "add": + if len(x_size) > 2 and binary_attr in ["add", "sum"]: # GEMM template needs 2D input, normalize input shape here x = view(x, [-1, x_size[-1]]) x2 = view(x2, [-1, x2_size[-1]]) @@ -1086,9 +1086,10 @@ def qlinear_binary( x2_dtype = x2.get_dtype() bias_dtype = bias.get_dtype() if bias is not None else None choices: list[ChoiceCaller] = [] - if ( - config.max_autotune or config.max_autotune_gemm - ) and binary_attr == "add": # Support inplace sum fusion + if (config.max_autotune or config.max_autotune_gemm) and binary_attr in [ + "add", + "sum", + ]: *_, layout, x, packed_weight, x2 = mm_args( x, packed_weight, x2, layout=layout, out_dtype=output_dtype ) @@ -1316,8 +1317,26 @@ def inner_fn_requant(index, scale, zero_point): layout, input_gen_fns=input_gen_fns, ) - if len(x_size) > 2 and binary_attr == "add": - result = view(result, (*x_size[:-1], result.get_size()[-1])) + if ( + isinstance(result.data.data, ir.CppTemplateBuffer) + and binary_attr == "sum" + and result.data.data.layout == x2.get_layout() + ): + # In this case, since x2 is inplace updated when binary_attr is "sum" + # we update the layout of result to view of x2 + result = ir.TensorBox.create( + ir.CppTemplateBuffer( + layout=ir.NonOwningLayout( + ir.ReinterpretView(data=x2, layout=x2.get_layout()) + ), + inputs=result.data.data.inputs, # type: ignore[arg-type] + make_kernel_render=result.data.data.make_kernel_render, # type: ignore[arg-type] + template=result.data.data.template, + choice=result.data.data.choice, + ) + ) + if len(x_size) > 2 and binary_attr in ["add", "sum"]: + result = view(result, (*x_size[:-1], result.get_size()[-1])) # type: ignore[arg-type] return result if torch._C.has_mkl: From 66004b993744b4106bf8afaba71f3c228a804206 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 1 Dec 2025 22:05:48 -0800 Subject: [PATCH 106/338] [dynamo] Fix test state leakage for test_modes.py (#168928) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168928 Approved by: https://github.com/anijain2305 --- test/dynamo/test_modes.py | 29 +++++++++++++++++++++++++++-- torch/__init__.py | 9 ++++----- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index f163e7169bfa3..476ba716b4ee6 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -138,12 +138,20 @@ def fn(x): class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): - cls.default_device_old = torch.get_default_device() + try: + cls.default_device_old = torch.get_default_device() + except AttributeError: + cls.default_device_old = torch.device("cpu") + global_default_ctx = getattr( + getattr(torch, "_GLOBAL_DEVICE_CONTEXT", None), "device_context", None + ) + cls._had_global_default_device = global_default_ctx is not None super().setUpClass() @classmethod def tearDownClass(cls): - torch.set_default_device(cls.default_device_old) + if cls._had_global_default_device: + torch.set_default_device(cls.default_device_old) super().tearDownClass() def setUp(self): @@ -791,6 +799,23 @@ def test_hop_eager(self): ) +class TorchFunctionModeLifecycleTests(torch._dynamo.test_case.TestCase): + def test_default_device_restored_after_mode_tests(self): + case = TorchFunctionModeTests("test_stack_state_mutation_default_device") + TorchFunctionModeTests.setUpClass() + try: + case.setUp() + try: + case.test_stack_state_mutation_default_device() + finally: + case.tearDown() + finally: + TorchFunctionModeTests.tearDownClass() + + stack = _get_current_function_mode_stack() + self.assertFalse(any(isinstance(mode, DeviceContext) for mode in stack)) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/__init__.py b/torch/__init__.py index 165ade4f04dcf..e6f9cfcb54472 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1208,11 +1208,10 @@ def _get_device_with_index(device): device = device_mode.device return _get_device_with_index(device) - if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): - device = _GLOBAL_DEVICE_CONTEXT.device_context.device - return _get_device_with_index(device) - else: - return torch.device("cpu") + device_context = getattr(_GLOBAL_DEVICE_CONTEXT, "device_context", None) + if device_context is not None: + return _get_device_with_index(device_context.device) + return torch.device("cpu") def set_default_device(device: "Device") -> None: From b9c8f3a4884befb965ff42620ce44a71b04887f5 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 1 Dec 2025 22:05:49 -0800 Subject: [PATCH 107/338] [dynamo] Fix local test failures for logging (#168927) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168927 Approved by: https://github.com/anijain2305 ghstack dependencies: #168928 --- test/dynamo/test_logging.py | 65 +++++++++++++++++++++++++--- test/dynamo/test_structured_trace.py | 22 ++++++++++ test/dynamo/test_utils.py | 5 +++ torch/_dynamo/__init__.py | 2 + torch/_dynamo/utils.py | 6 +++ 5 files changed, 93 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index f472705101e35..be6ce3d172756 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -68,6 +68,35 @@ def munge(s): return "\n".join([line for line, nsubs in lines if nsubs > 0]) +LOG_PREFIX_PATTERNS = [ + re.compile(r"^\[rank\d+\]:\s*"), + re.compile(r"^[A-Z]+:[^:]+:\s*"), + re.compile(r"^[A-Z]\d{2,4}\s+\d{2}:\d{2}:\d{2}(?:\.\d+)?\s+\d+\s+[^\]]+\]\s*"), + re.compile(r"^[A-Z](?:\d{4})?\s+[^:]+:\s*"), +] + + +def normalize_log_line(line: str) -> str: + line = line.rstrip() + for pattern in LOG_PREFIX_PATTERNS: + stripped, count = pattern.subn("", line, count=1) + if count: + line = stripped.lstrip() + break + return line + + +def normalize_rank_prefix(output: str) -> str: + if "[rank" in output: + return output + + def repl(match): + prefix = match.group(1) + return f"{prefix}[rank0]: " + + return re.sub(r"(^|\n)(?:[A-Z]+:[^:]+:)", repl, output) + + def example_fn(a): output = a.mul(torch.ones(1000, 1000)) output = output.add(torch.ones(1000, 1000)) @@ -388,8 +417,17 @@ def test_custom_format(self, records): if torch._logging._internal._is_torch_handler(handler): break self.assertIsNotNone(handler) - self.assertIn("I", handler.format(records[0])) - self.assertEqual("custom format", handler.format(records[1])) + formatted_dynamo = handler.format(records[0]) + self.assertIn("test dynamo", formatted_dynamo) + self.assertEqual(normalize_log_line(formatted_dynamo), "test dynamo") + ci_style_line = ( + "I1124 19:43:23.879000 4928 dynamo/test_logging.py:410] test dynamo" + ) + self.assertEqual(normalize_log_line(ci_style_line), "test dynamo") + + formatted_artifact = handler.format(records[1]) + self.assertIn("custom format", formatted_artifact) + self.assertEqual(normalize_log_line(formatted_artifact), "custom format") @make_logging_test(dynamo=logging.INFO) def test_multiline_format(self, records): @@ -404,10 +442,20 @@ def test_multiline_format(self, records): if torch._logging._internal._is_torch_handler(handler): break self.assertIsNotNone(handler) - for record in records: - r = handler.format(record) - for l in r.splitlines(): - self.assertIn("I", l) + expected_lines = [ + ["test", "dynamo"], + ["test", "dynamo"], + ["test", "test", "dynamo"], + ] + + for record, expected in zip(records, expected_lines): + formatted = handler.format(record) + normalized_lines = [ + line + for line in (normalize_log_line(l) for l in formatted.splitlines()) + if line + ] + self.assertEqual(normalized_lines, expected) test_trace_source_simple = within_range_record_test(1, 100, trace_source=True) @@ -566,7 +614,10 @@ def test_distributed_rank_logging(self): """, env=env, ) - self.assertIn("[rank0]:", stderr.decode("utf-8")) + stderr_text = stderr.decode("utf-8") + normalized = normalize_rank_prefix(stderr_text) + self.assertIn("[rank0]:", normalized) + self.assertIn("woof", normalized) @skipIfNotPy311 @make_logging_test(trace_call=True) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 33715d2cf861b..21cf04cffbf65 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -196,7 +196,24 @@ def tearDown(self): self.raw_file.close() trace_log.setLevel(self.old_level) + def assertExpectedInline(self, actual, expected): + super().assertExpectedInline( + self._normalize_rank_field(actual), + self._normalize_rank_field(expected), + ) + + @staticmethod + def _normalize_rank_field(text): + if not isinstance(text, str): + return text + text = text.replace(', "rank": 0', "") + text = text.replace('"rank": 0, ', "") + text = text.replace('"rank": 0', "") + return text + def assertParses(self): + if not HAS_TLPARSE: + self.skipTest("requires tlparse") out = tempfile.mkdtemp() try: subprocess.check_call( @@ -540,6 +557,11 @@ def throw(x): @requires_distributed() @requires_cuda_and_triton def test_ddp_graphs(self): + import torch._dynamo.convert_frame as convert_frame + + convert_frame.FRAME_COUNTER = 0 + convert_frame.FRAME_COMPILE_COUNTER.clear() + class ToyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index 24573a3a8178b..f0c1f50093f82 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -227,6 +227,11 @@ class TestDynamoTimed(TestCase): Test utilities surrounding dynamo_timed. """ + def setUp(self): + super().setUp() + if hasattr(torch._dynamo, "reset_recompile_user_contexts"): + torch._dynamo.reset_recompile_user_contexts() + def run_forward_backward(self): model = torch.compile(TestModel()) x = torch.rand([3], requires_grad=True) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index de097edf87752..e9a5e8d89d07c 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -68,6 +68,7 @@ orig_code_map, register_hook_for_recompile_user_context, reset_frame_count, + reset_recompile_user_contexts, ) @@ -103,6 +104,7 @@ "register_backend", "replay", "reset", + "reset_recompile_user_contexts", "run", "error_on_graph_break", "set_stance", diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 5b1070aad5ad6..d3c351e0de01a 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -285,6 +285,12 @@ def get_hook_for_recompile_user_context() -> Optional[list[Callable[[], str]]]: return _recompile_user_contexts +def reset_recompile_user_contexts() -> None: + """Clear any registered recompile user-context hooks (test helper).""" + global _recompile_user_contexts + _recompile_user_contexts = None + + op_count = 0 From 4cfb47ff548b6d996641058cf04a70e311a4c3aa Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 1 Dec 2025 14:52:50 -0800 Subject: [PATCH 108/338] [CP] Refactor CP sharding rules into separate module and only register when CP is enabled (#167381) Previously, CP-specific sharding strategies (which shard on the sequence dimension) were directly included in the base sharding strategies for scaled_dot_product_attention operators in `_matrix_ops.py`. This meant these strategies were always available, even when CP was not enabled, which could lead to incorrect sharding behavior as these sharding rules are not mathmetically correct without CP. 1. **Created new module**: `torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py` - Implements `op_strategy_context()` - a context manager for temporarily registering/unregistering strategies - Defines CP-enhanced strategy functions for all 6 scaled_dot_product_attention ops (forward and backward for flash, efficient, and cudnn variants) - Provides `register_cp_sharding_rules()` and `unregister_cp_sharding_rules()` APIs 2. **Updated `_matrix_ops.py`** - Removed all CP-specific sharding rules (sequence dimension sharding strategies) - Base strategies now only contain replicate, tensor parallelism, and batch sharding strategies 3. **Updated `_attention.py`** - `_enable_cp_dtensor_dispatcher()` now calls `register_cp_sharding_rules()` to dynamically add CP strategies - ~`_disable_cp_dtensor_dispatcher()` now calls `unregister_cp_sharding_rules()` to restore original strategies~ This will invalidate all the sharding prop caches. Disable it for now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167381 Approved by: https://github.com/wconstab --- test/distributed/tensor/test_attention.py | 58 +++ torch/distributed/tensor/_ops/_matrix_ops.py | 237 ++++------ .../_context_parallel/_attention.py | 15 + .../_context_parallel/_sharding_rules.py | 406 ++++++++++++++++++ 4 files changed, 571 insertions(+), 145 deletions(-) create mode 100644 torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index 6c3485f9d7025..4febcf82937df 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -34,6 +34,10 @@ from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import ( flex_cp_allgather, ) +from torch.distributed.tensor.experimental._context_parallel._sharding_rules import ( + register_cp_sharding_rules, + unregister_cp_sharding_rules, +) from torch.distributed.tensor.parallel import parallelize_module from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( @@ -813,6 +817,60 @@ def test_context_parallel_shard(self) -> None: ), ) + @skip_if_lt_x_gpu(2) + @with_comms + @unittest.skipIf( + not PLATFORM_SUPPORTS_FUSED_ATTENTION, + "Does not support flash nor efficient attention", + ) + def test_attention_shard_without_cp(self) -> None: + """Test that sharding on sequence dimension without CP enabled is not supported.""" + from torch.distributed.tensor import distribute_tensor, Replicate, Shard + + B = 2 + nheads = 4 + seq_len = 256 + dim = 32 + + device_mesh = init_device_mesh( + mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type + ) + + for backend in backends: + with sdpa_kernel(backend): + dtype = torch.bfloat16 + if backend == SDPBackend.EFFICIENT_ATTENTION: + dtype = torch.float32 + # Create q, k, v tensors with shape (B, nheads, seq_len, dim) + q = torch.randn( + B, nheads, seq_len, dim, device=self.device_type, dtype=dtype + ) + k = torch.randn( + B, nheads, seq_len, dim, device=self.device_type, dtype=dtype + ) + v = torch.randn( + B, nheads, seq_len, dim, device=self.device_type, dtype=dtype + ) + q_dt = distribute_tensor(q, device_mesh, [Shard(2)]) + k_dt = distribute_tensor(k, device_mesh, [Shard(2)]) + v_dt = distribute_tensor(v, device_mesh, [Shard(2)]) + + register_cp_sharding_rules() + out = F.scaled_dot_product_attention(q_dt, k_dt, v_dt) + unregister_cp_sharding_rules(clear_the_cache=True) + out = F.scaled_dot_product_attention(q_dt, k_dt, v_dt) + # Run SDPA with sequence-sharded tensors WITHOUT enabling CP + # Without CP enabled, DTensor should select a different strategy + # (not sequence-sharded) because Shard(2) strategy is only available with CP + + # Verify the output is NOT sharded on sequence dimension (dim 2) + # This proves that CP sharding rules were not used + self.assertNotEqual( + out.placements[0], Shard(2), f"Placement {out.placements}" + ) + # The output should be replicated or sharded on batch head dimensions. + self.assertIn(out.placements[0], [Replicate(), Shard(0), Shard(1)]) + RingAttentionTestWithLocalTensor = create_local_tensor_test_class( RingAttentionTest, diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index 30498a95e29d6..5911e4cef1e7d 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -265,16 +265,10 @@ def scaled_mm_strategy(op_schema: OpSchema) -> OpStrategy: return _scaled_mm_like_strategy("mk,kn->mn", mesh, op_schema) -@register_op_strategy( - aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5) -) -def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy: - # NOTE: currently we only support some simple strategies to support tensor parallelism - # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation - # as it involves: matmul, pointwise, reduction ops together. - - mesh = op_schema.get_mesh_from_args() - +def _scaled_dot_product_flash_attention_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] q_input_strategy = op_schema.args_schema[0] if not isinstance(q_input_strategy, OpStrategy): @@ -347,37 +341,30 @@ def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrate Shard(0), # v ] ) + return single_mesh_dim_strategies - # Context Parallelism: shards on the sequence dim - debug_attn_mask_sharding = Shard(2) if return_debug_mask else Replicate() - single_mesh_dim_strategies.append( - [ - Shard(2), # output - Shard(2), # logsumexp - None, # cum_seq_q - None, # cum_seq_k - None, # max_q - None, # max_k - Replicate(), # rng_state - None, # unused - debug_attn_mask_sharding, # debugattn - Shard(2), # q - Shard(2), # k - Shard(2), # v - ] + +@register_op_strategy( + aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5) +) +def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation + # as it involves: matmul, pointwise, reduction ops together. + + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies( + op_schema ) return expand_to_full_mesh_op_strategy( mesh, op_schema, single_mesh_dim_strategies, input_index=9 ) -@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default) -def scaled_dot_product_flash_attention_backward_strategy( +def _scaled_dot_product_flash_attention_backward_base_strategies( op_schema: OpSchema, -) -> OpStrategy: - # backward op does not need to validate the mesh since forward op has already done it - mesh = op_schema.get_mesh_from_args(validate=False) - +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" q_input_strategy = op_schema.args_schema[1] if not isinstance(q_input_strategy, OpStrategy): raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}") @@ -442,24 +429,18 @@ def scaled_dot_product_flash_attention_backward_strategy( batch_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) single_mesh_dim_strategies.append(batch_dim_sharding) - # Context Parallelism: shards on the sequence dim - seq_dim_sharding: PlacementList = [ - Shard(2), # grad_q - Shard(2), # grad_k - Shard(2), # grad_v - Shard(2), # grad_output - Shard(2), # q - Shard(2), # k - Shard(2), # v - Shard(2), # output - Shard(2), # logsumexp - ] - # accept replicate on the rest tensor inputs, potentially - # cum_seq_q, cum_seq_k, philox_seed, philox_offset - # at indices 6, 7, 12, 13, respectively - seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) - single_mesh_dim_strategies.append(seq_dim_sharding) + return single_mesh_dim_strategies + +@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default) +def scaled_dot_product_flash_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_flash_attention_backward_base_strategies(op_schema) + ) return expand_to_full_mesh_op_strategy( mesh, op_schema, single_mesh_dim_strategies, input_index=3 ) @@ -484,13 +465,10 @@ def constant_pad_nd_strategy(op_schema: OpSchema) -> OpStrategy: ) -@register_op_strategy( - aten._scaled_dot_product_efficient_attention.default, - schema_info=RuntimeSchemaInfo(4), -) -def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy: - # NOTE: currently we only support some simple strategies to support tensor parallelism - mesh = op_schema.get_mesh_from_args() +def _scaled_dot_product_efficient_attention_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" q_input_strategy = op_schema.args_schema[0] if not isinstance(q_input_strategy, OpStrategy): raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}") @@ -516,19 +494,6 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt if has_attn_bias: all_replicate.append(Replicate()) # attn bias - # Context Parallelism: shards on the sequence dim - single_mesh_dim_strategies.append( - [ - Shard(2), # output - Shard(2), # logsumexp - None, # philox_seed - None, # philox_offset - Shard(2), # q - Shard(2), # k - Shard(2), # v - ] - ) - single_mesh_dim_strategies.append(all_replicate) # second we can accept the sharding pattern of tensor parallelism, which @@ -574,6 +539,19 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt single_mesh_dim_strategies.append(batch_sharding) + return single_mesh_dim_strategies + + +@register_op_strategy( + aten._scaled_dot_product_efficient_attention.default, + schema_info=RuntimeSchemaInfo(4), +) +def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = ( + _scaled_dot_product_efficient_attention_base_strategies(op_schema) + ) return expand_to_full_mesh_op_strategy( mesh, op_schema, @@ -582,13 +560,10 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt ) -@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default) -def scaled_dot_product_efficient_attention_backward_strategy( +def _scaled_dot_product_efficient_attention_backward_base_strategies( op_schema: OpSchema, -) -> OpStrategy: - # backward op does not need to validate the mesh since forward op has already done it - mesh = op_schema.get_mesh_from_args(validate=False) - +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" q_input_strategy = op_schema.args_schema[1] if not isinstance(q_input_strategy, OpStrategy): raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}") @@ -660,27 +635,18 @@ def scaled_dot_product_efficient_attention_backward_strategy( batch_dim_sharding.extend([Replicate(), Replicate()]) single_mesh_dim_strategies.append(batch_dim_sharding) - # Context Parallelism: shards on the sequence dim - seq_dim_sharding: PlacementList = [ - Shard(2), # grad_q - Shard(2), # grad_k - Shard(2), # grad_v - Shard(1) if has_attn_bias else None, # grad_bias - Shard(2), # grad_output - Shard(2), # q - Shard(2), # k - Shard(2), # v - Shard(2), # output - Shard(2), # logsumexp - ] - # accept replicate on the rest tensor inputs, potentially - # cum_seq_q, cum_seq_k, philox_seed, philox_offset - # at indices 6, 7, 12, 13, respectively - if has_attn_bias: - num_heads_dim_sharding.insert(8, Shard(1)) - seq_dim_sharding.extend([Replicate(), Replicate()]) - single_mesh_dim_strategies.append(seq_dim_sharding) + return single_mesh_dim_strategies + +@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default) +def scaled_dot_product_efficient_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_efficient_attention_backward_base_strategies(op_schema) + ) return expand_to_full_mesh_op_strategy( mesh, op_schema, @@ -689,13 +655,10 @@ def scaled_dot_product_efficient_attention_backward_strategy( ) -@register_op_strategy( - aten._scaled_dot_product_cudnn_attention.default, - schema_info=RuntimeSchemaInfo(4), -) -def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy: - mesh = op_schema.get_mesh_from_args() - +def _scaled_dot_product_cudnn_attention_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" ( query_strategy, # query _, # key @@ -783,39 +746,27 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate ] single_mesh_dim_strategies.append(batch_dim_sharding) - # Context Parallelism: shards on the sequence dim - cp_sharding = Shard(2) # seq dim - logsumexp_sharding = cp_sharding if compute_log_sumexp else Replicate() - debug_attn_mask_sharding = cp_sharding if return_debug_mask else None + return single_mesh_dim_strategies - single_mesh_dim_strategies.append( - [ - cp_sharding, # output - logsumexp_sharding, # logsumexp - None, # cum_seq_q - None, # cum_seq_k - None, # max_q - None, # max_k - None, # philox_seed - None, # philox_offset - debug_attn_mask_sharding, # debug_attn_mask - cp_sharding, # q - cp_sharding, # k - cp_sharding, # v - ] + +@register_op_strategy( + aten._scaled_dot_product_cudnn_attention.default, + schema_info=RuntimeSchemaInfo(4), +) +def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies( + op_schema ) return expand_to_full_mesh_op_strategy( mesh, op_schema, single_mesh_dim_strategies, input_index=9 ) -@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default) -def scaled_scaled_dot_product_cudnn_attention_backward_strategy( +def _scaled_dot_product_cudnn_attention_backward_base_strategies( op_schema: OpSchema, -) -> OpStrategy: - # backward op does not need to validate the mesh since forward op has already done it - mesh = op_schema.get_mesh_from_args(validate=False) - +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" if len(op_schema.args_schema) < 15: raise AssertionError( f"Expected at least 15 args_schema, got {len(op_schema.args_schema)}" @@ -890,23 +841,7 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy( num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp single_mesh_dim_strategies.append(num_heads_dim_sharding) - # case 3: Context Parallelism which shards on the sequence dim - context_parallel_sharding_out: PlacementList = [Shard(2)] * 3 - context_parallel_sharding_inp: PlacementList = [Shard(2)] * 6 - context_parallel_sharding_inp += [ - Replicate() - ] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor - context_parallel_sharding_inp += [Shard(2) if has_attn_bias else None] - context_parallel_sharding_inp += [None] * 6 - if has_scale: - context_parallel_sharding_inp.append(None) - - context_parallel_sharding = ( - context_parallel_sharding_out + context_parallel_sharding_inp - ) - single_mesh_dim_strategies.append(context_parallel_sharding) - - # case 4: we can accept the sharding pattern of batch parallelism, which + # case 3: we can accept the sharding pattern of batch parallelism, which # shards on the batch dimension qkv_sharding = Shard(0) output_sharding = Shard(0) @@ -927,6 +862,18 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy( batch_dim_sharding = batch_dim_sharding_out + batch_dim_sharding_inp single_mesh_dim_strategies.append(batch_dim_sharding) + return single_mesh_dim_strategies + + +@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default) +def scaled_scaled_dot_product_cudnn_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema) + ) return expand_to_full_mesh_op_strategy( mesh, op_schema, single_mesh_dim_strategies, input_index=3 ) diff --git a/torch/distributed/tensor/experimental/_context_parallel/_attention.py b/torch/distributed/tensor/experimental/_context_parallel/_attention.py index f3d06b4fd274d..9a1c6299dfca4 100644 --- a/torch/distributed/tensor/experimental/_context_parallel/_attention.py +++ b/torch/distributed/tensor/experimental/_context_parallel/_attention.py @@ -989,16 +989,31 @@ def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None: def _enable_cp_dtensor_dispatcher() -> None: """Enables DTensor dispatcher to dispatch SDPA to CP.""" + # Enable custom op handlers for CP DTensor._op_dispatcher._custom_op_handlers = { **exitsing_custom_ops, **custom_ops, } + # Register CP-specific sharding rules + from ._sharding_rules import register_cp_sharding_rules + + register_cp_sharding_rules() def _disable_cp_dtensor_dispatcher() -> None: """Disables DTensor dispatcher to dispatch SDPA to CP.""" + # Restore original custom op handlers DTensor._op_dispatcher._custom_op_handlers = exitsing_custom_ops + # TODO: unregister_cp_sharding_rules(clear_the_cache=True) will cause + # all DTensor sharding propagation cache being invalidated. It is not + # easy to achieve selectively invalidating lru cache without rewriting + # the sharding propagation wrapper. + + from ._sharding_rules import unregister_cp_sharding_rules + + unregister_cp_sharding_rules(clear_the_cache=False) + def _enable_context_parallel_dispatcher_impl(seq_dim: int, mesh: DeviceMesh) -> None: sdpa_cp = _ContextParallel( diff --git a/torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py b/torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py new file mode 100644 index 0000000000000..ebb6eb0cface8 --- /dev/null +++ b/torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py @@ -0,0 +1,406 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +""" +Context Parallelism sharding rules for scaled_dot_product attention operators. + +The sharding rules for CP cannot be embedded by default because Shard(2) is not +a valid sharding for SDPA without CP enabled. This module provides utilities to +dynamically install Shard(2) sharding rules when CP is activated. +""" + +from contextlib import contextmanager + +import torch +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + RuntimeSchemaInfo, +) +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy +from torch.distributed.tensor.debug import ( + _clear_fast_path_sharding_prop_cache, + _clear_python_sharding_prop_cache, +) +from torch.distributed.tensor.placement_types import Replicate, Shard + + +aten = torch.ops.aten + +SEQ_DIM = 2 + + +@contextmanager +def _op_strategy_context(op_overload, strategy_func, schema_info=None): + """ + Context manager for setting and clearing op strategies for Context Parallelism. + + Args: + op_overload: The operator overload to set or clear the strategy for. + strategy_func: The strategy function to set for the operator overload. + schema_info: Optional schema information for the operator overload. + + Yields: + None + """ + from torch.distributed.tensor import DTensor + + propagator = DTensor._op_dispatcher.sharding_propagator + _origin_op_strategy_funcs = None + _origin_op_strategy_schema = None + try: + # Save original strategy if exists + if op_overload in propagator.op_strategy_funcs: + _origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload] + if op_overload in propagator.op_to_schema_info: + _origin_op_strategy_schema = propagator.op_to_schema_info[op_overload] + + # Register the new op strategy + register_op_strategy(op_overload, schema_info=schema_info)(strategy_func) + yield (_origin_op_strategy_funcs, _origin_op_strategy_schema) + finally: + # Restore original strategy + if _origin_op_strategy_funcs is None: + if op_overload in propagator.op_strategy_funcs: + del propagator.op_strategy_funcs[op_overload] + else: + propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs + + if _origin_op_strategy_schema is None: + if op_overload in propagator.op_to_schema_info: + del propagator.op_to_schema_info[op_overload] + else: + propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema + + # Ideally, we should clear the cache, but it is too expensive. + # _clear_python_sharding_prop_cache() + # _clear_fast_path_sharding_prop_cache() + + +# ==================== Flash Attention Strategies ==================== + + +def _scaled_dot_product_flash_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy: + """ + Strategy for flash attention forward with Context Parallelism support. + This includes the base strategies plus CP-specific sequence dimension sharding. + """ + # Import here to avoid circular dependency + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_flash_attention_base_strategies, + ) + + # Get the base strategies (without CP modifications) + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies( + op_schema + ) + + # Add Context Parallelism strategy: shards on the sequence dim + return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] + debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else Replicate() + + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # output + Shard(SEQ_DIM), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + debug_attn_mask_sharding, # debugattn + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + ] + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +def _scaled_dot_product_flash_attention_backward_cp_strategy( + op_schema: OpSchema, +) -> OpStrategy: + """ + Strategy for flash attention backward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_flash_attention_backward_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_flash_attention_backward_base_strategies(op_schema) + ) + + tensor_input_indices = [ + i + for i, arg_spec in enumerate(op_schema.args_schema) + if isinstance(arg_spec, OpStrategy) + ] + num_tensor_inputs = len(tensor_input_indices) + + # Context Parallelism: shards on the sequence dim + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # grad_q + Shard(SEQ_DIM), # grad_k + Shard(SEQ_DIM), # grad_v + Shard(SEQ_DIM), # grad_output + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + Shard(SEQ_DIM), # output + Shard(SEQ_DIM), # logsumexp + ] + cp_strategy.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +# ==================== Efficient Attention Strategies ==================== + + +def _scaled_dot_product_efficient_attention_cp_strategy( + op_schema: OpSchema, +) -> OpStrategy: + """ + Strategy for efficient attention forward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_efficient_attention_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = ( + _scaled_dot_product_efficient_attention_base_strategies(op_schema) + ) + + # Add Context Parallelism strategy + has_attn_bias = op_schema.args_schema[3] is not None + + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # output + Shard(SEQ_DIM), # logsumexp + None, # philox_seed + None, # philox_offset + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + ] + if has_attn_bias: + cp_strategy.append(Replicate()) # attn bias - not sharded for CP + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=4 + ) + + +def _scaled_dot_product_efficient_attention_backward_cp_strategy( + op_schema: OpSchema, +) -> OpStrategy: + """ + Strategy for efficient attention backward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_efficient_attention_backward_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_efficient_attention_backward_base_strategies(op_schema) + ) + + has_attn_bias = op_schema.args_schema[4] is not None + + # Context Parallelism: shards on the sequence dim + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # grad_q + Shard(SEQ_DIM), # grad_k + Shard(SEQ_DIM), # grad_v + Shard(1) if has_attn_bias else None, # grad_bias + Shard(SEQ_DIM), # grad_output + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + Shard(SEQ_DIM), # output + Shard(SEQ_DIM), # logsumexp + ] + if has_attn_bias: + cp_strategy.insert(8, Shard(1)) # attn_bias input + cp_strategy.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=4 + ) + + +# ==================== cuDNN Attention Strategies ==================== + + +def _scaled_dot_product_cudnn_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy: + """ + Strategy for cudnn attention forward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_cudnn_attention_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies( + op_schema + ) + + ( + query_strategy, + _, + _, + attn_bias_strategy, + compute_log_sumexp, + *rest_args, + ) = op_schema.args_schema + return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2] + has_attn_bias = attn_bias_strategy is not None + + # Context Parallelism: shards on the sequence dim + logsumexp_sharding = Shard(SEQ_DIM) if compute_log_sumexp else Replicate() + debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else None + + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # output + logsumexp_sharding, # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, # debug_attn_mask + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + ] + if has_attn_bias: + cp_strategy.append(Replicate()) # attn_bias - not sharded for CP + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +def _scaled_dot_product_cudnn_attention_backward_cp_strategy( + op_schema: OpSchema, +) -> OpStrategy: + """ + Strategy for cudnn attention backward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_cudnn_attention_backward_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema) + ) + + has_attn_bias = op_schema.args_schema[8] is not None + has_scale = len(op_schema.args_schema) >= 16 and False + + # Context Parallelism: shards on the sequence dim + cp_sharding_gout: PlacementList = [Shard(SEQ_DIM)] * 3 # grad_q, grad_k, grad_v + cp_sharding_ginp: PlacementList = [ + Shard(SEQ_DIM) + ] * 6 # grad_output, q, k, v, output, logsumexp + cp_sharding_ginp += [Replicate()] * 2 # philox_seed, philox_offset + cp_sharding_ginp += [Shard(SEQ_DIM) if has_attn_bias else None] # attn_bias + cp_sharding_ginp += [ + None + ] * 6 # cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal + if has_scale: + cp_sharding_ginp.append(None) + + cp_sharding = cp_sharding_gout + cp_sharding_ginp + single_mesh_dim_strategies.append(cp_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +# Store context managers and original strategies +_cp_strategy_contexts = {} +_original_strategies = {} + + +def register_cp_sharding_rules(): + """Register Context Parallelism sharding rules for all scaled_dot_product ops.""" + global _cp_strategy_contexts, _original_strategies + + # If already registered, don't register again + if _cp_strategy_contexts: + return + + # Define ops and their corresponding CP strategy functions + cp_strategies = [ + ( + aten._scaled_dot_product_flash_attention.default, + _scaled_dot_product_flash_attention_cp_strategy, + RuntimeSchemaInfo(5), + ), + ( + aten._scaled_dot_product_flash_attention_backward.default, + _scaled_dot_product_flash_attention_backward_cp_strategy, + None, + ), + ( + aten._scaled_dot_product_efficient_attention.default, + _scaled_dot_product_efficient_attention_cp_strategy, + RuntimeSchemaInfo(4), + ), + ( + aten._scaled_dot_product_efficient_attention_backward.default, + _scaled_dot_product_efficient_attention_backward_cp_strategy, + None, + ), + ( + aten._scaled_dot_product_cudnn_attention.default, + _scaled_dot_product_cudnn_attention_cp_strategy, + RuntimeSchemaInfo(4), + ), + ( + aten._scaled_dot_product_cudnn_attention_backward.default, + _scaled_dot_product_cudnn_attention_backward_cp_strategy, + None, + ), + ] + + # Register each strategy + for op_overload, strategy_func, schema_info in cp_strategies: + ctx = _op_strategy_context(op_overload, strategy_func, schema_info) + orig_funcs, orig_schema = ctx.__enter__() + _cp_strategy_contexts[op_overload] = ctx + _original_strategies[op_overload] = (orig_funcs, orig_schema) + + +def unregister_cp_sharding_rules(clear_the_cache=False): + """Unregister Context Parallelism sharding rules and restore original strategies.""" + global _cp_strategy_contexts, _original_strategies + + # Exit all context managers + for ctx in _cp_strategy_contexts.values(): + ctx.__exit__(None, None, None) + + if clear_the_cache: + _clear_fast_path_sharding_prop_cache() + _clear_python_sharding_prop_cache() + + _cp_strategy_contexts = {} + _original_strategies = {} From f2d6a75a00a1d648ca9a0abc6a33e14c3dea6c40 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 2 Dec 2025 06:43:52 +0000 Subject: [PATCH 109/338] bump timm pin (#169227) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169227 Approved by: https://github.com/huydhn --- .ci/docker/ci_commit_pins/timm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/timm.txt b/.ci/docker/ci_commit_pins/timm.txt index d8ef69d89156a..5d0b717ad4d8e 100644 --- a/.ci/docker/ci_commit_pins/timm.txt +++ b/.ci/docker/ci_commit_pins/timm.txt @@ -1 +1 @@ -5d535d7a2d4b435b1b5c1177fd8f04a12b942b9a +af3732eebe8c1964e5ba5f2769f955e6e0deb980 From b555c39217f765759954a4f9f9bd1e9b87bed11a Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Mon, 1 Dec 2025 14:34:28 -0800 Subject: [PATCH 110/338] [pallas backend] Add multi-output support to Pallas backend (#169323) 80 more passing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/169323 Approved by: https://github.com/yarongmu-google, https://github.com/jansel --- test/inductor/test_pallas.py | 18 ++++++++++ torch/_inductor/codegen/pallas.py | 58 +++++++++++++++++++++++++------ 2 files changed, 66 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py index 369013e1670b6..e8e5a37ace752 100644 --- a/test/inductor/test_pallas.py +++ b/test/inductor/test_pallas.py @@ -747,6 +747,24 @@ def fn(x): expected = fn(x) self.assertEqual(result, expected) + def test_arange_multi_output(self): + """Test arange with view and multiple outputs.""" + + def fn(x): + rng1 = torch.arange(8 * 8, dtype=torch.float32, device=x.device).view(8, 8) + rng2 = torch.arange(10, 18, device=x.device) + tmp = x * rng1 + return tmp, tmp + rng2 + + compiled = self._compile(fn) + + x = torch.randn(8, 8, device=self.DEVICE) + result = compiled(x) + expected = fn(x) + self.assertEqual(len(result), len(expected)) + for r, e in zip(result, expected): + self.assertEqual(r, e) + @unittest.skipUnless(has_cuda_pallas(), "requires jax and pallas") class PallasTestsCUDA(PallasTestsMixin, TestCase): diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 23bf0e1bbe31a..f5e64c54e77ba 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -860,13 +860,6 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove Returns: str: Complete Python source code for the Pallas kernel """ - # Ensure one (1) output for now - live_outs = list(self.args.live_output_buffers()) - if len(live_outs) != 1: - raise Unsupported( - "Pallas backend currently supports single-output elementwise kernels only" - ) - code = IndentedBuffer() # Define the Pallas kernel: accepts refs, uses broadcasted expressions @@ -985,9 +978,53 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove f"{mask_var} = jnp.arange(block_size) < {mask_var}_size" ) + # Generate iteration variables as jnp.arange arrays + # These are used by index_expr operations like torch.arange + if self.range_tree_nodes: + code.writeline("# Define iteration variables as JAX arrays") + # Get the first output buffer's shape for reshaping + first_output_shape = None + first_output_numel = None + if output_params: + first_out_param = output_params[0] + first_out_buf_name = output_buffer_lookup.get(first_out_param) + if first_out_buf_name: + try: + buf = V.graph.get_buffer(first_out_buf_name) + size = buf.get_size() + first_output_shape = tuple( + int(s) if hasattr(s, "__int__") else s for s in size + ) + first_output_numel = 1 + for s in first_output_shape: + first_output_numel *= s + except Exception: + pass + + for var_sym, entry in self.range_tree_nodes.items(): + var_name = str(var_sym) + length = entry.length + length_str = self.kexpr(length) + # If the iteration variable length matches the output numel, + # reshape it to match the output shape for proper broadcasting + try: + length_val = int(length) if hasattr(length, "__int__") else None + except (TypeError, ValueError): + length_val = None + + if ( + first_output_shape + and len(first_output_shape) > 1 + and length_val == first_output_numel + ): + shape_str = ", ".join(str(s) for s in first_output_shape) + code.writeline( + f"{var_name} = jnp.arange({length_str}).reshape({shape_str})" + ) + else: + code.writeline(f"{var_name} = jnp.arange({length_str})") + # Emit compute (CSE) and store lines; they reference *_ptr[index] directly. - # Iteration variables are implicitly handled by JAX vectorization, so - # explicit indices should be JAX-traced values. for line in self.compute._lines: code.writeline(str(line)) for line in self.stores._lines: @@ -1064,7 +1101,8 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove else " input_output_aliases={}," ) code.writeline(")(") - code.writeline(f" {', '.join(kernel_input_params)},") + if kernel_input_params: + code.writeline(f" {', '.join(kernel_input_params)},") code.writeline(")") main_name = f"{kernel_name}_main" From afdff7f0325080dedac44d080cb5a3b0e65e6c5e Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Mon, 1 Dec 2025 14:34:32 -0800 Subject: [PATCH 111/338] [pallas backend] support bitcast (#169324) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169324 Approved by: https://github.com/malfet, https://github.com/yarongmu-google, https://github.com/jansel ghstack dependencies: #169323 --- test/inductor/test_pallas.py | 27 +++++++++++++++++++++++++++ torch/_inductor/codegen/pallas.py | 8 ++++++++ 2 files changed, 35 insertions(+) diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py index e8e5a37ace752..9384d8de1b491 100644 --- a/test/inductor/test_pallas.py +++ b/test/inductor/test_pallas.py @@ -765,6 +765,33 @@ def fn(x): for r, e in zip(result, expected): self.assertEqual(r, e) + def test_dtype_bitcast(self): + """Test dtype bitcast (view tensor as different dtype).""" + + def fn(x): + # View float32 tensor as int32 (same byte size) + return x.view(torch.int32) + + compiled = self._compile(fn) + + x = torch.randn(16, device=self.DEVICE, dtype=torch.float32) + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_dtype_bitcast_float16_to_int16(self): + """Test dtype bitcast from float16 to int16.""" + + def fn(x): + return x.view(torch.int16) + + compiled = self._compile(fn) + + x = torch.randn(16, device=self.DEVICE, dtype=torch.float16) + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + @unittest.skipUnless(has_cuda_pallas(), "requires jax and pallas") class PallasTestsCUDA(PallasTestsMixin, TestCase): diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index f5e64c54e77ba..2ae68dbca575f 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -201,6 +201,14 @@ def to_dtype( # Wrap in jnp.asarray to handle scalars from integer indexing return f"jnp.asarray({x}).astype({jax_dtype})" + @staticmethod + def to_dtype_bitcast(x: str, dtype: torch.dtype, src_dtype: torch.dtype) -> str: + """Bitcast a value from one dtype to another with the same size.""" + jax_dtype = torch_dtype_to_jax(dtype) + jax_src_dtype = torch_dtype_to_jax(src_dtype) + # First ensure the value is the correct source dtype, then bitcast + return f"jax.lax.bitcast_convert_type(jnp.asarray({x}).astype({jax_src_dtype}), {jax_dtype})" + @staticmethod def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> str: """Convert a sympy expression to a JAX array indexing expression.""" From a951a9cee65c01660bbc6e6fded90ecb10fa6109 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 1 Dec 2025 17:00:54 -0800 Subject: [PATCH 112/338] [inductor] Increase tolerance for test_emulate_precision_casts_mean_ratio_chain (#169309) Fixes https://www.internalfb.com/tasks/?t=246834114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169309 Approved by: https://github.com/ezyang --- test/inductor/test_cuda_repro.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 2640f65116f4b..3cd2900051943 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1515,8 +1515,8 @@ def fn(a0, a1, a2, a3): @torch._inductor.config.patch(emulate_precision_casts=True) def test_emulate_precision_casts_mean_ratio_chain(self): - torch.manual_seed(0) - torch.cuda.manual_seed_all(0) + torch.manual_seed(12345) + torch.cuda.manual_seed_all(12345) with dynamo_config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True @@ -1561,7 +1561,7 @@ def fn(a0, a1, a2, a3, a4, a5): torch.testing.assert_close( eager_out, compiled_out, - rtol=5e-3, + rtol=5e-2, atol=1e-1, ) From 556375b55deebebbc56cb7aef81f4d52f031ba28 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 1 Dec 2025 17:00:54 -0800 Subject: [PATCH 113/338] [inductor] Increase tolerance for test_conv3d_binary_broadcast_shapes (#169310) Fixes https://www.internalfb.com/tasks/?t=246782196 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169310 Approved by: https://github.com/williamwen42 ghstack dependencies: #169309 --- test/inductor/test_mkldnn_pattern_matcher.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 440ee6a52f553..e91b7b9339ca4 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -200,10 +200,10 @@ def _test_common( maybe_autocast = torch.amp.autocast( device_type=device, dtype=torch.bfloat16 ) - atol, rtol = 1e-2, 1e-2 + atol, rtol = 5e-2, 5e-2 elif check_autocast == torch.float16 and (is_mkldnn_fp16_supported(device)): maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16) - atol, rtol = 1e-2, 1e-2 + atol, rtol = 5e-2, 5e-2 else: assert check_autocast == torch.float32 maybe_autocast = contextlib.nullcontext() @@ -576,6 +576,7 @@ def test_conv3d_binary(self, device): def _test_conv_binary_broadcast_shapes_base(self, dim=4): assert dim == 4 or dim == 5 + torch.manual_seed(12345) class M(torch.nn.Module): def __init__( @@ -676,7 +677,7 @@ def test_conv2d_binary_broadcast_shapes(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @reduced_f32_on_and_off() + @reduced_f32_on_and_off(bf32_precision=5e-2) def test_conv3d_binary_broadcast_shapes(self, device): self.device = device self._test_conv_binary_broadcast_shapes_base(dim=5) From 1e34fb2550e4aa650314f7a6d9f6daf4da7478a8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 2 Dec 2025 08:06:18 +0000 Subject: [PATCH 114/338] Revert "Refactor: Remove unnecessary ConstantVariable wrapping in raise_observed_exception (#168337)" This reverts commit fb5be221a46b51bfc9509013b0d85bc5a9d4f15b. Reverted https://github.com/pytorch/pytorch/pull/168337 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail some dynamo tests in trunk ([comment](https://github.com/pytorch/pytorch/pull/168337#issuecomment-3600738678)) --- torch/_dynamo/variables/functions.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 4f6301b1eb6c5..360c0fdd94488 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -210,7 +210,11 @@ def bind_args_cached( raise_observed_exception( TypeError, tx, - args=[f"Missing required positional argument: {name}"], + args=[ + ConstantVariable.create( + f"Missing required positional argument: {name}" + ) + ], ) # 2) *args @@ -222,7 +226,9 @@ def bind_args_cached( TypeError, tx, args=[ - f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" + ConstantVariable.create( + f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" + ) ], ) @@ -239,7 +245,11 @@ def bind_args_cached( raise_observed_exception( TypeError, tx, - args=[f"Missing required keyword-only argument: {name}"], + args=[ + ConstantVariable.create( + f"Missing required keyword-only argument: {name}" + ) + ], ) # 4) **kwargs @@ -249,7 +259,9 @@ def bind_args_cached( raise_observed_exception( TypeError, tx, - args=[f"Unexpected keyword arguments: {list(rem_kw)}"], + args=[ + ConstantVariable.create(f"Unexpected keyword arguments: {list(rem_kw)}") + ], ) return ba @@ -2982,7 +2994,7 @@ def call_function( if len(args) != 1: raise_type_error_exc( tx, - f"_get_node_type() takes 1 positional argument but {len(args)} were given", + f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", ) type_source = None if args[0].source: From 491731647f6b8a9345dcfb3bc9416aea254a7d96 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Mon, 1 Dec 2025 17:57:31 -0800 Subject: [PATCH 115/338] [Dynamo][Guards]Fix TLParse CPP guard message with sorting get_leaf_guards and verbose_code_parts (#169102) Fix #168379. 1. The results are validated in the improved testing that the ``___dict_contains`` will be sorted based on the verbose part. The first solution was also suggested in https://fb.workplace.com/groups/1075192433118967/permalink/1650742858897252/ by sorting the ``get_leaf_guards()`` in ``construct_manager_string``. 2. The second solution will be adopted the ``OrderedSet`` in setGuards during guards construction to make sure the ``contain_dict`` are displayed as the order of being added. We decided to pursuit the second options to reduce the sorting time overhead and simplicity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169102 Approved by: https://github.com/anijain2305 --- test/dynamo/test_misc.py | 31 +++++++++++++++---------------- torch/_dynamo/guards.py | 4 ++-- torch/_guards.py | 24 +++++++++++++----------- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a03537ad7d186..842355b57b94a 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1225,30 +1225,29 @@ def fn(x, y): # Filter out id-matches that won't reproduce run to run guard_code = filter( lambda line: "id" not in line and "lookup_backend" not in line, - sorted(guard_code), + guard_code, ) guard_code_str = "\n".join(guard_code) - for line in """\ -2 <= L['x'].size()[0] -L['x'] is L['y'] -L['x'].ndimension() == 2 -L['x'].requires_grad == False + # Make sure that the dict_contains are present in the order of added + self.assertExpectedInline( + guard_code_str, + """\ L['x'].size()[1] == L['x'].size()[0] L['x'].storage_offset() == 0 -___dict_contains('operator', G['sys'].modules) -___dict_contains('operator', G['sys'].modules) +2 <= L['x'].size()[0] +utils_device.CURRENT_DEVICE == None +str(L['x'].dtype) == 'torch.float32' +str(L['x'].device) == 'cpu' +L['x'].requires_grad == False +L['x'].ndimension() == 2 hasattr(L['x'], '_dynamo_dynamic_indices') == False +L['x'] is L['y'] not ___dict_contains('aaaaaaaa', G['sys'].modules) not ___dict_contains('bbbbbbbb', G['sys'].modules) -not ___dict_contains('cccccccc', G['sys'].modules) -str(L['x'].device) == 'cpu' -str(L['x'].dtype) == 'torch.float32' -utils_device.CURRENT_DEVICE == None""".split("\n"): - self.assertIn( - line, - guard_code_str, - ) +___dict_contains('operator', G['sys'].modules) +not ___dict_contains('cccccccc', G['sys'].modules)""", + ) def test_fold(self): def fn(a): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 71ddfed60df02..756996fb3f0f5 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -3871,7 +3871,7 @@ def _ref(x: Any) -> Any: }, global_scope=global_scope_state, _guards=torch._guards.GuardsSet( - { + OrderedSet( dataclasses.replace( guard, obj_weakref=None, @@ -3879,7 +3879,7 @@ def _ref(x: Any) -> Any: create_fn=normalize_create_fn(guard.create_fn), ) for guard in sorted_guards - } + ) ), input_source_to_sizes_strides=pytree.tree_map( convert_int_to_concrete_values, diff --git a/torch/_guards.py b/torch/_guards.py index c9daab1e69e81..386872c4eecfb 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -14,10 +14,11 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, TYPE_CHECKING, TypeVar +from typing import Any, Generic, NamedTuple, Optional, TYPE_CHECKING, TypeVar import torch from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._traceback import CapturedTraceback, format_frame from torch.utils.weak import WeakTensorKeyDictionary @@ -487,16 +488,16 @@ class GuardsCheckpointState: The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext """ - dynamo_guards: set[Guard] = set() + dynamo_guards: OrderedSet[Guard] - def __init__(self, dynamo_guards: set[Guard]) -> None: + def __init__(self, dynamo_guards: OrderedSet[Guard]) -> None: self.dynamo_guards = dynamo_guards - def diff(self, other: GuardsCheckpointState) -> set[Guard] | None: + def diff(self, other: GuardsCheckpointState) -> Optional[OrderedSet[Guard]]: """ Produces a delta against another GuardsCheckpointState. - Returns None if no delta is found, otherwise, return a set() of mismatched + Returns None if no delta is found, otherwise, return an OrderedSet() of mismatched Guard type objects. """ r = self.dynamo_guards.difference(other.dynamo_guards) @@ -605,10 +606,11 @@ def restore_graphstate(self, state: GlobalContextCheckpointState) -> None: # Like a Set[Guard] but will record the user stack on all guards at the # time they were installed at their destination class GuardsSet: - def __init__(self, inner: set[Guard] | None = None) -> None: + def __init__(self, inner: Optional[OrderedSet[Guard]] = None) -> None: if inner is None: - inner = set() - self.inner = inner + self.inner: OrderedSet[Guard] = OrderedSet() + else: + self.inner = inner def __iter__(self) -> Iterator[Guard]: return iter(self.inner) @@ -645,9 +647,9 @@ def remove_guards_with_source(self, source: Source) -> None: """Delete all guards that contains a given source""" from ._dynamo.source import is_from_source - self.inner = { + self.inner = OrderedSet( g for g in self.inner if not is_from_source(g.originating_source, source) - } + ) """ @@ -664,7 +666,7 @@ def __init__(self) -> None: self.aotautograd_guards: list[GuardEnvExpr] = [] def copy_graphstate(self) -> GuardsCheckpointState: - return GuardsCheckpointState(set(self.dynamo_guards.inner)) + return GuardsCheckpointState(OrderedSet(self.dynamo_guards.inner)) def restore_graphstate(self, state: GuardsCheckpointState) -> None: # NB: "steals" the passed in state From 5778f6ff894686a975a9a23645178ae4c87ad5dc Mon Sep 17 00:00:00 2001 From: Nikolay Beloborodov Date: Tue, 2 Dec 2025 10:44:57 +0000 Subject: [PATCH 116/338] [strobelight][gpusnoop] compress aoti stack (#169291) Summary: Compress aoti stack (replace full paths with filenames). Test Plan: ``` [nbeloborodov@devgpu031]~/fbsource/fbcode% strobe gpuevent --duration-ms=60000 --collect-kernel-events --kernel-sample-interval=0 --pids 1016951 Running "gpuevent" with run id -4456078642709746 and group_trace_id "" on hosts: ["::1"] Press Ctrl-C to stop the run > Queuing... (00:00:00.001) > Preparing... (00:00:04.055) > Profiling... (00:01:00.383) > Processing... (00:00:00.643) > Logging... (00:00:00.025) > Finished | Host | Return Code | Samples | Result Links | |------|-------------|---------|------------------------------------------------------------| | ::1 | SUCCESS | 4 | Raw samples: | | | | | https://fburl.com/scuba/strobelight_gpu/on_demand/zsglu6sc | | | | | | | | | | Run Details: | | | | | https://fburl.com/scuba/strobelight_runs/hmcuaz8u | ``` Differential Revision: D88005763 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169291 Approved by: https://github.com/yushangdi --- .../aoti_runtime/kernel_context_tls.h | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h b/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h index 3489494d77e4e..1001dac9cc68d 100644 --- a/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h +++ b/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h @@ -1,5 +1,8 @@ #pragma once +#include +#include +#include #include #include @@ -8,9 +11,42 @@ namespace torch::aot_inductor { struct KernelContext { std::string kernel_name; std::string python_stack; + std::string compressed_python_stack; KernelContext(std::string name, std::string stack) - : kernel_name(std::move(name)), python_stack(std::move(stack)) {} + : kernel_name(std::move(name)), python_stack(std::move(stack)) { + compressed_python_stack = compress_python_stack(python_stack); + } + + KernelContext(const KernelContext&) = default; + KernelContext& operator=(const KernelContext&) = default; + KernelContext(KernelContext&&) = default; + KernelContext& operator=(KernelContext&&) = default; + + private: + static std::string compress_python_stack(const std::string& stack) { + namespace fs = std::filesystem; + char func[129]; + char path[1025]; + uint32_t line; + int ret; + std::string compressed_stack; + std::stringstream stream{stack}; + std::string str; + std::string fmt = "File \"%1024[^\"]\", line %u, in %128[^\n]\n"; + while (std::getline(stream, str)) { + ret = sscanf(str.c_str(), fmt.c_str(), path, &line, func); + if (ret == 3) { + compressed_stack += func; + compressed_stack += ' '; + compressed_stack += fs::path{path}.filename(); + compressed_stack += ':'; + compressed_stack += std::to_string(line); + compressed_stack += '\n'; + } + } + return compressed_stack; + } }; // Thread-local pointer From 892640e25aeefa8007c5af837214b4502b6b62a6 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 1 Dec 2025 11:38:23 -0800 Subject: [PATCH 117/338] Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (#167496) Fixes https://github.com/pytorch/pytorch/issues/161660 This extends the `TORCH_STABLE_ONLY` stopgap added in https://github.com/pytorch/pytorch/pull/161658 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167496 Approved by: https://github.com/janeyx99, https://github.com/malfet, https://github.com/atalman --- setup.py | 63 +++++++++++++++++ .../libtorch_agnostic_2_10_extension/setup.py | 1 - .../torch_stable_test_extension/setup.py | 67 ------------------- .../torch_stable_test/__init__.py | 0 .../torch_stable_test/csrc/test_extension.cpp | 1 - .../torch_stable_test/test_torch_stable.py | 22 ------ torch/csrc/inductor/aoti_torch/c/shim.h | 6 +- 7 files changed, 66 insertions(+), 94 deletions(-) delete mode 100644 test/cpp_extensions/torch_stable_test_extension/setup.py delete mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py delete mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp delete mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py diff --git a/setup.py b/setup.py index 314f719ea67f0..f15e7bbdd0ac4 100644 --- a/setup.py +++ b/setup.py @@ -1089,6 +1089,60 @@ def check_pydep(importname: str, module: str) -> None: class build_ext(setuptools.command.build_ext.build_ext): + def _wrap_headers_with_macro(self, include_dir: Path) -> None: + """Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION). + + Excludes: + - torch/headeronly/* + - torch/csrc/stable/* + - torch/csrc/inductor/aoti_torch/c/ (only shim headers) + - torch/csrc/inductor/aoti_torch/generated/ + + This method is idempotent - it will not wrap headers that are already wrapped. + """ + header_extensions = (".h", ".hpp", ".cuh") + header_files = [ + f for ext in header_extensions for f in include_dir.rglob(f"*{ext}") + ] + + # Paths to exclude from wrapping (relative to include_dir) + exclude_dir_patterns = [ + "torch/headeronly/", + "torch/csrc/stable/", + "torch/csrc/inductor/aoti_torch/c/", + "torch/csrc/inductor/aoti_torch/generated/", + ] + + # Marker to detect if a header is already wrapped + wrap_start_marker = ( + "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + ) + + for header_file in header_files: + rel_path = header_file.relative_to(include_dir).as_posix() + + if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns): + report(f"Skipping header: {rel_path}") + continue + + original_content = header_file.read_text(encoding="utf-8") + + # Check if already wrapped (idempotency check) + if original_content.startswith(wrap_start_marker): + report(f"Already wrapped, skipping: {rel_path}") + continue + + wrapped_content = ( + wrap_start_marker + + f"{original_content}" + + "\n#else\n" + + '#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."\n' + + "#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + ) + + header_file.write_text(wrapped_content, encoding="utf-8") + report(f"Wrapped header: {rel_path}") + def _embed_libomp(self) -> None: # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS build_lib = Path(self.build_lib) @@ -1256,6 +1310,15 @@ def run(self) -> None: super().run() + # Wrap headers with TORCH_STABLE_ONLY and TORCH_TARGET_VERSION guards + build_lib = Path(self.build_lib) + build_torch_include_dir = build_lib / "torch" / "include" + if build_torch_include_dir.exists(): + report( + "-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)" + ) + self._wrap_headers_with_macro(build_torch_include_dir) + if IS_DARWIN: self._embed_libomp() diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py index 7bc37ba238139..af3bccf33be03 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py @@ -35,7 +35,6 @@ def get_extension(): extra_compile_args = { "cxx": [ "-fdiagnostics-color=always", - "-DTORCH_STABLE_ONLY", "-DTORCH_TARGET_VERSION=0x020a000000000000", ], } diff --git a/test/cpp_extensions/torch_stable_test_extension/setup.py b/test/cpp_extensions/torch_stable_test_extension/setup.py deleted file mode 100644 index 062d466e7ae98..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/setup.py +++ /dev/null @@ -1,67 +0,0 @@ -import distutils.command.clean -import shutil -from pathlib import Path - -from setuptools import find_packages, setup - -from torch.utils.cpp_extension import BuildExtension, CppExtension - - -ROOT_DIR = Path(__file__).parent -CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc" - - -class clean(distutils.command.clean.clean): - def run(self): - # Run default behavior first - distutils.command.clean.clean.run(self) - - # Remove extension - for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"): - path.unlink() - # Remove build and dist and egg-info directories - dirs = [ - ROOT_DIR / "build", - ROOT_DIR / "dist", - ROOT_DIR / "torch_stable_test.egg-info", - ] - for path in dirs: - if path.exists(): - shutil.rmtree(str(path), ignore_errors=True) - - -def get_extension(): - extra_compile_args = { - "cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"], - } - - sources = list(CSRC_DIR.glob("**/*.cpp")) - - return [ - CppExtension( - "torch_stable_test._C", - sources=sorted(str(s) for s in sources), - py_limited_api=True, - extra_compile_args=extra_compile_args, - extra_link_args=[], - ) - ] - - -setup( - name="torch_stable_test", - version="0.0", - author="PyTorch Core Team", - description="Test extension to verify TORCH_STABLE_ONLY flag", - packages=find_packages(exclude=("test",)), - package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]}, - install_requires=[ - "torch", - ], - ext_modules=get_extension(), - cmdclass={ - "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), - "clean": clean, - }, - options={"bdist_wheel": {"py_limited_api": "cp39"}}, -) diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp deleted file mode 100644 index c92d56da11ba3..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp +++ /dev/null @@ -1 +0,0 @@ -#include // This should trigger the TORCH_STABLE_ONLY error diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py deleted file mode 100644 index 5c5613bb5484e..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py +++ /dev/null @@ -1,22 +0,0 @@ -# Owner(s): ["module: cpp"] - -from pathlib import Path - -from torch.testing._internal.common_utils import ( - install_cpp_extension, - IS_WINDOWS, - run_tests, - TestCase, -) - - -if not IS_WINDOWS: - - class TestTorchStable(TestCase): - def test_setup_fails(self): - with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"): - install_cpp_extension(extension_root=Path(__file__).parent.parent) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 4fb746ea15271..2eda2b218e705 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -38,9 +38,9 @@ // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check -#include -#include -#include +#include +#include +#include #ifdef __cplusplus extern "C" { From 70d797a5fc109b20a517646fcaa819477cd0d485 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 2 Dec 2025 03:20:11 -0800 Subject: [PATCH 118/338] [dist] add reduce_scatter_out (#168260) Adding reduce_scatter_tensor_out to use in fx passes to efficiently decompose reduce_scatter without concatenation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168260 Approved by: https://github.com/wconstab --- .../test_c10d_functional_native.py | 16 +++++ torch/_inductor/comm_lowering.py | 12 ++++ torch/csrc/distributed/c10d/Functional.cpp | 58 ++++++++++++++++++- torch/csrc/distributed/c10d/Functional.hpp | 7 +++ torch/distributed/_functional_collectives.py | 11 ++++ torch/distributed/_tools/fake_collectives.py | 1 + 6 files changed, 104 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 473198e5421c5..b124315208af7 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -329,6 +329,22 @@ def test_reduce_scatter_tensor_single(self) -> None: assert output.eq(self.rank).all() assert output.completed + @skip_if_lt_x_gpu(2) + def test_reduce_scatter_tensor_out(self) -> None: + self._init_process_group() + + input = torch.tensor(self.ranks, device=self.device) + out = torch.tensor([-1], device=self.device) + w = torch.ops._c10d_functional.reduce_scatter_tensor_out( + input, + "avg", + self.world_size, + "default", + out=out, + ) + torch.ops._c10d_functional.wait_tensor(w) + assert out.eq(self.rank).all() + @skip_if_lt_x_gpu(2) def test_reduce_scatter_tensor_coalesced(self) -> None: self._init_process_group() diff --git a/torch/_inductor/comm_lowering.py b/torch/_inductor/comm_lowering.py index 5ec3d2bba7908..1f6cc5ee3e726 100644 --- a/torch/_inductor/comm_lowering.py +++ b/torch/_inductor/comm_lowering.py @@ -311,6 +311,18 @@ def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): group_name, ) + @register_comm_lowering(c10d.reduce_scatter_tensor_out) + def _reduce_scatter_tensor_out(inp, reduce_op, group_size, group_name, *, out): + ir._CollectiveKernel.create_inplace( + c10d.reduce_scatter_tensor_out.default, + inp, + reduce_op, + group_size, + group_name, + out=out, + ) + return out + @register_comm_lowering(c10d.reduce_scatter_tensor_coalesced) def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): return pytree.tree_map( diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index 16530f0e65028..c21c5f9129acb 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -203,6 +203,25 @@ std::vector reduce_scatter_tensor_coalesced( return outputs; } +static std::vector reduce_scatter_tensor_coalesced_out( + std::vector inputs, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string reduce_op, + int64_t group_size, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string group_name, + std::vector& outputs) { + c10d::ReduceScatterOptions opts; + opts.reduceOp = to_reduce_op(reduce_op); + + auto group = c10d::resolve_process_group(std::move(group_name)); + auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts); + for (const auto& tensor : outputs) { + c10d::register_work(tensor, work); + } + return outputs; +} + at::Tensor reduce_scatter_tensor( const at::Tensor& input, std::string reduce_op, @@ -220,6 +239,36 @@ at::Tensor reduce_scatter_tensor( inputs, std::move(reduce_op), group_size, std::move(group_name))[0]; } +at::Tensor reduce_scatter_tensor_out( + const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + std::string group_name, + at::Tensor& output) { + TORCH_CHECK(input.is_contiguous()); + if (input.is_complex()) { + TORCH_CHECK(output.is_complex()) + auto real_input = at::view_as_real(input); + std::vector inputs{std::move(real_input)}; + auto real_output = at::view_as_real(output); + std::vector outputs{std::move(real_output)}; + return at::view_as_complex(reduce_scatter_tensor_coalesced_out( + inputs, + std::move(reduce_op), + group_size, + std::move(group_name), + outputs)[0]); + } + std::vector inputs{std::move(input)}; + std::vector outputs{std::move(output)}; + return reduce_scatter_tensor_coalesced_out( + inputs, + std::move(reduce_op), + group_size, + std::move(group_name), + outputs)[0]; +} + at::Tensor all_to_all_single( const at::Tensor& input, c10::SymIntArrayRef _output_split_sizes, @@ -243,7 +292,7 @@ at::Tensor all_to_all_single( output_split_sizes.begin(), output_split_sizes.end(), int64_t(0)); auto output = input.new_empty(output_sizes); - auto group = c10d::resolve_process_group(group_name); + auto group = c10d::resolve_process_group(std::move(group_name)); auto work = group->alltoall_base( output, // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -332,6 +381,13 @@ TORCH_LIBRARY(_c10d_functional, m) { c10d::reduce_scatter_tensor), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); + m.def( + "reduce_scatter_tensor_out(Tensor input, str reduce_op, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)", + torch::dispatch( + c10::DispatchKey::CompositeExplicitAutograd, + c10d::reduce_scatter_tensor_out), + {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); + m.def( "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]", torch::dispatch( diff --git a/torch/csrc/distributed/c10d/Functional.hpp b/torch/csrc/distributed/c10d/Functional.hpp index 553ba296cc52c..9c0ccbe1b0f2c 100644 --- a/torch/csrc/distributed/c10d/Functional.hpp +++ b/torch/csrc/distributed/c10d/Functional.hpp @@ -58,6 +58,13 @@ C10_EXPORT at::Tensor reduce_scatter_tensor( int64_t group_size, std::string group_name); +C10_EXPORT at::Tensor reduce_scatter_tensor_out( + const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + std::string group_name, + at::Tensor& output); + C10_EXPORT at::Tensor all_to_all_single( const at::Tensor& input, at::SymIntArrayRef output_split_sizes, diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 9308a63d9e7c2..391facb10f508 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -1000,6 +1000,14 @@ def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name): return inp.new_empty(shape) +def _reduce_scatter_tensor_out_native_meta( + inp, reduce_op, group_size, group_name, *, out +): + shape = list(inp.size()) + shape[0] //= group_size + return inp.new_empty(shape) + + def _reduce_scatter_tensor_coalesced_native_meta( inputs, reduce_op, group_size, group_name ): @@ -1026,6 +1034,9 @@ def _reduce_scatter_tensor_coalesced_native_meta( "Meta", ) lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") +lib_impl.impl( + "reduce_scatter_tensor_out", _reduce_scatter_tensor_out_native_meta, "Meta" +) lib_impl.impl( "reduce_scatter_tensor_coalesced", _reduce_scatter_tensor_coalesced_native_meta, diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index 18bb1a02a0055..0ac0f8a764d3e 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -98,6 +98,7 @@ def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def _c10d_functional.all_reduce.default, _c10d_functional.all_gather_into_tensor.default, _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional.reduce_scatter_tensor_out.default, _c10d_functional.all_to_all_single.default, _c10d_functional_autograd.all_to_all_single.default, _c10d_functional.wait_tensor.default, From 285779b1621cf9f073a062b0889a642d200308d9 Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Tue, 2 Dec 2025 15:14:26 +0000 Subject: [PATCH 119/338] [Accelerator] Add Accelerator Capabilities API (#165631) # Motivation There are several issues related to the data type and precision that an accelerator supports (see #165038 and #143112). Sometimes, we have to check for these capabilities in the document, and then hard-code. This PR proposes a new unified API for users to check their accelerator capabilities. # Changes This PR creates a new data structure `DeviceCapability` containing the capabilities that an accelerator commonly has: - Supporting DataType (set to be supported as default): - `fp16`, `int32`, `complex` ... etc - Other capabilities (need to be discussed) To access the structure, this PR defines a new Python API in the Accelerator module -- `get_device_capability`. It takes `device` as an input and returns a dictionary containing the capabilities (now we have `supported_dtypes` as the key). # Usage ```python >>> import torch >>> import torch_openreg >>> torch.accelerator.get_device_capability('openreg:0') {'supported_dtypes': [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, torch.complex32, torch.complex64, torch.complex128, torch.bool, torch.qint8, torch.quint8, torch.qint32, torch.bfloat16, torch.quint4x2, torch.quint2x4, torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16, torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e5m2fnuz, torch.float8_e4m3fnuz, torch.uint16, torch.uint32, torch.uint64, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.int1, torch.int2, torch.int3, torch.int4, torch.int5, torch.int6, torch.int7, torch.float8_e8m0fnu, torch.float4_e2m1fn_x2]} ``` # TODO - So far, precision is the only capability to track, based on my knowledge. But we can find more capabilities in common, and the API should be designed for good extension. - It will support other in-tree accelerators, such as **cuda** and **mps**. - Clarify whether the capabilities are software or hardware supported. (By @guangyey ) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165631 Approved by: https://github.com/guangyey, https://github.com/albanD Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Co-authored-by: Jiawei Li --- aten/src/ATen/DeviceAccelerator.cpp | 6 ++ aten/src/ATen/DeviceAccelerator.h | 5 ++ c10/core/DeviceCapability.h | 74 +++++++++++++++++++ c10/core/impl/DeviceGuardImplInterface.h | 26 +++++++ c10/core/impl/VirtualGuardImpl.h | 4 + .../torch_openreg/csrc/runtime/OpenRegGuard.h | 9 +++ .../torch_openreg/tests/test_device.py | 9 ++- torch/_C/__init__.pyi.in | 1 + torch/accelerator/__init__.py | 27 ++++++- torch/csrc/DeviceAccelerator.cpp | 19 +++++ 10 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 c10/core/DeviceCapability.h diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index aa9d6e6b1ce9b..efab9ec9c5927 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -130,6 +130,12 @@ c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) { impl.uncheckedSetDevice({device_type, device_index}); return impl.getDevice().index(); } + +c10::DeviceCapability getDeviceCapability(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + c10::impl::VirtualGuardImpl impl(device_type); + return impl.getDeviceCapability({device_type, device_index}); +} // NOLINTEND(bugprone-unchecked-optional-access) } // namespace at::accelerator diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index 2cc4cff7cd1f2..d24b42ca459e7 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -73,6 +74,10 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +// Get the device capability of the given device index. +TORCH_API c10::DeviceCapability getDeviceCapability( + c10::DeviceIndex device_index); + TORCH_API inline void emptyCache() { const auto device_type = getAccelerator(true).value(); at::getDeviceAllocator(device_type)->emptyCache(); diff --git a/c10/core/DeviceCapability.h b/c10/core/DeviceCapability.h new file mode 100644 index 0000000000000..e24f12614978a --- /dev/null +++ b/c10/core/DeviceCapability.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +constexpr size_t NUMBER_OF_DEVICE_CAPABILITIES = NumScalarTypes; + +// Generate bitfields for each scalar type +#define DEFINE_SCALAR_TYPE(_1, n) unsigned int has_##n : 1; + +// Generate enum indices for each scalar type +#define DEFINE_SCALAR_ENUM(_1, name) kIndex_##name, + +enum ScalarTypeIndex { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_ENUM) +}; + +/** + * @brief DeviceCapability represents the the common capabilities that all + * devices should support. + * + * This struct provides a compact way to represent the common capabilities that + * all devices should support. Includes the following capabilities: + * - Supported data types + * + * Purpose + * - Enable device-specific optimizations based on supported capabilities + * + * Contract + * + * Supported data types: + * - Each bitfield represents support for one device capability + * - Bit value 1 means the capability is supported, 0 means not supported + * - The struct is initialized with all capabilities enabled by default + * + * @note Adding New Capabilities + * + * 1. Define the new capability in the `DeviceCapability` struct + * 2. Update the support of the new capability in each accelerator + * implementation + * 3. Add the new capability to the returned PyObject Dictionary + */ +struct C10_API DeviceCapability { + union { + struct { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE) + }; + uint64_t capability_bits; // Allow direct bit manipulation + }; + + // Default constructor with all capabilities enabled. + DeviceCapability() + : capability_bits((1ULL << NUMBER_OF_DEVICE_CAPABILITIES) - 1) {} + + // Iterate supported ScalarTypes without allocating a vector + template + void forEachSupportedScalarType(F&& visitor) const { +#define VISIT_SCALAR_TYPE(_1, n) \ + if (has_##n) { \ + visitor(ScalarType::n); \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(VISIT_SCALAR_TYPE) + +#undef VISIT_SCALAR_TYPE + } +}; + +#undef DEFINE_SCALAR_ENUM +#undef DEFINE_SCALAR_TYPE +} // namespace c10 diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index f9f67497c6315..00096584b9229 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -191,6 +192,15 @@ struct C10_API DeviceGuardImplInterface { */ virtual DeviceIndex deviceCount() const noexcept = 0; + /** + * Get the following capabilities of the current device: + * (1) Data type support + * Returns DeviceCapability object. + */ + virtual DeviceCapability getDeviceCapability(Device /*unused*/) const { + TORCH_CHECK(false, "Backend doesn't support getting device capabilities."); + } + /** * Return true if all the work previously enqueued on the stream for * asynchronous execution has completed running on the device. @@ -291,6 +301,22 @@ struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface { return 1; } + DeviceCapability getDeviceCapability(Device /*unused*/) const override { + DeviceCapability cap; + if constexpr (D == DeviceType::Meta) { + cap.capability_bits = 0; + // Meta only supports basic types for shape inference + // Byte, Char, Short, Int, Long, Float, Double, + // Bool, ComplexFloat, ComplexDouble + cap.capability_bits = (1ULL << kIndex_Byte) | (1ULL << kIndex_Char) | + (1ULL << kIndex_Short) | (1ULL << kIndex_Int) | + (1ULL << kIndex_Long) | (1ULL << kIndex_Float) | + (1ULL << kIndex_Double) | (1ULL << kIndex_ComplexFloat) | + (1ULL << kIndex_ComplexDouble) | (1ULL << kIndex_Bool); + } + return cap; + } + // Event-related functions void record( void** /*event*/, diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 3d259f5e390e3..0254c69baba00 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -57,6 +57,10 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { return impl_->deviceCount(); } + DeviceCapability getDeviceCapability(Device d) const override { + return impl_->getDeviceCapability(d); + } + // Event functions void record( void** event, diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h index 59bc2d5cdbff5..3c1c1193d3cdb 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -50,6 +51,14 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { return c10::Device(static_type, device_index); } + /** + * Get the device capability for a given device. + * By default, OpenReg has 2 same devices with the same capability. + */ + c10::DeviceCapability getDeviceCapability(c10::Device /*unused*/) const override { + return c10::DeviceCapability(); + } + /** * Set the current device to c10::Device. */ diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py index f925f15600ce7..9cb4a785d36e7 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py @@ -1,7 +1,7 @@ # Owner(s): ["module: PrivateUse1"] import torch -import torch_openreg # noqa: F401 +from torch.testing._internal.common_dtype import get_all_dtypes from torch.testing._internal.common_utils import run_tests, TestCase @@ -31,6 +31,13 @@ def test_invalid_device_index(self): with self.assertRaisesRegex(RuntimeError, "The device index is out of range"): torch.accelerator.set_device_index(2) + def test_device_capability(self): + capability = torch.accelerator.get_device_capability("openreg:0") + supported_dtypes = capability["supported_dtypes"] + expected_dtypes = get_all_dtypes(include_complex32=True, include_qint=True) + + self.assertTrue(all(dtype in supported_dtypes for dtype in expected_dtypes)) + if __name__ == "__main__": run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 532815d535d5e..520d07d487270 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2494,6 +2494,7 @@ def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails def _accelerator_getAccelerator() -> _device: ... def _accelerator_setDeviceIndex(device_index: _int) -> None: ... def _accelerator_getDeviceIndex() -> _int: ... +def _accelerator_getDeviceCapability(device_index: _int) -> dict[str, Any]: ... def _accelerator_setStream(Stream) -> None: ... def _accelerator_getStream(device_index: _int) -> Stream: ... def _accelerator_synchronizeDevice(device_index: _int) -> None: ... diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index e1a82aa63ce22..b0dfbe400bfbc 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -2,7 +2,8 @@ This package introduces support for the current :ref:`accelerator` in python. """ -from typing import Optional +from functools import cache +from typing import Any from typing_extensions import deprecated import torch @@ -25,6 +26,7 @@ "current_accelerator", "current_device_idx", # deprecated "current_device_index", + "get_device_capability", "current_stream", "device_count", "device_index", @@ -152,6 +154,29 @@ def current_device_index() -> int: """ +@cache +def get_device_capability(device: _device_t = None, /) -> dict[str, Any]: + r"""Return the capability of the currently selected device. + + Args: + device (:class:`torch.device`, str, int, optional): The device to query capabilities for + :ref:`accelerator` device type. If not given, + use :func:`torch.accelerator.current_device_index` by default. + + Returns: + dict[str, Any]: A dictionary containing device capability information. The dictionary includes: + - ``supported_dtypes`` (set(torch.dtype)): Set of PyTorch data types supported by the device + + Examples: + >>> # xdoctest: +SKIP("requires cuda") + >>> # Query capabilities for current device + >>> capabilities = torch.accelerator.get_device_capability("cuda:0") + >>> print("Supported dtypes:", capabilities["supported_dtypes"]) + """ + device_index = _get_device_index(device, optional=True) + return torch._C._accelerator_getDeviceCapability(device_index) + + def set_device_index(device: _device_t, /) -> None: r"""Set the current device index to a given device. diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 14e54851178f5..c6ffa893d95ae 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -33,6 +33,25 @@ void initModule(PyObject* module) { return at::accelerator::getDeviceIndex(); }); + m.def("_accelerator_getDeviceCapability", [](c10::DeviceIndex device_index) { + const auto device_type = at::accelerator::getAccelerator(true).value(); + torch::utils::maybe_initialize_device(device_type); + auto caps = at::accelerator::getDeviceCapability(device_index); + + py::dict dict; + + py::set dtype_set; + caps.forEachSupportedScalarType([&](c10::ScalarType dtype) { + THPDtype* thp_dtype = torch::getTHPDtype(dtype); + py::object dtype_obj = + py::reinterpret_borrow((PyObject*)thp_dtype); + dtype_set.add(dtype_obj); + }); + + dict["supported_dtypes"] = dtype_set; + return dict; + }); + m.def("_accelerator_setStream", [](c10::Stream stream) { const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); From 174272c15fae553d8488140af931f7d8050a313f Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Tue, 2 Dec 2025 15:31:07 +0000 Subject: [PATCH 120/338] [precompile] Include inductor-generated Dynamo guards in AOT (#169239) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/169239 Approved by: https://github.com/jamesjwu, https://github.com/laithsakka --- torch/_dynamo/aot_compile.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index 196c073d6df99..14309dbe15541 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -211,16 +211,15 @@ def new_guard_filter_fn( hooks.guard_filter_fn = new_guard_filter_fn fn, _ = convert_frame.get_traced_fn(model) - check_fn = graph_capture_output.build_guards( - fn.__code__, hooks=hooks, save=True, strict_error=True - ) - - assert check_fn.guards_state is not None backend_input = capture_output.backend_input assert backend_input is not None backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment] device_type = _graph_device_type(backend_input.graph_module.graph) + assert ( + backend_input.fake_mode.shape_env + is graph_capture_output.output_graph.shape_env + ) tracing_context = TracingContext(backend_input.fake_mode) tracing_context.tensor_to_context = backend_input.tensor_to_context with ( @@ -250,6 +249,12 @@ def new_guard_filter_fn( + f"from backend {compiler_fn}) does not implement SerializableCallable." ) + check_fn = graph_capture_output.build_guards( + fn.__code__, hooks=hooks, save=True, strict_error=True + ) + + assert check_fn.guards_state is not None + source_info = SourceInfo(inlined_sources=set()) for traced_code in graph_capture_output.traced_code: source_info.add_code(traced_code) From 9ff4a2ebc5762d46c73e46b1b523d7ff349fedfa Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 1 Dec 2025 14:33:33 -0800 Subject: [PATCH 121/338] Support AC in default partitioner when functionalization is enabled (#166610) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166610 Approved by: https://github.com/SherlockNoMad --- .../distributed/tensor/test_dtensor_export.py | 2 - test/dynamo/test_activation_checkpointing.py | 267 ++++++++++--- test/functorch/test_aotdispatch.py | 15 +- test/higher_order_ops/test_local_map.py | 4 +- .../_aot_autograd/functional_utils.py | 20 +- .../_aot_autograd/graph_capture_wrappers.py | 5 + torch/_functorch/partitioners.py | 368 +++++++++++------- torch/_higher_order_ops/local_map.py | 7 + 8 files changed, 484 insertions(+), 204 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index bd75668ab4856..4a88cf9a6e0b1 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: distributed"] import contextlib -import unittest import torch import torch.distributed as dist @@ -357,7 +356,6 @@ def test_export_parallelize_module_with_dtensor_input( # aot_export_joint_with_descriptors on strict-exported exported_program.module() # is producing a joint graph with backward region missing - @unittest.expectedFailure def test_strict_export_parallelize_module_with_dtensor_input(self): self._run_test(strict_export_and_aot_export_joint_with_descriptors) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 0d32a9e4917f5..768555efd1d4c 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -15,7 +15,7 @@ import torch.distributed as dist import torch.nn as nn import torch.utils.checkpoint -from functorch.compile import min_cut_rematerialization_partition +from functorch.compile import default_partition, min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd from torch._dynamo.testing import ( AotEagerAndRecordGraphs, @@ -24,7 +24,7 @@ ) from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu +from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor @@ -281,7 +281,14 @@ def runtime_wrapper(*runtime_args): run(export_compiler) - def test_tags_function(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -297,11 +304,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_function_via_global_checkpoint(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_via_global_checkpoint(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -316,17 +334,28 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_function_with_kwargs(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_with_kwargs(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) def fn(x, y): return torch.utils.checkpoint.checkpoint( - gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False + gn, torch.sin(x), y, use_reentrant=False ) x = torch.randn(4, 4, device=device, requires_grad=True) @@ -336,11 +365,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_sequential_layers(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_sequential_layers(self, device, partition_fn): def gn(x): x = x.cos() for _ in range(3): @@ -361,11 +401,22 @@ def fn(x): freqs=[2, 18], ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default], ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x) @requires_cuda_and_triton - def test_tags_multiple_checkpoints(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_multiple_checkpoints(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -383,11 +434,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=6, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_module(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_module(self, device, partition_fn): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -411,11 +473,22 @@ def fn(x): bw_compiler = functools.partial( count_ops, freq=1, op=torch.ops.aten.sigmoid.default ) - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x) @requires_cuda_and_triton - def test_tags_decomps(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_decomps(self, device, partition_fn): # Ensures that tags are passed on through decompositions as well class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -443,6 +516,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, + partition_fn=partition_fn, decompositions=lambda: import_module( "torch._inductor.compile_fx" ).select_decomp_table(), @@ -702,7 +776,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_must_recompute(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn): def context_fn_must_recompute_mm(): must_recompute_list = [ torch.ops.aten.mm.default, @@ -723,9 +804,9 @@ def context_fn_no_recompute_mm(): ), ) - def _test(context_fn, bw_compiler): + def _test(context_fn, bw_compiler, partition_fn): def gn(x): - return torch.sigmoid(torch.matmul(x, x)) + return torch.cos(torch.sin(torch.matmul(x, x) @ x)) def fn(x): return torch.utils.checkpoint.checkpoint( @@ -739,14 +820,14 @@ def fn(x): fw_compiler = functools.partial( count_ops, - freq=1, + freq=2, op=torch.ops.aten.mm.default, ) backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x) @@ -754,17 +835,19 @@ def fn(x): context_fn=context_fn_must_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3) + freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6) op=torch.ops.aten.mm.default, ), + partition_fn=partition_fn, ) _test( context_fn=context_fn_no_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=2, # 2 bwd mm ops per fwd matmul + freq=4, # 2 bwd mm ops per fwd matmul op=torch.ops.aten.mm.default, ), + partition_fn=partition_fn, ) def test_sac_with_partial_context_fn(self): @@ -801,7 +884,16 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_not_recompute_gemm( + self, device, partition_fn + ): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -841,15 +933,22 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization( - self, device + self, device, partition_fn ): def selective_checkpointing_context_fn(): no_recompute_list = [ @@ -889,7 +988,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, disable_functionalization=True, ) self._validate(fn, backend, x, y) @@ -897,7 +996,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_triton_kernel(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn): # Copy of the above test, but make sure that having a triton kernel in the # region does not error. def add_one(x): @@ -957,14 +1063,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_tensor_subclass(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1007,14 +1120,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_custom_rule(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn): def _get_custom_policy(meta): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1072,14 +1192,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_partial_ctx_fn(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn): def selective_checkpointing_context_fn(no_recompute_list): return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) @@ -1118,14 +1245,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_outplace_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1163,14 +1297,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_list_ops(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_list_ops(self, device, partition_fn): def selective_checkpointing_context_fn(): # recompute everything no_recompute_list = [] @@ -1206,7 +1347,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1217,7 +1358,14 @@ def fn(x, y): "requires TorchDispatchMode + torch.compile work to complete" ) @requires_cuda_and_triton - def test_compile_selective_checkpoint_inplace_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1257,7 +1405,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1265,7 +1413,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @torch._inductor.config.patch(fallback_random=True) - def test_compile_selective_checkpoint_random_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_random_op(self, device, partition_fn): for preserve_rng_state in [True, False]: def selective_checkpointing_context_fn(): @@ -1312,7 +1467,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, @@ -1324,7 +1479,14 @@ def fn(x): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_invalid_context(self): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_invalid_context(self, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) * y @@ -1353,7 +1515,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) with self.assertRaisesRegex( Exception, "must generate a tuple of two `TorchDispatchMode`s" @@ -1362,7 +1524,14 @@ def fn(x, y): @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) - def test_compile_selective_checkpoint_parametrization(self): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_parametrization(self, partition_fn): def sac_policy(): def _recomp_policy(): def _custom_policy(ctx, func, *args, **kwargs): @@ -1425,7 +1594,9 @@ def reset_parameters(self): bw_compiler = functools.partial( count_ops, freqs=[ - 2, # 1 from mul recompute, 1 from mul backward + # 1 from mul recompute, 1 from mul backward + # w/o CSE, we have one extra mul + 3 if partition_fn is default_partition else 2, 1, ], ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], @@ -1434,7 +1605,7 @@ def reset_parameters(self): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) model = MLPModule() diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 6cae42d8929da..c452f18e95d75 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2640,7 +2640,7 @@ def backward(ctx, grad_output): return grad_output * x, grad_output * x def f(a, b): - return FwBwMutation.apply(a, b) + return FwBwMutation.apply(a, b).sin_().clone() inps = [ torch.ones(3, 3, requires_grad=True), @@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2): add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None _foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None - return (mul, add)""", + clone = torch.ops.aten.clone.default(mul) + sin_ = torch.ops.aten.sin_.default(mul); mul = None + clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None + return (clone_1, add, clone)""", ) # important bit: there is 1 mutation in the bw self.assertExpectedInline( bw_graph[0].code.strip(), """\ -def forward(self, add, tangents_1): +def forward(self, add, clone, tangents_1): + cos = torch.ops.aten.cos.default(clone); clone = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None _foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None - mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None - return (mul_1, None)""", + mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None + return (mul_2, None)""", ) def test_fw_bw_mutation_no_functionalization2(self): diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index a585f2055e89f..7b5f01d236e7f 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -911,8 +911,8 @@ def inputs_fn(): op="call_function", target=torch.ops.aten.mm.default ) self.assertEqual(len(mm_nodes), 4) - self.assertNotIn("partitioner_tag", mm_nodes[0].meta) - self.assertNotIn("partitioner_tag", mm_nodes[1].meta) + self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward") + self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward") self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0) diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index fcbf861e537db..5af4fc9ee1195 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -10,6 +10,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional import torch from torch import Tensor @@ -449,7 +450,7 @@ def was_tensor_metadata_updated(arg, new_arg): # Returns the number of detected copy_ -def assert_functional_graph(fx_g: torch.fx.Graph) -> int: +def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]: allowed_mutation_ops = [ torch.ops.aten.copy_.default, torch.ops.aten.set_.source_Tensor, @@ -462,6 +463,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: # NB: It would also be nice to verify that the mutations all happen at the # end, but we also do some administrative views after mutations so this # isn't actually true. (TODO: Could this cause problems for Inductor?) + error = None for n in fx_g.nodes: if n.op == "placeholder": placeholders.add(n) @@ -471,14 +473,18 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: # this is mostly a hack to avoid failing XLA tests. # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 if "set_buffer_donor_" not in str(n.args[0]): - assert n.args[0] in placeholders, ( - f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" - ) + if n.args[0] not in placeholders: + error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" mutation_count += 1 else: - assert not n.target._schema.is_mutable, ( - f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" - ) + if n.target._schema.is_mutable: + error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + return error, mutation_count + + +def assert_functional_graph(fx_g: torch.fx.Graph) -> int: + error, mutation_count = _is_functional_graph(fx_g) + assert error is None, error return mutation_count diff --git a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py index bc4dc87ddeced..2ef84cb488604 100644 --- a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -27,6 +27,7 @@ from torch._prims_common import CUDARngStateHelper from torch.fx.experimental.proxy_tensor import ( _proxy_tensor_disable_update_tensor_tracker, + get_proxy_mode, maybe_disable_thunkify, maybe_enable_thunkify, ) @@ -295,6 +296,10 @@ def inner_fn( (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs( fn, primals ) + mode = get_proxy_mode() + assert mode is not None, "Expected non-None proxy mode" + for node in mode.tracer.graph.nodes: + node.meta["partitioner_tag"] = "is_forward" # TODO: I think this hook can also be eliminated now if joint_fn_handle and joint_fn_handle.post_forward: diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index c273ba39ce167..f98aca82fe328 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -10,6 +10,7 @@ import os import os.path import re +import warnings from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass, replace @@ -52,6 +53,7 @@ ) from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput +from ._aot_autograd.functional_utils import _is_functional_graph from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems @@ -298,6 +300,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "is_backward" +def _has_tag_is_forward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "is_forward" + + def _has_tag_must_be_in_forward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "must_be_in_forward" @@ -1022,69 +1028,91 @@ def default_partition( Returns: Returns the generated forward and backward Fx graph modules. """ - if has_recomputable_ops(joint_module): - return min_cut_rematerialization_partition( - joint_module, - _joint_inputs, - num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, - ) - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) + # Respect the original placement of ops rather than rely on dataflow. + forward_nodes = [] + last_node = None + for node in joint_module.graph.nodes: + if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node): + last_node = node + assert last_node is not None + for node in joint_module.graph.nodes: + if not _is_tangent(node): + forward_nodes.append(node) + if node is last_node: + break forward_node_names = OrderedSet( - node.name for node in forward_only_graph.nodes if node.op != "output" + node.name for node in forward_nodes if node.op != "output" + ) + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + if _is_functional_graph(joint_module.graph)[0] is not None: + # Fall-back to previous behavior to avoid bc-breaking, although can + # eventually flip the switch to make this a hard error. + warnings.warn( + "Trying to unsafely apply AC to a non-functional graph with the " + "default partitioner. Falling back to min-cut partitioner." + ) + return min_cut_rematerialization_partition( + joint_module, + _joint_inputs, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + + joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True) + + if not config.unsafe_allow_optimization_of_collectives: + force_save_collectives(joint_module) + + force_save_bw_mutation_src(joint_module) + + if static_lifetime_input_indices is None: + static_lifetime_input_indices = [] + node_info = classify_nodes( + joint_module, static_lifetime_input_indices, num_fwd_outputs ) - order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + saved_values = [] saved_sym_nodes = [] - def is_mutated_later_in_fw(node): - if _has_tag_is_backward(node): - return False - tensor_arg_aliases = [ - x - for x in node.args - if isinstance(x, fx.Node) - and "val" in x.meta - and isinstance(x.meta["val"], torch.Tensor) - ] - while len(tensor_arg_aliases) > 0: - a = tensor_arg_aliases.pop() - for u in a.users: - if not isinstance(u.target, torch._ops.OpOverload): - continue - # If we witness a mutation on our node later, and that mutation is not "must be in backward", - # then our node needs to be computed in the forward (otherwise we will compute it on the mutated values) - if ( - # one of the args was mutated - u.target._schema.is_mutable - # and the mutation happens "later" - and order[u] > order[node] - # and the mutation happened during the forward - and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u)) - ): - for idx, alias_info in enumerate(u.target._schema.arguments): - if alias_info.is_write and u.args[idx] is a: - return True - elif u.target.is_view: - tensor_arg_aliases.append(u) - return False + distributed_enabled = torch.distributed.is_available() + + def is_tensor(node): + return "tensor_meta" in node.meta or isinstance( + node.meta.get("val"), torch._subclasses.FakeTensor + ) + + def is_multi_output(node): + return ( + all(user.target == operator.getitem for user in node.users) + and len(node.users) > 0 + ) + + def is_impure(node): + # wait tensor is an "impure" op according to DCE's definition of impure + # (see is_impure in torch/fx/node.py), but it survives past + # functionalization and can be safely dup'd and reordered under the + # assumption SPMD. + return ( + node.is_impure(impure_random=False) + and node.op + not in ( + "placeholder", + "output", + ) + and ( + not distributed_enabled + or node.target is not torch.ops._c10d_functional.wait_tensor.default + ) + ) for node in joint_module.graph.nodes: if node.name not in forward_node_names: - # if a node isn't "required" to be in the forward, but any of its arguments - # are later mutated in the forward, then it must have been run in the forward - # (if not, and the node's arg was saved for backward, we would have mutated a saved value) - # NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated - if is_mutated_later_in_fw(node): - saved_values.append(node) + continue + if node.op == "get_attr" and node.name in ( + k for k, v in joint_module.named_modules() + ): continue if node.target is torch.ops.aten._assert_scalar.default: continue @@ -1092,37 +1120,48 @@ def is_mutated_later_in_fw(node): # Symints must be kept separate from tensors so that PythonFunction only calls # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes.append(node) - elif ( - "tensor_meta" not in node.meta - and node.op == "call_function" - and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) - ): - # Since we can't save tuple of tensor values, we need to flatten out what we're saving - users = node.users - assert all(user.target is operator.getitem for user in users) - saved_values.extend(users) - else: - backward_usages = [ - n for n in node.users if n.name not in forward_node_names - ] - if "tensor_meta" in node.meta and all( - is_sym_node(n) for n in backward_usages - ): - # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, - # and not the actual tensor data, - # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. - # - # Note that saving the tensor could also cause compilation problems: - # If the user mutated an input in the forward and uses its sizes/strides in the backward, - # then we would be obligated to clone the input before saving it to appease autograd. - # (This is how we originally found this bug). - saved_sym_nodes.extend(backward_usages) - else: - saved_values.append(node) + continue + if is_multi_output(node): + # Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE. + continue + if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE: + saved_values.append(node) + continue + if is_impure(node): + assert not graph_has_recomputable_ops, ( + "Trying to apply AC on a graph with impure op", + node, + node.target, + ) + saved_values.append(node) + continue + assert is_tensor(node) or node.op != "call_function", ( + f"Expected {node} to be a tensor" + ) + backward_usages = [n for n in node.users if n.name not in forward_node_names] + if all(is_sym_node(n) for n in backward_usages): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + continue + if not must_recompute(node): + saved_values.append(node) + saved_values = list(dict.fromkeys(saved_values).keys()) saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) - return _extract_fwd_bwd_modules( + if config._sync_decision_cross_ranks: + saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values) + + if static_lifetime_input_nodes is None: + static_lifetime_input_nodes = node_info.static_lifetime_input_nodes + fw_module, bw_module = _extract_fwd_bwd_modules( joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, @@ -1130,6 +1169,31 @@ def is_mutated_later_in_fw(node): static_lifetime_input_nodes=static_lifetime_input_nodes, ) + # Run DCE while overriding the definition of is_impure_node + def is_not_collective(node): + return getattr(node.target, "namespace", None) != "_c10d_functional" + + fw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) + bw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) + + if graph_has_recomputable_ops: + if graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + joint_module, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + # raise all getitem ops to as early as possible + # this is helpful for memory, especially in the case of aot_eager backend + fw_module = raise_getitems(fw_module) + bw_module = raise_getitems(bw_module) + + fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False) + if len(node_info.required_bw_nodes) > 0: + bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True) + + return fw_module, bw_module + INT_INF = int(1e6) @@ -1624,7 +1688,16 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: break -def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: +def is_getitem_of_multi_output(node): + if node.target != operator.getitem: + return False + parent = node.args[0] + return "tensor_meta" not in parent.meta and node.op == "call_function" + + +def cleanup_recompute_tags( + joint_module: fx.GraphModule, *, is_default_partition: bool +) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in between, we would still want to stash the tensor at the boundary of @@ -1661,6 +1734,20 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: # Solution: check whether `out` has a backward hook, and if so, intentionally save `out` # in forward graph outputs. With this, we can break the above circular dependency. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + elif ( + "ac_graph_id" not in node.meta + and any(must_recompute(user) for user in node.users) + and not ( + # Avoid saving getitem nodes which are not labeled with "ac_graph_id" + is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta + ) + and is_default_partition + ): + # This node is not part of the AC region and a user is marked as recompute. + # This means it's an input to the AC region and we should save it. + # For ease of landing, gate this to default partitioner only, but we should think + # about flipping the switch in general as well. + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE return joint_module @@ -2770,6 +2857,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward): return module +def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + elif _must_be_in_backward(node): + required_bw_nodes.add(node) + + if node in required_bw_nodes: + required_bw_nodes.update(node.users) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( + _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" + ) + required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + ) + unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + ) + static_lifetime_input_nodes = OrderedSet( + p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices + ) + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, + required_fw_nodes, + required_bw_nodes, + unclaimed_nodes, + fw_order, + static_lifetime_input_nodes, + ) + + def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs, @@ -2818,68 +2958,16 @@ def min_cut_rematerialization_partition( graph_has_recomputable_ops = has_recomputable_ops(joint_module) graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) if graph_has_recomputable_ops: - joint_module = cleanup_recompute_tags(joint_module) + joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False) if not config.unsafe_allow_optimization_of_collectives: force_save_collectives(joint_module) force_save_bw_mutation_src(joint_module) - def classify_nodes(joint_module, static_lifetime_input_indices): - name_to_node = get_name_to_node(joint_module.graph) - required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() - for node in joint_module.graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - required_bw_nodes.add(node) - elif _must_be_in_backward(node): - required_bw_nodes.add(node) - - if node in required_bw_nodes: - required_bw_nodes.update(node.users) - - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list( - filter(_is_fwd_seed_offset, joint_module.graph.nodes) - ) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - required_bw_nodes.update( - o for o in bwd_outputs if o is not None and o.op != "output" - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) - required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( - name_to_node[node.name] - for node in forward_only_graph.nodes - if node.op != "output" - ) - unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( - node - for node in joint_module.graph.nodes - if node not in required_fw_nodes and node not in required_bw_nodes - ) - static_lifetime_input_nodes = OrderedSet( - p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices - ) - fw_cnt = 0 - fw_order = {} - for node in joint_module.graph.nodes: - if node in required_fw_nodes: - fw_order[node] = fw_cnt - fw_cnt += 1 - return NodeInfo( - inputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - fw_order, - static_lifetime_input_nodes, - ) - if static_lifetime_input_indices is None: static_lifetime_input_indices = [] - node_info = classify_nodes(joint_module, static_lifetime_input_indices) + node_info = classify_nodes( + joint_module, static_lifetime_input_indices, num_fwd_outputs + ) # networkx blows up on graphs with no required backward nodes # Since there's nothing to partition anyway, and the default partitioner can "handle" diff --git a/torch/_higher_order_ops/local_map.py b/torch/_higher_order_ops/local_map.py index 7970acbc5d6ad..1d4ad631ea102 100644 --- a/torch/_higher_order_ops/local_map.py +++ b/torch/_higher_order_ops/local_map.py @@ -334,6 +334,13 @@ def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]: static_lifetime_input_indices=[], ) + # Fix tags because min-cut does not respect fw/bw boundary, breaking + # default partitioner's assumptions. + for node in new_fw_gm.graph.nodes: + node.meta["partitioner_tag"] = "is_forward" + for node in new_bw_gm.graph.nodes: + node.meta["partitioner_tag"] = "is_backward" + # Propagate meta onto fw/bw graphs, later will be set on proxied nodes new_fw_gm.meta["local_map_kwargs"] = local_map_kwargs new_bw_gm.meta["local_map_kwargs"] = {**local_map_kwargs} From 44ac69388a4a5eb463dbd2a13f00d1e3b924566c Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Mon, 1 Dec 2025 10:50:18 -0800 Subject: [PATCH 122/338] Add DCE pass to remove unused intermediates (#169131) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169131 Approved by: https://github.com/anijain2305 --- test/dynamo/test_activation_checkpointing.py | 18 +- test/export/test_export.py | 22 +-- test/higher_order_ops/test_invoke_subgraph.py | 55 ++++++ torch/_dynamo/dce_extra_outputs.py | 187 ++++++++++++++++++ torch/_dynamo/output_graph.py | 4 + 5 files changed, 263 insertions(+), 23 deletions(-) create mode 100644 torch/_dynamo/dce_extra_outputs.py diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 768555efd1d4c..064cf606182f9 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -1953,24 +1953,24 @@ def forward(self, L_x_: "f32[4, 4]"): wrap_body_0 = self.wrap_body_0 tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None - out1: "f32[4, 4]" = tag_activation_checkpoint[0] - out2: "f32[4, 4]" = tag_activation_checkpoint[1] - getitem_4: "f32[4, 4]" = tag_activation_checkpoint[4]; tag_activation_checkpoint = None + getitem_6: "f32[4, 4]" = tag_activation_checkpoint[0] + getitem_7: "f32[4, 4]" = tag_activation_checkpoint[1] + getitem_8: "f32[4, 4]" = tag_activation_checkpoint[2]; tag_activation_checkpoint = None - add: "f32[4, 4]" = out1 + out2; out1 = out2 = None - return (add, getitem_4) + add: "f32[4, 4]" = getitem_6 + getitem_7; getitem_6 = getitem_7 = None + return (add, getitem_8) class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[4, 4]"): matmul: "f32[4, 4]" = torch.matmul(l_x_, l_x_) - o: "f32[4, 4]" = matmul @ l_x_ + o: "f32[4, 4]" = matmul @ l_x_; matmul = None out: "f32[4, 4]" = l_x_.sin() - sin_1: "f32[4, 4]" = torch.sin(o) - cos: "f32[4, 4]" = torch.cos(sin_1) + sin_1: "f32[4, 4]" = torch.sin(o); o = None + cos: "f32[4, 4]" = torch.cos(sin_1); sin_1 = None sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None - return (cos, sin_2, matmul, o, out, sin_1) + return (cos, sin_2, out) """, ) diff --git a/test/export/test_export.py b/test/export/test_export.py index 6ebed4f224643..1e1f40fba99df 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1235,14 +1235,8 @@ def forward(self, x): %p_block_linear2_bias : [num_users=1] = placeholder[target=p_block_linear2_bias] %x : [num_users=1] = placeholder[target=x] %wrap_body0 : [num_users=1] = get_attr[target=wrap_body0] - %tag_activation_checkpoint : [num_users=7] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {}) + %tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {}) - %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 1), kwargs = {}) - %getitem_2 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 2), kwargs = {}) - %getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 3), kwargs = {}) - %getitem_4 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 4), kwargs = {}) - %getitem_5 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 5), kwargs = {}) - %getitem_6 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 6), kwargs = {}) return (getitem,)""", ) @@ -1251,14 +1245,14 @@ def forward(self, x): """\ graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] - %arg1_1 : [num_users=2] = placeholder[target=arg1_1] - %arg2_1 : [num_users=2] = placeholder[target=arg2_1] - %arg3_1 : [num_users=2] = placeholder[target=arg3_1] - %arg4_1 : [num_users=2] = placeholder[target=arg4_1] - %linear : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {}) - %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {}) + %arg1_1 : [num_users=1] = placeholder[target=arg1_1] + %arg2_1 : [num_users=1] = placeholder[target=arg2_1] + %arg3_1 : [num_users=1] = placeholder[target=arg3_1] + %arg4_1 : [num_users=1] = placeholder[target=arg4_1] + %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {}) + %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {}) %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %arg3_1, %arg4_1), kwargs = {}) - return (linear_1, arg1_1, arg2_1, linear, relu, arg3_1, arg4_1)""", + return (linear_1,)""", ) stack = contextlib.ExitStack() diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index c8a4ac1b67a84..67c4fa0757769 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -844,6 +844,61 @@ def forward(self, arg0_1: "f32[4]"): """, ) + def test_dce_recursive(self): + def fn1(x): + a = torch.sin(x) + _ = torch.cos(x) # unused intermediate + return a + + @nested_compile_region + def fn1_checkpoint(x): + return torch.utils.checkpoint.checkpoint(fn1, x, use_reentrant=False) + + def fn(x): + return fn1_checkpoint(x).detach() + + x = torch.randn(8, requires_grad=True) + + with torch._dynamo.config.patch( + skip_fwd_side_effects_in_bwd_under_checkpoint=True + ): + backend = EagerAndRecordGraphs() + torch.compile(fn, backend=backend, fullgraph=True)(x) + + if not TEST_WITH_CROSSREF: + # Verify that DCE applied recursively: + # - invoke_subgraph subgraph should be DCE'd + # - nested tag_activation_checkpoint subgraph should also be DCE'd (requires recursion) + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[8]"): + l_x_ = L_x_ + + subgraph_0 = self.subgraph_0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_); subgraph_0 = l_x_ = None + getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + + detach: "f32[8]" = getitem.detach(); getitem = None + return (detach,) + + class subgraph_0(torch.nn.Module): + def forward(self, l_x_: "f32[8]"): + wrap_body_0 = self.wrap_body_0 + tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None + getitem_2: "f32[8]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None + return (getitem_2,) + + class wrap_body_0(torch.nn.Module): + def forward(self, l_x_: "f32[8]"): + a: "f32[8]" = torch.sin(l_x_) + + _: "f32[8]" = torch.cos(l_x_); l_x_ = _ = None + return (a,) +""", + ) + def test_nonlocal_update(self): counter = 2 diff --git a/torch/_dynamo/dce_extra_outputs.py b/torch/_dynamo/dce_extra_outputs.py new file mode 100644 index 0000000000000..0c9342902ab2e --- /dev/null +++ b/torch/_dynamo/dce_extra_outputs.py @@ -0,0 +1,187 @@ +""" +DCE pass for unused extra outputs in HOP subgraphs. + +When enable_side_effects_with_extra_outputs is True, HOPs like invoke_subgraph, +checkpoint (tag_activation_checkpoint), and autograd.Function (autograd_function_apply) +return all intermediate tensors/symints as extra outputs to support side effects. +However, many of these extra outputs may not actually be used in the parent graph. + +Special handling for autograd_function_apply: +- The forward subgraph MUST return (output, saved_values, ...) where indices 0 and 1 + are always required by the runtime +- Only indices 2+ (extra intermediates) can be removed by DCE + +This pass removes unused extra outputs by: +1. Identifying which outputs of HOP calls are actually used +2. Removing unused outputs from the subgraph's output node +3. Updating the HOP call to reflect the new output arity +4. Updating getitem indices to account for removed outputs +""" + +import collections +import operator + +import torch + + +# HOPs that may have extra outputs that can be DCE'd +_HOPS_WITH_EXTRA_OUTPUTS = { + torch.ops.higher_order.invoke_subgraph, + torch.ops.higher_order.tag_activation_checkpoint, + # torch.ops.higher_order.autograd_function_apply, +} + + +def dce_hop_extra_outputs(gm: torch.fx.GraphModule) -> bool: + """ + Remove unused extra outputs from HOP calls recursively. + + Processes graphs top-down: first DCE the current graph's HOP outputs, + then recursively process nested subgraphs. This ensures that when we + process a nested subgraph, the parent has already removed unused getitems, + so the nested subgraph sees the correct usage information. + + Args: + gm: The GraphModule to optimize + + Returns: + True if any modifications were made, False otherwise + """ + modified = False + + # Group HOP nodes by subgraph name + # Multiple invocations may share the same subgraph, so we need to check + # which indices are used across ALL invocations before removing any + subgraph_to_nodes: dict[str, list[torch.fx.Node]] = collections.defaultdict(list) + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in _HOPS_WITH_EXTRA_OUTPUTS: + subgraph_attr = node.args[0] + if ( + isinstance(subgraph_attr, torch.fx.Node) + and subgraph_attr.op == "get_attr" + ): + subgraph_name = subgraph_attr.target + assert isinstance(subgraph_name, str) + subgraph_to_nodes[subgraph_name].append(node) + + # STEP 1: DCE this graph's HOP outputs first (top-down) + for subgraph_name, hop_nodes in subgraph_to_nodes.items(): + if _dce_subgraph(gm, subgraph_name, hop_nodes): + modified = True + + if modified: + gm.graph.lint() + gm.recompile() + + # STEP 2: Recursively process nested subgraphs + # After we've removed unused getitems from this graph, nested subgraphs + # will see the correct usage information + for subgraph_name in subgraph_to_nodes: + subgraph = getattr(gm, subgraph_name) + if isinstance(subgraph, torch.fx.GraphModule): + if dce_hop_extra_outputs(subgraph): + modified = True + + return modified + + +def _dce_subgraph( + gm: torch.fx.GraphModule, subgraph_name: str, hop_nodes: list[torch.fx.Node] +) -> bool: + """ + DCE a single subgraph by removing unused output indices. + """ + subgraph = getattr(gm, subgraph_name) + + if not isinstance(subgraph, torch.fx.GraphModule): + return False + + # Collect used indices for THIS subgraph + used_indices: set[int] = set() + + # Check if this is the forward subgraph of autograd_function_apply + # For autograd_function_apply, the fwd subgraph must return (output, saved_values, ...) + # where indices 0 and 1 are ALWAYS required by the runtime + # is_autograd_fwd = any( + # node.target == torch.ops.higher_order.autograd_function_apply + # for node in hop_nodes + # ) + is_autograd_fwd = False + + for hop_node in hop_nodes: + for user in list(hop_node.users): + if user.op == "call_function" and user.target == operator.getitem: + if len(list(user.users)) > 0: + idx = user.args[1] + assert isinstance(idx, int) + used_indices.add(idx) + + output_node = next(n for n in subgraph.graph.nodes if n.op == "output") + old_outputs = list(output_node.args[0]) + + # For autograd_function_apply forward subgraph, indices 0 (output) and 1 (saved_values) + # are ALWAYS used by the runtime, even if not explicitly accessed via getitem + if is_autograd_fwd and len(old_outputs) >= 2: + used_indices.add(0) # output + used_indices.add(1) # saved_values + + # Nothing to DCE if all outputs are used or no outputs are used + if len(used_indices) >= len(old_outputs) or len(used_indices) == 0: + return False + + # Build mapping from old indices to new indices + old_to_new: dict[int, int] = {} + new_outputs = [] + new_idx = 0 + + for old_idx in range(len(old_outputs)): + if old_idx in used_indices: + old_to_new[old_idx] = new_idx + new_outputs.append(old_outputs[old_idx]) + new_idx += 1 + + # Update subgraph output node + # Create a new output node with the filtered outputs + with subgraph.graph.inserting_before(output_node): + new_output_node = subgraph.graph.output(tuple(new_outputs)) + output_node.replace_all_uses_with(new_output_node) + subgraph.graph.erase_node(output_node) + + for hop_node in hop_nodes: + # Update getitem nodes to use new indices + for user in list(hop_node.users): + if user.op == "call_function" and user.target == operator.getitem: + old_idx = user.args[1] + assert isinstance(old_idx, int) + if old_idx not in old_to_new: + assert len(list(user.users)) == 0 + gm.graph.erase_node(user) + continue + + new_idx = old_to_new[old_idx] + # Create a new getitem node with the new index + with gm.graph.inserting_before(user): + new_getitem = gm.graph.call_function( + operator.getitem, args=(user.args[0], new_idx) + ) + # Copy metadata from old node + new_getitem.meta = user.meta.copy() + user.replace_all_uses_with(new_getitem) + gm.graph.erase_node(user) + + # Update example_value metadata on hop_node + if "example_value" in hop_node.meta: + old_example = hop_node.meta["example_value"] + assert isinstance(old_example, (tuple, list)) + new_example = tuple( + old_example[old_idx] + for old_idx in range(len(old_outputs)) + if old_idx in used_indices + ) + hop_node.meta["example_value"] = new_example + + subgraph.graph.lint() + subgraph.recompile() + + return True diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 67c29e9f9c62c..b5e662cefec59 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -2142,6 +2142,10 @@ def compile_and_call_fx_graph( gm = _make_graph_module(root, self.graph) + from .dce_extra_outputs import dce_hop_extra_outputs + + dce_hop_extra_outputs(gm) + # Saved tensors hooks are not used by the graph. # GraphModule by default only copies used in the graph submodules. # Copying them into the result graph manually. From b4cc1329c86acaef6d42c1fac7169b8d870ab0d7 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 2 Dec 2025 16:21:04 +0000 Subject: [PATCH 123/338] [BE][Typing][Dynamo] Type torch/_dynamo/variables/nn_module.py (#167342) Provides type coverage to torch/_dynamo/variables/nn_module.py Coverage report: `mypy torch/_dynamo/variables/nn_module.py --linecount-report /tmp/coverage_log` Compare before to after - we go from 0 lines and 0 funcs covered to 1378 lines and 31 funcs covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/167342 Approved by: https://github.com/williamwen42 --- torch/_dynamo/output_graph.py | 1 + torch/_dynamo/variables/nn_module.py | 318 ++++++++++++++++----------- 2 files changed, 185 insertions(+), 134 deletions(-) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index b5e662cefec59..414051bcaa1d9 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1183,6 +1183,7 @@ def wrap_name(module_key: str) -> VariableTracker: # sourceless, so let's return a unspecializedNNModule variable # tracker. def wrap_name(module_key: str) -> VariableTracker: + # pyrefly: ignore[bad-argument-type] return variables.UnspecializedNNModuleVariable(target, **options) elif isinstance(target, (torch.SymInt, torch.SymFloat)): diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 4b5198ffe8533..525c42a009c1d 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This module implements variable tracking for PyTorch nn.Module instances during Dynamo tracing. @@ -28,10 +26,12 @@ import itertools import re import types +from collections.abc import Iterable, Sequence from contextlib import contextmanager, nullcontext -from typing import TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import torch.nn +from torch._guards import Source from .. import graph_break_hints, trace_rules, variables from ..exc import raise_observed_exception, unimplemented, UnspecializeRestartAnalysis @@ -75,7 +75,12 @@ from .constant import ConstantVariable -def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs): +def initialize_lazy_module( + tx: "InstructionTranslator", + mod: torch.nn.Module, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], +) -> None: """ Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable. @@ -85,11 +90,11 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs): """ if hasattr(mod, "_initialize_hook"): - def convert_to_fake(x): + def convert_to_fake(x: Any) -> Any: if is_namedtuple(x): return type(x)(*(convert_to_fake(elem) for elem in x)) elif isinstance(x, dict): - return {k: convert_to_fake(v) for k, v in x.items()} + return {k: convert_to_fake(v) for k, v in x.items()} # type: ignore[misc] elif isinstance(x, (list, tuple, set)): return type(x)(convert_to_fake(elem) for elem in x) elif isinstance(x, torch.fx.Proxy): @@ -101,7 +106,7 @@ def convert_to_fake(x): fake_args = [convert_to_fake(arg) for arg in proxy_args] fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()} try: - mod._infer_parameters(mod, fake_args, fake_kwargs) + mod._infer_parameters(mod, fake_args, fake_kwargs) # type: ignore[operator] except AttributeError as e: # Re-raise with the original error message from the AttributeError raise_observed_exception( @@ -114,7 +119,9 @@ def convert_to_fake(x): @contextmanager -def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): +def record_nn_module_stack( + module_key: str, source: Source, tx: "InstructionTranslator", mod: torch.nn.Module +) -> Any: fully_qualified_name = source.name() # Remove redundant namings fully_qualified_name = re.sub( @@ -132,7 +139,9 @@ def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): del tx.nn_module_stack[module_key] -def guard_to_detect_forward_monkeypatching(source, mod): +def guard_to_detect_forward_monkeypatching( + source: Optional[Source], mod: torch.nn.Module +) -> None: # Users sometimes patch the forward method of a nn module instance to # perform optimizations like quantization. Though this is not a good # software practice, but python allows this and Dynamo needs to detect @@ -175,41 +184,51 @@ class NNModuleVariable(VariableTracker): } def __init__( - self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs + self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs: Any ) -> None: super().__init__(**kwargs) self.module_type = module_type self.module_key = module_key self.value = value - assert self.source + # pyrefly: ignore[bad-override] + # NOTE: Don't remove this; better than adding suppressions + # everywhere else with asserts + self.source: Source = self.source self.nn_module_stack_source = self.source - def get_nn_module_stack_source(self): - return self.nn_module_stack_source or self.source + def get_nn_module_stack_source(self) -> Source: + res = self.nn_module_stack_source or self.source + assert res + return res - def set_nn_module_stack_source(self, source): + def set_nn_module_stack_source(self, source: Source) -> None: self.nn_module_stack_source = source - def python_type(self): + def python_type(self) -> type: return self.module_type def _wrap_submodule( - self, tx: "InstructionTranslator", source, submod, *key_extra, **options - ): + self, + tx: "InstructionTranslator", + source: Source, + submod: torch.nn.Module, + *key_extra: Any, + **options: Any, + ) -> None: return - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: # implement list/iter/tuple/etc calls base = tx.output.get_submodule(self.module_key) + result: list[VariableTracker] = [] if isinstance(base, torch.nn.ModuleDict): - result = [] for name, submod in base.items(): name_var = variables.ConstantVariable.create(name) tx.output.register_attr_or_module( submod, self.module_key, name, - source=NNModuleSource(GetItemSource(self.source, name)), + source=NNModuleSource(GetItemSource(self.source, name)), # type: ignore[arg-type] ) result.append(name_var) return result @@ -217,8 +236,6 @@ def unpack_var_sequence(self, tx): assert isinstance( base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential) ), typestr(base) - assert self.source - result = [] for idx, submod in enumerate(base): result.append( tx.output.register_attr_or_module( @@ -242,11 +259,11 @@ def call_obj_hasattr( ) return variables.ConstantVariable.create(result) - def is_training(self, tx): + def is_training(self, tx: "InstructionTranslator") -> bool: mod = tx.output.get_submodule(self.module_key) return getattr(mod, "training", False) - def convert_to_unspecialized(self, tx): + def convert_to_unspecialized(self, tx: "InstructionTranslator") -> None: """Restart analysis treating this module as an UnspecializedNNModuleVariable""" mod = tx.output.get_submodule(self.module_key) GenerationTracker.tag(mod) @@ -256,7 +273,7 @@ def convert_to_unspecialized(self, tx): GenerationTracker.mark_class_dynamic(type(mod)) raise UnspecializeRestartAnalysis - def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key: str) -> bool: base = tx.output.get_submodule(self.module_key) if object_has_getattribute(base): @@ -279,7 +296,13 @@ def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): base_dict = object.__getattribute__(base, "__dict__") return key in base_dict - def _custom_getattr_fallback(self, base, tx, name, obj_source): + def _custom_getattr_fallback( + self, + base: torch.nn.Module, + tx: "InstructionTranslator", + name: str, + obj_source: Source, + ) -> Optional[VariableTracker]: """Check for a __getattr__ and handle it specially if it is implemented""" if object_has_getattribute(base): unimplemented( @@ -318,11 +341,12 @@ def _custom_getattr_fallback(self, base, tx, name, obj_source): ) options = {"source": AttrSource(obj_source, "__getattr__")} + # pyrefly: ignore[bad-argument-type] return variables.UserMethodVariable(getattr_fn, self, **options).call_function( tx, [variables.ConstantVariable.create(name)], {} ) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: source = self.source and AttrSource(self.source, name) base = tx.output.get_submodule(self.module_key) @@ -345,6 +369,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): if name == "__dict__": return variables.GetAttrVariable(self, name, source=source) + subobj = None if name in base_dict: subobj = base_dict[name] elif ( @@ -382,7 +407,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): return variables.UserDefinedClassVariable(base.__class__, source=source) if object_member: - out = VariableTracker.build(tx, subobj, NNModuleSource(source)) + out = VariableTracker.build(tx, subobj, NNModuleSource(source)) # type: ignore[arg-type] if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): # nn_module_stack source is BC surface area. Ensure that @@ -401,7 +426,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): # Get the getter function source = AttrSource(source, "fget") return variables.UserFunctionVariable( - subobj.fget, + subobj.fget, # pyrefly: ignore[bad-argument-type] source=source, ).call_function(tx, [(self)], {}) elif istype(subobj, classmethod): @@ -412,13 +437,15 @@ def var_getattr(self, tx: "InstructionTranslator", name): ) elif istype(subobj, staticmethod): return variables.UserFunctionVariable( - subobj.__get__(base), source=source + # pyrefly: ignore[bad-argument-type] + subobj.__get__(base), + source=source, ) elif istype(subobj, types.FunctionType): return variables.UserMethodVariable(subobj, self, source=source) elif is_safe_constant(subobj) or istensor(subobj): # Support possibly common cases of class members - return VariableTracker.build(tx, subobj, NNModuleSource(source)) + return VariableTracker.build(tx, subobj, NNModuleSource(source)) # type: ignore[arg-type] else: unimplemented( gb_type="Unsupported nn.Module attribute type", @@ -436,10 +463,10 @@ def var_getattr(self, tx: "InstructionTranslator", name): def call_function( self, - tx, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: mod = tx.output.get_submodule(self.module_key) with record_nn_module_stack( @@ -475,7 +502,7 @@ def call_function( submod, self.module_key, child_name, - source=NNModuleSource(AttrSource(self.source, child_name)), + source=NNModuleSource(AttrSource(self.source, child_name)), # type: ignore[arg-type] ), [arg], {}, @@ -486,7 +513,7 @@ def call_function( if is_lazy: # The module type will change after it is called if mod.cls_to_become is not None: - self.module_type = mod.cls_to_become + self.module_type = mod.cls_to_become # type: ignore[assignment] # The pre-hook runs to initialize the module shapes, then deletes itself. After this, # the module is more or less not lazy and can be treated as a normal module regardless of @@ -527,10 +554,6 @@ def call_function( ), ) else: - assert self.source, ( - "Must provide a valid source in order to inline, " - "since inlined function may have default args which must be guarded." - ) if isinstance(mod, torch.fx.GraphModule): # TODO: do we want to support __call__ for GM's? # If so at least some changes are needed, we don't allow inlining @@ -543,10 +566,11 @@ def call_function( if istype(fn, types.MethodType): fn = fn.__func__ fn_source = AttrSource(fn_source, "__func__") - args = [self] + args + args = [self] + list(args) else: assert istype(fn, types.FunctionType) return tx.inline_user_function_return( + # pyrefly: ignore[bad-argument-type] variables.UserFunctionVariable(fn, source=fn_source), args, kwargs, @@ -554,18 +578,18 @@ def call_function( def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - constant=False, - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + constant: bool = False, + ) -> VariableTracker: from . import ConstantVariable, ListIteratorVariable, TupleVariable key = self.module_key module = tx.output.get_submodule(key) - def generic_call_method_helper(name): + def generic_call_method_helper(name: str) -> VariableTracker: # Helper function to put a `call_method` node in FX graph, # with nn.Module as the first arg. mod_proxy = tx.output.create_proxy( @@ -605,7 +629,7 @@ def generic_call_method_helper(name): return generic_call_method_helper(name) if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( - inspect.getfile(module.__class__._check_input_dim) + inspect.getfile(module.__class__._check_input_dim) # type: ignore[union-attr] ): return ConstantVariable.create(True) @@ -620,10 +644,10 @@ def generic_call_method_helper(name): tx, f"``nn.Module`` {module}'s call method {name} requires a tuple as first argument", ) - mod_var = args[0].items[args[1].value] + mod_var = args[0].items[args[1].value] # type: ignore[attr-defined] if isinstance(mod_var, UnspecializedNNModuleVariable): return mod_var - key = mod_var.module_key + key = mod_var.module_key # type: ignore[attr-defined] submod = tx.output.get_submodule(key) return tx.output.register_attr_or_module( submod, @@ -637,7 +661,7 @@ def generic_call_method_helper(name): name = f"{module.__class__.__name__}_{name}_result" return invoke_and_store_as_constant(tx, fn, name, args, kwargs) - def assert_all_args_kwargs_const(): + def assert_all_args_kwargs_const() -> None: if not all( x.is_python_constant() for x in itertools.chain(args, kwargs.values()) ): @@ -649,7 +673,7 @@ def assert_all_args_kwargs_const(): hints=[], ) - def get_kwargs(*names): + def get_kwargs(*names: str) -> dict[str, Any]: assert_all_args_kwargs_const() fn = getattr(module, name) bound_args = inspect.signature(fn).bind( @@ -660,7 +684,9 @@ def get_kwargs(*names): bound_args = bound_args.arguments return {k: bound_args[k] for k in names} - def wrap_values(items): + def wrap_values( + items: Iterable[tuple[Any, Any]], + ) -> "variables.ListIteratorVariable": result = [] for name, submod in items: result.append( @@ -671,9 +697,11 @@ def wrap_values(items): source=NNModuleSource(gen_source(self.source, name)), ) ) - return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + return ListIteratorVariable( + named_children, mutation_type=ValueMutationNew() + ) - def named_embed(name, obj): + def named_embed(name: str, obj: Any) -> "variables.TupleVariable": return TupleVariable( [ ConstantVariable.create(name), @@ -686,7 +714,7 @@ def named_embed(name, obj): ] ) - def gen_source(source, name): + def gen_source(source: Source, name: str) -> Source: name_split = name.split(".") if name_split[0] == "": return source @@ -704,34 +732,40 @@ def gen_source(source, name): "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) - result = [] + named_children: list[VariableTracker] = [] for name, submod in module.named_children(): - result.append(named_embed(name, submod)) - return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + named_children.append(named_embed(name, submod)) + return ListIteratorVariable( + named_children, mutation_type=ValueMutationNew() + ) elif name == "named_parameters": tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) - result = [] + named_parameters: list[VariableTracker] = [] for name, param in module.named_parameters( **get_kwargs("prefix", "recurse") ): - result.append(named_embed(name, param)) - return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + named_parameters.append(named_embed(name, param)) + return ListIteratorVariable( + named_parameters, mutation_type=ValueMutationNew() + ) elif name == "named_buffers": tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) - result = [] + named_buffers: list[VariableTracker] = [] for name, buffer in module.named_buffers( **get_kwargs("prefix", "recurse", "remove_duplicate") ): - result.append(named_embed(name, buffer)) - return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + named_buffers.append(named_embed(name, buffer)) + return ListIteratorVariable(named_buffers, mutation_type=ValueMutationNew()) elif name == "named_modules": tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) - result = [] + named_modules_list: list[VariableTracker] = [] for name, submod in module.named_modules( **get_kwargs("memo", "prefix", "remove_duplicate") ): - result.append(named_embed(name, submod)) - return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + named_modules_list.append(named_embed(name, submod)) + return ListIteratorVariable( + named_modules_list, mutation_type=ValueMutationNew() + ) elif name == "children": tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) if args or kwargs: @@ -760,8 +794,9 @@ def gen_source(source, name): f"{len(args)} args and {len(kwargs)} kwargs", ) result = [] - for name in module: - result.append(ConstantVariable.create(name)) + # pyrefly: ignore[not-iterable] + for tmp in module: + result.append(ConstantVariable.create(tmp)) return ListIteratorVariable(result, mutation_type=ValueMutationNew()) elif name == "values": if args or kwargs: @@ -771,7 +806,7 @@ def gen_source(source, name): "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) - return wrap_values(module.items()) + return wrap_values(module.items()) # type: ignore[operator] elif name == "items": if args or kwargs: raise_args_mismatch( @@ -780,10 +815,10 @@ def gen_source(source, name): "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) - result = [] - for name, submod in module.items(): - result.append(named_embed(name, submod)) - return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + items_result: list[VariableTracker] = [] + for name, submod in module.items(): # type: ignore[operator] + items_result.append(named_embed(name, submod)) + return ListIteratorVariable(items_result, mutation_type=ValueMutationNew()) elif name == "__len__": if args or kwargs: raise_args_mismatch( @@ -792,7 +827,7 @@ def gen_source(source, name): "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) - return ConstantVariable.create(len(module)) + return ConstantVariable.create(len(module)) # type: ignore[arg-type] elif name == "__iter__": return ListIteratorVariable( self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() @@ -821,7 +856,7 @@ def gen_source(source, name): torch.nn.ParameterList.__getitem__, torch.nn.Sequential.__getitem__, ) - + # pyrefly: ignore[missing-attribute] if type(module).__getitem__ not in builtin_supported: if not ( isinstance(args[0], variables.ConstantVariable) @@ -840,15 +875,13 @@ def gen_source(source, name): assert isinstance(fn, types.FunctionType) - src = AttrSource(AttrSource(self.source, name), "__func__") + src = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type] return tx.inline_user_function_return( variables.UserFunctionVariable(fn, source=src), [self] + list(args), kwargs, ) - assert self.source - if isinstance(args[0], SliceVariable): # TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is # enabled for export. @@ -857,8 +890,8 @@ def gen_source(source, name): result = [] # Turn the slice into the list of integers - keys = list(range(len(module)))[args[0].as_python_constant()] - for idx, submod in enumerate(module[args[0].as_python_constant()]): + keys = list(range(len(module)))[args[0].as_python_constant()] # type: ignore[arg-type] + for idx, submod in enumerate(module[args[0].as_python_constant()]): # type: ignore[arg-type] key = keys[idx] src = NNModuleSource(GetItemSource(self.source, key)) result.append( @@ -869,7 +902,7 @@ def gen_source(source, name): ) ) - new_module = module[args[0].as_python_constant()] + new_module = module[args[0].as_python_constant()] # type: ignore[index] new_module_variable = tx.output.register_attr_or_module( new_module, f"{self}.__getitem__(slice)", @@ -885,10 +918,11 @@ def gen_source(source, name): from .tensor import SymNodeVariable + key_value = 0 if isinstance(args[0], SymNodeVariable): - key = args[0].evaluate_expr(tx.output) + key_value = args[0].evaluate_expr(tx.output) elif args[0].is_python_constant(): - key = args[0].as_python_constant() + key_value = args[0].as_python_constant() else: unimplemented( gb_type="Unsupported key type for nn.Module.__getitem__", @@ -898,12 +932,12 @@ def gen_source(source, name): hints=[], ) - submod = module[key] + submod = module[key_value] # type: ignore[index] return tx.output.register_attr_or_module( submod, self.module_key, - key, - source=NNModuleSource(GetItemSource(self.source, key)), + key_value, + source=NNModuleSource(GetItemSource(self.source, key_value)), ) elif ( name == "_get_abs_string_index" @@ -918,10 +952,10 @@ def gen_source(source, name): ): # Inline the function fn = getattr(module, name).__func__ - fn_source = AttrSource(AttrSource(self.source, name), "__func__") + fn_source = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type] return tx.inline_user_function_return( variables.UserFunctionVariable(fn, source=fn_source), - [self] + args, + [self] + list(args), kwargs, ) # A loose heuristic, but seems to be generally good before we drop into the @@ -936,7 +970,7 @@ def gen_source(source, name): ): return generic_call_method_helper(name) else: - return super().call_method(tx, name, args, kwargs) + return super().call_method(tx, name, list(args), kwargs) class UnspecializedNNModuleVariable(UserDefinedObjectVariable): @@ -955,7 +989,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable): Giving one graph per module class. """ - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: torch.nn.Module, **kwargs: Any) -> None: if type(value) is torch.jit._script.RecursiveScriptModule: unimplemented( gb_type="UnspecializedNNModuleVariable wrapped around ScriptModules unsupported", @@ -985,19 +1019,21 @@ def __init__(self, value, **kwargs) -> None: # nn_module_stack_source appropriately to resemble mod.linear. self.nn_module_stack_source = self.source - def _wrap_source(self, attr_source): + def _wrap_source(self, attr_source: Source) -> Source: # the vt is already wrapped with UnspecializedNNModuleSource return attr_source - def get_nn_module_stack_source(self): - return self.nn_module_stack_source or self.source + def get_nn_module_stack_source(self) -> Source: + res = self.nn_module_stack_source or self.source + assert res + return res - def set_nn_module_stack_source(self, source): + def set_nn_module_stack_source(self, source: Source) -> None: self.nn_module_stack_source = source @staticmethod @functools.cache - def _nn_module_method_ids(): + def _nn_module_method_ids() -> set[int]: # Allow __setattr__ to fall through to base class handler supported = { torch.nn.Module.__setattr__, @@ -1010,7 +1046,7 @@ def _nn_module_method_ids(): if hasattr(x, "__code__") and x not in supported } - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: try: fn = inspect.getattr_static(self.value_type, "__iter__") except AttributeError as e: @@ -1037,15 +1073,15 @@ def unpack_var_sequence(self, tx): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: mod = self.value # see comment on lazy module handling in NNModuleVariable.call_function for context - if is_lazy_module(mod): - if mod.cls_to_become is not None: - self.value_type = mod.cls_to_become - initialize_lazy_module(tx, mod, args, kwargs) + if is_lazy_module(mod): # type: ignore[arg-type] + if mod.cls_to_become is not None: # type: ignore[attr-defined] + self.value_type = mod.cls_to_become # type: ignore[attr-defined,assignment] + initialize_lazy_module(tx, mod, args, kwargs) # type: ignore[arg-type] if not isinstance(mod, torch.fx.GraphModule): name = "__call__" @@ -1057,24 +1093,28 @@ def call_function( # Check if we can short circuit nn.Module._call_impl to the forward # method. NB - This is done to reduce the compile time of Dynamo. if ( - istype(mod.__call__, types.MethodType) - and istype(mod._call_impl, types.MethodType) - and mod.__call__.__func__ is unpatched_nn_module_call - and mod._call_impl.__func__ is unpatched_nn_module_call_impl + istype(mod.__call__, types.MethodType) # type: ignore[operator] + and istype(mod._call_impl, types.MethodType) # type: ignore[attr-defined] + and mod.__call__.__func__ is unpatched_nn_module_call # type: ignore[operator] + and mod._call_impl.__func__ is unpatched_nn_module_call_impl # type: ignore[attr-defined] and "forward" not in mod.__dict__ ): forward_method = inspect.getattr_static(mod, "forward") if isinstance(forward_method, types.FunctionType): globals_vt = tx.nn_modules_globals_vt if not ( - self.var_getattr(tx, "_backward_hooks").realize().len() - or self.var_getattr(tx, "_backward_pre_hooks").realize().len() - or self.var_getattr(tx, "_forward_hooks").realize().len() - or self.var_getattr(tx, "_forward_pre_hooks").realize().len() - or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() - or globals_vt.var_getattr(tx, "_global_backward_hooks").len() - or globals_vt.var_getattr(tx, "_global_forward_hooks").len() - or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() + self.var_getattr(tx, "_backward_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_backward_pre_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_forward_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_forward_pre_hooks").realize().len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined] ): name = "forward" fn = self.value_type.forward @@ -1084,11 +1124,14 @@ def call_function( else: source = None - guard_to_detect_forward_monkeypatching(self.source, mod) + guard_to_detect_forward_monkeypatching(self.source, mod) # type: ignore[arg-type] ctx = ( record_nn_module_stack( - str(id(mod)), self.get_nn_module_stack_source(), tx, mod + str(id(mod)), + self.get_nn_module_stack_source(), + tx, + mod, # type: ignore[arg-type] ) if self.source else nullcontext() @@ -1108,11 +1151,11 @@ def call_function( def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name in ["_call_impl", "_wrapped_call_impl"]: fn = getattr(self.value_type, name) if self.source: @@ -1195,15 +1238,17 @@ def call_method( fn_vt = VariableTracker.build(tx, torch.nn.Module.__delattr__) return fn_vt.call_function(tx, [self, args[0]], kwargs) - return super().call_method(tx, name, args, kwargs) + return super().call_method(tx, name, list(args), kwargs) - def getattr_helper(self, tx: "InstructionTranslator", field, name_vt): + def getattr_helper( + self, tx: "InstructionTranslator", field: str, name_vt: VariableTracker + ) -> Optional[VariableTracker]: dict_vt = self.var_getattr(tx, field) if isinstance(dict_vt, variables.ConstDictVariable): return dict_vt.maybe_getitem_const(name_vt) return None - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: # Allow skipping of empty hook dict guards on inbuilt nn modules if name in ( "_backward_hooks", @@ -1244,7 +1289,9 @@ def var_getattr(self, tx: "InstructionTranslator", name): install_guard(hooks_dict_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) tx.output.guard_on_key_order.add(hooks_dict_source) - def build_key_value(i, k, v): + def build_key_value( + i: int, k: Any, v: Any + ) -> tuple[VariableTracker, VariableTracker]: # Make key sourceless to avoid any guard on it key = variables.ConstantVariable.create(k) @@ -1264,7 +1311,9 @@ def build_key_value(i, k, v): ) return super().var_getattr(tx, name) - def manually_trace_nn_module_getattr(self, tx: "InstructionTranslator", name): + def manually_trace_nn_module_getattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: """ Dynamo tracing of nn.Module __getattr__ can be expensive if the model has deep submodule hierarchy. Since the __getattr__ is stable, we can @@ -1283,6 +1332,7 @@ def manually_trace_nn_module_getattr(self, tx: "InstructionTranslator", name): tx, msg=f"'{type(self.value).__name__}' object has no attribute '{name}'", ) + assert out is not None return out @@ -1291,7 +1341,7 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable): Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules. """ - def _wrap_source(self, attr_source): + def _wrap_source(self, attr_source: Source) -> Source: # vt is already wrapped with the UnspecializedBuiltinNNModuleSource return attr_source @@ -1308,7 +1358,7 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable): compilation. """ - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: torch.nn.Module, **kwargs: Any) -> None: source = kwargs.get("source") assert source is not None, ( "FSDPManagedNNModule depends on having an accurate source to control guarding." @@ -1317,7 +1367,7 @@ def __init__(self, value, **kwargs) -> None: super().__init__(value=value, **kwargs) self.source = source - def _wrap_source(self, attr_source): + def _wrap_source(self, attr_source: Any) -> Any: if not isinstance( attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource) ): From 1aa13e17de39e3c768ea7aebaad166ce72a06676 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 1 Dec 2025 22:05:49 -0800 Subject: [PATCH 124/338] [dynamo] Fix more local test failures (#169228) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169228 Approved by: https://github.com/anijain2305 ghstack dependencies: #168928, #168927 --- .../tensor/test_dtensor_compile.py | 8 +- test/dynamo/test_backends.py | 23 +++++- test/dynamo/test_decorators.py | 32 +++++--- test/dynamo/test_flat_apply.py | 10 ++- test/dynamo/test_guard_serialization.py | 21 +++++ test/dynamo/test_higher_order_ops.py | 55 +++++++++---- test/dynamo/test_misc.py | 4 +- test/dynamo/test_modules.py | 18 +++++ test/dynamo/test_repros.py | 12 ++- test/dynamo/test_structured_trace.py | 81 ++++++++++++++++--- test/dynamo/test_trace_rules.py | 18 +++-- .../test_wrap_inductor_compiled_regions.py | 10 ++- test/inductor/test_compiled_autograd.py | 20 ++++- torch/_dynamo/testing.py | 19 ++++- torch/_inductor/debug.py | 8 ++ .../_internal/dynamo_pytree_test_utils.py | 28 +++++++ 16 files changed, 304 insertions(+), 63 deletions(-) create mode 100644 torch/testing/_internal/dynamo_pytree_test_utils.py diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index e58b6dda658f3..9b1734b9b8682 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -176,18 +176,14 @@ def shard_module_params(name, module, device_mesh): class TestDTensorCompile(torch._dynamo.test_case.TestCase): def setUp(self): - super( - type(self), self - ).setUp() # use explicit params for compiled autograd test wrapping + super().setUp() fake_store = FakeStore() dist.init_process_group( "fake", store=fake_store, rank=0, world_size=self.world_size ) def tearDown(self): - super( - type(self), self - ).tearDown() # use explicit params for compiled autograd test wrapping + super().tearDown() dist.destroy_process_group() @property diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index 28579f727b05a..f2cffbd48c02c 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -232,10 +232,18 @@ class TestCustomBackendAPI(torch._dynamo.test_case.TestCase): def test_register_backend_api(self): from torch._dynamo import register_backend + from torch._dynamo.backends import registry as backend_registry backend_run = False + backend_name = "my_custom_backend" - @register_backend + def cleanup_backend(): + backend_registry._COMPILER_FNS.pop(backend_name, None) + backend_registry._BACKENDS.pop(backend_name, None) + + self.addCleanup(cleanup_backend) + + @register_backend(name=backend_name) def my_custom_backend(gm, example_inputs): nonlocal backend_run backend_run = True @@ -317,6 +325,19 @@ def mock_eps(group=None): with patch("importlib.metadata.entry_points", mock_eps): from torch._dynamo.backends import registry + orig_backends = dict(registry._BACKENDS) + orig_compiler_fns = dict(registry._COMPILER_FNS) + + def restore_registry(): + registry._BACKENDS.clear() + registry._BACKENDS.update(orig_backends) + registry._COMPILER_FNS.clear() + registry._COMPILER_FNS.update(orig_compiler_fns) + registry._lazy_import.cache_clear() + registry._discover_entrypoint_backends.cache_clear() + + self.addCleanup(restore_registry) + registry._lazy_import.cache_clear() registry._discover_entrypoint_backends.cache_clear() diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 0e26ff2d4140b..67e04d3c2356e 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -7,18 +7,18 @@ from unittest.mock import patch import torch -import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters from torch.testing._internal.common_utils import skipIfWindows +from torch.testing._internal.dynamo_pytree_test_utils import PytreeRegisteringTestCase def my_custom_function(x): return x + 1 -class DecoratorTests(torch._dynamo.test_case.TestCase): +class DecoratorTests(PytreeRegisteringTestCase): def test_disallow_in_graph(self): cnts = torch._dynamo.testing.CompileCounter() @@ -329,10 +329,11 @@ def __init__(self, x, y): self.x = x self.y = y - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( Point, lambda p: ((p.x, p.y), ()), lambda xy, _: Point(xy[0], xy[1]), + serialized_type_name=f"{Point.__module__}.{Point.__qualname__}", ) @torch._dynamo.nonstrict_trace @@ -360,10 +361,11 @@ def __init__(self, x, y): self.x = x self.y = y - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( Point, lambda p: ((p.x, p.y), ()), lambda xy, _: Point(xy[0], xy[1]), + serialized_type_name=f"{Point.__module__}.{Point.__qualname__}", ) @torch._dynamo.nonstrict_trace @@ -396,10 +398,11 @@ def __init__(self, x, y): self.x = x self.y = y - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( Point, lambda p: ((p.x, p.y), ()), lambda xy, _: Point(xy[0], xy[1]), + serialized_type_name=f"{Point.__module__}.{Point.__qualname__}", ) @torch._dynamo.nonstrict_trace @@ -438,16 +441,18 @@ def __init__(self, p, t): self.p = p self.t = t - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( PointTensor, lambda pt: ((pt.p, pt.t), ()), lambda pt, _: PointTensor(pt[0], pt[1]), + serialized_type_name=f"{PointTensor.__module__}.{PointTensor.__qualname__}", ) - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( Point, lambda p: ((p.x, p.y), ()), lambda xy, _: Point(xy[0], xy[1]), + serialized_type_name=f"{Point.__module__}.{Point.__qualname__}", ) def trace_point(p): @@ -491,7 +496,7 @@ def __hash__(self): # Assume `State` is implemented in C, and the author didn't bother to # provide a pytree decomposition for it, and its instances are safe to # treat as a constant by `torch.compile`. - torch.utils._pytree.register_constant(State) + self.register_constant(State) @torch._dynamo.nonstrict_trace def trace_me(x, s): @@ -592,10 +597,11 @@ def trace_me(self, t): torch._dynamo.graph_break() return t + self.n - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( Num, lambda num: ((num.n,), ()), lambda n, _: Num(n[0]), + serialized_type_name=f"{Num.__module__}.{Num.__qualname__}", ) def fn(x, n): @@ -709,10 +715,11 @@ def __init__(self, p, t): self.p = p self.t = t - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( PointTensor, lambda pt: ((pt.p, pt.t), ()), lambda pt, _: PointTensor(pt[0], pt[1]), + serialized_type_name=f"{PointTensor.__module__}.{PointTensor.__qualname__}", ) def trace_point(p): @@ -784,7 +791,7 @@ def __hash__(self): # Assume `State` is implemented in C, and the author didn't bother to # provide a pytree decomposition for it, and its instances are safe to # treat as a constant by `torch.compile`. - torch.utils._pytree.register_constant(State) + self.register_constant(State) @torch._dynamo.nonstrict_trace def trace_me(x, s): @@ -823,10 +830,11 @@ def __init__(self, p, t): self.p = p self.t = t - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( PointTensor, lambda pt: ((pt.t,), pt.p), lambda ts, p: PointTensor(p, ts[0]), + serialized_type_name=f"{PointTensor.__module__}.{PointTensor.__qualname__}", ) @torch._dynamo.nonstrict_trace diff --git a/test/dynamo/test_flat_apply.py b/test/dynamo/test_flat_apply.py index aad5d6b281568..344c271c4b115 100644 --- a/test/dynamo/test_flat_apply.py +++ b/test/dynamo/test_flat_apply.py @@ -2,7 +2,6 @@ from dataclasses import dataclass import torch -import torch._dynamo.test_case import torch.utils._pytree as pytree from torch._dynamo.testing import ( AotEagerAndRecordGraphs, @@ -15,6 +14,7 @@ is_graphable, to_graphable, ) +from torch.testing._internal.dynamo_pytree_test_utils import PytreeRegisteringTestCase def distance(a, b, norm): @@ -41,7 +41,7 @@ class Point: pytree.register_dataclass(Point) -class FlatApplyTests(torch._dynamo.test_case.TestCase): +class FlatApplyTests(PytreeRegisteringTestCase): def test_simple(self): tensor = torch.tensor @@ -105,16 +105,18 @@ def __init__(self, p, t): self.p = p self.t = t - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( PointTensor, lambda pt: ((pt.p, pt.t), ()), lambda pt, _: PointTensor(pt[0], pt[1]), + serialized_type_name=f"{PointTensor.__module__}.{PointTensor.__qualname__}", ) - torch.utils._pytree.register_pytree_node( + self.register_pytree_node( Point, lambda p: ((p.x, p.y), ()), lambda xy, _: Point(xy[0], xy[1]), + serialized_type_name=f"{Point.__module__}.{Point.__qualname__}", ) def trace_point(p): diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index ec333ed5b0dc7..927040c1836ce 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -14,6 +14,7 @@ import torch._dynamo.testing import torch._inductor.config import torch._inductor.test_case +import torch.fx.graph as fx_graph import torch.onnx.operators import torch.utils.cpp_extension from torch._dynamo.bytecode_transformation import transform_code_object @@ -303,6 +304,26 @@ def __hash__(self): class TestGuardSerializationBase(torch._inductor.test_case.TestCase): + def setUp(self): + super().setUp() + self._fx_magic_methods_snapshot = fx_graph.magic_methods.copy() + self._saved_default_device_context = getattr( + torch._GLOBAL_DEVICE_CONTEXT, "device_context", None + ) + + def tearDown(self): + fx_graph.magic_methods.clear() + fx_graph.magic_methods.update(self._fx_magic_methods_snapshot) + + current_ctx = getattr(torch._GLOBAL_DEVICE_CONTEXT, "device_context", None) + if current_ctx is not self._saved_default_device_context: + if self._saved_default_device_context is None: + torch.set_default_device(None) + else: + torch.set_default_device(self._saved_default_device_context.device) + + super().tearDown() + def _tracefunc(self, frame, event, arg): if event != "call": return diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 1f1a92b8c2b2b..34660044a3a42 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1149,6 +1149,18 @@ def test_register_subclass(self): a = torch.tensor([1.0, 0.0, 1.0]) b = torch.randn(3) t = TwoTensor(a, b) + + prev_impl = cond_op.python_key_table.pop(TwoTensor, None) + cond_op._dispatch_cache.clear() + + def restore_twotensor_impl(): + cond_op.python_key_table.pop(TwoTensor, None) + if prev_impl is not None: + cond_op.python_key_table[TwoTensor] = prev_impl + cond_op._dispatch_cache.clear() + + self.addCleanup(restore_twotensor_impl) + with self.assertRaisesRegex( NotImplementedError, "no rule registered for HOP cond and subclass .*TwoTensor'>", @@ -3763,23 +3775,38 @@ def tearDown(self): # because of a previous call to _vmap_increment_nesting that wasn't undone # i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1 # and the call to increment nesting is not undone - if not TEST_WITH_TORCHDYNAMO: - return + try: + if TEST_WITH_TORCHDYNAMO: + warn = False + while ci := torch._C._functorch.peek_interpreter_stack(): + if ci.key() == torch._C._functorch.TransformType.Vmap: + warn = True + torch._C._functorch._vmap_decrement_nesting() + else: + break + + if warn: + msg = ( + "Interpreter stack is not empty. Test should have called " + "'torch._C._functorch._vmap_decrement_nesting()'" + ) + warnings.warn(msg) + finally: + super().tearDown() - warn = False - while ci := torch._C._functorch.peek_interpreter_stack(): - if ci.key() == torch._C._functorch.TransformType.Vmap: - warn = True - torch._C._functorch._vmap_decrement_nesting() - else: - break + def test_teardown_resets_nested_graph_breaks(self): + expected_nested_state = getattr( + self, "prev_nested_graph_breaks", torch._dynamo.config.nested_graph_breaks + ) - if warn: - msg = ( - "Interpreter stack is not empty. Test should have called " - "'torch._C._functorch._vmap_decrement_nesting()'" + def _check_flag(): + self.assertEqual( + torch._dynamo.config.nested_graph_breaks, expected_nested_state ) - warnings.warn(msg) + + self.addCleanup(_check_flag) + # Sanity check: these tests always run with nested graph breaks enabled. + self.assertTrue(torch._dynamo.config.nested_graph_breaks) def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0): backend = EagerAndRecordGraphs() diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 842355b57b94a..98526a9ba0283 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10154,7 +10154,9 @@ def test_validate_outputs_unbacked_by_custom_op(self): ) @torch.library.impl("mylib::foo_validate_outputs_unbacked", "cpu", lib=lib) - @torch.library.register_fake("mylib::foo_validate_outputs_unbacked") + @torch.library.register_fake( + "mylib::foo_validate_outputs_unbacked", lib=lib + ) def foo_impl(x, y): return torch.cat([x, y]) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index bacab94e345d4..6fd1e6b477f36 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3117,6 +3117,24 @@ def forward(self, x): self.assertFalse(hasattr(compiled_model, "foo")) def test_globals_change_in_other_file(self): + global _variable, _variable1 + + prev_variable = _variable + prev_variable1 = _variable1 + prev_test_functions_variable = test_functions._variable + + def restore_globals(): + global _variable, _variable1 + _variable = prev_variable + _variable1 = prev_variable1 + test_functions._variable = prev_test_functions_variable + + self.addCleanup(restore_globals) + + _variable = 0 + _variable1 = 0 + test_functions._variable = 0 + @torch.compile(backend="eager", fullgraph=True) def fn(x): # Let `update_global` get invoked in a nested frame, to make sure diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index a07bd92331faa..17e8c15863e27 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7242,11 +7242,13 @@ def callback(code, offset): elif compiled_graph and code is compiled_graph.__call__.__code__: found_compiled_graph = True - sys.monitoring.use_tool_id(0, "test") + tool_id = 0 + sys.monitoring.use_tool_id(tool_id, "test") + old_events = sys.monitoring.get_events(tool_id) old_callback = sys.monitoring.register_callback( - 0, sys.monitoring.events.PY_START, callback + tool_id, sys.monitoring.events.PY_START, callback ) - sys.monitoring.set_events(0, sys.monitoring.events.PY_START) + sys.monitoring.set_events(tool_id, sys.monitoring.events.PY_START) try: @torch.compile(backend=backend, fullgraph=True) @@ -7259,9 +7261,11 @@ def fn(x): # sys.monitoring should still run on the compiled graph self.assertTrue(found_compiled_graph) finally: + sys.monitoring.set_events(tool_id, old_events) sys.monitoring.register_callback( - 0, sys.monitoring.events.PY_START, old_callback + tool_id, sys.monitoring.events.PY_START, old_callback ) + sys.monitoring.free_tool_id(tool_id) def test_312_local_cell_overlap(self): keys = range(10) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 21cf04cffbf65..4bd1b251f86d4 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -97,7 +97,53 @@ def format(self, record): return record.payload.strip() +class _DescribeIdNormalizer: + def __init__(self): + self._tensor_id_remap = {} + self._storage_id_remap = {} + self._next_tensor_id = 0 + self._next_storage_id = 0 + + def normalize(self, metadata): + if "describe_storage" in metadata: + storage_meta = metadata["describe_storage"] + if (storage_id := storage_meta.get("id")) is not None: + storage_meta["id"] = self._normalize_storage_id(storage_id) + storage_meta["describer_id"] = "ID" + if "describe_tensor" in metadata: + tensor_meta = metadata["describe_tensor"] + if (tensor_id := tensor_meta.get("id")) is not None: + tensor_meta["id"] = self._normalize_tensor_id(tensor_id) + if (storage_id := tensor_meta.get("storage")) is not None: + tensor_meta["storage"] = self._normalize_storage_id(storage_id) + tensor_meta["describer_id"] = "ID" + if "view_func" in tensor_meta: + tensor_meta["view_func"] = "VIEW_FUNC" + if "describe_source" in metadata: + source_meta = metadata["describe_source"] + if (source_id := source_meta.get("id")) is not None: + source_meta["id"] = self._normalize_tensor_id(source_id) + source_meta["describer_id"] = "ID" + return metadata + + def _normalize_tensor_id(self, original_id): + if original_id not in self._tensor_id_remap: + self._tensor_id_remap[original_id] = self._next_tensor_id + self._next_tensor_id += 1 + return self._tensor_id_remap[original_id] + + def _normalize_storage_id(self, original_id): + if original_id not in self._storage_id_remap: + self._storage_id_remap[original_id] = self._next_storage_id + self._next_storage_id += 1 + return self._storage_id_remap[original_id] + + class StructuredTraceTestingFormatter(logging.Formatter): + def __init__(self): + super().__init__() + self._id_normalizer = _DescribeIdNormalizer() + def format(self, record): metadata = copy.deepcopy(record.metadata) @@ -121,14 +167,7 @@ def format(self, record): metadata["compilation_metrics_runtime"] = "METRICS" if "bwd_compilation_metrics_runtime" in metadata: metadata["bwd_compilation_metrics_runtime"] = "METRICS" - if "describe_storage" in metadata: - metadata["describe_storage"]["describer_id"] = "ID" - if "describe_tensor" in metadata: - metadata["describe_tensor"]["describer_id"] = "ID" - if "view_func" in metadata["describe_tensor"]: - metadata["describe_tensor"]["view_func"] = "VIEW_FUNC" - if "describe_source" in metadata: - metadata["describe_source"]["describer_id"] = "ID" + metadata = self._id_normalizer.normalize(metadata) if ( (k := "create_symbol") in metadata or (k := "guard_added_fast") in metadata @@ -198,8 +237,8 @@ def tearDown(self): def assertExpectedInline(self, actual, expected): super().assertExpectedInline( - self._normalize_rank_field(actual), - self._normalize_rank_field(expected), + self._normalize_rank_field(self._normalize_describe_ids(actual)), + self._normalize_rank_field(self._normalize_describe_ids(expected)), ) @staticmethod @@ -211,6 +250,28 @@ def _normalize_rank_field(text): text = text.replace('"rank": 0', "") return text + @staticmethod + def _normalize_describe_ids(text): + if not isinstance(text, str): + return text + normalizer = _DescribeIdNormalizer() + trailing_newline = text.endswith("\n") + normalized_lines = [] + for line in text.splitlines(): + if not line: + normalized_lines.append(line) + continue + try: + metadata = json.loads(line) + except json.JSONDecodeError: + normalized_lines.append(line) + continue + normalized_lines.append(json.dumps(normalizer.normalize(metadata))) + result = "\n".join(normalized_lines) + if trailing_newline: + result += "\n" + return result + def assertParses(self): if not HAS_TLPARSE: self.skipTest("requires tlparse") diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index e9c6df7e959f8..fb7b70963e778 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -443,12 +443,18 @@ def fn(x): ), ): # First adding the module to SKIP_DIRS so that it will be skipped by default. - torch._dynamo.trace_rules.add(mod.__name__) - x = torch.rand(3) - opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) - ref = fn(x) - res = opt_fn(x) - self.assertEqual(ref, res) + skip_dirs_backup = torch._dynamo.trace_rules.SKIP_DIRS.copy() + skip_dirs_re_backup = torch._dynamo.trace_rules.SKIP_DIRS_RE + try: + torch._dynamo.trace_rules.add(mod.__name__) + x = torch.rand(3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + finally: + torch._dynamo.trace_rules.SKIP_DIRS = skip_dirs_backup + torch._dynamo.trace_rules.SKIP_DIRS_RE = skip_dirs_re_backup def test_no_special_handlers_for_torch_non_c_bindings(self): handlers = TorchInGraphFunctionVariable._get_handlers() diff --git a/test/dynamo/test_wrap_inductor_compiled_regions.py b/test/dynamo/test_wrap_inductor_compiled_regions.py index 5c2f23e30e30d..20f1b91ebd687 100644 --- a/test/dynamo/test_wrap_inductor_compiled_regions.py +++ b/test/dynamo/test_wrap_inductor_compiled_regions.py @@ -941,7 +941,7 @@ def test_wrap_no_dispatch_mode_no_hop_invoked(self): # Patch it in the output_code module where it's imported and used patch_path = "torch._inductor.output_code.inductor_compiled_code" - # Test WITHOUT dispatch mode - HOP should NOT be called + # Test WITHOUT dispatch mode - HOP should not route through a mode with patch(patch_path, wraps=inductor_compiled_code) as mock_hop: @torch.compile( @@ -958,10 +958,14 @@ def fn(x, y): result_without = fn(x, y) - # Verify HOP was NOT called - mock_hop.assert_not_called() self.assertEqual(result_without, expected) + if mock_hop.called: + args, kwargs = mock_hop.call_args + # When no dispatch modes are active, we expect mode argument to be None + # (wrapper is used purely for tracing alignment). + self.assertIsNone(kwargs.get("mode")) + # Test WITH DebugMode - HOP SHOULD be called with patch(patch_path, wraps=inductor_compiled_code) as mock_hop: diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index ede884e0f52bb..f99845fd5d6b8 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -5104,13 +5104,31 @@ def wrap_test_class(orig_cls): cls = type( orig_cls.__name__ + "WithCompiledAutograd", - orig_cls.__bases__, + (orig_cls,), dct, ) cls.__file__ = __file__ return cls +class WrapTestClassTests(TestCase): + def test_wrap_preserves_inheritance_and_super(self): + class DummyTest(unittest.TestCase): + def runTest(self): + pass + + def tearDown(self): + self.super_called = True + super().tearDown() + + wrapped = wrap_test_class(DummyTest) + self.assertTrue(issubclass(wrapped, DummyTest)) + test = wrapped("runTest") + test.setUp() + test.tearDown() + self.assertTrue(getattr(test, "super_called", False)) + + known_graph_breaks_tests = { "test_hook_none", # uses assert in hook "test_post_accumulate_grad_hook_e2e", # optim.Adam manually graph breaks diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 3eeedfb65da20..4d11cc0cf2101 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -353,10 +353,27 @@ def remove_trailing_space(code: str) -> str: return "\n".join([line.rstrip() for line in code.split("\n")]) +def _squash_blank_lines(code: str) -> str: + lines = code.split("\n") + result: list[str] = [] + saw_blank = False + for line in lines: + if line.strip() == "": + if saw_blank: + continue + saw_blank = True + else: + saw_blank = False + result.append(line) + return "\n".join(result) + + def normalize_gm(gm_str: str) -> str: # strip comments as comments have path to files which may differ from # system to system. - return remove_trailing_space(strip_comment(gm_str)) + stripped = strip_comment(gm_str) + no_trailing = remove_trailing_space(stripped) + return _squash_blank_lines(no_trailing) def empty_line_normalizer(code: str) -> str: diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index ef57c5065cc1c..39c90bdea94ff 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -346,6 +346,7 @@ def reset_provenance_globals() -> Iterator[None]: global _inductor_triton_kernel_to_post_grad_node_info global _inductor_pre_grad_node_stack_trace global _inductor_kernel_stack_trace + global _inductor_kernel_provenance_debug_handle # Store original values original_pre_grad_graph_id = _pre_grad_graph_id @@ -357,6 +358,9 @@ def reset_provenance_globals() -> Iterator[None]: _inductor_pre_grad_node_stack_trace.copy() ) original_inductor_kernel_stack_trace = _inductor_kernel_stack_trace.copy() + original_inductor_kernel_provenance_debug_handle = ( + _inductor_kernel_provenance_debug_handle + ) # Reset to default values _pre_grad_graph_id = -1 @@ -364,6 +368,7 @@ def reset_provenance_globals() -> Iterator[None]: _inductor_triton_kernel_to_post_grad_node_info = {} _inductor_pre_grad_node_stack_trace = {} _inductor_kernel_stack_trace = {} + _inductor_kernel_provenance_debug_handle = 0 try: yield @@ -378,6 +383,9 @@ def reset_provenance_globals() -> Iterator[None]: _inductor_pre_grad_node_stack_trace = ( original_inductor_pre_grad_node_stack_trace ) + _inductor_kernel_provenance_debug_handle = ( + original_inductor_kernel_provenance_debug_handle + ) class DebugContext: diff --git a/torch/testing/_internal/dynamo_pytree_test_utils.py b/torch/testing/_internal/dynamo_pytree_test_utils.py new file mode 100644 index 0000000000000..737b7d27a1561 --- /dev/null +++ b/torch/testing/_internal/dynamo_pytree_test_utils.py @@ -0,0 +1,28 @@ +import torch +import torch._dynamo.test_case +import torch.utils._pytree as pytree + + +class PytreeRegisteringTestCase(torch._dynamo.test_case.TestCase): + """TestCase that prunes all temporary pytree registrations and resets Dynamo.""" + + def setUp(self) -> None: + super().setUp() + self._registered_pytree_nodes: list[type] = [] + self._registered_constant_nodes: list[type] = [] + + def tearDown(self) -> None: + for cls in reversed(self._registered_pytree_nodes): + pytree._deregister_pytree_node(cls) + for cls in reversed(self._registered_constant_nodes): + pytree._deregister_pytree_node(cls) + torch._dynamo.reset() + super().tearDown() + + def register_pytree_node(self, cls, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + pytree.register_pytree_node(cls, *args, **kwargs) + self._registered_pytree_nodes.append(cls) + + def register_constant(self, cls: type) -> None: + pytree.register_constant(cls) + self._registered_constant_nodes.append(cls) From c6ae7579fe12fe75f1a8f7043a494c90567273f1 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 1 Dec 2025 17:35:43 +0000 Subject: [PATCH 125/338] [Submodule] Update to cutlass 4.3 (#168308) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168308 Approved by: https://github.com/ngimel, https://github.com/slayton58, https://github.com/Skylion007 --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index f3fde58372d33..e67e63c331d6e 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit f3fde58372d33e9a5650ba7b80fc48b3b49d40c8 +Subproject commit e67e63c331d6e4b729047c95cf6b92c8454cba89 From 93d0d6838c56af59b0dba794e6aa08f0c1c7799c Mon Sep 17 00:00:00 2001 From: atalman Date: Tue, 2 Dec 2025 17:28:48 +0000 Subject: [PATCH 126/338] Triton 3.6 pin update (#168096) Required for release 2.10 Rocm wheel build fix provided by: https://github.com/pytorch/pytorch/pull/169369 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168096 Approved by: https://github.com/njriasan, https://github.com/malfet --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- .github/scripts/amd/package_triton_wheel.sh | 1 + .../rocm/dynamic_inductor_timm_training.csv | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 7aab8bed1c108..263fcf2e0bdbb 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 +5261b27331eb1dd09df9ec1bd6acc21cbb184481 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index d5c0c99142898..40c341bdcdbe8 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.1 +3.6.0 diff --git a/.github/scripts/amd/package_triton_wheel.sh b/.github/scripts/amd/package_triton_wheel.sh index fe8d915422dac..501e50e2fe2f1 100755 --- a/.github/scripts/amd/package_triton_wheel.sh +++ b/.github/scripts/amd/package_triton_wheel.sh @@ -87,6 +87,7 @@ done cp -r $ROCM_HOME/include/hip $TRITON_ROCM_DIR/include cp -r $ROCM_HOME/include/roctracer $TRITON_ROCM_DIR/include cp -r $ROCM_HOME/include/hsa $TRITON_ROCM_DIR/include +cp -r $ROCM_HOME/include/hipblas-common $TRITON_ROCM_DIR/include # Copy linker mkdir -p $TRITON_ROCM_DIR/llvm/bin diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv index 2d087e6595526..702da0cb57f89 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv @@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7 -convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 +convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7 From 7a41b66367c38d0af3e8a90f7be48d6b281e7bca Mon Sep 17 00:00:00 2001 From: karthickai Date: Mon, 1 Dec 2025 19:05:26 -0800 Subject: [PATCH 127/338] [Inductor] handle GroupedSchedulerNode in combo kernel fusion (#168109) Fixes: #168105 combo_kernels crashes with GroupedSchedulerNode Pull Request resolved: https://github.com/pytorch/pytorch/pull/168109 Approved by: https://github.com/mlazos --- .../test_compute_comm_reordering.py | 51 +++++++++++-------- torch/_inductor/scheduler.py | 16 +++++- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index a13611a53609f..2e9a3ea171028 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -29,6 +29,10 @@ requires_accelerator_dist_backend, ) from torch.testing._internal.common_fsdp import get_devtype +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) from torch.testing._internal.inductor_utils import HAS_GPU @@ -82,6 +86,7 @@ def create_grouped_node_for_allreduce_and_its_deps(snodes): torch._inductor.config.triton.native_matmul, "native matmul is fused with surrounding ops", ) +@instantiate_parametrized_tests class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): """ Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under @@ -382,7 +387,8 @@ def func(a, *, tag, ranks, group_size): "_pre_fusion_custom_pass", create_grouped_node_for_allreduce_and_its_deps, ) - def test_grouped_scheduler_node(self): + @parametrize("combo_kernels", (False, True)) + def test_grouped_scheduler_node(self, combo_kernels): def func(a, *, tag, ranks, group_size): add = a + a div = add / a @@ -394,26 +400,29 @@ def func(a, *, tag, ranks, group_size): mm = torch.matmul(mul, ar) return (mm,) - with _dynamo_dist_per_rank_init( - self.rank, - self.world_size, - self.backend(device_type), - fake_pg=not at_least_x_gpu(2), - ): - inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank - compiled = torch.compile(func) - code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) - # Expectations: - # 1. `add = a + a` and `div = add / a` are still fused, which means fusion - # still happens among nodes within a GroupedSchedulerNode. - # 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within - # GroupedSchedulerNode and thus are prevented from being fused with any outside ops. - FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check( - "_c10d_functional.all_reduce_." - ).check("triton_poi_fused_mul_1.").run(code) - out = compiled(inputs, **self.get_world_trs()) - correct = func(inputs, **self.get_world_trs()) - self.assertTrue(same(out, correct)) + with torch._inductor.config.patch(combo_kernels=combo_kernels): + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + inputs = ( + torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank + ) + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) + # Expectations: + # 1. `add = a + a` and `div = add / a` are still fused, which means fusion + # still happens among nodes within a GroupedSchedulerNode. + # 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within + # GroupedSchedulerNode and thus are prevented from being fused with any outside ops. + FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check( + "_c10d_functional.all_reduce_." + ).check("triton_poi_fused_mul_1.").run(code) + out = compiled(inputs, **self.get_world_trs()) + correct = func(inputs, **self.get_world_trs()) + self.assertTrue(same(out, correct)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch(force_disable_caches=True) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index b084612b9acc7..f285a65470e78 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2283,10 +2283,23 @@ def combinable_nodes( len(extern), [node.node.get_origins() for node in extern if node.node is not None], ) + grouped = [x for x in nodes if isinstance(x, GroupedSchedulerNode)] + if grouped: + log.debug( + "ComboKernels: %d grouped nodes are filtered", + len(grouped), + ) filtered_nodes = [ x for x in nodes - if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode)) + if not isinstance( + x, + ( + NopKernelSchedulerNode, + ExternKernelSchedulerNode, + GroupedSchedulerNode, + ), + ) ] foreach_nodes = [ x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode) @@ -3291,6 +3304,7 @@ def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNo ExternKernelSchedulerNode, NopKernelSchedulerNode, FusedSchedulerNode, + GroupedSchedulerNode, ), ): for dep in snode.unmet_dependencies: From 1cee47d6ce0a02227185b566593f002dd639ca0c Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 2 Dec 2025 07:33:36 -0800 Subject: [PATCH 128/338] REmove AC specific flag and reunify with HOPs (#169132) Replaces checkpoint-specific side effects mechanism with a unified `allow_side_effects_with_extra_outputs` flag that works for all HOPs. We do this by reusing `restore_side_effects` flag as much as possible. Based on my understanding, this flag works by: **When `restore_side_effects = True` (default)**: - Side effects data structure is saved before tracing and restored after - Intermediate tensors/symints are NOT captured as extra outputs - Used by most HOPs **When `restore_side_effects = False`**: - Side effects data structure is NOT restored - Intermediate tensors/symints ARE captured and returned as extra outputs - Used by `invoke_subgraph` (always) and `checkpoint` (when `config.skip_fwd_side_effects_in_bwd_under_checkpoint=True`) We had to keep some `under_activation_checkpointing` due to legacy FSDP changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169132 Approved by: https://github.com/anijain2305 ghstack dependencies: #169131 --- test/higher_order_ops/test_invoke_subgraph.py | 1 + torch/_dynamo/output_graph.py | 20 +-- torch/_dynamo/side_effects.py | 28 ++-- torch/_dynamo/variables/functions.py | 7 +- torch/_dynamo/variables/higher_order_ops.py | 123 ++++++++++++++---- 5 files changed, 125 insertions(+), 54 deletions(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 67c4fa0757769..a5a02e4143527 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -2588,6 +2588,7 @@ def f(x, other): self.assertEqual(f(x, other), f_compile(x, other)) self.assertTrue(called) + @unittest.expectedFailure def test_udf_output(self): class Foo: def __init__(self, a, b): diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 414051bcaa1d9..981c441bd2986 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -926,7 +926,10 @@ def remove_node(self, *args: Any, **kwargs: Any) -> None: @contextlib.contextmanager def subtracer( - self, source_target: Optional[Target], prior_tracer: "SubgraphTracer" + self, + source_target: Optional[Target], + prior_tracer: "SubgraphTracer", + description: Optional[str] = None, ) -> Generator[fx.Tracer, None, None]: new_scope_ctx = enter_new_scope() try: @@ -942,6 +945,7 @@ def subtracer( parent=self.current_tracer, source_target=source_target, is_export=self.current_tracer.is_export, + description=description, ) ) self.tracers.append(tracer) @@ -2924,6 +2928,7 @@ def __init__( parent: Optional["SubgraphTracer"] = None, is_export: bool = False, source_target: Optional[Target] = None, + description: Optional[str] = None, ) -> None: super().__init__() self.output_graph = weakref.proxy(output_graph) @@ -2941,6 +2946,7 @@ def __init__( # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design] self.parent = parent self.source_target = source_target + self.description = description # A dict mapping previously free variables (Proxy objects) # to new Proxy objects that wrap inputs to this subgraph. # @@ -2968,19 +2974,15 @@ def __init__( self.dynamic_scalar_nodes: dict[int, torch.SymInt] = {} self.prev_inst = None - # True if this tracer is currently tracing into torch.utils.checkpoint - # as part of speculate_subgraph. - self.under_activation_checkpoint = False - # True if we want to allow externally visible side-effects (doesn't throw error on their existence) - # during this tracer's tracing of torch.utils.checkpoint (via speculate_subgraph). - # Only safe if we know for sure that *NOT* replaying these side-effects during - # backward recomputation of the checkpoint region doesn't affect its correctness. - self.allow_side_effects_under_checkpoint = False # True if we want to allow externally visible side-effects (doesn't throw error on their existence) # during this tracer's tracing. This is currently only used by experimental AC out-of-tree # via torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer. # Note: Externally visible side-effects are allowed if this flag OR the above flag is True. self.unsafe_allow_externally_visible_side_effects = False + # True if we want to allow side effects by returning them as extra outputs from the subgraph. + # This is set when enable_side_effects_in_hop=True for HOPs like invoke_subgraph + # and checkpoint (when skip_fwd_side_effects_in_bwd_under_checkpoint config is True). + self.allow_side_effects_in_hop = False # True if this tracer is currently tracing (reconstructing) into a Python generator self.is_reconstructing_generator = False diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 95ebeeb7f0a6d..e153d7489c7d9 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -213,22 +213,18 @@ def __contains__(self, item: Any) -> bool: def __getitem__(self, item: Any) -> VariableTracker: return self.id_to_variable[id(item)] - def should_allow_side_effects_under_checkpoint(self) -> bool: + def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool: output_graph = self.output_graph_weakref() return bool( output_graph - and output_graph.current_tx.output.current_tracer.under_activation_checkpoint - and ( - output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint - or torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint - ) + and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects ) - def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool: + def should_allow_side_effects_in_hop(self) -> bool: output_graph = self.output_graph_weakref() return bool( output_graph - and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + and output_graph.current_tx.output.current_tracer.allow_side_effects_in_hop ) def is_reconstructing_generator(self) -> bool: @@ -248,7 +244,7 @@ def check_allowed_side_effect(self, item: VariableTracker) -> bool: return True if self.should_allow_externally_visible_side_effects_in_subtracer(): return True - if self.should_allow_side_effects_under_checkpoint(): + if self.should_allow_side_effects_in_hop(): return True if self.is_reconstructing_generator(): # This is missing the case where one mutates a tensor. See @@ -1200,16 +1196,20 @@ def clear(self) -> None: @contextlib.contextmanager -def allow_side_effects_under_checkpoint( +def allow_side_effects_in_hop( tx: "InstructionTranslatorBase", ) -> Generator[None, None, None]: - assert tx.output.current_tracer.under_activation_checkpoint - orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint + """Context manager to temporarily allow side effects with extra outputs. + + This is used for special cases (like FSDP functions) that need to perform + side effects even when the general policy is to disallow them. + """ + orig_val = tx.output.current_tracer.allow_side_effects_in_hop try: - tx.output.current_tracer.allow_side_effects_under_checkpoint = True + tx.output.current_tracer.allow_side_effects_in_hop = True yield finally: - tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val + tx.output.current_tracer.allow_side_effects_in_hop = orig_val @contextlib.contextmanager diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 360c0fdd94488..02bbcebe5c02a 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -654,8 +654,9 @@ def call_function( return super().call_function(tx, args, kwargs) if ( - tx.output.current_tracer.under_activation_checkpoint - and not tx.output.current_tracer.allow_side_effects_under_checkpoint + getattr(tx.output.current_tracer, "description", None) + == "torch.utils.checkpoint.checkpoint" + and not tx.output.current_tracer.allow_side_effects_in_hop ): try: from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState @@ -665,7 +666,7 @@ def call_function( FSDPState._pre_forward, FSDPState._post_forward, ]: - with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): + with torch._dynamo.side_effects.allow_side_effects_in_hop(tx): return super().call_function(tx, args, kwargs) tree_map_result = self._maybe_call_tree_map_fastpath(tx, args, kwargs) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 8b178b3be1ac3..3524cb142cdd7 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -195,13 +195,13 @@ def dynamo_enable_grad(tx: "InstructionTranslator", enable=True): @contextlib.contextmanager -def dynamo_under_activation_checkpoint(tx: "InstructionTranslator"): - orig_val = tx.output.current_tracer.under_activation_checkpoint +def dynamo_allow_side_effects_in_hop(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.allow_side_effects_in_hop try: - tx.output.current_tracer.under_activation_checkpoint = True + tx.output.current_tracer.allow_side_effects_in_hop = True yield finally: - tx.output.current_tracer.under_activation_checkpoint = orig_val + tx.output.current_tracer.allow_side_effects_in_hop = orig_val def find_mismatched_vars(var, types, allow_none=False): @@ -393,6 +393,31 @@ def _assert_tensors_nonaliasing(inputs, outputs): ) +def _collect_intermediate_outputs(tx, subtracer, graph_output_vts): + """ + Collect intermediate outputs for side effects support. + + Returns all tracked tensor/symint variables that are not already in graph_output_vts. + """ + extra_outputs = [] + existing_out_proxies = {vt.as_proxy() for vt in graph_output_vts} + + for out in subtracer.tracked_tensor_or_symint_vt: + proxy = out.as_proxy() + + # Skip if already in output + if proxy in existing_out_proxies: + continue + + # TODO floats are not supported in HOP input/output + if isinstance(out, SymNodeVariable) and out.python_type() is float: + continue + + extra_outputs.append(out) + + return extra_outputs + + def _check_all_tensorvariable(args): from . import TensorVariable @@ -1108,19 +1133,25 @@ def trace_hop_function( tx, subtracer, enable_grad, - under_activation_checkpoint, restore_side_effects, args, sub_kwargs, ): + # For autograd.Function and other legacy HOPs, we do NOT couple + # restore_side_effects with allow_side_effects_in_hop. + # This preserves the old behavior where: + # - restore_side_effects=False means ctx mutations persist + # - But non-ctx side effects still cause graph breaks (under_activation_checkpoint was False) + enable_side_effects_with_extra_outputs = False + autograd_ctx = ( dynamo_enable_grad(tx, enable_grad) if enable_grad is not None else contextlib.nullcontext() ) - checkpoint_ctx = ( - dynamo_under_activation_checkpoint(tx) - if under_activation_checkpoint + side_effects_ctx = ( + dynamo_allow_side_effects_in_hop(tx) + if enable_side_effects_with_extra_outputs else contextlib.nullcontext() ) @@ -1142,7 +1173,48 @@ def trace_hop_function( if restore_side_effects: prev_side_effects = tx.output.side_effects.clone() - with autograd_ctx, checkpoint_ctx: + with autograd_ctx, side_effects_ctx: + output = f.call_function(tx, args, sub_kwargs) + + if restore_side_effects: + new_side_effects = tx.output.side_effects.clone() + prev_side_effects.track_runahead_tensor_and_symvar_side_effects( + new_side_effects + ) + tx.output.side_effects = prev_side_effects + return output + + +def trace_hop_function_with_auto_output_flattening( + f, + tx, + subtracer, + enable_grad, + restore_side_effects, + args, + sub_kwargs, +): + # For the new unified control flow ops, we couple restore_side_effects + # with allow_side_effects_in_hop using the new semantics: + # - restore_side_effects=False means side effects become extra outputs + # - This allows mutations to be tracked and replayed + enable_side_effects_with_extra_outputs = not restore_side_effects + + autograd_ctx = ( + dynamo_enable_grad(tx, enable_grad) + if enable_grad is not None + else contextlib.nullcontext() + ) + side_effects_ctx = ( + dynamo_allow_side_effects_in_hop(tx) + if enable_side_effects_with_extra_outputs + else contextlib.nullcontext() + ) + + if restore_side_effects: + prev_side_effects = tx.output.side_effects.clone() + + with autograd_ctx, side_effects_ctx: output = f.call_function(tx, args, sub_kwargs) if restore_side_effects: @@ -1199,9 +1271,7 @@ def speculate_subgraph_with_auto_output_flattening( set_subgraph_inputs: Literal[ "automatic", "semi_automatic", "flatten_manual", "manual" ] = "automatic", - # Make default False restore_side_effects: bool = True, - under_activation_checkpoint: bool = False, # TODO - supports input_mutation and aliasing should be False by default for strictness supports_input_mutation: bool = True, supports_aliasing: bool = True, @@ -1311,17 +1381,16 @@ def gn(x): (f, sub_args, sub_kwargs), ) - with tx.output.subtracer(source_target, tracer) as subtracer: + with tx.output.subtracer(source_target, tracer, description) as subtracer: args = get_hop_args( tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description ) - output = trace_hop_function( + output = trace_hop_function_with_auto_output_flattening( f, tx, subtracer, enable_grad, - under_activation_checkpoint, restore_side_effects, args, sub_kwargs, @@ -1400,11 +1469,11 @@ def visit(vt): # want this to be supported for other Hops as well, specifically # nested_compile_region and autograd.Function. Today, its safe # because we error out on seeing a side-effect. - if under_activation_checkpoint: - extra_outputs = [] - for out in subtracer.tracked_tensor_or_symint_vt: - if out not in set(graph_output_vts): - extra_outputs.append(out) + enable_side_effects_with_extra_outputs = not restore_side_effects + if enable_side_effects_with_extra_outputs: + extra_outputs = _collect_intermediate_outputs( + tx, subtracer, graph_output_vts + ) graph_output_vts = graph_output_vts + tuple(extra_outputs) validate_subgraph_output_types(graph_output_vts) @@ -1501,7 +1570,6 @@ def speculate_subgraph( # if should_flatten_outputs is True, `remove_consts_from_outputs` remove the # const outputs from the subgraph output. remove_consts_from_outputs=True, - under_activation_checkpoint=False, # TODO - supports input_mutation and aliasing should be False by default for strictness supports_input_mutation=True, supports_aliasing=True, @@ -1537,7 +1605,7 @@ def speculate_subgraph( (f, sub_args, sub_kwargs), ) - with tx.output.subtracer(source_target, tracer) as subtracer: + with tx.output.subtracer(source_target, tracer, description) as subtracer: args = get_hop_args( tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description ) @@ -1547,7 +1615,6 @@ def speculate_subgraph( tx, subtracer, enable_grad, - under_activation_checkpoint, restore_side_effects, args, sub_kwargs, @@ -2829,7 +2896,6 @@ def create_wrapped_node( fn_args_vt, kwargs, description, - under_activation_checkpoint=False, *, subgraph_name="wrap_body", ): @@ -2848,7 +2914,6 @@ def create_wrapped_node( description, source_target=self.value, restore_side_effects=self.restore_side_effects, - under_activation_checkpoint=under_activation_checkpoint, supports_input_mutation=self.supports_input_mutation, supports_aliasing=self.supports_aliasing, ) @@ -3307,8 +3372,8 @@ def _call_function( class CheckpointHigherOrderVariable(WrapHigherOrderVariable): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - # If side effects are allowed under checkpoint, we should not restore - # the side effects after speculate subgraph. + # When skip_fwd_side_effects_in_bwd is True, we allow side effects by NOT restoring them. + # This enables collecting intermediate outputs for side effects. self.restore_side_effects = ( not torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint ) @@ -3354,7 +3419,6 @@ def _call_function( args[1:], gmod_kwargs, "torch.utils.checkpoint.checkpoint", - under_activation_checkpoint=True, ) if context_fn is not None: checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn @@ -4176,7 +4240,10 @@ def _call_function( class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): supports_input_mutation = True supports_aliasing = False - restore_side_effects = False + # TODO (tmanlaibaatar) This is in preparation for supporting side effects in invoke_subgraph. + # invoke_subgraph does not support side effects, so we restore them (default behavior). + # This means enable_side_effects_with_extra_outputs will be False. + restore_side_effects = True def install_subgraph_in_output_graph( self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name From 7741edd4ed665f3988052e260863efb508d61a03 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 1 Dec 2025 13:19:45 -0800 Subject: [PATCH 129/338] [MPS] Fix dlpack exports/imports for sliced tensors (#169272) For MPS tensor, one must pass both `id` (which is `t.storage().data()` and `t.storage_offset()`) Luckily, DLTensor already has `byte_offset` field, which feels natural to use as product of `storage_offset` and element_size. Partially extends https://github.com/pytorch/pytorch/pull/168193, but instead of writing a completely new test, fix both export and import paths of sliced tensor and unskip test_from_dlpack_noncontinguous for MPS Error out if one is attempting to create tensor with non-zero `byte_offsets` and no strides, as there are no `at::from_blob` variant that could be used Fixes https://github.com/pytorch/pytorch/issues/168177 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169272 Approved by: https://github.com/ngimel --- aten/src/ATen/DLConvertor.cpp | 21 +++++++++++++++++++-- test/test_dlpack.py | 7 +++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index ccb0ae15a11e6..b39f3eafa32df 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -356,8 +356,18 @@ ScalarType toScalarType(const DLDataType& dtype) { return stype; } + namespace { +int64_t toStorageOffset(int64_t byte_offset, ScalarType stype) { + if (byte_offset == 0) { + return 0; + } + const auto element_size = c10::elementSize(stype); + TORCH_CHECK_VALUE(byte_offset % element_size == 0, "byte offset must be multiple of element size"); + return byte_offset / element_size; +} + // The templated classes below are needed for supporting both: // - DLManagedTensor // - DLManagedTensorVersioned @@ -393,13 +403,18 @@ T* toDLPackImpl(const Tensor& src) { atDLMTensor->handle = src; atDLMTensor->tensor.manager_ctx = atDLMTensor; atDLMTensor->tensor.deleter = &deleter; - atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); + if (src.device().type() == kMPS) { + atDLMTensor->tensor.dl_tensor.data = src.storage().mutable_data(); + atDLMTensor->tensor.dl_tensor.byte_offset = src.storage_offset() * c10::elementSize(src.scalar_type()); + } else { + atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); + atDLMTensor->tensor.dl_tensor.byte_offset = 0; + } atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device()); atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); atDLMTensor->tensor.dl_tensor.shape = const_cast(src.sizes().data()); atDLMTensor->tensor.dl_tensor.strides = const_cast(src.strides().data()); - atDLMTensor->tensor.dl_tensor.byte_offset = 0; fillVersion(&atDLMTensor->tensor); return &(atDLMTensor->tensor); @@ -426,6 +441,7 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { ScalarType stype = toScalarType(dl_tensor.dtype); if (!dl_tensor.strides) { + TORCH_CHECK_VALUE(dl_tensor.byte_offset == 0, "Expected zero byte_offset"); return at::from_blob( dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), @@ -437,6 +453,7 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), IntArrayRef(dl_tensor.strides, dl_tensor.ndim), + toStorageOffset(dl_tensor.byte_offset, stype), deleter, at::device(device).dtype(stype), {device}); diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 3d6c4ae7484cb..3d27678b5864a 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -21,7 +21,6 @@ from torch.testing._internal.common_utils import ( IS_JETSON, run_tests, - skipIfMPS, skipIfTorchDynamo, TestCase, ) @@ -157,7 +156,6 @@ def test_from_dlpack(self, device, dtype): self.assertEqual(x, y) @skipMeta - @skipIfMPS # MPS crashes with noncontiguous now @onlyNativeDeviceTypes @dtypes( *all_types_and_complex_and( @@ -169,6 +167,11 @@ def test_from_dlpack(self, device, dtype): torch.uint64, ) ) + @dtypesIfMPS( + *all_mps_types_and( + torch.bool, torch.cfloat, torch.chalf, torch.uint16, torch.uint32 + ) + ) def test_from_dlpack_noncontinguous(self, device, dtype): x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5) From 4e0061c1aa52f606dda8cfab0bd7591e588faf2c Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Mon, 1 Dec 2025 19:13:22 -0800 Subject: [PATCH 130/338] [ez] add return type to _unpickle_sdp_backend (#169383) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169383 Approved by: https://github.com/Lucaskabela --- torch/_dynamo/guards.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 756996fb3f0f5..69197f44054a3 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -3327,7 +3327,7 @@ def _unpickle_bound_method(cls, func: Any, base: Any) -> Any: return types.MethodType(func, base) @staticmethod - def _unpickle_sdp_backend(name: str): + def _unpickle_sdp_backend(name: str) -> torch.nn.attention.SDPBackend: # Reconstruct from the Python-facing enum namespace return getattr(torch.nn.attention.SDPBackend, name) From 70076464a63ab218a7ceefb0e76ccd7131deb8f8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 2 Dec 2025 19:26:15 +0000 Subject: [PATCH 131/338] Revert "Remove unnecessary uses of thrust::tuple (#168936)" This reverts commit 265397e178dab071294f6a10e35226fe333b2983. Reverted https://github.com/pytorch/pytorch/pull/168936 on behalf of https://github.com/huydhn due to Sorry for reverting your change but we failed to land this due to the mismatch of ROCm version on OSS and internal ([comment](https://github.com/pytorch/pytorch/pull/168936#issuecomment-3603608569)) --- aten/src/ATen/native/cuda/ActivationEluKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationGeluKernel.cu | 1 + aten/src/ATen/native/cuda/ActivationGluKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationHardswishKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationMishKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationSiluKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationThresholdKernel.cu | 2 ++ aten/src/ATen/native/cuda/Loops.cuh | 2 +- aten/src/ATen/native/cuda/group_norm_kernel.cu | 1 + aten/src/ATen/native/cuda/layer_norm_kernel.cu | 3 ++- 17 files changed, 31 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cuda/ActivationEluKernel.cu b/aten/src/ATen/native/cuda/ActivationEluKernel.cu index 9fc29aa5539b5..5ad1f806f9ba5 100644 --- a/aten/src/ATen/native/cuda/ActivationEluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationEluKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu index 87781c44e3348..cd5a0ae85e61c 100644 --- a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu @@ -5,6 +5,7 @@ #include +#include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGluKernel.cu b/aten/src/ATen/native/cuda/ActivationGluKernel.cu index 8a782a129c9fb..e28a6d61ea152 100644 --- a/aten/src/ATen/native/cuda/ActivationGluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGluKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu index f0968b957aa6d..2a0be3f5d27bf 100644 --- a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu index 813a8c07ccfac..fcacef37ceaf0 100644 --- a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu index 651cdef82543b..1642d0909f7f0 100644 --- a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu index 85aa7ccd22a9e..a18072f7a27bc 100644 --- a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu index 340a6f97d00de..72130739898fe 100644 --- a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu index 2175920917852..9a1d672428b48 100644 --- a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationMishKernel.cu b/aten/src/ATen/native/cuda/ActivationMishKernel.cu index 25ba9810e37cf..0db0e96bb180a 100644 --- a/aten/src/ATen/native/cuda/ActivationMishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationMishKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu index ebdfe245b6166..f7ddfd8502a18 100644 --- a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu index 65f4f3679f862..64ffc21123707 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu index 712c86e0e5216..0c2dc63dbcf45 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu index 430f9cbfa78bb..2d1cb4a47d7d8 100644 --- a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index e739d7d2ecee2..a80c51fa6a9cb 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -282,7 +282,7 @@ void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) using traits = function_traits; using output_t = typename traits::result_type; static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); - constexpr int num_outputs = std::tuple_size::value; + constexpr int num_outputs = thrust::tuple_size::value; constexpr int num_inputs = traits::arity; constexpr int ntensors = num_outputs + num_inputs; diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index 0ef6434f909de..77d26e915b65a 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -3,6 +3,7 @@ #include +#include #include #include diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 6f5112c605fab..84812eb22125f 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -1,9 +1,10 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include #include +#include + #include #include #include From 6f7dcf51e46d0c880db1a2f5c70de57adb576f4a Mon Sep 17 00:00:00 2001 From: Nick Riasanovsky Date: Tue, 2 Dec 2025 19:27:24 +0000 Subject: [PATCH 132/338] Add shape logging to autotuning to debug timeout issues (#169062) Summary: Add logging at the beginning of the autotuning process to capture the shapes being benchmarked before timeouts occur. This helps identify which specific tensor shapes cause compilation or benchmarking to hang, enabling faster debugging of timeout issues during kernel selection. Test Plan: Ran the code with a model that previously timed out during autotuning and verified that the shape information is now logged before the timeout occurs, making it easy to identify problematic shapes. Differential Revision: D87804069 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169062 Approved by: https://github.com/PaulZhang12 --- torch/_inductor/select_algorithm.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index f0101f01f3617..77448c914df80 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2964,6 +2964,29 @@ def do_autotuning( NoValidChoicesError: When all choices fail to compile or benchmark, or when all timing results are non-finite. """ + if log.isEnabledFor(logging.DEBUG): + # Log shape information for debugging timeout issues + sizevars = V.graph.sizevars + shapes = [ + "x".join( + map( + str, + sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ), + ) + ) + for node in input_nodes + ] + log.debug( + "[BENCHMARK DEBUG] Starting autotuning for '%s' with shapes: %s, device: %s", + name, + shapes, + layout.device.type if layout else "unknown", + ) + precompile_start_ts = time.time() with dynamo_timed( f"{name}_template_precompiling", From 066997fb38ade71e00d78e9d572e380b5f02bd3e Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 2 Dec 2025 19:34:59 +0000 Subject: [PATCH 133/338] [CUDA][FlexAttention] Use `sm8x` configs for `sm12x` for backward (#168367) Otherwise we seem to see failures in e.g., `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_non_pow_2_headdim_head_dim_94_cuda_float16` Shared memory limit of `sm12x` is close to that of `sm8x` than `sm10x` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168367 Approved by: https://github.com/drisspg --- torch/_inductor/template_heuristics/triton.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 68a34f5d1d2f1..a9925db292a36 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -906,7 +906,6 @@ class CUDAConfigHeuristic(BaseConfigHeuristic): def __init__(self) -> None: super().__init__() - self.sm_120_default_flex_config = { (torch.float32, 64): FlexConfig(128, 32, 2, 4), (torch.float32, 128): FlexConfig(128, 32, 2, 4), @@ -981,7 +980,7 @@ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfi if dtype == torch.float32: default_config = FlexConfig(64, 64, 3, 4) else: - default_config = FlexConfig(128, 64, 3, 4) + default_config = FlexConfig(64, 64, 3, 4) if capability >= (12, 0): default_config = self.sm_120_default_flex_config.get( (dtype, head_dim), default_config @@ -1014,7 +1013,6 @@ def get_flex_attn_bwd_configs( ) -> list[FlexBwDConfig]: capability = torch.cuda.get_device_capability() flex_attn_bwd_configs: list[FlexBwDConfig] = [] - if config.max_autotune: if config.max_autotune_flex_search_space == "EXHAUSTIVE": return self.exhaustive_flex_attn_bwd_configs @@ -1023,6 +1021,8 @@ def get_flex_attn_bwd_configs( major, minor = capability if dtype == torch.float32: capability_class = "float32" + elif major == 12: + capability_class = "sm12x" elif major >= 10: capability_class = "sm10x" elif capability == (9, 0): @@ -1053,6 +1053,13 @@ def get_flex_attn_bwd_configs( 64, 64, 64, 64, 3 if minor == 6 and h == 128 else 2, 4 ) ), + "sm12x": lambda h: ( + FlexBwDConfig(32, 128, 128, 32, 3, 4) + if h < 64 + else FlexBwDConfig( + 64, 64, 64, 64, 3 if minor == 6 and h == 128 else 2, 4 + ) + ), } # fmt: on From d973dc6b87d763859fe1c5bd1287e3b6b1c49d1b Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 2 Dec 2025 19:48:34 +0000 Subject: [PATCH 134/338] [ez] Remove maybe unused var in common_utils.py (#166455) I think we don't actually need this anymore since we don't use inspect.getfile to get the filename anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/166455 Approved by: https://github.com/huydhn --- torch/testing/_internal/common_utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index ef199e07d6a04..df3ca03b76242 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -114,7 +114,6 @@ class ProfilingMode(Enum): PROFILING = 3 # Set by parse_cmd_line_args() if called -CI_TEST_PREFIX = "" DISABLED_TESTS_FILE = "" GRAPH_EXECUTOR : Optional[ProfilingMode] = None LOG_SUFFIX = "" @@ -957,7 +956,6 @@ def _get_test_report_path(): return os.path.join('test-reports', test_source) def parse_cmd_line_args(): - global CI_TEST_PREFIX global DISABLED_TESTS_FILE global GRAPH_EXECUTOR global LOG_SUFFIX @@ -1035,8 +1033,6 @@ def run_unittest_help(argv): set_rng_seed() - # CI Prefix path used only on CI environment - CI_TEST_PREFIX = str(Path(os.getcwd())) def wait_for_process(p, timeout=None): try: @@ -1160,9 +1156,6 @@ def chunk_list(lst, nchunks): # sanitize filename e.g., distributed/pipeline/sync/skip/test_api.py -> distributed.pipeline.sync.skip.test_api def sanitize_test_filename(filename): - # inspect.getfile returns absolute path in some CI jobs, converting it to relative path if needed - if filename.startswith(CI_TEST_PREFIX): - filename = filename[len(CI_TEST_PREFIX) + 1:] strip_py = re.sub(r'.py$', '', filename) return re.sub('/', r'.', strip_py) From 2c87367e6f88662cd5cedbd1537748b7948c38e1 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 1 Dec 2025 16:15:29 -0800 Subject: [PATCH 135/338] dist/debug: add TCPStore debug page (#169095) This adds a TCPStore debug page. Test plan: run debug server [ 20251125_17h23m00s_grim ](url) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169095 Approved by: https://github.com/fduwjj --- test/distributed/test_debug.py | 7 +++++++ torch/distributed/debug/_frontend.py | 22 ++++++++++++++++++++++ torch/distributed/debug/_store.py | 5 +++-- 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/test/distributed/test_debug.py b/test/distributed/test_debug.py index e1612d7639a13..ebd208d4c5b35 100644 --- a/test/distributed/test_debug.py +++ b/test/distributed/test_debug.py @@ -50,6 +50,13 @@ def fetch(path: str) -> str: self.assertEqual(resp.status_code, 404) self.assertIn("Handler not found: /blah", resp.text) + with self.subTest("tcpstore"): + store.set("test", "value") + store.set("test2", "a" * 1000) + out = fetch("/tcpstore") + self.assertIn("test: b'value'", out) + self.assertIn("test2: b'" + "a" * 95 + "...", out) + stop_debug_server() diff --git a/torch/distributed/debug/_frontend.py b/torch/distributed/debug/_frontend.py index 10dae4c2802cd..c2d57606c4d45 100644 --- a/torch/distributed/debug/_frontend.py +++ b/torch/distributed/debug/_frontend.py @@ -97,6 +97,7 @@ def format_json(blob: str): FlightRecorder NCCL torch profiler Wait Counters + TCPStore
@@ -209,6 +210,19 @@ def format_json(blob: str): {% endif %} {% endfor %} +{% endblock %} + """, + "tcpstore.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}TCPStore Keys{% endblock %}

+{% endblock %} +{% block content %} +
+    {% for k, v in zip(keys, values) -%}
+{{ k }}: {{ v | truncate(100) }}
+    {% endfor %}
+    
{% endblock %} """, } @@ -259,6 +273,7 @@ def __init__(self, port: int): "/fr_trace_nccl": self._handle_fr_trace_nccl, "/profile": self._handle_profiler, "/wait_counters": self._handle_wait_counters, + "/tcpstore": self._handle_tcpstore, } # Create HTTP server @@ -354,6 +369,13 @@ def _handle_wait_counters(self, req: HTTPRequestHandler) -> bytes: "json_resp.html", title="Wait Counters", addrs=addrs, resps=resps ) + def _handle_tcpstore(self, req: HTTPRequestHandler) -> bytes: + store = tcpstore_client(prefix="") + keys = store.list_keys() + keys.sort() + values = [repr(v) for v in store.multi_get(keys)] + return self._render_template("tcpstore.html", keys=keys, values=values) + def main(port: int) -> None: server = FrontendServer(port=port) diff --git a/torch/distributed/debug/_store.py b/torch/distributed/debug/_store.py index 70c6cd0f3dde1..487dd30abd6af 100644 --- a/torch/distributed/debug/_store.py +++ b/torch/distributed/debug/_store.py @@ -11,7 +11,7 @@ def get_world_size() -> int: return int(os.environ["WORLD_SIZE"]) -def tcpstore_client() -> dist.Store: +def tcpstore_client(prefix: str = "debug_server") -> dist.Store: MASTER_ADDR = os.environ["MASTER_ADDR"] MASTER_PORT = int(os.environ["MASTER_PORT"]) @@ -20,5 +20,6 @@ def tcpstore_client() -> dist.Store: port=MASTER_PORT, is_master=False, ) - store = dist.PrefixStore("debug_server", store) + if prefix: + store = dist.PrefixStore(prefix, store) return store From d998c03304cb6ede76e1ed535b4ddeb6c2bf40ec Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 1 Dec 2025 16:15:29 -0800 Subject: [PATCH 136/338] dist/debug: use aiohttp to scale to 100k workers (#169096) This uses `aiohttp` to run all requests concurrently. This cuts the latency at 10k from `15s -> 5s` and is `50s` at 100k. I expect that 100k number is a little sus given I was running this on a single machine with only 4 workers. Test plan: patch fetch_all to do 100k requests instead Pull Request resolved: https://github.com/pytorch/pytorch/pull/169096 Approved by: https://github.com/fduwjj ghstack dependencies: #169095 --- .ci/docker/requirements-ci.txt | 1 + torch/distributed/debug/__init__.py | 6 ++++ torch/distributed/debug/_frontend.py | 53 ++++++++++++++++++++++++---- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 242cbaafa059e..f00516ccf1293 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -404,4 +404,5 @@ tabulate==0.9.0 #Description: These package are needed to build FBGEMM and torchrec on PyTorch CI Jinja2==3.1.6 +aiohttp==3.13.2 #Description: required for torch.distributed.debug diff --git a/torch/distributed/debug/__init__.py b/torch/distributed/debug/__init__.py index 46267a686e86d..93295802ae847 100644 --- a/torch/distributed/debug/__init__.py +++ b/torch/distributed/debug/__init__.py @@ -29,6 +29,12 @@ def start_debug_server(port: int = 25999, worker_port: int = 0) -> None: deadlocked distributed jobs across all ranks simultaneously. This collects data such as stack traces, FlightRecorder events, and performance profiles. + This depends on dependencies which are not installed by default. + + Dependencies: + - Jinja2 + - aiohttp + WARNING: This is intended to only be used in trusted network environments. The debug server is not designed to be secure and should not be exposed to the public internet. See SECURITY.md for more details. diff --git a/torch/distributed/debug/_frontend.py b/torch/distributed/debug/_frontend.py index c2d57606c4d45..58389abfe97da 100644 --- a/torch/distributed/debug/_frontend.py +++ b/torch/distributed/debug/_frontend.py @@ -1,13 +1,14 @@ +import asyncio import json import logging import socket import threading -from collections.abc import Iterator +from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from urllib.parse import parse_qs, urlparse -import requests from jinja2 import DictLoader, Environment from torch.distributed.debug._store import get_world_size, tcpstore_client @@ -16,16 +17,54 @@ logger: logging.Logger = logging.getLogger(__name__) -def fetch_all( - endpoint: str, args: str = "" -) -> tuple[list[str], Iterator[requests.Response]]: +@dataclass(slots=True) +class Response: + status_code: int + text: str + + +def fetch_thread_pool(urls: list[str]) -> Iterable[Response]: + # late import for optional dependency + import requests + + max_workers = 20 + + def get(url: str) -> Response: + resp = requests.post(url) + return Response(resp.status_code, resp.text) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + resps = executor.map(get, urls) + + return resps + + +def fetch_aiohttp(urls: list[str]) -> Iterable[Response]: + # late import for optional dependency + import aiohttp + + async def fetch(session: aiohttp.ClientSession, url: str) -> Response: + async with session.post(url) as resp: + text = await resp.text() + return Response(resp.status, text) + + async def gather(urls: list[str]) -> Iterable[Response]: + async with aiohttp.ClientSession() as session: + return await asyncio.gather(*[fetch(session, url) for url in urls]) + + return asyncio.run(gather(urls)) + + +def fetch_all(endpoint: str, args: str = "") -> tuple[list[str], Iterable[Response]]: store = tcpstore_client() keys = [f"rank{r}" for r in range(get_world_size())] addrs = store.multi_get(keys) addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] - with ThreadPoolExecutor(max_workers=10) as executor: - resps = executor.map(requests.post, addrs) + try: + resps = fetch_aiohttp(addrs) + except ImportError: + resps = fetch_thread_pool(addrs) return addrs, resps From 082e96b68dfcd16cab7cfafc4d3d055767dab3eb Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 1 Dec 2025 16:15:30 -0800 Subject: [PATCH 137/338] distributed/debug: add fr trace analysis (#169144) This adds FlightRecorder trace analysis using frtrace to the debug server. Test plan: 20251126_14h58m19s_grim Pull Request resolved: https://github.com/pytorch/pytorch/pull/169144 Approved by: https://github.com/fduwjj ghstack dependencies: #169095, #169096 --- test/distributed/test_debug.py | 35 +++--- torch/distributed/debug/_frontend.py | 102 +++++++++++++++++- .../flight_recorder/components/utils.py | 5 +- 3 files changed, 122 insertions(+), 20 deletions(-) diff --git a/test/distributed/test_debug.py b/test/distributed/test_debug.py index ebd208d4c5b35..1c9dfcf96b83f 100644 --- a/test/distributed/test_debug.py +++ b/test/distributed/test_debug.py @@ -20,7 +20,7 @@ class TestDebug(TestCase): - def test_basics(self) -> None: + def test_all(self) -> None: store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False) os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(store.port) @@ -36,19 +36,30 @@ def fetch(path: str) -> str: start_debug_server(port=port) - self.assertIn("torch profiler", fetch("/")) - self.assertIn("View 0", fetch("/profile?duration=0.01")) - self.assertIn("test_basics", fetch("/stacks")) - self.assertIn("pg_status", fetch("/fr_trace")) - self.assertIn("Rank 0", fetch("/wait_counters")) + with self.subTest("index"): + self.assertIn("torch profiler", fetch("/")) - if torch.cuda.is_available(): - self.assertIn("pg_status", fetch("/fr_trace_nccl")) + with self.subTest("profile"): + self.assertIn("View 0", fetch("/profile?duration=0.01")) - # test errors - resp = session.get(f"http://localhost:{port}/blah") - self.assertEqual(resp.status_code, 404) - self.assertIn("Handler not found: /blah", resp.text) + with self.subTest("stacks"): + self.assertIn("test_all", fetch("/stacks")) + + with self.subTest("wait_counters"): + self.assertIn("Rank 0", fetch("/wait_counters")) + + with self.subTest("fr_trace"): + self.assertIn("Memberships", fetch("/fr_trace")) + self.assertIn("pg_status", fetch("/fr_trace_json")) + + if torch.cuda.is_available(): + self.assertIn("Memberships", fetch("/fr_trace_nccl")) + self.assertIn("pg_status", fetch("/fr_trace_nccl_json")) + + with self.subTest("error codes"): + resp = session.get(f"http://localhost:{port}/blah") + self.assertEqual(resp.status_code, 404) + self.assertIn("Handler not found: /blah", resp.text) with self.subTest("tcpstore"): store.set("test", "value") diff --git a/torch/distributed/debug/_frontend.py b/torch/distributed/debug/_frontend.py index 58389abfe97da..d31d3e734c28b 100644 --- a/torch/distributed/debug/_frontend.py +++ b/torch/distributed/debug/_frontend.py @@ -10,8 +10,17 @@ from urllib.parse import parse_qs, urlparse from jinja2 import DictLoader, Environment +from tabulate import tabulate from torch.distributed.debug._store import get_world_size, tcpstore_client +from torch.distributed.flight_recorder.components.builder import build_db +from torch.distributed.flight_recorder.components.config_manager import JobConfig +from torch.distributed.flight_recorder.components.types import ( + Collective, + Group, + Membership, + NCCLCall, +) logger: logging.Logger = logging.getLogger(__name__) @@ -22,6 +31,13 @@ class Response: status_code: int text: str + def raise_for_status(self): + if self.status_code != 200: + raise RuntimeError(f"HTTP {self.status_code}: {self.text}") + + def json(self): + return json.loads(self.text) + def fetch_thread_pool(urls: list[str]) -> Iterable[Response]: # late import for optional dependency @@ -132,8 +148,10 @@ def format_json(blob: str): Home Python Stack Traces - FlightRecorder + FlightRecorder CPU + (JSON) FlightRecorder NCCL + (JSON) torch profiler Wait Counters TCPStore @@ -262,6 +280,22 @@ def format_json(blob: str): {{ k }}: {{ v | truncate(100) }} {% endfor %} +{% endblock %} + """, + "fr_trace.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} +

Groups

+ {{ groups | safe }} +

Memberships

+ {{ memberships | safe }} +

Collectives

+ {{ collectives | safe }} +

NCCL Calls

+ {{ ncclcalls | safe }} {% endblock %} """, } @@ -275,6 +309,13 @@ class _IPv6HTTPServer(ThreadingHTTPServer): class HTTPRequestHandler(BaseHTTPRequestHandler): frontend: "FrontendServer" + def log_message(self, format, *args): + logger.info( + "%s %s", + self.client_address[0], + format % args, + ) + def do_GET(self): self.frontend._handle_request(self) @@ -309,7 +350,9 @@ def __init__(self, port: int): "/": self._handle_index, "/stacks": self._handle_stacks, "/fr_trace": self._handle_fr_trace, + "/fr_trace_json": self._handle_fr_trace_json, "/fr_trace_nccl": self._handle_fr_trace_nccl, + "/fr_trace_nccl_json": self._handle_fr_trace_nccl_json, "/profile": self._handle_profiler, "/wait_counters": self._handle_wait_counters, "/tcpstore": self._handle_tcpstore, @@ -336,7 +379,7 @@ def _serve(self) -> None: try: self._server.serve_forever() except Exception: - logger.exception("got exception in checkpoint server") + logger.exception("got exception in frontend server") def join(self) -> None: self._thread.join() @@ -350,12 +393,13 @@ def _handle_request(self, req: HTTPRequestHandler) -> None: handler = self._routes[path] try: resp = handler(req) - except Exception as e: + # Catch SystemExit to not crash when FlightRecorder errors. + except (Exception, SystemExit) as e: logger.exception( - "Exception in checkpoint server when handling %s", + "Exception in frontend server when handling %s", path, ) - req.send_error(500, str(e)) + req.send_error(500, f"Exception: {repr(e)}") return req.send_response(200) @@ -375,9 +419,50 @@ def _handle_stacks(self, req: HTTPRequestHandler) -> bytes: "raw_resp.html", title="Stacks", addrs=addrs, resps=resps ) + def _render_fr_trace(self, addrs: list[str], resps: list[Response]) -> bytes: + config = JobConfig() + # pyrefly: ignore [bad-assignment] + args = config.parse_args(args=[]) + args.allow_incomplete_ranks = True + args.verbose = True + + details = {} + for rank, resp in enumerate(resps): + resp.raise_for_status() + dump = { + "rank": rank, + "host_name": addrs[rank], + **resp.json(), + } + if "entries" not in dump: + dump["entries"] = [] + details[f"rank{rank}.json"] = dump + + version = next(iter(details.values()))["version"] + + db = build_db(details, args, version) + + return self._render_template( + "fr_trace.html", + title="FlightRecorder", + groups=tabulate(db.groups, headers=Group._fields, tablefmt="html"), + memberships=tabulate( + db.memberships, headers=Membership._fields, tablefmt="html" + ), + collectives=tabulate( + db.collectives, headers=Collective._fields, tablefmt="html" + ), + ncclcalls=tabulate(db.ncclcalls, headers=NCCLCall._fields, tablefmt="html"), + ) + def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("fr_trace_json") + return self._render_fr_trace(addrs, list(resps)) + + def _handle_fr_trace_json(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("fr_trace_json") + return self._render_template( "json_resp.html", title="FlightRecorder", @@ -388,6 +473,11 @@ def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes: def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + return self._render_fr_trace(addrs, list(resps)) + + def _handle_fr_trace_nccl_json(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + return self._render_template( "json_resp.html", title="FlightRecorder NCCL", @@ -417,6 +507,8 @@ def _handle_tcpstore(self, req: HTTPRequestHandler) -> bytes: def main(port: int) -> None: + logger.setLevel(logging.INFO) + server = FrontendServer(port=port) logger.info("Frontend server started on port %d", server._server.server_port) server.join() diff --git a/torch/distributed/flight_recorder/components/utils.py b/torch/distributed/flight_recorder/components/utils.py index 25c5350381187..6ab7919a2a24d 100644 --- a/torch/distributed/flight_recorder/components/utils.py +++ b/torch/distributed/flight_recorder/components/utils.py @@ -702,9 +702,8 @@ def check_no_missing_dump_files( for membership in memberships: all_ranks.add(int(membership.global_rank)) dumps_ranks = {int(key) for key in entries} - assert dumps_ranks == all_ranks, ( - f"Missing dump files from ranks {all_ranks - dumps_ranks}" - ) + missing = all_ranks - dumps_ranks + assert len(missing) == 0, f"Missing dump files from ranks {missing}" def check_version(version_by_ranks: dict[str, str], version: str) -> None: From 6f2783a6c08e1db34275ff25176ffe9aebc30a71 Mon Sep 17 00:00:00 2001 From: "Patrick C. Toulme" Date: Tue, 2 Dec 2025 20:19:50 +0000 Subject: [PATCH 138/338] [MTIA] Support bfloat16 in half_to_float (#168938) Summary: Support bfloat16 in half_to_float when tracing a softmax op that has half_to_float=True. Test Plan: CI Differential Revision: D87607452 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168938 Approved by: https://github.com/zou3519, https://github.com/cyyever --- test/test_nn.py | 10 ++++++++++ torch/_meta_registrations.py | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/test_nn.py b/test/test_nn.py index 176516713feb1..2b1a8166ef5e7 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12149,6 +12149,16 @@ def test_softmax_bfloat16(self, device): # test softmax with large input value which causes exp() to overflow _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=0.05, scale_factor=1000.0) + def test_softmax_bfloat16_half_to_float(self): + # half_to_float is only supported on MTIA + # Test meta tensors - both dtypes work for meta regardless of target device + for dtype in [torch.half, torch.bfloat16]: + x_meta = torch.randn(8, 16, device='meta', dtype=dtype) + result_meta = torch._softmax(x_meta, dim=1, half_to_float=True) + # Meta tensor result should also be float32 + self.assertEqual(result_meta.dtype, torch.float32) + self.assertEqual(result_meta.shape, (8, 16)) + def test_nll_loss_1d_input_1d_target_invalid_size(self, device): x = torch.randn(10, device=device) t = torch.randint(0, 10, (3,), dtype=torch.int64, device=device) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 0055bdd77f315..d48b421f105c7 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -8103,7 +8103,8 @@ def meta_scaled_grouped_mm( @out_wrapper() def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor: if half_to_float: - assert x.dtype == torch.half + assert x.dtype in [torch.half, torch.bfloat16] + computation_dtype, result_dtype = utils.elementwise_dtypes( x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ) From 8d9dd9603e5ee26c01007f0cd4f018e584840922 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 1 Dec 2025 14:17:55 -0800 Subject: [PATCH 139/338] ensure that regional inductor uses a boxed calling convention (#168277) I noticed when running regional inductor that it would emit the unboxed compiler warnings: ``` UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments ``` I tried tweaking the regional_inductor backend to use boxed codegen. Two followup changes I needed were: (1) making it pickleable (2) Ensuring that `invoke_subgraph()` knows how to handle the case where the subgraph it calls expects boxed inputs Pull Request resolved: https://github.com/pytorch/pytorch/pull/168277 Approved by: https://github.com/anijain2305, https://github.com/ezyang --- test/dynamo/test_regional_inductor.py | 31 ++++++++++++++++++++++ torch/_dynamo/backends/common.py | 6 +++++ torch/_higher_order_ops/invoke_subgraph.py | 5 +++- torch/fx/passes/regional_inductor.py | 5 ++++ 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_regional_inductor.py b/test/dynamo/test_regional_inductor.py index 524d7fa499c39..b087d44dec606 100644 --- a/test/dynamo/test_regional_inductor.py +++ b/test/dynamo/test_regional_inductor.py @@ -1,6 +1,7 @@ # Owner(s): ["module: dynamo"] import functools +import warnings from typing import TYPE_CHECKING import torch @@ -102,6 +103,36 @@ def fn(x, y): _, codes = run_fw_bw_and_get_code(lambda: opt_fn(x, y)) self.assertEqual(len(codes), 2) + def test_boxed_calling_convention(self): + def fn(x, y): + sin = torch.sin(x) + + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + + return torch.sin(add) + + opt_fn = torch.compile( + fn, backend=aot_eager_regional_inductor(serialize=False), fullgraph=True + ) + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + + # Check that inductor compilation is called twice + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _, codes = run_fw_bw_and_get_code(lambda: opt_fn(x, y)) + + msgs = [str(warn.message) for warn in w] + self.assertTrue( + not any( + "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments" + in m + for m in msgs + ) + ) + @parametrize("serialize", [False, True]) def test_repeated_blocks(self, serialize): def fn(x, y): diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index 2ffd9523bdf15..0d2b6ecff0c17 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -78,6 +78,7 @@ def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: # The two disables here: # - stop TorchDynamo from trying to compile the bw_compiler function itself # - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces + return disable( disable( bw_compiler_fn, reason="do not trace backward compiler function" @@ -85,12 +86,17 @@ def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: reason="do not trace generated backwards pass", ) + _wrapped_bw_compiler._is_wrapped_bw_compiler = ( # pyrefly: ignore [missing-attribute] + True + ) return _wrapped_bw_compiler bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] if isinstance(bw_compiler, SerializableAOTDispatchCompiler): bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn) + elif getattr(bw_compiler, "_is_wrapped_bw_compiler", False): + bw_compiler.compiler_fn = bw_compiler else: bw_compiler = wrap_bw_compiler(bw_compiler) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index bb0d6cef3ee6f..8eb3901ab0734 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -605,7 +605,10 @@ def _(subgraph, identifier, *operands): mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" - return subgraph(*operands) + if getattr(subgraph, "_boxed_call", False): + return subgraph(list(operands)) + else: + return subgraph(*operands) @invoke_subgraph.py_functionalize_impl diff --git a/torch/fx/passes/regional_inductor.py b/torch/fx/passes/regional_inductor.py index c3f9c22d252d3..4146fd6c967bf 100644 --- a/torch/fx/passes/regional_inductor.py +++ b/torch/fx/passes/regional_inductor.py @@ -122,6 +122,7 @@ def _needs_inductor_compile(node: torch.fx.Node): def _compile_fx_annotated_nodes_with_inductor(gm): + from torch.fx.graph import _BoxedCodeGen from torch.fx.passes.operator_support import OperatorSupport found_marked_node = False @@ -141,6 +142,10 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: marked_nodes = InductorMarkedNodes() gm = _partition_by_supported_nodes(gm, marked_nodes, "__marked_inductor_submod") gm = _compile_submod(gm, "__marked_inductor_submod") + + gm.graph.set_codegen(_BoxedCodeGen()) + gm.recompile() + return gm From b7d60685f8cbc939b68a20871e90db67e729329b Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 2 Dec 2025 09:13:17 -0800 Subject: [PATCH 140/338] fix for using cross-pg overlap (#169384) Now that we do cross-pg overlap, when a wait is a on a different pg 1, and used for overlapping a node on pg 2, we need to include it in our dependencies. because we were entering into `_schedulable_wait_node` the check below for overlap wasnt firing. repro sort of long for test., but verified it fixed error Pull Request resolved: https://github.com/pytorch/pytorch/pull/169384 Approved by: https://github.com/IvanKobzarev --- torch/_inductor/fx_passes/overlap_preserving_bucketer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index 7fc456f388deb..7c819f37a1a83 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -219,6 +219,9 @@ def build_timeline(self, pg: str) -> Optional[PGEvent]: wait_input = node.args[0] if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg: node_type = "waits" + # Wait for a different PG but hiding a collective on this PG + elif node in hiding_nodes: + node_type = "compute" elif is_compute_node(node) or node in hiding_nodes: node_type = "compute" From d8fd5c6eed28e5004150691d048a3f6785e19a8e Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Tue, 2 Dec 2025 21:26:05 +0000 Subject: [PATCH 141/338] [MPS] Modify eps for gradcheck for float32 dtypes (#168902) Modify eps for gradcheck for float32 dtypes. Followup to: #168156 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168902 Approved by: https://github.com/malfet --- test/test_sparse.py | 105 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 92 insertions(+), 13 deletions(-) diff --git a/test/test_sparse.py b/test/test_sparse.py index 58398a915ff17..25d46892de258 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -61,8 +61,69 @@ def _op_supports_any_sparse(op): # sharding on sandcastle. This line silences flake warnings load_tests = load_tests # noqa: PLW0127 +def _make_lowp_aware_gradcheck(gradcheck_fn): + """ + Wraps a gradcheck function to handle low precision dtypes + + For float64/complex128 inputs: runs gradcheck directly + For lower precision inputs: compares backward() on device against + backward() on CPU in float64/complex128 + """ + HIGHP_DTYPES = (torch.float64, torch.complex128) + + def needs_backward_comparison(inputs): + return any(inp.dtype not in HIGHP_DTYPES for inp in inputs) + + def clone_inputs_cpu(inputs): + cloned = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + cloned.append(inp) + continue + gradcheck_dtype = torch.complex128 if inp.dtype.is_complex else torch.float64 + c = inp.detach().clone().to("cpu").to(gradcheck_dtype) + if c.is_sparse: + c = c.coalesce() + c = c.requires_grad_(inp.requires_grad) + cloned.append(c) + return tuple(cloned) + + def compute_grads(fn, inputs): + grad_inputs = [x for x in inputs if isinstance(x, torch.Tensor) and x.requires_grad] + out = fn(*inputs) + grads = torch.autograd.grad(out, grad_inputs, torch.ones_like(out), allow_unused=True) + return grads, grad_inputs + + @functools.wraps(gradcheck_fn) + def wrapped(fn, inputs, *args, **kwargs): + inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) + if not needs_backward_comparison(inputs): + return gradcheck_fn(fn, inputs, *args, **kwargs) + + ref_grads, ref_inputs = compute_grads(fn, clone_inputs_cpu(inputs)) + orig_grads, orig_inputs = compute_grads(fn, inputs) + + for i, (og, rg, o_inp, r_inp) in enumerate(zip(orig_grads, ref_grads, orig_inputs, ref_inputs)): + og_dense = og.to_dense() if og.is_sparse else og + rg_dense = rg.to_dense() if rg.is_sparse else rg + og_dense = og_dense.to('cpu') + rg_dense = rg_dense.to(device='cpu', dtype=og_dense.dtype) + if not torch.allclose(og_dense, rg_dense): + max_diff = (og_dense - rg_dense).abs().max() + raise AssertionError( + f"Gradient mismatch for input {i}:\n" + f" input dtype/device: orig={o_inp.dtype}/{o_inp.device}, ref={r_inp.dtype}/{r_inp.device}\n" + f" shapes: {tuple(og_dense.shape)} vs {tuple(rg_dense.shape)}\n" + f" max abs diff: {max_diff}" + ) + return True + if hasattr(gradcheck_fn, 'masked'): + wrapped.masked = gradcheck_fn.masked + return wrapped + # batched grad doesn't support sparse gradcheck = functools.partial(gradcheck, check_batched_grad=False) +gradcheck = _make_lowp_aware_gradcheck(gradcheck) CUSPARSE_SPMM_COMPLEX128_SUPPORTED = ( IS_WINDOWS and torch.version.cuda @@ -646,8 +707,7 @@ def test_tensor(x, res): def fn(x): return x.to_dense(masked_grad=gradcheck.masked) x.requires_grad_(True) - kwargs = {"eps": 1e-4} if device == "mps:0" else {} - gradcheck(fn, (x,), **kwargs) + gradcheck(fn, (x,)) i = self.index_tensor([ [0, 1, 2, 2], @@ -1034,8 +1094,7 @@ def test_shape(sparse_dims, nnz, with_size): else: self.assertFalse(s_permuted.is_coalesced()) - kwargs = {"eps": 1e-4} if device == "mps:0" else {} - gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_(), **kwargs) + gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_()) else: # otherwise check if exception is thrown fail_message = "transpositions between sparse and dense dimensions are not allowed" @@ -1698,8 +1757,7 @@ def test_shape(d1, d2, d3, nnz, transposed): def fn(S, D): return torch.sparse.mm(S, D) - kwargs = {"eps": 1e-4, "atol": 2e-5} if device == "mps:0" else {} - gradcheck(fn, (S, D), masked=True, **kwargs) + gradcheck(fn, (S, D), masked=True) test_shape(7, 8, 9, 20, False) test_shape(7, 8, 9, 20, True) @@ -1713,16 +1771,16 @@ def test_sparse_mul(self, device, dtype, coalesced, gradcheck): # https://github.com/pytorch/pytorch/issues/79914 a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True) b = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True) - gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(masked_grad=gradcheck.masked), [a, b], eps=1e-4) + gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(masked_grad=gradcheck.masked), [a, b]) def test_shape(sparse_dims, nnz, with_shape): a = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True) b = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True) self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense()) - gradcheck(lambda x, y: (x * y).to_dense(), [a, b], eps=1e-4) + gradcheck(lambda x, y: (x * y).to_dense(), [a, b]) # Issues with 0-dim indices/values - gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True, eps=3e-4, atol=5e-5) + gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True) test_shape(2, 3, [2, 3, 4, 5]) test_shape(2, 3, [2, 2, 0]) @@ -2246,7 +2304,6 @@ def test_sparse_mask_backward(self, device, dtype): nnzs = (0, 5, 15, 25) lhs_data = torch.arange(1, 26, device=device).reshape(shape).to(dtype).to_sparse(sparse_dims) - for nnz in nnzs: for lhs_is_coalesced, rhs_is_coalesced in product(*repeat((True, False), 2)): lhs = torch.sparse_coo_tensor( @@ -2265,9 +2322,31 @@ def test_sparse_mask_backward(self, device, dtype): # sparsity_pattern(lhs) == sparsity_pattern(lhs.grad). # lhs.sparse_mask(lhs_mask) accomplishes that. lhs_mask = lhs.detach().clone() - gradcheck(lambda x: x.sparse_mask(lhs_mask).sparse_mask(rhs).to_dense(masked_grad=True), (lhs,), - masked=True, eps=3e-4, atol=5e-5) - gradcheck(lambda x: x.sparse_mask(rhs).to_dense(masked_grad=False), (lhs,), masked=False, eps=3e-4, atol=5e-5) + + def op_masked(x): + m, r = lhs_mask, rhs + if x.device != m.device: + m = m.to(device=x.device) + r = r.to(device=x.device) + return x.sparse_mask(m).sparse_mask(r).to_dense(masked_grad=True) + + gradcheck( + op_masked, + (lhs,), + masked=True + ) + + def op_unmasked(x): + r = rhs + if x.device != r.device: + r = r.to(device=x.device) + return x.sparse_mask(r).to_dense(masked_grad=False) + + gradcheck( + op_unmasked, + (lhs, ), + masked=False + ) @coalescedonoff @dtypes(torch.double, torch.cdouble) From fec710bf89173f5355468a7ce1afe9157c3d9009 Mon Sep 17 00:00:00 2001 From: Agron Tsai Date: Tue, 2 Dec 2025 21:27:37 +0000 Subject: [PATCH 142/338] Triton 3.5+fb: Drop the legacy support for autoWS (#169089) Summary: https://www.internalfb.com/intern/testinfra/diagnostics/11258999198456358.562950213489292.1763806758/ Test Plan: buck test fbcode//mode/opt fbcode//caffe2/test/inductor:triton_heuristics -- --exact 'fbcode//caffe2/test/inductor:triton_heuristics - test_template_function_ws (caffe2.test.inductor.test_triton_heuristics.TestTritonHeuristics)' Differential Revision: D87881729 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169089 Approved by: https://github.com/htyu --- torch/_inductor/runtime/triton_compat.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/runtime/triton_compat.py b/torch/_inductor/runtime/triton_compat.py index faae38ea46dc1..49ceacb50bc3d 100644 --- a/torch/_inductor/runtime/triton_compat.py +++ b/torch/_inductor/runtime/triton_compat.py @@ -76,11 +76,8 @@ def _triton_config_has(param_name: str) -> bool: return False return param_name in inspect.signature(triton.Config.__init__).parameters - HAS_WARP_SPEC = ( - hasattr(tl, "async_task") - and _triton_config_has("num_consumer_groups") - and _triton_config_has("num_buffers_warp_spec") - ) + # Drop the legacy support of autoWS + HAS_WARP_SPEC = False try: from triton import knobs From 15da21026cb13cd20257dc9e96830db108743c10 Mon Sep 17 00:00:00 2001 From: Malay Bag Date: Tue, 2 Dec 2025 21:27:47 +0000 Subject: [PATCH 143/338] [torch.export] Copy common custom metadata of children node to parent (call_module) node (#167952) Summary: Same module can be used with and without dynamo disabled flag. So instead of marking the module as dynamo disabled, marking the call module which dynamo disabled. This can introduce a new submodule to encapsulate dynamo disabled nodes. For that reason, prune_pytree_flatten_unflatten need to be updated. Test Plan: ``` buck test mode/opt caffe2/test:test_export -- 'test_uplift_common_custom_meta' ``` Differential Revision: D86354586 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167952 Approved by: https://github.com/angelayi --- test/export/test_export.py | 103 ++++++++++++++++++++++++++++++++++--- torch/export/unflatten.py | 22 ++++++++ 2 files changed, 118 insertions(+), 7 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 1e1f40fba99df..92ea28c077e52 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -285,6 +285,28 @@ def get_hop_schema(ep: torch.export.ExportedProgram): return torch._library.utils.hop_schema_from_fx_node(hop_node) +def cleanup_dynamo_metadata(ep: torch.export.ExportedProgram) -> None: + for node in ep.graph.nodes: + if "custom" in node.meta: + node.meta["custom"] = { + k: v + for k, v in node.meta["custom"].items() + if "_torchdynamo_disable" not in k + } + + +def cleanup_dispatch_trace_metadata(mod: torch.export.ExportedProgram) -> None: + for node in mod.graph.nodes: + if ( + "custom" not in node.meta + or "_torchdynamo_disable_method" not in node.meta["custom"] + or node.meta["custom"]["_torchdynamo_disable_method"] + not in ["dispatch_trace", "trace"] + ): + continue + del node.meta["custom"] + + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestDynamismExpression(TestCase): def test_export_inline_constraints(self): @@ -742,13 +764,7 @@ def forward(self, x, y): # clean up _torchdynamo related meta data as it could vary depending on the caller # https://github.com/pytorch/pytorch/issues/167432 - for node in ep.graph.nodes: - if "custom" in node.meta: - node.meta["custom"] = { - k: v - for k, v in node.meta["custom"].items() - if "_torchdynamo_disable" not in k - } + cleanup_dynamo_metadata(ep) custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module()) @@ -762,6 +778,79 @@ def forward(self, x, y): ('call_function', 'mul', {'moo': 0})""", ) + def test_uplift_common_custom_meta(self) -> None: + class N(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + 2 + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.n = N() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + with torch.fx.traceback.annotate({"moo": 1}): + z = self.n(x) + 1 + return z @ y + + inp = (torch.rand(2, 2), torch.rand(2, 2)) + with torch.fx.traceback.preserve_node_meta(): + ep = torch.export.export(M(), inp) + cleanup_dynamo_metadata(ep) + unf = unflatten(ep) + unf_node_map = {node.name: node for node in unf.graph.nodes} + self.assertTrue("custom" in unf_node_map["n"].meta) + self.assertEqual(unf_node_map["n"].meta["custom"], {"moo": 1}) + for node in unf.n.graph.nodes: + self.assertTrue("custom" not in node.meta or not node.meta["custom"]) + + def test_uplift_common_custom_meta_with_multiple_calls(self) -> None: + class N(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(2, 2)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.buffer + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.n = N() + + @torch._dynamo.disable() + def foo1(self, x: torch.Tensor) -> torch.Tensor: + return self.n(x) @ x + + def foo2(self, x: torch.Tensor) -> torch.Tensor: + return self.n(x) * x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.foo1(x) + self.foo2(x) + self.foo1(x) + + m = M() + x = (torch.randn(2, 2),) + with torch.fx.traceback.preserve_node_meta(): + ep = torch.export.export(m, x) + cleanup_dispatch_trace_metadata(ep) + unf = torch.export.unflatten(ep) + unf_node_map = {node.name: node for node in unf.graph.nodes} + self.assertTrue("custom" in unf_node_map["n"].meta) + self.assertFalse("custom" in unf_node_map["n_1"].meta) + self.assertTrue("custom" in unf_node_map["n_2"].meta) + self.assertTrue("_torchdynamo_disable_method", unf_node_map["n"].meta["custom"]) + self.assertTrue( + "_torchdynamo_disable_method", unf_node_map["n_2"].meta["custom"] + ) + self.assertEqual( + unf_node_map["n"].meta["custom"]["_torchdynamo_disable_method"], "foo1" + ) + self.assertEqual( + unf_node_map["n_2"].meta["custom"]["_torchdynamo_disable_method"], "foo1" + ) + for node in unf.n.graph.nodes: + self.assertTrue("custom" not in node.meta or not node.meta["custom"]) + @requires_gpu def test_flex_attention_export(self): from torch.nn.attention.flex_attention import create_block_mask, flex_attention diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index a3f86fabceb7b..1af396e6bd29d 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -1278,6 +1278,27 @@ def remap_input(self, x): f"Could not run remap_input() on op type: {x.op} for node {x}" ) + def uplift_common_custom_metadata(self) -> None: + # Copy custom metadata if all nodes have same custom metadata + custom_meta = None + for node in self.node_map.values(): + curr_meta = node.meta.get("custom", {}) + if custom_meta is None: + # first node + custom_meta = curr_meta + continue + + if curr_meta != custom_meta: + custom_meta = {} + break + + if custom_meta: + # Lift common custom metadata to parent node and clear children node's custom metadata + assert self.parent_call_module is not None + self.parent_call_module.meta["custom"] = custom_meta + for node in self.node_map.values(): + del node.meta["custom"] + def finalize_outputs(self): self.created_modules.pop(self.fqn, None) @@ -1356,6 +1377,7 @@ def get_actual_output_node(output): if isinstance(graph_outputs, torch.fx.Node) else [o.meta.get("val") for o in graph_outputs] ) + self.uplift_common_custom_metadata() if len(orig_outputs) == 1 and signature is None: self.parent.node_map[orig_outputs[0]] = parent_out From c0660bcee27e7d7731634e274576a7081882bede Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 2 Dec 2025 21:44:00 +0000 Subject: [PATCH 144/338] [CI] Run inductor-unittests if workflow file is modified (#169398) Not sure how folks were testing it beforehand Pull Request resolved: https://github.com/pytorch/pytorch/pull/169398 Approved by: https://github.com/yangw-dev, https://github.com/atalman, https://github.com/seemethere, https://github.com/oulgen --- .github/workflows/inductor-unittest.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 0902026adb8ce..9c1dd3d82769d 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -7,9 +7,12 @@ on: workflow_call: schedule: - cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests. + pull_request: + paths: + - .github/workflows/inductor-unittest.yml concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-unittest + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-unittest cancel-in-progress: true permissions: From 6ec30b490aee1db6bcdc7340abddef25784f08ec Mon Sep 17 00:00:00 2001 From: Nolan O'Brien Date: Tue, 2 Dec 2025 21:50:41 +0000 Subject: [PATCH 145/338] [xplat][caffe2] Fix -Wswitch-default issues (#169022) Summary: **Context:** https://fburl.com/switch-enum ---- Now that `-Wswitch-enum` is an error in all of `fbobjc`, making it such that all enum values need a case within a switch regardless of having a default case, we want to also enable `-Wswitch-default` as an error (ensuring we cover all values outside the given enum's values). This will reduce SEVs as it will eliminate undefined behavior when a value outside the enum range is switched upon. Test Plan: CI Pass Differential Revision: D87817056 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169022 Approved by: https://github.com/shoumikhin --- torch/csrc/autograd/profiler_legacy.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h index 7753deec04a63..32bf4502330d5 100644 --- a/torch/csrc/autograd/profiler_legacy.h +++ b/torch/csrc/autograd/profiler_legacy.h @@ -96,8 +96,9 @@ struct TORCH_API LegacyEvent { return "pop"; case EventKind::MemoryAlloc: return "memory_alloc"; + default: + TORCH_CHECK(false, "unknown event kind"); } - TORCH_CHECK(false, "unknown event kind"); } EventKind kind() const { From 81af382128efa094d8702e18f2c133760904c718 Mon Sep 17 00:00:00 2001 From: Nick Riasanovsky Date: Tue, 2 Dec 2025 21:55:18 +0000 Subject: [PATCH 146/338] [GB300] [Triton] Disable experimental Triton API inside PyTorch when specified (#169014) Adds an extra import check to allow disabling the experimental API inside PyTorch for newer internal Triton versions that must maintain backwards capability due to existing user code. This should hopefully help lead to eventual deprecation and removal soon. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169014 Approved by: https://github.com/Sibylau, https://github.com/jananisriram --- torch/utils/_triton.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index f062f7e7508cb..98de7bbcc5868 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -45,7 +45,12 @@ def has_triton_experimental_host_tma() -> bool: create_2d_tma_descriptor, ) - return True + try: + from triton.tools.experimental_descriptor import enable_in_pytorch + + return enable_in_pytorch() + except ImportError: + return True except ImportError: pass From 45d14e2497292be06ad36eaa1aaaf7c630a2586a Mon Sep 17 00:00:00 2001 From: Prachi Gupta Date: Tue, 2 Dec 2025 21:59:29 +0000 Subject: [PATCH 147/338] [ROCm] Enable ZerO Optimizer UTs (#169077) Fixes #168436 Fixes #168437 Fixes #168438 Fixes #168439 Fixes #168440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169077 Approved by: https://github.com/jeffdaily --- test/distributed/optim/test_zero_redundancy_optimizer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index e26d67a1d9f1f..96d8d1d3225d1 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -37,7 +37,6 @@ requires_gloo, skip_if_lt_x_gpu, skip_if_no_gpu, - skip_if_rocm_multiprocess, skip_if_win32, ) from torch.testing._internal.common_utils import ( @@ -372,7 +371,6 @@ def _check_same_model_params( ) @skip_if_no_gpu - @skip_if_rocm_multiprocess def test_step(self): """Check that ZeroRedundancyOptimizer properly exposes the ``step()`` interface.""" @@ -412,7 +410,6 @@ def test_step(self): self.assertEqual(m.bias, m_zero.bias) @skip_if_no_gpu - @skip_if_rocm_multiprocess def test_step_with_closure(self): """Check that ZeroRedundancyOptimizer properly exposes the ``step(closure)`` interface.""" @@ -631,7 +628,6 @@ def test_multiple_param_groups(self): torch.testing.assert_close(layer1.bias, layer3.bias) @skip_if_no_gpu - @skip_if_rocm_multiprocess def test_collect_shards(self): """Check the state consolidation mechanism and the state dict exposed by ZeroRedundancyOptimizer.""" @@ -1357,7 +1353,6 @@ def _test_ddp_zero_overlap( @skip_if_win32() @requires_accelerator_dist_backend() @skip_if_no_gpu - @skip_if_rocm_multiprocess @parametrize( "use_gpu", [True], From 8ef0c0b02b062d75e7c9be2594914a3e784d23ca Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 2 Dec 2025 12:42:38 -0800 Subject: [PATCH 148/338] dist/debug: support py-spy (native+subprocess) stacks (#169147) This adds support for getting subprocess+native+Python stack traces if `py-spy` is installed in the Python environment. This handler is implemented in C++ and doesn't depend on Python GIL Test plan: run simple debug server script 20251202_12h42m17s_grim Pull Request resolved: https://github.com/pytorch/pytorch/pull/169147 Approved by: https://github.com/fduwjj --- test/distributed/test_debug.py | 9 +++- .../c10d/control_plane/Handlers.cpp | 38 ++++++++++++++++ torch/distributed/debug/_frontend.py | 43 ++++++++++++++++++- 3 files changed, 87 insertions(+), 3 deletions(-) diff --git a/test/distributed/test_debug.py b/test/distributed/test_debug.py index 1c9dfcf96b83f..9a4f9eebfc159 100644 --- a/test/distributed/test_debug.py +++ b/test/distributed/test_debug.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import os +import shutil import requests from requests.adapters import HTTPAdapter @@ -27,7 +28,7 @@ def test_all(self) -> None: os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" - port = 25999 + port = 25998 def fetch(path: str) -> str: resp = session.get(f"http://localhost:{port}{path}") @@ -68,6 +69,12 @@ def fetch(path: str) -> str: self.assertIn("test: b'value'", out) self.assertIn("test2: b'" + "a" * 95 + "...", out) + with self.subTest("pyspy"): + if shutil.which("py-spy"): + self.assertIn("test_all", fetch("/pyspy_dump")) + self.assertIn("_frontend", fetch("/pyspy_dump?subprocesses=1")) + self.assertIn("libc.so", fetch("/pyspy_dump?native=1")) + stop_debug_server() diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index 5e5c3195046cb..138dc9b0fe2c5 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -91,6 +91,44 @@ RegisterHandler waitCounterHandler{ }(); #endif +#ifndef _WIN32 +RegisterHandler pyspyHandler{ + "pyspy_dump", + [](const Request& req, Response& res) { + pid_t target = getpid(); + std::string cmd = "py-spy dump"; + cmd += " --pid " + std::to_string(target); + if (req.getParam("native") != "") { + cmd += " --native"; + } + if (req.getParam("subprocesses") != "") { + cmd += " --subprocesses"; + } + if (req.getParam("nonblocking") != "") { + cmd += " --nonblocking"; + } + cmd += " 2>&1"; + std::array buf{}; + std::string output; + FILE* pipe = popen(cmd.c_str(), "r"); + if (!pipe) { + throw std::runtime_error("Failed to start py-spy, not installed?"); + } + while (fgets(buf.data(), buf.size(), pipe)) { + output.append(buf.data()); + } + int rc = pclose(pipe); + + // Get all wait counter values from our tracking backend + res.setContent(std::move(output), "text/plain"); + if (rc != 0) { + res.setStatus(500); + } else { + res.setStatus(200); + } + }}; +#endif + } // namespace void registerHandler(const std::string& name, HandlerFunc f) { diff --git a/torch/distributed/debug/_frontend.py b/torch/distributed/debug/_frontend.py index d31d3e734c28b..16cccb88632f0 100644 --- a/torch/distributed/debug/_frontend.py +++ b/torch/distributed/debug/_frontend.py @@ -148,6 +148,7 @@ def format_json(blob: str): Home Python Stack Traces + py-spy Stacks FlightRecorder CPU (JSON) FlightRecorder NCCL @@ -212,7 +213,7 @@ def format_json(blob: str): {% endblock %} {% block content %} -
+ @@ -296,6 +297,31 @@ def format_json(blob: str): {{ collectives | safe }}

NCCL Calls

{{ ncclcalls | safe }} +{% endblock %} + """, + "pyspy_dump.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}py-spy Stack Traces{% endblock %}

+{% endblock %} +{% block content %} + + + + + + +
+ + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} {% endblock %} """, } @@ -323,7 +349,10 @@ def get_path(self) -> str: return urlparse(self.path).path def get_query(self) -> dict[str, list[str]]: - return parse_qs(urlparse(self.path).query) + return parse_qs(self.get_raw_query()) + + def get_raw_query(self) -> str: + return urlparse(self.path).query def get_query_arg( self, name: str, default: object = None, type: type = str @@ -349,6 +378,7 @@ def __init__(self, port: int): self._routes = { "/": self._handle_index, "/stacks": self._handle_stacks, + "/pyspy_dump": self._handle_pyspy_dump, "/fr_trace": self._handle_fr_trace, "/fr_trace_json": self._handle_fr_trace_json, "/fr_trace_nccl": self._handle_fr_trace_nccl, @@ -372,6 +402,7 @@ def __init__(self, port: int): target=self._serve, args=(), daemon=True, + name="distributed.debug.FrontendServer", ) self._thread.start() @@ -419,6 +450,14 @@ def _handle_stacks(self, req: HTTPRequestHandler) -> bytes: "raw_resp.html", title="Stacks", addrs=addrs, resps=resps ) + def _handle_pyspy_dump(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("pyspy_dump", req.get_raw_query()) + return self._render_template( + "pyspy_dump.html", + addrs=addrs, + resps=resps, + ) + def _render_fr_trace(self, addrs: list[str], resps: list[Response]) -> bytes: config = JobConfig() # pyrefly: ignore [bad-assignment] From 31fc12773026e8e00f054dd79ad9b2491e693b48 Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 1 Dec 2025 23:09:31 +0000 Subject: [PATCH 149/338] [dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems (#167888) Fixes https://github.com/pytorch/pytorch/issues/167789 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167888 Approved by: https://github.com/malfet, https://github.com/colesbury --- test/dynamo/test_repros.py | 62 ++++++++++++++++++++++++++++ torch/_C/_dynamo/eval_frame.pyi | 2 + torch/_dynamo/__init__.py | 30 ++++++++++++++ torch/csrc/dynamo/eval_frame.c | 5 ++- torch/csrc/dynamo/eval_frame_cpp.cpp | 61 ++++++++++++++++++++++++++- torch/csrc/dynamo/eval_frame_cpp.h | 7 +++- torch/csrc/dynamo/init.cpp | 4 ++ 7 files changed, 166 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 17e8c15863e27..900d712ccf70b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7468,6 +7468,68 @@ def forward(self, x): msg, ) + @unittest.skipIf( + sys.version_info < (3, 12) or sys.version_info >= (3, 14), + "only 3.12, 3.13 affected by c recursion limit", + ) + def test_dynamo_set_recursion_limit(self): + old_recursion_limit = sys.getrecursionlimit() + old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit() + try: + + def fn(x, n): + if n == 0: + return x + return fn(x, n - 1) + 1 + + sys.setrecursionlimit(100) + + with self.assertRaises(RecursionError): + fn(torch.ones(3), 500) + + sys.setrecursionlimit(1000) + + fn(torch.ones(3), 500) + opt_fn = torch.compile(fn, backend="eager", dynamic=False) + sys.setrecursionlimit(20000) + with self.assertRaises(Exception): + opt_fn(torch.ones(3), 500) + + torch._dynamo.set_recursion_limit(20000) + self.assertEqual(fn(torch.ones(3), 500), opt_fn(torch.ones(3), 500)) + finally: + if old_dynamo_recursion_limit > 0: + torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit) + sys.setrecursionlimit(old_recursion_limit) + + @unittest.skipIf( + sys.version_info < (3, 12) or sys.version_info >= (3, 14), + "only 3.12, 3.13 affected by c recursion limit", + ) + def test_dynamo_set_recursion_limit_usage(self): + old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit() + try: + torch._dynamo.set_recursion_limit(500) + self.assertEqual(torch._dynamo.get_recursion_limit(), 500) + + @torch.compile(backend="eager", dynamic=False) + def fn(x, n): + if n == 0: + return x + return fn(x, n - 1) + 1 + + # a limit of 500 should be lower than the default limit + with self.assertWarnsRegex(RuntimeWarning, "new c_recursion limit"): + fn(torch.ones(3), 5) + + with self.assertRaisesRegex(ValueError, "recursion limit"): + torch._dynamo.set_recursion_limit(0) + + self.assertEqual(torch._dynamo.get_recursion_limit(), 500) + finally: + if old_dynamo_recursion_limit > 0: + torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit) + @expectedFailureDynamic def test_dynamo_default_lru_cache_behavior(self): @torch.compile(backend="eager") diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 3c3a18ed4e063..060bf2638e096 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -20,6 +20,8 @@ def set_guard_complete_hook( hook: Optional[DynamoGuardCompleteHook], ) -> Optional[DynamoGuardCompleteHook]: ... def raise_sigtrap() -> None: ... +def set_c_recursion_limit(limit: int) -> None: ... +def get_c_recursion_limit() -> int: ... class _CacheEntry: def check_fn(self, *args: object, **kwargs: object) -> bool: ... diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index e9a5e8d89d07c..532659d4e7fbb 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -107,6 +107,7 @@ "reset_recompile_user_contexts", "run", "error_on_graph_break", + "set_recursion_limit", "set_stance", "skip_frame", "step_unsupported", @@ -183,3 +184,32 @@ def reset_code_caches() -> None: if code: reset_code(code) code_context.clear() + + +def get_recursion_limit() -> int: + """ + Returns the internal dynamo recursion limit set by `torch._dynamo.set_recursion_limit`. + + Returns -1 if no c recursion limit has been set. + """ + return torch._C._dynamo.eval_frame.get_c_recursion_limit() + + +def set_recursion_limit(limit: int) -> None: + """ + Sets an internal dynamo recursion limit. The limit must be >= 1. + + This is possibly needed in Python 3.12-3.13 since there is a separate C recursion limit + that is not visible at the Python level. If you are getting RecursionErrors during + Dynamo compilation and `sys.setrecursionlimit()` doesn't help, this function may alleviate + the issue. + + NOTE: this function does NOT call `sys.setrecursionlimit()` - the user is expected to manually + call this if required. This is because the 2 recursion limits are not sync'd up - e.g. in + Python 3.12, functions can be inline-evaluated, which apparently doesn't use up the C stack. + + WARNING: increasing the recursion limit to an arbitrary large value may cause segfaults + due to stack overflows! You can try also try to manually increase the stack size, e.g. + with `$ ulimit -s ...` + """ + torch._C._dynamo.eval_frame.set_c_recursion_limit(limit) diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index b08fffedaa014..58cb48de664d5 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -733,7 +733,10 @@ static PyMethodDef _methods[] = { {"get_eval_frame_callback", get_eval_frame_callback_py, METH_NOARGS, NULL}, {"reset_code", reset_code, METH_O, NULL}, {"unsupported", unsupported, METH_VARARGS, NULL}, - {"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL}, + {"set_code_exec_strategy", + dynamo_set_code_exec_strategy, + METH_VARARGS, + NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, {"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL}, {"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL}, diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index e678bc7bad04a..eb99cd7d067a4 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -50,6 +50,56 @@ static py::handle _callback_from_action( return callback; } +// c_recursion_remaining only defined in 3.12 and 3.13 + +static int32_t c_recursion_limit = -1; + +void dynamo_set_c_recursion_limit(int32_t limit) { + if (limit < 1) { + throw std::range_error("recursion limit must be greater or equal than 1"); + } + c_recursion_limit = limit; +} + +int32_t dynamo_get_c_recursion_limit() { + return c_recursion_limit; +} + +#if IS_PYTHON_3_12_PLUS && !IS_PYTHON_3_14_PLUS + +struct CRecursionLimitRAII { + PyThreadState* tstate; + int32_t old_recursion_remaining; + CRecursionLimitRAII(PyThreadState* tstate) : tstate{tstate} { + auto limit = dynamo_get_c_recursion_limit(); + auto& remaining = tstate->c_recursion_remaining; + this->old_recursion_remaining = remaining; + if (limit < 0) { + // no change to limit + return; + } + if (limit < remaining) { + std::stringstream ss; + ss << "new c_recursion limit (" << limit + << ") is lower than thread's current c_recursion_remaining (" + << remaining << ")."; + PyErr_WarnEx(PyExc_RuntimeWarning, ss.str().c_str(), 1); + } + remaining = limit; + } + ~CRecursionLimitRAII() { + this->tstate->c_recursion_remaining = this->old_recursion_remaining; + } +}; + +#else + +struct CRecursionLimitRAII { + CRecursionLimitRAII(PyThreadState* tstate) {} +}; + +#endif + // frame and callback are borrowed references. // Returns new reference. PyObject* dynamo__custom_eval_frame( @@ -258,6 +308,13 @@ PyObject* dynamo__custom_eval_frame( bool apply_to_code = false; PyObject* guarded_code = nullptr; try { + CRecursionLimitRAII tmp(tstate); // increase C recursion limit to the given + // value during compilation + // C recursion limit failure + if (PyErr_Occurred()) { + fail(); + return eval_result; + } callback_result = dynamo_call_callback( callback, frame, locals.get(), cache_entry, frame_state); new_strategy = @@ -320,7 +377,7 @@ PyObject* dynamo__custom_eval_frame( return eval_result; } -PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) { +PyObject* dynamo_set_code_exec_strategy(PyObject* dummy, PyObject* args) { PyObject* code_obj = nullptr; PyObject* strategy_obj = nullptr; if (!PyArg_ParseTuple(args, "OO", &code_obj, &strategy_obj)) { @@ -344,7 +401,7 @@ PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) { Py_RETURN_NONE; } -void skip_code_recursive(PyCodeObject* code) { +void dynamo_skip_code_recursive(PyCodeObject* code) { ExtraState* extra = get_extra_state(code); if (extra == nullptr) { extra = init_and_set_extra_state(code); diff --git a/torch/csrc/dynamo/eval_frame_cpp.h b/torch/csrc/dynamo/eval_frame_cpp.h index 2f3587094f763..8cc1ab7618b3d 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.h +++ b/torch/csrc/dynamo/eval_frame_cpp.h @@ -16,8 +16,11 @@ PyObject* dynamo__custom_eval_frame( int throw_flag, PyObject* callback); -PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj); -void skip_code_recursive(PyCodeObject* code); +PyObject* dynamo_set_code_exec_strategy(PyObject* dummy, PyObject* obj); +void dynamo_skip_code_recursive(PyCodeObject* code); + +void dynamo_set_c_recursion_limit(int32_t limit); +int32_t dynamo_get_c_recursion_limit(); #ifdef __cplusplus diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 9ed9a465642c3..69d6e0555ceb4 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -251,6 +252,9 @@ void initDynamoBindings(PyObject* torch) { .def_readwrite("cur_action", &FrameExecStrategy::cur_action) .def_readwrite("recursive_action", &FrameExecStrategy::recursive_action); + m.def("set_c_recursion_limit", &dynamo_set_c_recursion_limit); + m.def("get_c_recursion_limit", &dynamo_get_c_recursion_limit); + m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list); m.def("_reset_precompile_entries", &_reset_precompile_entries); m.def("_load_precompile_entry", &_load_precompile_entry); From 7f55ba19c456a3d6cc443dd9edb6bb7cca677ead Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 1 Dec 2025 23:09:32 +0000 Subject: [PATCH 150/338] [dynamo, guards] add guard builder microbenchmark (#169087) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169087 Approved by: https://github.com/anijain2305 ghstack dependencies: #167888 --- .../microbenchmarks/dynamo_guard_build.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 benchmarks/dynamo/microbenchmarks/dynamo_guard_build.py diff --git a/benchmarks/dynamo/microbenchmarks/dynamo_guard_build.py b/benchmarks/dynamo/microbenchmarks/dynamo_guard_build.py new file mode 100644 index 0000000000000..b61a2bd3b3465 --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/dynamo_guard_build.py @@ -0,0 +1,50 @@ +import sys +import time + +import torch + + +class Foo: + pass + + +obj = Foo() + +DEPTH = 2000 + +attrs = [f"attr{i}" for i in range(DEPTH)] + +for i, attr in enumerate(attrs): + setattr(obj, attr, i) + +lst = obj + +for _ in range(DEPTH): + lst = [lst] + +sys.setrecursionlimit(100000) +torch._dynamo.set_recursion_limit(1000000) + + +@torch.compile(backend="eager") +def fn(x): + unpacked = lst + for _ in range(DEPTH): + unpacked = unpacked[0] + for i in range(DEPTH): + x = x + getattr(unpacked, f"attr{i}") + return x + + +def main(): + opt_fn = torch.compile(fn, backend="eager") + + start = time.perf_counter() + opt_fn(torch.randn(3)) + end = time.perf_counter() + + print(f"total time: {end - start:.2f}s") + + +if __name__ == "__main__": + main() From ef8ecc13830a86c4b231f1aad9aba7851db61b53 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 3 Dec 2025 00:07:01 +0000 Subject: [PATCH 151/338] [11/N] Use Python 3.10 typing (#169335) This PR applies Python 3.10 typing syntax to some files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169335 Approved by: https://github.com/Lucaskabela --- torch/optim/_adafactor.py | 64 +++++++++++++++++----------------- torch/optim/_muon.py | 13 +++---- torch/optim/adadelta.py | 8 ++--- torch/optim/adagrad.py | 28 +++++++-------- torch/optim/adam.py | 68 ++++++++++++++++++------------------- torch/optim/adamax.py | 10 +++--- torch/optim/adamw.py | 23 ++++++------- torch/optim/asgd.py | 14 ++++---- torch/optim/lbfgs.py | 7 ++-- torch/optim/lr_scheduler.py | 47 ++++++++++--------------- torch/optim/nadam.py | 14 ++++---- torch/optim/optimizer.py | 47 ++++++++++++------------- torch/optim/radam.py | 14 ++++---- torch/optim/rmsprop.py | 8 ++--- torch/optim/rprop.py | 8 ++--- torch/optim/sgd.py | 40 +++++++++++----------- torch/optim/sparse_adam.py | 3 +- torch/optim/swa_utils.py | 21 ++++++------ 18 files changed, 210 insertions(+), 227 deletions(-) diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index c417b354429b5..6aed25a36aa82 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -from typing import cast, Optional, TYPE_CHECKING, Union +from typing import cast, TYPE_CHECKING import torch from torch import Tensor @@ -24,13 +24,13 @@ class Adafactor(Optimizer): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-2, + lr: float | Tensor = 1e-2, beta2_decay: float = -0.8, - eps: tuple[Optional[float], float] = (None, 1e-3), + eps: tuple[float | None, float] = (None, 1e-3), d: float = 1.0, weight_decay: float = 0.0, *, - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: @@ -136,9 +136,9 @@ def step(self, closure=None): for group in self.param_groups: params_with_grad: list[Tensor] = [] grads: list[Tensor] = [] - row_vars: list[Optional[Tensor]] = [] - col_vars: list[Optional[Tensor]] = [] - variances: list[Optional[Tensor]] = [] + row_vars: list[Tensor | None] = [] + col_vars: list[Tensor | None] = [] + variances: list[Tensor | None] = [] state_steps: list[Tensor] = [] eps1, eps2 = group["eps"] @@ -334,18 +334,18 @@ def _single_tensor_adafactor( # so row_var and col_var will be None while variance will be filled. # Contrarily, for a grad with multiple dimensions, we will factor along the last # 2 dimensions, and so row_var and col_var will be filled and variance will be None. - row_vars: list[Optional[Tensor]], - col_vars: list[Optional[Tensor]], - variances: list[Optional[Tensor]], + row_vars: list[Tensor | None], + col_vars: list[Tensor | None], + variances: list[Tensor | None], state_steps: list[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, d: float, - lr: Union[Tensor, float], + lr: Tensor | float, beta2_decay: float, weight_decay: float, - eps1: Optional[float], + eps1: float | None, eps2: float, maximize: bool, has_complex: bool, @@ -419,16 +419,16 @@ def _single_tensor_adafactor( def _group_tensors_by_device_dtype_and_is_multidim( tensorlists: TensorListList, ) -> dict[ - tuple[Optional[torch.device], Optional[torch.dtype], bool], - list[list[Optional[Tensor]]], + tuple[torch.device | None, torch.dtype | None, bool], + list[list[Tensor | None]], ]: """Groups tensors by device, dtype, AND multidimensionality -- whether the tensor has multiple dims or just one dim (is a vector). This allows the foreach impl of Adafactor to assume that every group of params will either be factored or not.""" grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists) ultra_grouped_tensors: dict[ - tuple[Optional[torch.device], Optional[torch.dtype], bool], - list[list[Optional[Tensor]]], + tuple[torch.device | None, torch.dtype | None, bool], + list[list[Tensor | None]], ] = {} for (device, dtype), (tensorlists, _) in grouped_tensors.items(): matrix_key = (device, dtype, True) @@ -458,18 +458,18 @@ def _multi_tensor_adafactor( # so row_var and col_var will be None while variance will be filled. # Contrarily, for a grad with multiple dimensions, we will factor along the last # 2 dimensions, and so row_var and col_var will be filled and variance will be None. - row_vars: list[Optional[Tensor]], - col_vars: list[Optional[Tensor]], - variances: list[Optional[Tensor]], + row_vars: list[Tensor | None], + col_vars: list[Tensor | None], + variances: list[Tensor | None], state_steps: list[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, d: float, - lr: Union[Tensor, float], + lr: Tensor | float, beta2_decay: float, weight_decay: float, - eps1: Optional[float], + eps1: float | None, eps2: float, maximize: bool, has_complex: bool, @@ -606,19 +606,19 @@ def _multi_tensor_adafactor( def adafactor( params: list[Tensor], grads: list[Tensor], - row_vars: list[Optional[Tensor]], - col_vars: list[Optional[Tensor]], - variances: list[Optional[Tensor]], + row_vars: list[Tensor | None], + col_vars: list[Tensor | None], + variances: list[Tensor | None], state_steps: list[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, + foreach: bool | None = None, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, has_complex: bool = False, *, d: float, - lr: Union[float, Tensor], + lr: float | Tensor, beta2_decay: float, weight_decay: float, eps1: float, diff --git a/torch/optim/_muon.py b/torch/optim/_muon.py index 5b7b9892daf3a..e441c8b911b2f 100644 --- a/torch/optim/_muon.py +++ b/torch/optim/_muon.py @@ -4,7 +4,6 @@ import math from collections.abc import MutableMapping -from typing import Optional import torch from torch import Tensor @@ -71,9 +70,7 @@ def _zeropower_via_newtonschulz( return ortho_grad -def _adjust_lr( - lr: float, adjust_lr_fn: Optional[str], param_shape: torch.Size -) -> float: +def _adjust_lr(lr: float, adjust_lr_fn: str | None, param_shape: torch.Size) -> float: """Default learning rate adjustment used by Muon.""" A, B = param_shape[:2] @@ -98,7 +95,7 @@ def __init__( ns_coefficients: tuple[float, float, float] = (DEFAULT_A, DEFAULT_B, DEFAULT_C), eps: float = EPS, ns_steps: int = DEFAULT_NS_STEPS, - adjust_lr_fn: Optional[str] = None, + adjust_lr_fn: str | None = None, ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") @@ -297,7 +294,7 @@ def _single_tensor_muon( ns_coefficients: tuple[float, float, float], ns_steps: int, eps: float, - adjust_lr_fn: Optional[str], + adjust_lr_fn: str | None, has_complex: bool, ) -> None: lr = _to_scalar(lr) @@ -327,7 +324,7 @@ def muon( grads: list[Tensor], muon_momentum_bufs: list[Tensor], *, - foreach: Optional[bool] = None, + foreach: bool | None = None, lr: float, weight_decay: float, momentum: float, @@ -335,7 +332,7 @@ def muon( ns_coefficients: tuple[float, float, float], ns_steps: int, eps: float, - adjust_lr_fn: Optional[str], + adjust_lr_fn: str | None, has_complex: bool, ) -> None: r"""Functional API that performs Muon algorithm computation. diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 75ac77790e309..1ee27f46f194d 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, cast, Optional, Union +from typing import Any, cast import torch from torch import Tensor @@ -29,11 +29,11 @@ class Adadelta(Optimizer): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1.0, + lr: float | Tensor = 1.0, rho: float = 0.9, eps: float = 1e-6, weight_decay: float = 0, - foreach: Optional[bool] = None, + foreach: bool | None = None, *, capturable: bool = False, maximize: bool = False, @@ -418,7 +418,7 @@ def adadelta( # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim capturable: bool = False, - foreach: Optional[bool] = None, + foreach: bool | None = None, differentiable: bool = False, has_complex: bool = False, *, diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 519900ab5da63..a6a57fb61b8ba 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -28,16 +28,16 @@ class Adagrad(Optimizer): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-2, + lr: float | Tensor = 1e-2, lr_decay: float = 0, weight_decay: float = 0, initial_accumulator_value: float = 0, eps: float = 1e-10, - foreach: Optional[bool] = None, + foreach: bool | None = None, *, maximize: bool = False, differentiable: bool = False, - fused: Optional[bool] = None, + fused: bool | None = None, ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") @@ -246,13 +246,13 @@ def adagrad( grads: list[Tensor], state_sums: list[Tensor], state_steps: list[Tensor], - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, + fused: bool | None = None, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting these as kwargs for now as functional API is compiled by torch/distributed/optim has_sparse_grad: bool = False, - foreach: Optional[bool] = None, + foreach: bool | None = None, differentiable: bool = False, has_complex: bool = False, *, @@ -325,8 +325,8 @@ def _single_tensor_adagrad( grads: list[Tensor], state_sums: list[Tensor], state_steps: list[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, lr: float, weight_decay: float, @@ -393,8 +393,8 @@ def _multi_tensor_adagrad( grads: list[Tensor], state_sums: list[Tensor], state_steps: list[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, lr: float, weight_decay: float, @@ -504,8 +504,8 @@ def _fused_adagrad( grads: list[Tensor], state_sums: list[Tensor], state_steps: list[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, lr: float, weight_decay: float, diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 6b8fd5b7e70f6..64c23e7ddf391 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -35,17 +35,17 @@ class Adam(Optimizer): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-3, - betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), + lr: float | Tensor = 1e-3, + betas: tuple[float | Tensor, float | Tensor] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, amsgrad: bool = False, *, - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, - fused: Optional[bool] = None, + fused: bool | None = None, decoupled_weight_decay: bool = False, ) -> None: if isinstance(lr, Tensor): @@ -351,14 +351,14 @@ def _single_tensor_adam( exp_avg_sqs: list[Tensor], max_exp_avg_sqs: list[Tensor], state_steps: list[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, amsgrad: bool, has_complex: bool, - beta1: Union[float, Tensor], - beta2: Union[float, Tensor], - lr: Union[float, Tensor], + beta1: float | Tensor, + beta2: float | Tensor, + lr: float | Tensor, weight_decay: float, eps: float, maximize: bool, @@ -389,7 +389,7 @@ def _single_tensor_adam( # Note: ensure type declaration is under conditional check for isinstance # or else torchscript will get cranky about the DeviceDict type. if isinstance(beta1, Tensor): - beta1_dict: Optional[DeviceDtypeDict] = {(beta1.device, beta1.dtype): beta1} + beta1_dict: DeviceDtypeDict | None = {(beta1.device, beta1.dtype): beta1} else: beta1_dict = None @@ -448,7 +448,7 @@ def _single_tensor_adam( device=device, dtype=dtype, non_blocking=True ) - device_beta1: Union[float, Tensor] = beta1_dict[key] + device_beta1: float | Tensor = beta1_dict[key] else: device_beta1 = beta1 @@ -558,14 +558,14 @@ def _multi_tensor_adam( exp_avg_sqs: list[Tensor], max_exp_avg_sqs: list[Tensor], state_steps: list[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, amsgrad: bool, has_complex: bool, - beta1: Union[float, Tensor], - beta2: Union[float, Tensor], - lr: Union[float, Tensor], + beta1: float | Tensor, + beta2: float | Tensor, + lr: float | Tensor, weight_decay: float, eps: float, maximize: bool, @@ -630,7 +630,7 @@ def _multi_tensor_adam( # We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - beta1_dict: Optional[DeviceDict] = ( # type: ignore[attr-defined] + beta1_dict: DeviceDict | None = ( # type: ignore[attr-defined] {beta1.device: beta1} if isinstance(beta1, Tensor) and str(beta1.device) != "cpu" else None @@ -727,9 +727,9 @@ def _multi_tensor_adam( del device_grads del scaled_device_grads - bias_correction1: Union[tuple[Tensor, ...], list[Tensor]] - bias_correction2: Union[tuple[Tensor, ...], list[Tensor]] - bias_correction2_sqrt: Union[tuple[Tensor, ...], list[Tensor]] + bias_correction1: tuple[Tensor, ...] | list[Tensor] + bias_correction2: tuple[Tensor, ...] | list[Tensor] + bias_correction2_sqrt: tuple[Tensor, ...] | list[Tensor] if capturable: bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type] @@ -807,14 +807,14 @@ def _fused_adam( exp_avg_sqs: list[Tensor], max_exp_avg_sqs: list[Tensor], state_steps: list[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, amsgrad: bool, has_complex: bool, # Needed for consistency. - beta1: Union[float, Tensor], - beta2: Union[float, Tensor], - lr: Union[float, Tensor], + beta1: float | Tensor, + beta2: float | Tensor, + lr: float | Tensor, weight_decay: float, eps: float, maximize: bool, @@ -839,7 +839,7 @@ def _fused_adam( # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ( + lr_dict: DeviceDict | None = ( {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None ) grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( @@ -909,19 +909,19 @@ def adam( state_steps: list[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, + foreach: bool | None = None, capturable: bool = False, differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, + fused: bool | None = None, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, has_complex: bool = False, decoupled_weight_decay: bool = False, *, amsgrad: bool, - beta1: Union[float, Tensor], - beta2: Union[float, Tensor], - lr: Union[float, Tensor], + beta1: float | Tensor, + beta2: float | Tensor, + lr: float | Tensor, weight_decay: float, eps: float, maximize: bool, diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 264451dbb4091..320ee97d14e5a 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -30,11 +30,11 @@ class Adamax(Optimizer): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 2e-3, + lr: float | Tensor = 2e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, - foreach: Optional[bool] = None, + foreach: bool | None = None, *, maximize: bool = False, differentiable: bool = False, @@ -402,7 +402,7 @@ def _multi_tensor_adamax( torch._foreach_add_(grouped_grads, eps) torch._foreach_maximum_(grouped_exp_infs, grouped_grads) - bias_corrections: Union[tuple[Tensor, ...], list[Tensor]] + bias_corrections: tuple[Tensor, ...] | list[Tensor] if capturable: bias_corrections = torch._foreach_pow(beta1, grouped_state_steps) # foreach_sub doesn't allow a scalar as the first arg @@ -430,7 +430,7 @@ def adamax( state_steps: list[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, differentiable: bool = False, capturable: bool = False, diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 2c968fabb698c..aa3b922cf90b4 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional, Union from torch import Tensor @@ -22,17 +21,17 @@ class AdamW(Adam): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-3, - betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999), + lr: float | Tensor = 1e-3, + betas: tuple[float | Tensor, float | Tensor] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, amsgrad: bool = False, *, maximize: bool = False, - foreach: Optional[bool] = None, + foreach: bool | None = None, capturable: bool = False, differentiable: bool = False, - fused: Optional[bool] = None, + fused: bool | None = None, ) -> None: super().__init__( params, @@ -137,18 +136,18 @@ def adamw( state_steps: list[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, + foreach: bool | None = None, capturable: bool = False, differentiable: bool = False, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, + fused: bool | None = None, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, has_complex: bool = False, *, amsgrad: bool, - beta1: Union[float, Tensor], - beta2: Union[float, Tensor], - lr: Union[float, Tensor], + beta1: float | Tensor, + beta2: float | Tensor, + lr: float | Tensor, weight_decay: float, eps: float, maximize: bool, diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 0af7f9b4e6f6d..19f2e6e25beba 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -30,12 +30,12 @@ class ASGD(Optimizer): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-2, + lr: float | Tensor = 1e-2, lambd: float = 1e-4, alpha: float = 0.75, t0: float = 1e6, weight_decay: float = 0, - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, differentiable: bool = False, capturable: bool = False, @@ -355,7 +355,7 @@ def _multi_tensor_asgd( torch._foreach_add_(grouped_state_steps, 1) # intermediate = grad + param * lambd - intermediate: Union[tuple[Tensor, ...], list[Tensor]] + intermediate: tuple[Tensor, ...] | list[Tensor] if weight_decay != 0: if maximize: torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) @@ -390,8 +390,8 @@ def _multi_tensor_asgd( torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus) del intermediate - new_etas: Union[tuple[Tensor, ...], list[Tensor]] - new_mus: Union[tuple[Tensor, ...], list[Tensor]] + new_etas: tuple[Tensor, ...] | list[Tensor] + new_mus: tuple[Tensor, ...] | list[Tensor] if capturable: # update grouped_mus new_mus = torch._foreach_sub(grouped_state_steps, t0) @@ -431,7 +431,7 @@ def asgd( state_steps: list[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, differentiable: bool = False, capturable: bool = False, diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 3d138f6a43f76..ed4cf1a8b2e88 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional, Union import torch from torch import Tensor @@ -247,13 +246,13 @@ class LBFGS(Optimizer): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1, + lr: float | Tensor = 1, max_iter: int = 20, - max_eval: Optional[int] = None, + max_eval: int | None = None, tolerance_grad: float = 1e-7, tolerance_change: float = 1e-9, history_size: int = 100, - line_search_fn: Optional[str] = None, + line_search_fn: str | None = None, ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 6426283e6542c..208a182bb1770 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -9,16 +9,7 @@ from bisect import bisect_right from collections import Counter from functools import partial, wraps -from typing import ( - Any, - cast, - Literal, - Optional, - SupportsFloat, - TYPE_CHECKING, - TypedDict, - Union, -) +from typing import Any, cast, Literal, SupportsFloat, TYPE_CHECKING, TypedDict from typing_extensions import override, Self from weakref import ref @@ -244,7 +235,7 @@ def get_lr(self) -> list[float | Tensor]: """ raise NotImplementedError - def step(self, epoch: Optional[int] = None) -> None: + def step(self, epoch: int | None = None) -> None: """Step the scheduler. Args: @@ -290,7 +281,7 @@ def step(self, epoch: Optional[int] = None) -> None: warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning, stacklevel=2) self._update_lr(epoch) - def _update_lr(self, epoch: Optional[int] = None) -> None: + def _update_lr(self, epoch: int | None = None) -> None: with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 @@ -298,9 +289,7 @@ def _update_lr(self, epoch: Optional[int] = None) -> None: else: self.last_epoch = epoch if hasattr(self, "_get_closed_form_lr"): - values = cast( - list[Union[float, Tensor]], self._get_closed_form_lr() - ) + values = cast(list[float | Tensor], self._get_closed_form_lr()) else: values = self.get_lr() @@ -389,7 +378,7 @@ class LambdaLR(LRScheduler): def __init__( self, optimizer: Optimizer, - lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]], + lr_lambda: Callable[[int], float] | list[Callable[[int], float]], last_epoch: int = -1, ) -> None: # noqa: D107 self.optimizer = optimizer @@ -505,7 +494,7 @@ class MultiplicativeLR(LRScheduler): def __init__( self, optimizer: Optimizer, - lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]], + lr_lambda: Callable[[int], float] | list[Callable[[int], float]], last_epoch: int = -1, ) -> None: # noqa: D107 self.optimizer = optimizer @@ -1519,7 +1508,7 @@ class ChainedScheduler(LRScheduler): """ def __init__( - self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None + self, schedulers: Sequence[LRScheduler], optimizer: Optimizer | None = None ) -> None: # noqa: D107 if len(schedulers) < 1: raise ValueError( @@ -1659,7 +1648,7 @@ def __init__( threshold: float = 1e-4, threshold_mode: Literal["rel", "abs"] = "rel", cooldown: int = 0, - min_lr: Union[list[float], float] = 0, + min_lr: list[float] | float = 0, eps: float = 1e-8, ) -> None: # noqa: D107 if factor >= 1.0: @@ -1894,13 +1883,13 @@ class CyclicLR(LRScheduler): def __init__( self, optimizer: Optimizer, - base_lr: Union[float, list[float]], - max_lr: Union[float, list[float]], + base_lr: float | list[float], + max_lr: float | list[float], step_size_up: int = 2000, - step_size_down: Optional[int] = None, + step_size_down: int | None = None, mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", gamma: float = 1.0, - scale_fn: Optional[Callable[[float], float]] = None, + scale_fn: Callable[[float], float] | None = None, scale_mode: Literal["cycle", "iterations"] = "cycle", cycle_momentum: bool = True, base_momentum: float = 0.8, @@ -2396,15 +2385,15 @@ class OneCycleLR(LRScheduler): def __init__( self, optimizer: Optimizer, - max_lr: Union[float, list[float]], - total_steps: Optional[int] = None, - epochs: Optional[int] = None, - steps_per_epoch: Optional[int] = None, + max_lr: float | list[float], + total_steps: int | None = None, + epochs: int | None = None, + steps_per_epoch: int | None = None, pct_start: float = 0.3, anneal_strategy: Literal["cos", "linear"] = "cos", cycle_momentum: bool = True, - base_momentum: Union[float, list[float]] = 0.85, - max_momentum: Union[float, list[float]] = 0.95, + base_momentum: float | list[float] = 0.85, + max_momentum: float | list[float] = 0.95, div_factor: float = 25.0, final_div_factor: float = 1e4, three_phase: bool = False, diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index f83cd4b85d02f..46a9bd47ddc81 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Implementation for the NAdam algorithm.""" -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -33,14 +33,14 @@ class NAdam(Optimizer): # noqa: D101 def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 2e-3, + lr: float | Tensor = 2e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, momentum_decay: float = 4e-3, decoupled_weight_decay: bool = False, *, - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, @@ -485,9 +485,9 @@ def _multi_tensor_nadam( exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs) - bias_correction_sqrt: Union[tuple[Tensor, ...], list[Tensor]] - mus: Union[tuple[Tensor, ...], list[Tensor]] - mu_nexts: Union[tuple[Tensor, ...], list[Tensor]] + bias_correction_sqrt: tuple[Tensor, ...] | list[Tensor] + mus: tuple[Tensor, ...] | list[Tensor] + mu_nexts: tuple[Tensor, ...] | list[Tensor] if capturable: # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay)) exponent = torch._foreach_mul(grouped_state_steps, momentum_decay) @@ -612,7 +612,7 @@ def nadam( # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim decoupled_weight_decay: bool = False, - foreach: Optional[bool] = None, + foreach: bool | None = None, capturable: bool = False, differentiable: bool = False, has_complex: bool = False, diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index c42ea3cfb02d5..8e691389ea50e 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Hashable, Iterable, Sequence from copy import deepcopy from itertools import chain -from typing import Any, cast, Optional, overload, TypeAlias, TypeVar, Union +from typing import Any, cast, overload, TypeAlias, TypeVar from typing_extensions import ParamSpec, Self import torch @@ -28,13 +28,11 @@ Args: TypeAlias = tuple[Any, ...] Kwargs: TypeAlias = dict[str, Any] StateDict: TypeAlias = dict[str, Any] -DeviceDict: TypeAlias = dict[Optional[torch.device], torch.Tensor] -DeviceDtypeDict: TypeAlias = dict[ - Optional[tuple[torch.device, torch.dtype]], torch.Tensor -] +DeviceDict: TypeAlias = dict[torch.device | None, torch.Tensor] +DeviceDtypeDict: TypeAlias = dict[tuple[torch.device, torch.dtype] | None, torch.Tensor] GlobalOptimizerPreHook: TypeAlias = Callable[ - ["Optimizer", Args, Kwargs], Optional[tuple[Args, Kwargs]] + ["Optimizer", Args, Kwargs], tuple[Args, Kwargs] | None ] GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] @@ -106,7 +104,7 @@ def _stack_if_compiling(x): def _disable_dynamo_if_unsupported( - single_tensor_fn: Optional[Callable[..., object]] = None, + single_tensor_fn: Callable[..., object] | None = None, ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # workaround for torchscript BC # it requires all called functions to be in the @@ -230,7 +228,7 @@ def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]: return capturable_supported_devices -def _to_scalar(x: Union[float, torch.Tensor]): +def _to_scalar(x: float | torch.Tensor): r"""This function converts a hyperparameter to a 0-dimension (scalar) tensor if it is a nonzero-dimensions 1-element tensor. If it is not a tensor, it is kept as is. @@ -331,9 +329,11 @@ def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> Removabl return handle -ParamsT: TypeAlias = Union[ - Iterable[torch.Tensor], Iterable[dict[str, Any]], Iterable[tuple[str, torch.Tensor]] -] +ParamsT: TypeAlias = ( + Iterable[torch.Tensor] + | Iterable[dict[str, Any]] + | Iterable[tuple[str, torch.Tensor]] +) R = TypeVar("R") T = TypeVar("T") @@ -356,7 +356,7 @@ class Optimizer: OptimizerPreHook: TypeAlias = Callable[ [Self, Args, Kwargs], # type: ignore[misc] - Optional[tuple[Args, Kwargs]], + tuple[Args, Kwargs] | None, ] OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] @@ -366,11 +366,11 @@ class Optimizer: _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' _optimizer_state_dict_post_hooks: ( # pyrefly: ignore [not-a-type] - 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + 'OrderedDict[int, Callable[["Optimizer", StateDict], StateDict | None]]' ) _optimizer_load_state_dict_pre_hooks: ( # pyrefly: ignore [not-a-type] - 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + 'OrderedDict[int, Callable[["Optimizer", StateDict], StateDict | None]]' ) _optimizer_load_state_dict_post_hooks: ( # pyrefly: ignore [not-a-type] @@ -541,10 +541,10 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: def _group_tensors_by_device_and_dtype( tensorlistlist: TensorListList, with_indices: bool = False, - ) -> Union[ - dict[tuple[None, None], tuple[TensorListList, Indices]], - dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]], - ]: + ) -> ( + dict[tuple[None, None], tuple[TensorListList, Indices]] + | dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]] + ): """Group a list of lists of tensors by device and dtype. Skips this step if we are compiling since this will occur during inductor lowering. @@ -641,7 +641,7 @@ def register_state_dict_pre_hook( def register_state_dict_post_hook( self, - hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + hook: Callable[["Optimizer", StateDict], StateDict | None], prepend: bool = False, ) -> RemovableHandle: r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called. @@ -800,7 +800,7 @@ def _process_value_according_to_param_policy( def register_load_state_dict_pre_hook( self, - hook: Callable[["Optimizer", StateDict], Optional[StateDict]], + hook: Callable[["Optimizer", StateDict], StateDict | None], prepend: bool = False, ) -> RemovableHandle: # noqa: D205 D400 r"""Register a load_state_dict pre-hook which will be called before @@ -1041,9 +1041,10 @@ def zero_grad(self, set_to_none: bool = True) -> None: if not hasattr(self, "_zero_grad_profile_name"): self._patch_step_function() - per_device_and_dtype_grads: Optional[ + per_device_and_dtype_grads: ( defaultdict[torch.device, defaultdict[torch.dtype, list[torch.Tensor]]] - ] + | None + ) if foreach: per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) else: @@ -1085,7 +1086,7 @@ def step(self, closure: None = None) -> None: ... @overload def step(self, closure: Callable[[], float]) -> float: ... - def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + def step(self, closure: Callable[[], float] | None = None) -> float | None: r"""Perform a single optimization step to update parameter. Args: diff --git a/torch/optim/radam.py b/torch/optim/radam.py index db69bbb01a042..c54b2bb83db31 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Implementation for the RAdam algorithm.""" -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -32,13 +32,13 @@ class RAdam(Optimizer): # noqa: D101 def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-3, + lr: float | Tensor = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, decoupled_weight_decay: bool = False, *, - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, @@ -438,9 +438,9 @@ def _multi_tensor_radam( # maximum length of the approximated SMA rho_inf = 2 / (1 - beta2) - 1 # compute the length of the approximated SMA - bias_correction1: Union[tuple[Tensor, ...], list[Tensor]] - bias_correction2: Union[tuple[Tensor, ...], list[Tensor]] - rho_t_list: Union[tuple[Tensor, ...], list[Tensor]] + bias_correction1: tuple[Tensor, ...] | list[Tensor] + bias_correction2: tuple[Tensor, ...] | list[Tensor] + rho_t_list: tuple[Tensor, ...] | list[Tensor] if capturable: bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps) torch._foreach_neg_(bias_correction1) @@ -575,7 +575,7 @@ def radam( # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim decoupled_weight_decay: bool = False, - foreach: Optional[bool] = None, + foreach: bool | None = None, differentiable: bool = False, capturable: bool = False, has_complex: bool = False, diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 364068ecc9ab3..f8e6da5489d74 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Implementation for the RMSprop algorithm.""" -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -31,14 +31,14 @@ class RMSprop(Optimizer): # noqa: D101 def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-2, + lr: float | Tensor = 1e-2, alpha: float = 0.99, eps: float = 1e-8, weight_decay: float = 0, momentum: float = 0, centered: bool = False, capturable: bool = False, - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, differentiable: bool = False, ) -> None: # noqa: D107 @@ -483,7 +483,7 @@ def rmsprop( state_steps: list[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, differentiable: bool = False, capturable: bool = False, diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index c9e1d5eabaeee..dcdc91692b7d3 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Implementation for the Resilient backpropagation.""" -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -31,12 +31,12 @@ class Rprop(Optimizer): # noqa: D101 def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-2, + lr: float | Tensor = 1e-2, etas: tuple[float, float] = (0.5, 1.2), step_sizes: tuple[float, float] = (1e-6, 50), *, capturable: bool = False, - foreach: Optional[bool] = None, + foreach: bool | None = None, maximize: bool = False, differentiable: bool = False, ) -> None: # noqa: D107 @@ -418,7 +418,7 @@ def rprop( state_steps: list[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim - foreach: Optional[bool] = None, + foreach: bool | None = None, capturable: bool = False, maximize: bool = False, differentiable: bool = False, diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 63c80d645cd08..8044d853f0b4e 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Implementation for Stochastic Gradient Descent optimizer.""" -from typing import cast, Optional, Union +from typing import cast import torch from torch import Tensor @@ -29,16 +29,16 @@ class SGD(Optimizer): # noqa: D101 def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-3, + lr: float | Tensor = 1e-3, momentum: float = 0, dampening: float = 0, - weight_decay: Union[float, Tensor] = 0, + weight_decay: float | Tensor = 0, nesterov: bool = False, *, maximize: bool = False, - foreach: Optional[bool] = None, + foreach: bool | None = None, differentiable: bool = False, - fused: Optional[bool] = None, + fused: bool | None = None, ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") @@ -118,7 +118,7 @@ def step(self, closure=None): for group in self.param_groups: params: list[Tensor] = [] grads: list[Tensor] = [] - momentum_buffer_list: list[Optional[Tensor]] = [] + momentum_buffer_list: list[Tensor | None] = [] has_sparse_grad = self._init_group( group, params, grads, momentum_buffer_list @@ -252,14 +252,14 @@ def step(self, closure=None): def sgd( params: list[Tensor], d_p_list: list[Tensor], - momentum_buffer_list: list[Optional[Tensor]], + momentum_buffer_list: list[Tensor | None], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim has_sparse_grad: bool = False, - foreach: Optional[bool] = None, - fused: Optional[bool] = None, - grad_scale: Optional[Tensor] = None, - found_inf: Optional[Tensor] = None, + foreach: bool | None = None, + fused: bool | None = None, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, *, weight_decay: float, momentum: float, @@ -322,9 +322,9 @@ def sgd( def _single_tensor_sgd( params: list[Tensor], grads: list[Tensor], - momentum_buffer_list: list[Optional[Tensor]], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + momentum_buffer_list: list[Tensor | None], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, weight_decay: float, momentum: float, @@ -383,9 +383,9 @@ def _single_tensor_sgd( def _multi_tensor_sgd( params: list[Tensor], grads: list[Tensor], - momentum_buffer_list: list[Optional[Tensor]], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + momentum_buffer_list: list[Tensor | None], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, weight_decay: float, momentum: float, @@ -480,9 +480,9 @@ def _multi_tensor_sgd( def _fused_sgd( params: list[Tensor], grads: list[Tensor], - momentum_buffer_list: list[Optional[Tensor]], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], + momentum_buffer_list: list[Tensor | None], + grad_scale: Tensor | None, + found_inf: Tensor | None, *, weight_decay: float, momentum: float, diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index ed58c93181ae2..d6196cb20cd4e 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Union import torch from torch import Tensor @@ -15,7 +14,7 @@ class SparseAdam(Optimizer): def __init__( self, params: ParamsT, - lr: Union[float, Tensor] = 1e-3, + lr: float | Tensor = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, maximize: bool = False, diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index ebe3e07025957..260292d23afc0 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable, Iterable from copy import deepcopy -from typing import Any, cast, Literal, Optional, Union +from typing import Any, cast, Literal, Union from typing_extensions import override import torch @@ -65,7 +65,7 @@ def get_swa_multi_avg_fn(): def swa_update( averaged_param_list: PARAM_LIST, current_param_list: PARAM_LIST, - num_averaged: Union[Tensor, int], + num_averaged: Tensor | int, ) -> None: # foreach lerp only handles float and complex if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( @@ -112,7 +112,7 @@ def get_swa_avg_fn(): @torch.no_grad() def swa_update( - averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int] + averaged_param: Tensor, current_param: Tensor, num_averaged: Tensor | int ): return averaged_param + (current_param - averaged_param) / (num_averaged + 1) @@ -223,11 +223,10 @@ class AveragedModel(Module): def __init__( self, model: Module, - device: Optional[Union[int, torch.device]] = None, - avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None, - multi_avg_fn: Optional[ - Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] - ] = None, + device: int | torch.device | None = None, + avg_fn: Callable[[Tensor, Tensor, Tensor | int], Tensor] | None = None, + multi_avg_fn: Callable[[PARAM_LIST, PARAM_LIST, Tensor | int], None] + | None = None, use_buffers=False, ) -> None: # noqa: D107 super().__init__() @@ -263,8 +262,8 @@ def update_parameters(self, model: Module) -> None: if self.use_buffers else model.parameters() ) - self_param_detached: list[Optional[Tensor]] = [] - model_param_detached: list[Optional[Tensor]] = [] + self_param_detached: list[Tensor | None] = [] + model_param_detached: list[Tensor | None] = [] copy_param = bool(self.n_averaged == 0) for p_averaged, p_model in zip(self_param, model_param, strict=False): p_model_ = p_model.detach().to(p_averaged.device) @@ -330,7 +329,7 @@ def update_parameters(self, model: Module) -> None: def update_bn( loader: Iterable[Any], model: Module, - device: Optional[Union[int, torch.device]] = None, + device: int | torch.device | None = None, ) -> None: r"""Update BatchNorm running_mean, running_var buffers in the model. From 0c281dd78773b2bc17c58ead0e4cd4ac46e775c5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 3 Dec 2025 00:36:30 +0000 Subject: [PATCH 152/338] Revert "[MPS] Fix dlpack exports/imports for sliced tensors (#169272)" This reverts commit 7741edd4ed665f3988052e260863efb508d61a03. Reverted https://github.com/pytorch/pytorch/pull/169272 on behalf of https://github.com/huydhn due to I am seeing some ROCm failures in trunk after this lands ([comment](https://github.com/pytorch/pytorch/pull/169272#issuecomment-3604521392)) --- aten/src/ATen/DLConvertor.cpp | 21 ++------------------- test/test_dlpack.py | 7 ++----- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index b39f3eafa32df..ccb0ae15a11e6 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -356,18 +356,8 @@ ScalarType toScalarType(const DLDataType& dtype) { return stype; } - namespace { -int64_t toStorageOffset(int64_t byte_offset, ScalarType stype) { - if (byte_offset == 0) { - return 0; - } - const auto element_size = c10::elementSize(stype); - TORCH_CHECK_VALUE(byte_offset % element_size == 0, "byte offset must be multiple of element size"); - return byte_offset / element_size; -} - // The templated classes below are needed for supporting both: // - DLManagedTensor // - DLManagedTensorVersioned @@ -403,18 +393,13 @@ T* toDLPackImpl(const Tensor& src) { atDLMTensor->handle = src; atDLMTensor->tensor.manager_ctx = atDLMTensor; atDLMTensor->tensor.deleter = &deleter; - if (src.device().type() == kMPS) { - atDLMTensor->tensor.dl_tensor.data = src.storage().mutable_data(); - atDLMTensor->tensor.dl_tensor.byte_offset = src.storage_offset() * c10::elementSize(src.scalar_type()); - } else { - atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); - atDLMTensor->tensor.dl_tensor.byte_offset = 0; - } + atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device()); atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); atDLMTensor->tensor.dl_tensor.shape = const_cast(src.sizes().data()); atDLMTensor->tensor.dl_tensor.strides = const_cast(src.strides().data()); + atDLMTensor->tensor.dl_tensor.byte_offset = 0; fillVersion(&atDLMTensor->tensor); return &(atDLMTensor->tensor); @@ -441,7 +426,6 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { ScalarType stype = toScalarType(dl_tensor.dtype); if (!dl_tensor.strides) { - TORCH_CHECK_VALUE(dl_tensor.byte_offset == 0, "Expected zero byte_offset"); return at::from_blob( dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), @@ -453,7 +437,6 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), IntArrayRef(dl_tensor.strides, dl_tensor.ndim), - toStorageOffset(dl_tensor.byte_offset, stype), deleter, at::device(device).dtype(stype), {device}); diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 3d27678b5864a..3d6c4ae7484cb 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -21,6 +21,7 @@ from torch.testing._internal.common_utils import ( IS_JETSON, run_tests, + skipIfMPS, skipIfTorchDynamo, TestCase, ) @@ -156,6 +157,7 @@ def test_from_dlpack(self, device, dtype): self.assertEqual(x, y) @skipMeta + @skipIfMPS # MPS crashes with noncontiguous now @onlyNativeDeviceTypes @dtypes( *all_types_and_complex_and( @@ -167,11 +169,6 @@ def test_from_dlpack(self, device, dtype): torch.uint64, ) ) - @dtypesIfMPS( - *all_mps_types_and( - torch.bool, torch.cfloat, torch.chalf, torch.uint16, torch.uint32 - ) - ) def test_from_dlpack_noncontinguous(self, device, dtype): x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5) From d1c9f03b2a5af4104721712f8cdffe9b4f340c01 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 3 Dec 2025 00:52:10 +0000 Subject: [PATCH 153/338] [4/N] Use context managers (#169169) This PR uses context managers in tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169169 Approved by: https://github.com/Lucaskabela --- .../ao/sparsity/test_qlinear_packed_params.py | 2 +- test/profiler/test_execution_trace.py | 23 +++++----- .../numpy_tests/core/test_multiarray.py | 11 +++-- torch/_dynamo/repro/after_aot.py | 45 ++++++++++--------- torch/_inductor/codegen/common.py | 2 +- .../_inductor/compile_worker/subproc_pool.py | 2 +- torch/_inductor/runtime/caching/interfaces.py | 4 +- torch/_logging/_internal.py | 2 +- torch/distributed/debug/_handlers.py | 3 +- .../elastic/multiprocessing/api.py | 4 +- .../subprocess_handler/subprocess_handler.py | 4 +- torch/multiprocessing/spawn.py | 2 +- torch/profiler/_memory_profiler.py | 14 +++--- torch/profiler/profiler.py | 5 +-- torch/testing/_internal/common_distributed.py | 6 ++- 15 files changed, 66 insertions(+), 63 deletions(-) diff --git a/test/ao/sparsity/test_qlinear_packed_params.py b/test/ao/sparsity/test_qlinear_packed_params.py index 1c4c58a93667a..7968e57eb3775 100644 --- a/test/ao/sparsity/test_qlinear_packed_params.py +++ b/test/ao/sparsity/test_qlinear_packed_params.py @@ -226,7 +226,7 @@ def make_lin_get_state_weight_bias_and_save(): state = lin._packed_params._packed_params.__getstate__() weight_bias = lin._weight_bias() - file_buff = tempfile.TemporaryFile() + file_buff = tempfile.TemporaryFile() # noqa:SIM115 torch.save(lin, file_buff) file_buff.seek(0) diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index 26c0ab42905de..1bc07c187fdc3 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -33,6 +33,7 @@ run_tests, skipIfHpu, skipIfTorchDynamo, + TemporaryFileName, TEST_HPU, TEST_XPU, TestCase, @@ -669,18 +670,18 @@ def test_execution_trace_repeat_in_loop(self, device): assert event_count == expected_loop_events def test_execution_trace_no_capture(self): - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) - fp.close() - et = ExecutionTraceObserver().register_callback(fp.name) + with TemporaryFileName("w+t", suffix=".et.json") as file_name: + et = ExecutionTraceObserver().register_callback(file_name) - assert fp.name == et.get_output_file_path() - et.unregister_callback() - nodes = self.get_execution_trace_root(fp.name) - for n in nodes: - assert "name" in n - if "[pytorch|profiler|execution_trace|process]" in n["name"]: - found_root_node = True - assert found_root_node + assert file_name == et.get_output_file_path() + et.unregister_callback() + nodes = self.get_execution_trace_root(file_name) + found_root_node = False + for n in nodes: + assert "name" in n + if "[pytorch|profiler|execution_trace|process]" in n["name"]: + found_root_node = True + assert found_root_node @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500") def test_execution_trace_nested_tensor(self): diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index 4f4bc16f53221..e73378817570f 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -4163,12 +4163,11 @@ def test_big_binary(self): fourgbplus = 2**32 + 2**16 testbytes = np.arange(8, dtype=np.int8) n = len(testbytes) - flike = tempfile.NamedTemporaryFile() - f = flike.file - np.tile(testbytes, fourgbplus // testbytes.nbytes).tofile(f) - flike.seek(0) - a = np.fromfile(f, dtype=np.int8) - flike.close() + with tempfile.NamedTemporaryFile() as flike: + f = flike.file + np.tile(testbytes, fourgbplus // testbytes.nbytes).tofile(f) + flike.seek(0) + a = np.fromfile(f, dtype=np.int8) assert_(len(a) == fourgbplus) # check only start and end for speed: assert_((a[:n] == testbytes).all()) diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index d8465541cdfa3..25ef68a111080 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -664,32 +664,35 @@ def isolate_fails( # print(fd.read()) new_env = os.environ.copy() new_env = {**new_env, **env} - stdout, stderr = TemporaryFile(), TemporaryFile() - if use_buck: cmd = BuckTargetWriter(file_name).write(print_msg=False) else: cmd = [sys.executable, file_name] + with ( + TemporaryFile() as stdout, + TemporaryFile() as stderr, + subprocess.Popen( + cmd, + cwd=subdir, + stdout=stdout, + stderr=stderr, + env=new_env, + ) as p, + ): + p.wait() - p = subprocess.Popen( - cmd, - cwd=subdir, - stdout=stdout, - stderr=stderr, - env=new_env, - ) - p.wait() - - stdout.seek(0) - stderr.seek(0) - print( - textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), file=sys.stdout - ) - print( - textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), file=sys.stderr - ) - # print(f"Isolated test failed - {file_name}") - return p.returncode != 0 + stdout.seek(0) + stderr.seek(0) + print( + textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), + file=sys.stdout, + ) + print( + textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), + file=sys.stderr, + ) + # print(f"Isolated test failed - {file_name}") + return p.returncode != 0 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 8b5e68780cb28..e27336af8eab9 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -112,7 +112,7 @@ class FileBackedGraphModule: def __post_init__(self) -> None: # Write the code to a file for compatibility with debugging utilities. # The file is deleted upon program termination. - self.tempfile = tempfile.NamedTemporaryFile( + self.tempfile = tempfile.NamedTemporaryFile( # noqa: SIM115 mode="w+", suffix=".py", delete=False ) atexit.register(os.remove, self.tempfile.name) diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index b0e0d4ba58495..07c59b8cbb860 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -175,7 +175,7 @@ def __init__( if log_path: # pyrefly: ignore [bad-assignment] - self.log_file = open(log_path, "w") + self.log_file = open(log_path, "w") # noqa:SIM115 self.process = subprocess.Popen( cmd, diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index d0c1011200e43..4c0972268e6f0 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -574,7 +574,7 @@ def _dump_imc_to_disk(self) -> Path | None: with odc.lock(): r_fp, w_fp = None, None try: - w_fp = open(fpath, "x") + w_fp = open(fpath, "x") # noqa:SIM115 except FileExistsError: with open(fpath) as r_fp: existing_dump = json.load(r_fp) @@ -585,7 +585,7 @@ def _dump_imc_to_disk(self) -> Path | None: elif to_dump[key] != value: raise exceptions.DeterministicCachingIMCDumpConflictError from None - w_fp = open(fpath, "w") + w_fp = open(fpath, "w") # noqa:SIM115 finally: assert w_fp is not None try: diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index e0af21614cb55..23dc6f46576b5 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1177,7 +1177,7 @@ def emit(self, record) -> None: ranksuffix = "" if dist.is_available() and dist.is_initialized(): ranksuffix = f"rank_{dist.get_rank()}_" - self.stream = tempfile.NamedTemporaryFile( + self.stream = tempfile.NamedTemporaryFile( # noqa: SIM115 mode="w+", suffix=".log", prefix=f"dedicated_log_torch_trace_{ranksuffix}", diff --git a/torch/distributed/debug/_handlers.py b/torch/distributed/debug/_handlers.py index ba951b7bda075..b8095c5b34bea 100644 --- a/torch/distributed/debug/_handlers.py +++ b/torch/distributed/debug/_handlers.py @@ -1,3 +1,4 @@ +import pathlib import tempfile import time @@ -15,7 +16,7 @@ def _torch_profile(req: _Request, resp: _Response) -> None: with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f: prof.export_chrome_trace(f.name) - resp.set_content(open(f.name, "rb").read(), "application/json") + resp.set_content(pathlib.Path(f.name).read_bytes(), "application/json") resp.set_status(200) diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 41252bc35e00b..719f6d5a1fef1 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -498,7 +498,7 @@ def __init__( ] if duplicate_stdout_filters: - self.filtered_stdout = open( + self.filtered_stdout = open( # noqa: SIM115 logs_dest.filtered_stdout, mode="w", errors="replace", buffering=1 ) self._tail_logs.append( @@ -514,7 +514,7 @@ def __init__( ) if duplicate_stderr_filters: - self.filtered_stderr = open( + self.filtered_stderr = open( # noqa: SIM115 logs_dest.filtered_stderr, mode="w", errors="replace", buffering=1 ) self._tail_logs.append( diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py index d4642541a191c..268817108d8cd 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -43,8 +43,8 @@ def __init__( local_rank_id: int, numa_options: NumaOptions | None, ): - self._stdout = open(stdout, "w") if stdout else None - self._stderr = open(stderr, "w") if stderr else None + self._stdout = open(stdout, "w") if stdout else None # noqa: SIM115 + self._stderr = open(stderr, "w") if stderr else None # noqa: SIM115 # inherit parent environment vars env_vars = os.environ.copy() env_vars.update(env) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index f553f7cacd753..f53be2ebe0392 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -265,7 +265,7 @@ def start_process(i): # used a multiprocessing.Queue but that can be prone to # deadlocks, so we went with a simpler solution for a one-shot # message between processes. - tf = tempfile.NamedTemporaryFile( + tf = tempfile.NamedTemporaryFile( # noqa: SIM115 prefix="pytorch-errorfile-", suffix=".pickle", delete=False ) tf.close() diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index dfa83f7467cd6..94ef747621a5e 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -1152,7 +1152,6 @@ def export_memory_timeline_html( return from base64 import b64encode - from os import remove from tempfile import NamedTemporaryFile import matplotlib.pyplot as plt @@ -1190,12 +1189,12 @@ def export_memory_timeline_html( axes.set_title(title) # Embed the memory timeline image into the HTML file - tmpfile = NamedTemporaryFile("wb", suffix=".png", delete=False) - tmpfile.close() - fig.savefig(tmpfile.name, format="png") + with NamedTemporaryFile("wb", suffix=".png") as tmpfile: + fig.savefig(tmpfile, format="png") - with open(tmpfile.name, "rb") as tmp: - encoded = b64encode(tmp.read()).decode("utf-8") + tmpfile.seek(0, 0) + encoded = b64encode(tmpfile.read()).decode("utf-8") + assert encoded html = f""" GPU Memory Timeline HTML @@ -1203,6 +1202,5 @@ def export_memory_timeline_html( """ - with open(path, "w") as f: + with open(path, "w", encoding="utf-8") as f: f.write(html) - remove(tmpfile.name) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index c52bd0f9ce2bb..151a41af919e4 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -1003,9 +1003,8 @@ def register_callback(self, output_file_path: str) -> Self: """ def get_temp_uncompressed_file() -> str: - fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) - fp.close() - return fp.name + with tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) as fp: + return fp.name if not self._registered: self.output_file_path = output_file_path diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index c2b4dd57055a6..0df79fa00f81b 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -809,7 +809,8 @@ def setUp(self) -> None: self.processes = [] # type: ignore[var-annotated] self.rank = self.MAIN_PROCESS_RANK - self.file_name = tempfile.NamedTemporaryFile(delete=False).name + with tempfile.NamedTemporaryFile(delete=False) as f: + self.file_name = f.name # pid to pipe consisting of error message from process. self.pid_to_pipe = {} # type: ignore[var-annotated] @@ -1811,7 +1812,8 @@ def _spawn_processes(cls, world_size) -> None: cls.task_queues = [] cls.completion_queues = [] # Need a rendezvous file for `init_process_group` purpose. - cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name + with tempfile.NamedTemporaryFile(delete=False) as f: + cls.rdvz_file = f.name # CUDA multiprocessing requires spawn instead of fork, to make sure # child processes have their own memory space. From d19f1e8cab6810bb2e99141f9976665954c67a50 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 3 Dec 2025 01:10:05 +0000 Subject: [PATCH 154/338] Remove unnecessary uses of thrust::tuple (#168936) This PR removes unnecessary uses of thrust::tuple before moving to CCCL. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168936 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/ActivationEluKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationGeluKernel.cu | 1 - aten/src/ATen/native/cuda/ActivationGluKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationHardswishKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationMishKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationSiluKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationThresholdKernel.cu | 2 -- aten/src/ATen/native/cuda/Loops.cuh | 2 +- aten/src/ATen/native/cuda/group_norm_kernel.cu | 1 - aten/src/ATen/native/cuda/layer_norm_kernel.cu | 3 +-- 17 files changed, 2 insertions(+), 31 deletions(-) diff --git a/aten/src/ATen/native/cuda/ActivationEluKernel.cu b/aten/src/ATen/native/cuda/ActivationEluKernel.cu index 5ad1f806f9ba5..9fc29aa5539b5 100644 --- a/aten/src/ATen/native/cuda/ActivationEluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationEluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu index cd5a0ae85e61c..87781c44e3348 100644 --- a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu @@ -5,7 +5,6 @@ #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGluKernel.cu b/aten/src/ATen/native/cuda/ActivationGluKernel.cu index e28a6d61ea152..8a782a129c9fb 100644 --- a/aten/src/ATen/native/cuda/ActivationGluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu index 2a0be3f5d27bf..f0968b957aa6d 100644 --- a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu index fcacef37ceaf0..813a8c07ccfac 100644 --- a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu index 1642d0909f7f0..651cdef82543b 100644 --- a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu index a18072f7a27bc..85aa7ccd22a9e 100644 --- a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu index 72130739898fe..340a6f97d00de 100644 --- a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu index 9a1d672428b48..2175920917852 100644 --- a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationMishKernel.cu b/aten/src/ATen/native/cuda/ActivationMishKernel.cu index 0db0e96bb180a..25ba9810e37cf 100644 --- a/aten/src/ATen/native/cuda/ActivationMishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationMishKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu index f7ddfd8502a18..ebdfe245b6166 100644 --- a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu index 64ffc21123707..65f4f3679f862 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu index 0c2dc63dbcf45..712c86e0e5216 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu index 2d1cb4a47d7d8..430f9cbfa78bb 100644 --- a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index a80c51fa6a9cb..e739d7d2ecee2 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -282,7 +282,7 @@ void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) using traits = function_traits; using output_t = typename traits::result_type; static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); - constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_outputs = std::tuple_size::value; constexpr int num_inputs = traits::arity; constexpr int ntensors = num_outputs + num_inputs; diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index 77d26e915b65a..0ef6434f909de 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -3,7 +3,6 @@ #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 84812eb22125f..6f5112c605fab 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -1,10 +1,9 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include -#include - #include #include #include From 6c261c6cb07892c90ca19ed51c9705b1659a3f7d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 3 Dec 2025 01:16:33 +0000 Subject: [PATCH 155/338] Revert "[Accelerator] Add Accelerator Capabilities API (#165631)" This reverts commit 285779b1621cf9f073a062b0889a642d200308d9. Reverted https://github.com/pytorch/pytorch/pull/165631 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it has a small bug when building this internally ([comment](https://github.com/pytorch/pytorch/pull/165631#issuecomment-3604616720)) --- aten/src/ATen/DeviceAccelerator.cpp | 6 -- aten/src/ATen/DeviceAccelerator.h | 5 -- c10/core/DeviceCapability.h | 74 ------------------- c10/core/impl/DeviceGuardImplInterface.h | 26 ------- c10/core/impl/VirtualGuardImpl.h | 4 - .../torch_openreg/csrc/runtime/OpenRegGuard.h | 9 --- .../torch_openreg/tests/test_device.py | 9 +-- torch/_C/__init__.pyi.in | 1 - torch/accelerator/__init__.py | 27 +------ torch/csrc/DeviceAccelerator.cpp | 19 ----- 10 files changed, 2 insertions(+), 178 deletions(-) delete mode 100644 c10/core/DeviceCapability.h diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index efab9ec9c5927..aa9d6e6b1ce9b 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -130,12 +130,6 @@ c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) { impl.uncheckedSetDevice({device_type, device_index}); return impl.getDevice().index(); } - -c10::DeviceCapability getDeviceCapability(c10::DeviceIndex device_index) { - const auto device_type = getAccelerator(true).value(); - c10::impl::VirtualGuardImpl impl(device_type); - return impl.getDeviceCapability({device_type, device_index}); -} // NOLINTEND(bugprone-unchecked-optional-access) } // namespace at::accelerator diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index d24b42ca459e7..2cc4cff7cd1f2 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -74,10 +73,6 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); -// Get the device capability of the given device index. -TORCH_API c10::DeviceCapability getDeviceCapability( - c10::DeviceIndex device_index); - TORCH_API inline void emptyCache() { const auto device_type = getAccelerator(true).value(); at::getDeviceAllocator(device_type)->emptyCache(); diff --git a/c10/core/DeviceCapability.h b/c10/core/DeviceCapability.h deleted file mode 100644 index e24f12614978a..0000000000000 --- a/c10/core/DeviceCapability.h +++ /dev/null @@ -1,74 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace c10 { - -constexpr size_t NUMBER_OF_DEVICE_CAPABILITIES = NumScalarTypes; - -// Generate bitfields for each scalar type -#define DEFINE_SCALAR_TYPE(_1, n) unsigned int has_##n : 1; - -// Generate enum indices for each scalar type -#define DEFINE_SCALAR_ENUM(_1, name) kIndex_##name, - -enum ScalarTypeIndex { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_ENUM) -}; - -/** - * @brief DeviceCapability represents the the common capabilities that all - * devices should support. - * - * This struct provides a compact way to represent the common capabilities that - * all devices should support. Includes the following capabilities: - * - Supported data types - * - * Purpose - * - Enable device-specific optimizations based on supported capabilities - * - * Contract - * - * Supported data types: - * - Each bitfield represents support for one device capability - * - Bit value 1 means the capability is supported, 0 means not supported - * - The struct is initialized with all capabilities enabled by default - * - * @note Adding New Capabilities - * - * 1. Define the new capability in the `DeviceCapability` struct - * 2. Update the support of the new capability in each accelerator - * implementation - * 3. Add the new capability to the returned PyObject Dictionary - */ -struct C10_API DeviceCapability { - union { - struct { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE) - }; - uint64_t capability_bits; // Allow direct bit manipulation - }; - - // Default constructor with all capabilities enabled. - DeviceCapability() - : capability_bits((1ULL << NUMBER_OF_DEVICE_CAPABILITIES) - 1) {} - - // Iterate supported ScalarTypes without allocating a vector - template - void forEachSupportedScalarType(F&& visitor) const { -#define VISIT_SCALAR_TYPE(_1, n) \ - if (has_##n) { \ - visitor(ScalarType::n); \ - } - - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(VISIT_SCALAR_TYPE) - -#undef VISIT_SCALAR_TYPE - } -}; - -#undef DEFINE_SCALAR_ENUM -#undef DEFINE_SCALAR_TYPE -} // namespace c10 diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index 00096584b9229..f9f67497c6315 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -192,15 +191,6 @@ struct C10_API DeviceGuardImplInterface { */ virtual DeviceIndex deviceCount() const noexcept = 0; - /** - * Get the following capabilities of the current device: - * (1) Data type support - * Returns DeviceCapability object. - */ - virtual DeviceCapability getDeviceCapability(Device /*unused*/) const { - TORCH_CHECK(false, "Backend doesn't support getting device capabilities."); - } - /** * Return true if all the work previously enqueued on the stream for * asynchronous execution has completed running on the device. @@ -301,22 +291,6 @@ struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface { return 1; } - DeviceCapability getDeviceCapability(Device /*unused*/) const override { - DeviceCapability cap; - if constexpr (D == DeviceType::Meta) { - cap.capability_bits = 0; - // Meta only supports basic types for shape inference - // Byte, Char, Short, Int, Long, Float, Double, - // Bool, ComplexFloat, ComplexDouble - cap.capability_bits = (1ULL << kIndex_Byte) | (1ULL << kIndex_Char) | - (1ULL << kIndex_Short) | (1ULL << kIndex_Int) | - (1ULL << kIndex_Long) | (1ULL << kIndex_Float) | - (1ULL << kIndex_Double) | (1ULL << kIndex_ComplexFloat) | - (1ULL << kIndex_ComplexDouble) | (1ULL << kIndex_Bool); - } - return cap; - } - // Event-related functions void record( void** /*event*/, diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 0254c69baba00..3d259f5e390e3 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -57,10 +57,6 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { return impl_->deviceCount(); } - DeviceCapability getDeviceCapability(Device d) const override { - return impl_->getDeviceCapability(d); - } - // Event functions void record( void** event, diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h index 3c1c1193d3cdb..59bc2d5cdbff5 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -51,14 +50,6 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { return c10::Device(static_type, device_index); } - /** - * Get the device capability for a given device. - * By default, OpenReg has 2 same devices with the same capability. - */ - c10::DeviceCapability getDeviceCapability(c10::Device /*unused*/) const override { - return c10::DeviceCapability(); - } - /** * Set the current device to c10::Device. */ diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py index 9cb4a785d36e7..f925f15600ce7 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py @@ -1,7 +1,7 @@ # Owner(s): ["module: PrivateUse1"] import torch -from torch.testing._internal.common_dtype import get_all_dtypes +import torch_openreg # noqa: F401 from torch.testing._internal.common_utils import run_tests, TestCase @@ -31,13 +31,6 @@ def test_invalid_device_index(self): with self.assertRaisesRegex(RuntimeError, "The device index is out of range"): torch.accelerator.set_device_index(2) - def test_device_capability(self): - capability = torch.accelerator.get_device_capability("openreg:0") - supported_dtypes = capability["supported_dtypes"] - expected_dtypes = get_all_dtypes(include_complex32=True, include_qint=True) - - self.assertTrue(all(dtype in supported_dtypes for dtype in expected_dtypes)) - if __name__ == "__main__": run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 520d07d487270..532815d535d5e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2494,7 +2494,6 @@ def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails def _accelerator_getAccelerator() -> _device: ... def _accelerator_setDeviceIndex(device_index: _int) -> None: ... def _accelerator_getDeviceIndex() -> _int: ... -def _accelerator_getDeviceCapability(device_index: _int) -> dict[str, Any]: ... def _accelerator_setStream(Stream) -> None: ... def _accelerator_getStream(device_index: _int) -> Stream: ... def _accelerator_synchronizeDevice(device_index: _int) -> None: ... diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index b0dfbe400bfbc..e1a82aa63ce22 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -2,8 +2,7 @@ This package introduces support for the current :ref:`accelerator` in python. """ -from functools import cache -from typing import Any +from typing import Optional from typing_extensions import deprecated import torch @@ -26,7 +25,6 @@ "current_accelerator", "current_device_idx", # deprecated "current_device_index", - "get_device_capability", "current_stream", "device_count", "device_index", @@ -154,29 +152,6 @@ def current_device_index() -> int: """ -@cache -def get_device_capability(device: _device_t = None, /) -> dict[str, Any]: - r"""Return the capability of the currently selected device. - - Args: - device (:class:`torch.device`, str, int, optional): The device to query capabilities for - :ref:`accelerator` device type. If not given, - use :func:`torch.accelerator.current_device_index` by default. - - Returns: - dict[str, Any]: A dictionary containing device capability information. The dictionary includes: - - ``supported_dtypes`` (set(torch.dtype)): Set of PyTorch data types supported by the device - - Examples: - >>> # xdoctest: +SKIP("requires cuda") - >>> # Query capabilities for current device - >>> capabilities = torch.accelerator.get_device_capability("cuda:0") - >>> print("Supported dtypes:", capabilities["supported_dtypes"]) - """ - device_index = _get_device_index(device, optional=True) - return torch._C._accelerator_getDeviceCapability(device_index) - - def set_device_index(device: _device_t, /) -> None: r"""Set the current device index to a given device. diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index c6ffa893d95ae..14e54851178f5 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -33,25 +33,6 @@ void initModule(PyObject* module) { return at::accelerator::getDeviceIndex(); }); - m.def("_accelerator_getDeviceCapability", [](c10::DeviceIndex device_index) { - const auto device_type = at::accelerator::getAccelerator(true).value(); - torch::utils::maybe_initialize_device(device_type); - auto caps = at::accelerator::getDeviceCapability(device_index); - - py::dict dict; - - py::set dtype_set; - caps.forEachSupportedScalarType([&](c10::ScalarType dtype) { - THPDtype* thp_dtype = torch::getTHPDtype(dtype); - py::object dtype_obj = - py::reinterpret_borrow((PyObject*)thp_dtype); - dtype_set.add(dtype_obj); - }); - - dict["supported_dtypes"] = dtype_set; - return dict; - }); - m.def("_accelerator_setStream", [](c10::Stream stream) { const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); From 0b80a4c62b94402844bf221791c096b0035c6d75 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 2 Dec 2025 18:37:57 +0000 Subject: [PATCH 156/338] [dynamo, ci] Update 3.13 pull CI tests to Python 3.14 (#169032) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169032 Approved by: https://github.com/williamwen42, https://github.com/atalman, https://github.com/malfet --- .github/workflows/pull.yml | 20 ++++++++++---------- test/inductor/test_caching.py | 3 +++ test/nn/test_parametrization.py | 8 ++++++++ torch/testing/_internal/common_utils.py | 4 ++++ 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index f2483dff9a94c..e1d46de9110b4 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -221,14 +221,14 @@ jobs: test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }} secrets: inherit - linux-jammy-py3_13-clang12-build: - name: linux-jammy-py3.13-clang12 + linux-jammy-py3_14-clang12-build: + name: linux-jammy-py3.14-clang12 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.13-clang12 - docker-image-name: ci-image:pytorch-linux-jammy-py3.13-clang12 + build-environment: linux-jammy-py3.14-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-py3.14-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -245,14 +245,14 @@ jobs: ]} secrets: inherit - linux-jammy-py3_13-clang12-test: - name: linux-jammy-py3.13-clang12 + linux-jammy-py3_14-clang12-test: + name: linux-jammy-py3.14-clang12 uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-py3_13-clang12-build + needs: linux-jammy-py3_14-clang12-build with: - build-environment: linux-jammy-py3.13-clang12 - docker-image: ${{ needs.linux-jammy-py3_13-clang12-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_13-clang12-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.14-clang12 + docker-image: ${{ needs.linux-jammy-py3_14-clang12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_14-clang12-build.outputs.test-matrix }} secrets: inherit linux-jammy-cuda12_8-cudnn9-py3_10-clang12-build: diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index aa4c3a1f229f1..17527ffb79c1d 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -33,6 +33,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + xfailIfPy314Plus, ) @@ -374,6 +375,7 @@ def test_isolation_key_is_repeatable(self) -> None: """ self.assertEqual(context._isolation_key(), context._isolation_key()) + @xfailIfPy314Plus def test_select_runtime_context_matches_forms_of_context(self) -> None: """ Tests that the selected runtime context matches the forms of context. @@ -387,6 +389,7 @@ def test_select_runtime_context_matches_forms_of_context(self) -> None: set(context._RuntimeContext.forms_of_context()), ) + @xfailIfPy314Plus def test_select_compile_context_matches_forms_of_context(self) -> None: """ Tests that the selected compile context matches the forms of context. diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py index 5dca91f0d2c80..a06202ebf861a 100644 --- a/test/nn/test_parametrization.py +++ b/test/nn/test_parametrization.py @@ -1,5 +1,7 @@ # Owner(s): ["module: nn"] import pickle +import sys +import unittest from copy import deepcopy from itertools import product @@ -669,6 +671,7 @@ def right_inverse(self, w): self.assertFalse(parametrize.is_parametrized(module)) self.assertEqual(module.weight, weight_init) + @unittest.skipIf(sys.version_info >= (3, 14), "Failing on Python 3.14+") @swap([True, False]) def test_errors_parametrized_tensor_parametrization(self): # Test errors when registering a parametrization on a parametrized tensor @@ -853,6 +856,7 @@ def right_inverse(self, w): # FIXME: Rewrite this test using functions not depending on LAPACK # and remove the `@skipIfNoLapack` (see #70995) @skipIfNoLapack + @unittest.skipIf(sys.version_info >= (3, 14), "Failing on Python 3.14+") @swap([True, False]) def test_caching_parametrization(self): r"""Test the caching system of a parametrization""" @@ -881,6 +885,7 @@ def forward(self, X): # FIXME: Rewrite this test using functions not depending on LAPACK # and remove the `@skipIfNoLapack` (see #70995) @skipIfNoLapack + @unittest.skipIf(sys.version_info >= (3, 14), "Failing on Python 3.14+") @swap([True, False]) def test_caching_parametrization_with_transfer_parametrizations_and_params(self): r"""Test that transferring parametrizations doesn't cause issues with caching""" @@ -914,6 +919,7 @@ def forward(self, X): # test that the results are distinct objects for each module self.assertNotEqual(id(A), id(X)) + @unittest.skipIf(sys.version_info >= (3, 14), "Failing on Python 3.14+") @swap([True, False]) def test_parametrization_same_training_mode(self): r"""Test training mode updated on parametrization registration""" @@ -931,6 +937,7 @@ def forward(self, X): self.assertTrue(module.parametrizations.weight[0].training) self.assertTrue(module.parametrizations.weight[1].training) + @unittest.skipIf(sys.version_info >= (3, 14), "Failing on Python 3.14+") @swap([True, False]) def test_type_before_parametrizations(self): r"""Test that type_before_parametrizations always retrieves original type""" @@ -1546,6 +1553,7 @@ def test_new_spectral_norm_dim(self): snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape ) + @unittest.skipIf(sys.version_info >= (3, 14), "Failing on Python 3.14+") @swap([True, False]) def test_new_spectral_norm_forward(self): input = torch.randn(3, 5) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index df3ca03b76242..b6904fd760982 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1709,6 +1709,10 @@ def xfailIfPy312Plus(func): return unittest.expectedFailure(func) if sys.version_info >= (3, 12) else func +def xfailIfPy314Plus(func): + return unittest.expectedFailure(func) if sys.version_info >= (3, 14) else func + + def xfailIfLinux(func): return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM and not IS_FBCODE else func From ef019d1d431c4c5a95b594cb90d40a50cd00f5e4 Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Tue, 2 Dec 2025 15:18:45 +0000 Subject: [PATCH 157/338] Silent XNNPACK GCC14 warnings (#166873) Fixes: #149828,#167642 and allows us to update to GCC14 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166873 Approved by: https://github.com/malfet --- cmake/Dependencies.cmake | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 4df8ba4a784b4..dfff1f2ad833a 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -581,6 +581,12 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK) "${XNNPACK_SOURCE_DIR}" "${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK") + if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL "14") + foreach(xnn_tgt IN ITEMS XNNPACK microkernels-prod microkernels-all) + target_compile_options(${xnn_tgt} PRIVATE -Wno-error=incompatible-pointer-types) + endforeach() + endif() + # Revert to whatever it was before set(CMAKE_POSITION_INDEPENDENT_CODE ${__caffe2_CMAKE_POSITION_INDEPENDENT_CODE_FLAG}) endif() From 7a1e316115fc6996b3f2336822ba5d5f6179f0c3 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 3 Dec 2025 02:36:20 +0000 Subject: [PATCH 158/338] Fix missing ConstantPooling header in passes.cpp (#169420) Summary: The previous diff (D88000002) removed the `constant_pooling.h` header from `/data/users/wouterdevriendt/fbsource/xplat/caffe2/torch/csrc/jit/runtime/static/passes.cpp`, but the file still uses the `ConstantPooling` function in the `SplitOutPrecomputeOpsForSparseNN` function. This caused a compilation error: ``` error: use of undeclared identifier 'ConstantPooling' ``` This diff adds back the necessary header to fix the compilation error. Test Plan: Buck build should now pass: ``` buck2 build fbcode//caffe2:_libtorch --mode=opt ``` Differential Revision: D88060085 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169420 Approved by: https://github.com/yangw-dev, https://github.com/atalman, https://github.com/cyyever --- torch/csrc/jit/runtime/static/passes.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 1029dd7019f8c..4d2cb6336bbdf 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include From c178ed43d3d99cbefe84fbfb21d6f282b20d62ac Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 3 Dec 2025 04:36:48 +0000 Subject: [PATCH 159/338] Revert "[Dynamo][Guards]Fix TLParse CPP guard message with sorting get_leaf_guards and verbose_code_parts (#169102)" This reverts commit 491731647f6b8a9345dcfb3bc9416aea254a7d96. Reverted https://github.com/pytorch/pytorch/pull/169102 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert this in order to revert https://github.com/pytorch/pytorch/pull/169229 ([comment](https://github.com/pytorch/pytorch/pull/169102#issuecomment-3605053944)) --- test/dynamo/test_misc.py | 31 ++++++++++++++++--------------- torch/_dynamo/guards.py | 4 ++-- torch/_guards.py | 24 +++++++++++------------- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 98526a9ba0283..78b5c7e4553da 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1225,29 +1225,30 @@ def fn(x, y): # Filter out id-matches that won't reproduce run to run guard_code = filter( lambda line: "id" not in line and "lookup_backend" not in line, - guard_code, + sorted(guard_code), ) guard_code_str = "\n".join(guard_code) - # Make sure that the dict_contains are present in the order of added - self.assertExpectedInline( - guard_code_str, - """\ -L['x'].size()[1] == L['x'].size()[0] -L['x'].storage_offset() == 0 + for line in """\ 2 <= L['x'].size()[0] -utils_device.CURRENT_DEVICE == None -str(L['x'].dtype) == 'torch.float32' -str(L['x'].device) == 'cpu' -L['x'].requires_grad == False +L['x'] is L['y'] L['x'].ndimension() == 2 +L['x'].requires_grad == False +L['x'].size()[1] == L['x'].size()[0] +L['x'].storage_offset() == 0 +___dict_contains('operator', G['sys'].modules) +___dict_contains('operator', G['sys'].modules) hasattr(L['x'], '_dynamo_dynamic_indices') == False -L['x'] is L['y'] not ___dict_contains('aaaaaaaa', G['sys'].modules) not ___dict_contains('bbbbbbbb', G['sys'].modules) -___dict_contains('operator', G['sys'].modules) -not ___dict_contains('cccccccc', G['sys'].modules)""", - ) +not ___dict_contains('cccccccc', G['sys'].modules) +str(L['x'].device) == 'cpu' +str(L['x'].dtype) == 'torch.float32' +utils_device.CURRENT_DEVICE == None""".split("\n"): + self.assertIn( + line, + guard_code_str, + ) def test_fold(self): def fn(a): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 69197f44054a3..1a5f235ad916b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -3871,7 +3871,7 @@ def _ref(x: Any) -> Any: }, global_scope=global_scope_state, _guards=torch._guards.GuardsSet( - OrderedSet( + { dataclasses.replace( guard, obj_weakref=None, @@ -3879,7 +3879,7 @@ def _ref(x: Any) -> Any: create_fn=normalize_create_fn(guard.create_fn), ) for guard in sorted_guards - ) + } ), input_source_to_sizes_strides=pytree.tree_map( convert_int_to_concrete_values, diff --git a/torch/_guards.py b/torch/_guards.py index 386872c4eecfb..c9daab1e69e81 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -14,11 +14,10 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Generic, NamedTuple, TYPE_CHECKING, TypeVar import torch from torch.utils import _pytree as pytree -from torch.utils._ordered_set import OrderedSet from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._traceback import CapturedTraceback, format_frame from torch.utils.weak import WeakTensorKeyDictionary @@ -488,16 +487,16 @@ class GuardsCheckpointState: The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext """ - dynamo_guards: OrderedSet[Guard] + dynamo_guards: set[Guard] = set() - def __init__(self, dynamo_guards: OrderedSet[Guard]) -> None: + def __init__(self, dynamo_guards: set[Guard]) -> None: self.dynamo_guards = dynamo_guards - def diff(self, other: GuardsCheckpointState) -> Optional[OrderedSet[Guard]]: + def diff(self, other: GuardsCheckpointState) -> set[Guard] | None: """ Produces a delta against another GuardsCheckpointState. - Returns None if no delta is found, otherwise, return an OrderedSet() of mismatched + Returns None if no delta is found, otherwise, return a set() of mismatched Guard type objects. """ r = self.dynamo_guards.difference(other.dynamo_guards) @@ -606,11 +605,10 @@ def restore_graphstate(self, state: GlobalContextCheckpointState) -> None: # Like a Set[Guard] but will record the user stack on all guards at the # time they were installed at their destination class GuardsSet: - def __init__(self, inner: Optional[OrderedSet[Guard]] = None) -> None: + def __init__(self, inner: set[Guard] | None = None) -> None: if inner is None: - self.inner: OrderedSet[Guard] = OrderedSet() - else: - self.inner = inner + inner = set() + self.inner = inner def __iter__(self) -> Iterator[Guard]: return iter(self.inner) @@ -647,9 +645,9 @@ def remove_guards_with_source(self, source: Source) -> None: """Delete all guards that contains a given source""" from ._dynamo.source import is_from_source - self.inner = OrderedSet( + self.inner = { g for g in self.inner if not is_from_source(g.originating_source, source) - ) + } """ @@ -666,7 +664,7 @@ def __init__(self) -> None: self.aotautograd_guards: list[GuardEnvExpr] = [] def copy_graphstate(self) -> GuardsCheckpointState: - return GuardsCheckpointState(OrderedSet(self.dynamo_guards.inner)) + return GuardsCheckpointState(set(self.dynamo_guards.inner)) def restore_graphstate(self, state: GuardsCheckpointState) -> None: # NB: "steals" the passed in state From 59abd50e931f4efb21b053f7a2911f5d8a49d883 Mon Sep 17 00:00:00 2001 From: "Andy (An) Wang" Date: Wed, 3 Dec 2025 04:43:14 +0000 Subject: [PATCH 160/338] [reland][Full Inductor][Pytorch] Prevent decomposition to support fallback of aten.native_layer_norm for MTIA (#168986) Differential Revision: D87665337 ## Context **Original context of [#168290](https://github.com/pytorch/pytorch/pull/168290):** MTIA-Triton currently doesn't support aten.native_layer_norm and we need Inductor to fallback it to Aten. Currently `make_fallback` doesn't work for aten.native_layer_norm due to decomposition. This PR prevents the decomposition, following the PR [#151637](https://github.com/pytorch/pytorch/pull/151637) where XPU enabled fallback for embedding_dense_backward. **This PR:** https://github.com/pytorch/pytorch/pull/168290 was reverted because it caused runtime failures for mtia tests ([T245927718](https://www.internalfb.com/intern/tasks/?t=245927718)). Not sure why the internal CI didn't prevent the diff landing though. After investigation, I think the failures were caused by `torch.mtia.is_available()`: 1. The API didn't run lazy initialization of mtia device, which caused error when this API was run by Inductor while mtia device hasn't been initialized. This issue should be fixed by D88058718. 2. I also found some failing tests were run without mtia device, which causes `torch.mtia.is_available()` returns false. ## This PR Exactly copy of https://github.com/pytorch/pytorch/pull/168290, except for changing `torch.mtia.is_available()` to `torch.mtia._is_compiled()`, which doesn't require mtia device to be initialized. And I verified the failed tests can pass with `torch.mtia._is_compiled()`. `torch.mtia._is_compiled()` checks whether MTIA hook is registered, so should not affect non-MTIA program. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168986 Approved by: https://github.com/eellison --- torch/_inductor/decomposition.py | 16 ++++++++++++++++ torch/_inductor/lowering.py | 4 ++++ 2 files changed, 20 insertions(+) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 3cedad185c3f2..db9c8f5f0333c 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -35,6 +35,7 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype, ) +from torch._refs import native_layer_norm as decomp_native_layer_norm from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true from . import config, inductor_prims @@ -118,6 +119,7 @@ aten.clamp_max, aten.clamp_min, aten.embedding_dense_backward, # we fall back on xpu + aten.native_layer_norm, # we fall back on mtia aten.index_add, # we conditionally call this decomp aten.glu, # inductor lowers this directly aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass @@ -159,6 +161,20 @@ def _embedding_dense_backward( ) +@register_decomposition(aten.native_layer_norm) +def _native_layer_norm( + input: torch.Tensor, + normalized_shape: utils.ShapeType, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if input.is_mtia: + return NotImplemented + # We can write a util function to update decomp table if we have more ops to fallback. + return decomp_native_layer_norm(input, normalized_shape, weight, bias, eps) + + @register_decomposition([aten.sym_constrain_range_for_size.default]) def sym_constrain_range_for_size( symbol: torch.SymInt, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 090265d208c92..427997964bbb7 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2902,6 +2902,10 @@ def is_aligned(x): aten.embedding_dense_backward, warn=False ) # (XPU-only and faster than decomp) +if torch.mtia._is_compiled(): + make_fallback( + aten.native_layer_norm, warn=False + ) # (MTIA-only and faster than decomp) # 1.5) Easy or Impossible make_fallback(aten._cdist_forward) # p=2 should be feasible From 3418bd29475dff06695045fcdf93e7d0dac67da8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 3 Dec 2025 04:47:36 +0000 Subject: [PATCH 161/338] Revert " [10/N] Use Python 3.10 typing (#169229)" This reverts commit 518c2b1b3dab9a2ef2849e04b3bc2f20c1c41db9. Reverted https://github.com/pytorch/pytorch/pull/169229 on behalf of https://github.com/huydhn due to Sorry for reverting your change but there is a subtlety w.r.t export/import that is surfaced by this change ([comment](https://github.com/pytorch/pytorch/pull/169229#issuecomment-3605071398)) --- torch/__init__.py | 60 +++++---- torch/_compile.py | 6 +- torch/_guards.py | 88 ++++++------- torch/_jit_internal.py | 6 +- torch/_linalg_utils.py | 10 +- torch/_lobpcg.py | 80 ++++++------ torch/_lowrank.py | 19 +-- torch/_meta_registrations.py | 204 +++++++++++++++--------------- torch/_ops.py | 35 ++++- torch/_sources.py | 8 +- torch/_tensor.py | 52 ++++---- torch/_tensor_str.py | 4 +- torch/_utils.py | 6 +- torch/_utils_internal.py | 10 +- torch/_vmap_internals.py | 12 +- torch/_weights_only_unpickler.py | 12 +- torch/functional.py | 42 +++--- torch/hub.py | 12 +- torch/library.py | 66 +++++----- torch/masked/_ops.py | 174 ++++++++++++------------- torch/nn/_reduction.py | 9 +- torch/nn/common_types.py | 8 +- torch/nn/init.py | 28 ++-- torch/nn/modules/activation.py | 41 +++--- torch/nn/modules/batchnorm.py | 18 +-- torch/nn/modules/container.py | 14 +- torch/nn/modules/conv.py | 20 +-- torch/nn/modules/lazy.py | 4 +- torch/nn/modules/loss.py | 25 ++-- torch/nn/modules/module.py | 40 +++--- torch/nn/modules/normalization.py | 6 +- torch/nn/modules/pooling.py | 34 ++--- torch/nn/modules/rnn.py | 36 +++--- torch/nn/modules/transformer.py | 82 ++++++------ torch/nn/modules/upsampling.py | 25 ++-- torch/overrides.py | 4 +- torch/quasirandom.py | 11 +- torch/serialization.py | 48 +++---- torch/storage.py | 52 ++++---- torch/types.py | 22 ++-- torch/xpu/__init__.py | 18 +-- torch/xpu/random.py | 7 +- 42 files changed, 755 insertions(+), 703 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index e6f9cfcb54472..ad32f8a054dc7 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -320,7 +320,7 @@ def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> ctypes.CDLL(lib_path) -def _preload_cuda_deps(err: OSError | None = None) -> None: +def _preload_cuda_deps(err: _Optional[OSError] = None) -> None: cuda_libs: list[tuple[str, str]] = [ ("cublas", "libcublas.so.*[0-9]"), ("cudnn", "libcudnn.so.*[0-9]"), @@ -1276,7 +1276,7 @@ def set_default_device(device: "Device") -> None: _GLOBAL_DEVICE_CONTEXT.device_context = device_context -def set_default_tensor_type(t: type["torch.Tensor"] | str, /) -> None: +def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None: r""" .. warning:: @@ -1524,7 +1524,7 @@ def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: return _C._get_deterministic_algorithms_warn_only() -def set_deterministic_debug_mode(debug_mode: builtins.int | str) -> None: +def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None: r"""Sets the debug mode for deterministic operations. .. note:: This is an alternative interface for @@ -1686,7 +1686,7 @@ def is_warn_always_enabled() -> builtins.bool: def _check_with( error_type, - cond: builtins.bool | SymBool, + cond: _Union[builtins.bool, SymBool], message: _Callable[[], str], ): # noqa: F811 if not isinstance(cond, (builtins.bool, SymBool)): @@ -2092,7 +2092,7 @@ def _dtype(self): return torch.quint2x4 -_storage_classes: set[type[TypedStorage | UntypedStorage]] = { +_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = { UntypedStorage, DoubleStorage, FloatStorage, @@ -2398,13 +2398,13 @@ def __eq__(self, other): and self.dynamic == other.dynamic ) - def apply_mode(self, mode: str | None): + def apply_mode(self, mode: _Optional[str]): if mode and mode != "default": from torch._inductor import list_mode_options self.apply_options(list_mode_options(mode, self.dynamic)) - def apply_options(self, options: dict[str, _Any] | None): + def apply_options(self, options: _Optional[dict[str, _Any]]): if not options: return @@ -2524,10 +2524,12 @@ def compile( model: _Callable[_InputT, _RetT], *, fullgraph: builtins.bool = False, - dynamic: builtins.bool | None = None, - backend: str | _Callable = "inductor", - mode: str | None = None, - options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, + dynamic: _Optional[builtins.bool] = None, + backend: _Union[str, _Callable] = "inductor", + mode: _Union[str, None] = None, + options: _Optional[ + dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] + ] = None, disable: builtins.bool = False, ) -> _Callable[_InputT, _RetT]: ... @@ -2537,27 +2539,31 @@ def compile( model: None = None, *, fullgraph: builtins.bool = False, - dynamic: builtins.bool | None = None, - backend: str | _Callable = "inductor", - mode: str | None = None, - options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, + dynamic: _Optional[builtins.bool] = None, + backend: _Union[str, _Callable] = "inductor", + mode: _Union[str, None] = None, + options: _Optional[ + dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] + ] = None, disable: builtins.bool = False, ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ... def compile( - model: _Callable[_InputT, _RetT] | None = None, + model: _Optional[_Callable[_InputT, _RetT]] = None, *, fullgraph: builtins.bool = False, - dynamic: builtins.bool | None = None, - backend: str | _Callable = "inductor", - mode: str | None = None, - options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, + dynamic: _Optional[builtins.bool] = None, + backend: _Union[str, _Callable] = "inductor", + mode: _Union[str, None] = None, + options: _Optional[ + dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] + ] = None, disable: builtins.bool = False, -) -> ( - _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]] - | _Callable[_InputT, _RetT] -): +) -> _Union[ + _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]], + _Callable[_InputT, _RetT], +]: """ Optimizes given model/function using TorchDynamo and specified backend. If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` @@ -2865,7 +2871,7 @@ def __getattr__(name): @functools.cache -def get_device_module(device: torch.device | str | None = None): +def get_device_module(device: _Optional[_Union[torch.device, str]] = None): """ Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). If no device is given, return the module for the current accelerator or CPU if none is present. @@ -2891,8 +2897,8 @@ def get_device_module(device: torch.device | str | None = None): def _constrain_as_size( symbol, - min: builtins.int | None = None, - max: builtins.int | None = None, + min: _Optional[builtins.int] = None, + max: _Optional[builtins.int] = None, ): """ This indicates that a given int is size-like, and can be used in any context where a size is expected. diff --git a/torch/_compile.py b/torch/_compile.py index bf7d715883d58..76ddd3ccb05b4 100644 --- a/torch/_compile.py +++ b/torch/_compile.py @@ -5,7 +5,7 @@ import functools from collections.abc import Callable -from typing import overload, TypeVar +from typing import Optional, overload, TypeVar, Union from typing_extensions import ParamSpec @@ -26,8 +26,8 @@ def _disable_dynamo( def _disable_dynamo( - fn: Callable[_P, _T] | None = None, recursive: bool = True -) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]: + fn: Optional[Callable[_P, _T]] = None, recursive: bool = True +) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: """ This API should be only used inside torch, external users should still use torch._dynamo.disable. The main goal of this API is to avoid circular diff --git a/torch/_guards.py b/torch/_guards.py index c9daab1e69e81..1bd32fc7f08ec 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -14,7 +14,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, TYPE_CHECKING, TypeVar +from typing import Any, Generic, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union import torch from torch.utils import _pytree as pytree @@ -92,7 +92,7 @@ def __str__(self) -> str: return f"{self.frame_id}/{self.frame_compile_id}" @classmethod - def from_string(cls, compile_id: str | None) -> CompileId | None: + def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]: """ Factory method that creates a CompileId from its string representation. Keep this in sync with the __str__ method. @@ -255,14 +255,14 @@ class Guard: create_fn: Callable[[GuardBuilderBase, Guard], None] # Export only. These values are written to at time of guard check_fn creation. - guard_types: list[str] | None = None - code_list: list[str] | None = None - obj_weakref: object | None = None - guarded_class_weakref: weakref.ReferenceType[Any] | None = None - - stack: CapturedTraceback | None = None - user_stack: traceback.StackSummary | None = None - _hash: int | None = None + guard_types: Optional[list[str]] = None + code_list: Optional[list[str]] = None + obj_weakref: Optional[object] = None + guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None + + stack: Optional[CapturedTraceback] = None + user_stack: Optional[traceback.StackSummary] = None + _hash: Optional[int] = None _unserializable: bool = False def __hash__(self) -> int: @@ -379,7 +379,7 @@ def create_fn_name(self) -> str: def set_export_info( self, guard_type: str, - guarded_class: weakref.ReferenceType[Any] | None, + guarded_class: Optional[weakref.ReferenceType[Any]], code_list: list[str], obj_weakref: object, ) -> None: @@ -492,7 +492,7 @@ class GuardsCheckpointState: def __init__(self, dynamo_guards: set[Guard]) -> None: self.dynamo_guards = dynamo_guards - def diff(self, other: GuardsCheckpointState) -> set[Guard] | None: + def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]: """ Produces a delta against another GuardsCheckpointState. @@ -516,7 +516,7 @@ class ModuleContextCheckpointState: def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None: self.nn_modules = nn_modules - def diff(self, other: ModuleContextCheckpointState) -> set[str] | None: + def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]: """ Produces a delta against another ModuleContextCheckpointState. @@ -552,7 +552,7 @@ class GlobalContextCheckpointState: def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None: self.global_state = global_states - def diff(self, other: GlobalContextCheckpointState) -> set[str] | None: + def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]: """ Produces a delta against another GlobalContextCheckpointState. @@ -605,7 +605,7 @@ def restore_graphstate(self, state: GlobalContextCheckpointState) -> None: # Like a Set[Guard] but will record the user stack on all guards at the # time they were installed at their destination class GuardsSet: - def __init__(self, inner: set[Guard] | None = None) -> None: + def __init__(self, inner: Optional[set[Guard]] = None) -> None: if inner is None: inner = set() self.inner = inner @@ -683,13 +683,13 @@ def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ... def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ... @abstractmethod - def get_autograd_key_entry(self, identifier: str) -> Callable | None: ... + def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: ... @abstractmethod def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ... @abstractmethod - def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: ... + def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ... @abstractmethod def add_lazy_bwd_entry( @@ -702,7 +702,7 @@ def add_lazy_bwd_entry( @abstractmethod def get_lazy_bwd_entry( self, identifier: str, tangent_metadata: tuple[object] - ) -> tuple[torch.fx.GraphModule | None, int | None]: ... + ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ... class InvokeSubgraphCache(HopSubgraphCache): @@ -726,13 +726,13 @@ def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: self.autograd_cache[identifier] = key - def get_autograd_key_entry(self, identifier: str) -> Callable | None: + def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: return self.autograd_cache.get(identifier, None) def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: self.proxy_dispatch_cache[identifier] = key - def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: + def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: return self.proxy_dispatch_cache.get(identifier, None) def add_lazy_bwd_entry( @@ -748,7 +748,7 @@ def add_lazy_bwd_entry( def get_lazy_bwd_entry( self, identifier: str, tangent_metadata: tuple[object] - ) -> tuple[torch.fx.GraphModule | None, int | None]: + ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: if identifier not in self.lazy_bwd_cache: return (None, None) @@ -765,7 +765,7 @@ def add_effects(self, identifier: str, effects: set) -> None: ) self.effects_cache[identifier] = effects - def get_effects(self, identifier: str) -> set | None: + def get_effects(self, identifier: str) -> Optional[set]: """Retrieve the effect types for a given invoke_subgraph identifier.""" return self.effects_cache.get(identifier, None) @@ -814,7 +814,7 @@ def get() -> CompileContext: def try_get() -> CompileContext | None: return getattr(_TLS, "compile_context", None) - def __init__(self, compile_id: CompileId | None) -> None: + def __init__(self, compile_id: Optional[CompileId]) -> None: assert compile_id is None or isinstance(compile_id, CompileId) self.compile_id: CompileId | None = compile_id self.attempt = 0 @@ -822,14 +822,14 @@ def __init__(self, compile_id: CompileId | None) -> None: self.shape_env_guards: list[str] = [] @staticmethod - def current_compile_id() -> CompileId | None: + def current_compile_id() -> Optional[CompileId]: self = CompileContext.try_get() if self is None: return None return self.compile_id @staticmethod - def current_trace_id() -> TraceId | None: + def current_trace_id() -> Optional[TraceId]: self = CompileContext.try_get() if self is None: return None @@ -858,13 +858,13 @@ def get() -> TracingContext: "TracingContext.get() must be called within an ongoing trace." ) - def __init__(self, fake_mode: FakeTensorMode | None) -> None: + def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: self.guards_context = GuardsContext() self.module_context = ModuleContext() self.global_context = GlobalContext() self.previously_inlined_functions: dict[Any, Any] = dict() self.previously_cleaned_instructions: dict[Any, Any] = dict() - self.fake_mode: FakeTensorMode | None = fake_mode + self.fake_mode: Optional[FakeTensorMode] = fake_mode self.frame_summary_stack: list[traceback.FrameSummary] = [] # This is morally part of frame_summary_stack, but it is kept separate # for clarity. As we process a frame, this variable gets updated @@ -872,16 +872,16 @@ def __init__(self, fake_mode: FakeTensorMode | None) -> None: # function call, this gets cleared and the frame location is pushed # to frame_summary_stack (prepping this variable for the inner frame's # progress) - self.loc_in_frame: tuple[str, int, str] | None = None + self.loc_in_frame: Optional[tuple[str, int, str]] = None # this is only set after aot_autograd - self.fw_metadata: ViewAndMutationMeta | None = None + self.fw_metadata: Optional[ViewAndMutationMeta] = None # this is only set when the DDPOptimizer is used - self.ddp_optimizer_ctx: DDPOptimizerContext | None = None + self.ddp_optimizer_ctx: Optional[DDPOptimizerContext] = None # this is only set after aot_autograd - self.aot_graph_name: list[str] | None = None - self.params_flat: list[Any] | None = None - self.params_flat_unwrap_subclasses: list[Any] | None = None - self.params_unwrapped_to_flat_index: list[Any] | None = None + self.aot_graph_name: Optional[list[str]] = None + self.params_flat: Optional[list[Any]] = None + self.params_flat_unwrap_subclasses: Optional[list[Any]] = None + self.params_unwrapped_to_flat_index: Optional[list[Any]] = None # this is for extended return calling convention from backend # compiler to aot_autograd # Per output, what the compiler specified stride of the output is, @@ -985,7 +985,7 @@ def clear_frame() -> Generator[None, None, None]: @staticmethod @contextlib.contextmanager def current_frame( - frame_summary: traceback.FrameSummary | None, + frame_summary: Optional[traceback.FrameSummary], ) -> Generator[None, None, None]: # frame_summary can be None to solely take advantage of real_stack # attachment to thrown exceptions @@ -1008,7 +1008,7 @@ def current_frame( @staticmethod @contextlib.contextmanager def report_output_strides() -> Generator[ - list[tuple[int, ...] | None] | None, None, None + Optional[list[Optional[tuple[int, ...]]]], None, None ]: tc = TracingContext.try_get() if tc is None: @@ -1028,7 +1028,7 @@ def set_current_loc(filename: str, lineno: int, frame_name: str) -> None: TracingContext.get().loc_in_frame = (filename, lineno, frame_name) @staticmethod - def get_traced_code() -> list[CodeType] | None: + def get_traced_code() -> Optional[list[CodeType]]: tc = TracingContext.try_get() if tc is None: return None @@ -1037,8 +1037,8 @@ def get_traced_code() -> list[CodeType] | None: @contextmanager def compile_context( - context: CompileContext | None, -) -> Generator[CompileContext | None, None, None]: + context: Optional[CompileContext], +) -> Generator[Optional[CompileContext], None, None]: old_context = getattr(_TLS, "compile_context", None) _TLS.compile_context = context try: @@ -1049,8 +1049,8 @@ def compile_context( @contextmanager def tracing( - context: TracingContext | None, -) -> Generator[TracingContext | None, None, None]: + context: Optional[TracingContext], +) -> Generator[Optional[TracingContext], None, None]: """ This function installs the passed in tracing context as a dynamic scoped global variable. @@ -1127,7 +1127,7 @@ def get_base(self) -> Source: return current -def detect_fake_mode(inputs: Any = None) -> FakeTensorMode | None: +def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: """ Attempts to "detect" what the current fake mode is. If there is one ambiently available from TracingContext, we preferentially use that. Otherwise, we @@ -1164,7 +1164,7 @@ def detect_fake_mode(inputs: Any = None) -> FakeTensorMode | None: # pyrefly: ignore [bad-argument-type] fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) if is_traceable_wrapper_subclass(flat_input): - out: list[torch.Tensor | int | torch.SymInt] = [] + out: list[Union[torch.Tensor, int, torch.SymInt]] = [] get_plain_tensors(flat_input, out=out) # type: ignore[arg-type] fake_tensors: list[FakeTensor] = [ x for x in out if isinstance(x, FakeTensor) @@ -1193,7 +1193,7 @@ def detect_fake_mode(inputs: Any = None) -> FakeTensorMode | None: return None -def active_fake_mode() -> FakeTensorMode | None: +def active_fake_mode() -> Optional[FakeTensorMode]: """ Inspects the dispatch mode stack for an active fake mode and returns it. Returns None if no fake mode is active. diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 27c5768477dab..9efa0583cdea7 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -52,7 +52,7 @@ _P = ParamSpec("_P") _R = TypeVar("_R") -BuiltinUnionType: type | tuple[type, ...] = types.UnionType +BuiltinUnionType: Union[type, tuple[type, ...]] = types.UnionType LockType: type try: @@ -1236,7 +1236,7 @@ def _try_get_dispatched_fn(fn): def _get_named_tuple_properties( obj, - loc: torch._C._jit_tree_views.SourceRange | None = None, + loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None, ): if loc is None: @@ -1531,7 +1531,7 @@ def _extract_tensors(obj): return tensors -def _get_model_id(obj) -> str | None: +def _get_model_id(obj) -> Optional[str]: if isinstance(obj, torch.jit.ScriptModule): return str(obj._c._type()) elif isinstance(obj, torch.jit.ScriptFunction): diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index 213393da9aa99..43c8b65767e00 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs """Various linear algebra utility methods for internal use.""" +from typing import Optional + import torch from torch import Tensor @@ -27,7 +29,7 @@ def get_floating_dtype(A): return torch.float32 -def matmul(A: Tensor | None, B: Tensor) -> Tensor: +def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: """Multiply two matrices. If A is None, return B. A can be sparse or dense. B is always @@ -40,12 +42,12 @@ def matmul(A: Tensor | None, B: Tensor) -> Tensor: return torch.matmul(A, B) -def bform(X: Tensor, A: Tensor | None, Y: Tensor) -> Tensor: +def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: """Return bilinear form of matrices: :math:`X^T A Y`.""" return matmul(X.mT, matmul(A, Y)) -def qform(A: Tensor | None, S: Tensor): +def qform(A: Optional[Tensor], S: Tensor): """Return quadratic form :math:`S^T A S`.""" return bform(S, A, S) @@ -55,7 +57,7 @@ def basis(A): return torch.linalg.qr(A).Q -def symeig(A: Tensor, largest: bool | None = False) -> tuple[Tensor, Tensor]: +def symeig(A: Tensor, largest: Optional[bool] = False) -> tuple[Tensor, Tensor]: """Return eigenpairs of A with specified ordering.""" if largest is None: largest = False diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index cdc426047c33f..1137efdc5f63a 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -3,6 +3,8 @@ # Author: Pearu Peterson # Created: February 2020 +from typing import Optional + import torch from torch import _linalg_utils as _utils, Tensor from torch.overrides import handle_torch_function, has_torch_function @@ -256,19 +258,19 @@ class LOBPCGAutogradFunction(torch.autograd.Function): def forward( # type: ignore[override] ctx, A: Tensor, - k: int | None = None, - B: Tensor | None = None, - X: Tensor | None = None, - n: int | None = None, - iK: Tensor | None = None, - niter: int | None = None, - tol: float | None = None, - largest: bool | None = None, - method: str | None = None, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, tracker: None = None, - ortho_iparams: dict[str, int] | None = None, - ortho_fparams: dict[str, float] | None = None, - ortho_bparams: dict[str, bool] | None = None, + ortho_iparams: Optional[dict[str, int]] = None, + ortho_fparams: Optional[dict[str, float]] = None, + ortho_bparams: Optional[dict[str, bool]] = None, ) -> tuple[Tensor, Tensor]: # makes sure that input is contiguous for efficiency. # Note: autograd does not support dense gradients for sparse input yet. @@ -342,19 +344,19 @@ def backward(ctx, D_grad, U_grad): # pyrefly: ignore # bad-override def lobpcg( A: Tensor, - k: int | None = None, - B: Tensor | None = None, - X: Tensor | None = None, - n: int | None = None, - iK: Tensor | None = None, - niter: int | None = None, - tol: float | None = None, - largest: bool | None = None, - method: str | None = None, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, tracker: None = None, - ortho_iparams: dict[str, int] | None = None, - ortho_fparams: dict[str, float] | None = None, - ortho_bparams: dict[str, bool] | None = None, + ortho_iparams: Optional[dict[str, int]] = None, + ortho_fparams: Optional[dict[str, float]] = None, + ortho_bparams: Optional[dict[str, bool]] = None, ) -> tuple[Tensor, Tensor]: """Find the k largest (or smallest) eigenvalues and the corresponding eigenvectors of a symmetric positive definite generalized @@ -582,19 +584,19 @@ def lobpcg( def _lobpcg( A: Tensor, - k: int | None = None, - B: Tensor | None = None, - X: Tensor | None = None, - n: int | None = None, - iK: Tensor | None = None, - niter: int | None = None, - tol: float | None = None, - largest: bool | None = None, - method: str | None = None, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, tracker: None = None, - ortho_iparams: dict[str, int] | None = None, - ortho_fparams: dict[str, float] | None = None, - ortho_bparams: dict[str, bool] | None = None, + ortho_iparams: Optional[dict[str, int]] = None, + ortho_fparams: Optional[dict[str, float]] = None, + ortho_bparams: Optional[dict[str, bool]] = None, ) -> tuple[Tensor, Tensor]: # A must be square: assert A.shape[-2] == A.shape[-1], A.shape @@ -694,10 +696,10 @@ class LOBPCG: def __init__( self, - A: Tensor | None, - B: Tensor | None, + A: Optional[Tensor], + B: Optional[Tensor], X: Tensor, - iK: Tensor | None, + iK: Optional[Tensor], iparams: dict[str, int], fparams: dict[str, float], bparams: dict[str, bool], diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 25089d66d35ea..182883cfc5e59 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -2,6 +2,7 @@ __all__ = ["svd_lowrank", "pca_lowrank"] +from typing import Optional import torch from torch import _linalg_utils as _utils, Tensor @@ -11,8 +12,8 @@ def get_approximate_basis( A: Tensor, q: int, - niter: int | None = 2, - M: Tensor | None = None, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, ) -> Tensor: """Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is @@ -84,9 +85,9 @@ def get_approximate_basis( def svd_lowrank( A: Tensor, - q: int | None = 6, - niter: int | None = 2, - M: Tensor | None = None, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, ) -> tuple[Tensor, Tensor, Tensor]: r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, batches of matrices, or a sparse matrix :math:`A` such that @@ -148,9 +149,9 @@ def svd_lowrank( def _svd_lowrank( A: Tensor, - q: int | None = 6, - niter: int | None = 2, - M: Tensor | None = None, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, ) -> tuple[Tensor, Tensor, Tensor]: # Algorithm 5.1 in Halko et al., 2009 @@ -182,7 +183,7 @@ def _svd_lowrank( def pca_lowrank( A: Tensor, - q: int | None = None, + q: Optional[int] = None, center: bool = True, niter: int = 2, ) -> tuple[Tensor, Tensor, Tensor]: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index d48b421f105c7..a54bf3c026fe2 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3,7 +3,7 @@ from collections.abc import Callable, Sequence from enum import Enum from functools import wraps -from typing import TypeVar +from typing import Optional, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -547,9 +547,9 @@ def meta_sparse_structured_linear( input: Tensor, weight: Tensor, _meta: Tensor, - bias: Tensor | None = None, - _activation_opt: str | None = None, - out_dtype: torch.dtype | None = None, + bias: Optional[Tensor] = None, + _activation_opt: Optional[str] = None, + out_dtype: Optional[torch.dtype] = None, ): output_sizes = list(input.shape) if bias is not None: @@ -581,7 +581,7 @@ def meta_sparse_structured_mm( mat1: Tensor, mat1_meta: Tensor, mat2: Tensor, - out_dtype: torch.dtype | None = None, + out_dtype: Optional[torch.dtype] = None, ): assert len(mat1.shape) == 2 assert len(mat1_meta.shape) == 2 @@ -610,7 +610,7 @@ def meta_sparse_structured_addmm( *, alpha=1, beta=1, - out_dtype: torch.dtype | None = None, + out_dtype: Optional[torch.dtype] = None, ): assert len(input.shape) == 1, ( "only input broadcasted to columns of mat1 * mat2 product is supported" @@ -640,9 +640,9 @@ def meta_sparse_structured_addmm( def meta__cslt_sparse_mm( compressed_A: torch.Tensor, dense_B: torch.Tensor, - bias: Tensor | None = None, - alpha: Tensor | None = None, - out_dtype: torch.dtype | None = None, + bias: Optional[Tensor] = None, + alpha: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, transpose_result: bool = False, alg_id: int = 0, split_k: int = 1, @@ -724,9 +724,9 @@ def meta_segment_reduce( data: Tensor, reduce: str, *, - lengths: Tensor | None = None, - indices: Tensor | None = None, - offsets: Tensor | None = None, + lengths: Optional[Tensor] = None, + indices: Optional[Tensor] = None, + offsets: Optional[Tensor] = None, axis: int = 0, unsafe: bool = False, initial=None, @@ -1468,7 +1468,7 @@ def _linalg_svd_meta( A: Tensor, full_matrices: bool = False, compute_uv: bool = True, - driver: str | None = None, + driver: Optional[str] = None, ): checkIsMatrix(A, "linalg.svd") checkFloatingOrComplex(A, "linalg.svd") @@ -1521,7 +1521,7 @@ def _linalg_broadcast_batch_dims( def _linalg_broadcast_batch_dims_name( arg1: Tensor, arg2: Tensor, - name: str | None, + name: Optional[str], ) -> tuple[Tensor, Tensor]: # If there's no name we assume we don't want to check the errors if name: @@ -1553,10 +1553,10 @@ def _linalg_solve_ex( *, left: bool = True, check_errors: bool = False, - result: Tensor | None = None, - LU: Tensor | None = None, - pivots: Tensor | None = None, - info: Tensor | None = None, + result: Optional[Tensor] = None, + LU: Optional[Tensor] = None, + pivots: Optional[Tensor] = None, + info: Optional[Tensor] = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: checkFloatingOrComplex(A, "linalg.solve") torch._check( @@ -1613,7 +1613,7 @@ def linalg_solve_triangular_meta( upper: bool, left: bool = True, unitriangular: bool = False, - out: Tensor | None = None, + out: Optional[Tensor] = None, ) -> Tensor: if out is None: out = A.new_empty([0]) @@ -2264,7 +2264,7 @@ def meta__fused_moving_avg_obs_fq_helper( @register_meta(aten.mm) @out_wrapper(exact_dtype=True) -def meta_mm(a, b, out_dtype: torch.dtype | None = None): +def meta_mm(a, b, out_dtype: Optional[torch.dtype] = None): torch._check(a.dim() == 2, lambda: "a must be 2D") torch._check(b.dim() == 2, lambda: "b must be 2D") N, M1 = a.shape @@ -2313,12 +2313,12 @@ def device_hint(tensor) -> "str": def calc_conv_nd_return_shape( input_tensor: torch.Tensor, weight: torch.Tensor, - stride: list[int] | int, - padding: list[int] | int, - dilation: list[int] | int, + stride: Union[list[int], int], + padding: Union[list[int], int], + dilation: Union[list[int], int], is_transposed: bool, groups: int, - output_padding: list[int] | int | None = None, + output_padding: Optional[Union[list[int], int]] = None, ): def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ @@ -2384,7 +2384,7 @@ def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int elif len(dilation) == 1: dilation = [dilation[0]] * len(dims) - output_padding_list: list[int] | None = None + output_padding_list: Optional[list[int]] = None if output_padding: if isinstance(output_padding, IntLike): # pyrefly: ignore [bad-assignment] @@ -2435,9 +2435,9 @@ def is_channels_last(ten): def meta_miopen_batch_norm( input_tensor: torch.Tensor, weight: torch.Tensor, - bias: torch.Tensor | None, - running_mean: torch.Tensor | None, - running_var: torch.Tensor | None, + bias: Optional[torch.Tensor], + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], training: bool, exponential_average_factor: float, epsilon: float, @@ -3383,7 +3383,7 @@ def meta_index_Tensor(self, indices): torch._check(bool(indices), lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors - result: list[Tensor | None] = [] + result: list[Optional[Tensor]] = [] for i, index in enumerate(indices): if index is not None: torch._check( @@ -3853,7 +3853,7 @@ def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs): @register_meta([aten._dyn_quant_pack_4bit_weight]) def meta__dyn_quant_pack_4bit_weight( - weights, scales_zeros, bias: Tensor | None, block_size, in_features, out_features + weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features ): torch._check( weights.dtype is torch.uint8, @@ -5655,7 +5655,7 @@ def meta__scaled_dot_product_flash_attention( dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: float | None = None, + scale: Optional[float] = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -5737,12 +5737,12 @@ def meta__scaled_dot_product_cudnn_attention( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Tensor | None, + attn_bias: Optional[Tensor], compute_log_sumexp: bool, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: float | None = None, + scale: Optional[float] = None, ): B = query.size(0) H = query.size(1) @@ -5781,11 +5781,11 @@ def meta__scaled_dot_product_fused_attention_overrideable( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Tensor | None = None, + attn_bias: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: float | None = None, + scale: Optional[float] = None, ): B = query.size(0) H_Q = query.size(1) @@ -5839,7 +5839,7 @@ def meta__scaled_dot_product_flash_backward( is_causal: bool, philox_seed: Tensor, philox_offset: Tensor, - scale: float | None = None, + scale: Optional[float] = None, ): grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2) grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2) @@ -5858,8 +5858,8 @@ def meta__scaled_dot_product_flash_attention_for_cpu( value: Tensor, dropout_p: float = 0.0, is_causal: bool = False, - attn_mask: Tensor | None = None, - scale: float | None = None, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -5895,8 +5895,8 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( logsumexp: Tensor, dropout_p: float, is_causal: bool, - attn_mask: Tensor | None = None, - scale: float | None = None, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, ): # cpus's grad layout is different from cuda's, # i.e. (batch_size, seq_len, num_heads, head_dim) @@ -5927,11 +5927,11 @@ def meta__scaled_dot_product_attention_math_for_mps( query: Tensor, key: Tensor, value: Tensor, - attn_mask: Tensor | None = None, + attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, - dropout_mask: Tensor | None = None, - scale: float | None = None, + dropout_mask: Optional[Tensor] = None, + scale: Optional[float] = None, ) -> tuple[Tensor, Tensor]: def ensure_4d(x): if x.dim() == 3: @@ -5982,11 +5982,11 @@ def meta__scaled_dot_product_efficient_attention( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Tensor | None, + attn_bias: Optional[Tensor], compute_log_sumexp: bool, dropout_p=0.0, is_causal: bool = False, - scale: float | None = None, + scale: Optional[float] = None, ): query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -6032,7 +6032,7 @@ def meta__scaled_dot_product_efficient_backward( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Tensor | None, + attn_bias: Optional[Tensor], out: Tensor, logsumexp: Tensor, philox_seed: Tensor, @@ -6040,7 +6040,7 @@ def meta__scaled_dot_product_efficient_backward( dropout_p: float, grad_input_mask: list[bool], is_causal: bool = False, - scale: float | None = None, + scale: Optional[float] = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -6103,7 +6103,7 @@ def meta__scaled_dot_product_cudnn_backward( max_k: int, dropout_p: float, is_causal: bool, - scale: float | None = None, + scale: Optional[float] = None, ): grad_q = torch.empty_like(query) grad_k = torch.empty_like(key) @@ -6120,18 +6120,18 @@ def meta__flash_attention_forward( query: Tensor, key: Tensor, value: Tensor, - cum_seq_q: Tensor | None, - cum_seq_k: Tensor | None, + cum_seq_q: Optional[Tensor], + cum_seq_k: Optional[Tensor], max_q: int, max_k: int, dropout_p: float, is_causal: bool, return_debug_mask: bool, - scale: float | None = None, - window_size_left: int | None = None, - window_size_right: int | None = None, - seqused_k: Tensor | None = None, - alibi_slopes: Tensor | None = None, + scale: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + seqused_k: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, ): # NB: there are two underlying paths: # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) @@ -6211,9 +6211,9 @@ def meta__flash_attention_backward( is_causal: bool, philox_seed: Tensor, philox_offset: Tensor, - scale: float | None = None, - window_size_left: int | None = None, - window_size_right: int | None = None, + scale: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, ): grad_query = torch.empty_like(query) grad_key = torch.empty_like(key) @@ -6231,18 +6231,18 @@ def meta__efficient_attention_forward( query: Tensor, key: Tensor, value: Tensor, - bias: Tensor | None, - cu_seqlens_q: Tensor | None, - cu_seqlens_k: Tensor | None, - max_seqlen_q: int | None, - max_seqlen_k: int | None, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], dropout_p: float, custom_mask_type: int, compute_log_sumexp: bool = False, - scale: float | None = None, - causal_diagonal: Tensor | None = None, - seqlen_k: Tensor | None = None, - window_size: int | None = None, + scale: Optional[float] = None, + causal_diagonal: Optional[Tensor] = None, + seqlen_k: Optional[Tensor] = None, + window_size: Optional[int] = None, ): B = query.size(0) M = query.size(1) @@ -6284,9 +6284,9 @@ def meta__efficient_attention_backward( query: Tensor, key: Tensor, value: Tensor, - bias: Tensor | None, - cu_seqlens_q: Tensor | None, - cu_seqlens_k: Tensor | None, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], max_seqlen_q: torch.SymInt, max_seqlen_k: torch.SymInt, logsumexp: Tensor, @@ -6295,8 +6295,8 @@ def meta__efficient_attention_backward( philox_offset: Tensor, custom_mask_type: int, bias_requires_grad: bool, - scale: float | None = None, - num_splits_key: int | None = None, + scale: Optional[float] = None, + num_splits_key: Optional[int] = None, shared_storage_dqdkdv: bool = False, ): if shared_storage_dqdkdv: @@ -6339,9 +6339,9 @@ def _check_scaled_mm_sizes( mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - bias: torch.Tensor | None = None, - scale_result: torch.Tensor | None = None, - out_dtype: torch.dtype | None = None, + bias: Optional[torch.Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): def is_fp8_or_fp4_type(dtype): @@ -6520,9 +6520,9 @@ def meta_scaled_mm( mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - bias: torch.Tensor | None = None, - scale_result: torch.Tensor | None = None, - out_dtype: torch.dtype | None = None, + bias: Optional[torch.Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): return _check_scaled_mm_sizes( @@ -6537,10 +6537,10 @@ def _check_scaled_mm_sizes_v2( scale_recipe_a: list[ScalingType], scale_b: list[torch.Tensor], scale_recipe_b: list[ScalingType], - bias: torch.Tensor | None = None, - out_dtype: torch.dtype | None = None, - swizzle_a: list[SwizzleType] | None = None, - swizzle_b: list[SwizzleType] | None = None, + bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + swizzle_a: Optional[list[SwizzleType]] = None, + swizzle_b: Optional[list[SwizzleType]] = None, use_fast_accum: bool = False, ): def is_fp8_or_fp4_type(dtype): @@ -6872,9 +6872,9 @@ def meta_scaled_mm_v2( scale_b: list[torch.Tensor], scale_recipe_b: list[ScalingType], swizzle_b: list[SwizzleType], - bias: torch.Tensor | None = None, - output_dtype: torch.dtype | None = None, - contraction_dims: list[int] | None = None, + bias: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, + contraction_dims: Optional[list[int]] = None, use_fast_accum: bool = False, ): return _check_scaled_mm_sizes_v2( @@ -6997,10 +6997,10 @@ def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None): ) def upsample_nearest2d_backward( grad_output: Tensor, - output_size: Sequence[int | torch.SymInt], - input_size: Sequence[int | torch.SymInt], - scales_h: float | None = None, - scales_w: float | None = None, + output_size: Sequence[Union[int, torch.SymInt]], + input_size: Sequence[Union[int, torch.SymInt]], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, ): full_output_size = upsample_common_check( input_size, output_size, num_spatial_dims=2 @@ -7842,12 +7842,12 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): def _meta_grouped_mm_common( mat_a: Tensor, mat_b: Tensor, - scale_a: torch.Tensor | None, - scale_b: torch.Tensor | None, - offs: Tensor | None = None, - bias: Tensor | None = None, - scale_result: torch.Tensor | None = None, - out_dtype: torch.dtype | None = None, + scale_a: Optional[torch.Tensor], + scale_b: Optional[torch.Tensor], + offs: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): torch._check( @@ -8055,9 +8055,9 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): def meta_grouped_mm( mat_a: Tensor, mat_b: Tensor, - offs: Tensor | None = None, - bias: Tensor | None = None, - out_dtype: torch.dtype | None = None, + offs: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> Tensor: return _meta_grouped_mm_common( mat_a, @@ -8077,10 +8077,10 @@ def meta_scaled_grouped_mm( mat_b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - offs: torch.Tensor | None = None, - bias: torch.Tensor | None = None, - scale_result: torch.Tensor | None = None, - out_dtype: torch.dtype | None = None, + offs: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): # matching _scaled_grouped_mm_cuda Blas.cpp implementation diff --git a/torch/_ops.py b/torch/_ops.py index 23108117a9870..8f8a7328429fa 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -8,7 +8,16 @@ import types from collections.abc import Callable, Iterator from functools import cached_property -from typing import Any, ClassVar, Concatenate, final, Generic, TYPE_CHECKING +from typing import ( + Any, + ClassVar, + Concatenate, + final, + Generic, + Optional, + TYPE_CHECKING, + Union, +) from typing_extensions import ParamSpec, TypeVar import torch @@ -70,7 +79,9 @@ def __init__(self): # for use with OpOverload; cache lookup is done entirely from C++ # for speed. # TODO: The cache is NOT currently used by HigherOrderOperator, but it should! - self._dispatch_cache: dict[DispatchKey, DispatchKey | Callable[..., Any]] = {} + self._dispatch_cache: dict[ + DispatchKey, Union[DispatchKey, Callable[..., Any]] + ] = {} # This table allows you to override the behavior of a particular # dispatch key to call a custom Python function, rather than the @@ -88,7 +99,7 @@ def __init__(self): # makes sense that you should be able to register them, the same # way you can register dispatch keys. self.python_key_table: dict[ - type[TorchDispatchMode | torch.Tensor], Callable[..., Any] + type[Union[TorchDispatchMode, torch.Tensor]], Callable[..., Any] ] = {} # This table allows you to override the behavior of functorch @@ -110,7 +121,12 @@ def has_kernel_for_any_dispatch_key(self, ks): def py_impl( self, - k: type[TorchDispatchMode] | type[torch.Tensor] | TransformType | DispatchKey, + k: Union[ + type[TorchDispatchMode], + type[torch.Tensor], + TransformType, + DispatchKey, + ], ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]: if inspect.isclass(k) and ( @@ -169,7 +185,7 @@ def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: return fn(CppFunctionalizeAPI(), *args, **kwargs) def functionalize_dispatch_mode_fn( - mode: FunctionalTensorMode | None, *args: _P.args, **kwargs: _P.kwargs + mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs ) -> _T: return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) @@ -291,7 +307,12 @@ def __init__(self, name, *, cacheable=False): def py_impl( self, - k: type[TorchDispatchMode] | type[torch.Tensor] | TransformType | DispatchKey, + k: Union[ + type[TorchDispatchMode], + type[torch.Tensor], + TransformType, + DispatchKey, + ], ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) @@ -873,7 +894,7 @@ def _uncache_dispatch(self, key: DispatchKey) -> None: self._dispatch_cache.pop(key, None) # This implements the pre-computation logic for the Python dispatcher. - def _get_dispatch(self, key: DispatchKey) -> DispatchKey | Callable[_P, _T]: + def _get_dispatch(self, key: DispatchKey) -> Union[DispatchKey, Callable[_P, _T]]: # This is only called upon a cache miss assert key not in self._dispatch_cache, f"{self} {key}" diff --git a/torch/_sources.py b/torch/_sources.py index e0ab883a8b46c..1327729a717b1 100644 --- a/torch/_sources.py +++ b/torch/_sources.py @@ -3,7 +3,7 @@ import functools import inspect from textwrap import dedent -from typing import Any, NamedTuple +from typing import Any, NamedTuple, Optional from torch._C import ErrorReport from torch._C._jit_tree_views import SourceRangeFactory @@ -11,8 +11,8 @@ def get_source_lines_and_file( obj: Any, - error_msg: str | None = None, -) -> tuple[list[str], int, str | None]: + error_msg: Optional[str] = None, +) -> tuple[list[str], int, Optional[str]]: """ Wrapper around inspect.getsourcelines and inspect.getsourcefile. @@ -113,7 +113,7 @@ class ParsedDef(NamedTuple): ast: ast.Module ctx: SourceContext source: str - filename: str | None + filename: Optional[str] file_lineno: int diff --git a/torch/_tensor.py b/torch/_tensor.py index c1093f35aa984..c6351ed75ffcb 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -8,7 +8,7 @@ from collections.abc import Callable from copy import deepcopy from numbers import Number -from typing import Any, cast, Concatenate, TypeVar, Union +from typing import Any, cast, Concatenate, Optional, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -180,10 +180,10 @@ def __deepcopy__(self, memo): new_storage = self._typed_storage()._deepcopy(memo) if self.is_quantized: # quantizer_params can be different type based on torch attribute - quantizer_params: ( - tuple[torch.qscheme, float, int] - | tuple[torch.qscheme, Tensor, Tensor, int] - ) + quantizer_params: Union[ + tuple[torch.qscheme, float, int], + tuple[torch.qscheme, Tensor, Tensor, int], + ] if self.qscheme() == torch.per_tensor_affine: quantizer_params = ( self.qscheme(), @@ -366,9 +366,9 @@ def _reduce_ex_internal(self, proto): "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" ) # quantizer_params can be different type based on torch attribute - quantizer_params: ( - tuple[torch.qscheme, float, int] | tuple[Any, Tensor, Tensor, int] - ) + quantizer_params: Union[ + tuple[torch.qscheme, float, int], tuple[Any, Tensor, Tensor, int] + ] if self.qscheme() == torch.per_tensor_affine: quantizer_params = ( torch.per_tensor_affine, @@ -893,7 +893,7 @@ def __reversed__(self): def norm( self, - p: float | str | None = "fro", + p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, dtype=None, @@ -944,15 +944,15 @@ def lu(self, pivot=True, get_infos=False): def stft( self, n_fft: int, - hop_length: int | None = None, - win_length: int | None = None, - window: "Tensor | None" = None, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: "Optional[Tensor]" = None, center: bool = True, pad_mode: str = "reflect", normalized: bool = False, - onesided: bool | None = None, - return_complex: bool | None = None, - align_to_window: bool | None = None, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, ): r"""See :func:`torch.stft` @@ -993,13 +993,13 @@ def stft( def istft( self, n_fft: int, - hop_length: int | None = None, - win_length: int | None = None, - window: "Tensor | None" = None, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: "Optional[Tensor]" = None, center: bool = True, normalized: bool = False, - onesided: bool | None = None, - length: int | None = None, + onesided: Optional[bool] = None, + length: Optional[int] = None, return_complex: bool = False, ): r"""See :func:`torch.istft`""" @@ -1528,7 +1528,9 @@ def to_sparse_coo(self): """ return self.to_sparse() - def dim_order(self, *, ambiguity_check: bool | list[torch.memory_format] = False): + def dim_order( + self, *, ambiguity_check: Union[bool, list[torch.memory_format]] = False + ): """ dim_order(ambiguity_check=False) -> tuple @@ -1710,10 +1712,10 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def __dlpack__( self, *, - stream: Any | None = -1, - max_version: tuple[int, int] | None = None, - dl_device: tuple[enum.IntEnum, int] | None = None, - copy: bool | None = None, + stream: Optional[Any] = -1, + max_version: Optional[tuple[int, int]] = None, + dl_device: Optional[tuple[enum.IntEnum, int]] = None, + copy: Optional[bool] = None, ): """ Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 46af738829312..613fa9ad6ff95 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -3,7 +3,7 @@ import dataclasses import math import textwrap -from typing import Any +from typing import Any, Optional import torch from torch import inf @@ -15,7 +15,7 @@ class __PrinterOptions: threshold: float = 1000 edgeitems: int = 3 linewidth: int = 80 - sci_mode: bool | None = None + sci_mode: Optional[bool] = None PRINT_OPTS = __PrinterOptions() diff --git a/torch/_utils.py b/torch/_utils.py index 70641a7c534d7..01cf9d393188b 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Callable from types import ModuleType -from typing import Any, Generic, TYPE_CHECKING +from typing import Any, Generic, Optional, TYPE_CHECKING from typing_extensions import deprecated, ParamSpec import torch @@ -856,7 +856,7 @@ def _get_device_index( """ if isinstance(device, str): device = torch.device(device) - device_idx: int | None = None + device_idx: Optional[int] = None if isinstance(device, torch.device): if not allow_cpu and device.type == "cpu": raise ValueError(f"Expected a non cpu device, but got: {device}") @@ -1054,7 +1054,7 @@ def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None: ) -def try_import(module_name: str) -> ModuleType | None: +def try_import(module_name: str) -> Optional[ModuleType]: # Implementation based on # https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported if (module := sys.modules.get(module_name, None)) is not None: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 6f95511b5ce80..3a172a814e2e5 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -6,7 +6,7 @@ import tempfile import typing_extensions from collections.abc import Callable -from typing import Any, TypeVar +from typing import Any, Optional, TypeVar from typing_extensions import ParamSpec import torch @@ -255,7 +255,7 @@ def max_clock_rate(): return 1100 -def get_mast_job_name_version() -> tuple[str, int] | None: +def get_mast_job_name_version() -> Optional[tuple[str, int]]: return None @@ -274,7 +274,7 @@ def get_mast_job_name_version() -> tuple[str, int] | None: REQUIRES_SET_PYTHON_MODULE = False -def maybe_upload_prof_stats_to_manifold(profile_path: str) -> str | None: +def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]: print("Uploading profile stats (fb-only otherwise no-op)") return None @@ -367,11 +367,11 @@ def get_default_numa_options(): return None -def log_triton_builds(fail: str | None): +def log_triton_builds(fail: Optional[str]): pass -def find_compile_subproc_binary() -> str | None: +def find_compile_subproc_binary() -> Optional[str]: """ Allows overriding the binary used for subprocesses """ diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 861d4fd4b4153..3f303f78a4713 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools from collections.abc import Callable -from typing import Any +from typing import Any, Optional, Union from typing_extensions import deprecated import torch @@ -9,13 +9,13 @@ from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten -in_dims_t = int | tuple -out_dims_t = int | tuple[int, ...] +in_dims_t = Union[int, tuple] +out_dims_t = Union[int, tuple[int, ...]] # Checks that all args-to-be-batched have the same batch dim size def _validate_and_get_batch_size( - flat_in_dims: list[int | None], + flat_in_dims: list[Optional[int]], flat_args: list, ) -> int: batch_sizes = [ @@ -31,7 +31,7 @@ def _validate_and_get_batch_size( return batch_sizes[0] -def _num_outputs(batched_outputs: Tensor | tuple[Tensor, ...]) -> int: +def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int: if isinstance(batched_outputs, tuple): return len(batched_outputs) return 1 @@ -115,7 +115,7 @@ def _create_batched_inputs( # Undos the batching (and any batch dimensions) associated with the `vmap_level`. def _unwrap_batched( - batched_outputs: Tensor | tuple[Tensor, ...], + batched_outputs: Union[Tensor, tuple[Tensor, ...]], out_dims: out_dims_t, vmap_level: int, batch_size: int, diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index a4c8aaafa351b..5aaa77b25697a 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -69,7 +69,7 @@ ) from struct import unpack from sys import maxsize -from typing import Any +from typing import Any, Union import torch from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING @@ -84,15 +84,15 @@ "nt", ] -_marked_safe_globals_set: set[Callable | tuple[Callable, str]] = set() +_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set() -def _add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]): +def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]): global _marked_safe_globals_set _marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals)) -def _get_safe_globals() -> list[Callable | tuple[Callable, str]]: +def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: global _marked_safe_globals_set return list(_marked_safe_globals_set) @@ -103,14 +103,14 @@ def _clear_safe_globals(): def _remove_safe_globals( - globals_to_remove: list[Callable | tuple[Callable, str]], + globals_to_remove: list[Union[Callable, tuple[Callable, str]]], ): global _marked_safe_globals_set _marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove) class _safe_globals: - def __init__(self, safe_globals: list[Callable | tuple[Callable, str]]): + def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]): self.safe_globals = safe_globals def __enter__(self): diff --git a/torch/functional.py b/torch/functional.py index 33b0ada75324c..013832d59cfb3 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -2,7 +2,7 @@ import itertools import operator from collections.abc import Sequence -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch.nn.functional as F @@ -120,7 +120,7 @@ def broadcast_shapes(*shapes): def split( tensor: Tensor, - split_size_or_sections: int | list[int], + split_size_or_sections: Union[int, list[int]], dim: int = 0, ) -> tuple[Tensor, ...]: r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. @@ -387,13 +387,13 @@ def parse_subscript(n: int) -> str: if TYPE_CHECKING: # The JIT doesn't understand Union, so only add type annotation for mypy def meshgrid( - *tensors: Tensor | list[Tensor], indexing: str | None = None + *tensors: Union[Tensor, list[Tensor]], indexing: Optional[str] = None ) -> tuple[Tensor, ...]: return _meshgrid(*tensors, indexing=indexing) else: - def meshgrid(*tensors, indexing: str | None = None) -> tuple[Tensor, ...]: + def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]: r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors. This is helpful when you want to visualize data over some @@ -490,7 +490,7 @@ def meshgrid(*tensors, indexing: str | None = None) -> tuple[Tensor, ...]: return _meshgrid(*tensors, indexing=indexing) -def _meshgrid(*tensors, indexing: str | None): +def _meshgrid(*tensors, indexing: Optional[str]): if has_torch_function(tensors): return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing) if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): @@ -508,15 +508,15 @@ def _meshgrid(*tensors, indexing: str | None): def stft( input: Tensor, n_fft: int, - hop_length: int | None = None, - win_length: int | None = None, - window: Tensor | None = None, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, center: bool = True, pad_mode: str = "reflect", normalized: bool = False, - onesided: bool | None = None, - return_complex: bool | None = None, - align_to_window: bool | None = None, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, ) -> Tensor: r"""Short-time Fourier transform (STFT). @@ -788,7 +788,7 @@ def _unique_impl( sorted: bool = True, return_inverse: bool = False, return_counts: bool = False, - dim: int | None = None, + dim: Optional[int] = None, ) -> _unique_impl_out: r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> tuple[Tensor, Tensor, Tensor] @@ -956,7 +956,7 @@ def _unique_consecutive_impl( input: Tensor, return_inverse: bool = False, return_counts: bool = False, - dim: int | None = None, + dim: Optional[int] = None, ) -> _unique_impl_out: r"""Eliminates all but the first element from every consecutive group of equivalent elements. @@ -1201,7 +1201,7 @@ def tensordot( a, b, dims: int = 2, - out: torch.Tensor | None = None, + out: Optional[torch.Tensor] = None, ): pass @@ -1210,7 +1210,7 @@ def tensordot( # noqa: F811 a, b, dims: tuple[list[int], list[int]], - out: torch.Tensor | None = None, + out: Optional[torch.Tensor] = None, ): pass @@ -1219,7 +1219,7 @@ def tensordot( # noqa: F811 a, b, dims: list[list[int]], - out: torch.Tensor | None = None, + out: Optional[torch.Tensor] = None, ): pass @@ -1228,7 +1228,7 @@ def tensordot( # noqa: F811 a, b, dims: torch.Tensor, - out: torch.Tensor | None = None, + out: Optional[torch.Tensor] = None, ): pass @@ -1237,7 +1237,7 @@ def tensordot( # noqa: F811 a, b, dims=2, - out: torch.Tensor | None = None, + out: Optional[torch.Tensor] = None, ): r"""Returns a contraction of a and b over multiple dimensions. @@ -1659,7 +1659,7 @@ def norm( # noqa: F811 def norm( # noqa: F811 input, - p: float | str | None = "fro", + p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, @@ -1882,7 +1882,7 @@ def norm( # noqa: F811 def unravel_index( indices: Tensor, - shape: int | Sequence[int] | torch.Size, + shape: Union[int, Sequence[int], torch.Size], ) -> tuple[Tensor, ...]: r"""Converts a tensor of flat indices into a tuple of coordinate tensors that index into an arbitrary tensor of the specified shape. @@ -1938,7 +1938,7 @@ def unravel_index( return res_tensor.unbind(-1) -def _unravel_index(indices: Tensor, shape: int | Sequence[int]) -> Tensor: +def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: torch._check_type( not indices.is_complex() and not indices.is_floating_point() diff --git a/torch/hub.py b/torch/hub.py index 3ec285fcb3a9e..bf138f7784347 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -12,7 +12,7 @@ import warnings import zipfile from pathlib import Path -from typing import Any +from typing import Any, Optional, Union from typing_extensions import deprecated from urllib.error import HTTPError, URLError from urllib.parse import urlparse # noqa: F401 @@ -91,7 +91,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): VAR_DEPENDENCY = "dependencies" MODULE_HUBCONF = "hubconf.py" READ_DATA_CHUNK = 128 * 1024 -_hub_dir: str | None = None +_hub_dir: Optional[str] = None @contextlib.contextmanager @@ -417,7 +417,7 @@ def get_dir() -> str: return os.path.join(_get_torch_home(), "hub") -def set_dir(d: str | os.PathLike) -> None: +def set_dir(d: Union[str, os.PathLike]) -> None: r""" Optionally set the Torch Hub directory used to save downloaded models & weights. @@ -694,7 +694,7 @@ def _load_local(hubconf_dir, model, *args, **kwargs): def download_url_to_file( url: str, dst: str, - hash_prefix: str | None = None, + hash_prefix: Optional[str] = None, progress: bool = True, ) -> None: r"""Download object at the given URL to a local path. @@ -816,11 +816,11 @@ def _legacy_zip_load( def load_state_dict_from_url( url: str, - model_dir: str | None = None, + model_dir: Optional[str] = None, map_location: MAP_LOCATION = None, progress: bool = True, check_hash: bool = False, - file_name: str | None = None, + file_name: Optional[str] = None, weights_only: bool = False, ) -> dict[str, Any]: r"""Loads the Torch serialized object at the given URL. diff --git a/torch/library.py b/torch/library.py index 5305d647bc613..76e5d27aae434 100644 --- a/torch/library.py +++ b/torch/library.py @@ -7,7 +7,7 @@ import traceback import weakref from collections.abc import Callable, Sequence -from typing import Any, overload, TYPE_CHECKING, TypeVar, Union +from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union from typing_extensions import deprecated, ParamSpec import torch @@ -98,7 +98,7 @@ def __init__(self, ns, kind, dispatch_key=""): frame = traceback.extract_stack(limit=2)[0] filename, lineno = frame.filename, frame.lineno - self.m: Any | None = torch._C._dispatch_library( + self.m: Optional[Any] = torch._C._dispatch_library( kind, ns, dispatch_key, filename, lineno ) self.ns = ns @@ -399,7 +399,7 @@ def fallback(self, fn, dispatch_key="", *, with_keyset=False): self.m.fallback(dispatch_key, fn, with_keyset) - def _register_effectful_op(self, op_name: str, effect: EffectType | None): + def _register_effectful_op(self, op_name: str, effect: Optional[EffectType]): """ Registers an effect to an operator. This is used to register an op that has side effects that is not capturable by the schema. @@ -570,20 +570,20 @@ def wrap(f): @overload def impl( qualname: str, - types: str | Sequence[str], + types: Union[str, Sequence[str]], func: None = None, *, - lib: Library | None = None, + lib: Optional[Library] = None, ) -> Callable[[Callable[..., object]], None]: ... @overload def impl( qualname: str, - types: str | Sequence[str], + types: Union[str, Sequence[str]], func: Callable[..., object], *, - lib: Library | None = None, + lib: Optional[Library] = None, ) -> None: ... @@ -599,10 +599,10 @@ def impl( @functools.singledispatch def impl( qualname: str, - types: str | Sequence[str], - func: Callable[_P, _T] | None = None, + types: Union[str, Sequence[str]], + func: Optional[Callable[_P, _T]] = None, *, - lib: Library | None = None, + lib: Optional[Library] = None, ) -> object: """Register an implementation for a device type for this operator. @@ -683,10 +683,10 @@ def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]: @overload def _impl( qualname: str, - types: str | Sequence[str], + types: Union[str, Sequence[str]], func: None = None, *, - lib: Library | None = None, + lib: Optional[Library] = None, disable_dynamo: bool = False, ) -> Callable[[Callable[..., object]], None]: ... @@ -694,22 +694,22 @@ def _impl( @overload def _impl( qualname: str, - types: str | Sequence[str], + types: Union[str, Sequence[str]], func: Callable[..., object], *, - lib: Library | None = None, + lib: Optional[Library] = None, disable_dynamo: bool = False, ) -> None: ... def _impl( qualname: str, - types: str | Sequence[str], - func: Callable[..., object] | None = None, + types: Union[str, Sequence[str]], + func: Optional[Callable[..., object]] = None, *, - lib: Library | None = None, + lib: Optional[Library] = None, disable_dynamo: bool = False, -) -> Callable[[Callable[..., object]], None] | None: +) -> Optional[Callable[[Callable[..., object]], None]]: # See impl() if isinstance(types, str): types = (types,) @@ -786,10 +786,10 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): def register_kernel( op: _op_identifier, device_types: device_types_t, - func: Callable | None = None, + func: Optional[Callable] = None, /, *, - lib: Library | None = None, + lib: Optional[Library] = None, ): """Register an implementation for a device type for this operator. @@ -857,7 +857,7 @@ def register_autocast( cast_inputs: _dtype, /, *, - lib: Library | None = None, + lib: Optional[Library] = None, ): r"""Register an autocast dispatch rule for this custom op. @@ -948,10 +948,10 @@ def kernel(_, *args, **kwargs): def register_fake( op: _op_identifier, - func: Callable | None = None, + func: Optional[Callable] = None, /, *, - lib: Library | None = None, + lib: Optional[Library] = None, _stacklevel: int = 1, allow_override: bool = False, ): @@ -1084,9 +1084,9 @@ def register(func): def _register_effectful_op( op: _op_identifier, - effect: EffectType | None, + effect: Optional[EffectType], *, - lib: Library | None = None, + lib: Optional[Library] = None, ) -> None: r""" To specify that an operator has side-effects, we must register an effect @@ -1125,7 +1125,7 @@ def register_autograd( backward: Callable, /, *, - setup_context: Callable | None = None, + setup_context: Optional[Callable] = None, lib=None, ) -> None: r"""Register a backward formula for this custom op. @@ -1253,10 +1253,10 @@ def register_autograd( def register_torch_dispatch( op: _op_identifier, torch_dispatch_class: Any, - func: Callable | None = None, + func: Optional[Callable] = None, /, *, - lib: Library | None = None, + lib: Optional[Library] = None, ): r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. @@ -1333,7 +1333,7 @@ def register(func): def register_vmap( op: _op_identifier, - func: Callable | None = None, + func: Optional[Callable] = None, /, *, lib=None, @@ -1525,7 +1525,7 @@ def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": def get_kernel( - op: _op_identifier, dispatch_key: str | torch.DispatchKey + op: _op_identifier, dispatch_key: Union[str, torch.DispatchKey] ) -> torch._C._SafeKernelFunction: """Returns the computed kernel for a given operator and dispatch key. @@ -1607,11 +1607,11 @@ def get_kernel( def opcheck( - op: torch._ops.OpOverload | torch._ops.OpOverloadPacket | CustomOpDef, + op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], args: tuple[Any, ...], - kwargs: dict[str, Any] | None = None, + kwargs: Optional[dict[str, Any]] = None, *, - test_utils: str | Sequence[str] = _OPCHECK_DEFAULT_UTILS, + test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS, raise_exception: bool = True, atol=None, rtol=None, diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index dd3ff69fd6af8..4bae914f0292b 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import warnings from collections.abc import Callable -from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar +from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -16,7 +16,7 @@ from torch._prims_common import DimsType from torch.types import _dtype as DType - DimOrDims: TypeAlias = DimsType | None + DimOrDims: TypeAlias = Optional[DimsType] else: # The JIT doesn't understand Union, nor torch.dtype here DType = int @@ -624,7 +624,7 @@ def _sparse_coo_scatter_reduction_helper( mask_input: Tensor, dims: tuple[int, ...], keepdim: bool, - dtype: DType | None = None, + dtype: Optional[DType] = None, ) -> Tensor: reduce = op.__name__ valid_reductions = ["sum", "prod", "amax", "amin"] @@ -744,7 +744,7 @@ def _sparse_csr_segment_reduction_helper( mask_input: Tensor, dims: tuple[int, ...], keepdim: bool, - dtype: DType | None = None, + dtype: Optional[DType] = None, ) -> Tensor: # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True # FIXME: when dense dimensions are implemented for CSR tensors @@ -869,7 +869,7 @@ def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: ) -def _input_mask(input: Tensor | MaskedTensor, *args, **kwargs) -> Tensor: +def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor: """Return canonical input mask. A canonical input mask is defined as a boolean mask tensor that @@ -1000,7 +1000,9 @@ def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: ) -def _combine_input_and_mask(op, input: MaskedTensor | Tensor, mask, *args) -> Tensor: +def _combine_input_and_mask( + op, input: Union[MaskedTensor, Tensor], mask, *args +) -> Tensor: def helper(input, mask): if mask is None: return input @@ -1044,12 +1046,12 @@ def backward(ctx, grad_output): @_apply_docstring_templates def sum( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: DimOrDims = None, *, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: # __doc__ is generated by _apply_docstring_templates decorator if dtype is None: @@ -1097,12 +1099,12 @@ def sum( @_apply_docstring_templates def prod( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: DimOrDims = None, *, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: # __doc__ is generated by _apply_docstring_templates decorator if dtype is None: @@ -1177,8 +1179,8 @@ def cumsum( input: Tensor, dim: int, *, - dtype: DType | None = None, - mask: Tensor | None = None, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1197,8 +1199,8 @@ def cumprod( input: Tensor, dim: int, *, - dtype: DType | None = None, - mask: Tensor | None = None, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1214,12 +1216,12 @@ def cumprod( @_apply_docstring_templates def amax( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: DimOrDims = None, *, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1264,12 +1266,12 @@ def amax( @_apply_docstring_templates def amin( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: DimOrDims = None, *, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1314,12 +1316,12 @@ def amin( @_apply_docstring_templates def argmax( - input: Tensor | MaskedTensor, - dim: int | None = None, + input: Union[Tensor, MaskedTensor], + dim: Optional[int] = None, *, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1340,12 +1342,12 @@ def argmax( @_apply_docstring_templates def argmin( - input: Tensor | MaskedTensor, - dim: int | None = None, + input: Union[Tensor, MaskedTensor], + dim: Optional[int] = None, *, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1366,12 +1368,12 @@ def argmin( @_apply_docstring_templates def mean( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: DimOrDims = None, *, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1433,12 +1435,12 @@ def mean( @_apply_docstring_templates def median( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: int = -1, *, keepdim: bool = False, - dtype: DType | None = None, - mask: Tensor | None = None, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1480,8 +1482,8 @@ def logsumexp( dim: DimOrDims = None, *, keepdim: bool = False, - dtype: DType | None = None, - mask: Tensor | None = None, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1497,12 +1499,12 @@ def logsumexp( # Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations def logaddexp( - input: Tensor | MaskedTensor, - other: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], + other: Union[Tensor, MaskedTensor], *, - dtype: DType | None = None, - input_mask: Tensor | None = None, - other_mask: Tensor | None = None, + dtype: Optional[DType] = None, + input_mask: Optional[Tensor] = None, + other_mask: Optional[Tensor] = None, ) -> Tensor: """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor @@ -1559,13 +1561,13 @@ def logaddexp( @_apply_docstring_templates def norm( - input: Tensor | MaskedTensor, - ord: float | None = 2.0, + input: Union[Tensor, MaskedTensor], + ord: Optional[float] = 2.0, dim: DimOrDims = None, *, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1594,15 +1596,15 @@ def norm( def _std_var( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: DimOrDims, - unbiased: bool | None, + unbiased: Optional[bool], *, - correction_opt: int | float | None, - keepdim: bool | None, - dtype: DType | None, - mask: Tensor | None, - take_sqrt: bool | None, + correction_opt: Optional[Union[int, float]], + keepdim: Optional[bool], + dtype: Optional[DType], + mask: Optional[Tensor], + take_sqrt: Optional[bool], ) -> Tensor: assert unbiased is None or correction_opt is None, ( "Only one of unbiased and correction may be given" @@ -1675,14 +1677,14 @@ def _std_var( @_apply_docstring_templates def var( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: DimOrDims = None, - unbiased: bool | None = None, + unbiased: Optional[bool] = None, *, - correction: int | float | None = None, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + correction: Optional[Union[int, float]] = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1706,14 +1708,14 @@ def var( @_apply_docstring_templates def std( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: DimOrDims = None, - unbiased: bool | None = None, + unbiased: Optional[bool] = None, *, - correction: int | None = None, - keepdim: bool | None = False, - dtype: DType | None = None, - mask: Tensor | None = None, + correction: Optional[int] = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: """\ {reduction_signature} @@ -1737,11 +1739,11 @@ def std( @_apply_docstring_templates def softmax( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: int, *, - dtype: DType | None = None, - mask: Tensor | None = None, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1757,11 +1759,11 @@ def softmax( @_apply_docstring_templates def log_softmax( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: int, *, - dtype: DType | None = None, - mask: Tensor | None = None, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1777,11 +1779,11 @@ def log_softmax( @_apply_docstring_templates def softmin( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], dim: int, *, - dtype: DType | None = None, - mask: Tensor | None = None, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1797,13 +1799,13 @@ def softmin( @_apply_docstring_templates def normalize( - input: Tensor | MaskedTensor, + input: Union[Tensor, MaskedTensor], ord: float, dim: int, *, eps: float = 1e-12, - dtype: DType | None = None, - mask: Tensor | None = None, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, ) -> Tensor: if dtype is None: dtype = input.dtype diff --git a/torch/nn/_reduction.py b/torch/nn/_reduction.py index a3ca62929a3b5..9764f935b7c3d 100644 --- a/torch/nn/_reduction.py +++ b/torch/nn/_reduction.py @@ -1,4 +1,5 @@ import warnings +from typing import Optional # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h @@ -30,8 +31,8 @@ def get_enum(reduction: str) -> int: # We use these functions in torch/legacy as well, in which case we'll silence the warning def legacy_get_string( - size_average: bool | None, - reduce: bool | None, + size_average: Optional[bool], + reduce: Optional[bool], emit_warning: bool = True, ) -> str: warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." @@ -53,8 +54,8 @@ def legacy_get_string( def legacy_get_enum( - size_average: bool | None, - reduce: bool | None, + size_average: Optional[bool], + reduce: Optional[bool], emit_warning: bool = True, ) -> int: return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/torch/nn/common_types.py b/torch/nn/common_types.py index e1928414a396e..9262c45472271 100644 --- a/torch/nn/common_types.py +++ b/torch/nn/common_types.py @@ -1,4 +1,4 @@ -from typing import TypeAlias as _TypeAlias, TypeVar +from typing import Optional, TypeAlias as _TypeAlias, TypeVar from torch import Tensor @@ -29,9 +29,9 @@ _size_6_t: _TypeAlias = _scalar_or_tuple_6_t[int] # For arguments which represent optional size parameters (eg, adaptive pool parameters) -_size_any_opt_t: _TypeAlias = _scalar_or_tuple_any_t[int | None] -_size_2_opt_t: _TypeAlias = _scalar_or_tuple_2_t[int | None] -_size_3_opt_t: _TypeAlias = _scalar_or_tuple_3_t[int | None] +_size_any_opt_t: _TypeAlias = _scalar_or_tuple_any_t[Optional[int]] +_size_2_opt_t: _TypeAlias = _scalar_or_tuple_2_t[Optional[int]] +_size_3_opt_t: _TypeAlias = _scalar_or_tuple_3_t[Optional[int]] # For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) _ratio_2_t: _TypeAlias = _scalar_or_tuple_2_t[float] diff --git a/torch/nn/init.py b/torch/nn/init.py index 900b2d34bc08f..3956d9399876e 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -3,7 +3,7 @@ import math import warnings from collections.abc import Callable -from typing import Literal, TypeVar +from typing import Literal, Optional as _Optional, TypeVar from typing_extensions import ParamSpec import torch @@ -67,7 +67,7 @@ # managers, so these need to be implemented as builtins. Using these wrappers # lets us keep those builtins small and reusable. def _no_grad_uniform_( - tensor: Tensor, a: float, b: float, generator: torch.Generator | None = None + tensor: Tensor, a: float, b: float, generator: _Optional[torch.Generator] = None ) -> Tensor: with torch.no_grad(): return tensor.uniform_(a, b, generator=generator) @@ -77,7 +77,7 @@ def _no_grad_normal_( tensor: Tensor, mean: float, std: float, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: with torch.no_grad(): return tensor.normal_(mean, std, generator=generator) @@ -89,7 +89,7 @@ def _no_grad_trunc_normal_( std: float, a: float, b: float, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x: float) -> float: @@ -138,7 +138,7 @@ def _no_grad_zero_(tensor: Tensor) -> Tensor: def calculate_gain( - nonlinearity: _NonlinearityType, param: int | float | None = None + nonlinearity: _NonlinearityType, param: _Optional[int | float] = None ) -> float: r"""Return the recommended gain value for the given nonlinearity function. @@ -215,7 +215,7 @@ def uniform_( tensor: Tensor, a: float = 0.0, b: float = 1.0, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from the uniform distribution. @@ -242,7 +242,7 @@ def normal_( tensor: Tensor, mean: float = 0.0, std: float = 1.0, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from the normal distribution. @@ -271,7 +271,7 @@ def trunc_normal_( std: float = 1.0, a: float = -2.0, b: float = 2.0, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from a truncated normal distribution. @@ -438,7 +438,7 @@ def _calculate_fan_in_and_fan_out(tensor: Tensor) -> tuple[int, int]: def xavier_uniform_( tensor: Tensor, gain: float = 1.0, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Xavier uniform distribution. @@ -471,7 +471,7 @@ def xavier_uniform_( def xavier_normal_( tensor: Tensor, gain: float = 1.0, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Xavier normal distribution. @@ -515,7 +515,7 @@ def kaiming_uniform_( a: float = 0, mode: _FanMode = "fan_in", nonlinearity: _NonlinearityType = "leaky_relu", - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. @@ -580,7 +580,7 @@ def kaiming_normal_( a: float = 0, mode: _FanMode = "fan_in", nonlinearity: _NonlinearityType = "leaky_relu", - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Kaiming normal distribution. @@ -631,7 +631,7 @@ def kaiming_normal_( def orthogonal_( tensor: Tensor, gain: float = 1, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input `Tensor` with a (semi) orthogonal matrix. @@ -683,7 +683,7 @@ def sparse_( tensor: Tensor, sparsity: float, std: float = 0.01, - generator: torch.Generator | None = None, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the 2D input `Tensor` as a sparse matrix. diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index dac27cdb0d246..edd65601db985 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import warnings +from typing import Optional import torch import torch.nn.functional as F @@ -260,8 +261,8 @@ def __init__( min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False, - min_value: float | None = None, - max_value: float | None = None, + min_value: Optional[float] = None, + max_value: Optional[float] = None, ) -> None: super().__init__() if min_value is not None: @@ -1052,7 +1053,7 @@ def extra_repr(self) -> str: return str(self.lambd) -def _check_arg_device(x: torch.Tensor | None) -> bool: +def _check_arg_device(x: Optional[torch.Tensor]) -> bool: if x is not None: return x.device.type in [ "cpu", @@ -1062,7 +1063,7 @@ def _check_arg_device(x: torch.Tensor | None) -> bool: return True -def _arg_requires_grad(x: torch.Tensor | None) -> bool: +def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool: if x is not None: return x.requires_grad return False @@ -1155,8 +1156,8 @@ class MultiheadAttention(Module): """ __constants__ = ["batch_first"] - bias_k: torch.Tensor | None - bias_v: torch.Tensor | None + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] def __init__( self, @@ -1257,12 +1258,12 @@ def forward( query: Tensor, key: Tensor, value: Tensor, - key_padding_mask: Tensor | None = None, + key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, - attn_mask: Tensor | None = None, + attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, is_causal: bool = False, - ) -> tuple[Tensor, Tensor | None]: + ) -> tuple[Tensor, Optional[Tensor]]: r"""Compute attention outputs using query, key, and value embeddings. Supports optional parameters for padding, masks and attention weights. @@ -1516,10 +1517,10 @@ def forward( def merge_masks( self, - attn_mask: Tensor | None, - key_padding_mask: Tensor | None, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], query: Tensor, - ) -> tuple[Tensor | None, int | None]: + ) -> tuple[Optional[Tensor], Optional[int]]: r"""Determine mask type and combine masks if necessary. If only one mask is provided, that mask @@ -1534,8 +1535,8 @@ def merge_masks( merged_mask: merged mask mask_type: merged mask type (0, 1, or 2) """ - mask_type: int | None = None - merged_mask: Tensor | None = None + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None if key_padding_mask is not None: mask_type = 1 @@ -1731,9 +1732,9 @@ class Softmin(Module): """ __constants__ = ["dim"] - dim: int | None + dim: Optional[int] - def __init__(self, dim: int | None = None) -> None: + def __init__(self, dim: Optional[int] = None) -> None: super().__init__() self.dim = dim @@ -1796,9 +1797,9 @@ class Softmax(Module): """ __constants__ = ["dim"] - dim: int | None + dim: Optional[int] - def __init__(self, dim: int | None = None) -> None: + def __init__(self, dim: Optional[int] = None) -> None: super().__init__() self.dim = dim @@ -1881,9 +1882,9 @@ class LogSoftmax(Module): """ __constants__ = ["dim"] - dim: int | None + dim: Optional[int] - def __init__(self, dim: int | None = None) -> None: + def __init__(self, dim: Optional[int] = None) -> None: super().__init__() self.dim = dim diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 40a912b4f0568..2ac05f2e8f933 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any +from typing import Any, Optional import torch from torch import Tensor @@ -29,7 +29,7 @@ class _NormBase(Module): __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] num_features: int eps: float - momentum: float | None + momentum: Optional[float] affine: bool track_running_stats: bool # WARNING: weight and bias purposely not defined here. @@ -39,7 +39,7 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: float | None = 0.1, + momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, @@ -65,8 +65,8 @@ def __init__( self.register_buffer( "running_var", torch.ones(num_features, **factory_kwargs) ) - self.running_mean: Tensor | None - self.running_var: Tensor | None + self.running_mean: Optional[Tensor] + self.running_var: Optional[Tensor] self.register_buffer( "num_batches_tracked", torch.tensor( @@ -76,7 +76,7 @@ def __init__( **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, ), ) - self.num_batches_tracked: Tensor | None + self.num_batches_tracked: Optional[Tensor] else: self.register_buffer("running_mean", None) self.register_buffer("running_var", None) @@ -146,7 +146,7 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: float | None = 0.1, + momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, @@ -718,10 +718,10 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: float | None = 0.1, + momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True, - process_group: Any | None = None, + process_group: Optional[Any] = None, device=None, dtype=None, ) -> None: diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index d99151369e18e..f062c4bcbd12b 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -4,7 +4,7 @@ import operator from collections import abc as container_abcs, OrderedDict from itertools import chain, islice -from typing import Any, overload, TYPE_CHECKING, TypeVar +from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar from typing_extensions import deprecated, Self import torch @@ -358,7 +358,7 @@ def forward(self, x): _modules: dict[str, Module] # type: ignore[assignment] - def __init__(self, modules: Iterable[Module] | None = None) -> None: + def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: super().__init__() if modules is not None: self += modules @@ -545,7 +545,7 @@ def forward(self, x, choice, act): _modules: dict[str, Module] # type: ignore[assignment] - def __init__(self, modules: Mapping[str, Module] | None = None) -> None: + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: super().__init__() if modules is not None: self.update(modules) @@ -673,7 +673,7 @@ def forward(self, x): return x """ - def __init__(self, values: Iterable[Any] | None = None) -> None: + def __init__(self, values: Optional[Iterable[Any]] = None) -> None: super().__init__() self._size = 0 if values is not None: @@ -888,7 +888,7 @@ def copy(self) -> ParameterDict: def __contains__(self, key: str) -> bool: return key in self._keys - def setdefault(self, key: str, default: Any | None = None) -> Any: + def setdefault(self, key: str, default: Optional[Any] = None) -> Any: """Set the default for a key in the Parameterdict. If key is in the ParameterDict, return its value. @@ -927,7 +927,7 @@ def popitem(self) -> tuple[str, Any]: del self[k] return k, val - def get(self, key: str, default: Any | None = None) -> Any: + def get(self, key: str, default: Optional[Any] = None) -> Any: r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. Args: @@ -937,7 +937,7 @@ def get(self, key: str, default: Any | None = None) -> Any: return self[key] if key in self else default # noqa: SIM401 def fromkeys( - self, keys: Iterable[str], default: Any | None = None + self, keys: Iterable[str], default: Optional[Any] = None ) -> ParameterDict: r"""Return a new ParameterDict with the keys provided. diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 8b74b6a5a39e8..b539203f6fedd 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -67,7 +67,7 @@ class _ConvNd(Module): __annotations__ = {"bias": Optional[torch.Tensor]} def _conv_forward( # type: ignore[empty-body] - self, input: Tensor, weight: Tensor, bias: Tensor | None + self, input: Tensor, weight: Tensor, bias: Optional[Tensor] ) -> Tensor: ... in_channels: int @@ -82,7 +82,7 @@ def _conv_forward( # type: ignore[empty-body] groups: int padding_mode: Literal["zeros", "reflect", "replicate", "circular"] weight: Tensor - bias: Tensor | None + bias: Optional[Tensor] def __init__( self, @@ -353,7 +353,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != "zeros": return F.conv1d( F.pad( @@ -531,7 +531,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != "zeros": return F.conv2d( F.pad( @@ -701,7 +701,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != "zeros": return F.conv3d( F.pad( @@ -766,12 +766,12 @@ def __init__( def _output_padding( self, input: Tensor, - output_size: list[int] | None, + output_size: Optional[list[int]], stride: list[int], padding: list[int], kernel_size: list[int], num_spatial_dims: int, - dilation: list[int] | None = None, + dilation: Optional[list[int]] = None, ) -> list[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already @@ -965,7 +965,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: + def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: if self.padding_mode != "zeros": raise ValueError( "Only `zeros` padding mode is supported for ConvTranspose1d" @@ -1153,7 +1153,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: + def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: """ Performs the forward pass. @@ -1344,7 +1344,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: + def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: if self.padding_mode != "zeros": raise ValueError( "Only `zeros` padding mode is supported for ConvTranspose3d" diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index 72d90d1c10364..d4c192ee8ce4a 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import itertools -from typing import Any, Protocol +from typing import Any, Optional, Protocol import torch from torch.nn.parameter import is_lazy @@ -167,7 +167,7 @@ class LazyModuleMixin: # modules inheriting from this will change their __class__ to the specified # one after they are fully initialized - cls_to_become: type[Any] | None = None + cls_to_become: Optional[type[Any]] = None def __init__(self: _LazyProtocol, *args, **kwargs): # Mypy doesn't like this super call in a mixin diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 00ada62febded..05b39ba762f47 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs from collections.abc import Callable +from typing import Optional from typing_extensions import deprecated from torch import Tensor @@ -49,14 +50,14 @@ def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> N class _WeightedLoss(_Loss): def __init__( self, - weight: Tensor | None = None, + weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = "mean", ) -> None: super().__init__(size_average, reduce, reduction) self.register_buffer("weight", weight) - self.weight: Tensor | None + self.weight: Optional[Tensor] class L1Loss(_Loss): @@ -240,7 +241,7 @@ class NLLLoss(_WeightedLoss): def __init__( self, - weight: Tensor | None = None, + weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -271,7 +272,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: class NLLLoss2d(NLLLoss): def __init__( self, - weight: Tensor | None = None, + weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -816,17 +817,17 @@ class BCEWithLogitsLoss(_Loss): def __init__( self, - weight: Tensor | None = None, + weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = "mean", - pos_weight: Tensor | None = None, + pos_weight: Optional[Tensor] = None, ) -> None: super().__init__(size_average, reduce, reduction) self.register_buffer("weight", weight) self.register_buffer("pos_weight", pos_weight) - self.weight: Tensor | None - self.pos_weight: Tensor | None + self.weight: Optional[Tensor] + self.pos_weight: Optional[Tensor] def forward(self, input: Tensor, target: Tensor) -> Tensor: """Runs the forward pass.""" @@ -1346,7 +1347,7 @@ class probabilities only when a single class label per minibatch item is too res def __init__( self, - weight: Tensor | None = None, + weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -1625,7 +1626,7 @@ def __init__( self, p: int = 1, margin: float = 1.0, - weight: Tensor | None = None, + weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = "mean", @@ -1868,7 +1869,7 @@ class TripletMarginWithDistanceLoss(_Loss): def __init__( self, *, - distance_function: Callable[[Tensor, Tensor], Tensor] | None = None, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, margin: float = 1.0, swap: bool = False, reduction: str = "mean", @@ -1878,7 +1879,7 @@ def __init__( raise ValueError( f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead" ) - self.distance_function: Callable[[Tensor, Tensor], Tensor] | None = ( + self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ( distance_function if distance_function is not None else PairwiseDistance() ) self.margin = margin diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f9795cc1c74aa..6557f60389964 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -115,7 +115,7 @@ def __setstate__(self, state: dict): purposes""" _global_backward_pre_hooks: dict[int, Callable] = OrderedDict() _global_backward_hooks: dict[int, Callable] = OrderedDict() -_global_is_full_backward_hook: bool | None = None +_global_is_full_backward_hook: Optional[bool] = None _global_forward_pre_hooks: dict[int, Callable] = OrderedDict() _global_forward_hooks: dict[int, Callable] = OrderedDict() _global_forward_hooks_always_called: dict[int, bool] = OrderedDict() @@ -453,12 +453,12 @@ def forward(self, x): the change.""" training: bool - _parameters: dict[str, Parameter | None] - _buffers: dict[str, Tensor | None] + _parameters: dict[str, Optional[Parameter]] + _buffers: dict[str, Optional[Tensor]] _non_persistent_buffers_set: set[str] _backward_pre_hooks: dict[int, Callable] _backward_hooks: dict[int, Callable] - _is_full_backward_hook: bool | None + _is_full_backward_hook: Optional[bool] _forward_hooks: dict[int, Callable] # Marks whether the corresponding _forward_hooks accept kwargs or not. # As JIT does not support set[int], this dict is used as a set, where all @@ -477,7 +477,7 @@ def forward(self, x): _load_state_dict_post_hooks: dict[int, Callable] _modules: dict[str, Optional["Module"]] call_super_init: bool = False - _compiled_call_impl: Callable | None = None + _compiled_call_impl: Optional[Callable] = None def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" @@ -526,7 +526,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: forward: Callable[..., Any] = _forward_unimplemented def register_buffer( - self, name: str, tensor: Tensor | None, persistent: bool = True + self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: r"""Add a buffer to the module. @@ -589,7 +589,7 @@ def register_buffer( else: self._non_persistent_buffers_set.add(name) - def register_parameter(self, name: str, param: Parameter | None) -> None: + def register_parameter(self, name: str, param: Optional[Parameter]) -> None: r"""Add a parameter to the module. The parameter can be accessed as an attribute using given name. @@ -1073,7 +1073,7 @@ def apply(self, fn: Callable[["Module"], None]) -> Self: fn(self) return self - def cuda(self, device: int | device | None = None) -> Self: + def cuda(self, device: Optional[int | device] = None) -> Self: r"""Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So @@ -1092,7 +1092,7 @@ def cuda(self, device: int | device | None = None) -> Self: """ return self._apply(lambda t: t.cuda(device)) - def ipu(self, device: int | device | None = None) -> Self: + def ipu(self, device: Optional[int | device] = None) -> Self: r"""Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So @@ -1111,7 +1111,7 @@ def ipu(self, device: int | device | None = None) -> Self: """ return self._apply(lambda t: t.ipu(device)) - def xpu(self, device: int | device | None = None) -> Self: + def xpu(self, device: Optional[int | device] = None) -> Self: r"""Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So @@ -1130,7 +1130,7 @@ def xpu(self, device: int | device | None = None) -> Self: """ return self._apply(lambda t: t.xpu(device)) - def mtia(self, device: int | device | None = None) -> Self: + def mtia(self, device: Optional[int | device] = None) -> Self: r"""Move all model parameters and buffers to the MTIA. This also makes associated parameters and buffers different objects. So @@ -1218,7 +1218,9 @@ def bfloat16(self) -> Self: """ return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) - def to_empty(self, *, device: DeviceLikeType | None, recurse: bool = True) -> Self: + def to_empty( + self, *, device: Optional[DeviceLikeType], recurse: bool = True + ) -> Self: r"""Move the parameters and buffers to the specified device without copying storage. Args: @@ -1237,8 +1239,8 @@ def to_empty(self, *, device: DeviceLikeType | None, recurse: bool = True) -> Se @overload def to( self, - device: DeviceLikeType | None = ..., - dtype: dtype | None = ..., + device: Optional[DeviceLikeType] = ..., + dtype: Optional[dtype] = ..., non_blocking: bool = ..., ) -> Self: ... @@ -1621,9 +1623,9 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn) -> None: def register_forward_pre_hook( self, - hook: Callable[[T, tuple[Any, ...]], Any | None] + hook: Callable[[T, tuple[Any, ...]], Optional[Any]] | Callable[ - [T, tuple[Any, ...], dict[str, Any]], tuple[Any, dict[str, Any]] | None + [T, tuple[Any, ...], dict[str, Any]], Optional[tuple[Any, dict[str, Any]]] ], *, prepend: bool = False, @@ -1684,8 +1686,8 @@ def register_forward_pre_hook( def register_forward_hook( self, - hook: Callable[[T, tuple[Any, ...], Any], Any | None] - | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Any | None], + hook: Callable[[T, tuple[Any, ...], Any], Optional[Any]] + | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]], *, prepend: bool = False, with_kwargs: bool = False, @@ -2828,7 +2830,7 @@ def modules(self) -> Iterator["Module"]: def named_modules( self, - memo: set["Module"] | None = None, + memo: Optional[set["Module"]] = None, prefix: str = "", remove_duplicate: bool = True, ): diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index d492cdb3cf5a0..4a7302d5cae33 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import numbers -from typing import Union +from typing import Optional, Union import torch from torch import Size, Tensor @@ -375,13 +375,13 @@ class RMSNorm(Module): __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: tuple[int, ...] - eps: float | None + eps: Optional[float] elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, - eps: float | None = None, + eps: Optional[float] = None, elementwise_affine: bool = True, device=None, dtype=None, diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 1dc57c25b1683..777e6b0abd8c4 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch.nn.functional as F from torch import Tensor from torch.nn.common_types import ( @@ -55,7 +57,7 @@ class _MaxPoolNd(Module): def __init__( self, kernel_size: _size_any_t, - stride: _size_any_t | None = None, + stride: Optional[_size_any_t] = None, padding: _size_any_t = 0, dilation: _size_any_t = 1, return_indices: bool = False, @@ -387,7 +389,7 @@ class MaxUnpool1d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_1_t, - stride: _size_1_t | None = None, + stride: Optional[_size_1_t] = None, padding: _size_1_t = 0, ) -> None: super().__init__() @@ -396,7 +398,7 @@ def __init__( self.padding = _single(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: list[int] | None = None + self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool1d( @@ -483,7 +485,7 @@ class MaxUnpool2d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_2_t, - stride: _size_2_t | None = None, + stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, ) -> None: super().__init__() @@ -492,7 +494,7 @@ def __init__( self.padding = _pair(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: list[int] | None = None + self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool2d( @@ -562,7 +564,7 @@ class MaxUnpool3d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_3_t, - stride: _size_3_t | None = None, + stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, ) -> None: super().__init__() @@ -571,7 +573,7 @@ def __init__( self.padding = _triple(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: list[int] | None = None + self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool3d( @@ -760,11 +762,11 @@ class AvgPool2d(_AvgPoolNd): def __init__( self, kernel_size: _size_2_t, - stride: _size_2_t | None = None, + stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: int | None = None, + divisor_override: Optional[int] = None, ) -> None: super().__init__() self.kernel_size = kernel_size @@ -877,11 +879,11 @@ class AvgPool3d(_AvgPoolNd): def __init__( self, kernel_size: _size_3_t, - stride: _size_3_t | None = None, + stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: int | None = None, + divisor_override: Optional[int] = None, ) -> None: super().__init__() self.kernel_size = kernel_size @@ -962,8 +964,8 @@ class FractionalMaxPool2d(Module): def __init__( self, kernel_size: _size_2_t, - output_size: _size_2_t | None = None, - output_ratio: _ratio_2_t | None = None, + output_size: Optional[_size_2_t] = None, + output_ratio: Optional[_ratio_2_t] = None, return_indices: bool = False, _random_samples=None, ) -> None: @@ -1048,8 +1050,8 @@ class FractionalMaxPool3d(Module): def __init__( self, kernel_size: _size_3_t, - output_size: _size_3_t | None = None, - output_ratio: _ratio_3_t | None = None, + output_size: Optional[_size_3_t] = None, + output_ratio: Optional[_ratio_3_t] = None, return_indices: bool = False, _random_samples=None, ) -> None: @@ -1104,7 +1106,7 @@ def __init__( self, norm_type: float, kernel_size: _size_any_t, - stride: _size_any_t | None = None, + stride: Optional[_size_any_t] = None, ceil_mode: bool = False, ) -> None: super().__init__() diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 68e8292870fc8..13cd9ec08cb55 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -4,7 +4,7 @@ import numbers import warnings import weakref -from typing import overload +from typing import Optional, overload from typing_extensions import deprecated import torch @@ -106,7 +106,7 @@ def __init__( self.dropout = float(dropout) self.bidirectional = bidirectional self.proj_size = proj_size - self._flat_weight_refs: list[weakref.ReferenceType[Parameter] | None] = [] + self._flat_weight_refs: list[Optional[weakref.ReferenceType[Parameter]]] = [] num_directions = 2 if bidirectional else 1 if ( @@ -298,7 +298,7 @@ def reset_parameters(self) -> None: for weight in self.parameters(): init.uniform_(weight, -stdv, stdv) - def check_input(self, input: Tensor, batch_sizes: Tensor | None) -> None: + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: if not torch.jit.is_scripting(): if ( input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr] @@ -318,7 +318,7 @@ def check_input(self, input: Tensor, batch_sizes: Tensor | None) -> None: ) def get_expected_hidden_size( - self, input: Tensor, batch_sizes: Tensor | None + self, input: Tensor, batch_sizes: Optional[Tensor] ) -> tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) @@ -362,14 +362,14 @@ def _weights_have_changed(self): return weights_changed def check_forward_args( - self, input: Tensor, hidden: Tensor, batch_sizes: Tensor | None + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] ) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) self.check_hidden_size(hidden, expected_hidden_size) - def permute_hidden(self, hx: Tensor, permutation: Tensor | None): + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): if permutation is None: return hx return _apply_permutation(hx, permutation) @@ -645,7 +645,7 @@ def __init__(self, *args, **kwargs): def forward( self, input: Tensor, - hx: Tensor | None = None, + hx: Optional[Tensor] = None, ) -> tuple[Tensor, Tensor]: pass @@ -654,7 +654,7 @@ def forward( def forward( self, input: PackedSequence, - hx: Tensor | None = None, + hx: Optional[Tensor] = None, ) -> tuple[PackedSequence, Tensor]: pass @@ -990,7 +990,7 @@ def __init__(self, *args, **kwargs): super().__init__("LSTM", *args, **kwargs) def get_expected_cell_size( - self, input: Tensor, batch_sizes: Tensor | None + self, input: Tensor, batch_sizes: Optional[Tensor] ) -> tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) @@ -1010,7 +1010,7 @@ def check_forward_args( self, input: Tensor, hidden: tuple[Tensor, Tensor], # type: ignore[override] - batch_sizes: Tensor | None, + batch_sizes: Optional[Tensor], ) -> None: self.check_input(input, batch_sizes) self.check_hidden_size( @@ -1028,7 +1028,7 @@ def check_forward_args( def permute_hidden( # type: ignore[override] self, hx: tuple[Tensor, Tensor], - permutation: Tensor | None, + permutation: Optional[Tensor], ) -> tuple[Tensor, Tensor]: if permutation is None: return hx @@ -1042,7 +1042,7 @@ def permute_hidden( # type: ignore[override] def forward( self, input: Tensor, - hx: tuple[Tensor, Tensor] | None = None, + hx: Optional[tuple[Tensor, Tensor]] = None, ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811 pass @@ -1052,7 +1052,7 @@ def forward( def forward( self, input: PackedSequence, - hx: tuple[Tensor, Tensor] | None = None, + hx: Optional[tuple[Tensor, Tensor]] = None, ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811 pass @@ -1338,7 +1338,7 @@ def __init__(self, *args, **kwargs): def forward( self, input: Tensor, - hx: Tensor | None = None, + hx: Optional[Tensor] = None, ) -> tuple[Tensor, Tensor]: # noqa: F811 pass @@ -1347,7 +1347,7 @@ def forward( def forward( self, input: PackedSequence, - hx: Tensor | None = None, + hx: Optional[Tensor] = None, ) -> tuple[PackedSequence, Tensor]: # noqa: F811 pass @@ -1584,7 +1584,7 @@ def __init__( super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) self.nonlinearity = nonlinearity - def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: if input.dim() not in (1, 2): raise ValueError( f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" @@ -1704,7 +1704,7 @@ def __init__( super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) def forward( - self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None ) -> tuple[Tensor, Tensor]: if input.dim() not in (1, 2): raise ValueError( @@ -1815,7 +1815,7 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) - def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: if input.dim() not in (1, 2): raise ValueError( f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index f5775f63ff4ad..abcd7240a742c 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -2,7 +2,7 @@ import copy import warnings from collections.abc import Callable -from typing import Any +from typing import Any, Optional import torch import torch.nn.functional as F @@ -28,8 +28,8 @@ def _generate_square_subsequent_mask( sz: int, - device: torch.device | None = None, - dtype: torch.dtype | None = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. @@ -41,7 +41,7 @@ def _generate_square_subsequent_mask( ) -def _get_seq_len(src: Tensor, batch_first: bool) -> int | None: +def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: if src.is_nested: return None else: @@ -106,8 +106,8 @@ def __init__( dim_feedforward: int = 2048, dropout: float = 0.1, activation: str | Callable[[Tensor], Tensor] = F.relu, - custom_encoder: Any | None = None, - custom_decoder: Any | None = None, + custom_encoder: Optional[Any] = None, + custom_decoder: Optional[Any] = None, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, @@ -182,14 +182,14 @@ def forward( self, src: Tensor, tgt: Tensor, - src_mask: Tensor | None = None, - tgt_mask: Tensor | None = None, - memory_mask: Tensor | None = None, - src_key_padding_mask: Tensor | None = None, - tgt_key_padding_mask: Tensor | None = None, - memory_key_padding_mask: Tensor | None = None, - src_is_causal: bool | None = None, - tgt_is_causal: bool | None = None, + src_mask: Optional[Tensor] = None, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + src_is_causal: Optional[bool] = None, + tgt_is_causal: Optional[bool] = None, memory_is_causal: bool = False, ) -> Tensor: r"""Take in and process masked source/target sequences. @@ -301,8 +301,8 @@ def forward( @staticmethod def generate_square_subsequent_mask( sz: int, - device: torch.device | None = None, - dtype: torch.dtype | None = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. @@ -354,7 +354,7 @@ def __init__( self, encoder_layer: "TransformerEncoderLayer", num_layers: int, - norm: Module | None = None, + norm: Optional[Module] = None, enable_nested_tensor: bool = True, mask_check: bool = True, ) -> None: @@ -407,9 +407,9 @@ def __init__( def forward( self, src: Tensor, - mask: Tensor | None = None, - src_key_padding_mask: Tensor | None = None, - is_causal: bool | None = None, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -587,7 +587,7 @@ def __init__( self, decoder_layer: "TransformerDecoderLayer", num_layers: int, - norm: Module | None = None, + norm: Optional[Module] = None, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") @@ -599,11 +599,11 @@ def forward( self, tgt: Tensor, memory: Tensor, - tgt_mask: Tensor | None = None, - memory_mask: Tensor | None = None, - tgt_key_padding_mask: Tensor | None = None, - memory_key_padding_mask: Tensor | None = None, - tgt_is_causal: bool | None = None, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: Optional[bool] = None, memory_is_causal: bool = False, ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer in turn. @@ -798,8 +798,8 @@ def __setstate__(self, state): def forward( self, src: Tensor, - src_mask: Tensor | None = None, - src_key_padding_mask: Tensor | None = None, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, is_causal: bool = False, ) -> Tensor: r"""Pass the input through the encoder layer. @@ -959,8 +959,8 @@ def forward( def _sa_block( self, x: Tensor, - attn_mask: Tensor | None, - key_padding_mask: Tensor | None, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], is_causal: bool = False, ) -> Tensor: x = self.self_attn( @@ -1088,10 +1088,10 @@ def forward( self, tgt: Tensor, memory: Tensor, - tgt_mask: Tensor | None = None, - memory_mask: Tensor | None = None, - tgt_key_padding_mask: Tensor | None = None, - memory_key_padding_mask: Tensor | None = None, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: bool = False, memory_is_causal: bool = False, ) -> Tensor: @@ -1156,8 +1156,8 @@ def forward( def _sa_block( self, x: Tensor, - attn_mask: Tensor | None, - key_padding_mask: Tensor | None, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], is_causal: bool = False, ) -> Tensor: x = self.self_attn( @@ -1176,8 +1176,8 @@ def _mha_block( self, x: Tensor, mem: Tensor, - attn_mask: Tensor | None, - key_padding_mask: Tensor | None, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], is_causal: bool = False, ) -> Tensor: x = self.multihead_attn( @@ -1212,9 +1212,9 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: def _detect_is_causal_mask( - mask: Tensor | None, - is_causal: bool | None = None, - size: int | None = None, + mask: Optional[Tensor], + is_causal: Optional[bool] = None, + size: Optional[int] = None, ) -> bool: """Return whether the given attention mask is causal. diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index 29e58bc6a9f37..7fd102a768225 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Optional import torch.nn.functional as F from torch import Tensor @@ -142,19 +143,19 @@ class Upsample(Module): "recompute_scale_factor", ] name: str - size: _size_any_t | None - scale_factor: _ratio_any_t | None + size: Optional[_size_any_t] + scale_factor: Optional[_ratio_any_t] mode: str - align_corners: bool | None - recompute_scale_factor: bool | None + align_corners: Optional[bool] + recompute_scale_factor: Optional[bool] def __init__( self, - size: _size_any_t | None = None, - scale_factor: _ratio_any_t | None = None, + size: Optional[_size_any_t] = None, + scale_factor: Optional[_ratio_any_t] = None, mode: str = "nearest", - align_corners: bool | None = None, - recompute_scale_factor: bool | None = None, + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, ) -> None: super().__init__() self.name = type(self).__name__ @@ -241,8 +242,8 @@ class UpsamplingNearest2d(Upsample): def __init__( self, - size: _size_2_t | None = None, - scale_factor: _ratio_2_t | None = None, + size: Optional[_size_2_t] = None, + scale_factor: Optional[_ratio_2_t] = None, ) -> None: super().__init__(size, scale_factor, mode="nearest") @@ -292,7 +293,7 @@ class UpsamplingBilinear2d(Upsample): def __init__( self, - size: _size_2_t | None = None, - scale_factor: _ratio_2_t | None = None, + size: Optional[_size_2_t] = None, + scale_factor: Optional[_ratio_2_t] = None, ) -> None: super().__init__(size, scale_factor, mode="bilinear", align_corners=True) diff --git a/torch/overrides.py b/torch/overrides.py index b1193bab3d6dc..e0597eafd8107 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -30,7 +30,7 @@ import warnings from collections.abc import Callable, Iterable from functools import wraps -from typing import Any, TypeVar +from typing import Any, Optional, TypeVar from typing_extensions import ParamSpec import torch @@ -1609,7 +1609,7 @@ def wrapped(*args, **kwargs): def _get_overloaded_args( relevant_args: Iterable[Any], - get_type_fn: Callable[[Any], type] | None = None, + get_type_fn: Optional[Callable[[Any], type]] = None, ) -> list[Any]: """Returns a list of arguments on which to call __torch_function__. diff --git a/torch/quasirandom.py b/torch/quasirandom.py index f9e6619cab180..b5d4540e592f1 100644 --- a/torch/quasirandom.py +++ b/torch/quasirandom.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Optional import torch @@ -77,8 +78,8 @@ def __init__(self, dimension, scramble=False, seed=None): def draw( self, n: int = 1, - out: torch.Tensor | None = None, - dtype: torch.dtype | None = None, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: r""" Function to draw a sequence of :attr:`n` points from a Sobol sequence. @@ -130,8 +131,8 @@ def draw( def draw_base2( self, m: int, - out: torch.Tensor | None = None, - dtype: torch.dtype | None = None, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: r""" Function to draw a sequence of :attr:`2**m` points from a Sobol sequence. @@ -186,7 +187,7 @@ def fast_forward(self, n): return self def _scramble(self): - g: torch.Generator | None = None + g: Optional[torch.Generator] = None if self.seed is not None: g = torch.Generator() g.manual_seed(self.seed) diff --git a/torch/serialization.py b/torch/serialization.py index 1a6acc8010634..398d011f324b5 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -16,7 +16,7 @@ from collections.abc import Callable from contextlib import closing, contextmanager from enum import Enum -from typing import Any, cast, Generic, IO, TypeAlias, TypeVar +from typing import Any, cast, Generic, IO, Optional, TypeAlias, TypeVar, Union from typing_extensions import TypeIs import torch @@ -66,10 +66,10 @@ PROTOCOL_VERSION = 1001 STORAGE_KEY_SEPARATOR = "," -MAP_LOCATION: TypeAlias = ( - Callable[[Storage, str], Storage] | torch.device | str | dict[str, str] | None -) -STORAGE: TypeAlias = Storage | torch.storage.TypedStorage | torch.UntypedStorage +MAP_LOCATION: TypeAlias = Optional[ + Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]] +] +STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] IS_WINDOWS = sys.platform == "win32" @@ -99,7 +99,7 @@ def _default_to_weights_only(pickle_module): class _SerializationLocal(threading.local): def __init__(self): super().__init__() - self.map_location: MAP_LOCATION | None = None + self.map_location: Optional[MAP_LOCATION] = None self.skip_data: bool = False self.materialize_fake_tensors: bool = False @@ -123,8 +123,8 @@ def mkdtemp(): _package_registry: list[ tuple[ int, - Callable[[STORAGE], str | None], - Callable[[STORAGE, str], STORAGE | None], + Callable[[STORAGE], Optional[str]], + Callable[[STORAGE, str], Optional[STORAGE]], ] ] = [] @@ -135,7 +135,7 @@ class LoadEndianness(Enum): BIG = 3 -def get_default_load_endianness() -> LoadEndianness | None: +def get_default_load_endianness() -> Optional[LoadEndianness]: """ Get fallback byte order for loading files @@ -197,7 +197,7 @@ def set_crc32_options(compute_crc32: bool): config.save.compute_crc32 = compute_crc32 -def get_default_mmap_options() -> int | None: +def get_default_mmap_options() -> Optional[int]: """ Get default mmap options for :func:`torch.load` with ``mmap=True``. @@ -272,14 +272,14 @@ def clear_safe_globals() -> None: _weights_only_unpickler._clear_safe_globals() -def get_safe_globals() -> list[Callable | tuple[Callable, str]]: +def get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: """ Returns the list of user-added globals that are safe for ``weights_only`` load. """ return _weights_only_unpickler._get_safe_globals() -def add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]) -> None: +def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]) -> None: """ Marks the given globals as safe for ``weights_only`` load. For example, functions added to this list can be called during unpickling, classes could be instantiated @@ -443,8 +443,8 @@ def _is_zipfile(f) -> bool: def register_package( priority: int, - tagger: Callable[[STORAGE], str | None], - deserializer: Callable[[STORAGE, str], STORAGE | None], + tagger: Callable[[STORAGE], Optional[str]], + deserializer: Callable[[STORAGE, str], Optional[STORAGE]], ): """ Registers callables for tagging and deserializing storage objects with an associated priority. @@ -672,7 +672,7 @@ def _deserialize(backend_name, obj, location): def location_tag( - storage: Storage | torch.storage.TypedStorage | torch.UntypedStorage, + storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage], ): for _, tagger, _ in _package_registry: location = tagger(storage) @@ -726,7 +726,7 @@ def storage_to_tensor_type(storage): return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) -def _is_path(name_or_buffer: object) -> TypeIs[str | os.PathLike]: +def _is_path(name_or_buffer: object) -> TypeIs[Union[str, os.PathLike]]: return isinstance(name_or_buffer, (str, os.PathLike)) @@ -745,7 +745,7 @@ def __exit__(self, *args): class _open_file(_opener[IO[bytes]]): - def __init__(self, name: str | os.PathLike[str], mode: str) -> None: + def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None: super().__init__(open(name, mode)) # noqa: SIM115 def __exit__(self, *args): @@ -776,7 +776,7 @@ def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]: class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]): - def __init__(self, name_or_buffer: str | IO[bytes]) -> None: + def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None: super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) @@ -829,7 +829,7 @@ def __exit__(self, *args) -> None: self.buffer.flush() -def _open_zipfile_writer(name_or_buffer: str | IO[bytes]) -> _opener: +def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener: container: type[_opener] if _is_path(name_or_buffer): container = _open_zipfile_writer_file @@ -1004,7 +1004,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: # TODO: This feature could be added in the future storage_dtypes: dict[int, torch.dtype] = {} - def persistent_id(obj: Any) -> tuple | None: + def persistent_id(obj: Any) -> Optional[tuple]: # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see @@ -1064,7 +1064,7 @@ def persistent_id(obj: Any) -> tuple | None: else: storage_dtypes[storage.data_ptr()] = storage_dtype - view_metadata: tuple[str, int, int] | None + view_metadata: Optional[tuple[str, int, int]] # Offset is always 0, but we keep it for backwards compatibility # with the old serialization format (which supported storage views) @@ -1291,8 +1291,8 @@ def load( map_location: MAP_LOCATION = None, pickle_module: Any = None, *, - weights_only: bool | None = None, - mmap: bool | None = None, + weights_only: Optional[bool] = None, + mmap: Optional[bool] = None, **pickle_load_args: Any, ) -> Any: # Reference: https://github.com/pytorch/pytorch/issues/54354 @@ -1852,7 +1852,7 @@ def persistent_load(saved_id): return result -def _maybe_decode_ascii(bytes_str: bytes | str) -> str: +def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: # When using encoding='bytes' in Py3, some **internal** keys stored as # strings in Py2 are loaded as bytes. This function decodes them with # ascii encoding, one that Py3 uses by default. diff --git a/torch/storage.py b/torch/storage.py index 29847d958523d..1b9023121ddfb 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -8,7 +8,7 @@ import io import threading import warnings -from typing import Any, cast, TYPE_CHECKING, TypeVar +from typing import Any, cast, Optional as _Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import Self import torch @@ -35,7 +35,7 @@ _share_memory_lock = threading.Lock() _share_memory_map: dict[int, threading.RLock] = {} -T = TypeVar("T", bound="_StorageBase | TypedStorage") +T = TypeVar("T", bound="Union[_StorageBase, TypedStorage]") class _StorageBase: @@ -46,9 +46,9 @@ class _StorageBase: # Used when # (1) stashing FakeTensor device onto storage in torch.serialization.skip_data # (2) stashing device onto storage to propagate to FakeTensor when torch.load under FakeTensorMode - _fake_device: torch.device | None = None + _fake_device: _Optional[torch.device] = None # Used when loading with FakeTensorMode to give information about offset of storage in torch.saved-file - _checkpoint_offset: int | None = None + _checkpoint_offset: _Optional[int] = None def __init__(self, *args, **kwargs): pass @@ -62,10 +62,10 @@ def __getitem__(self, idx): def __setitem__(self, *args, **kwargs): raise NotImplementedError - def copy_(self, source: T, non_blocking: _bool | None = None) -> T: + def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T: raise NotImplementedError - def new(self) -> _StorageBase | TypedStorage: + def new(self) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError def nbytes(self) -> _int: @@ -75,11 +75,13 @@ def size(self) -> _int: return self.nbytes() def type( - self, dtype: str | None = None, non_blocking: _bool = False - ) -> _StorageBase | TypedStorage: + self, dtype: _Optional[str] = None, non_blocking: _bool = False + ) -> Union[_StorageBase, TypedStorage]: return _type(self, dtype, non_blocking) - def cuda(self, device=None, non_blocking=False) -> _StorageBase | TypedStorage: + def cuda( + self, device=None, non_blocking=False + ) -> Union[_StorageBase, TypedStorage]: """Returns a copy of this object in CUDA memory. If this object is already in CUDA memory and on the correct device, then @@ -94,7 +96,7 @@ def cuda(self, device=None, non_blocking=False) -> _StorageBase | TypedStorage: device2 = torch.device("cuda", device) if device else torch.device("cuda") return self.to(device=device2, non_blocking=non_blocking) - def hpu(self, device=None, non_blocking=False) -> _StorageBase | TypedStorage: + def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]: """Returns a copy of this object in HPU memory. If this object is already in HPU memory and on the correct device, then @@ -164,7 +166,7 @@ def _release_ipc_counter_cuda(cls, *args, **kwargs) -> Self: def _new_with_weak_ptr(cls, *args, **kwargs) -> Self: raise NotImplementedError - def _shared_decref(self) -> _StorageBase | TypedStorage: + def _shared_decref(self) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError def _write_file(self, *args, **kwargs): @@ -173,7 +175,7 @@ def _write_file(self, *args, **kwargs): def resize_(self, size: _int): raise NotImplementedError - def _weak_ref(self, *args, **kwargs) -> _StorageBase | TypedStorage: + def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError def _set_from_file(self, *args, **kwargs): @@ -208,17 +210,17 @@ def is_hpu(self): raise NotImplementedError @classmethod - def from_file(cls, filename, shared, nbytes) -> _StorageBase | TypedStorage: + def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError @classmethod - def _expired(cls, *args, **kwargs) -> _StorageBase | TypedStorage: + def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: raise NotImplementedError def _byteswap(self, *args, **kwargs): raise NotImplementedError - def _get_filename(self, *args, **kwargs) -> str | None: + def _get_filename(self, *args, **kwargs) -> _Optional[str]: raise NotImplementedError def __repr__(self): @@ -352,7 +354,7 @@ def float8_e4m3fnuz(self): """Casts this storage to float8_e4m3fnuz type""" return self._to(torch.float8_e4m3fnuz) - def is_pinned(self, device: str | torch.device = "cuda"): + def is_pinned(self, device: Union[str, torch.device] = "cuda"): r"""Determine whether the CPU storage is already pinned on device. Args: @@ -368,7 +370,7 @@ def is_pinned(self, device: str | torch.device = "cuda"): .is_pinned(device) ) - def pin_memory(self, device: str | torch.device = "cuda"): + def pin_memory(self, device: Union[str, torch.device] = "cuda"): r"""Copy the CPU storage to pinned memory, if it's not already pinned. Args: @@ -476,7 +478,7 @@ def is_hpu(self): return self.device.type == "hpu" @property - def filename(self) -> str | None: + def filename(self) -> _Optional[str]: """Returns the file name associated with this storage. The file name will be a string if the storage is on CPU and was created via @@ -669,7 +671,7 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: torch.device | None = None + _fake_device: _Optional[torch.device] = None dtype: torch.dtype @@ -678,7 +680,7 @@ def _dtype(self): return self.dtype @property - def filename(self) -> str | None: + def filename(self) -> _Optional[str]: """Returns the file name associated with this storage if the storage was memory mapped from a file. or ``None`` if the storage was not created by memory mapping a file.""" return self._untyped_storage.filename @@ -1016,7 +1018,7 @@ def _getitem(self, idx): ).set_(self) return tmp_tensor[idx_wrapped].item() - def copy_(self, source: T, non_blocking: bool | None = None): + def copy_(self, source: T, non_blocking: _Optional[bool] = None): _warn_typed_storage_removal() if isinstance(source, TypedStorage): self._untyped_storage.copy_(source._untyped_storage, non_blocking) @@ -1034,9 +1036,9 @@ def _nbytes(self): def type( self, - dtype: str | None = None, + dtype: _Optional[str] = None, non_blocking: bool = False, - ) -> _StorageBase | TypedStorage | str: + ) -> Union[_StorageBase, TypedStorage, str]: _warn_typed_storage_removal() if dtype is None: legacy_class = self._get_legacy_storage_class() @@ -1155,7 +1157,7 @@ def cpu(self): _warn_typed_storage_removal() return self._new_wrapped_storage(self._untyped_storage.cpu()) - def is_pinned(self, device: str | torch.device = "cuda"): + def is_pinned(self, device: Union[str, torch.device] = "cuda"): r"""Determine whether the CPU TypedStorage is already pinned on device. Args: @@ -1168,7 +1170,7 @@ def is_pinned(self, device: str | torch.device = "cuda"): _warn_typed_storage_removal() return self._untyped_storage.is_pinned(device) - def pin_memory(self, device: str | torch.device = "cuda"): + def pin_memory(self, device: Union[str, torch.device] = "cuda"): r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned. Args: diff --git a/torch/types.py b/torch/types.py index 9ed69a859b1ee..0388c9c66aefe 100644 --- a/torch/types.py +++ b/torch/types.py @@ -38,7 +38,7 @@ # Convenience aliases for common composite types that we need # to talk about in PyTorch -_TensorOrTensors: TypeAlias = Tensor | Sequence[Tensor] # noqa: PYI047 +_TensorOrTensors: TypeAlias = Union[Tensor, Sequence[Tensor]] # noqa: PYI047 _TensorOrTensorsOrGradEdge: TypeAlias = Union[ # noqa: PYI047 Tensor, Sequence[Tensor], @@ -46,32 +46,32 @@ Sequence["GradientEdge"], ] -_size: TypeAlias = Size | list[int] | tuple[int, ...] # noqa: PYI042,PYI047 -_symsize: TypeAlias = Size | Sequence[int | SymInt] # noqa: PYI042,PYI047 -_dispatchkey: TypeAlias = str | DispatchKey # noqa: PYI042,PYI047 +_size: TypeAlias = Union[Size, list[int], tuple[int, ...]] # noqa: PYI042,PYI047 +_symsize: TypeAlias = Union[Size, Sequence[Union[int, SymInt]]] # noqa: PYI042,PYI047 +_dispatchkey: TypeAlias = Union[str, DispatchKey] # noqa: PYI042,PYI047 # int or SymInt -IntLikeType: TypeAlias = int | SymInt +IntLikeType: TypeAlias = Union[int, SymInt] # float or SymFloat -FloatLikeType: TypeAlias = float | SymFloat +FloatLikeType: TypeAlias = Union[float, SymFloat] # bool or SymBool -BoolLikeType: TypeAlias = bool | SymBool +BoolLikeType: TypeAlias = Union[bool, SymBool] py_sym_types = (SymInt, SymFloat, SymBool) # left un-annotated intentionally -PySymType: TypeAlias = SymInt | SymFloat | SymBool +PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool] # Meta-type for "numeric" things; matches our docs -Number: TypeAlias = int | float | bool +Number: TypeAlias = Union[int, float, bool] # tuple for isinstance(x, Number) checks. # FIXME: refactor once python 3.9 support is dropped. _Number = (int, float, bool) -FileLike: TypeAlias = str | os.PathLike[str] | IO[bytes] +FileLike: TypeAlias = Union[str, os.PathLike[str], IO[bytes]] # Meta-type for "device-like" things. Not to be confused with 'device' (a # literal device object). This nomenclature is consistent with PythonArgParser. # None means use the default device (typically CPU) -Device: TypeAlias = _device | str | int | None +Device: TypeAlias = Union[_device, str, int, None] # Storage protocol implemented by ${Type}StorageBase classes diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 6cb4f9b9c012b..194684e3388e4 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -218,7 +218,7 @@ def set_device(device: _device_t) -> None: torch._C._xpu_setDevice(device) -def get_device_name(device: _device_t | None = None) -> str: +def get_device_name(device: Optional[_device_t] = None) -> str: r"""Get the name of a device. Args: @@ -234,7 +234,7 @@ def get_device_name(device: _device_t | None = None) -> str: @lru_cache(None) -def get_device_capability(device: _device_t | None = None) -> dict[str, Any]: +def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: r"""Get the xpu capability of a device. Args: @@ -259,7 +259,7 @@ def get_device_capability(device: _device_t | None = None) -> dict[str, Any]: def get_device_properties( - device: _device_t | None = None, + device: Optional[_device_t] = None, ) -> _XpuDeviceProperties: # pyrefly: ignore # not-a-type r"""Get the properties of a device. @@ -281,7 +281,7 @@ def current_device() -> int: return torch._C._xpu_getDevice() -def _get_device(device: int | str | torch.device) -> torch.device: +def _get_device(device: Union[int, str, torch.device]) -> torch.device: r"""Return the torch.device type object from the passed in device. Args: @@ -395,7 +395,7 @@ def set_stream(stream: Stream) -> None: ) -def current_stream(device: _device_t | None = None) -> Stream: +def current_stream(device: Optional[_device_t] = None) -> Stream: r"""Return the currently selected :class:`Stream` for a given device. Args: @@ -413,7 +413,9 @@ def current_stream(device: _device_t | None = None) -> Stream: ) -def get_stream_from_external(data_ptr: int, device: _device_t | None = None) -> Stream: +def get_stream_from_external( + data_ptr: int, device: Optional[_device_t] = None +) -> Stream: r"""Return a :class:`Stream` from an external SYCL queue. This function is used to wrap SYCL queue created in other libraries in order @@ -482,7 +484,7 @@ def _get_generator(device: torch.device) -> torch._C.Generator: def _set_rng_state_offset( - offset: int, device: int | str | torch.device = "xpu" + offset: int, device: Union[int, str, torch.device] = "xpu" ) -> None: r"""Set the random number generator state offset of the specified GPU. @@ -500,7 +502,7 @@ def cb() -> None: _lazy_call(cb) -def _get_rng_state_offset(device: int | str | torch.device = "xpu") -> int: +def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: r"""Return the random number generator state offset of the specified GPU. Args: diff --git a/torch/xpu/random.py b/torch/xpu/random.py index 8b489e871f7c5..ec770225aef39 100644 --- a/torch/xpu/random.py +++ b/torch/xpu/random.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs from collections.abc import Iterable +from typing import Union import torch from torch import Tensor @@ -7,7 +8,7 @@ from . import _lazy_call, _lazy_init, current_device, device_count -def get_rng_state(device: int | str | torch.device = "xpu") -> Tensor: +def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor: r"""Return the random number generator state of the specified GPU as a ByteTensor. Args: @@ -35,7 +36,9 @@ def get_rng_state_all() -> list[Tensor]: return results -def set_rng_state(new_state: Tensor, device: int | str | torch.device = "xpu") -> None: +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "xpu" +) -> None: r"""Set the random number generator state of the specified GPU. Args: From 135f3753c418a6879b1954904184937b67e61688 Mon Sep 17 00:00:00 2001 From: Su Tong Date: Wed, 3 Dec 2025 04:49:24 +0000 Subject: [PATCH 162/338] [xpu][feature] [3/3] Register the `scaled_mm` and `scaled_mm_v2` for xpu (#166056) This PR registers the `scaled_mm` op for XPU support. It does the following: 1. Registered the `_scaled_mm` and `_scaled_mm_v2` op for XPU. 2. Enables XPU tests in `test_scaled_matmul_cuda.py`. 3. Update torch-xpu-ops pin to remove fallback `scaled_mm` to CPU implementation. ## PR Stack: - https://github.com/pytorch/pytorch/pull/165978 : implementation of XPU scaled_mm and oneDNN kernel - https://github.com/pytorch/pytorch/pull/167518 : implementation of XPU scaled_mm_v2 - -> https://github.com/pytorch/pytorch/pull/166056 : Op registration ## Task tracker: We will track all the scaled_mm related tasks in: https://github.com/pytorch/pytorch/issues/167170 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166056 Approved by: https://github.com/EikanWang, https://github.com/slayton58, https://github.com/drisspg --- aten/src/ATen/native/native_functions.yaml | 4 + test/inductor/test_cutlass_backend.py | 15 +- test/inductor/test_fp8.py | 131 ++++++++++++------ test/test_scaled_matmul_cuda.py | 122 +++++++++------- torch/_meta_registrations.py | 4 +- .../aoti_torch/generated/c_shim_xpu.h | 2 + torch/testing/_internal/common_cuda.py | 2 + 7 files changed, 180 insertions(+), 100 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 81a782f733245..39df81ff44bce 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7278,6 +7278,7 @@ dispatch: CPU: _scaled_mm_cpu CUDA: _scaled_mm_cuda + XPU: _scaled_mm_xpu tags: needs_exact_strides @@ -7286,17 +7287,20 @@ dispatch: CPU: _scaled_mm_out_cpu CUDA: _scaled_mm_out_cuda + XPU: _scaled_mm_out_xpu tags: needs_exact_strides - func: _scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor variants: function dispatch: CUDA: _scaled_mm_cuda_v2 + XPU: _scaled_mm_xpu_v2 - func: _scaled_mm_v2.out(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CUDA: _scaled_mm_cuda_v2_out + XPU: _scaled_mm_xpu_v2_out - func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 55f8dd5d24ebc..b4c4f6f18f1eb 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -2096,7 +2096,10 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str): for op, deserialized_op in zip(ops, deserialized_ops, strict=False): self.assertTrue(_check_if_instances_equal(op, deserialized_op)) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+") + @unittest.skipIf( + torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+", + ) @unittest.skipIf(not SM90OrLater, "need sm_90") @fp8_config @parametrize("float8_dtype", (torch.float8_e4m3fn,)) @@ -2170,7 +2173,10 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, output_dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+") + @unittest.skipIf( + torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+", + ) @unittest.skipIf(not SM90OrLater, "need sm_90") @fp8_config @parametrize("float8_dtype", (torch.float8_e4m3fn,)) @@ -2264,7 +2270,10 @@ def forward(self, x): torch.testing.assert_close(expected, actual, rtol=1e-2, atol=0.05) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+") + @unittest.skipIf( + torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+", + ) @unittest.skipIf(not SM90OrLater, "need sm_90") @fp8_config @parametrize("float8_dtype", (torch.float8_e4m3fn,)) diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index f1067b8ffebb3..621f4b4632f7a 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -17,11 +17,13 @@ PLATFORM_SUPPORTS_FP8, PLATFORM_SUPPORTS_MX_GEMM, ) -from torch.testing._internal.common_quantized import ceil_div, to_blocked -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCUDA, + onlyOn, ) +from torch.testing._internal.common_quantized import ceil_div, to_blocked +from torch.testing._internal.common_utils import parametrize from torch.testing._internal.inductor_utils import ( _quantize_blockwise, _quantize_rowwise, @@ -36,7 +38,7 @@ torch.set_float32_matmul_precision("high") -f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" +f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ and XPU devices" def _fix_fp8_dtype_for_rocm( @@ -66,10 +68,8 @@ def _fix_fp8_dtype_for_rocm( return dtype -@instantiate_parametrized_tests class TestFP8Types(TestCase): @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) - @parametrize("device", ("cuda", "cpu")) def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str): """ TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4 @@ -92,7 +92,6 @@ def f(x): torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2) @parametrize("dtype", (torch.float16, torch.bfloat16)) - @parametrize("device", ("cuda", "cpu")) def test_eager_fallback(self, dtype: torch.dtype, device: torch.device): if device == "cuda" and not PLATFORM_SUPPORTS_FP8: raise unittest.SkipTest(f8_msg) @@ -137,7 +136,6 @@ def fp8_matmul_unwrapped(x): @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("shape", ("15,3,13", "4,2048,4096")) @parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)]) - @parametrize("device", ("cuda", "cpu")) def test_valid_cast( self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device ): @@ -161,7 +159,7 @@ def fp8_cast(x): torch.testing.assert_close(y1_fp8, x, rtol=5e-1, atol=5e-1) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - def test_bad_cast(self): + def test_bad_cast(self, device): def fp8_cast(x, dtype): return x.to(dtype=dtype) @@ -173,20 +171,19 @@ def fp8_cast(x, dtype): torch._dynamo.exc.BackendCompilerFailed, "Conversions between float8_e5m2 and float8_e4m3fn is not supported!", ): - x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e4m3fn) + x = torch.rand(*x_shape, device=device).to(dtype=torch.float8_e4m3fn) compiled_fp8_cast(x, torch.float8_e5m2) with self.assertRaisesRegex( torch._dynamo.exc.BackendCompilerFailed, "Conversions between float8_e5m2 and float8_e4m3fn is not supported!", ): - x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2) + x = torch.rand(*x_shape, device=device).to(dtype=torch.float8_e5m2) compiled_fp8_cast(x, torch.float8_e4m3fn) @parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("16,16,16", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) def test_to_fp8_saturated( self, src_dtype: torch.dtype, @@ -213,7 +210,6 @@ def fp8_saturated(x, dtype): @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) def test_amax_fp8_quant( self, float8_dtype: torch.dtype, shape: str, device: torch.device ): @@ -244,7 +240,6 @@ def amax_fp8(x: Tensor, scale: Tensor): @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) def test_amax_along_with_fp8_quant( self, float8_dtype: torch.dtype, shape: str, device: torch.device ): @@ -279,7 +274,6 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("amax_keep_dim", (True, False)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) def test_layernorm_fp8_quant( self, float8_dtype: torch.dtype, @@ -326,6 +320,7 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 ) + @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("4,2048,4096",)) @@ -391,7 +386,6 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): ) -@instantiate_parametrized_tests class TestFP8Lowering(TestCase): @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.bfloat16, torch.float32)) @@ -401,6 +395,7 @@ class TestFP8Lowering(TestCase): @parametrize( "persistent_matmul", [False, True] if has_triton_tma_device() else [False] ) + @onlyOn(["cuda", "xpu"]) def test_tensorwise_scaling( self, dtype: torch.dtype, @@ -408,11 +403,10 @@ def test_tensorwise_scaling( has_bias: bool, use_fast_accum: bool, persistent_matmul: bool, + device, ): if dtype is torch.float32 and has_bias: self.skipTest("bias is not supported when output dtype is float32") - - device = "cuda" dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) @@ -426,6 +420,9 @@ def test_tensorwise_scaling( if has_bias: bias = torch.randn(N, device=device, dtype=torch.bfloat16) + # if "xpu" in device and use_fast_accum: + self.skipTest("XPU does not support use_fast_accum=True for now") + # quantize weight (prior to inference) w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8) w_t_fp8 = w_fp8.t() @@ -475,10 +472,14 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_eager, y_compiled, rtol=1e-2, atol=0.05) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - def test_scaled_mm_preserves_strides(self): + @onlyOn(["cuda", "xpu"]) + def test_scaled_mm_preserves_strides(self, device): """Test that scaled_mm preserves stride ordering through a custom pass.""" - GPU_TYPE = "cuda" + GPU_TYPE = device + use_fast_accum = True + if "xpu" in device: + use_fast_accum = False def f(a, b, scale_a, scale_b): # Convert to fp8 with correct strides for scaled_mm @@ -487,7 +488,12 @@ def f(a, b, scale_a, scale_b): a_fp8 = a.to(dtype_float8).contiguous() # row-major b_fp8 = b.t().contiguous().t().to(dtype_float8) # column-major return torch._scaled_mm( - a_fp8, b_fp8, scale_a, scale_b, out_dtype=torch.bfloat16 + a_fp8, + b_fp8, + scale_a, + scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=use_fast_accum, ) class ScaledMMStridePass(PatternMatcherPass): @@ -555,6 +561,7 @@ def __call__(self, g: torch.fx.Graph): # The clones should be visible in the generated code self.assertIn("clone", wrapper.lower()) + @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" @@ -567,8 +574,10 @@ def test_tensorwise_scaling_tma_template( dtype: torch.dtype, shape: str, use_fast_accum: bool, + device, ): - device = "cuda" + if "xpu" in device and use_fast_accum: + self.skipTest("XPU does not support use_fast_accum=True for now") dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) @@ -641,6 +650,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @onlyOn(["cuda", "xpu"]) @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512")) @parametrize("has_bias", (False, True)) @parametrize("use_fast_accum", (False, True)) @@ -648,11 +658,17 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): "persistent_matmul", [False, True] if has_triton_tma_device() else [False] ) def test_rowwise_scaling( - self, shape: str, has_bias: bool, use_fast_accum: bool, persistent_matmul: bool + self, + shape: str, + has_bias: bool, + use_fast_accum: bool, + persistent_matmul: bool, + device, ): + if "xpu" in device and use_fast_accum: + self.skipTest("XPU does not support use_fast_accum=True for now") # Only bf16 output type is supported for row-wise scaling, not fp32 dtype: torch.dtype = torch.bfloat16 - device = "cuda" dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) @@ -710,16 +726,17 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) + @onlyCUDA @parametrize("shape", ("16,32,32", "1024,1024,512")) @parametrize("use_fast_accum", (False, True)) def test_rowwise_scaling_tma_template( self, shape: str, use_fast_accum: bool, + device, ): # Only bf16 output type is supported for row-wise scaling, not fp32 dtype: torch.dtype = torch.bfloat16 - device = "cuda" dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) @@ -794,6 +811,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): _get_torch_cuda_version() < (12, 9), "cuBLAS blockwise scaling added in CUDA 12.9", ) + @onlyCUDA @parametrize("shape", ((16, 256, 256), (1024, 512, 1024))) @parametrize("use_fast_accum", (False, True)) @parametrize( @@ -804,10 +822,12 @@ def test_main_loop_scaling( shape: tuple[int, int, int], use_fast_accum: bool, scaling_block_sizes: tuple[int, int, int, int], + device, ): + if "xpu" in device and use_fast_accum: + self.skipTest("XPU does not support use_fast_accum=True for now") # Only bf16 output type is supported for non-tensorwise scaling, not fp32 dtype: torch.dtype = torch.bfloat16 - device = "cuda" dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) @@ -896,6 +916,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @onlyOn(["cuda", "xpu"]) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) @parametrize("N", (16, 2048)) @@ -903,12 +924,14 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): "persistent_matmul", [False, True] if has_triton_tma_device() else [False] ) def test_tensorwise_scaling_acceptable_input_dims( - self, M: int, K: int, N: int, persistent_matmul: bool + self, M: int, K: int, N: int, persistent_matmul: bool, device ): # alignment requirements: K and N divisible by 16 dtype: torch.dtype = torch.bfloat16 use_fast_accum = True - device = "cuda" + # xpu does not support fast_accum now + if "xpu" in device: + use_fast_accum = False dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) @@ -953,9 +976,13 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07) + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @torch._inductor.config.patch("emulate_precision_casts", True) - def test_mx_fusion(self): + def test_mx_fusion(self, device): + # use a device key for library registration + device_type = torch.device(device).type + device_dispatch_key = "CUDA" if device_type == "cuda" else "XPU" # Register fake_scaled_mm custom op scoped to this test with torch.library._scoped_library("test_fp8", "FRAGMENT") as lib: # Define the op schema @@ -966,8 +993,8 @@ def test_mx_fusion(self): ) input_values = [] - # Register CUDA implementation - @torch.library.impl(lib, "fake_scaled_mm", "CUDA") + # Register CUDA/XPU implementation + @torch.library.impl(lib, "fake_scaled_mm", device_dispatch_key) def fake_scaled_mm_impl( mat_a, mat_b, @@ -1036,7 +1063,7 @@ def forward( ) isnan = torch.ops.aten.isnan.default(unsqueeze) scalar_tensor = torch.ops.aten.scalar_tensor.default( - 255, dtype=torch.uint8, layout=torch.strided, device="cuda" + 255, dtype=torch.uint8, layout=torch.strided, device=device ) where = torch.ops.aten.where.self( isnan, scalar_tensor, convert_element_type @@ -1086,7 +1113,7 @@ def forward( isnan_1 = torch.ops.aten.isnan.default(unsqueeze_1) unsqueeze_1 = None scalar_tensor_1 = torch.ops.aten.scalar_tensor.default( - 255, dtype=torch.uint8, layout=torch.strided, device="cuda" + 255, dtype=torch.uint8, layout=torch.strided, device=device ) where_1 = torch.ops.aten.where.self( isnan_1, scalar_tensor_1, convert_element_type_3 @@ -1152,7 +1179,6 @@ def forward( # Run with largest shape M, K, N = 8192, 8192, 8192 - device = "cuda" A = torch.randn(M, K, dtype=torch.float32, device=device) B = torch.randn(K, N, dtype=torch.float32, device=device) @@ -1188,6 +1214,7 @@ def forward( ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @onlyOn(["cuda", "xpu"]) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) @parametrize("N", (16, 2048)) @@ -1195,11 +1222,13 @@ def forward( "persistent_matmul", [False, True] if has_triton_tma_device() else [False] ) def test_rowwise_scaling_acceptable_input_dims( - self, M: int, K: int, N: int, persistent_matmul: bool + self, M: int, K: int, N: int, persistent_matmul: bool, device ): dtype: torch.dtype = torch.bfloat16 use_fast_accum = True - device = "cuda" + # xpu does not support fast_accum now + if "xpu" in device: + use_fast_accum = False dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) @@ -1246,11 +1275,11 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, "Not supported on non B200") - def test_mx_fp8_max_autotune(self): + def test_mx_fp8_max_autotune(self, device): M, K, N = 128, 32, 128 BLOCK_SIZE = 32 - device = "cuda" dtype = torch.bfloat16 A_ref = torch.eye(M, device=device, dtype=torch.bfloat16) B_ref = torch.eye(N, device=device, dtype=torch.bfloat16) @@ -1284,14 +1313,18 @@ def linear(A, B, A_scale, B_scale): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - def test_unacceptable_input_dims(self): + def test_unacceptable_input_dims(self, device): # for compiled ops, type checking is in torch/_meta_registrations.py dtype: torch.dtype = torch.bfloat16 - device = "cuda" dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) + # xpu does not support fast_accum now + use_fast_accum = True + if "xpu" in device: + use_fast_accum = False M, K, N = 64, 15, 2048 # K needs to be a multiple of 16 x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) @@ -1308,7 +1341,7 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): w_inverse_scale, bias, out_dtype=dtype, - use_fast_accum=True, + use_fast_accum=use_fast_accum, ) return y @@ -1326,9 +1359,9 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - def test_unacceptable_scale_dims_rowwise_scaling(self): + @onlyOn(["cuda", "xpu"]) + def test_unacceptable_scale_dims_rowwise_scaling(self, device): dtype: torch.dtype = torch.bfloat16 - device = "cuda" dtype_float8 = torch.float8_e4m3fn dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) @@ -1338,6 +1371,10 @@ def test_unacceptable_scale_dims_rowwise_scaling(self): bias = torch.randn(N, device=device, dtype=torch.bfloat16) w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8) w_t_fp8 = w_fp8.t() + # xpu does not support fast_accum now + use_fast_accum = True + if "xpu" in device: + use_fast_accum = False def linear(x, w_t_fp8, w_inverse_scale, bias): x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8) @@ -1348,7 +1385,7 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): x_inverse_scale, bias, out_dtype=dtype, - use_fast_accum=True, + use_fast_accum=use_fast_accum, ) return y @@ -1363,6 +1400,10 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): self.assertTrue("Invalid scaling configuration." in str(cm.exception)) +instantiate_device_type_tests(TestFP8Types, globals(), allow_xpu=True) +instantiate_device_type_tests(TestFP8Lowering, globals(), allow_xpu=True) + + if __name__ == "__main__": if HAS_CUDA_AND_TRITON or HAS_CPU: run_tests() diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 25c4efe35a1ab..f620df52a6d3c 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -35,10 +35,13 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, + onlyOn, e4m3_type, e5m2_type, E4M3_MAX_POS, E5M2_MAX_POS, + skipXPU, + skipCUDAIf, ) from torch.testing._internal.common_utils import ( @@ -65,7 +68,7 @@ if TEST_CUDA: _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 -f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" +f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ and XPU devices" f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" mxfp8_grouped_mm_skip_msg = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+" @@ -73,6 +76,12 @@ # avoid division by zero when calculating scale EPS = 1e-12 +def _device_supports_scaled_mm_fp8(device): + if device not in ['cpu', 'xpu'] and (torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8): + return False + return True + + def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype ): @@ -687,7 +696,7 @@ def _test_tautological_mm(self, device: str = "cuda", y_dtype: torch.dtype = e4m3_type, out_dtype: Optional[torch.dtype] = None, size: int = 16) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: + if not _device_supports_scaled_mm_fp8(device): raise unittest.SkipTest(f8_msg) x_fp8 = torch.rand(size, size, device=device).to(x_dtype) y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() @@ -700,12 +709,12 @@ def _test_tautological_mm(self, device: str = "cuda", self.assertEqual(out_fp32, out_fp8.to(torch.float)) def test_float8_basics(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: + if not _device_supports_scaled_mm_fp8(device): raise unittest.SkipTest(f8_msg) self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported # supported on ROCm but fails on CUDA - ctx = self.assertRaises(ValueError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext() + ctx = self.assertRaises(ValueError) if torch.version.hip is None and "cuda" in device else contextlib.nullcontext() with ctx: self._test_tautological_mm(device, e5m2_type, e5m2_type) @@ -716,11 +725,15 @@ def test_float8_basics(self, device) -> None: self._test_tautological_mm(device, size=96, out_dtype=torch.float32) self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) - with self.assertRaises(AssertionError if torch.version.hip or device == "cpu" else RuntimeError): + with self.assertRaises( + AssertionError if (torch.version.hip or "xpu" in device or "cpu" in device) + else RuntimeError + ): self._test_tautological_mm(device, out_dtype=e5m2_type) + def test_float8_scale(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: + if not _device_supports_scaled_mm_fp8(device): raise unittest.SkipTest(f8_msg) size = (16, 16) x = torch.full(size, .5, device=device, dtype=e4m3_type) @@ -736,7 +749,6 @@ def test_float8_scale(self, device) -> None: self.assertEqual(out_fp8, out_fp8_s) - @unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg) @parametrize("G", [1, 4, 16]) @parametrize("M", [2048, 2049]) @@ -951,14 +963,14 @@ def _2d_to_blocked_scaled(X, K, G, offs, format): @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) - def test_scaled_mm_vs_emulated(self, base_dtype): + def test_scaled_mm_vs_emulated(self, base_dtype, device="cuda"): torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype compare_type = torch.float32 - x = torch.randn(16, 16, device="cuda", dtype=base_dtype) - y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + x = torch.randn(16, 16, device=device, dtype=base_dtype) + y = torch.randn(32, 16, device=device, dtype=base_dtype).t() x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() @@ -1001,14 +1013,14 @@ def test_scaled_mm_vs_emulated(self, base_dtype): @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) - def test_scaled_mm_change_stride(self, base_dtype): + def test_scaled_mm_change_stride(self, base_dtype, device="cuda"): torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype compare_type = torch.float32 - x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype) - y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype) + x = torch.empty_strided((16, 16), (16, 1), device=device, dtype=base_dtype) + y = torch.empty_strided((16, 32), (1, 64), device=device, dtype=base_dtype) x.normal_() y.normal_() @@ -1051,10 +1063,9 @@ def test_scaled_mm_change_stride(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @onlyCUDA + @onlyOn(["cuda", "xpu"]) + @skipCUDAIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_bias(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) (k, l, m) = (16, 48, 32) x = torch.ones((k, l), device=device).to(e4m3_type) y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() @@ -1069,7 +1080,7 @@ def test_float8_bias(self, device) -> None: difference = torch.abs(out_fp32 - outb_fp32) self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32)) - @onlyCUDA + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("bias", [True, False]) def test_non_divisible_leading_dim(self, device, bias: bool) -> None: @@ -1082,7 +1093,7 @@ def test_non_divisible_leading_dim(self, device, bias: bool) -> None: input_bias = torch.rand((16,), device=device).to(torch.bfloat16) _ = scaled_mm_wrap(x, y, scale_a, scale_b, bias=input_bias) - @onlyCUDA + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_bias_relu_edgecase(self, device) -> None: (k, l, m) = (16, 48, 32) @@ -1095,7 +1106,7 @@ def test_float8_bias_relu_edgecase(self, device) -> None: outb_fp32 = outb_fp8.to(torch.float32) self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)) - @onlyCUDA + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float32_output_errors_with_bias(self, device) -> None: (k, l, m) = (16, 48, 32) @@ -1104,11 +1115,13 @@ def test_float32_output_errors_with_bias(self, device) -> None: scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) - self.assertRaisesRegex( - ValueError, - "Bias is not supported when out_dtype is set to Float32", - lambda: scaled_mm_wrap(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), - ) + # XPU supports the case when out_dtype is fp32 + bias. So we just test it with normal run. + if "xpu" not in device: + self.assertRaisesRegex( + ValueError if torch.cuda.is_available() else RuntimeError, + "Bias is not supported when out_dtype is set to Float32", + lambda: scaled_mm_wrap(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), + ) @onlyCUDA @unittest.skipIf(PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(), f8_msg) @@ -1139,11 +1152,14 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=e4m3_type, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) - @onlyCUDA + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") + @skipCUDAIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") @parametrize("use_fast_accum", [True, False]) def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: + if torch.xpu.is_available() and use_fast_accum: + raise unittest.SkipTest("XPU does not support fast accum yet") + M, K, N = (1024, 512, 2048) fill_value = 0.5 x = torch.full((M, K), fill_value, device=device) @@ -1167,7 +1183,7 @@ def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> No out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) ) - @onlyCUDA + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) def test_float8_error_messages(self, device) -> None: M, K, N = (1024, 512, 2048) @@ -1184,8 +1200,8 @@ def test_float8_error_messages(self, device) -> None: scaled_mm_wrap( x_fp8, y_fp8, - scale_a=torch.ones((1, 1), device="cuda"), - scale_b=torch.ones((1, 2), device="cuda"), + scale_a=torch.ones((1, 1), device=device), + scale_b=torch.ones((1, 2), device=device), scale_recipe_a=ScalingType.TensorWise, scale_recipe_b=ScalingType.TensorWise, out_dtype=torch.bfloat16, @@ -1197,8 +1213,8 @@ def test_float8_error_messages(self, device) -> None: scaled_mm_wrap( x_fp8, y_fp8, - scale_a=torch.ones((M, 1), device="cuda"), - scale_b=torch.ones((1, N + 1), device="cuda"), + scale_a=torch.ones((M, 1), device=device), + scale_b=torch.ones((1, N + 1), device=device), scale_recipe_a=ScalingType.RowWise, scale_recipe_b=ScalingType.RowWise, out_dtype=torch.bfloat16, @@ -1209,8 +1225,8 @@ def test_float8_error_messages(self, device) -> None: scaled_mm_wrap( x_fp8, y_fp8, - scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N, 1), device="cuda"), + scale_a=torch.ones((M), device=device), + scale_b=torch.ones((N, 1), device=device), scale_recipe_a=ScalingType.RowWise, scale_recipe_b=ScalingType.RowWise, out_dtype=torch.bfloat16, @@ -1222,8 +1238,8 @@ def test_float8_error_messages(self, device) -> None: scaled_mm_wrap( x_fp8, y_fp8, - scale_a=torch.ones((M, 1), device="cuda"), - scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2], + scale_a=torch.ones((M, 1), device=device), + scale_b=torch.ones((1, N * 2), device=device)[:, ::2], scale_recipe_a=ScalingType.RowWise, scale_recipe_b=ScalingType.RowWise, out_dtype=torch.bfloat16, @@ -1233,13 +1249,17 @@ def e5m2(): out = scaled_mm_wrap( x_fp8, y_fp8.to(e5m2_type), - scale_a=torch.ones((M, 1), device="cuda"), - scale_b=torch.ones((1, N), device="cuda"), + scale_a=torch.ones((M, 1), device=device), + scale_b=torch.ones((1, N), device=device), out_dtype=torch.bfloat16, ) return out - if torch.cuda.get_device_capability() == (9, 0) and torch.version.cuda and torch.version.cuda >= "12.9": + if (torch.xpu.is_available() or + (torch.cuda.is_available() and + torch.cuda.get_device_capability() == (9, 0) and + torch.version.cuda and + torch.version.cuda >= "12.9")): out = e5m2() self.assertEqual(out, torch.ones_like(out) * 128.) else: @@ -1258,39 +1278,39 @@ def e5m2(): e5m2() @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") + @skipCUDAIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") @parametrize("base_dtype", [torch.bfloat16, torch.float16, torch.float32]) @parametrize("shapes", [ (128, 512, 256), ]) @with_tf32_off - def test_scaled_mm_vs_emulated_row_wise(self, base_dtype, shapes): + def test_scaled_mm_vs_emulated_row_wise(self, base_dtype, shapes, device): M, K, N = shapes # Fp32 out_dtype is only supported by cuBLAS, which however only started # shipping row-wise kernels in CUDA 12.9, and only for sm90+. if base_dtype is torch.float32: if torch.version.hip: raise unittest.SkipTest("hipblaslt rowwise _scaled_mm only supports BFloat16") - if _get_torch_cuda_version() < (12, 9): + if torch.cuda.is_available() and _get_torch_cuda_version() < (12, 9): raise unittest.SkipTest("Need CUDA 12.9+ for row-wise fp8 w/ cuBLAS") - if torch.cuda.get_device_capability() < (9, 0): + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0): raise unittest.SkipTest("Need sm90+ for row-wise fp8 w/ cuBLAS") if base_dtype is torch.float16: if torch.version.hip: raise unittest.SkipTest("hipblaslt rowwise _scaled_mm only supports BFloat16") - if torch.cuda.get_device_capability() < (9, 0): + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0): raise unittest.SkipTest("Need sm90+ for row-wise fp8 w/ cuBLAS") torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype - x = torch.randn(M, K, device="cuda", dtype=base_dtype) - y = torch.randn(N, K, device="cuda", dtype=base_dtype).t() + x = torch.randn(M, K, device=device, dtype=base_dtype) + y = torch.randn(N, K, device=device, dtype=base_dtype).t() bias = None if base_dtype in {torch.bfloat16, torch.float16}: - bias = torch.randn((N,), device="cuda", dtype=base_dtype) + bias = torch.randn((N,), device=device, dtype=base_dtype) x_scales = tensor_to_scale(x, input_dtype, dim=1).float() y_scales = tensor_to_scale(y, input_dtype, dim=0).float() @@ -1328,7 +1348,7 @@ def test(): # only cuBLAS supports rowwise with fp32 output and cuBLAS only supports # rowwise on SM 9.0 - if torch.cuda.get_device_capability() != (9, 0) and output_dtype == torch.float: + if torch.cuda.is_available() and torch.cuda.get_device_capability() != (9, 0) and output_dtype == torch.float: with self.assertRaisesRegex( ValueError, "Only bf16 and fp16 high precision output types are supported for row-wise scaling." @@ -1683,8 +1703,7 @@ def test_scaled_mm_deepseek_error_messages( @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("which_dim_zero", [0, 1, 2]) @parametrize("use_torch_compile", [False, True]) - def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: - device = "cuda" + def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile, device) -> None: x_dtype, y_dtype = e4m3_type, e4m3_type out_dtype = torch.bfloat16 M, K, N = 32, 32, 32 @@ -1782,6 +1801,7 @@ def test_honor_sm_carveout(self) -> None: self.assertNotEqual(no_carveout, carveout_66) self.assertNotEqual(carveout_66, carveout_0) + @skipXPU def test_pack_uint4(self): """ Verify that given a tensor with high precision values [val0, val1], @@ -2115,6 +2135,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, sqnr = compute_error(C_ref, C) assert sqnr.item() > approx_match_sqnr_target + @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg) @parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"]) def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: @@ -2390,6 +2411,7 @@ def test_blockwise_mxfp8_compile(self) -> None: ) torch.testing.assert_close(C, C_ref, atol=0, rtol=0) + @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) def test_blockwise_nvfp4_compile(self) -> None: @@ -2421,7 +2443,7 @@ def test_blockwise_nvfp4_compile(self) -> None: torch.testing.assert_close(C, C_ref, atol=0, rtol=0) -instantiate_device_type_tests(TestFP8Matmul, globals(), except_for="cpu") +instantiate_device_type_tests(TestFP8Matmul, globals(), except_for="cpu", allow_xpu=True) if __name__ == '__main__': TestCase._default_dtype_check_enabled = True diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index a54bf3c026fe2..ce00b67373e26 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6362,7 +6362,7 @@ def is_fp8_or_fp4_type(dtype): lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", ) - if device_hint(self) == "cuda": + if device_hint(self) == "cuda" or device_hint(self) == "xpu": def is_row_major(stride): return stride[0] > stride[1] and stride[1] == 1 @@ -6592,7 +6592,7 @@ def is_fp4_type(dtype): SwizzleType.NO_SWIZZLE, ] - if device_hint(self) == "cuda": + if device_hint(self) == "cuda" or device_hint(self) == "xpu": def is_row_major(stride): return stride[0] > stride[1] and stride[1] == 1 diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 39f0dec86165a..3f41e4e1a6b12 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -17,6 +17,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 0fe9813d51b34..b1e23016dacbc 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -118,6 +118,8 @@ def evaluate_platform_supports_fp8(): return True else: return SM90OrLater or torch.cuda.get_device_capability() == (8, 9) + if torch.xpu.is_available(): + return True return False def evaluate_platform_supports_fp8_grouped_gemm(): From 5fafc13038c9988d9ac21fa793fbd5890604b447 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 1 Dec 2025 13:19:45 -0800 Subject: [PATCH 163/338] [MPS] Fix dlpack exports/imports for sliced tensors (#169272) For MPS tensor, one must pass both `id` (which is `t.storage().data()` and `t.storage_offset()`) Luckily, DLTensor already has `byte_offset` field, which feels natural to use as product of `storage_offset` and element_size. Partially extends https://github.com/pytorch/pytorch/pull/168193, but instead of writing a completely new test, fix both export and import paths of sliced tensor and unskip test_from_dlpack_noncontinguous for MPS Error out if one is attempting to create tensor with non-zero `byte_offsets` and no strides, as there are no `at::from_blob` variant that could be used Fixes https://github.com/pytorch/pytorch/issues/168177 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169272 Approved by: https://github.com/ngimel --- aten/src/ATen/DLConvertor.cpp | 21 +++++++++++++++++++-- test/test_dlpack.py | 7 +++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index ccb0ae15a11e6..b39f3eafa32df 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -356,8 +356,18 @@ ScalarType toScalarType(const DLDataType& dtype) { return stype; } + namespace { +int64_t toStorageOffset(int64_t byte_offset, ScalarType stype) { + if (byte_offset == 0) { + return 0; + } + const auto element_size = c10::elementSize(stype); + TORCH_CHECK_VALUE(byte_offset % element_size == 0, "byte offset must be multiple of element size"); + return byte_offset / element_size; +} + // The templated classes below are needed for supporting both: // - DLManagedTensor // - DLManagedTensorVersioned @@ -393,13 +403,18 @@ T* toDLPackImpl(const Tensor& src) { atDLMTensor->handle = src; atDLMTensor->tensor.manager_ctx = atDLMTensor; atDLMTensor->tensor.deleter = &deleter; - atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); + if (src.device().type() == kMPS) { + atDLMTensor->tensor.dl_tensor.data = src.storage().mutable_data(); + atDLMTensor->tensor.dl_tensor.byte_offset = src.storage_offset() * c10::elementSize(src.scalar_type()); + } else { + atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); + atDLMTensor->tensor.dl_tensor.byte_offset = 0; + } atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device()); atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); atDLMTensor->tensor.dl_tensor.shape = const_cast(src.sizes().data()); atDLMTensor->tensor.dl_tensor.strides = const_cast(src.strides().data()); - atDLMTensor->tensor.dl_tensor.byte_offset = 0; fillVersion(&atDLMTensor->tensor); return &(atDLMTensor->tensor); @@ -426,6 +441,7 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { ScalarType stype = toScalarType(dl_tensor.dtype); if (!dl_tensor.strides) { + TORCH_CHECK_VALUE(dl_tensor.byte_offset == 0, "Expected zero byte_offset"); return at::from_blob( dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), @@ -437,6 +453,7 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), IntArrayRef(dl_tensor.strides, dl_tensor.ndim), + toStorageOffset(dl_tensor.byte_offset, stype), deleter, at::device(device).dtype(stype), {device}); diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 3d6c4ae7484cb..3d27678b5864a 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -21,7 +21,6 @@ from torch.testing._internal.common_utils import ( IS_JETSON, run_tests, - skipIfMPS, skipIfTorchDynamo, TestCase, ) @@ -157,7 +156,6 @@ def test_from_dlpack(self, device, dtype): self.assertEqual(x, y) @skipMeta - @skipIfMPS # MPS crashes with noncontiguous now @onlyNativeDeviceTypes @dtypes( *all_types_and_complex_and( @@ -169,6 +167,11 @@ def test_from_dlpack(self, device, dtype): torch.uint64, ) ) + @dtypesIfMPS( + *all_mps_types_and( + torch.bool, torch.cfloat, torch.chalf, torch.uint16, torch.uint32 + ) + ) def test_from_dlpack_noncontinguous(self, device, dtype): x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5) From 2bec68e73b64715354af076ad309335f943e36cd Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 2 Dec 2025 16:41:31 -0800 Subject: [PATCH 164/338] [dynamo][hops] Remove restore_side_effect from the auto flattening codepath (#169394) Earlier, restore_side_effects was doing 2 things 1) Restoring side effects 2) If False, it would also expose intermediates to outputs This was getting confusing with autograd.Function, where for the backward, we do want to restore side effects but dont want to expose intermediates as outputs. So, here we fully get rid of restore_side_effects, and leave it to the caller of the auto_flattening speculate subgraph to do this manually. We introduce allow_side_effects flag, which exposes intermediates to outputs and also sets up side effects to not raise when it sees a mutation during hop tracing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169394 Approved by: https://github.com/tugsbayasgalan --- test/higher_order_ops/test_invoke_subgraph.py | 1 - torch/_dynamo/output_graph.py | 1 + torch/_dynamo/side_effects.py | 1 + torch/_dynamo/variables/higher_order_ops.py | 69 ++++++++++--------- 4 files changed, 37 insertions(+), 35 deletions(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index a5a02e4143527..67c4fa0757769 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -2588,7 +2588,6 @@ def f(x, other): self.assertEqual(f(x, other), f_compile(x, other)) self.assertTrue(called) - @unittest.expectedFailure def test_udf_output(self): class Foo: def __init__(self, a, b): diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 981c441bd2986..6ff908ff0394f 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -2979,6 +2979,7 @@ def __init__( # via torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer. # Note: Externally visible side-effects are allowed if this flag OR the above flag is True. self.unsafe_allow_externally_visible_side_effects = False + self.traced_with_externally_visible_side_effects = False # True if we want to allow side effects by returning them as extra outputs from the subgraph. # This is set when enable_side_effects_in_hop=True for HOPs like invoke_subgraph # and checkpoint (when skip_fwd_side_effects_in_bwd_under_checkpoint config is True). diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index e153d7489c7d9..999bd145c3e57 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1219,6 +1219,7 @@ def allow_externally_visible_side_effects_in_subtracer( orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects try: tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True + tx.output.current_tracer.traced_with_externally_visible_side_effects = True yield finally: tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = orig_val diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 3524cb142cdd7..4fe11f4dd03d1 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1190,16 +1190,10 @@ def trace_hop_function_with_auto_output_flattening( tx, subtracer, enable_grad, - restore_side_effects, + allow_side_effects, args, sub_kwargs, ): - # For the new unified control flow ops, we couple restore_side_effects - # with allow_side_effects_in_hop using the new semantics: - # - restore_side_effects=False means side effects become extra outputs - # - This allows mutations to be tracked and replayed - enable_side_effects_with_extra_outputs = not restore_side_effects - autograd_ctx = ( dynamo_enable_grad(tx, enable_grad) if enable_grad is not None @@ -1207,22 +1201,13 @@ def trace_hop_function_with_auto_output_flattening( ) side_effects_ctx = ( dynamo_allow_side_effects_in_hop(tx) - if enable_side_effects_with_extra_outputs + if allow_side_effects else contextlib.nullcontext() ) - if restore_side_effects: - prev_side_effects = tx.output.side_effects.clone() - with autograd_ctx, side_effects_ctx: output = f.call_function(tx, args, sub_kwargs) - if restore_side_effects: - new_side_effects = tx.output.side_effects.clone() - prev_side_effects.track_runahead_tensor_and_symvar_side_effects( - new_side_effects - ) - tx.output.side_effects = prev_side_effects return output @@ -1271,7 +1256,9 @@ def speculate_subgraph_with_auto_output_flattening( set_subgraph_inputs: Literal[ "automatic", "semi_automatic", "flatten_manual", "manual" ] = "automatic", - restore_side_effects: bool = True, + # If True, exposes intermediates to subgraph outputs to allow later tensor ops to + # access intermediates from the subgraph, this is useful for mutation + allow_side_effects: bool = False, # TODO - supports input_mutation and aliasing should be False by default for strictness supports_input_mutation: bool = True, supports_aliasing: bool = True, @@ -1386,12 +1373,25 @@ def gn(x): tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description ) + # Special case - if users uses + # `traced_with_externally_visible_side_effects`, we still need to + # return the intermediates as outputs. However, this API gets + # triggered during the hop tracing, and we don't know at this point + # of time, if the API will take into effect. To handle this, we have + # a flag traced_with_externally_visible_side_effects (default=False) + # that is set to True anytime + # `traced_with_externally_visible_side_effects` is set. We reset it + # with the old value after the hop is traced out. + old_value = ( + tx.output.current_tracer.traced_with_externally_visible_side_effects + ) + output = trace_hop_function_with_auto_output_flattening( f, tx, subtracer, enable_grad, - restore_side_effects, + allow_side_effects, args, sub_kwargs, ) @@ -1469,13 +1469,21 @@ def visit(vt): # want this to be supported for other Hops as well, specifically # nested_compile_region and autograd.Function. Today, its safe # because we error out on seeing a side-effect. - enable_side_effects_with_extra_outputs = not restore_side_effects - if enable_side_effects_with_extra_outputs: + + allow_side_effects = ( + allow_side_effects + or tx.output.current_tracer.traced_with_externally_visible_side_effects + ) + if allow_side_effects: extra_outputs = _collect_intermediate_outputs( tx, subtracer, graph_output_vts ) graph_output_vts = graph_output_vts + tuple(extra_outputs) + tx.output.current_tracer.traced_with_externally_visible_side_effects = ( + old_value + ) + validate_subgraph_output_types(graph_output_vts) # The output proxies might not belong to this SubgraphTracer @@ -2877,9 +2885,7 @@ def call_function( class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): supports_input_mutation = True supports_aliasing = True - # TODO - Go through all subclasses of WrapHigherOrderVariable to see if - # restore_side_effects can be ignored. For now, this is conservative. - restore_side_effects = True + allow_side_effects = False def install_subgraph_in_output_graph( self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" @@ -2900,7 +2906,6 @@ def create_wrapped_node( subgraph_name="wrap_body", ): # See NOTE [HigherOrderOperator tracing design] for more details - ( body_r, body_graph, @@ -2913,7 +2918,7 @@ def create_wrapped_node( kwargs, description, source_target=self.value, - restore_side_effects=self.restore_side_effects, + allow_side_effects=self.allow_side_effects, supports_input_mutation=self.supports_input_mutation, supports_aliasing=self.supports_aliasing, ) @@ -3372,10 +3377,8 @@ def _call_function( class CheckpointHigherOrderVariable(WrapHigherOrderVariable): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - # When skip_fwd_side_effects_in_bwd is True, we allow side effects by NOT restoring them. - # This enables collecting intermediate outputs for side effects. - self.restore_side_effects = ( - not torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint + self.allow_side_effects = ( + torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint ) def _call_function( @@ -4240,10 +4243,8 @@ def _call_function( class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): supports_input_mutation = True supports_aliasing = False - # TODO (tmanlaibaatar) This is in preparation for supporting side effects in invoke_subgraph. - # invoke_subgraph does not support side effects, so we restore them (default behavior). - # This means enable_side_effects_with_extra_outputs will be False. - restore_side_effects = True + # TODO - make this true to support mutation + allow_side_effects = False def install_subgraph_in_output_graph( self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name From bac403c0b38c63bdbcc0c31f1c2b0bc0260f610f Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Wed, 3 Dec 2025 07:40:41 +0000 Subject: [PATCH 165/338] [Dynamo] Adding `strict` as a kwarg for `map()` in Python3.14 (#167828) The [map()](https://docs.python.org/3.14/library/functions.html#map) function now has an optional keyword-only strict flag like [zip()](https://docs.python.org/3.14/library/functions.html#zip) to check that all the iterables are of equal length. Refer [this](https://docs.python.org/3.14/whatsnew/3.14.html). Pull Request resolved: https://github.com/pytorch/pytorch/pull/167828 Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42 --- test/dynamo/test_functions.py | 70 ++++++++++++++++++++++++++++++ torch/_dynamo/variables/builtin.py | 27 +++++++++++- torch/_dynamo/variables/iter.py | 22 +++++++--- 3 files changed, 112 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 840d4b32ab389..c0f40052c8d63 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -4778,6 +4778,76 @@ def fn(x, ys, zs): with self.assertRaisesRegex(ValueError, "zip()"): opt_fn(x, ys, zs[:1]) + def test_map_strict(self): + def fn(x, ys, zs): + x = x.clone() + for y, z in map(lambda a, b: (a, b), ys, zs, strict=True): + x += y * z + return x, map(lambda a, b: a + b, ys, zs, strict=True) + + opt_fn = torch.compile(fn, backend="eager") + nopython_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = torch.ones(3) + ys = [1.0, 2.0, 3.0] + zs = [2.0, 5.0, 8.0] + + if sys.version_info < (3, 14): + with self.assertRaises(TypeError): + opt_fn(x, ys, zs) + with self.assertRaises(TypeError): + nopython_fn(x, ys, zs) + return + + ref = fn(x, ys, zs) + res = opt_fn(x, ys, zs) + self.assertEqual(ref[0], res[0]) + self.assertEqual(list(ref[1]), list(res[1])) + self.assertIsInstance(res[1], map) + + # If nopython, should raise UserError + with self.assertRaisesRegex(torch._dynamo.exc.UserError, "map()"): + nopython_fn(x, ys[:1], zs) + + with self.assertRaisesRegex(torch._dynamo.exc.UserError, "map()"): + nopython_fn(x, ys, zs[:1]) + + # Should cause fallback if allow graph break + with self.assertRaisesRegex(ValueError, "map()"): + opt_fn(x, ys[:1], zs) + + with self.assertRaisesRegex(ValueError, "map()"): + opt_fn(x, ys, zs[:1]) + + # Check strict is set by testing a map returned from dynamo + opt_map_fn = torch.compile( + lambda ys, zs: map(lambda a, b: a + b, ys, zs, strict=True), backend="eager" + ) + strict_map_from_dynamo = opt_map_fn(ys[:1], zs) + with self.assertRaises(ValueError): + list(strict_map_from_dynamo) + + @unittest.skipIf(sys.version_info < (3, 14), "strict requires Python 3.14+") + def test_map_strict_with_graph_break(self): + def f(a): + a += 1 + + def g(x, y): + nonlocal a + a += 1 + return x + y + + m = map(g, [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], strict=True) + a += next(m) # won't graph break + torch._dynamo.graph_break() + a += next(m) # will graph break + return a + + cnts = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=cnts) + self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3))) + self.assertEqual(cnts.frame_count, 3) + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_gpu_current_device(self): def fn(x): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 8fdaefea56f89..9bd1bae080508 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -26,6 +26,7 @@ import logging import math import operator +import sys import types import typing import unittest @@ -2474,8 +2475,31 @@ def call_hasattr( return None def call_map( - self, tx: "InstructionTranslator", fn: VariableTracker, *seqs: VariableTracker + self, + tx: "InstructionTranslator", + fn: VariableTracker, + *seqs: VariableTracker, + **kwargs: VariableTracker, ) -> VariableTracker: + strict = ConstantVariable.create(False) + if kwargs: + if sys.version_info >= (3, 14): + if not (len(kwargs) == 1 and "strict" in kwargs): + raise_args_mismatch( + tx, + "map", + "1 kwargs (`strict`)", + f"{len(kwargs)} kwargs", + ) + strict = kwargs.pop("strict", ConstantVariable.create(False)) + else: + raise_args_mismatch( + tx, + "map", + "0 kwargs", + f"{len(kwargs)} kwargs", + ) + seq_list = [ seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq for seq in seqs @@ -2483,6 +2507,7 @@ def call_map( return variables.MapVariable( fn, seq_list, # type: ignore[arg-type] + strict=strict.as_python_constant(), mutation_type=ValueMutationNew(), ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index c111dca9f2d68..2689d5e094977 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -14,6 +14,7 @@ """ import itertools +import sys from collections.abc import Callable, Sequence from typing import Any, TYPE_CHECKING, Union @@ -522,12 +523,21 @@ def reconstruct(self, codegen: "PyCodegen") -> None: ) codegen(self.fn) self.reconstruct_items(codegen) - codegen.extend_output( - [ - create_build_tuple(len(self.iterables) + 1), - *create_call_function_ex(False, False), - ] - ) + codegen.append_output(create_build_tuple(len(self.iterables) + 1)) + if self.strict: + assert sys.version_info >= (3, 14), ( + "Unexpected bug: map(strict=True) requires Python 3.14+" + ) + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + *create_call_function_ex(True, False), + ] + ) + else: + codegen.extend_output(create_call_function_ex(False, False)) class FilterVariable(IteratorVariable): From 74fe26a1ebe32931783569f2e762e3c2c974901f Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Wed, 3 Dec 2025 08:23:15 +0000 Subject: [PATCH 166/338] Fixes complex datatype handling in ddp (#166863) Fixes #158753 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166863 Approved by: https://github.com/wconstab --- test/distributed/test_c10d_nccl.py | 222 ++++++++++++++++++++++++ torch/csrc/distributed/c10d/reducer.cpp | 130 +++++++++++--- torch/csrc/distributed/c10d/reducer.hpp | 3 + 3 files changed, 327 insertions(+), 28 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5b1b6c8925806..60deb3654df27 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1997,6 +1997,228 @@ def _test_nccl_backend( process_group, devices, device_ids, multi_device, gradient_as_bucket_view ) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_ddp_complex_params_and_grads(self): + # test ddp with complex parameters and gradients + process_group = self._get_process_group() + device_id = gpus_for_rank(self.world_size)[self.rank][0] + device = torch.device(f"cuda:{device_id}") + + torch.manual_seed(42 + self.rank) + model = nn.Sequential( + nn.Linear(4, 8, dtype=torch.cfloat), + nn.Linear(8, 2, dtype=torch.cfloat), + ).to(device) + + torch.manual_seed(42 + self.rank) + ref_model = nn.Sequential( + nn.Linear(4, 8, dtype=torch.cfloat), + nn.Linear(8, 2, dtype=torch.cfloat), + ).to(device) + + # 0.001 forces tiny buckets, creating multiple buckets, stress-testing bucketing + ddp_model = DistributedDataParallel( + model, + device_ids=[device_id], + process_group=process_group, + bucket_cap_mb=0.001, + ) + + torch.manual_seed(100) + batch_size = 16 + input_dim = 4 + output_dim = 2 + + x = torch.randn(batch_size, input_dim, dtype=torch.cfloat, device=device) + y = torch.randn(batch_size, output_dim, dtype=torch.cfloat, device=device) + + optimizer_ddp = torch.optim.SGD(ddp_model.parameters(), lr=0.01) + optimizer_ref = torch.optim.SGD(ref_model.parameters(), lr=0.01) + + for iteration in range(5): + optimizer_ddp.zero_grad() + output_ddp = ddp_model(x) + loss_ddp = torch.mean(torch.abs(output_ddp - y) ** 2) + loss_ddp.backward() + + optimizer_ref.zero_grad() + with torch.no_grad(): + for p_ddp, p_ref in zip(ddp_model.parameters(), ref_model.parameters()): + p_ref.copy_(p_ddp) + + output_ref = ref_model(x) + loss_ref = torch.mean(torch.abs(output_ref - y) ** 2) + loss_ref.backward() + + for param in ref_model.parameters(): + if param.grad is not None: + dist.all_reduce( + param.grad.data, op=dist.ReduceOp.SUM, group=process_group + ) + param.grad.data /= self.world_size + + for name, (p_ddp, p_ref) in enumerate( + zip(ddp_model.parameters(), ref_model.parameters()) + ): + self.assertIsNotNone( + p_ddp.grad, + f"DDP gradient is None at iteration {iteration}, param {name}", + ) + + self.assertIsNotNone( + p_ref.grad, + f"Reference gradient is None at iteration {iteration}, param {name}", + ) + + self.assertTrue( + p_ddp.grad.is_complex(), + f"DDP gradient lost complex dtype at iteration {iteration}, param {name}", + ) + + self.assertTrue( + p_ref.grad.is_complex(), + f"Reference gradient lost complex dtype at iteration {iteration}, param {name}", + ) + + self.assertFalse( + torch.allclose(p_ddp.grad.imag, torch.zeros_like(p_ddp.grad.imag)), + f"DDP imaginary gradient is all zeros at iteration {iteration}, param {name}! " + f"This indicates the complex gradient bug.", + ) + + self.assertTrue( + torch.allclose( + p_ddp.grad.real, p_ref.grad.real, rtol=1e-5, atol=1e-5 + ), + f"Real gradient mismatch at iteration {iteration}, param {name}\n" + f"DDP real: {p_ddp.grad.real.mean():.6f}, " + f"Ref real: {p_ref.grad.real.mean():.6f}", + ) + + self.assertTrue( + torch.allclose( + p_ddp.grad.imag, p_ref.grad.imag, rtol=1e-5, atol=1e-5 + ), + f"Imaginary gradient mismatch at iteration {iteration}, param {name}\n" + f"DDP imag: {p_ddp.grad.imag.mean():.6f}, " + f"Ref imag: {p_ref.grad.imag.mean():.6f}", + ) + + optimizer_ddp.step() + optimizer_ref.step() + + for p_ddp, p_ref in zip(ddp_model.parameters(), ref_model.parameters()): + self.assertTrue( + torch.allclose(p_ddp, p_ref, rtol=1e-4, atol=1e-4), + "Final model parameters don't match after training", + ) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_ddp_mixed_real_and_complex_params(self): + # test ddp with mixed real and complex parameters and gradients + process_group = self._get_process_group() + device_id = gpus_for_rank(self.world_size)[self.rank][0] + device = torch.device(f"cuda:{device_id}") + + class MixedModule(nn.Module): + def __init__(self): + super().__init__() + self.complex_fc = nn.Linear(4, 4, dtype=torch.cfloat) + self.real_fc = nn.Linear(4, 4, dtype=torch.float32) + self.final_fc = nn.Linear(4, 2, dtype=torch.cfloat) + + def forward(self, x_complex, x_real): + complex_branch = self.complex_fc(x_complex) + real_branch = self.real_fc(x_real) + real_as_complex = torch.complex( + real_branch, torch.zeros_like(real_branch) + ) + return self.final_fc(complex_branch + real_as_complex) + + torch.manual_seed(42 + self.rank) + model = MixedModule().to(device) + ref_model = MixedModule().to(device) + + # 100 forces large bucket, forcing the BucketKey mechanism to segregate buckets, testing bucket segregation by dtype + ddp_model = DistributedDataParallel( + model, + device_ids=[device_id], + process_group=process_group, + bucket_cap_mb=100, + ) + + optimizer_ddp = torch.optim.SGD(ddp_model.parameters(), lr=0.01) + optimizer_ref = torch.optim.SGD(ref_model.parameters(), lr=0.01) + + torch.manual_seed(100) + x_complex = torch.randn(8, 4, dtype=torch.cfloat, device=device) + x_real = torch.randn(8, 4, dtype=torch.float32, device=device) + target = torch.randn(8, 2, dtype=torch.cfloat, device=device) + + for iteration in range(5): + optimizer_ddp.zero_grad() + loss_ddp = torch.mean(torch.abs(ddp_model(x_complex, x_real) - target) ** 2) + loss_ddp.backward() + + optimizer_ref.zero_grad() + with torch.no_grad(): + for p_ddp, p_ref in zip(ddp_model.parameters(), ref_model.parameters()): + p_ref.copy_(p_ddp) + loss_ref = torch.mean(torch.abs(ref_model(x_complex, x_real) - target) ** 2) + loss_ref.backward() + for param in ref_model.parameters(5): + if param.grad is not None and param.grad.is_floating_point(): + dist.all_reduce( + param.grad.data, + op=dist.ReduceOp.SUM, + group=process_group, + ) + param.grad.data /= self.world_size + + for name, (p_ddp, p_ref) in enumerate( + zip(ddp_model.parameters(), ref_model.parameters()) + ): + self.assertIsNotNone( + p_ddp.grad, + f"DDP gradient is None at iteration {iteration}, param {name}", + ) + self.assertIsNotNone( + p_ref.grad, + f"Reference gradient is None at iteration {iteration}, param {name}", + ) + + self.assertTrue( + p_ddp.grad.is_complex() == p_ref.grad.is_complex(), + f"Gradient dtype mismatch at iteration {iteration}, param {name}", + ) + + if p_ddp.grad.is_complex(): + self.assertFalse( + torch.allclose( + p_ddp.grad.imag, torch.zeros_like(p_ddp.grad.imag) + ), + f"DDP imaginary gradient is all zeros at iteration {iteration}, param {name}", + ) + self.assertTrue( + torch.allclose( + p_ddp.grad.real, p_ref.grad.real, rtol=1e-5, atol=1e-5 + ), + f"Real gradient mismatch at iteration {iteration}, param {name}", + ) + self.assertTrue( + torch.allclose( + p_ddp.grad.imag, p_ref.grad.imag, rtol=1e-5, atol=1e-5 + ), + f"Imaginary gradient mismatch at iteration {iteration}, param {name}", + ) + else: + self.assertTrue( + torch.allclose(p_ddp.grad, p_ref.grad, rtol=1e-5, atol=1e-5), + f"Real gradient mismatch at iteration {iteration}, param {name}", + ) + @requires_nccl() @skip_if_lt_x_gpu(2) def test_nccl_propagate_error_reason(self): diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index c4af19ef44209..d2bf2c6cf7f62 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -1151,6 +1151,9 @@ void Reducer::initialize_buckets( } if (!options.has_dtype()) { options = options.dtype(variable.dtype()); + if (variable.is_complex()) { + bucket.is_complex_bucket = true; + } } else { REDUCER_CHECK( variable.dtype() == options.dtype(), @@ -1201,6 +1204,10 @@ void Reducer::initialize_buckets( LOG(INFO) << "Reducer: comm-optimized memory allocator not found, using regular one"; bucket.gradients = at::empty({bucketSize}, options); + + if (bucket.is_complex_bucket) { + bucket.gradients = at::view_as_real(bucket.gradients).reshape({-1}); + } } // Note: "Gradient Layout Contract" @@ -1267,21 +1274,55 @@ void Reducer::initialize_bucket_views(Reducer::Bucket& bucket) { const auto offset = bucket.offsets[i]; const auto length = bucket.lengths[i]; - if (v.is_non_overlapping_and_dense()) { - // If the param's memory is dense, match its layout, anticipating - // the autograd engine (AccumulateGrad) will also create gradients - // matching its layout. - bucket.bucket_views_in.push_back( - gradients.as_strided(v.sizes(), v.strides(), offset)); + if (v.is_complex() && bucket.is_complex_bucket) { + const auto real_offset = offset * 2; + const auto real_length = length * 2; + + if (v.is_non_overlapping_and_dense()) { + auto complex_strides = v.strides(); + std::vector real_strides; + real_strides.reserve(complex_strides.size() + 1); + for (auto s : complex_strides) { + real_strides.push_back(s * 2); + } + real_strides.push_back(1); + + auto complex_sizes = v.sizes(); + std::vector real_sizes( + complex_sizes.begin(), complex_sizes.end()); + real_sizes.push_back(2); + + auto real_view = + gradients.as_strided(real_sizes, real_strides, real_offset); + auto complex_view = at::view_as_complex(real_view); + bucket.bucket_views_in.push_back(complex_view); + } else { + auto real_view = gradients.narrow( + 0, + static_cast(real_offset), + static_cast(real_length)); + auto complex_view = at::view_as_complex( + real_view.reshape({static_cast(length), 2})); + bucket.bucket_views_in.push_back(complex_view.view(v.sizes())); + } } else { - // Fall back to a C-style contiguous view, again anticipating - // AccumulateGrad will do the same when stashing grads for non-dense - // params. - bucket.bucket_views_in.push_back( - gradients - .narrow( - 0, static_cast(offset), static_cast(length)) - .view(v.sizes())); + if (v.is_non_overlapping_and_dense()) { + // If the param's memory is dense, match its layout, anticipating + // the autograd engine (AccumulateGrad) will also create gradients + // matching its layout. + bucket.bucket_views_in.push_back( + gradients.as_strided(v.sizes(), v.strides(), offset)); + } else { + // Fall back to a C-style contiguous view, again anticipating + // AccumulateGrad will do the same when stashing grads for non-dense + // params. + bucket.bucket_views_in.push_back(gradients + .narrow( + 0, + static_cast(offset), + static_cast(length)) + .view(v.sizes())); + } } // By default `bucket_views_out` and `bucket_views_in` are // essentially the same thing. @@ -1322,21 +1363,54 @@ void Reducer::populate_bucket_views_out( const auto offset = bucket.offsets[i]; const auto length = bucket.lengths[i]; - if (v.is_non_overlapping_and_dense()) { - // If the param's memory is dense, match its layout, anticipating - // the autograd engine (AccumulateGrad) will also create gradients - // matching its layout. - bucket.bucket_views_out.push_back( - tensor.as_strided(v.sizes(), v.strides(), offset)); + if (v.is_complex() && bucket.is_complex_bucket) { + const auto real_offset = offset * 2; + + if (v.is_non_overlapping_and_dense()) { + auto complex_strides = v.strides(); + std::vector real_strides; + real_strides.reserve(complex_strides.size() + 1); + for (auto s : complex_strides) { + real_strides.push_back(s * 2); + } + real_strides.push_back(1); + + auto complex_sizes = v.sizes(); + std::vector real_sizes( + complex_sizes.begin(), complex_sizes.end()); + real_sizes.push_back(2); + + auto real_view = + tensor.as_strided(real_sizes, real_strides, real_offset); + bucket.bucket_views_out.push_back(at::view_as_complex(real_view)); + } else { + const auto real_length = length * 2; + auto real_view = tensor.narrow( + 0, + static_cast(real_offset), + static_cast(real_length)); + auto complex_view = at::view_as_complex( + real_view.reshape({static_cast(length), 2})); + bucket.bucket_views_out.push_back(complex_view.view(v.sizes())); + } } else { - // Fall back to a C-style contiguous view, again anticipating - // AccumulateGrad will do the same when stashing grads for non-dense - // params. - bucket.bucket_views_out.push_back( - tensor - .narrow( - 0, static_cast(offset), static_cast(length)) - .view(v.sizes())); + if (v.is_non_overlapping_and_dense()) { + // If the param's memory is dense, match its layout, anticipating + // the autograd engine (AccumulateGrad) will also create gradients + // matching its layout. + bucket.bucket_views_out.push_back( + tensor.as_strided(v.sizes(), v.strides(), offset)); + } else { + // Fall back to a C-style contiguous view, again anticipating + // AccumulateGrad will do the same when stashing grads for non-dense + // params. + bucket.bucket_views_out.push_back(tensor + .narrow( + 0, + static_cast(offset), + static_cast(length)) + .view(v.sizes())); + } } } } diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index 4e5ed6a9a5c3f..37ea033445177 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -386,6 +386,9 @@ class TORCH_API Reducer { // If no hook is registered, a temporary vanilla allreduce hook is used. c10::intrusive_ptr future_work; + // if this bucket contains complex parameters + bool is_complex_bucket = false; + // If this bucket should expect a single sparse gradient // If `true`, then this implies that `bucket.variables.size() == 1`. bool expect_sparse_gradient = false; From 0bbbdf1750567a980634ad907a325357ba8ba8f2 Mon Sep 17 00:00:00 2001 From: Jason Xie Date: Wed, 3 Dec 2025 08:35:09 +0000 Subject: [PATCH 167/338] [Inductor] Fix Diode / exhaustive autotune crash on AMD (#169225) Summary: Two issues prevent using Diode w/ expanded search space on AMD: 1. matrix_instr_nonkdim=2 and kpack=2 causes triton compile to fail 2. GROUP_M=0 crashes AMD GPU (but not NV) repro: P2057901593 Test Plan: MODEL=822608598; SNAPSHOT=0 TORCHINDUCTOR_COMPILE_THREADS=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCH_COMPILE_DEBUG=1 HIP_VISIBLE_DEVICES=7 buck2 run -m rocm640 mode/opt-split-dwarf mode/inplace mode/amd-gpu -c fbcode.triton_backend=amd -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=mi300 caffe2/torch/fb/model_transform/fx2trt/packaging:generate_merge_net_file -- --action=generate --max-batch-size=3072 --preset_lowerer='relu_nan_to_num;disable_new_lowering_weights' --input-file=/home/${USER}/models/${MODEL}/${SNAPSHOT}/input.predictor.disagg.gpu.merge --output-file=/home/${USER}/models/${MODEL}/${SNAPSHOT}/fp8_amd_output_diode.predictor.disagg.gpu.merge --diode-config="{'top_k': 100, 'expand_search_space': True, 'discard_unpredicted': False}" --lower-backend aot_inductor --add_passes="use_matmul_lce_replace_normal_LCE,use_triton_dot_compress,use_matmul_fuse_lce_replace_first_LCE,use_contiguous_linear_reduction_replace_linear_reduction" --aot_inductor_config="{'max_autotune': True, 'comprehensive_padding': False, 'aot_inductor.use_runtime_constant_folding': True}" --hardware-type GFX942_X86 --node_replacement_dict="{'torch.nn.Linear':{'(3000+, 3000+)':'fp8_float_model_dynamic_quantization_rowwise_triton'}}" 2>&1 | tee ~/logs/lower_diode.log Differential Revision: D87963125 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169225 Approved by: https://github.com/coconutruben --- torch/_inductor/template_heuristics/triton.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index a9925db292a36..55c798922184a 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -1278,7 +1278,16 @@ def _prune_exhaustive_configs( configs: list[BaseConfig], dtype_size: int, ) -> list[BaseConfig]: - return configs + # these cause AMD compile to crash + pruned_configs = [ + c + for c in configs + if not ( + getattr(c, "matrix_instr_nonkdim", 0) == 2 + and getattr(c, "kpack", 0) == 2 + ) + ] + return pruned_configs def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: """ @@ -1329,6 +1338,9 @@ def _finalize_mm_configs( # Check if gemm specific arg exists - add to key if does group_m = getattr(conf, "group_m", None) + # AMD GPU crashes if group_m = 0 + if group_m is not None and group_m <= 0: + group_m = 8 if group_m is not None: key += (group_m,) From 5bf1cdf4755c54ef462b44cb8041b0a57311556b Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Mon, 1 Dec 2025 17:38:15 -0800 Subject: [PATCH 168/338] [DTensor] compute shape and offset for arbitrary _StridedShard (#168146) resolve https://github.com/pytorch/pytorch/issues/167859 for _StridedShard, compute_local_shape_and_global_offset was landed to consider fsdp2 + tp: (_StridedShard(0, split_factor=mesh.size(k)), Shard(0)). Need to extend it to arbitrary _StridedShard for example, `_StridedShard(dim=0, split_factor=batch_size), _StridedShard(dim=0, split_factor=batch_size * seq_len / device_mesh.size(0))` This PR ensure correct local shape for DTensor views with _StridedShard Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D87897203](https://our.internmc.facebook.com/intern/diff/D87897203) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168146 Approved by: https://github.com/wconstab --- test/distributed/tensor/test_utils.py | 471 +++++++++++------- torch/distributed/tensor/_api.py | 2 +- .../distributed/tensor/_ops/_common_rules.py | 5 +- torch/distributed/tensor/_ops/_matrix_ops.py | 2 +- torch/distributed/tensor/_sharding_prop.py | 2 +- torch/distributed/tensor/_utils.py | 238 ++++----- torch/distributed/tensor/placement_types.py | 51 +- 7 files changed, 452 insertions(+), 319 deletions(-) diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 11b70c8554e52..5f3225d174cb2 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -16,7 +16,6 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._utils import ( _compute_local_shape_and_global_offset, - _explicit_order_placements, compute_global_tensor_info, compute_global_tensor_shape, compute_local_shape_and_global_offset, @@ -46,85 +45,6 @@ class LocalTest(TestCase): - def test_explicit_order_placements(self): - # mesh_shape: ShapeType, placements: Sequence[Placement] - test_cases = [ - { - "mesh_shape": [2, 4], - "placements": [Replicate(), Replicate()], - "ordered": [(0, Replicate()), (1, Replicate())], - }, - { - "mesh_shape": [3, 2], - "placements": [Shard(0), Replicate()], - "ordered": [(0, Shard(0)), (1, Replicate())], - }, - { - "mesh_shape": [2, 4], - "placements": [_StridedShard(0, split_factor=4), Shard(0)], - "ordered": [(1, Shard(0)), (0, Shard(0))], - }, - { - "mesh_shape": [2, 3, 4], - "placements": [Shard(0), _StridedShard(0, split_factor=4), Shard(0)], - "ordered": [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))], - }, - { - "mesh_shape": [2, 3, 4], - "placements": [ - _StridedShard(0, split_factor=12), - _StridedShard(0, split_factor=4), - Shard(0), - ], - "ordered": [(2, Shard(0)), (1, Shard(0)), (0, Shard(0))], - }, - ] - for test_case in test_cases: - actual = _explicit_order_placements( - test_case["mesh_shape"], test_case["placements"] - ) - expected = test_case["ordered"] - - self.assertEqual( - actual, - expected, - f"mesh_shape={test_case['mesh_shape']} placements={test_case['placements']}, output: {actual=}, {expected=}", - ) - - error_cases = [ - { - "mesh_shape": [2, 3, 4], - "placements": [Shard(0), _StridedShard(0, split_factor=3), Shard(0)], - "exception_type": RuntimeError, - "exception_text": "Can only convert _StridedShard to ordered Shard if split_factor", - }, - { - "mesh_shape": [2, 3, 4], - "placements": [ - _StridedShard(0, split_factor=3), - Shard(0), - Shard(0), - ], - "exception_type": NotImplementedError, - "exception_text": r"Strided sharding does not allow Shard\(\) to appear after the strided part has ended", - }, - { - "mesh_shape": [2, 3], - "placements": [ - Shard(0), - ], - "exception_type": RuntimeError, - "exception_text": "Expected one placement per mesh dim", - }, - ] - for test_case in error_cases: - with self.assertRaisesRegex( - test_case["exception_type"], test_case["exception_text"] - ): - _explicit_order_placements( - test_case["mesh_shape"], test_case["placements"] - ) - def test_compute_local_shape_and_global_offset_uneven(self): # This case is not only 'uneven' bug also has an empty shard # (e.g. most DP ranks have local shape 18,4096, one has 8,4096, one has 0,4096 @@ -151,6 +71,225 @@ def test_compute_local_shape_and_global_offset_uneven(self): self.assertEqual(local_shape, (expected_shard_size, 4096)) self.assertEqual(global_offset, (expected_shard_offset, 0)) + # S, S uneven without empty + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), Shard(0)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + + dp012_shard_size = 5 + if dp_rank in (0, 1, 2): + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 3 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 2 + else: + assert dp_rank == 3 + tp0_shard_size = 2 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 1 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # S, S uneven with empty + global_shape = (13, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), Shard(0)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + + dp012_shard_size = 4 + if dp_rank in (0, 1, 2): + tp0_shard_size = 2 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 2 + else: + assert dp_rank == 3 + tp0_shard_size = 1 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 1 + else: + assert tp_rank == 1 + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # SS, Shard + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [_StridedShard(0, split_factor=TP), Shard(0)] + TP_shard_size = int(global_shape[0] / TP) + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + expected_shard_size = 3 + expected_shard_offset = ( + tp_rank * TP_shard_size + expected_shard_size * dp_rank + ) + if dp_rank == 3: + expected_shard_size = 0 + expected_shard_offset = 18 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # SS, SS + global_shape = (39, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [ + _StridedShard(0, split_factor=3), + _StridedShard(0, split_factor=4), + ] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if dp_rank in (0, 1, 2): + tp0_shard_size = 8 + if tp_rank == 0: + expected_shard_offset = 4 * dp_rank + expected_shard_size = tp0_shard_size + else: + assert tp_rank == 1 + expected_shard_offset = 4 * dp_rank + 2 + expected_shard_size = 4 + else: + assert dp_rank == 3 + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = 4 * dp_rank + expected_shard_size = 3 + else: + assert tp_rank == 1 + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # (Shard, SS) + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), _StridedShard(0, split_factor=2)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if dp_rank in (0, 1, 2): + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = 5 * dp_rank + expected_shard_size = tp0_shard_size + else: + assert tp_rank == 1 + expected_shard_offset = 5 * dp_rank + 2 + expected_shard_size = 2 + else: + assert dp_rank == 3 + if tp_rank == 0: + expected_shard_offset = 5 * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = 5 * dp_rank + 1 + expected_shard_size = 1 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # (Shard, SS, Shard) + global_shape = (39, 2) + mesh0, mesh1, mesh2 = 4, 2, 3 + mesh_shape = (mesh0, mesh1, mesh2) + placements = [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] + for my_coordinate in itertools.product( + range(mesh0), range(mesh1), range(mesh2) + ): + mesh0_rank, mesh1_rank, mesh2_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if mesh0_rank in (0, 1, 2): + if mesh1_rank == 0: + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 2 + expected_shard_size = 2 + else: + expected_shard_offset = 10 * mesh0_rank + 6 + expected_shard_size = 2 + else: + assert mesh1_rank == 1 + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + 3 + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 8 + expected_shard_size = 2 + else: + assert mesh2_rank == 2 + expected_shard_size = 0 + expected_shard_offset = global_shape[0] + else: + assert mesh0_rank == 3 + if mesh1_rank == 0: + if mesh2_rank in (0, 1): + expected_shard_offset = 10 * mesh0_rank + 2 * mesh2_rank + expected_shard_size = 2 + else: + assert mesh2_rank == 2 + expected_shard_offset = 10 * mesh0_rank + 6 + expected_shard_size = 1 + else: + assert mesh1_rank == 1 + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + 3 + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 7 + expected_shard_size = 2 + else: + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + class UtilTest(DTensorTestBase): @property @@ -292,6 +431,78 @@ def test_compute_local_shape_and_global_offset_2D(self): global_tensor[dim0_start:dim0_end, dim1_start:dim1_end], ) + @with_comms + def test_compute_local_shape_and_global_offset_3D(self): + global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) + mesh_size_0 = 2 + mesh_size_1 = 2 + mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1) + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1, mesh_size_2), + mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + Shard(0), + Shard(0), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank = global_mesh.get_coordinate() + self.assertEqual(local_shape, [2, 2 * self.world_size]) + self.assertEqual( + global_offset, (4 * mesh0_rank + 8 * mesh1_rank + 2 * mesh2_rank, 0) + ) + + @with_comms + def test_compute_local_shape_and_global_offset_4D(self): + global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) + mesh_size_0 = 1 + mesh_size_1 = 2 + mesh_size_2 = 2 + mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2) + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3), + mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + _StridedShard(1, split_factor=mesh_size_3), + Shard(0), + Shard(1), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank, mesh3_rank = global_mesh.get_coordinate() + self.assertEqual( + local_shape, (2 * mesh_size_1 * mesh_size_3, 2 * mesh_size_0 * mesh_size_2) + ) + self.assertEqual( + global_offset, + (8 * mesh2_rank + 4 * mesh0_rank, 8 * mesh3_rank + 4 * mesh1_rank), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + _StridedShard(1, split_factor=mesh_size_3), + Shard(0), + Shard(0), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank, mesh3_rank = global_mesh.get_coordinate() + self.assertEqual( + local_shape, (2 * mesh_size_1, 2 * mesh_size_2 * mesh_size_3 * mesh_size_0) + ) + self.assertEqual( + global_offset, + (8 * mesh2_rank + 0 * mesh0_rank + 4 * mesh3_rank, 4 * mesh1_rank), + ) + @with_comms def test_fsdp_tp_meta_compute(self): # FSDP + TP sharding @@ -362,106 +573,6 @@ def test_hsdp_tp_meta_compute(self): self.assertEqual(local_shape, expected_local_shape) self.assertEqual(global_offset, expected_global_offset) - # TODO: remove this test once we support general meta compute on strided sharding - @with_comms - def test_strided_sharding_assumption_in_meta_compute(self): - # current ``compute_local_shape_and_global_offset`` does not allow Shard(i) - # placement to appear after the strided sharding part has ended. This test - # check that ``compute_local_shape_and_global_offset`` does not allow placements - # that violate the assumption and does not forbid the allowed ones. - - # Test 0: 2-D mesh - mesh_size_0 = 2 - mesh_size_1 = self.world_size // mesh_size_0 - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1), - mesh_dim_names=("mesh-0", "mesh-1"), - ) - global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) - - for shard_dim in [0, 1]: - placements = [ - _StridedShard(shard_dim, split_factor=mesh_size_1), - Shard(shard_dim), - ] - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # Test 1: 3-D mesh - mesh_size_0 = 2 - mesh_size_1 = 2 - mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1) - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1, mesh_size_2), - mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"), - ) - - # legal placements: Shard() appear after the strided part but it's on another - # tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - Shard(0), - Shard(1), - ] - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # illegal placements: Shard() appear after the strided part and it's on the - # same tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - Shard(0), - Shard(0), - ] - with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"): - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # Test 2: 4-D mesh - mesh_size_0 = 1 - mesh_size_1 = 2 - mesh_size_2 = 2 - mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2) - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3), - mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"), - ) - # legal placements: Shard() appear after the strided part but it's on another - # tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - _StridedShard(1, split_factor=mesh_size_3), - Shard(0), - Shard(1), - ] - local_shape, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - expected_local_shape = ( - 2 * mesh_size_1 * mesh_size_3, - 2 * mesh_size_0 * mesh_size_2, - ) - self.assertEqual(local_shape, expected_local_shape) - - # illegal placements: Shard() appear after the strided part and it's on the - # same tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - _StridedShard(1, split_factor=mesh_size_3), - Shard(0), - Shard(0), - ] - with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"): - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - class UtilSingleDeviceTest(TestCase): def test_compute_global_tensor_info_unsupported_placement(self): diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 070d8625f50e0..f10a17a3154bd 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -1071,7 +1071,7 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] # get local tensor shape local_shape, _ = compute_local_shape_and_global_offset( - size, device_mesh, placements + size, device_mesh, placements, skip_offset=True ) # initialize the local tensor diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 88a6e4298d246..2312f8e56c554 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -168,7 +168,10 @@ def merge_sharding(dim: str, a: int, b: int) -> int: assert input_spec.tensor_meta is not None global_shape = input_spec.tensor_meta.shape local_shape, _ = compute_local_shape_and_global_offset( - global_shape, input_spec.mesh, input_spec.placements + global_shape, + input_spec.mesh, + input_spec.placements, + skip_offset=True, ) cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) # pyrefly: ignore [bad-argument-type] diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index 5911e4cef1e7d..f633088e946ed 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -1035,7 +1035,7 @@ def local_meta(spec: OpSpec, placements: tuple[Placement, ...]) -> TensorMeta: meta: TensorMeta = spec.output_specs.tensor_meta local_stride = compute_local_stride(meta.stride, mesh, placements) local_shape, _ = compute_local_shape_and_global_offset( - meta.shape, mesh, placements + meta.shape, mesh, placements, skip_offset=True ) return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index f3cbb90dc8f04..11eb7a8ce667b 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -658,7 +658,7 @@ def _adjust_shape_and_stride_args( # adjust shape to be the same as that of the _local_tensor # of the DTensor input arg at index 0, which is inferred expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( - out_tensor_meta.shape, spec.mesh, spec.placements + out_tensor_meta.shape, spec.mesh, spec.placements, skip_offset=True ) # adjust the stride arg for aten.new_empty_strided.default diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index adf0e8e8069a6..aa65dbc08529f 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,12 +1,12 @@ import threading -from collections import defaultdict from collections.abc import Sequence -from typing import cast +from typing import Any, cast, Optional import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._api as dtensor from torch._prims_common import ShapeType +from torch.distributed._local_tensor import maybe_run_for_local_tensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._collective_utils import redistribute_cost from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -17,7 +17,6 @@ Replicate, Shard, ) -from torch.utils._typing_utils import not_none class ExplicitRedistributionContext: @@ -56,61 +55,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): ExplicitRedistributionContext._local._active = self._prev -def _explicit_order_placements( - mesh_shape: ShapeType, placements: Sequence[Placement] -) -> Sequence[tuple[int, Placement]]: - """ - Replace Strided Shards with regular shards in an adjusted order. - - Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order. - - ex. - [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] -> - [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))] - - """ - if not len(placements) == len(mesh_shape): - raise RuntimeError( - "Expected one placement per mesh dim, " - f"but found {len(placements)} placements and {len(mesh_shape)} mesh dims." - ) - ordered = [] - deferred_strided_placements = defaultdict(list) - strided_part_ended_for_dim = set() - for mesh_dim, p in enumerate(placements): - if isinstance(p, _StridedShard): - # validate the stride is the correct multiple of the meshdim and the earlier shard - deferred_strided_placements[p.dim].append((mesh_dim, p)) - - else: - ordered.append((mesh_dim, p)) - if isinstance(p, Shard): - if p.dim in strided_part_ended_for_dim: - raise NotImplementedError( - f"Strided sharding does not allow Shard() to appear after " - f"the strided part has ended. {p} at mesh dim {mesh_dim} in " - f"{placements} violates this assumption." - ) - - if p.dim in deferred_strided_placements: - strided_part_ended_for_dim.add(p.dim) - strided_placements = deferred_strided_placements.pop(p.dim) - aggregate_size = mesh_shape[mesh_dim] - while len(strided_placements) > 0: - strided_mesh_dim, strided = strided_placements.pop() - if not strided.split_factor == aggregate_size: - raise RuntimeError( - f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})" - f" == aggregate mesh size ({aggregate_size})" - ) - aggregate_size *= mesh_shape[strided_mesh_dim] - ordered.append((strided_mesh_dim, Shard(p.dim))) - - return ordered - - def compute_local_shape_and_global_offset( - global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + skip_offset: bool = False, ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Compute the local tensor shape and the global offsets into the original tensor @@ -143,24 +92,68 @@ def compute_local_shape_and_global_offset( global_shape (ShapeType): The global shape of the DTensor. mesh (:class:`DeviceMesh`): The device mesh this DTensor is distributed on. placements (Sequence[:class:`Placement`]]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. Return: local_shape: the shape of the DTensor's _local_tensor on the current rank. global_offset: a tuple of offsets for each dimension of the global tensor shape, - identifying how this shard fits into the global tensor in each dimension. + identifying how this shard fits into the global tensor in each dimension. If + skip_offset is True, this will be an empty tuple. """ return _compute_local_shape_and_global_offset( - global_shape, mesh.shape, mesh.get_coordinate(), placements + global_shape, mesh.shape, mesh.get_coordinate(), placements, skip_offset ) +@maybe_run_for_local_tensor +def _get_shard_size_and_offsets( + curr_local_size: int, + mesh_dim_size: int, + rank: int, + placement: Shard | _StridedShard, + previous_offsets, + zero_global_offset: int, + skip_offset: bool, +) -> tuple[int, Optional[torch.Tensor]]: + kwargs: dict[str, Any] = { + "curr_local_size": curr_local_size, + "num_chunks": mesh_dim_size, + "rank": rank, + } + if isinstance(placement, _StridedShard): + kwargs["return_first_offset"] = False + shard_size, shard_offsets = placement._local_shard_size_and_offset(**kwargs) + if skip_offset: + return shard_size, None + if shard_size == 0: + return shard_size, torch.arange(zero_global_offset, zero_global_offset + 1) + if isinstance(placement, Shard) and not isinstance(placement, _StridedShard): + assert isinstance(shard_offsets, int) + index = torch.arange(shard_offsets, shard_offsets + shard_size) + else: + assert isinstance(shard_offsets, list) + index = torch.tensor(shard_offsets) + if previous_offsets is None: + return shard_size, index + else: + return shard_size, previous_offsets[index] + + +@maybe_run_for_local_tensor +def _get_first_offset(offsets: torch.Tensor) -> int: + return int(offsets[0]) + + # accept 'plain data types' to enable simpler unit testing without creating device mesh def _compute_local_shape_and_global_offset( global_shape: ShapeType, mesh_shape: ShapeType, my_coordinate: list[int] | None, placements: Sequence[Placement], + skip_offset: bool = False, ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Suppose you have a full tensor with size global_shape, and you have sharded @@ -176,85 +169,68 @@ def _compute_local_shape_and_global_offset( This function is fairly simple if your tensor is evenly sharded; the complication is around uneven splits. There is also some complication for handling StridedShard, which changes the order you should apply sharding. + + Args: + global_shape (ShapeType): The global shape of the tensor. + mesh_shape (ShapeType): The shape of the device mesh. + my_coordinate (Optional[list[int]]): The coordinate of the current rank in the device mesh. + placements (Sequence[Placement]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. + + Returns: + tuple: A tuple containing: + - local_shape (tuple[int, ...]): The shape of the local shard on the current rank. + - global_offset (tuple[int, ...]): The offsets for each dimension identifying where + this shard begins in the global tensor. If skip_offset is True, this will be an + empty tuple. """ + empty_offset = () if my_coordinate is None: # if rank not in the mesh, return empty offset - return ((0,), ()) - - # StridedShard implies a non-standard order to apply shards; get the - # correct order to start applying splits - ordered_placements = _explicit_order_placements(mesh_shape, placements) + return ((0,), empty_offset) local_shape = list(global_shape) - # We'll compute the data for where the shard begins on a per-dim basis. - # However, a single dim can be sharded multiple times, so we will end up - # doing a Sum(size*stride) like computation to determine the location of our - # shard for each of the shardings on that dim. + # Perform shard from left to right. For example, + # global tensor: [0, 1, 2, 3, 4, 5, 6, 7] + # placements: S(0), SS(0, split_factor=2) + # mesh_shape: (2, 2) + # After S(0), shard_dim_to_global_offsets are + # {0: [0, 1, 2, 3]} on my_coordinate [0, 0] [0, 1] + # {0: [4, 5, 6, 7]} on my_coordinate [1, 0] [1, 1] + # After SS(0, split_factor=2), shard_dim_to_global_offsets are + # {0: [0, 2]} on my_coordinate [0, 0] + # {0: [1, 3]} on my_coordinate [0, 1] + # {0: [4, 6]} on my_coordinate [1, 0] + # {0: [5, 7]} on my_coordinate [1, 1] + shard_dim_to_global_offsets = {} + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, (Shard, _StridedShard)): + continue + shard_dim = placement.dim + zero_global_offset = global_shape[shard_dim] + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) + previous_offsets = shard_dim_to_global_offsets.get(shard_dim) + shard_size, shard_offsets = _get_shard_size_and_offsets( + local_shape[shard_dim], + mesh_shape[mesh_dim], + my_coordinate[mesh_dim], + placement, + previous_offsets, + zero_global_offset, + skip_offset, + ) + local_shape[shard_dim] = shard_size + shard_dim_to_global_offsets[shard_dim] = shard_offsets + if skip_offset: + return tuple(local_shape), empty_offset global_offset = [0] * len(global_shape) - - for mesh_dim, placement in ordered_placements: - mesh_dim_size = mesh_shape[mesh_dim] - if isinstance(placement, Shard): - shard_dim = placement.dim - assert shard_dim < len(local_shape), ( - f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" - ) - shard_size, shard_offset = placement._local_shard_size_and_offset( - local_shape[shard_dim], - mesh_dim_size, - my_coordinate[mesh_dim], - ) - - local_shape[shard_dim] = shard_size - - shard_global_offset = global_offset[shard_dim] + not_none(shard_offset) - - zero_global_offset = global_shape[shard_dim] - if isinstance(shard_global_offset, torch.SymInt) and not isinstance( - zero_global_offset, torch.SymInt - ): - zero_global_offset = torch.SymInt(zero_global_offset) - - global_offset[shard_dim] = torch.sym_ite( - shard_size == 0, - # Special case to fill in a standardized non-garbage value for - # the global_offset of zero-sized shards. This value is out - # of bounds of the tensor, so it won't conflict with any real - # offsets. DCP may rely on this value to de-duplicate shards. - # Note that you can end up with zero-size shards that are - # still otherwise in bounds for the tensor (TODO: give an - # example). - zero_global_offset, - # As we successively shard the same dimension, we keep - # advancing our pointer beyond our original offset until we - # get to the final chunk start. - shard_global_offset, - ) - - # NOTE: the offset compute relies on the local shard index and it has no - # problem when strided sharding is not present. To correctly compute, we assume - # that the ``_StridedShard.split_factor`` field encodes how many partitions - # each local tensor will be further split into when sharding on higher mesh - # dimensions. However, this number is only correct if the DTensor is not - # sharded after the strided sharding completes. For example, - # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements - # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on - # device mesh dim-2, and last on mesh dim-1. We define the - # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding - # part because strided sharding happens on mesh dim-1 and it was caused by - # the fact that sharding on dim-2 occurred ahead. In this case, there's no - # further sharding after this strided sharding part and ``split_factor`` - # correctly encodes the number. Another example is - # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's - # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh - # dim-2. This violates our assumption that no further sharding shall occur - # after the strided sharding part and ``split_factor`` won't correctly - # encode the number of further split. So far, the only case where _StridedShard - # placement would appear is FSDP2 + TP on 2D mesh and the above case could only - # happen on mesh of 3 or more dimensions. - # TODO: change this function to correctly address this. - # TODO: this logic can be applied to contiguous sharding as well + for shard_dim, global_offsets in shard_dim_to_global_offsets.items(): + global_offset[shard_dim] = _get_first_offset(global_offsets) return tuple(local_shape), tuple(global_offset) diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index a9f253c177ef2..590ec80b8f009 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -684,12 +684,49 @@ def _to_replicate_tensor( def _local_shard_size(sharded_indices: list[torch.Tensor], rank: int) -> int: return len(sharded_indices[rank]) - def _local_shard_size_and_offset( + # delete pyre-ignore once separating _StridedShard from Shard + def _local_shard_size_and_offset( # pyre-ignore[bad-override] self, curr_local_size: int, num_chunks: int, rank: int, - ) -> tuple[int, int | None]: + return_first_offset: bool = True, + ) -> tuple[int, list[int]]: + return _StridedShard.local_shard_size_and_offset( + self, curr_local_size, num_chunks, rank, return_first_offset + ) + + @staticmethod + @maybe_run_for_local_tensor + def local_shard_size_and_offset( # pyre-ignore[bad-override] + self, + curr_local_size: int, + num_chunks: int, + rank: int, + return_first_offset: bool = True, + ) -> tuple[int, list[int] | int]: + """ + Compute the local shard size and offset(s) for a _StridedShard placement. + + Unlike the regular Shard placement which produces contiguous offsets, _StridedShard + produces non-contiguous (strided) offsets due to the right-to-left sharding semantics. + This method computes the actual indices that belong to the local shard. + + Args: + self (_StridedShard): The _StridedShard placement instance. + curr_local_size (int): The current size of the tensor dimension to be sharded. + num_chunks (int): Number of chunks to split the dimension into (typically the mesh dimension size). + rank (int): The rank index to compute the shard for. + return_first_offset (bool): If True, return only the first offset as an int. If False, + return all offsets as a list. Defaults to True. + + Returns: + tuple: A tuple containing: + - local_shard_size (int): The number of elements in the local shard for this rank. + - offset (int | list[int]): If return_first_offset is True, returns the first offset + as an int. If False or if the shard size is 0, returns a list of all offsets + (which may be empty for empty shards). + """ # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [curr_local_size] @@ -707,9 +744,15 @@ def _local_shard_size_and_offset( sharded_indices = [shard.view(-1) for shard in sharded_indices] local_shard_size = _StridedShard._local_shard_size(sharded_indices, rank) + if local_shard_size > 0: + offsets = sharded_indices[rank].tolist() + else: + offsets = [] + + if return_first_offset and len(offsets) > 0: + offsets = offsets[0] - # offsets from _StridedShard is never used - return local_shard_size, None + return local_shard_size, offsets class Replicate(torch._C._distributed.Replicate): From 47b28ddf7bd74b50fa93b307a7d3b183a6d77f54 Mon Sep 17 00:00:00 2001 From: Ali Raza Date: Wed, 3 Dec 2025 09:43:37 +0000 Subject: [PATCH 169/338] Eliminate GPU to CPU sync in linalg.eig for CUDA tensors (#168283) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a CUDA kernel for linalg_eig_make_complex_eigenvectors, eliminating unnecessary GPU→CPU→GPU synchronization when computing eigenvalues and eigenvectors of real-valued CUDA tensors with complex eigenvalues. When torch.linalg.eig() is called on a real-valued CUDA tensor that produces complex eigenvalues, the implementation previously transferred eigenvectors to CPU for decoding, then transferred them back to GPU. This PR keeps all operations on the GPU, improving performance by ~8-10%. ## PR Overview This PR makes linalg_eig_make_complex_eigenvectors dispatchable and adds a CUDA kernel implementation. The changes span 4 commits: **Commit 1: Add dispatch infrastructure** - Declares DECLARE_DISPATCH stub for linalg_eig_make_complex_eigenvectors - File: aten/src/ATen/native/BatchLinearAlgebra.h **Commit 2: Refactor for dispatch** - Moves CPU implementation to dispatch system - Removes GPU-to-CPU synchronization from call site - Preserves existing CPU implementation exactly (rename only) - Files: aten/src/ATen/native/BatchLinearAlgebra.cpp, BatchLinearAlgebraKernel.cpp **Commit 3: Add CUDA kernel** - Implements parallel CUDA kernel for eigenvector decoding - Parallelizes across batch, eigenvectors, and elements - Uses AT_DISPATCH_V2 for float32/float64 support - File: aten/src/ATen/native/cuda/BatchLinearAlgebraEig.cu **Commit 4: Add comprehensive test** - Tests CUDA kernel with rotation matrices (known analytical eigenvalues) - Validates fundamental identity: A@v = λv - Covers complex eigenvalues, real eigenvalues, and batched operations - File: test/test_linalg.py Now, for CUDA tensors, the entire pipeline runs on GPU with no host synchronization. ## Performance Results Benchmarked on NVIDIA H200 with CUDA 12.8 using block-diagonal rotation matrices (which guarantee complex eigenvalues, triggering the eigenvector decoding path). Timing method: Wall-clock with GPU synchronization, 20 iterations after 5 warmup runs. Configuration: batch_size × matrix_rows × matrix_cols (e.g., 1×100×100 = single 100×100 matrix, 10×100×100 = batch of 10 matrices of size 100×100) | Configuration | Before (ms) | After (ms) | Improvement | Speedup | |----------------|-------------|------------|-------------|---------| | 1×100×100 | 11.459 | 10.506 | 0.953 ms | 8.3% | | 1×500×500 | 58.677 | 52.918 | 5.759 ms | 9.8% | | 1×1000×1000 | 123.130 | 110.919 | 12.211 ms | 9.9% | | 10×100×100 | 112.885 | 103.173 | 9.712 ms | 8.6% | | 100×100×100 | 1109.828 | 1024.709 | 85.119 ms | 7.7% | **Average speedup: ~9% for the entire linalg.eig() operation** Fixes #167105 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168283 Approved by: https://github.com/lezcano --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 80 +++--------- aten/src/ATen/native/BatchLinearAlgebra.h | 11 ++ .../ATen/native/BatchLinearAlgebraKernel.cpp | 61 +++++++++ .../ATen/native/cuda/BatchLinearAlgebraEig.cu | 119 ++++++++++++++++++ test/test_linalg.py | 85 +++++++++++++ 5 files changed, 295 insertions(+), 61 deletions(-) create mode 100644 aten/src/ATen/native/cuda/BatchLinearAlgebraEig.cu diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 8ebf50e913a75..40eaa6463de19 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2857,61 +2857,24 @@ Tensor& linalg_eigvalsh_out(const Tensor& A, std::string_view uplo, Tensor& L) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// This function returns complex-valued eigenvectors that is obtained from LAPACK GEEV's real-valued output -// This function is also used for the MAGMA path because intermediate MAGMA's results live on CPU -template -static void linalg_eig_make_complex_eigenvectors_impl(Tensor& result, const Tensor& complex_values, const Tensor& real_vectors) { - // From GEEV documentation: - // Complex conjugate pairs of eigenvalues appear consecutively with the eigenvalue having the positive imaginary part first - // If the j-th eigenvalue is real, then v(j) = VR(:,j), the j-th column of VR. - // If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, then v(j) = VR(:,j) + i*VR(:,j+1) and v(j+1) = VR(:,j) - i*VR(:,j+1). - - auto batch_size = batchCount(real_vectors); - auto n = real_vectors.size(-1); - auto matrix_stride = matrixStride(real_vectors); - - auto result_data = result.data_ptr>(); - auto real_vectors_data = real_vectors.const_data_ptr(); - auto values_data = complex_values.const_data_ptr>(); - - for (auto b = decltype(batch_size){0}; b < batch_size; b++) { - const scalar_t* vecs = &real_vectors_data[b * matrix_stride]; - c10::complex* res = &result_data[b * matrix_stride]; - const c10::complex* vals = &values_data[b * n]; - for (auto j = decltype(n){0}; j < n; j++) { - if (vals[j].imag() == 0.0) { // eigenvalue is real, then v(j) = VR(:,j) - for (auto i = decltype(n){0}; i < n; i++) { - res[j * n + i] = c10::complex(vecs[j * n + i], 0); - } - } else { - for (auto i = decltype(n){0}; i < n; i++) { - res[j * n + i] = c10::complex(vecs[j * n + i], vecs[(j+1) * n + i]); // v(j) = VR(:,j) + i*VR(:,j+1) - res[(j+1) * n + i] = c10::complex(vecs[j * n + i], -vecs[(j+1) * n + i]); // v(j+1) = VR(:,j) - i*VR(:,j+1) - } - j++; - } - } - } -} +DEFINE_DISPATCH(linalg_eig_make_complex_eigenvectors_stub); -static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, const Tensor& complex_values, const Tensor& real_vectors) { - // These asserts make explicit the requirements on tensors for 'linalg_eig_make_complex_eigenvectors_impl' - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.device() == at::kCPU); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.device() == at::kCPU); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.device() == at::kCPU); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.is_complex()); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.is_complex()); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.is_floating_point()); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.mT().is_contiguous()); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.is_contiguous()); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.mT().is_contiguous()); +// Converts LAPACK's real-valued eigenvector encoding to complex eigenvectors. +// This function dispatches to device-specific implementations (CPU or CUDA) based +// on the device type of the input tensors. +void linalg_eig_make_complex_eigenvectors(const Tensor& complex_vectors, const Tensor& complex_values, const Tensor& real_vectors) { + // Device consistency checks + TORCH_CHECK( + complex_vectors.device() == complex_values.device() && + complex_vectors.device() == real_vectors.device(), + "linalg_eig_make_complex_eigenvectors: all tensors must be on the same device"); - AT_DISPATCH_FLOATING_TYPES(real_vectors.scalar_type(), "linalg_eig_make_complex_vector", [&]{ - linalg_eig_make_complex_eigenvectors_impl(complex_vectors, complex_values, real_vectors); - }); - return complex_vectors; + // Dispatch to device-specific implementation + linalg_eig_make_complex_eigenvectors_stub( + complex_vectors.device().type(), + complex_vectors, + complex_values, + real_vectors); } DEFINE_DISPATCH(linalg_eig_stub); @@ -3006,14 +2969,9 @@ static std::tuple linalg_eig_out_info(const Tensor& input, Ten } if (compute_eigenvectors) { if (vectors.is_complex()) { - // We move to the CPU because linalg_eig_make_complex_eigenvectors requires it. - // Performance note: this function could be implemented via a TensorIterator, - // which would avoid an explicit host-device synchronization. - auto vectors_cpu = vectors.cpu(); - auto values_cpu = values.cpu(); - auto maybe_complex_vectors_cpu = maybe_complex_vectors.cpu(); - vectors_cpu = linalg_eig_make_complex_eigenvectors(vectors_cpu, values_cpu, maybe_complex_vectors_cpu); - vectors.copy_(vectors_cpu); + // Decode LAPACK's real eigenvector format into complex eigenvectors + // This now dispatches to device-specific implementations (CPU/CUDA) + linalg_eig_make_complex_eigenvectors(vectors, values, maybe_complex_vectors); } else { TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvectors is non-zero, can't safely cast eigenvectors to non-complex dtype.") } diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index 1b8ce2bdf5417..577bdf000aacf 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -236,6 +236,17 @@ using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/ DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub) +// Converts LAPACK's real-valued eigenvector encoding to complex eigenvectors +TORCH_API void linalg_eig_make_complex_eigenvectors( + const Tensor& complex_vectors, + const Tensor& complex_values, + const Tensor& real_vectors); + +DECLARE_DISPATCH( + void(*)(const Tensor&, const Tensor&, const Tensor&), + linalg_eig_make_complex_eigenvectors_stub) + + using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/); DECLARE_DISPATCH(geqrf_fn, geqrf_stub) diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index fdc0c09124978..bba7a61aeb5f6 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -136,6 +137,59 @@ Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper) return result; } +// This function returns complex-valued eigenvectors that is obtained from LAPACK GEEV's real-valued output +// This function is also used for the MAGMA path because intermediate MAGMA's results live on CPU +template +static void linalg_eig_make_complex_eigenvectors_cpu_impl(const Tensor& result, const Tensor& complex_values, const Tensor& real_vectors) { + // From GEEV documentation: + // Complex conjugate pairs of eigenvalues appear consecutively with the eigenvalue having the positive imaginary part first + // If the j-th eigenvalue is real, then v(j) = VR(:,j), the j-th column of VR. + // If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, then v(j) = VR(:,j) + i*VR(:,j+1) and v(j+1) = VR(:,j) - i*VR(:,j+1). + + auto batch_size = batchCount(real_vectors); + auto n = real_vectors.size(-1); + auto matrix_stride = matrixStride(real_vectors); + + auto result_data = result.data_ptr>(); + auto real_vectors_data = real_vectors.const_data_ptr(); + auto values_data = complex_values.const_data_ptr>(); + + for (auto b = decltype(batch_size){0}; b < batch_size; b++) { + const scalar_t* vecs = &real_vectors_data[b * matrix_stride]; + c10::complex* res = &result_data[b * matrix_stride]; + const c10::complex* vals = &values_data[b * n]; + for (auto j = decltype(n){0}; j < n; j++) { + if (vals[j].imag() == 0.0) { // eigenvalue is real, then v(j) = VR(:,j) + for (auto i = decltype(n){0}; i < n; i++) { + res[j * n + i] = c10::complex(vecs[j * n + i], 0); + } + } else { + for (auto i = decltype(n){0}; i < n; i++) { + res[j * n + i] = c10::complex(vecs[j * n + i], vecs[(j+1) * n + i]); // v(j) = VR(:,j) + i*VR(:,j+1) + res[(j+1) * n + i] = c10::complex(vecs[j * n + i], -vecs[(j+1) * n + i]); // v(j+1) = VR(:,j) - i*VR(:,j+1) + } + j++; + } + } + } +} + +// CPU dispatch kernel +void linalg_eig_make_complex_eigenvectors_cpu(const Tensor& complex_vectors, const Tensor& complex_values, const Tensor& real_vectors) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.mT().is_contiguous()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.is_contiguous()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.mT().is_contiguous()); + + AT_DISPATCH_V2( + real_vectors.scalar_type(), + "linalg_eig_make_complex_eigenvectors_cpu", + AT_WRAP([&] { + linalg_eig_make_complex_eigenvectors_cpu_impl( + complex_vectors, complex_values, real_vectors); + }), + AT_EXPAND(AT_FLOATING_TYPES)); +} + /* LAPACK query functions return workspace size as floating point value, which means that it might not be accurately represented if it's size exceed mantissa of the @@ -1166,6 +1220,13 @@ REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl) REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl) REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl) +REGISTER_ARCH_DISPATCH(linalg_eig_make_complex_eigenvectors_stub, DEFAULT, &linalg_eig_make_complex_eigenvectors_cpu) +REGISTER_AVX512_DISPATCH(linalg_eig_make_complex_eigenvectors_stub, &linalg_eig_make_complex_eigenvectors_cpu) +REGISTER_AVX2_DISPATCH(linalg_eig_make_complex_eigenvectors_stub, &linalg_eig_make_complex_eigenvectors_cpu) +REGISTER_VSX_DISPATCH(linalg_eig_make_complex_eigenvectors_stub, &linalg_eig_make_complex_eigenvectors_cpu) +REGISTER_ZVECTOR_DISPATCH(linalg_eig_make_complex_eigenvectors_stub, &linalg_eig_make_complex_eigenvectors_cpu) +REGISTER_SVE256_DISPATCH(linalg_eig_make_complex_eigenvectors_stub, &linalg_eig_make_complex_eigenvectors_cpu) + REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel) REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel) REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel) diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraEig.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebraEig.cu new file mode 100644 index 0000000000000..3be4b8d953361 --- /dev/null +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraEig.cu @@ -0,0 +1,119 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS + +#include +#include +#include +#include +#include + +namespace at::native { + +namespace { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig_make_complex_eigenvectors ~~~~~~~~~~~~~~~~~~~~~~~ + +// Processes all columns in parallel. For complex conjugate pairs, each thread +// reads from neighboring columns but writes only to its own column. +template +__global__ void linalg_eig_make_complex_eigenvectors_kernel( + c10::complex* __restrict__ result, + const c10::complex* __restrict__ eigenvalues, + const scalar_t* __restrict__ vectors, + const int64_t batch_size, + const int64_t n, + const int64_t matrix_stride) { + + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total_elements = batch_size * n * n; + + if (idx >= total_elements) return; + + const int64_t batch_idx = idx / (n * n); + const int64_t local_idx = idx % (n * n); + const int64_t col = local_idx / n; + const int64_t row = local_idx % n; + + const auto* batch_eigenvalues = eigenvalues + batch_idx * n; + const auto* batch_vectors = vectors + batch_idx * matrix_stride; + auto* batch_result = result + batch_idx * matrix_stride; + + const auto eigenvalue = batch_eigenvalues[col]; + + if (eigenvalue.imag() == scalar_t(0)) { + batch_result[col * n + row] = c10::complex( + batch_vectors[col * n + row], + scalar_t(0)); + } else if (eigenvalue.imag() > scalar_t(0)) { + batch_result[col * n + row] = c10::complex( + batch_vectors[col * n + row], + batch_vectors[(col + 1) * n + row]); + } else { + batch_result[col * n + row] = c10::complex( + batch_vectors[(col - 1) * n + row], + -batch_vectors[col * n + row]); + } +} + +template +void linalg_eig_make_complex_eigenvectors_cuda_impl( + const Tensor& complex_vectors, + const Tensor& complex_values, + const Tensor& real_vectors) { + + const auto n = real_vectors.size(-1); + const auto matrix_stride = matrixStride(real_vectors); + const auto batch_size = batchCount(real_vectors); + + if (batch_size == 0 || n == 0) return; + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.mT().is_contiguous()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.is_contiguous()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.mT().is_contiguous()); + + const int64_t total_elements = batch_size * n * n; + + const int threads = 256; + const int blocks = (total_elements + threads - 1) / threads; + + auto* result_ptr = complex_vectors.data_ptr>(); + const auto* eigenvalues_ptr = complex_values.const_data_ptr>(); + const auto* vectors_ptr = real_vectors.const_data_ptr(); + + linalg_eig_make_complex_eigenvectors_kernel + <<>>( + result_ptr, + eigenvalues_ptr, + vectors_ptr, + batch_size, + n, + matrix_stride); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void linalg_eig_make_complex_eigenvectors_cuda( + const Tensor& complex_vectors, + const Tensor& complex_values, + const Tensor& real_vectors) { + + TORCH_INTERNAL_ASSERT(complex_vectors.is_cuda()); + TORCH_INTERNAL_ASSERT(complex_values.is_cuda()); + TORCH_INTERNAL_ASSERT(real_vectors.is_cuda()); + + c10::cuda::CUDAGuard device_guard(real_vectors.device()); + + AT_DISPATCH_V2( + real_vectors.scalar_type(), + "linalg_eig_make_complex_eigenvectors_cuda", + AT_WRAP([&] { + linalg_eig_make_complex_eigenvectors_cuda_impl( + complex_vectors, complex_values, real_vectors); + }), + AT_EXPAND(AT_FLOATING_TYPES)); +} + +} // anonymous namespace + +REGISTER_CUDA_DISPATCH(linalg_eig_make_complex_eigenvectors_stub, &linalg_eig_make_complex_eigenvectors_cuda) + +} // namespace at::native diff --git a/test/test_linalg.py b/test/test_linalg.py index ed3ca079748fd..cabc561277b35 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -2267,6 +2267,91 @@ def test_eig_check_magma(self, device, dtype): # check correctness using eigendecomposition identity self.assertEqual(a.to(v.dtype) @ v, w * v, atol=1e-3, rtol=1e-3) + @onlyCUDA + @dtypes(torch.float32, torch.float64) + def test_eig_cuda_complex_eigenvectors(self, device, dtype): + """Test CUDA eigenvector decoding with known ground truth, including batching.""" + + # Test 1: Rotation matrix (complex eigenvalues - conjugate pairs) + theta = math.pi / 4 + A_complex = torch.tensor([ + [math.cos(theta), -math.sin(theta)], + [math.sin(theta), math.cos(theta)] + ], dtype=dtype, device=device) + + vals_complex, vecs_complex = torch.linalg.eig(A_complex) + + # Verify eigenvalues are e^(±iθ) for rotation by θ + # For θ = π/4, eigenvalues are e^(±iπ/4) - a conjugate pair + expected_eigenvalue = complex(math.cos(theta), math.sin(theta)) + expected_val = torch.tensor( + expected_eigenvalue, dtype=vals_complex.dtype, device=device + ) + expected_val_conj = torch.tensor( + expected_eigenvalue.conjugate(), dtype=vals_complex.dtype, device=device + ) + # Check both eigenvalues are present and form a conjugate pair + match_0_pos = torch.allclose(vals_complex[0], expected_val, atol=1e-5, rtol=1e-5) + match_0_neg = torch.allclose(vals_complex[0], expected_val_conj, atol=1e-5, rtol=1e-5) + match_1_pos = torch.allclose(vals_complex[1], expected_val, atol=1e-5, rtol=1e-5) + match_1_neg = torch.allclose(vals_complex[1], expected_val_conj, atol=1e-5, rtol=1e-5) + # Valid if (vals[0]=λ AND vals[1]=λ*) OR (vals[0]=λ* AND vals[1]=λ) + self.assertTrue( + (match_0_pos and match_1_neg) or (match_0_neg and match_1_pos), + f"Expected conjugate pair {{λ, λ*}}, got {vals_complex[0]}, {vals_complex[1]}" + ) + + # Verify output is complex type + self.assertTrue(vals_complex.dtype in [torch.complex64, torch.complex128]) + self.assertTrue(vecs_complex.dtype in [torch.complex64, torch.complex128]) + + # Verify Av = λv for all eigenpairs (vectorized) + lhs = A_complex.to(vecs_complex.dtype) @ vecs_complex + rhs = vals_complex.unsqueeze(-2) * vecs_complex + self.assertEqual(lhs, rhs, atol=1e-5, rtol=1e-5) + + # Test 2: Diagonal matrix (all real eigenvalues) + A_real = torch.diag(torch.tensor([1.0, 2.0, 3.0], dtype=dtype, device=device)) + + vals_real, vecs_real = torch.linalg.eig(A_real) + + # Output is still complex type, but imaginary parts should be ~zero + self.assertTrue(torch.allclose(vals_real.imag, torch.zeros_like(vals_real.imag), atol=1e-6)) + # Real parts should match diagonal values + self.assertTrue(torch.allclose( + torch.sort(vals_real.real)[0], + torch.tensor([1., 2., 3.], dtype=dtype, device=device), + atol=1e-6, rtol=1e-6 + )) + + # Verify Av = λv for all eigenpairs (vectorized) + lhs = A_real.to(vecs_real.dtype) @ vecs_real + rhs = vals_real.unsqueeze(-2) * vecs_real + self.assertEqual(lhs, rhs, atol=1e-5, rtol=1e-5) + + # Test 3: Batched - mix of real and complex eigenvalues + A_batch = torch.stack([ + # Rotation (complex eigenvalues) + torch.tensor([ + [math.cos(math.pi / 6), -math.sin(math.pi / 6)], + [math.sin(math.pi / 6), math.cos(math.pi / 6)] + ], dtype=dtype, device=device), + # Diagonal (real eigenvalues) + torch.diag(torch.tensor([4.0, 5.0], dtype=dtype, device=device)), + # Another rotation (complex eigenvalues) + torch.tensor([ + [math.cos(math.pi / 3), -math.sin(math.pi / 3)], + [math.sin(math.pi / 3), math.cos(math.pi / 3)] + ], dtype=dtype, device=device), + ]) + + vals_batch, vecs_batch = torch.linalg.eig(A_batch) + + # Verify Av = λv for all matrices in batch + lhs = A_batch.to(vecs_batch.dtype) @ vecs_batch + rhs = vals_batch.unsqueeze(-2) * vecs_batch + self.assertEqual(lhs, rhs, atol=1e-5, rtol=1e-5) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) From 5f21d27e71268464d362a96c9ac09ea475f7f202 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Wed, 3 Dec 2025 10:24:40 +0000 Subject: [PATCH 170/338] [xpu][feature] Upgrade XPU OneDNN to v3.10.2 (#169443) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169443 Approved by: https://github.com/EikanWang --- cmake/Modules/FindMKLDNN.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 0349b09119cae..7f53dacadef59 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -47,7 +47,7 @@ IF(NOT MKLDNN_FOUND) endif() ExternalProject_Add(xpu_mkldnn_proj GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN - GIT_TAG v3.9.1 + GIT_TAG v3.10.2 PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx From 3d35fd20a78ff4d016fa80f4e5fad37191d7bcae Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 3 Dec 2025 10:31:20 +0000 Subject: [PATCH 171/338] Fix device determination logic in Conditional (#169199) Fixes #169197. Primarily use operand device for output device, use predicate device as a fallback. This is because it's possible that the predicate is on a different device than operands. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169199 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor.py | 39 ++++++++++++++++++++++++++++++ torch/_inductor/ir.py | 6 ++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 1e71936d5653d..6c0c932023638 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -2461,6 +2461,45 @@ def false_fn(x): dynamic_shapes=dynamic_shapes, ) + @common_utils.parametrize("max_autotune", [False, True]) + def test_cond_cpu_predicate_cuda_operands(self, max_autotune): + """ + Test torch.cond with CPU predicate and CUDA operands. + This is a regression test for the bug where inductor incorrectly + determined device from [predicate] + operands, causing CPU predicates + to force CUDA outputs onto CPU during autotuning. + """ + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self, input_dim=4, hidden_dim=8): + super().__init__() + self.true_linear = torch.nn.Linear(input_dim, hidden_dim, bias=True) + self.false_linear = torch.nn.Linear(input_dim, hidden_dim, bias=True) + self.another_linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, predicate: torch.Tensor, x: torch.Tensor): + def true_fn(x): + return self.true_linear(x) * 2.0 + + def false_fn(x): + return self.false_linear(x) + 1.0 + + res = torch.cond(predicate, true_fn, false_fn, (x,)) + return self.another_linear(res) + + # Predicate on CPU, data on CUDA + predicate = torch.tensor(True, dtype=torch.bool, device="cpu") + x = torch.randn(4, 4, device=self.device) + example_inputs = (predicate, x) + + with config.patch({"max_autotune": max_autotune}): + self.check_model( + Model().to(self.device), + example_inputs=example_inputs, + ) + def test_while_loop_simple(self): inputs = ( torch.randn((10, 20), device=self.device), diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b4bc3bbf19e88..de4b4ab20a779 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -8835,9 +8835,13 @@ def create( assert t_o.get_dtype() == f_o.get_dtype(), (i, t_o, f_o) assert t_o.get_layout().offset == f_o.get_layout().offset, (i, t_o, f_o) + # Determine device from operands and predicate + # The predicate can be on a different device (e.g., CPU for control flow) + # while the data operands and outputs should be on the compute device, so + # using predicate device as a fallback. device = next( o.get_device() - for o in [predicate] + operands + for o in operands + [predicate] if not isinstance(o, ShapeAsConstantBuffer) ) unbacked_bindings = resolve_unbacked_bindings( From 6ff831180d2fa436c7f1c1af3adac641fce9d60e Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Wed, 3 Dec 2025 10:53:02 +0000 Subject: [PATCH 172/338] [xpu] Enable TransformerEncoderLayer Fast Path for XPU Device (#168234) This PR aims to fix the device compatibility check for the Fast Path (high-performance C++ fusion operation) within torch.nn.TransformerEncoderLayer. In the forward method of TransformerEncoderLayer, the device whitelist for the Fast Path implementation by default only includes "cpu" and "cuda". This change explicitly adds the "xpu" device type to the supported list. This ensures that when XPU tensors are used as input, the code correctly selects and executes the highly optimized C++ fused kernel (torch._transformer_encoder_layer_fwd), instead of falling back to the slower, unoptimized Python implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168234 Approved by: https://github.com/guangyey, https://github.com/albanD --- torch/nn/modules/transformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index abcd7240a742c..ed35224423aa6 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -511,6 +511,7 @@ def forward( _supported_device_type = [ "cpu", "cuda", + "xpu", torch.utils.backend_registration._privateuse1_backend_name, ] if torch.overrides.has_torch_function(tensor_args): @@ -895,6 +896,7 @@ def forward( _supported_device_type = [ "cpu", "cuda", + "xpu", torch.utils.backend_registration._privateuse1_backend_name, ] if torch.overrides.has_torch_function(tensor_args): From 9b3e34d8589b29f7b4e7fab6f78711b7ca6e4639 Mon Sep 17 00:00:00 2001 From: linhaifeng <1371675203@qq.com> Date: Wed, 3 Dec 2025 10:55:34 +0000 Subject: [PATCH 173/338] [MPS][BugFix] Fix MaxPool2d/MaxPool3d output size validation (Issue #168246) (#168332) Fix MaxPool2d and MaxPool3d on Apple MPS backend to properly validate output dimensions and raise errors for invalid kernel sizes, matching CPU behavior. Fixes #168246 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168332 Approved by: https://github.com/malfet, https://github.com/cyyever Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .../src/ATen/native/mps/operations/Pooling.mm | 45 ++++++++++++++++++- test/nn/test_pooling.py | 1 - .../_internal/common_methods_invocations.py | 11 +++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index ecd5f12df17f8..a8e25389b25a3 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -369,7 +369,8 @@ static PoolSizes process_pool_sizes(const Tensor& input, out_size += stride_expanded[dim] - 1; } - out_size = out_size / stride_expanded[dim] + 1; + // Use div_rtn for proper floor division (matching CPU behavior) + out_size = div_rtn(out_size, static_cast(stride_expanded[dim])) + 1; if (ceil_mode) { if (((out_size - 1) * stride_expanded[dim]) >= (input.size(leading_dims + dim) + padding_expanded[dim])) { @@ -387,6 +388,48 @@ static PoolSizes process_pool_sizes(const Tensor& input, output_size[leading_dims + dim] = output_pooling_size[dim]; } + // Validate output sizes using the same shape check functions as CPU/CUDA + if (pooling_dims == 2) { + const auto memory_format = input.suggest_memory_format(); + pool2d_shape_check(input, + kernel_size_expanded[0], + kernel_size_expanded[1], + stride_expanded[0], + stride_expanded[1], + padding_expanded[0], + padding_expanded[1], + dilation_expanded[0], + dilation_expanded[1], + input.size(leading_dims - 1), + input.size(leading_dims), + input.size(leading_dims + 1), + output_pooling_size[0], + output_pooling_size[1], + memory_format); + } else if (pooling_dims == 3) { + pool3d_shape_check(input, + input.size(leading_dims - 1), + kernel_size_expanded[0], + kernel_size_expanded[1], + kernel_size_expanded[2], + stride_expanded[0], + stride_expanded[1], + stride_expanded[2], + padding_expanded[0], + padding_expanded[1], + padding_expanded[2], + dilation_expanded[0], + dilation_expanded[1], + dilation_expanded[2], + input.size(leading_dims), + input.size(leading_dims + 1), + input.size(leading_dims + 2), + output_pooling_size[0], + output_pooling_size[1], + output_pooling_size[2], + op_name.c_str()); + } + return PoolSizes(dims, output_size, kernel_size_expanded, diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index f20ee2a29d573..f5240031def91 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -2045,7 +2045,6 @@ def helper(pool): helper(nn.AdaptiveAvgPool2d((2**6, 2**6))) @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) - @expectedFailureMPS @dtypes(torch.float) def test_pool_invalid_size(self, device, dtype): for op in ("max", "avg"): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5f3454ef54cca..4578789eddf22 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3801,6 +3801,11 @@ def error_inputs_max_pool2d(op_info, device, **kwargs): kwargs={'kernel_size': 1}), error_regex=err_msg) + # error: inputs when kernel size too large for input + yield ErrorInput(SampleInput(make_arg((1, 1, 4)), + kwargs={'kernel_size': 2}), + error_regex='Output size is too small') + def error_inputs_max_pool3d(op_info, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) @@ -3833,6 +3838,12 @@ def error_inputs_max_pool3d(op_info, device, **kwargs): kwargs={'kernel_size': 1}), error_regex=err_msg) + # error: inputs when kernel size too large for input + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 4, 4)), + kwargs={'kernel_size': 2}), + error_regex='Output size is too small') + + def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad) From 6ceb4a32f92ae67ce5d7d97931d17401ebf5ffa5 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 3 Dec 2025 09:01:25 +0000 Subject: [PATCH 174/338] Remove `LocalGeneratorObjectVariable::_get_inline_tracer()` (#169306) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169306 Approved by: https://github.com/zou3519 --- torch/_dynamo/variables/functions.py | 29 +++++++++++----------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 02bbcebe5c02a..f493e0e1fd961 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -98,6 +98,8 @@ if TYPE_CHECKING: from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import ( + InliningGeneratorInstructionTranslator, + InliningInstructionTranslator, InstructionTranslator, InstructionTranslatorBase, ) @@ -899,7 +901,7 @@ def __init__( self, code: types.CodeType, f_globals: dict[str, Any], - inline_tracer: Optional["InstructionTranslator"], + inline_tracer: "InliningGeneratorInstructionTranslator", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -944,7 +946,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: temp = temporarely_allow_writes_to_output_graph(tx) with save, disallow, temp: - tracer = self._get_inline_tracer(tx) + tracer = self.inline_tracer if not tracer.generator_exhausted: self.remaining_items = self.force_unpack_var_sequence(tx) variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) @@ -963,17 +965,8 @@ def get_globals(self) -> dict[str, Any]: def python_type(self) -> type: return types.GeneratorType - def _get_inline_tracer(self, tx: "InstructionTranslator") -> Any: - from torch._dynamo.symbolic_convert import InliningInstructionTranslator - - if self.inline_tracer is None: - self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( # type: ignore[assignment] - tx, self, [], {} - ) - return self.inline_tracer - def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: - tracer = self._get_inline_tracer(tx) + tracer = self.inline_tracer if self._is_generator_exhausted(): raise_observed_exception(StopIteration, tx) @@ -1030,7 +1023,7 @@ def should_allow_nested_graph_breaks(self): def _setup_exception( self, tx: "InstructionTranslator", exc: VariableTracker ) -> None: - tracer = self._get_inline_tracer(tx) + tracer = self.inline_tracer try: tracer._raise_exception_variable(exc) except ObservedException as e: @@ -1068,7 +1061,7 @@ def call_method( for arg in args ): raise_observed_exception(TypeError, tx) - tracer = self._get_inline_tracer(tx) + tracer = self.inline_tracer tracer.push_many(args) return self.next_variable(tx) elif name == "close": @@ -1085,7 +1078,7 @@ def call_method( # Return None if close is called on a just-started generator # See test GeneratorCloseCpythonTests::test_close_not_started - tracer = self._get_inline_tracer(tx) + tracer = self.inline_tracer if self._is_generator_just_started() or self._is_generator_exhausted(): tracer.generator_exhausted = True return variables.ConstantVariable(None) @@ -1145,7 +1138,7 @@ def call_method( # or raises a different exception, then that exception propagates to the caller. # Setup the exception table and jump target in case of try...finally - tracer = self._get_inline_tracer(tx) + tracer = self.inline_tracer try: # In Python 3.9, the exception is represented as a triple (typ, val, tb) # In such cases, we re-raise the exception object given to avoid @@ -1278,7 +1271,7 @@ def _build_inline_tracer( tx: "InstructionTranslatorBase", args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "InstructionTranslatorBase": + ) -> "InliningInstructionTranslator": from torch._dynamo.symbolic_convert import InliningInstructionTranslator return InliningInstructionTranslator.build_inline_tracer( @@ -1340,7 +1333,7 @@ def _build_inline_tracer( tx: "InstructionTranslatorBase", args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "InstructionTranslatorBase": + ) -> "InliningGeneratorInstructionTranslator": # NOTE: This only exists to not break support for context manager when # config.enable_faithful_generator_behavior = False and # config.enable_trace_contextlib = True. In case the former is false, From 78adb3b3df41b45d2368b67226d2f864b78939a6 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 3 Dec 2025 09:01:26 +0000 Subject: [PATCH 175/338] [generator] Close all open generators in compile_subgraph (#157149) **Motivation and Example** Explictly close all open generators in compile_subgraph to ensure that all remaining finally blocks are executed. In CPython this is done by invoking the `tp_finalize` function of the generator object, which triggers `genclose`: https://github.com/python/cpython/blob/58a42dea97f4fa0df38ef4a95a2ede65e0549f71/Objects/genobject.c#L128-L134 ```python import gc def whoo(t): nonlocal z z = 0 try: z += 1 yield t.sin() except ValueError: z += 10 yield t.cos() except RuntimeError: z += 100 yield t.tan() finally: z += 1000 z += 10_000 gen = whoo(t) a = next(gen) b = gen.throw(RuntimeError) gc.collect() print(z) # 1101 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157149 Approved by: https://github.com/zou3519 ghstack dependencies: #169306 --- test/dynamo/test_generator.py | 8 +- ...ns-ExceptionTests.test_generator_leaking3} | 0 ...onTest.test_except_throw_exception_context | 0 torch/_dynamo/output_graph.py | 5 + torch/_dynamo/side_effects.py | 20 ++++ torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/variables/functions.py | 95 +++---------------- 7 files changed, 45 insertions(+), 85 deletions(-) rename test/dynamo_expected_failures/{CPython313-test_generators-ExceptionTest.test_except_throw => CPython313-test_exceptions-ExceptionTests.test_generator_leaking3} (100%) delete mode 100644 test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index c02126c7404ff..2a0bd874f881c 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1009,7 +1009,7 @@ def test_close_with_side_effects(self): z = 0 def whoo(t): - nonlocal z + nonlocal z # noqa: F824 try: L.append(1) yield t.sin() @@ -1050,7 +1050,6 @@ def whoo(t): @torch.compile(backend="eager", fullgraph=True) def fn(t): - nonlocal z gen = whoo(t) i = next(gen) y = gen.close() @@ -1078,7 +1077,6 @@ def whoo(t): @torch.compile(backend="eager", fullgraph=fullgraph) def fn(t): - nonlocal z gen = whoo(t) i = next(gen) gen.close() @@ -1380,8 +1378,10 @@ def fn(t): a = next(gen) try: gen.throw(ValueError) - except StopIteration: + except StopIteration as e: + assert len(e.args) == 0 return a + raise AssertionError("Expected StopIteration") t = torch.randn(2) y = self._compile_check(fn, (t,)) diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking3 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw rename to test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking3 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 6ff908ff0394f..7374035854c4f 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1629,6 +1629,11 @@ def compile_subgraph( ) self.codegen_suffix(tx, stack_values_flat, pass1) + # Close all generators opened while tracing. Needs to be done after + # pass1, as PyCodegen might try to reconstruct the generator, which + # sets LocalGeneratorObjectVariable.remaining_items + self.side_effects.close_local_generators() + # Use `pass1.uses` to selectively cache multi-user variables into a # temporary local source. This (a). speeds up loading VTs with long # chained source, and (b). avoids redundantly saving single-user VT diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 999bd145c3e57..594c7fd7060aa 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -59,6 +59,7 @@ if TYPE_CHECKING: from torch._dynamo.output_graph import OutputGraph from torch._dynamo.symbolic_convert import InstructionTranslatorBase + from torch._dynamo.variables.functions import LocalGeneratorObjectVariable from torch._dynamo.variables.lists import ListVariable @@ -134,6 +135,7 @@ def __init__( self.keepalive = keepalive or [] self.save_for_backward = save_for_backward or [] self.tensor_hooks = tensor_hooks or {} + self.local_generators: list[LocalGeneratorObjectVariable] = [] # Used by MappingProxyVariable to graph break in case of any mutated # dict self._has_existing_dict_mutation = False @@ -227,6 +229,24 @@ def should_allow_side_effects_in_hop(self) -> bool: and output_graph.current_tx.output.current_tracer.allow_side_effects_in_hop ) + def track_generator(self, gen: "LocalGeneratorObjectVariable") -> None: + self.local_generators.append(gen) + + def untrack_generator(self, gen: "LocalGeneratorObjectVariable") -> None: + self.local_generators.remove(gen) + + def close_local_generators(self) -> None: + from .symbolic_convert import temporarely_allow_writes_to_output_graph + + output_graph = self.output_graph_weakref() + if output_graph: + tx = output_graph.root_tx + with temporarely_allow_writes_to_output_graph(tx): + for gen in self.local_generators: + if not gen._is_generator_exhausted(): + # pyrefly: ignore[bad-argument-type] + gen.call_method(tx, "close", [], {}) + def is_reconstructing_generator(self) -> bool: output_graph = self.output_graph_weakref() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index f401b9d6178b9..5410de2f4365e 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1016,7 +1016,7 @@ class ExceptionStack: # and "stack" sometimes refers to a C variable with the same name and the # exception stack, respectively. # - # The lifetime of an exception is (Python 3.11+): + # The lifetime of an exception in Python 3.11+ is: # + tx._raise_exception_variable(...) := sets the current_exception variable # + PUSH_EXC_INFO := pushes the current_exception to the *exception stack* # + POP_EXCEPT := pops TOS from the *exception stack* diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index f493e0e1fd961..df37a5d9a4cbc 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -45,7 +45,6 @@ from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..exc import ( format_skip_frame_message, - get_dynamo_observed_exception, handle_observed_exception, InfiniteGeneratorError, ObservedException, @@ -908,6 +907,7 @@ def __init__( self.code = code self.f_globals = f_globals self.inline_tracer = inline_tracer + inline_tracer.output.side_effects.track_generator(self) def get_code(self) -> types.CodeType: return self.code @@ -976,9 +976,12 @@ def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: # created on call_function. Any exception needs to be propagated to tx # for Dynamo to behave correctly return tracer.inline_call_() - except ObservedException as e: + except ObservedUserStopIteration: + tracer.output.side_effects.untrack_generator(self) + raise + except ObservedException: tracer.generator_exhausted = True - raise e + raise except InfiniteGeneratorError: # test/dynamo/test_misc.py::test_iterator_limit raise @@ -1020,9 +1023,10 @@ def force_apply_to_var_sequence( def should_allow_nested_graph_breaks(self): return False - def _setup_exception( + def _setup_and_raise_exception( self, tx: "InstructionTranslator", exc: VariableTracker ) -> None: + # Raise an exception at the point where the generator is paused tracer = self.inline_tracer try: tracer._raise_exception_variable(exc) @@ -1086,7 +1090,7 @@ def call_method( # Raise GeneratorExit to see if user code catches it. Any other exception # is propagated to the parent frame. try: - self._setup_exception( + self._setup_and_raise_exception( tx, variables.ExceptionVariable(GeneratorExit, ()) ) # There's an extra block on Python 3.12+ to handle StopIteration @@ -1135,7 +1139,7 @@ def call_method( # returns the next value yielded by the generator. # * If the generator exits without yielding, raise StopIteration # * If the generator function does not catch the passed-in exception, - # or raises a different exception, then that exception propagates to the caller. + # or raises a different exception, then that new exception propagates to the caller. # Setup the exception table and jump target in case of try...finally tracer = self.inline_tracer @@ -1144,84 +1148,15 @@ def call_method( # In such cases, we re-raise the exception object given to avoid # creating a new object, so that IS_OP works. # See: https://github.com/pytorch/pytorch/pull/146496 - self._setup_exception(tx, args[1] if len(args) == 3 else args[0]) + self._setup_and_raise_exception( + tx, args[1] if len(args) == 3 else args[0] + ) except ObservedException: # noqa: TRY203 # propagate the exception back to the parent caller raise - retval = self.next_variable(tx) - - # The exception raised before is still active. We need to check the exception - # table one more time to find the next target. But why? Let's walk - # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M - # - # z = 0 - # def whoo(): - # global z - # z = 0 - # try: - # yield 1 - # except ValueError: - # yield 2 - # finally: - # z += 1 - # z += 10 - # - # gen = whoo() - # next(gen) - # gen.throw(ValueError) - # print('z', z) -> z = 1 - # - # ... - # >> 58 PUSH_EXC_INFO - # - # 8 60 LOAD_GLOBAL 2 (ValueError) - # 70 CHECK_EXC_MATCH - # 72 POP_JUMP_IF_FALSE 7 (to 88) - # 74 POP_TOP - # - # 9 76 LOAD_CONST 3 (2) - # 78 YIELD_VALUE 3 <------ ValueError is still active here - # 80 RESUME 1 - # 82 POP_TOP - # 84 POP_EXCEPT - # 86 jump_backward 34 (to 20) - # ... - # - # ExceptionTable: - # 4 to 8 -> 124 [0] lasti - # 12 to 18 -> 58 [0] - # 20 to 56 -> 124 [0] lasti - # 58 to 82 -> 90 [1] lasti <------ move to 90 - # 84 to 86 -> 96 [0] - # 88 to 88 -> 90 [1] lasti - # 90 to 94 -> 96 [0] - # 96 to 116 -> 118 [1] lasti - # 118 to 122 -> 124 [0] lasti - # - # In this scenario, a generator can yield after `throw()` is called. Even - # after the exception is raised a few lines above, it remains active - # within the `78 YIELD_VALUE` instruction. When the generator resumes - # after the second yield on instruction `80 RESUME`, we cannot simply - # return the control flow to the next instruction. Instead, one must - # check the exception table (or equivalent) to find the next target - # In this case, it says the instruction pointer must be moved to 90. - # - # Without this step, if we let the trace proceed to the next - # instruction, it would follow the control flow where the exception - # raised by `throw()` was handled and swallowed, potentially leading - # to incorrect behavior. - exc_type = type("__InternalThrowException", (Exception,), {}) - - try: - self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) - self.next_variable(tx) - except get_dynamo_observed_exception(exc_type): - # We should get back the exception raised before. - pass - else: - raise_observed_exception(RuntimeError, tracer) - return retval + # If reaches here, it means user code captured the exception + return self.next_variable(tx) return super().call_method(tx, name, args, kwargs) From 9f0df5686cb4ada94f94620acba2e3c3f363b11d Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 3 Dec 2025 14:00:40 +0000 Subject: [PATCH 176/338] Restore LD_LIBRARY_PATH after binary wheel checks (#169444) Fixes: https://github.com/pytorch/pytorch/issues/169331 Looks like https://github.com/pytorch/pytorch/pull/168349 caused the issue above. Make this check less intrusive by restoring LD_LIBRARY_PATH to previous value Pull Request resolved: https://github.com/pytorch/pytorch/pull/169444 Approved by: https://github.com/yangw-dev, https://github.com/desertfire --- .ci/pytorch/check_binary.sh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.ci/pytorch/check_binary.sh b/.ci/pytorch/check_binary.sh index 95d57f35ce4bd..c8c89fe871fe3 100755 --- a/.ci/pytorch/check_binary.sh +++ b/.ci/pytorch/check_binary.sh @@ -26,6 +26,8 @@ set -eux -o pipefail # libtorch package. # ensure we don't link to system libraries, linked libraries should be found from RPATH +# Save the old LD_LIBRARY_PATH to restore it later +OLD_LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-}" unset LD_LIBRARY_PATH if [[ -z ${DESIRED_PYTHON:-} ]]; then @@ -308,3 +310,10 @@ except RuntimeError as e: exit 1 fi fi + +############################################################################### +# Restore LD_LIBRARY_PATH to its original value +############################################################################### +if [[ -n "$OLD_LD_LIBRARY_PATH" ]]; then + export LD_LIBRARY_PATH="$OLD_LD_LIBRARY_PATH" +fi From 8d56e98c8db988a22cb2dfaeefb30bc7d2a3cc43 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 1 Dec 2025 11:38:24 -0800 Subject: [PATCH 177/338] Smoke tests for binary symbol hiding (#169307) This was extracted from https://github.com/pytorch/pytorch/pull/167496 and https://github.com/pytorch/pytorch/pull/167855 Landing smoke tests as separate PR per offline discussion with @atalman Pull Request resolved: https://github.com/pytorch/pytorch/pull/169307 Approved by: https://github.com/atalman ghstack dependencies: #167496 --- .../smoke_test/check_binary_symbols.py | 348 ++++++++++++++++++ 1 file changed, 348 insertions(+) diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index b0c607659c72d..7ad10ca946215 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -100,6 +100,347 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None: ) +def _compile_and_extract_symbols( + cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None +) -> list[str]: + """ + Helper to compile a C++ file and extract all symbols. + + Args: + cpp_content: C++ source code to compile + compile_flags: Compilation flags + exclude_list: List of symbol names to exclude. Defaults to ["main"]. + + Returns: + List of all symbols found in the object file (excluding those in exclude_list). + """ + import subprocess + import tempfile + + if exclude_list is None: + exclude_list = ["main"] + + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + cpp_file = tmppath / "test.cpp" + obj_file = tmppath / "test.o" + + cpp_file.write_text(cpp_content) + + result = subprocess.run( + compile_flags + [str(cpp_file), "-o", str(obj_file)], + capture_output=True, + text=True, + timeout=60, + ) + + if result.returncode != 0: + raise RuntimeError(f"Compilation failed: {result.stderr}") + + symbols = get_symbols(str(obj_file)) + + # Return all symbol names, excluding those in the exclude list + return [name for _addr, _stype, name in symbols if name not in exclude_list] + + +def check_stable_only_symbols(install_root: Path) -> None: + """ + Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code. + + This approach tests: + 1. WITHOUT macros -> many torch symbols exposed (compilation succeeds) + 2. WITH TORCH_STABLE_ONLY -> compilation fails with #error directive + 3. WITH TORCH_TARGET_VERSION -> compilation fails with #error directive + 4. WITH both macros -> compilation fails with #error directive + """ + import subprocess + import tempfile + + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + test_cpp_content = """ +// Main torch C++ API headers +#include +#include + +// ATen tensor library +#include + +// Core c10 headers (commonly used) +#include +#include +#include +#include +#include + +int main() { return 0; } +""" + + base_compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", # Compile only, don't link + ] + + # Compile WITHOUT any macros - should succeed + symbols_without = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=base_compile_flags, + ) + + # We expect constexpr symbols, inline functions used by other headers etc. + # to produce symbols + num_symbols_without = len(symbols_without) + print(f"Found {num_symbols_without} symbols without any macros defined") + assert num_symbols_without != 0, ( + "Expected a non-zero number of symbols without any macros" + ) + + # Helper to verify compilation fails with expected error + def _expect_compilation_failure(compile_flags: list[str], macro_name: str) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + cpp_file = tmppath / "test.cpp" + obj_file = tmppath / "test.o" + + cpp_file.write_text(test_cpp_content) + + result = subprocess.run( + compile_flags + [str(cpp_file), "-o", str(obj_file)], + capture_output=True, + text=True, + timeout=60, + ) + + if result.returncode == 0: + raise RuntimeError( + f"Expected compilation to fail with {macro_name} defined, but it succeeded" + ) + + stderr = result.stderr + expected_error_msg = ( + "This file should not be included when either TORCH_STABLE_ONLY " + "or TORCH_TARGET_VERSION is defined." + ) + + if expected_error_msg not in stderr: + raise RuntimeError( + f"Expected error message to contain:\n '{expected_error_msg}'\n" + f"but got:\n{stderr[:1000]}" + ) + + print(f"Compilation correctly failed with {macro_name} defined") + + compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"] + _expect_compilation_failure(compile_flags_with_stable_only, "TORCH_STABLE_ONLY") + + compile_flags_with_target_version = base_compile_flags + [ + "-DTORCH_TARGET_VERSION=1" + ] + _expect_compilation_failure( + compile_flags_with_target_version, "TORCH_TARGET_VERSION" + ) + + compile_flags_with_both = base_compile_flags + [ + "-DTORCH_STABLE_ONLY", + "-DTORCH_TARGET_VERSION=1", + ] + _expect_compilation_failure(compile_flags_with_both, "both macros") + + +def check_stable_api_symbols(install_root: Path) -> None: + """ + Test that stable API headers still expose symbols with TORCH_STABLE_ONLY. + The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + stable_dir = include_dir / "torch" / "csrc" / "stable" + assert stable_dir.exists(), f"Expected {stable_dir} to be present" + + stable_headers = list(stable_dir.rglob("*.h")) + if not stable_headers: + raise RuntimeError("Could not find any stable headers") + + includes = [] + for header in stable_headers: + rel_path = header.relative_to(include_dir) + includes.append(f"#include <{rel_path.as_posix()}>") + + includes_str = "\n".join(includes) + test_stable_content = f""" +{includes_str} +int main() {{ return 0; }} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_stable = _compile_and_extract_symbols( + cpp_content=test_stable_content, + compile_flags=compile_flags, + ) + num_symbols_stable = len(symbols_stable) + print(f"Found {num_symbols_stable} symbols in torch/csrc/stable") + assert num_symbols_stable > 0, ( + f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_stable} symbols" + ) + + +def check_headeronly_symbols(install_root: Path) -> None: + """ + Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # Find all headers in torch/headeronly + headeronly_dir = include_dir / "torch" / "headeronly" + assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present" + headeronly_headers = list(headeronly_dir.rglob("*.h")) + if not headeronly_headers: + raise RuntimeError("Could not find any headeronly headers") + + # Filter out platform-specific headers that may not compile everywhere + platform_specific_keywords = [ + "cpu/vec", + ] + + filtered_headers = [] + for header in headeronly_headers: + rel_path = header.relative_to(include_dir).as_posix() + if not any( + keyword in rel_path.lower() for keyword in platform_specific_keywords + ): + filtered_headers.append(header) + + includes = [] + for header in filtered_headers: + rel_path = header.relative_to(include_dir) + includes.append(f"#include <{rel_path.as_posix()}>") + + includes_str = "\n".join(includes) + test_headeronly_content = f""" +{includes_str} +int main() {{ return 0; }} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_headeronly = _compile_and_extract_symbols( + cpp_content=test_headeronly_content, + compile_flags=compile_flags, + ) + num_symbols_headeronly = len(symbols_headeronly) + print(f"Found {num_symbols_headeronly} symbols in torch/headeronly") + assert num_symbols_headeronly > 0, ( + f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_headeronly} symbols" + ) + + +def check_aoti_shim_symbols(install_root: Path) -> None: + """ + Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # There are no constexpr symbols etc., so we need to actually use functions + # so that some symbols are found. + test_shim_content = """ +#include +int main() { + int32_t (*fp1)() = &aoti_torch_device_type_cpu; + int32_t (*fp2)() = &aoti_torch_dtype_float32; + (void)fp1; (void)fp2; + return 0; +} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_shim = _compile_and_extract_symbols( + cpp_content=test_shim_content, + compile_flags=compile_flags, + ) + num_symbols_shim = len(symbols_shim) + assert num_symbols_shim > 0, ( + f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_shim} symbols" + ) + + +def check_stable_c_shim_symbols(install_root: Path) -> None: + """ + Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # Check if the stable C shim exists + stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h" + if not stable_shim.exists(): + raise RuntimeError("Could not find stable c shim") + + # There are no constexpr symbols etc., so we need to actually use functions + # so that some symbols are found. + test_stable_shim_content = """ +#include +int main() { + // Reference stable C API functions to create undefined symbols + AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string; + AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads; + (void)fp1; (void)fp2; + return 0; +} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_stable_shim = _compile_and_extract_symbols( + cpp_content=test_stable_shim_content, + compile_flags=compile_flags, + ) + num_symbols_stable_shim = len(symbols_stable_shim) + assert num_symbols_stable_shim > 0, ( + f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_stable_shim} symbols" + ) + + def check_lib_symbols_for_abi_correctness(lib: str) -> None: print(f"lib: {lib}") cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS) @@ -129,6 +470,13 @@ def main() -> None: check_lib_symbols_for_abi_correctness(libtorch_cpu_path) check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path) + # Check symbols when TORCH_STABLE_ONLY is defined + check_stable_only_symbols(install_root) + check_stable_api_symbols(install_root) + check_headeronly_symbols(install_root) + check_aoti_shim_symbols(install_root) + check_stable_c_shim_symbols(install_root) + if __name__ == "__main__": main() From 597930f6b568852356ca9795dac76f9e4653adbd Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 3 Dec 2025 15:07:21 +0000 Subject: [PATCH 178/338] Correctly set max_numwarps in coordinate_descent_tuner (#159146) Current max_numwarps is incorrect on ROCm as warp_size is not taken into account. This PR resolves this and handles in a none hardcoded way using device props when available. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159146 Approved by: https://github.com/jansel, https://github.com/shunting314 --- torch/_inductor/runtime/coordinate_descent_tuner.py | 11 ++++++++--- torch/_inductor/runtime/hints.py | 2 ++ torch/_inductor/runtime/triton_heuristics.py | 5 +++++ torch/_inductor/utils.py | 11 +++++++++++ torch/csrc/cuda/Module.cpp | 2 ++ 5 files changed, 28 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 36bd64cbae280..91736febd29f6 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -7,6 +7,7 @@ from torch.utils._ordered_set import OrderedSet +from ..utils import get_max_numwarps from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -81,9 +82,13 @@ def get_config_max(self, prefix: str) -> int: return min(max_block, size_hint) if size_hint is not None else max_block def get_warpsmax(self): - # Currently, CUDA has a maximum of 1024 threads, so 32 is the max - # number of warps. - return 1024 // 32 + # Avoid querying device directly if device properties are populated in inductor_meta + warp_size = self.inductor_meta.get("warp_size") + max_threads_per_block = self.inductor_meta.get("max_threads_per_block") + if warp_size and max_threads_per_block: + return max_threads_per_block // warp_size + else: + return get_max_numwarps() def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 7e7409c698e90..a9ddf91e9a59c 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -135,6 +135,7 @@ class DeviceProperties(typing.NamedTuple): major: int | None = None regs_per_multiprocessor: int | None = None max_threads_per_multi_processor: int | None = None + max_threads_per_block: int | None = None warp_size: int | None = None @classmethod @@ -169,6 +170,7 @@ def create(cls, device) -> DeviceProperties: max_threads_per_multi_processor=getattr( props, "max_threads_per_multi_processor", None ), + max_threads_per_block=getattr(props, "max_threads_per_block", 1024), warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None), ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ce3cd317934fe..5a37a0afccb34 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -296,6 +296,11 @@ def __init__( "device_type": self.device_props.type, } self.inductor_meta = {} if inductor_meta is None else inductor_meta + # Add device properties to inductor_meta for use by coordinate descent tuner + self.inductor_meta["warp_size"] = self.device_props.warp_size + self.inductor_meta["max_threads_per_block"] = ( + self.device_props.max_threads_per_block + ) self.deterministic_mode = self.inductor_meta.get("deterministic", False) self.save_cache_hook = save_cache_hook diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a45d9c0275b73..884d060a1b071 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2708,6 +2708,17 @@ def get_gpu_shared_memory() -> int: return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) +def get_max_numwarps() -> int: + if torch.cuda.is_available(): + warp_size = torch.cuda.get_device_properties().warp_size + max_threads_per_block = torch.cuda.get_device_properties().max_threads_per_block + else: + # Defaults + warp_size = 32 + max_threads_per_block = 1024 + return max_threads_per_block // warp_size + + def is_welford_reduction(reduction_type: str) -> bool: return reduction_type.startswith("welford") diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index a8ae82b1b66ea..ec7e5be7eefe7 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1060,6 +1060,8 @@ static void registerCudaDeviceProperties(PyObject* module) { .def_readonly( "max_threads_per_multi_processor", &cudaDeviceProp::maxThreadsPerMultiProcessor) + .def_readonly( + "max_threads_per_block", &cudaDeviceProp::maxThreadsPerBlock) .def_readonly("warp_size", &cudaDeviceProp::warpSize) #ifndef USE_ROCM // NVIDIA-only properties From ecbcc3f6bf327856b435b259ac63cc2f328c4b4e Mon Sep 17 00:00:00 2001 From: linhaifeng <1371675203@qq.com> Date: Wed, 3 Dec 2025 15:28:15 +0000 Subject: [PATCH 179/338] [Fix] Add safeguard for unsafe index (#169234) Inspired by #169140, without checking if the args tuple is empty, which could lead to IndexError and mask the actual error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169234 Approved by: https://github.com/Lucaskabela --- torch/_dynamo/functional_export.py | 16 ++++++++++------ torch/_export/non_strict_utils.py | 7 ++++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 6eb2dcb59b7f3..19e8007f86fdf 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -53,9 +53,10 @@ def post_process_error_msg( orig_sig = inspect.signature(func) flat_input_paths = _get_input_paths((args, kwargs), orig_sig) - constraint_violation_error.args = ( - _replace_sources(constraint_violation_error.args[0], flat_input_paths), - ) + if constraint_violation_error.args: + constraint_violation_error.args = ( + _replace_sources(constraint_violation_error.args[0], flat_input_paths), + ) return constraint_violation_error @@ -423,9 +424,12 @@ def _suggest_or_raise_constraint_violation( forced_specializations, ) if constraint_violation_error: - constraint_violation_error.args = ( - constraint_violation_error.args[0] + msg, - ) + if constraint_violation_error.args: + constraint_violation_error.args = ( + constraint_violation_error.args[0] + msg, + ) + else: + constraint_violation_error.args = (msg,) else: if forced_specializations: constraint_violation_error = ConstraintViolationError(msg) diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index e84e67e5c5b9b..1c064845fe160 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -587,7 +587,12 @@ def produce_guards_and_solve_constraints( ) if constraint_violation_error: - constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) + if constraint_violation_error.args: + constraint_violation_error.args = ( + constraint_violation_error.args[0] + msg, + ) + else: + constraint_violation_error.args = (msg,) elif forced_specializations: constraint_violation_error = ConstraintViolationError(msg) if constraint_violation_error: From ee87bbe876c42575e961b32a0827d76bc9782ca2 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 2 Dec 2025 18:49:32 -0800 Subject: [PATCH 180/338] Support AC rematerializing in forward+loss+bwd graph (#168082) This PR adds an optimization pass that minimizes memory usage during forward+backward+loss mode compilation by rematerializing activation checkpoint (AC) nodes closer to backward region to reduce their lifetime. For example: ```python def inference_with_grad(): x = x_data.requires_grad_(True) # Forward: checkpoint node computed here z = torch.utils.checkpoint.checkpoint(lambda a: torch.sin(a), x) loss = z.sum() # Backward: checkpoint node only used here dx = torch.autograd.grad(loss, x)[0] return dx ``` Without optimization, the checkpointed `sin(x)` is computed during forward and held in memory until backward. This wastes memory if it's only needed in backward. Adds `rematerialize_ac_nodes()` pass that: 1. **Duplicates AC nodes used in both forward/backward**: Keeps in forward, recomputes in backward The pass is controlled by the config flag `torch._functorch.config.enable_inference_mode_ac_reordering` (default: `False`). The pass works in several steps: 1. **Find backward boundary**: Identify where backward computation starts using metadata annotations 2. **Categorize AC nodes**: Classify each AC node based on usage 3. **Rebuild graph**: - Backward: Duplicate AC chains for nodes used in both forward and backward while maintaining original order 4. **DCE** Use DCE to clean up forward nodes since some of them are only used in backward. **User code:** ```python def inference_with_grad(arg0_1, arg1_1): # Checkpointed computation z = torch.utils.checkpoint.checkpoint( lambda a, b: torch.sigmoid(torch.mm(a, b)), arg0_1, arg1_1, use_reentrant=False ) loss = z.sum() # Backward in same graph with torch.fx.traceback.annotate({"backward": 0}): grads = torch.autograd.grad(loss, (arg0_1, arg1_1)) return grads ``` **Compiled graph without AC reordering:** ```python # Forward mm = torch.mm(arg0_1, arg1_1) sigmoid = torch.sigmoid(mm) # AC node computed in forward detach = sigmoid.detach() # Saved for backward sum_1 = sigmoid.sum() # Backward sigmoid_backward = ... # Uses saved sigmoid (held in memory since forward) mm_1 = torch.mm(t, sigmoid_backward) mm_2 = torch.mm(sigmoid_backward, t_1) # Operation counts: mm=3, sigmoid=1 # Memory: sigmoid held from forward through backward ``` **Compiled graph with AC reordering:** ```python # Forward mm = torch.mm(arg0_1, arg1_1) sigmoid = torch.sigmoid(mm) # AC node computed for sum sum_1 = sigmoid.sum() # Used in forward, then freed # Backward - AC nodes recomputed just-in-time mm_recomputed = torch.mm(arg0_1, arg1_1) # Recompute sigmoid_recomputed = torch.sigmoid(mm_recomputed) # Recompute detach = sigmoid_recomputed.detach() # Use recomputed version sigmoid_backward = ... mm_1 = torch.mm(t, sigmoid_backward) mm_2 = torch.mm(sigmoid_backward, t_1) # Operation counts: mm=4, sigmoid=2 # Memory: sigmoid freed after forward, recomputed just-in-time ``` We tested on a synthetic large model which looks like: ```python def large_model(x, weights): """ Large model with multiple checkpointed transformer blocks. Simulates a model with ~6 checkpointed layers. """ # Initial projection x = torch.sin(x) # Checkpointed block 1 w1_1, w1_2, w1_3, w1_4 = weights[0:4] x = torch.utils.checkpoint.checkpoint( transformer_block, x, w1_1, w1_2, w1_3, w1_4, use_reentrant=False ) x = torch.nn.functional.layer_norm(x, x.shape[-1:]) # Checkpointed block 2 w2_1, w2_2, w2_3, w2_4 = weights[4:8] x = torch.utils.checkpoint.checkpoint( transformer_block, x, w2_1, w2_2, w2_3, w2_4, use_reentrant=False ) x = torch.nn.functional.layer_norm(x, x.shape[-1:]) # Checkpointed block 3 w3_1, w3_2, w3_3, w3_4 = weights[8:12] x = torch.utils.checkpoint.checkpoint( transformer_block, x, w3_1, w3_2, w3_3, w3_4, use_reentrant=False ) x = torch.nn.functional.layer_norm(x, x.shape[-1:]) # Checkpointed block 4 w4_1, w4_2, w4_3, w4_4 = weights[12:16] x = torch.utils.checkpoint.checkpoint( transformer_block, x, w4_1, w4_2, w4_3, w4_4, use_reentrant=False ) x = torch.nn.functional.layer_norm(x, x.shape[-1:]) # Checkpointed block 5 w5_1, w5_2, w5_3, w5_4 = weights[16:20] x = torch.utils.checkpoint.checkpoint( transformer_block, x, w5_1, w5_2, w5_3, w5_4, use_reentrant=False ) x = torch.nn.functional.layer_norm(x, x.shape[-1:]) # Checkpointed block 6 w6_1, w6_2, w6_3, w6_4 = weights[20:24] x = torch.utils.checkpoint.checkpoint( transformer_block, x, w6_1, w6_2, w6_3, w6_4, use_reentrant=False ) x = torch.nn.functional.layer_norm(x, x.shape[-1:]) return x ``` The naive inference graph has peak memory of 2700MB while with reordering pass, we achieve around 1100MB which is more or less same as eager peak memory. **Next Steps** Currently, we manually annotate the backward region using fx annotation API, but we should automatically tag them when torchdynamo encounter torch.autograd.grad. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168082 Approved by: https://github.com/soulitzer --- test/dynamo/test_activation_checkpointing.py | 331 ++++++++++++++++++ ..._using_tags_for_fwd_loss_bwd_graph_pass.py | 134 +++++++ .../_functorch/_aot_autograd/graph_compile.py | 7 + torch/_functorch/config.py | 8 + torch/_functorch/partitioners.py | 8 +- 5 files changed, 485 insertions(+), 3 deletions(-) create mode 100644 torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 064cf606182f9..8c3acaba18583 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -67,6 +67,11 @@ def inner(*args): return inner +@torch._dynamo.allow_in_graph +def _grad(*args, **kwargs): + return torch.autograd.grad(*args, **kwargs) + + def count_ops( gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None ): @@ -1994,6 +1999,332 @@ def forward(self, primals_1: "f32[4, 4]"): ) +class RematerializeACNodesPassTests(torch._dynamo.test_case.TestCase): + """Tests for AC reordering optimization in full graph (forward+backward in one graph).""" + + def count_op(self, gm, target): + return sum(1 for n in gm.graph.nodes if n.target == target) + + def _compile_and_capture(self, fn, remat_using_tags_for_fwd_loss_bwd_graph, inputs): + captured_gm = None + + def compiler(gm, example_inputs): + nonlocal captured_gm + captured_gm = gm + return gm.forward + + backend = aot_autograd( + fw_compiler=compiler, + bw_compiler=None, + partition_fn=None, + ) + + with torch._functorch.config.patch( + remat_using_tags_for_fwd_loss_bwd_graph=remat_using_tags_for_fwd_loss_bwd_graph + ): + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + result = compiled_fn(*inputs) + + return result, captured_gm + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_ac_rematerialize_simple_forward_backward(self): + x = torch.randn(4, 4, requires_grad=True) + y = torch.randn(4, 4, requires_grad=True) + + def simple_fwd_bwd(x, y): + z = torch.utils.checkpoint.checkpoint( + lambda a, b: torch.sigmoid(torch.matmul(a, b)), + x, + y, + use_reentrant=False, + ) + loss = z.sum() + + with torch.fx.traceback.annotate({"remat_pass_tag": "is_backward"}): + dx, dy = _grad(loss, (x, y)) + + return dx.detach(), dy.detach() + + (dx1, dy1), gm_without = self._compile_and_capture( + simple_fwd_bwd, False, (x, y) + ) + (dx2, dy2), gm_with = self._compile_and_capture(simple_fwd_bwd, True, (x, y)) + + self.assertTrue(torch.allclose(dx1, dx2)) + self.assertTrue(torch.allclose(dy1, dy2)) + + mm_with = self.count_op(gm_with, torch.ops.aten.mm.default) + mm_without = self.count_op(gm_without, torch.ops.aten.mm.default) + sigmoid_with = self.count_op(gm_with, torch.ops.aten.sigmoid.default) + sigmoid_without = self.count_op(gm_without, torch.ops.aten.sigmoid.default) + self.assertEqual(mm_with, 4, "mm should be recomputed in backward") + self.assertEqual(mm_without, 3) + self.assertEqual(sigmoid_with, 2, "sigmoid should be recomputed in backward") + self.assertEqual(sigmoid_without, 1) + + self.assertExpectedInline( + gm_with.code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + mm = torch.ops.aten.mm.default(arg0_1, arg1_1) + sigmoid = torch.ops.aten.sigmoid.default(mm); mm = None + sum_1 = torch.ops.aten.sum.default(sigmoid); sigmoid = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [4, 4]); ones_like = None + mm_recomputed = torch.ops.aten.mm.default(arg0_1, arg1_1) + sigmoid_recomputed = torch.ops.aten.sigmoid.default(mm_recomputed); mm_recomputed = None + detach_recomputed = torch.ops.aten.detach.default(sigmoid_recomputed); sigmoid_recomputed = None + detach_2 = torch.ops.aten.detach.default(detach_recomputed); detach_recomputed = None + sigmoid_backward = torch.ops.aten.sigmoid_backward.default(expand, detach_2); expand = detach_2 = None + t = torch.ops.aten.t.default(arg0_1); arg0_1 = None + mm_2 = torch.ops.aten.mm.default(t, sigmoid_backward); t = None + t_1 = torch.ops.aten.t.default(arg1_1); arg1_1 = None + mm_3 = torch.ops.aten.mm.default(sigmoid_backward, t_1); sigmoid_backward = t_1 = None + detach_3 = torch.ops.aten.detach.default(mm_3); mm_3 = None + detach_4 = torch.ops.aten.detach.default(mm_2); mm_2 = None + return (detach_3, detach_4)""", + ) + + def test_ac_rematerialize_with_rng_ops_raises_error(self): + x = torch.randn(4, 4, requires_grad=True) + + def fwd_bwd_with_rng(x): + z = torch.utils.checkpoint.checkpoint( + lambda a: torch.sigmoid(a + torch.rand_like(a)), x, use_reentrant=False + ) + loss = z.sum() + + with torch.fx.traceback.annotate({"remat_pass_tag": "is_backward"}): + dx = _grad(loss, x)[0] + + return dx + + with self.assertRaisesRegex( + torch._dynamo.exc.BackendCompilerFailed, + "Activation checkpoint rematerializing in `forward-loss-backward` graph does not support RNG ops in checkpointed regions.", + ): + self._compile_and_capture(fwd_bwd_with_rng, True, (x,)) + + def test_ac_rematerialize_with_no_annotations_warns_and_returns_unchanged(self): + x = torch.randn(4, 4, requires_grad=True) + + def fwd_bwd(x): + z = torch.utils.checkpoint.checkpoint( + lambda a: torch.sigmoid(a + 4), x, use_reentrant=False + ) + loss = z.sum() + return _grad(loss, x)[0] + + # Without backward annotations, the pass should warn and return unchanged + # We verify this by checking that remat_using_tags=True produces the same + # graph as remat_using_tags=False (i.e., no recomputation happens) + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result_with, gm_with = self._compile_and_capture(fwd_bwd, True, (x,)) + + # Check warning was issued + self.assertTrue( + any("no backward region" in str(warning.message) for warning in w), + f"Expected warning about no backward region, got: {[str(warning.message) for warning in w]}", + ) + + # Get the graph without the pass for comparison + result_without, gm_without = self._compile_and_capture(fwd_bwd, False, (x,)) + + # Results should be correct + self.assertTrue(torch.allclose(result_with, result_without)) + + # Both graphs should have the same number of sigmoid ops (no recomputation) + sigmoid_with = self.count_op(gm_with, torch.ops.aten.sigmoid.default) + sigmoid_without = self.count_op(gm_without, torch.ops.aten.sigmoid.default) + self.assertEqual(sigmoid_with, sigmoid_without) + + def test_ac_rematerialize_with_selective_checkpoint_policy(self): + x = torch.randn(4, 128, requires_grad=True) + w1 = torch.randn(128, 128, requires_grad=True) + b1 = torch.randn(128, requires_grad=True) + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.addmm.default: + return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE + return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE + + context_fn = functools.partial( + torch.utils.checkpoint.create_selective_checkpoint_contexts, policy_fn + ) + + def fwd_bwd_with_policy(x, w1, b1): + def checkpoint_fn(inp, w, b): + linear = torch.nn.functional.linear(inp, w, b) + return torch.relu(linear) + + result = torch.utils.checkpoint.checkpoint( + checkpoint_fn, x, w1, b1, use_reentrant=False, context_fn=context_fn + ) + loss = result.sum() + + with torch.fx.traceback.annotate({"remat_pass_tag": "is_backward"}): + dx, dw, db = _grad(loss, (x, w1, b1)) + return dx, dw, db + + result_with, gm_with = self._compile_and_capture( + fwd_bwd_with_policy, True, (x, w1, b1) + ) + result_without, gm_without = self._compile_and_capture( + fwd_bwd_with_policy, False, (x, w1, b1) + ) + + torch.testing.assert_close(result_with[0], result_without[0]) + torch.testing.assert_close(result_with[1], result_without[1]) + torch.testing.assert_close(result_with[2], result_without[2]) + + addmm_without = self.count_op(gm_without, torch.ops.aten.addmm.default) + relu_without = self.count_op(gm_without, torch.ops.aten.relu.default) + + addmm_with = self.count_op(gm_with, torch.ops.aten.addmm.default) + relu_with = self.count_op(gm_with, torch.ops.aten.relu.default) + + self.assertEqual(addmm_without, addmm_with) + self.assertEqual(relu_with, relu_without + 1) + + recomputed_nodes = [ + n.name for n in gm_with.graph.nodes if "_recomputed" in n.name + ] + self.assertNotIn("addmm_recomputed", recomputed_nodes) + + self.assertTrue( + any("relu" in name for name in recomputed_nodes), + f"Expected relu_recomputed but got: {recomputed_nodes}", + ) + + def _compile_with_joint_graph_pass_and_capture(self, fn, inputs): + from torch._inductor.fx_passes.joint_graph import joint_graph_passes + + captured_gm_before = None + captured_gm_after = None + + def custom_compiler(gm, example_inputs): + nonlocal captured_gm_before, captured_gm_after + import copy + + captured_gm_before = copy.deepcopy(gm) + joint_graph_passes(gm) + captured_gm_after = gm + return gm.forward + + backend = aot_autograd( + fw_compiler=custom_compiler, + bw_compiler=None, + partition_fn=None, + ) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + result = compiled_fn(*inputs) + + return result, captured_gm_before, captured_gm_after + + def test_joint_graph_passes_view_optimization(self): + x = torch.randn(4, 4, requires_grad=True) + + def fwd_bwd_with_views(x): + def checkpoint_fn(a): + b = a.view(16) + c = b.view(4, 4) + return torch.sigmoid(c) + + z = torch.utils.checkpoint.checkpoint( + checkpoint_fn, + x, + use_reentrant=False, + ) + loss = z.sum() + + with torch.fx.traceback.annotate({"remat_pass_tag": "is_backward"}): + dx = _grad(loss, x)[0] + + return dx.detach() + + result, gm_before, gm_after = self._compile_with_joint_graph_pass_and_capture( + fwd_bwd_with_views, (x,) + ) + + result_eager = torch.autograd.grad(torch.sigmoid(x).sum(), x)[0] + self.assertTrue(torch.allclose(result, result_eager, atol=1e-5)) + + view_count_before = self.count_op(gm_before, torch.ops.aten.view.default) + view_count_after = self.count_op(gm_after, torch.ops.aten.view.default) + self.assertTrue(view_count_after == 0) + self.assertTrue(view_count_before == 6) + + self.assertExpectedInline( + gm_after.code.strip(), + """\ +def forward(self, arg0_1): + sigmoid = torch.ops.aten.sigmoid.default(arg0_1) + sum_1 = torch.ops.aten.sum.default(sigmoid); sigmoid = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [4, 4]); ones_like = None + sigmoid_recomputed = torch.ops.aten.sigmoid.default(arg0_1); arg0_1 = None + detach_recomputed = torch.ops.aten.detach.default(sigmoid_recomputed); sigmoid_recomputed = None + detach_2 = torch.ops.aten.detach.default(detach_recomputed); detach_recomputed = None + sigmoid_backward = torch.ops.aten.sigmoid_backward.default(expand, detach_2); expand = detach_2 = None + detach_3 = torch.ops.aten.detach.default(sigmoid_backward); sigmoid_backward = None + return (detach_3,)""", + ) + + def test_joint_graph_passes_permute_optimization(self): + x = torch.randn(4, 4, requires_grad=True) + + def fwd_bwd_with_permute(x): + def checkpoint_fn(a): + b = a.permute(1, 0) + c = b.permute(1, 0) + return torch.sigmoid(c) + + z = torch.utils.checkpoint.checkpoint( + checkpoint_fn, + x, + use_reentrant=False, + ) + loss = z.sum() + + with torch.fx.traceback.annotate({"remat_pass_tag": "is_backward"}): + dx = _grad(loss, x)[0] + + return dx.detach() + + result, gm_before, gm_after = self._compile_with_joint_graph_pass_and_capture( + fwd_bwd_with_permute, (x,) + ) + + result_eager = torch.autograd.grad(torch.sigmoid(x).sum(), x)[0] + self.assertTrue(torch.allclose(result, result_eager, atol=1e-5)) + + permute_count_before = self.count_op(gm_before, torch.ops.aten.permute.default) + permute_count_after = self.count_op(gm_after, torch.ops.aten.permute.default) + self.assertTrue(permute_count_after == 0) + self.assertTrue(permute_count_before == 6) + + self.assertExpectedInline( + gm_after.code.strip(), + """\ +def forward(self, arg0_1): + sigmoid = torch.ops.aten.sigmoid.default(arg0_1) + sum_1 = torch.ops.aten.sum.default(sigmoid); sigmoid = None + ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand = torch.ops.aten.expand.default(ones_like, [4, 4]); ones_like = None + sigmoid_recomputed = torch.ops.aten.sigmoid.default(arg0_1); arg0_1 = None + detach_recomputed = torch.ops.aten.detach.default(sigmoid_recomputed); sigmoid_recomputed = None + detach_2 = torch.ops.aten.detach.default(detach_recomputed); detach_recomputed = None + sigmoid_backward = torch.ops.aten.sigmoid_backward.default(expand, detach_2); expand = detach_2 = None + detach_3 = torch.ops.aten.detach.default(sigmoid_backward); sigmoid_backward = None + return (detach_3,)""", + ) + + devices = ["cuda", "hpu"] instantiate_device_type_tests( ActivationCheckpointingViaTagsTests, globals(), only_for=devices diff --git a/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py b/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py new file mode 100644 index 0000000000000..7adc1e0302d11 --- /dev/null +++ b/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py @@ -0,0 +1,134 @@ +""" +AC rematerialize pass: Duplicates checkpointed nodes for backward, then DCE removes unused forward versions. +""" + +import warnings + +import torch +import torch.fx as fx +from torch._functorch import config +from torch._functorch.compile_utils import raise_getitems +from torch._functorch.partitioners import ( + cleanup_recompute_tags, + force_save_bw_mutation_src, + force_save_collectives, + has_recomputable_ops, + has_recomputable_rng_ops, + is_not_collective, + must_recompute, +) + + +def is_impure_node_for_dce(node): + # Check for special collectives that should be treated as pure + if not is_not_collective(node): + # It's a collective (wait_tensor, all_gather_into_tensor, etc.) + # Treat as pure - can be eliminated if unused + return False + + # For everything else, fall back to the DEFAULT logic + # This is what eliminate_dead_code() calls when is_impure_node=None + impure_random = True + if torch._guards.TracingContext.try_get(): + impure_random = torch._inductor.config.fallback_random + return node.is_impure(impure_random) + + +def _is_backward_node(node: fx.Node) -> bool: + """Check if node is in backward region via annotation""" + return node.meta.get("custom", {}).get("remat_pass_tag", None) == "is_backward" + + +def remat_using_tags_for_fwd_loss_bwd_graph(gm: fx.GraphModule) -> fx.GraphModule: + """ + Duplicate checkpointed nodes for backward use. DCE removes unused forward versions. We assume that + you already annotated your backward region with fx.traceback.annotate({"remat_pass_tag": "is_backward"}) + which helps us identify the backward region. + """ + if not has_recomputable_ops(gm): + return gm + + if has_recomputable_rng_ops(gm): + raise RuntimeError( + "Activation checkpoint rematerializing in `forward-loss-backward` graph does not support RNG ops " + "in checkpointed regions. Please move RNG operations outside " + "of checkpoint regions, or use joint graph mode (where partitioner handles RNG)." + ) + + # Use partitioner pass to normalize AC node tags. + gm = cleanup_recompute_tags(gm, is_default_partition=True) + + if not config.unsafe_allow_optimization_of_collectives: + force_save_collectives(gm) + + force_save_bw_mutation_src(gm) + + # Find backward boundary and build ordering + bwd_start: int | None = None + order = {} + for idx, node in enumerate(gm.graph.nodes): + order[node] = idx + if _is_backward_node(node) and bwd_start is None: + bwd_start = idx + + if bwd_start is None: + warnings.warn( + "remat_using_tags_for_fwd_loss_bwd_graph: Graph has recomputable ops but no backward region. " + "This may indicate a forward-only graph (e.g., from nested compilation) or missing backward annotations. " + "Returning graph unchanged." + ) + return gm + + new_graph = fx.Graph() + env: dict[fx.Node, fx.Node] = {} + recomputed_nodes: dict[fx.Node, fx.Node] = {} + + # Insert forward nodes + for node in list(gm.graph.nodes)[:bwd_start]: + env[node] = new_graph.node_copy(node, lambda x: env[x]) + + def remat_input(x): + # fx.Node can have args that are primitive types (e.g. int, float, bool) + if not isinstance(x, fx.Node): + return x + return recomputed_nodes.get(x, env[x]) + + def gather_checkpointed_deps(node: fx.Node, visited: set) -> None: + if node in visited or node in recomputed_nodes: + return + visited.add(node) + for inp in node.all_input_nodes: + if must_recompute(inp): + gather_checkpointed_deps(inp, visited) + + # Insert backward nodes + for node in list(gm.graph.nodes)[bwd_start:]: + # Gather all checkpointed deps needed by this node + deps = set() + for inp in node.all_input_nodes: + if must_recompute(inp): + gather_checkpointed_deps(inp, deps) + + # Insert deps in forward order (guaranteed disjoint from already-inserted) + # This is not as inefficient as it looks, because we only add fresh dependencies + # when they are not yet processed as recomputed nodes. + for dep in sorted(deps, key=lambda n: order[n]): + assert dep not in recomputed_nodes, "We shouldn't have recomputed it before" + dup = new_graph.node_copy(dep, remat_input) + dup.name = dep.name + "_recomputed" + recomputed_nodes[dep] = dup + + env[node] = new_graph.node_copy(node, remat_input) + + new_gm = torch.fx.GraphModule(gm, new_graph) + + # DCE with custom is_impure_node (like default_partition) + # Treats certain collectives as pure while delegating to default impurity logic + new_gm.graph.eliminate_dead_code(is_impure_node=is_impure_node_for_dce) + + # raise_getitems pass for better memory (like default_partition) + new_gm = raise_getitems(new_gm) + + new_gm.recompile() + + return new_gm diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 78320c1b37563..c4b1939a741e5 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -236,6 +236,13 @@ def orig_flat_fn2(*args: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: fw_metadata=aot_state.fw_metadata, ) ) + # Apply AC rematerialization to forward+loss+bwd graph + if torch._functorch.config.remat_using_tags_for_fwd_loss_bwd_graph: + from torch._functorch._activation_checkpointing.remat_using_tags_for_fwd_loss_bwd_graph_pass import ( + remat_using_tags_for_fwd_loss_bwd_graph, + ) + + graph = remat_using_tags_for_fwd_loss_bwd_graph(graph) if config.selective_decompose: from torch.fx.experimental.proxy_tensor import selective_decompose diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 42d6f308f831a..49a069a096f58 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -140,6 +140,14 @@ def remote_autograd_cache_default() -> Optional[bool]: # Generally a good idea since views are free to recompute. recompute_views = False +# Rematerialize AC nodes for graphs with forward+loss+backward in one graph. +# This optimization minimizes activation checkpoint node lifetimes by computing them +# just-in-time. For AC nodes only used in backward, they are deferred to backward region +# instead of being computed and saved in forward. This reduces peak memory usage. +# Note: This only applies to forward+loss+backward graphs where torch.autograd.grad is allowed +# in the graph. Joint graphs (standard AOTAutograd) use the partitioner instead. +remat_using_tags_for_fwd_loss_bwd_graph = True + # By default, the partitioner is purely trying to optimize for runtime (although # it should always use less memory than eager) # This knob controls the partitioner to make that tradeoff for you, choosing the diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index f98aca82fe328..3e2abf2b5650f 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -174,6 +174,11 @@ def __repr__(self): return "Invalid Node" +# Run DCE while overriding the definition of is_impure_node +def is_not_collective(node): + return getattr(node.target, "namespace", None) != "_c10d_functional" + + InvalidNode = InvalidNodeBase() @@ -1170,9 +1175,6 @@ def is_impure(node): ) # Run DCE while overriding the definition of is_impure_node - def is_not_collective(node): - return getattr(node.target, "namespace", None) != "_c10d_functional" - fw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) bw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) From 7ba4680f3755a560af81aa0f688791e367aa3609 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 2 Dec 2025 23:56:50 +0000 Subject: [PATCH 181/338] Fix deprecation warning (#169421) Fixes: ```Shell /home/dev/meta/pytorch/torch/_dynamo/variables/user_defined.py:1824: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. return ctor(*args, **kwargs) /home/dev/meta/pytorch/torch/_dynamo/variables/user_defined.py:1824: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. return ctor(*args, **kwargs) Benchmarking (20 iters)... ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/169421 Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/cyyever --- torch/_dynamo/variables/user_defined.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 012bea32620e9..0863d8592abd2 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1808,6 +1808,10 @@ def as_python_constant(self): "currently can't reconstruct arbitrary frozen dataclass instances" ) + # LeafSpec is deprecated, use treespec_leaf() instead + if istype(self.value, pytree.LeafSpec): + return pytree.treespec_leaf() + args = [] kwargs = {} for field in fields(self.value): From fe0e65adfc0e7ca6e5f57e6ea8b16bd5cc967307 Mon Sep 17 00:00:00 2001 From: Gassan Date: Wed, 3 Dec 2025 16:02:17 +0000 Subject: [PATCH 182/338] Enabling WOQ fusion path on ACL (#165643) - Enabling _register_woq_lowerings() under the ACL (Arm Compute Library) path, while keeping the other fusion operators disabled. - Improves test coverage on (`inductor/test_cpu_select_algorithm.py test_int8_woq_mm_batch_*`). Previously, a list of **54 assertions were disabled/failed for aarch64**. List attached below. - Modified pattern matcher assertions under `test/inductor/test_mkldnn_pattern_matcher.py` to reflect that both ACL should have a woq_matcher_count == 1 for `test_woq_int8`.
test_cpu_select_algorithm.py previously failing tests - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_1024_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_1024_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_1024_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_128_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_128_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_128_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_144_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_144_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_144_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_1024_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_1024_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_1024_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_128_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_128_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_128_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_144_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_144_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_17_mid_dim_8_in_features_144_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_1024_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_1024_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_1024_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_128_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_128_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_128_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_144_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_144_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_1_in_features_144_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_1024_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_1024_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_1024_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_128_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_128_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_128_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_144_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_144_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_1_mid_dim_8_in_features_144_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_1024_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_1024_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_1024_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_128_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_128_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_128_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_144_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_144_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_1_in_features_144_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_1024_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_1024_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_1024_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_128_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_128_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_128_out_features_65_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_144_out_features_1024_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_144_out_features_64_cpu_bfloat16 - TestSelectAlgorithmCPU::test_int8_woq_mm_batch_size_32_mid_dim_8_in_features_144_out_features_65_cpu_bfloat16
Sample error Message: TestSelectAlgorithmCPU::test_int8_woq_mm_batch_* (AssertionError) **Status:** failure **Framework:** pytest **Repro command (from repo root):** ```bash python test/inductor/test_cpu_select_algorithm.py TestSelectAlgorithmCPU.test_int8_woq_mm_batch_size_17_mid_dim_1_in_features_144_out_features_65_cpu_bfloat16 ``` **Assertion:** ```text AssertionError: Scalars are not equal! Expected 1 but got 0. Absolute difference: 1 Relative difference: 1.0 ``` **Traceback (sanitized):** ```text Traceback (most recent call last): ... in test_int8_woq_mm self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) ... in assertEqual return super().assertEqual(x, y, *args, **kwargs) ... in assertEqual raise error_metas.pop()[0].to_error() AssertionError: Scalars are not equal! ``` **Note:** Set `PYTORCH_PRINT_REPRO_ON_FAILURE=0` to suppress this message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165643 Approved by: https://github.com/robert-hardwick, https://github.com/fadara01, https://github.com/jgong5, https://github.com/malfet Co-authored-by: Aditya Tewari Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .ci/pytorch/test.sh | 1 + test/inductor/test_mkldnn_pattern_matcher.py | 4 +--- torch/_inductor/fx_passes/mkldnn_fusion.py | 11 +++++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index fa884ecf2b52a..154e1cfa1b94a 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1734,6 +1734,7 @@ test_linux_aarch64() { inductor/test_split_cat_fx_passes inductor/test_compile inductor/test_torchinductor \ inductor/test_torchinductor_codegen_dynamic_shapes inductor/test_torchinductor_dynamic_shapes inductor/test_memory \ inductor/test_triton_cpu_backend inductor/test_triton_extension_backend inductor/test_mkldnn_pattern_matcher inductor/test_cpu_cpp_wrapper \ + inductor/test_cpu_select_algorithm \ --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose } diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index e91b7b9339ca4..1001a8a9f997a 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -4287,9 +4287,7 @@ def forward(self, x, weight, scales): s = torch.randn(s_shape, dtype=torch.bfloat16) def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["woq_matcher_count"], 0 if TEST_ACL else 1 - ) + self.assertEqual(counters["inductor"]["woq_matcher_count"], 1) self._test_common( mod, diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 08252e58dd566..8f729596cbb1f 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -1562,16 +1562,19 @@ def _mkldnn_fusion_init(): # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. # Otherwise even the matmul or innerproduct can not be accelerated with acl if ( - torch.backends.mkldnn.enabled - and torch.backends.mkldnn.is_available() - and not torch.ops.mkldnn._is_mkldnn_acl_supported() + not torch.backends.mkldnn.enabled + or not torch.backends.mkldnn.is_available() ): + return + + if not torch.ops.mkldnn._is_mkldnn_acl_supported(): _register_unary_fusion() _register_inplace_fusion() _register_binary_unary_fusion() _register_binary_fusion() _register_quantization_lowerings() - _register_woq_lowerings() + + _register_woq_lowerings() @functools.cache def _mkldnn_weight_pack_init(): From e3f24fd73ad74c6e7176687986436956c7c18235 Mon Sep 17 00:00:00 2001 From: Robert Hardwick Date: Tue, 2 Dec 2025 20:17:18 +0000 Subject: [PATCH 183/338] Fix test_matmul_mv_cpu_float32 and Enable test_linalg for AArch64 (#167069) I have edited this PR multiple times. ### Problem Summary ### The problem we are trying to solve is that we see test_matmul_mv_cpu_float32 and test_matmul_mv_cpu_float16 and test_matmul_mv_cpu_bfloat16 fail for various reasons on AArch64. I have done a thorough investigation and some of these tests are passing simply by pure random chance. test_matmul_mv_cpu_bfloat16 passes in it's current form but can be made to fail simply by changing the random seed. test_matmul_mv_cpu_float32 fails but equally can be made to pass by setting the random seed. The failure message is this something like this ``` Mismatched elements: 50000 / 50000 (100.0%) Greatest absolute difference: 0.0859375 at index (0,) (up to 1e-05 allowed) Greatest relative difference: 3.4299155231565237e-06 at index (0,) (up to 1.3e-06 allowed) To execute this test, run the following from the base repo dir: python test/test_linalg.py TestLinalgCPU.test_matmul_mv_cpu_float32 ``` The test is the following ``` n = 50_000 A = torch.ones(n, n, dtype=dtype, device=device) B = torch.rand(n, dtype=dtype, device=device) C = torch.matmul(A, B) self.assertEqual(C, B.sum().expand(B.shape)) ``` The problem is that we are testing 2 different reduction paths for a very large N( torch.matmul and torch.sum will perform a different order of summation, especially with vectorization ). So it is entirely expected that C and B.sum() will have large differences. This is made worse by the use of the uniform distribution instead of the normal distribution. Below are 2 charts which demonstrate that errors tends to accumulate into larger magnitudes for the uniform distribution. ( Note the x axis is square root of N ) ### mean error calculated over different 100 seeds for torch.rand linalg_matmul ### mean error calculated over different 100 seeds for torch.randn randn There is some mathematical proof for this.... following a random walk model to calculate the expected error .. with normally distributed data the error scales with square root of N ( N^1/2 ) , where as for uniformly distributed data the error scales with N^3/2 . The reason being is that for normally distributed data the running sum tends to oscillate around 0, so the magnitude of the error stays roughly constant, where as for non-zero mean random data the sum grows and so does the magnitude of the error. ### The Fix ### - I have changed rand to randn. - Technically the comment says this is regression test for a segmentation fault https://github.com/pytorch/pytorch/issues/150637 so in theory we don't need to check for closeness, we could just put sanity checks in opn the size and shape. - For completeness we keep the assertEqual but i have overridden the tolerances with suitably large atol and rtol which have been made from the mathematical theory but also from evidence of running this with 100 seeds. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167069 Approved by: https://github.com/malfet, https://github.com/jondea --- .ci/pytorch/test.sh | 1 + test/test_linalg.py | 30 +++++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 154e1cfa1b94a..c597f1ee648ec 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1716,6 +1716,7 @@ test_linux_aarch64() { test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \ test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops profiler/test_memory_profiler \ distributed/elastic/timer/api_test distributed/elastic/timer/local_timer_example distributed/elastic/timer/local_timer_test \ + test_linalg \ --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose # Dynamo tests diff --git a/test/test_linalg.py b/test/test_linalg.py index cabc561277b35..eec9c173e8a14 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -17,6 +17,10 @@ from typing import Union, Optional from torch._prims_common import DimsType from packaging import version +from torch.testing._internal.common_device_type import ( + tol, + toleranceOverride +) from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, @@ -3824,11 +3828,11 @@ def run_test_atol(shape0, shape1, batch): # Test broadcasting of tol if a.ndim > 2: tolerances.append(make_tensor(a.shape[-3], dtype=torch.float32, device=device, low=0)) - for tol in tolerances: - actual = torch.linalg.matrix_rank(a, atol=tol) - actual_tol = torch.linalg.matrix_rank(a, tol=tol) + for tol_ in tolerances: + actual = torch.linalg.matrix_rank(a, atol=tol_) + actual_tol = torch.linalg.matrix_rank(a, tol=tol_) self.assertEqual(actual, actual_tol) - numpy_tol = tol if isinstance(tol, float) else tol.cpu().numpy() + numpy_tol = tol_ if isinstance(tol_, float) else tol_.cpu().numpy() expected = np.linalg.matrix_rank(a.cpu().numpy(), tol=numpy_tol) self.assertEqual(actual, expected) @@ -10151,13 +10155,29 @@ def gen_mat(w, h, use_transpose: bool = False): @dtypes(torch.float, torch.half, torch.bfloat16) @largeTensorTest('16GB') + @toleranceOverride({ + torch.float32: tol(atol=1e-05, rtol=1e-05), + torch.float16: tol(atol=0.6, rtol=1e-03), + torch.bfloat16: tol(atol=5.0, rtol=1e-03) + }) def test_matmul_mv(self, device, dtype): # Regression test for https://github.com/pytorch/pytorch/issues/150637 # Such matrix will take more than 4Gb in memory + + # It is expected that we have very large errors when we are summing + # 50,000 random numbers in low precision dtypes using 2 different + # reduction paths so atol,rtol values above reflect this. n = 50_000 A = torch.ones(n, n, dtype=dtype, device=device) - B = torch.rand(n, dtype=dtype, device=device) + B = torch.randn(n, dtype=dtype, device=device) C = torch.matmul(A, B) + + # Sanity Checks + self.assertEqual(C.shape, (n,)) + self.assertEqual(C.dtype, dtype) + self.assertFalse(torch.isnan(C).any()) + self.assertFalse(torch.isinf(C).any()) + self.assertEqual(C, B.sum().expand(B.shape)) @onlyCUDA From 18f3ca08f13b8de61307f5e8cd7d4cccb67e9d11 Mon Sep 17 00:00:00 2001 From: Chinmay Kuchinad Date: Wed, 3 Dec 2025 16:49:11 +0000 Subject: [PATCH 184/338] [ROCm] Enable StaticCudaLauncher for ROCm (#166492) This PR enables ROCm/HIP support for PyTorch's StaticCudaLauncher, which provides static compilation and launching of Triton kernels. The implementation has been tested on AMD MI300 and MI200 hardware. **Changes** **Python (torch/_inductor/runtime/)** - static_cuda_launcher.py: Added ROCm detection, .hsaco binary support, and ROCm-specific scratch parameter handling - triton_heuristics.py: Updated device type checks to support both cuda and hip **C++ (torch/csrc/)** - Module.cpp: Enabled StaticCudaLauncher for ROCm builds - inductor/static_cuda_launcher.cpp: Added HIP API equivalents for all CUDA driver calls - inductor/static_cuda_launcher.h: Updated header guard **Tests (test/inductor/)** - test_static_cuda_launcher.py: Removed @skipIfRocm decorators and updated binary file handling **Enabled Unit Tests** All tests in test/inductor/test_static_cuda_launcher.py now pass on ROCm: 1. test_basic 2. test_unsigned_integers 3. test_signed_integers 4. test_basic_1arg 5. test_constexpr 6. test_implied_constant 7. test_kernel_no_args 8. test_high_shared_mem 9. test_too_high_shared_mem 10. test_kernel_empty_tensor 11. test_kernel_many_args 12. test_basic_compile 13. test_incompatible_code 14. test_static_launch_user_defined_triton_kernels 15. test_empty_tensor 16. test_any 17. test_disable_static_cuda_launcher In addition to this, the following tests from test/inductor/test_codecache.py also pass: 1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False 2. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False 3. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True 4. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False 5. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False 6. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True The following tests are skipped since triton bundling is necessary for StaticCudaLauncher: 1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True 2. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True Pull Request resolved: https://github.com/pytorch/pytorch/pull/166492 Approved by: https://github.com/jeffdaily --- test/inductor/test_ck_backend.py | 1 + test/inductor/test_codecache.py | 9 +- test/inductor/test_static_cuda_launcher.py | 21 +--- .../_inductor/runtime/static_cuda_launcher.py | 56 ++++++++-- torch/_inductor/runtime/triton_heuristics.py | 11 +- torch/_inductor/utils.py | 5 + torch/csrc/Module.cpp | 2 +- torch/csrc/inductor/static_cuda_launcher.cpp | 102 ++++++++++++++++-- torch/csrc/inductor/static_cuda_launcher.h | 2 +- 9 files changed, 168 insertions(+), 41 deletions(-) diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index 079be79fcc9d8..405e46d8ded52 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -235,6 +235,7 @@ def mm(a, b): Y_eager = a @ b torch.testing.assert_close(Y_compiled, Y_eager, equal_nan=True) + @unittest.skip("Autotune Mismatch being investigated") @unittest.skipIf(not torch.version.hip, "ROCM only") @unittest.mock.patch.dict(os.environ, _test_env) @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 1ab261051f4c6..e86a673ad813f 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -479,14 +479,17 @@ def test_remote_cache_load_function( if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if ( + device == "cuda" + and torch.version.hip is None + and dtype == torch.bfloat16 + and not SM80OrLater + ): raise unittest.SkipTest("requires SM80 or later") if use_static_cuda_launcher and not (device == "cuda" and bundle_triton): raise unittest.SkipTest( "Static cuda launcher requires cuda and triton bundling" ) - if use_static_cuda_launcher and TEST_WITH_ROCM: - raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM") def fn(x, y): return (x * 2, y @ y) diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index 654bfd269f761..ec9586197d085 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -12,7 +12,6 @@ from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -39,8 +38,9 @@ def write_cubin_to_tmp(self, kernel: CompiledKernel) -> str: # Just used by tests for now. # TODO: derive cubin_path from wherever triton stores the cubin file on disk. tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False) + binary_key = "hsaco" if torch.version.hip else "cubin" with tmp_file: - tmp_file.write(kernel.asm["cubin"]) + tmp_file.write(kernel.asm[binary_key]) self.tmp_files.append(tmp_file) return tmp_file.name @@ -64,7 +64,6 @@ def _make_launcher( result.load_kernel(device_interface.current_device()) return result - @skipIfRocm def test_basic(self): @triton.jit def simple_kernel(arg0, arg1): @@ -91,7 +90,6 @@ def simple_kernel(arg0, arg1): # 2. triton relies on inspect.get_source to get the type annotations # so I can't even use exec() to generate the test cases. # So we'll just make a few kernels by hand - @skipIfRocm def test_unsigned_integers(self): @triton.jit def unsigned_integers( @@ -115,7 +113,6 @@ def unsigned_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_signed_integers(self): @triton.jit def signed_integers( @@ -139,7 +136,6 @@ def signed_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_basic_1arg(self): @triton.jit def simple_kernel_1_arg(arg0): @@ -164,7 +160,6 @@ def simple_kernel_1_arg(arg0): ) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_constexpr(self): # Constexprs are compiled directly into the cubin file, # so we never need to pass it to StaticCudaLauncher. @@ -193,7 +188,6 @@ def kernel_constexpr(arg0, CONSTANT: tl.constexpr): ) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_implied_constant(self): """xnumel is unused in this kernel, but isn't explicitly marked as a constexpr""" @@ -246,7 +240,6 @@ def triton_red_fused_any_isinf_0( launcher.run(1, 1, 1, stream, arg0, arg2, 128) self.assertEqual(arg1, arg2) - @skipIfRocm def test_kernel_no_args(self): # Just an easy way to test incompatible number of arguments @triton.jit @@ -259,7 +252,6 @@ def kernel_no_op(): stream = device_interface.get_raw_stream(device_interface.current_device()) launcher.run(1, 1, 1, stream) - @skipIfRocm def test_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -283,7 +275,6 @@ def simple_kernel(arg0, arg1): launcher.run(1, 1, 1, stream, new_arg0, arg1) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_too_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -303,7 +294,6 @@ def simple_kernel(arg0, arg1): lambda: self._make_launcher(compiled_kernel), ) - @skipIfRocm def test_kernel_empty_tensor(self): # Triton kernel generated by torch.compile of the following: # @torch.compile() @@ -364,7 +354,6 @@ def triton_poi_fused_cat_0( launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel) self.assertEqual(buf0, buf1) - @skipIfRocm def test_kernel_many_args(self): N = 200 # Make 200 arguments @@ -405,7 +394,6 @@ class TestStaticTritonCompileResult(TestCase): Tests static cuda launcher with torch.compile() """ - @skipIfRocm def test_basic_compile(self): @torch.compile def foo(x, y): @@ -415,7 +403,6 @@ def foo(x, y): y = torch.randn(10, device="cuda") self.assertEqual(foo(x, y), x + y) - @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch("compile_threads", 1) def test_incompatible_code(self): @@ -438,7 +425,6 @@ def foo(x): lambda: foo(x), ) - @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch( {"compile_threads": 1, "static_launch_user_defined_triton_kernels": True} @@ -460,7 +446,6 @@ def foo(x): x2 = x.clone().detach_() self.assertEqual(foo(x), x2 + 5) - @skipIfRocm def test_empty_tensor(self): @torch.compile() def foo(x, y): @@ -472,7 +457,6 @@ def foo(x, y): result = foo(x, y) self.assertEqual(result, torch.cat(((x * 4), y + 10))) - @skipIfRocm def test_any(self): def fn(x): return ( @@ -492,7 +476,6 @@ def fn(x): compiled_result = compiled_fn(arg) self.assertEqual(eager_result, compiled_result) - @skipIfRocm def test_disable_static_cuda_launcher(self): @torch.compile def fn(x, y): diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index f48f351ce823a..a53ef35f4cf83 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -3,6 +3,7 @@ from typing import Any from typing_extensions import Unpack +from ..utils import is_rocm from .triton_compat import ASTSource, CompiledKernel, knobs as triton_knobs from .triton_helpers import get_constexprs @@ -38,7 +39,20 @@ def __init__(self, kernel: CompiledKernel) -> None: # pyrefly: ignore [missing-attribute] self.name = kernel.src.fn.__name__ # pyrefly: ignore [missing-attribute] - self.cubin_raw = kernel.asm.get("cubin", None) + if "hsaco" in kernel.asm: + # pyrefly: ignore [missing-attribute] + self.cubin_raw = kernel.asm["hsaco"] + + # pyrefly: ignore [missing-attribute] + elif "cubin" in kernel.asm: + # pyrefly: ignore [missing-attribute] + self.cubin_raw = kernel.asm["cubin"] + + else: + raise RuntimeError( + "Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm" + ) + # pyrefly: ignore [missing-attribute] self.cubin_path = kernel._cubin_path @@ -245,12 +259,42 @@ def run( # thing, it should always match. # Get rid of constants before passing to cubin launcher - # Add a None if triton wants extra parameters for scratch spaces arg_tys = self.arg_tys - for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: - if has_scratch: - arg_tys = arg_tys + "O" - args = (*args, None) + + if is_rocm(): + # ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both + # global_scratch and profile_scratch parameters in the kernel signature, + # even when the kernel doesn't use them (i.e., when has_*_scratch is False). + # + # This differs fundamentally from CUDA, where these parameters are only + # present in the signature if the corresponding has_*_scratch flag is True. + # + # The flags indicate whether memory will be allocated/used: + # - has_global_scratch: Whether global scratch workspace is needed + # - has_profile_scratch: Whether profiling instrumentation is enabled + # + # However, regardless of flag values, we MUST always pass both parameters + # to match the HIP kernel ABI. Passing None is safe: + # + # - If scratch is not needed (has_*_scratch=False or scratch_size=0): + # The None becomes nullptr, which the kernel never dereferences + # + # - If scratch is needed (has_*_scratch=True and scratch_size>0): + # The None becomes nullptr initially, but the HIP runtime intercepts + # the kernel launch, allocates the required scratch memory based on + # kernel metadata, and replaces the nullptr with a valid pointer before + # the kernel actually executes + # + # Not passing both parameters causes segmentation faults because the kernel + # expects them at specific positions in the argument array. + arg_tys = arg_tys + "OO" + args = (*args, None, None) + + else: + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) # pyrefly: ignore [bad-argument-type] assert len(args) == len(arg_tys) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5a37a0afccb34..2e2cd8a8db780 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1613,9 +1613,8 @@ def can_statically_launch( return None def check_can_launch() -> StaticallyLaunchedCudaKernel: - if triton_meta.get("device_type") != "cuda": - # Only cuda kernels - raise CannotStaticallyLaunchKernel("Non-cuda device") + if triton_meta.get("device_type") not in ("cuda", "hip"): + raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device") if torch._inductor.config.cpp_wrapper: # If we're running with cpp wrapper, it doesn't @@ -1641,10 +1640,11 @@ def check_can_launch() -> StaticallyLaunchedCudaKernel: "static launch does not support launch attributes" ) + binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin" cubin_location = os.path.join( triton_cache_dir(triton_meta.get("device", 0)), triton_hash_to_path_key(kernel.hash), - f"{kernel.src.fn.__name__}.cubin", + f"{kernel.src.fn.__name__}.{binary_ext}", ) if not os.path.exists(cubin_location): @@ -1676,10 +1676,11 @@ def reload_cubin_path(self): When loading from cache on disk, we want to reload cubin files from their appropriate location on disc. """ + binary_ext = "hsaco" if torch.version.hip else "cubin" cubin_location = os.path.join( triton_cache_dir(self.compile_meta.get("device", 0)), triton_hash_to_path_key(self.kernel.hash), - f"{self.kernel.name}.cubin", + f"{self.kernel.name}.{binary_ext}", ) if not os.path.exists(cubin_location): if self.kernel.cubin_raw is not None: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 884d060a1b071..4d1ddc9ad4769 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3053,6 +3053,11 @@ def is_gpu(device: Optional[str]) -> bool: return device in GPU_TYPES +def is_rocm() -> bool: + """Check if we're running on ROCm/HIP platform.""" + return torch.version.hip is not None + + def device_need_guard(device: str) -> bool: return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 61ef99e8086f9..4de6ba3976688 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2150,7 +2150,7 @@ PyObject* initModule() { #ifdef USE_CUDA torch::cuda::initModule(module); #endif -#if defined(USE_CUDA) && !defined(USE_ROCM) +#if defined(USE_CUDA) ASSERT_TRUE(StaticCudaLauncher_init(module)); #endif #ifdef USE_MPS diff --git a/torch/csrc/inductor/static_cuda_launcher.cpp b/torch/csrc/inductor/static_cuda_launcher.cpp index 59916b6763bfa..da61cd28c1b6f 100644 --- a/torch/csrc/inductor/static_cuda_launcher.cpp +++ b/torch/csrc/inductor/static_cuda_launcher.cpp @@ -1,7 +1,4 @@ -#if defined(USE_CUDA) && !defined(USE_ROCM) -// We disable this file from being hipified because there are CUDA drivers hip -// has not implemented yet. Also, we're passing in a cubin file directly, so it -// would take more work to support ROCM anyway. +#if defined(USE_CUDA) || defined(USE_ROCM) #include #include @@ -16,6 +13,11 @@ #include #include #include + +#if defined(USE_ROCM) +#include +#endif + /** Implements a static launcher for triton compiled CUDA kernels. Given a path to a cubin file, a function name, and some metadata, @@ -56,8 +58,14 @@ const at::cuda::NVRTC& nvrtc() { CUdeviceptr getPointer(PyObject* obj) { CUdeviceptr data_ptr = 0; + if (THPUtils_checkLong(obj)) { +#if defined(USE_ROCM) + data_ptr = reinterpret_cast(THPUtils_unpackUInt64(obj)); +#else data_ptr = THPUtils_unpackUInt64(obj); +#endif + return data_ptr; } if (obj == Py_None) { @@ -73,13 +81,25 @@ CUdeviceptr getPointer(PyObject* obj) { TORCH_CHECK( THPUtils_checkLong(ret), "data_ptr method of Pointer object must return 64-bit int"); + +#if defined(USE_ROCM) + data_ptr = reinterpret_cast(THPUtils_unpackUInt64(ret)); +#else data_ptr = THPUtils_unpackUInt64(ret); +#endif + if (!data_ptr) return data_ptr; CUdeviceptr dev_ptr = 0; +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute( + &dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute( &dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); +#endif + return dev_ptr; } @@ -98,6 +118,15 @@ CUfunction loadKernel( } CUmodule mod = nullptr; CUfunction func = nullptr; + +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str())); + AT_CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str())); + int shared_optin = 0; + AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( + &shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device)); + +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str())); AT_CUDA_DRIVER_CHECK( nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str())); @@ -106,6 +135,9 @@ CUfunction loadKernel( &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); + +#endif + // Shared memory logic from triton/third-party/nvidia/backend/driver.c // If we're using more than 48 KB of shared memory, and we have // access to more than 48 KB of shared memory on the device, @@ -124,6 +156,21 @@ CUfunction loadKernel( " Reducing block sizes or `num_stages` may help."); if (sharedMemBytes > SHARED_MEM_STATIC_MAX && shared_optin > SHARED_MEM_STATIC_MAX) { +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipFuncSetCacheConfig(func, hipFuncCachePreferShared)); + int shared_total = 0, shared_static = 0; + AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( + &shared_total, + hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, + device)); + AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( + &shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); + AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute( + func, + hipFuncAttributeMaxDynamicSharedMemorySize, + shared_optin - shared_static)); + +#else AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED)); int shared_total = 0, shared_static = 0; @@ -137,6 +184,7 @@ CUfunction loadKernel( func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); +#endif } return func; } @@ -152,6 +200,27 @@ inline void launchKernel( cudaStream_t stream) { // cta_args is always 1 for inductor generated triton kernels, // so we don't need to figure out grid dimension here +#if defined(USE_ROCM) + int device = 0; + AT_CUDA_DRIVER_CHECK(hipGetDevice(&device)); + int warp_size = 0; + AT_CUDA_DRIVER_CHECK( + hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device)); + + AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel( + func, + gridX, + gridY, + gridZ, + warp_size * numWarps, // blockDim.x + 1, // blockDim.y + 1, // blockDim.z + sharedMemBytes, + stream, + args, + nullptr)); + +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( func, gridX, @@ -164,6 +233,7 @@ inline void launchKernel( stream, args, nullptr)); +#endif } template @@ -269,11 +339,20 @@ PyObject* load_kernel(PyObject* self, PyObject* args) { CUdevice device = static_cast(device_ptr); // NOLINT CUfunction func = nullptr; func = loadKernel(filePath, funcName, sharedMemBytes, device); - // Taken from triton/nvidia/backend/driver.c + +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK( + hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func)); + AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( + &n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); + +#else AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func)); AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute( &n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); + +#endif n_spills /= 4; // Return a tuple of CUFunction, n_regs, n_spills return Py_BuildValue( @@ -299,7 +378,6 @@ PyObject* launch_kernel_inner( std::array argStorage = {}; std::array kernelArgs = {}; parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data()); - launchKernel( func, gridX, @@ -386,13 +464,25 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) { Py_RETURN_NONE; } CUcontext pctx = nullptr; +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx)); +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); +#endif + if (!pctx) { // Ensure device context exists CUdevice device = 0; +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0)); + AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device)); + AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx)); +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0)); AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device)); AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx)); + +#endif } CUfunction func = reinterpret_cast(func_ptr); // NOLINT cudaStream_t cudaStream = reinterpret_cast(stream); // NOLINT diff --git a/torch/csrc/inductor/static_cuda_launcher.h b/torch/csrc/inductor/static_cuda_launcher.h index 517036b9975e6..6f3980172275b 100644 --- a/torch/csrc/inductor/static_cuda_launcher.h +++ b/torch/csrc/inductor/static_cuda_launcher.h @@ -1,5 +1,5 @@ #pragma once -#if defined(USE_CUDA) && !defined(USE_ROCM) +#if defined(USE_CUDA) #include #include From 7348cb355ff0a6f79cd4871215aea72185748734 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 2 Dec 2025 23:56:52 +0000 Subject: [PATCH 185/338] [Submodule] cutlass bump to minor bug fix (#169433) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169433 Approved by: https://github.com/danielvegamyhre ghstack dependencies: #169421 --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index e67e63c331d6e..f88806b1e31df 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit e67e63c331d6e4b729047c95cf6b92c8454cba89 +Subproject commit f88806b1e31dfa579842638740216dd41fc6c588 From 4ae684634d842f5228ab83e0d41d76c9bd0e4de1 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 2 Dec 2025 18:37:58 +0000 Subject: [PATCH 186/338] Add inductor core tests for Python 3.11/3.12 (#167542) Follow-up of PR #166978, which adds 2 new CI jobs to run the inductor core tests (`test/inductor/*`) on Python 3.11 and 3.12. Tests are splitted into 2 shards and each one takes around 1 hour to finish. The job is scheduled to run every night at 1:29 PDT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167542 Approved by: https://github.com/malfet ghstack dependencies: #169032 --- .ci/pytorch/test.sh | 25 ++++++++++++++++++ .github/workflows/inductor-unittest.yml | 35 +++++++++++++++++++++++++ test/run_test.py | 16 +++++++++++ 3 files changed, 76 insertions(+) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index c597f1ee648ec..af6ddbcb838ce 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -461,6 +461,29 @@ test_inductor_distributed() { assert_git_not_dirty } +test_inductor_core() { + time python test/run_test.py \ + --include-inductor-core-tests \ + --exclude inductor/test_benchmark_fusion \ + inductor/test_cutlass_backend \ + inductor/test_flex_attention \ + inductor/test_max_autotune \ + inductor/test_aot_inductor_arrayref \ + inductor/test_aot_inductor_arrayref \ + inductor/test_compiled_autograd \ + inductor/test_compile_subprocess \ + inductor/test_cpu_cpp_wrapper \ + inductor/test_cpu_repro \ + inductor/test_cpu_select_algorithm \ + inductor/test_torchinductor_dynamic_shapes \ + inductor/test_torchinductor \ + inductor/test_mkldnn_pattern_matcher \ + inductor/test_torchinductor_codegen_dynamic_shapes \ + --verbose \ + --upload-artifacts-while-running + assert_git_not_dirty +} + test_inductor_shard() { if [[ -z "$NUM_TEST_SHARDS" ]]; then echo "NUM_TEST_SHARDS must be defined to run a Python test shard" @@ -1914,6 +1937,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then if [[ "$SHARD_NUMBER" -eq "1" ]]; then test_inductor_aoti_cpp fi +elif [[ "${TEST_CONFIG}" == *inductor_core* ]]; then + test_inductor_core elif [[ "${TEST_CONFIG}" == *inductor* ]]; then install_torchvision test_inductor_shard "${SHARD_NUMBER}" diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 9c1dd3d82769d..b4e8ca2526811 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -160,3 +160,38 @@ jobs: docker-image: ${{ needs.inductor-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-cpu-build.outputs.test-matrix }} secrets: inherit + + inductor-cpu-core-build: + name: inductor-cpu-core-build + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + strategy: + matrix: + python-version: ['3.11', '3.12'] + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py${{ matrix.python-version }}-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12 + test-matrix: | + { include: [ + { config: "inductor_core", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, + { config: "inductor_core", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, + ]} + secrets: inherit + + inductor-cpu-core-test: + name: inductor-cpu-core-test + uses: ./.github/workflows/_linux-test.yml + needs: [get-label-type, inductor-cpu-core-build] + strategy: + matrix: + python-version: ['3.11', '3.12'] + with: + build-environment: linux-jammy-py${{ matrix.python-version }}-clang12 + docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12 + test-matrix: | + { include: [ + { config: "inductor_core", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, + { config: "inductor_core", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, + ]} + secrets: inherit diff --git a/test/run_test.py b/test/run_test.py index 39b13980c2f04..349d0755360ec 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1361,6 +1361,16 @@ def parse_args(): "(including dynamo tests)." ), ) + parser.add_argument( + "--include-inductor-core-tests", + "--include-inductor-core-tests", + action="store_true", + help=( + "If this flag is present, we will only run inductor tests. " + "If this flag is not present, we will run all tests " + "(including inductor tests)." + ), + ) parser.add_argument( "--functorch", "--functorch", @@ -1633,6 +1643,12 @@ def get_selected_tests(options) -> list[str]: filter(lambda test_name: test_name in DYNAMO_CORE_TESTS, selected_tests) ) + # Filter to only run dynamo tests when --include-inductor-core-tests option is specified + if options.include_inductor_core_tests: + selected_tests = list( + filter(lambda test_name: test_name in INDUCTOR_TESTS, selected_tests) + ) + # Filter to only run functorch tests when --functorch option is specified if options.functorch: selected_tests = list( From 09076941a95c76f4d9ad189d064dfd8baa39e672 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 2 Dec 2025 18:37:58 +0000 Subject: [PATCH 187/338] Run dynamo/inductor core tests on Python 3.13 (#168293) This one should be merged after https://github.com/pytorch/pytorch/pull/166679 which moves CI to 3.14 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168293 Approved by: https://github.com/williamwen42, https://github.com/malfet ghstack dependencies: #169032, #167542 --- .github/workflows/dynamo-unittest.yml | 4 ++-- .github/workflows/inductor-unittest.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/dynamo-unittest.yml b/.github/workflows/dynamo-unittest.yml index e1399b1376de4..8177d64a2d5ee 100644 --- a/.github/workflows/dynamo-unittest.yml +++ b/.github/workflows/dynamo-unittest.yml @@ -36,7 +36,7 @@ jobs: needs: get-label-type strategy: matrix: - python-version: ['3.11', '3.12'] + python-version: ['3.11', '3.12', '3.13'] with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py${{ matrix.python-version }}-clang12 @@ -56,7 +56,7 @@ jobs: needs: [get-label-type, dynamo-build] strategy: matrix: - python-version: ['3.11', '3.12'] + python-version: ['3.11', '3.12', '3.13'] with: build-environment: linux-jammy-py${{ matrix.python-version }}-clang12 docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12 diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index b4e8ca2526811..22801a7f2f158 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -167,7 +167,7 @@ jobs: needs: get-label-type strategy: matrix: - python-version: ['3.11', '3.12'] + python-version: ['3.11', '3.12', '3.13'] with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py${{ matrix.python-version }}-clang12 @@ -185,7 +185,7 @@ jobs: needs: [get-label-type, inductor-cpu-core-build] strategy: matrix: - python-version: ['3.11', '3.12'] + python-version: ['3.11', '3.12', '3.13'] with: build-environment: linux-jammy-py${{ matrix.python-version }}-clang12 docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12 From 07dcc0b83db3211653a38565a24e15acdba75654 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:50:15 +0000 Subject: [PATCH 188/338] [ROCm][CI] Add docker caching for MI250 runners (#169300) Enables docker image caching for MI250 CI runners. Tested via: https://github.com/pytorch/pytorch/actions/runs/19841041285 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169300 Approved by: https://github.com/jeffdaily --- .github/actionlint.yaml | 1 + .github/workflows/docker-cache-rocm.yml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index dfb30e155b162..46d0b2b20b127 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -59,6 +59,7 @@ self-hosted-runner: - linux.rocm.gpu.mi250 - linux.rocm.gpu.2 - linux.rocm.gpu.4 + - linux.rocm.mi250.docker-cache # gfx942 runners - linux.rocm.gpu.gfx942.1 - linux.rocm.gpu.gfx942.2 diff --git a/.github/workflows/docker-cache-rocm.yml b/.github/workflows/docker-cache-rocm.yml index ffb2007ca105f..0ce02dbc1de57 100644 --- a/.github/workflows/docker-cache-rocm.yml +++ b/.github/workflows/docker-cache-rocm.yml @@ -57,7 +57,7 @@ jobs: strategy: fail-fast: false matrix: - runner: [linux.rocm.gfx942.docker-cache] + runner: [linux.rocm.gfx942.docker-cache, linux.rocm.mi250.docker-cache] docker-image: [ "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}" From f7e1bd80a063e17453c361837ba6ea2570920a73 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:01:26 +0000 Subject: [PATCH 189/338] Pass GPU_FLAGS inside container during the test (#169397) By adding `cuda` to the `build-environment:` workflow parameter name Also move it to g6 runners to match the build args Pull Request resolved: https://github.com/pytorch/pytorch/pull/169397 Approved by: https://github.com/oulgen --- .github/workflows/inductor-unittest.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 22801a7f2f158..308e3bedf2ea0 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -96,7 +96,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" }, + { config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.12xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -105,7 +105,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-pallas-build with: - build-environment: linux-jammy-py3.12-gcc11 + build-environment: linux-jammy-cuda12.8-py3.12-gcc11 docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }} secrets: inherit From 1d21b4df2babe322e5d085ceb6de884eb260a62d Mon Sep 17 00:00:00 2001 From: Dzmitry Huba Date: Tue, 2 Dec 2025 11:55:26 -0800 Subject: [PATCH 190/338] Add more local test coverage for DTensor based tests (#169396) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169396 Approved by: https://github.com/dolpm --- .../tensor/parallel/test_parallelize_api.py | 20 ++++++++++++++++++- .../tensor/parallel/test_tp_style.py | 5 +++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py index 2ef70f1a447e3..017f61234a4ae 100644 --- a/test/distributed/tensor/parallel/test_parallelize_api.py +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -15,7 +15,9 @@ ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorTestBase, + map_local_tensor_for_rank, MLPModule, MLPStacked, with_comms, @@ -78,7 +80,14 @@ def _compare_module( # check forward correctness local_output = local_module(inp) - inp = inp.chunk(self.world_size, dim=-1)[self.rank] if rowwise else inp + inp = map_local_tensor_for_rank( + inp, + self.rank, + lambda inp, rank: inp.chunk(self.world_size, dim=-1)[rank] + if rowwise + else inp, + ) + # inp = inp.chunk(self.world_size, dim=-1)[self.rank] if rowwise else inp dist_output = dist_module(inp) dist_output = ( dist_output.redistribute(dist_output.device_mesh, [Replicate()]).to_local() @@ -404,5 +413,14 @@ def test_empty_plan(self): parallelize_module(model, device_mesh) +TensorParallelAPITestsWithLocalTensor = create_local_tensor_test_class( + TensorParallelAPITests, + skipped_tests=[ + # Uses mesh_scatter that has local rank dependent logic + "test_parallelize_module_src_data_rank", + ], +) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py index b34d707a7e65e..7eb54cbde3a1c 100644 --- a/test/distributed/tensor/parallel/test_tp_style.py +++ b/test/distributed/tensor/parallel/test_tp_style.py @@ -19,6 +19,7 @@ from torch.distributed.tensor.placement_types import _Partial from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorTestBase, NUM_DEVICES, RMSNormPython, @@ -434,5 +435,9 @@ def test_sequence_parallel_style(self): self.assertEqual(comm_mode.get_total_counts(), 2) +TensorParallelStyleTestWithLocalTensor = create_local_tensor_test_class( + TensorParallelStyleTest, +) + if __name__ == "__main__": run_tests() From 87329491c82a5f8c1cc4ec11d8f55a5de2551ece Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 3 Dec 2025 15:39:31 +0200 Subject: [PATCH 191/338] Add shims for setCurrentCUDAStream, getStreamFromPool, and synchronize. (#169376) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169376 Approved by: https://github.com/eqy, https://github.com/mikaylagawarecki --- .../csrc/test_cuda_stream.cu | 36 +++++++++++ .../libtorch_agnostic_2_10/ops.py | 51 +++++++++++++++ test/cpp_extensions/test_libtorch_agnostic.py | 63 +++++++++++++++++++ torch/csrc/cuda/shim_common.cpp | 30 +++++++++ torch/csrc/stable/c/shim.h | 11 ++++ 5 files changed, 191 insertions(+) create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cuda_stream.cu diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cuda_stream.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cuda_stream.cu new file mode 100644 index 0000000000000..5daa429476d43 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cuda_stream.cu @@ -0,0 +1,36 @@ +#include +#include + +void* my_get_current_cuda_stream(int32_t device_index) { + void* ret_stream; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_index, &ret_stream)); + return ret_stream; +} + +void my_set_current_cuda_stream(void* stream, int32_t device_index) { + TORCH_ERROR_CODE_CHECK(torch_set_current_cuda_stream(stream, device_index)); +} + +void* my_get_cuda_stream_from_pool(bool isHighPriority, int32_t device_index) { + void* ret_stream; + TORCH_ERROR_CODE_CHECK(torch_get_cuda_stream_from_pool(isHighPriority, device_index, &ret_stream)); + return ret_stream; +} + +void my_cuda_stream_synchronize(void* stream, int32_t device_index) { + TORCH_ERROR_CODE_CHECK(torch_cuda_stream_synchronize(stream, device_index)); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_get_current_cuda_stream(int device_index) -> int"); + m.def("my_set_current_cuda_stream(int stream, int device_index) -> ()"); + m.def("my_get_cuda_stream_from_pool(bool isHighPriority, int device_index) -> int"); + m.def("my_cuda_stream_synchronize(int stream, int device_index) -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_get_current_cuda_stream", TORCH_BOX(&my_get_current_cuda_stream)); + m.impl("my_set_current_cuda_stream", TORCH_BOX(&my_set_current_cuda_stream)); + m.impl("my_get_cuda_stream_from_pool", TORCH_BOX(&my_get_cuda_stream_from_pool)); + m.impl("my_cuda_stream_synchronize", TORCH_BOX(&my_cuda_stream_synchronize)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index d53e481ca4a10..b1fca47322e1b 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -265,3 +265,54 @@ def my_string_op(t, accessor, passthru) -> tuple[list[str], int]: Returns: tuple - (list of [accessor, value, passthru] as strings, value) """ return torch.ops.libtorch_agnostic_2_10.my_string_op.default(t, accessor, passthru) + + +def my_get_current_cuda_stream(device_index: int) -> int: + """ + Return the current cudaStream_t pointer value. + + Args: + device_index: int - device index + """ + return torch.ops.libtorch_agnostic_2_10.my_get_current_cuda_stream.default( + device_index + ) + + +def my_set_current_cuda_stream(stream: int, device_index: int): + """ + Set the current stream to cudaStream_t pointer value. + + Args: + stream: int - cudaStream_t pointer value + device_index: int - device index + """ + return torch.ops.libtorch_agnostic_2_10.my_set_current_cuda_stream.default( + stream, device_index + ) + + +def my_get_cuda_stream_from_pool(high_priority: bool, device_index: int) -> int: + """ + Return the cudaStream_t pointer value from pool. + + Args: + high_priority: bool - if true, return a stream with high priority + device_index: int - device index + """ + return torch.ops.libtorch_agnostic_2_10.my_get_cuda_stream_from_pool.default( + high_priority, device_index + ) + + +def my_cuda_stream_synchronize(stream: int, device_index: int): + """ + Synchronize cuda stream. + + Args: + stream: int - cudaStream_t pointer value + device_index: int - device index + """ + return torch.ops.libtorch_agnostic_2_10.my_cuda_stream_synchronize( + stream, device_index + ) diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index dfb9b6b37f593..10f1bba1e3179 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -859,6 +859,69 @@ def test_my_string_op(self, device): with self.assertRaisesRegex(RuntimeError, "Unsupported accessor value: "): libtorch_agnostic.ops.my_string_op(t, "invalid", "") + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + def test_my_get_current_cuda_stream(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + device_index = torch.device(device).index + res = libtorch_agnostic.ops.my_get_current_cuda_stream(device_index) + expected = torch.cuda.current_stream(device_index).cuda_stream + self.assertEqual(res, expected) + + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + def test_my_set_current_cuda_stream(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + device_index = torch.device(device).index + prev_stream = torch.cuda.current_stream(device_index).cuda_stream + new_stream = torch.cuda.streams.Stream(device_index).cuda_stream + + try: + libtorch_agnostic.ops.my_set_current_cuda_stream( + new_stream, device_index + ) + expected = torch.cuda.current_stream(device_index).cuda_stream + self.assertEqual(new_stream, expected) + finally: + libtorch_agnostic.ops.my_set_current_cuda_stream( + prev_stream, device_index + ) + + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + def test_my_get_cuda_stream_from_pool(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + device_index = torch.device(device).index + prev_stream = torch.cuda.current_stream(device_index).cuda_stream + + try: + for high_priority in [False, True]: + stream = libtorch_agnostic.ops.my_get_cuda_stream_from_pool( + high_priority, device_index + ) + libtorch_agnostic.ops.my_set_current_cuda_stream( + stream, device_index + ) + expected = torch.cuda.current_stream(device_index).cuda_stream + self.assertEqual(stream, expected) + finally: + libtorch_agnostic.ops.my_set_current_cuda_stream( + prev_stream, device_index + ) + + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + def test_my_cuda_stream_synchronize(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + device_index = torch.device(device).index + stream = torch.cuda.current_stream(device_index).cuda_stream + # sanity check for torch_cuda_stream_synchronize: + libtorch_agnostic.ops.my_cuda_stream_synchronize(stream, device_index) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/cuda/shim_common.cpp b/torch/csrc/cuda/shim_common.cpp index cb5f28dba0152..24cee443bb1aa 100644 --- a/torch/csrc/cuda/shim_common.cpp +++ b/torch/csrc/cuda/shim_common.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -7,3 +8,32 @@ AOTITorchError torch_get_current_cuda_blas_handle(void** ret_handle) { *(cublasHandle_t*)(ret_handle) = at::cuda::getCurrentCUDABlasHandle(); }); } + +AOTITorchError torch_set_current_cuda_stream( + void* stream, + int32_t device_index) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::cuda::setCurrentCUDAStream(at::cuda::getStreamFromExternal( + static_cast(stream), device_index)); + }); +} + +AOTITorchError torch_get_cuda_stream_from_pool( + const bool isHighPriority, + int32_t device_index, + void** ret_stream) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + *(cudaStream_t*)(ret_stream) = + at::cuda::getStreamFromPool(isHighPriority, device_index); + }); +} + +AOTITorchError torch_cuda_stream_synchronize( + void* stream, + int32_t device_index) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::cuda::getStreamFromExternal( + static_cast(stream), device_index) + .synchronize(); + }); +} diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index 202ca3ba40c05..545cb3eeb2c56 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -122,6 +122,17 @@ torch_string_c_str(StringHandle handle, const char** data); AOTI_TORCH_EXPORT AOTITorchError torch_get_current_cuda_blas_handle(void** ret_handle); +AOTI_TORCH_EXPORT AOTITorchError +torch_set_current_cuda_stream(void* stream, int32_t device_index); + +AOTI_TORCH_EXPORT AOTITorchError torch_get_cuda_stream_from_pool( + bool isHighPriority, + int32_t device_index, + void** ret_stream); + +AOTI_TORCH_EXPORT AOTITorchError +torch_cuda_stream_synchronize(void* stream, int32_t device_index); + #endif // USE_CUDA #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 From d038b0130ec7c20ebcac219301292fd8e98a1ace Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Wed, 3 Dec 2025 18:37:58 +0000 Subject: [PATCH 192/338] The Nested Pool (#168382) This PR fixes issue #161193 by simply reversing the iteration order over captures_underway. After discussing with @galv, we decided to land this minimal fix first to unblock nested MemPool usage. Long-term, the underlying infrastructure (e.g., captures_underway) still needs refactoring to clearly define the interaction between graph capture, MemPools, and threads. That broader cleanup will be addressed in #168137. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168382 Approved by: https://github.com/eqy, https://github.com/ngimel, https://github.com/galv --- c10/cuda/CUDACachingAllocator.cpp | 22 +++++++++++++--------- test/test_cuda.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 1d70edde5a4ca..3d1837061e7b2 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1838,9 +1838,11 @@ class DeviceCachingAllocator { if (graph_reuse_context.find(info.capture_id) == graph_reuse_context.end()) { bool found = false; - for (auto& entry : captures_underway) { - if (entry.second(stream)) { - auto graph_pool = graph_pools.find(entry.first); + // Use the reverse iterator to search captures_underway in LIFO order. + for (auto it = captures_underway.rbegin(); it != captures_underway.rend(); + ++it) { + if (it->second(stream)) { + auto graph_pool = graph_pools.find(it->first); TORCH_INTERNAL_ASSERT( graph_pool != graph_pools.end(), "Could not find graph pool for capture."); @@ -2530,10 +2532,10 @@ class DeviceCachingAllocator { std::function filter) { std::lock_guard lock(mutex); create_or_incref_pool(mempool_id); - for (auto it2 = captures_underway.begin(); it2 != captures_underway.end(); - ++it2) { + for (auto it = captures_underway.begin(); it != captures_underway.end(); + ++it) { TORCH_CHECK( - it2->first != mempool_id, + it->first != mempool_id, "beginAllocateToPool: already recording to mempool_id"); } captures_underway.emplace_back(mempool_id, std::move(filter)); @@ -2962,9 +2964,11 @@ class DeviceCachingAllocator { // a capture, so it's usually 0, and we can short-circuit // cudaStreamCaptureStatus (which does a TLS lookup). if (C10_UNLIKELY(!captures_underway.empty())) { - for (auto& entry : captures_underway) { - if (entry.second(stream)) { - auto it1 = graph_pools.find(entry.first); + // Use the reverse iterator to search captures_underway in LIFO order. + for (auto it = captures_underway.rbegin(); it != captures_underway.rend(); + ++it) { + if (it->second(stream)) { + auto it1 = graph_pools.find(it->first); TORCH_INTERNAL_ASSERT(it1 != graph_pools.end()); if (size <= kSmallSize) { return it1->second->small_blocks; diff --git a/test/test_cuda.py b/test/test_cuda.py index 5712187775ef6..1ad9769072c23 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -5710,6 +5710,37 @@ def my_function(pool): s = p.snapshot() self.assertEqual(len(s), 1, "Expected to have a single segment") + @serialTest() + def test_nested_mempool(self): + torch.cuda.empty_cache() + pool1 = torch.cuda.MemPool() + pool2 = torch.cuda.MemPool() + pool3 = torch.cuda.MemPool() + + data = [] + nelem_1mb = 1024 * 1024 // 4 + + def allocate_data(): + x = torch.empty(nelem_1mb * 20, device="cuda") + data.append(x) + + with torch.cuda.use_mem_pool(pool1): + allocate_data() + with torch.cuda.use_mem_pool(pool2): + allocate_data() + with torch.cuda.use_mem_pool(pool3): + allocate_data() + allocate_data() + allocate_data() + + pool1_segments = torch.cuda.memory.memory_snapshot(pool1.id) + pool2_segments = torch.cuda.memory.memory_snapshot(pool2.id) + pool3_segments = torch.cuda.memory.memory_snapshot(pool3.id) + + self.assertEqual(len(pool1_segments), 2) + self.assertEqual(len(pool2_segments), 2) + self.assertEqual(len(pool3_segments), 1) + @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) From 52ac0f0dc4acacd219f1317fbc28ec631c01e07a Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 3 Dec 2025 11:04:08 -0800 Subject: [PATCH 193/338] [CI] Run both CPU and CUDA Pallas tests Need to be better fixed later on, but it's a move in the right direction --- .ci/pytorch/test.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index af6ddbcb838ce..44dff52974320 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1888,6 +1888,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then test_inductor_halide elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then + # NS: Remove me later, but pallas tests are pretty small + unset PYTORCH_TESTING_DEVICE_ONLY_FOR test_inductor_pallas elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then test_inductor_triton_cpu From 2887faaec6295d081580d09fce161201826c6d87 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 2 Dec 2025 23:05:52 -0800 Subject: [PATCH 194/338] [inductor_on_demand] Refactor - scoop regions and then compile regions (#169457) The intention is to give power users capability to just scoop out the regions in the beginning and then run their own passes and then finally compile the scooped out regions. It might be easier to work with a graph that already has scooped out regions for better abstraction. If not done this way, it is necessary to run the regional inductor pass at the very end of the compiler to avoid striding issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169457 Approved by: https://github.com/eellison, https://github.com/yushangdi --- torch/fx/passes/regional_inductor.py | 118 ++++++++++++++++++++------- 1 file changed, 87 insertions(+), 31 deletions(-) diff --git a/torch/fx/passes/regional_inductor.py b/torch/fx/passes/regional_inductor.py index 4146fd6c967bf..ae98950ab60b0 100644 --- a/torch/fx/passes/regional_inductor.py +++ b/torch/fx/passes/regional_inductor.py @@ -121,46 +121,100 @@ def _needs_inductor_compile(node: torch.fx.Node): ) -def _compile_fx_annotated_nodes_with_inductor(gm): - from torch.fx.graph import _BoxedCodeGen - from torch.fx.passes.operator_support import OperatorSupport +class _RegionScooper: + """ + Scoops out the inductor marked regions. It does NOT compile them. + """ - found_marked_node = False - for node in gm.graph.nodes: - if _needs_inductor_compile(node): - found_marked_node = True - break + @staticmethod + def scoop_regions(gm): + from torch.fx.passes.operator_support import OperatorSupport + + found_marked_node = False + for node in gm.graph.nodes: + if _needs_inductor_compile(node): + found_marked_node = True + break + + if not found_marked_node: + logger.info("No inductor marked nodes found") + return gm + + class InductorMarkedNodes(OperatorSupport): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return _needs_inductor_compile(node) + + marked_nodes = InductorMarkedNodes() + return _partition_by_supported_nodes( + gm, marked_nodes, "__marked_inductor_submod" + ) + + @staticmethod + def recursively_scoop_regions(gm): + for node in gm.graph.find_nodes(op="get_attr"): + if _needs_inductor_compile(node): + # If the get_attr itself is marked for compile, the outer graph will + # take care of it. If we dont do that, we end up with nested + # regional inductor compiles that do not work well. + continue + submod = getattr(gm, node.target) + if isinstance(submod, torch.fx.GraphModule): + _RegionScooper.recursively_scoop_regions(submod) + + return _RegionScooper.scoop_regions(gm) + + def __call__(self, gm): + with torch.fx.traceback.preserve_node_meta(enable=False): + return _RegionScooper.recursively_scoop_regions(gm) + + +class _RegionCompiler: + """ + Compiles the scooped out regions. + """ - if not found_marked_node: - logger.info("No inductor marked nodes found") + @staticmethod + def compile_region(gm): + from torch.fx.graph import _BoxedCodeGen + + gm = _compile_submod(gm, "__marked_inductor_submod") + gm.graph.set_codegen(_BoxedCodeGen()) + gm.recompile() return gm - class InductorMarkedNodes(OperatorSupport): - def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return _needs_inductor_compile(node) + @staticmethod + def recursively_compile_regions(gm): + # Find if the graph module has a scooped out region + found_region = False + for node in gm.graph.find_nodes(op="call_module"): + submod = getattr(gm, node.target) + if isinstance(submod, torch.fx.GraphModule): + if node.target.startswith("__marked_inductor_submod"): + found_region = True - marked_nodes = InductorMarkedNodes() - gm = _partition_by_supported_nodes(gm, marked_nodes, "__marked_inductor_submod") - gm = _compile_submod(gm, "__marked_inductor_submod") + # Recurse through the subgraphs + for node in gm.graph.find_nodes(op="get_attr"): + submod = getattr(gm, node.target) + if isinstance(submod, torch.fx.GraphModule): + _RegionCompiler.recursively_compile_regions(submod) - gm.graph.set_codegen(_BoxedCodeGen()) - gm.recompile() + if found_region: + return _RegionCompiler.compile_region(gm) + return gm - return gm + def __call__(self, gm): + with torch.fx.traceback.preserve_node_meta(enable=False): + return _RegionCompiler.recursively_compile_regions(gm) -def _recursive_compile_fx_annotated_nodes_with_inductor(gm): - for node in gm.graph.find_nodes(op="get_attr"): - if _needs_inductor_compile(node): - # If the get_attr itself is marked for compile, the outer graph will - # take care of it. If we dont do that, we end up with nested - # regional inductor compiles that do not work well. - continue - submod = getattr(gm, node.target) - if isinstance(submod, torch.fx.GraphModule): - _recursive_compile_fx_annotated_nodes_with_inductor(submod) +def _create_inductor_marked_regions(gm): + with torch.fx.traceback.preserve_node_meta(enable=False): + return _RegionScooper()(gm) - return _compile_fx_annotated_nodes_with_inductor(gm) + +def _compile_inductor_marked_regions(gm): + with torch.fx.traceback.preserve_node_meta(enable=False): + return _RegionCompiler()(gm) @compatibility(is_backward_compatible=False) @@ -181,4 +235,6 @@ def regional_inductor(gm, *example_args): # fuser utils create new nodes using create_proxy which retains the seq_nr # metadata and cause issues with torch.fx.traceback.preserve_node_meta(enable=False): - return _recursive_compile_fx_annotated_nodes_with_inductor(gm) + gm = _create_inductor_marked_regions(gm) + gm = _compile_inductor_marked_regions(gm) + return gm From abfa1a6d65c7c159e35c72c25979b9da4971689e Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 3 Dec 2025 19:32:43 +0000 Subject: [PATCH 195/338] [aot_compile] Pass additional globals to callable (#169070) Corresponding vllm change -- https://github.com/vllm-project/vllm/pull/29428 Fixes https://github.com/vllm-project/vllm/issues/27591 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169070 Approved by: https://github.com/zhxchen17 --- test/dynamo/test_aot_compile.py | 32 +++++++++++++++++++++++++ torch/_dynamo/aot_compile.py | 15 ++++++++---- torch/_dynamo/convert_frame.py | 41 +++++++++++++++++++++++++++++++++ torch/compiler/__init__.py | 7 ++++-- 4 files changed, 89 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index 8ea9ca2bb72c0..3146a37cb661a 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -34,6 +34,11 @@ EPS = torch.tensor(1e-7) +class MooType: + def __init__(self, x): + self.x = x + + class CustomCompiledFunction(torch._dynamo.aot_compile.SerializableCallable): def __init__(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]): self.gm = gm @@ -800,6 +805,33 @@ def compute(x, y): actual = compiled_fn(*inputs) self.assertEqual(expected, actual) + def test_external_refs_validation(self): + """Test that external refs tracking and f_globals parameter work correctly""" + + def fn(x, y): + return MooType(x + y) + + def make_inputs(): + return (torch.randn(3, 4), torch.randn(3, 4)) + + compiled_fn = torch.compile(fn, fullgraph=True).aot_compile((make_inputs(), {})) + test_inputs = make_inputs() + expected = fn(*test_inputs) + actual = compiled_fn(*test_inputs) + self.assertEqual(expected.x, actual.x) + compiled_fn.save_compiled_function(self.path()) + + with self.assertRaisesRegex(RuntimeError, "Missing required external ref"): + with open(self.path(), "rb") as f: + compiled_fn = torch.compiler.load_compiled_function(f) + + with open(self.path(), "rb") as f: + compiled_fn = torch.compiler.load_compiled_function( + f, f_globals=fn.__globals__ + ) + actual = compiled_fn(*test_inputs) + self.assertEqual(expected.x, actual.x) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index 14309dbe15541..7bc03aff84a20 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -78,6 +78,7 @@ def reducer_override(self, obj: Any) -> Any: class AOTCompiledFunction: _artifacts: CompileArtifacts _guard_check_enabled: bool = True + _extra_globals: Optional[dict[str, object]] = None def guard_check(self, *args: Any, **kwargs: Any) -> bool: f_locals: dict[str, Any] = {} @@ -101,7 +102,9 @@ def __post_init__(self) -> None: # pyrefly: ignore [read-only] self.fn = self._artifacts.runtime_env.forward_callable( - self._artifacts.backend_id, self._artifacts.compiled_fn + self._artifacts.backend_id, + self._artifacts.compiled_fn, + extra_globals=self._extra_globals, ) if self._artifacts.guard_manager is None: @@ -149,7 +152,9 @@ def serialize(cls, fn: "AOTCompiledFunction") -> bytes: return buf.getvalue() @classmethod - def deserialize(cls, data: bytes) -> "AOTCompiledFunction": + def deserialize( + cls, data: bytes, f_globals: Optional[dict[str, object]] = None + ) -> "AOTCompiledFunction": from torch._dynamo.package import SerializedCode state = pickle.loads(data) @@ -163,7 +168,7 @@ def deserialize(cls, data: bytes) -> "AOTCompiledFunction": state["original_code"] = SerializedCode.to_code_object(state["original_code"]) artifacts = CompileArtifacts(**state) - return cls(artifacts) + return cls(artifacts, _extra_globals=f_globals) def disable_guard_check(self) -> None: self._guard_check_enabled = False @@ -271,7 +276,9 @@ def new_guard_filter_fn( device_type=device_type, backend_name=getattr(backend, "compiler_name", "unknown"), ) - aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts) + aot_compiled_fn = AOTCompiledFunction( + _artifacts=artifacts, _extra_globals=fn.__globals__ + ) return aot_compiled_fn diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 87dc80e99bd79..0da68fa5fe042 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -26,6 +26,7 @@ import collections import contextlib import cProfile +import dataclasses import dis import functools import gc @@ -932,6 +933,7 @@ class GraphRuntimeEnv: used_globals: dict[str, Any] closure: Optional[tuple[Any, ...]] argdefs: Optional[tuple[Any, ...]] + external_refs: set[str] = dataclasses.field(default_factory=set) def forward_callable( self, @@ -950,6 +952,10 @@ def forward_callable( **(extra_globals or {}), backend_id: compiled_fn, } + + # check that all external references are available + self._check_external_refs(f_globals) + return types.FunctionType( self.bytecode, f_globals, @@ -957,6 +963,18 @@ def forward_callable( argdefs=self.argdefs, ) + def _check_external_refs(self, f_globals: dict[str, Any]) -> None: + missing_refs = [] + for ref in self.external_refs: + if ref not in f_globals: + missing_refs.append(ref) + + if missing_refs: + raise RuntimeError( + f"Missing required external references: {missing_refs}. " + "Please load AOT compiled function with `f_globals=`" + ) + @dataclass class GraphCaptureOutput: @@ -1003,14 +1021,37 @@ def get_runtime_env(self) -> GraphRuntimeEnv: if global_name in self.f_globals: used_globals[global_name] = self.f_globals[global_name] + # Scan bytecode for all external references + external_refs = self._get_external_refs(self.bytecode) + return GraphRuntimeEnv( bytecode=self.bytecode, import_sources=self.import_sources, used_globals=used_globals, closure=self.closure, argdefs=self.argdefs, + external_refs=external_refs, ) + @staticmethod + def _get_external_refs(bytecode: types.CodeType) -> set[str]: + import dis + + external_refs: set[str] = set() + + # Get all instructions from the bytecode + for instruction in dis.get_instructions(bytecode): + # LOAD_GLOBAL loads a global variable or a builtin + if instruction.opname == "LOAD_GLOBAL": + if instruction.argval: + external_refs.add(instruction.argval) + # LOAD_NAME loads a name (used in module-level code, less common in functions) + elif instruction.opname == "LOAD_NAME": + if instruction.argval: + external_refs.add(instruction.argval) + + return external_refs + @dataclass class CaptureOutput: diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 809ec86fa5ec4..442cd7d765b89 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -668,7 +668,9 @@ def nested_compile_region(fn=None): return _mark_compile_region(fn) -def load_compiled_function(file: io.IOBase) -> Callable[..., Any]: +def load_compiled_function( + file: io.IOBase, *, f_globals: Optional[dict[str, object]] = None +) -> Callable[..., Any]: """ Load an aot-compiled function from a file. @@ -678,6 +680,7 @@ def load_compiled_function(file: io.IOBase) -> Callable[..., Any]: Args: file: A file-like object containing the serialized compiled function. + f_globals: Optional globals to be loaded into the compiled function. Returns: A torch-compiled function with compilation preloaded from disk. @@ -685,4 +688,4 @@ def load_compiled_function(file: io.IOBase) -> Callable[..., Any]: from torch._dynamo.aot_compile import AOTCompiledFunction data = file.read() - return AOTCompiledFunction.deserialize(data) + return AOTCompiledFunction.deserialize(data, f_globals) From 8c73bbbb02159223c0c97d268a0a74cb78158a1c Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 3 Dec 2025 11:56:36 -0600 Subject: [PATCH 196/338] [MPS] Migrate `clamp.Tensor_out` to Metal (#169407) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169407 Approved by: https://github.com/malfet --- aten/src/ATen/native/mps/MetalShaderLibrary.h | 1 + aten/src/ATen/native/mps/OperationUtils.mm | 91 +++++++++ .../native/mps/kernels/BinaryKernel.metal | 18 ++ .../native/mps/kernels/TensorCompare.metal | 25 +++ .../native/mps/operations/BinaryKernel.mm | 10 + .../native/mps/operations/TensorCompare.mm | 15 +- aten/src/ATen/native/native_functions.yaml | 3 +- c10/metal/indexing.h | 172 ++++++++++++++++++ 8 files changed, 329 insertions(+), 6 deletions(-) create mode 100644 aten/src/ATen/native/mps/kernels/TensorCompare.metal diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index fcdf39b8a9f4b..9a12220eca486 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -146,6 +146,7 @@ class MetalShaderLibrary { const std::string& name, const std::optional alpha = std::nullopt, const std::optional scalar_arg_type = std::nullopt); + void exec_ternary_kernel(TensorIteratorBase& iter, const std::string& name); template void exec_unary_kernel_with_params( diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index d5ed84aec5617..df06013492f57 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -1133,6 +1133,97 @@ static dispatch_data_t getSectionData(const std::string& name) { }); } +void MetalShaderLibrary::exec_ternary_kernel(TensorIteratorBase& iter, const std::string& name) { + // TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?) + // Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with + // double as common dtype (because Python floating point are always 64-bit values) + TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS"); + + // Skip for empty iterators + if (iter.numel() == 0) { + return; + } + + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_binary_kernel(sub_iter, name); + } + return; + } + + auto convert_double_scalar = [](Tensor& t) { + if (t.dim() != 0) { + return; + } + if (t.scalar_type() == kDouble) { + t = t.to(kFloat); + } else if (t.scalar_type() == kComplexDouble) { + t = t.to(kComplexFloat); + } + }; + + Tensor input = iter.input(0); + Tensor other1 = iter.input(1); + Tensor other2 = iter.input(2); + Tensor out = iter.output(); + + convert_double_scalar(input); + convert_double_scalar(other1); + convert_double_scalar(other2); + + MPSStream* mpsStream = getCurrentMPSStream(); + const auto cast_needed = + (input.scalar_type() != other1.scalar_type()) || (input.scalar_type() != other2.scalar_type()); + const auto suffix = iter.is_contiguous() ? "dense" : "strided"; + // TODO: Implicitly pass both input and output types to non-cast kernels + const auto kernel_name = cast_needed + ? fmt::format("{}_{}_cast_{}", name, suffix, scalarToMetalTypeString(out)) + : fmt::format("{}_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input)); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = mpsStream->commandEncoder(); + auto binaryPSO = getPipelineStateForFunc(kernel_name); + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other1, other2}); + [computeEncoder setComputePipelineState:binaryPSO]; + // Set input and output tensors + bind_iter_tensors(computeEncoder, iter); + // Iterator is contiguous if all of its elements are dense in storage, + // i.e. it's true for both row-first and column-first tensors + if (iter.is_contiguous()) { + if (cast_needed) { + std::array sizes = {static_cast(c10::elementSize(input.scalar_type())), + static_cast(c10::elementSize(other1.scalar_type())), + static_cast(c10::elementSize(other2.scalar_type()))}; + std::array types = {static_cast(input.scalar_type()), + static_cast(other1.scalar_type()), + static_cast(other2.scalar_type())}; + mtl_setArgs<4>(computeEncoder, sizes, types); + } + } else { + // Please note that shapes and strides of the iterator might be + // different than that of its operands, for example binary op + // between 4x4 tensor and scalar will result in 1D 16 element iterator + std::array types = {static_cast(input.scalar_type()), + static_cast(other1.scalar_type()), + static_cast(other2.scalar_type()), + static_cast(out.scalar_type())}; + mtl_setArgs<4>(computeEncoder, + iter.shape(), + iter.strides(0), + iter.strides(1), + iter.strides(2), + iter.strides(3), + iter.ndim(), + types); + } + mtl_dispatch1DJob(computeEncoder, binaryPSO, iter.numel()); + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); +} + MetalShaderLibrary& MetalShaderLibrary::getBundledLibrary() { static BundledShaderLibary l; return l; diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 5cb6dd38822a6..c0ac66b6cf501 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -60,6 +60,20 @@ struct fmin_functor { } }; +struct maximum_functor { + template + inline T operator()(const T a, const T b) { + return max(a, b); + } +}; + +struct minimum_functor { + template + inline T operator()(const T a, const T b) { + return min(a, b); + } +}; + struct copysign_functor { template inline enable_if_t, T> operator()( @@ -396,6 +410,10 @@ REGISTER_FLOAT_BINARY_OP(copysign); REGISTER_INT2FLOAT_BINARY_OP(copysign); REGISTER_FLOAT_BINARY_OP(fmax); REGISTER_FLOAT_BINARY_OP(fmin); +REGISTER_FLOAT_BINARY_OP(maximum); +REGISTER_INTEGER_BINARY_OP(maximum); +REGISTER_FLOAT_BINARY_OP(minimum); +REGISTER_INTEGER_BINARY_OP(minimum); REGISTER_FLOAT_BINARY_OP(nextafter); REGISTER_FLOAT_BINARY_OP(zeta); REGISTER_INT2FLOAT_BINARY_OP(zeta); diff --git a/aten/src/ATen/native/mps/kernels/TensorCompare.metal b/aten/src/ATen/native/mps/kernels/TensorCompare.metal new file mode 100644 index 0000000000000..0f34dfc898384 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/TensorCompare.metal @@ -0,0 +1,25 @@ +#include +#include +#include +#include +using namespace metal; + +struct clamp_functor { + template + inline T operator()(const T a, const T b_min, const T c_max) { + return c10::metal::min(c10::metal::max(a, b_min), c_max); + } +}; + +#define REGISTER_ALL_CLAMP_OPS(T) REGISTER_TERNARY_OP(clamp, T, T); + +REGISTER_ALL_CLAMP_OPS(long); +REGISTER_ALL_CLAMP_OPS(int); +REGISTER_ALL_CLAMP_OPS(short); +REGISTER_ALL_CLAMP_OPS(uchar); +REGISTER_ALL_CLAMP_OPS(char); +REGISTER_ALL_CLAMP_OPS(bool); + +REGISTER_ALL_CLAMP_OPS(float); +REGISTER_ALL_CLAMP_OPS(half); +REGISTER_ALL_CLAMP_OPS(bfloat); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index f8baf2e7f1171..c08f828b26e08 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -75,6 +75,14 @@ static void fmin_mps_kernel(TensorIteratorBase& iter) { } } +static void maximum_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "maximum"); +} + +static void minimum_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "minimum"); +} + static void copysign_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "copysign"); } @@ -216,6 +224,8 @@ static void hypot_mps_kernel(TensorIteratorBase& iter) { REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel) REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) +REGISTER_DISPATCH(maximum_stub, &maximum_mps_kernel) +REGISTER_DISPATCH(minimum_stub, &minimum_mps_kernel) REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel) REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel) diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index ed659bddd65cc..af8dad7671f26 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -23,6 +23,12 @@ #endif namespace at::native { +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif + namespace mps { struct CachedGraph : public MPSCachedGraph { @@ -374,10 +380,6 @@ static void is_posneginf_helper(TensorIteratorBase& iter, bool is_neg) { } // namespace mps // APIs exposed to at::native scope -TORCH_IMPL_FUNC(clamp_Tensor_out_mps) -(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t) { - mps::clamp_tensor_out_mps(input_t, min, max, output_t, __func__); -} TORCH_IMPL_FUNC(clamp_out_mps) (const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t) { @@ -604,8 +606,13 @@ static void isposinf_kernel_mps(TensorIteratorBase& iter) { mps::is_posneginf_helper(iter, false); } +static void clamp_kernel_mps(TensorIteratorBase& iter) { + lib.exec_ternary_kernel(iter, "clamp"); +} + REGISTER_DISPATCH(where_kernel, &where_kernel_mps) REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_mps) REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_mps) +REGISTER_DISPATCH(clamp_stub, &clamp_kernel_mps) } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 39df81ff44bce..a4d7797273d0d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1572,8 +1572,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_Tensor_out - MPS: clamp_Tensor_out_mps + CPU, CUDA, MPS: clamp_Tensor_out tags: pointwise - func: clamp_max(Tensor self, Scalar max) -> Tensor diff --git a/c10/metal/indexing.h b/c10/metal/indexing.h index 9cfe65f6a03a8..79cde5554fb25 100644 --- a/c10/metal/indexing.h +++ b/c10/metal/indexing.h @@ -475,5 +475,177 @@ kernel void binary_alpha_dense_cast( constant DTYPEA& alpha, \ constant uint4& sizes_types, \ uint tid) + +// Ternary elementwise ops kernels +// Right now there are 4 flavors available: +// - ternary_dense where both input, other1, other2, and output are dense and +// share the same type +// - ternary_strided when all inputs are of the same types, but some elements +// are strided +// - ternary_dense_cast - inputs are dense, but of different dtypes +// - ternary_strided_cast - inputs or output are strided and of different dtypes +// Note about accuracy (for more info see +// https://github.com/pytorch/pytorch/issues/152736) Sometimes when kernel is +// invoked to produce `half` output, but one of the arguments is float arguments +// should be upcast to float, rather than downcast to half At the moment this is +// expressed with `om_t` optional argument (which stands for opmath_type) which +// is identical to output type but could be something else + +template +kernel void ternary_strided( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other1 [[buffer(2)]], + constant void* other2 [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* output_strides [[buffer(5)]], + constant long* input_strides [[buffer(6)]], + constant long* other1_strides [[buffer(7)]], + constant long* other2_strides [[buffer(8)]], + constant uint& ndim [[buffer(9)]], + constant uint4& types [[buffer(10)]], + uint index [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto other1_offs = offset_from_coord(pos, other1_strides, ndim); + const auto other2_offs = offset_from_coord(pos, other2_strides, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + const auto a = val_at_offs(input, input_offs); + const auto b = val_at_offs(other1, other1_offs); + const auto c = val_at_offs(other2, other2_offs); + ref_at_offs(output, output_offs) = + static_cast(f(om_t(a), om_t(b), om_t(c))); +} + +template > +kernel void ternary_strided_cast( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other1 [[buffer(2)]], + constant void* other2 [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* output_strides [[buffer(5)]], + constant long* input_strides [[buffer(6)]], + constant long* other1_strides [[buffer(7)]], + constant long* other2_strides [[buffer(8)]], + constant uint& ndim [[buffer(9)]], + constant uint4& types [[buffer(10)]], + uint index [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto other1_offs = offset_from_coord(pos, other1_strides, ndim); + const auto other2_offs = offset_from_coord(pos, other2_strides, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + const auto a = + val_at_offs(input, input_offs, static_cast(types.x)); + const auto b = + val_at_offs(other1, other1_offs, static_cast(types.y)); + const auto c = + val_at_offs(other2, other2_offs, static_cast(types.z)); + ref_at_offs(output, output_offs) = static_cast(f(a, b, c)); +} + +template > +kernel void ternary_dense( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T* other1 [[buffer(2)]], + constant T* other2 [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + out[tid] = static_cast( + f(om_t(input[tid]), om_t(other1[tid]), om_t(other2[tid]))); +} + +template +kernel void ternary_dense_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other1 [[buffer(2)]], + constant void* other2 [[buffer(3)]], + constant uint3& sizes [[buffer(4)]], + constant uint3& types [[buffer(5)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + const auto a = + val_at_offs(input, tid * sizes.x, static_cast(types.x)); + const auto b = val_at_offs( + other1, tid * sizes.y, static_cast(types.y)); + const auto c = val_at_offs( + other2, tid * sizes.z, static_cast(types.z)); + out[tid] = static_cast(f(a, b, c)); +} + +#define REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, OMT) \ + static_assert( \ + ::metal::is_same_v< \ + DTYPEO, \ + ::c10::metal::result_of>, \ + "Output dtype mismatch for ternary op " #NAME " and input " #DTYPEI); \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ + c10::metal::ternary_strided( \ + device void* out, \ + constant void* input, \ + constant void* other1, \ + constant void* other2, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other1_strides, \ + constant long* other2_strides, \ + constant uint& ndim, \ + constant uint4& types, \ + uint tid); \ + template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \ + metal::ternary_strided_cast( \ + device void* out, \ + constant void* input, \ + constant void* other1, \ + constant void* other2, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other1_strides, \ + constant long* other2_strides, \ + constant uint& ndim, \ + constant uint4& types, \ + uint tid); \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ + c10::metal::ternary_dense( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant DTYPEI * input_, \ + constant DTYPEI * other1_, \ + constant DTYPEI * other2_, \ + uint tid); \ + template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \ + metal::ternary_dense_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* input, \ + constant void* other1, \ + constant void* other2, \ + constant uint3& sizes, \ + constant uint3& types, \ + uint tid) + +// OpMath ternary Op promotes inputs to higher precision type before Functor +// call +#define REGISTER_OPMATH_TERNARY_OP(NAME, DTYPEI, DTYPEO) \ + REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, ::c10::metal::opmath_t) + +#define REGISTER_TERNARY_OP(NAME, DTYPEI, DTYPEO) \ + REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI) + } // namespace metal } // namespace c10 From 39d07dbf03a911bdd45d1af78d8638dc92074938 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 3 Dec 2025 07:53:24 -0800 Subject: [PATCH 197/338] Use stable topological sort in fuse_by_partitions (#167397) legalize_graph() performs a topo sort that shuffles the nodes is a global way, making the result unpredictable. We should avoid this in graph pass in general. This problem is discovered when testing regional_inductor, a single fuse region trigger the global reordering. Before https://www.internalfb.com/intern/diffing/?before_paste_number=2029217728&after_paste_number=2029218006®ex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff After https://www.internalfb.com/intern/diffing/?paste_number=2029162294®ex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff Left is gm before regional_inductor, right is after. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167397 Approved by: https://github.com/ezyang --- docs/source/conf.py | 1 + test/allowlist_for_publicAPI.json | 2 +- torch/fx/passes/tools_common.py | 71 ++++++++++++++++++++++++++++ torch/fx/passes/utils/fuser_utils.py | 46 +++++++++++------- 4 files changed, 103 insertions(+), 17 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7a3663ca062df..5c404f8c129fc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -950,6 +950,7 @@ "get_node_target", "is_node_output_tensor", "legalize_graph", + "stable_topological_sort", # torch.fx.passes.utils.common "compare_graphs", "lift_subgraph_as_module", diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index d01d41d37997e..b6c203aea4ab6 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2090,7 +2090,7 @@ "SimpleQueue", "Tuple", "compatibility", - "legalize_graph", + "stable_topological_sort", "lift_subgraph_as_module" ], "torch.fx.tensor_type": [ diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index 212b094e86e35..d6a8f0df84497 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import collections +import heapq import operator from collections.abc import Mapping from dataclasses import dataclass @@ -17,6 +18,7 @@ "is_node_output_tensor", "FxNetAccFusionsFinder", "legalize_graph", + "stable_topological_sort", ] Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]] @@ -258,6 +260,10 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: Returns: The graph module in-place sorted + + Warning: + This topological sort is NOT stable, it will NOT preserve the original node order. + If you need a stable topological sort, use stable_topological_sort instead. """ # These operators are used for making runtime assertions before any @@ -317,3 +323,68 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: new_graph._codegen = gm.graph._codegen gm.graph = new_graph return gm + + +@compatibility(is_backward_compatible=False) +def stable_topological_sort(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order while preserving the original node order + as much as possible. + + This function performs a stable topological sort where nodes appear in an order that: + 1. Respects data dependencies (topological ordering) + 2. Preserves the original node order when there are no dependency constraints + + The algorithm uses Kahn's algorithm with a priority queue: nodes with all dependencies + satisfied are added to a min-heap, ordered by their original position. This ensures + we always process the earliest node in the original order among ready nodes. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + Returns: + The graph module in-place sorted + """ + indeg = dict.fromkeys(gm.graph.nodes, 0) + new_graph = torch.fx.Graph() + + # Build node to original index mapping + node_to_id: dict[torch.fx.Node, int] = { + node: idx for idx, node in enumerate(gm.graph.nodes) + } + + # Track how many unfulfilled dependencies each node has + for node in gm.graph.nodes: + for user in node.users: + indeg[user] += 1 + + # Priority queue: (original_index, node) + # Use min-heap to always process the node with smallest original index + ready_queue: list[tuple[int, torch.fx.Node]] = [] + for node in gm.graph.nodes: + if indeg[node] == 0: + heapq.heappush(ready_queue, (node_to_id[node], node)) + + env: dict[torch.fx.Node, torch.fx.Node] = {} + + # Process nodes + while ready_queue: + # Pop node with smallest original index + _, cur = heapq.heappop(ready_queue) + env[cur] = new_graph.node_copy(cur, lambda x: env[x]) + + # Update in-degrees and add newly ready nodes + for user in cur.users: + indeg[user] -= 1 + if indeg[user] == 0: + heapq.heappush(ready_queue, (node_to_id[user], user)) + + # Check if all nodes were processed + assert len(new_graph.nodes) == len(gm.graph.nodes), ( + f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}" + ) + + new_graph._codegen = gm.graph._codegen + gm.graph = new_graph + return gm diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 33db9fd03d790..e5509187b39dd 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -7,7 +7,7 @@ from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule from torch.fx.node import Node -from torch.fx.passes.tools_common import legalize_graph, NodeList, NodeSet +from torch.fx.passes.tools_common import NodeList, NodeSet, stable_topological_sort from torch.fx.passes.utils import lift_subgraph_as_module # type: ignore[attr-defined] @@ -220,22 +220,36 @@ def insert_subgm( submodule_name = sub_gm.__class__.__name__ gm.add_submodule(submodule_name, sub_gm) - # Create a call_module node in main graph. - module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None) + def last_node(target_nodes: tuple[Node, ...]) -> Node | None: + for node in reversed(gm.graph.nodes): + if node in target_nodes: + return node + return None - output_node = sub_gm.graph.output_node() - if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple): - # main_remapping[comp.orig_outputs[0]] = module_node - orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True) - else: - for i, orig_output in enumerate(orig_outputs): - # Use Proxy to record getitem access. - proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] - orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) + last_output_node: Node | None = last_node(orig_outputs) + assert last_output_node is not None - module_node.meta["val"] = tuple( - orig_output.meta.get("val", None) for orig_output in orig_outputs + # Create a call_module node in main graph. + with gm.graph.inserting_after(last_output_node): + module_node = gm.graph.call_module( + submodule_name, args=orig_inputs, kwargs=None ) + output_node = sub_gm.graph.output_node() + + next_node = module_node.next + with gm.graph.inserting_before(next_node): + if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple): + # main_remapping[comp.orig_outputs[0]] = module_node + orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True) + else: + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] + orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) + + module_node.meta["val"] = tuple( + orig_output.meta.get("val", None) for orig_output in orig_outputs + ) return gm @@ -269,7 +283,7 @@ def fuse_by_partitions( erase_nodes(gm, sorted_nodes) - # topological sort original gm with newly created sub_gm - legalize_graph(gm) + stable_topological_sort(gm) + gm.graph.lint() return gm From 195f92e98d3d66738577f11f22c4b5c8a1c76dd5 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 2 Dec 2025 14:49:31 -0800 Subject: [PATCH 198/338] [pallas backend] fix breakages (#169406) test/inductor/test_pallas.py::PallasTestsCUDA::test_different_shapes test/inductor/test_pallas.py::PallasTestsCUDA::test_non_power_of_2_multiple_ops est/inductor/test_pallas.py::PallasTestsCUDA::test_non_power_of_2_sizes started failing since CI wasnt running, lets fix them Pull Request resolved: https://github.com/pytorch/pytorch/pull/169406 Approved by: https://github.com/yarongmu-google, https://github.com/malfet --- torch/_inductor/codegen/pallas.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 2ae68dbca575f..854adf3f53d34 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -988,7 +988,9 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove # Generate iteration variables as jnp.arange arrays # These are used by index_expr operations like torch.arange - if self.range_tree_nodes: + # Skip on GPU with masked ops - iteration vars would create non-power-of-2 arrays + # which are not supported by Pallas Triton backend + if self.range_tree_nodes and not self.use_masked_ops: code.writeline("# Define iteration variables as JAX arrays") # Get the first output buffer's shape for reshaping first_output_shape = None @@ -1020,6 +1022,11 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove except (TypeError, ValueError): length_val = None + # Skip symbolic lengths - jnp.arange requires concrete values + # This happens with dynamic shapes + if length_val is None: + continue + if ( first_output_shape and len(first_output_shape) > 1 From dfbd3714d15c37a7b83b322a6b60f997fc00f50c Mon Sep 17 00:00:00 2001 From: tianrengao Date: Wed, 3 Dec 2025 20:29:07 +0000 Subject: [PATCH 199/338] Enable custom collective op autotuning (#167294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add collective op autotuning with distributed benchmarking Collective operations can have multiple equivalent implementations with different performance. Existing autotuning only supports single-process benchmarking, making it impossible to autotune collective ops that require multi-rank coordination. **Summary:** Added distributed benchmarking via `benchmark_collective_choice()` and `register_custom_op_autotuning()` API. All ranks are synchronized and benchmarked at the same time to ensure a fair comparison. The benchmark across all possible ranks(world_size) has a timeout detection(default 30s, configurable in configs.py) and exception handling to prevent deadlocks. If any rank times out or encounters an exception, all ranks fallback to regular benchmark together. **Example:** ```python # Define custom op with multiple implementations @torch.library.custom_op("mylib::allreduce", mutates_args=()) def my_allreduce(x: torch.Tensor) -> torch.Tensor: return torch.ops._c10d_functional.all_reduce_(x.clone(), "sum", "default") # Register autotuning choices register_custom_op_autotuning( my_allreduce, configs=[ CustomOpConfig(lambda x: all_reduce_(x.clone(), "sum", "default")), CustomOpConfig(lambda x: all_reduce_(x.clone(), "avg", "default") * world_size), ], ) # Compile and autotune model = torch.compile(MyModel()) output = model(input) # Autotuning happens here ``` **Implementation:** Different from regular benchmark which calls benchmarker.benchmark with warmup and rep, collective op benchmarking should handle barrier synchronization and timing. So I created `collective_benchmark` for subgraphcaller choice to only run the implementation once with cached compiled modules. Then the timing and synchronization will be handled in select_algorithm.py **example log** ``` Autotune Choices Stats: {"num_choices": 3, "num_triton_choices": 0, "best_kernel": "test::vllm_allreduce_autotuned_nccl_allreduce_direct_1", "best_kernel_desc": "CustomOp nccl_allreduce_direct", "best_time": 0.004397066775709391} Autotune Choices Stats: {"num_choices": 3, "num_triton_choices": 0, "best_kernel": "test::vllm_allreduce_autotuned_nccl_allreduce_direct_1", "best_kernel_desc": "CustomOp nccl_allreduce_direct", "best_time": 0.004397066775709391} Autotune Choices Stats: {"num_choices": 3, "num_triton_choices": 0, "best_kernel": "test::vllm_allreduce_autotuned_nccl_allreduce_direct_1", "best_kernel_desc": "CustomOp nccl_allreduce_direct", "best_time": 0.004397066775709391} Autotune Choices Stats: {"num_choices": 3, "num_triton_choices": 0, "best_kernel": "test::vllm_allreduce_autotuned_nccl_allreduce_direct_1", "best_kernel_desc": "CustomOp nccl_allreduce_direct", "best_time": 0.004397066775709391} [rank0]:W1126 14:09:52.629000 394749 torch/_inductor/select_algorithm.py:3996] [0/0] [COLLECTIVE AUTOTUNING] All timings: [rank0]:W1126 14:09:52.630000 394749 torch/_inductor/select_algorithm.py:3999] [0/0] - test::vllm_allreduce_autotuned_nccl_allreduce_direct_1: 0.004397 ms ← SELECTED [rank0]:W1126 14:09:52.630000 394749 torch/_inductor/select_algorithm.py:3999] [0/0] - test::vllm_allreduce_autotuned_vllm_buffer_copy_allreduce_0: 0.004507 ms [rank0]:W1126 14:09:52.630000 394749 torch/_inductor/select_algorithm.py:3999] [0/0] - test::vllm_allreduce_autotuned_fallback_default: 0.012899 ms ``` **Test Plan:** * Added `test_equivalent_allreduce_strategies` (2 ranks) and `test_all_gather_4ranks` (4 ranks) * Verified timeout detection and exception handling prevent deadlocks Pull Request resolved: https://github.com/pytorch/pytorch/pull/167294 Approved by: https://github.com/shunting314, https://github.com/mlazos, https://github.com/eellison --- test/inductor/test_collective_autotuning.py | 189 +++++++++++++++ torch/_inductor/codegen/subgraph.py | 55 ++++- torch/_inductor/config.py | 10 + torch/_inductor/kernel/custom_op.py | 47 ++-- torch/_inductor/select_algorithm.py | 244 +++++++++++++++++++- torch/_inductor/utils.py | 21 ++ 6 files changed, 532 insertions(+), 34 deletions(-) create mode 100644 test/inductor/test_collective_autotuning.py diff --git a/test/inductor/test_collective_autotuning.py b/test/inductor/test_collective_autotuning.py new file mode 100644 index 0000000000000..a5a05d05a9028 --- /dev/null +++ b/test/inductor/test_collective_autotuning.py @@ -0,0 +1,189 @@ +# Owner(s): ["module: inductor"] + +import torch +import torch.distributed as dist +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import run_tests + + +class TestCollectiveAutotuning2Ranks(MultiProcessTestCase): + """Test collective autotuning with 2 ranks""" + + @property + def world_size(self): + return 2 + + def setUp(self): + super().setUp() + self._spawn_processes() + + @skip_if_lt_x_gpu(2) + def test_equivalent_allreduce_strategies(self): + """ + Test autotuning between mathematically equivalent all_reduce strategies. + + Strategy 1: sum all_reduce + Strategy 2: avg all_reduce * world_size + """ + dist.init_process_group( + backend="nccl", + init_method=f"file:///tmp/test_equiv_allreduce_{self.id()}", + world_size=self.world_size, + rank=self.rank, + ) + + dist.barrier() + + rank = dist.get_rank() + device = f"cuda:{rank}" + + from torch._C._distributed_c10d import _register_process_group + + _register_process_group("default", dist.group.WORLD) + + @torch.library.custom_op("test::equiv_ar", mutates_args=()) + def equiv_ar(x: torch.Tensor) -> torch.Tensor: + result = x.clone() + return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") + + @equiv_ar.register_fake + def _(x): + return torch.empty_like(x) + + def sum_allreduce(x: torch.Tensor) -> torch.Tensor: + result = x.clone() + return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") + + def avg_allreduce_scaled(x: torch.Tensor) -> torch.Tensor: + result = x.clone() + result = torch.ops._c10d_functional.all_reduce_(result, "avg", "default") + return result * self.world_size + + from torch._inductor.kernel.custom_op import ( + CustomOpConfig, + register_custom_op_autotuning, + ) + + register_custom_op_autotuning( + equiv_ar, + configs=[ + CustomOpConfig(sum_allreduce), + CustomOpConfig(avg_allreduce_scaled), + ], + ) + + class EquivAllReduceModel(torch.nn.Module): + def forward(self, x): + return equiv_ar(x) + + model = torch.compile(EquivAllReduceModel()).to(device) + + torch.manual_seed(42) + x = torch.randn(128, 128, device=device) + dist.broadcast(x, src=0) + + _ = model(x) + + dist.barrier() + dist.destroy_process_group() + + +class TestCollectiveAutotuning4Ranks(MultiProcessTestCase): + """Test collective autotuning with 4 ranks""" + + @property + def world_size(self): + return 4 + + def setUp(self): + super().setUp() + self._spawn_processes() + + @skip_if_lt_x_gpu(4) + def test_vllm_style_allreduce(self): + """ + Test vLLM-style custom allreduce with buffer copy pattern. + + vLLM uses custom allreduce optimized for small tensors (<8MB). + Two implementations simulate vLLM's registered=False mode vs standard NCCL. + """ + dist.init_process_group( + backend="nccl", + init_method=f"file:///tmp/test_vllm_allreduce_{self.id()}", + world_size=self.world_size, + rank=self.rank, + ) + + dist.barrier() + + rank = dist.get_rank() + device = f"cuda:{rank}" + + from torch._C._distributed_c10d import _register_process_group + + _register_process_group("default", dist.group.WORLD) + + @torch.library.custom_op("test::vllm_allreduce", mutates_args=()) + def vllm_allreduce(x: torch.Tensor) -> torch.Tensor: + result = x.clone() + return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") + + @vllm_allreduce.register_fake + def _(x): + return torch.empty_like(x) + + def vllm_buffer_copy_allreduce(x: torch.Tensor) -> torch.Tensor: + """ + vLLM registered=False: flatten -> copy to IPC buffer -> allreduce -> reshape + + vLLM code: + inp_size = inp.numel() * inp.element_size() + self.buffer_ptrs[self.rank][:inp_size].copy_(inp.view(-1)) + ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) + """ + original_shape = x.shape + flat_x = x.contiguous().view(-1) + buffer_copy = flat_x.clone() + result = torch.ops._c10d_functional.all_reduce_( + buffer_copy, "sum", "default" + ) + return result.view(original_shape) + + def nccl_allreduce_direct(x: torch.Tensor) -> torch.Tensor: + """Standard NCCL allreduce without buffer copy.""" + result = x.clone() + return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") + + from torch._inductor.kernel.custom_op import ( + CustomOpConfig, + register_custom_op_autotuning, + ) + + register_custom_op_autotuning( + vllm_allreduce, + configs=[ + CustomOpConfig(vllm_buffer_copy_allreduce), + CustomOpConfig(nccl_allreduce_direct), + ], + ) + + class VLLMAllReduceModel(torch.nn.Module): + def forward(self, x): + return vllm_allreduce(x) + + model = torch.compile(VLLMAllReduceModel()).to(device) + + torch.manual_seed(42 + rank) + x = torch.randn(128, 256, device=device) + + y = model(x) + self.assertEqual(y.shape, x.shape) + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 1c1f0f1c9cd2c..7b931fb3bf47e 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -71,16 +71,25 @@ def __init__( self.sym_inputs = get_symbolic_inputs(self.input_nodes) + # Cache compiled module to avoid recompiling on every benchmark call + self._compiled_module: Any = None + self._compiled_sym_inputs: list[Any] | None = None + def __str__(self) -> str: return f"SubgraphCaller({self.name})" - def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: - # Codegen Subgraph for benchmarking - # Need GraphLowering instead of SubgraphLowering to generate - # fully callable module + def _compile_for_benchmarking(self, *args: list[Any]) -> tuple[Any, list[Any]]: + """ + Compile the subgraph for benchmarking and return (module, sym_inputs). + + TODO: Add precompile() method to enable parallel compilation of all choices + before benchmarking. + """ import torch._inductor.config as inductor_config from torch._inductor.graph import GraphLowering + safe_name = self.name.replace("::", "_").replace(".", "_") + bm_graph_lowering = GraphLowering( gm=self.gm, example_inputs=self.example_inputs, @@ -90,7 +99,7 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: extern_node_serializer=V.graph.extern_node_serializer, is_inference=V.graph.is_inference, is_backward=V.graph.is_backward, - name=f"benchmark_{self.name}", + name=f"benchmark_{safe_name}", ) for sym_inp in self.sym_inputs: @@ -123,9 +132,23 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: ): bm_graph_lowering.run(*self.example_inputs) mod = bm_graph_lowering.compile_to_module() - bm_func = mod.call - bm_func([*sym_inputs, *args]) + return mod, sym_inputs + + def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: + """ + Regular benchmarking: compile and use benchmarker with warmup/rep. + """ + if self._compiled_module is None: + mod, sym_inputs = self._compile_for_benchmarking(*args) + self._compiled_module = mod + self._compiled_sym_inputs = sym_inputs + else: + mod = self._compiled_module + sym_inputs = self._compiled_sym_inputs + assert sym_inputs is not None # Type narrowing + + bm_func = mod.call if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args])) return benchmarker.benchmark( @@ -134,6 +157,24 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: device=benchmarker.infer_device(*sym_inputs, *args), ) + def benchmark_collective(self, *args: list[Any], out: torch.Tensor) -> None: + """ + Only run once with cached compiled module. + Called by benchmark_collective_choice which handles warmup + and timing with barrier synchronization across all ranks. + """ + if self._compiled_module is None: + mod, sym_inputs = self._compile_for_benchmarking(*args) + self._compiled_module = mod + self._compiled_sym_inputs = sym_inputs + else: + mod = self._compiled_module + sym_inputs = self._compiled_sym_inputs + assert sym_inputs is not None # Type narrowing + + bm_func = mod.call + bm_func([*sym_inputs, *args]) + def hash_key(self) -> str: return "-".join( [ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 7ba93575ce8bf..fcfb8f51ae6e7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -608,6 +608,16 @@ def prologue_fusion_enabled() -> bool: # If autotuning in subprocess, whether to use multiple devices autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" +# Number of benchmark runs for collective operations +collective_benchmark_nruns = int( + os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_NRUNS", "50") +) + +# Timeout in seconds for collective benchmarking +collective_benchmark_timeout = float( + os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_TIMEOUT", "30") +) + coordinate_descent_tuning = ( os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" ) diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 12cc68dcb9844..c6a641ce83b17 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -6,7 +6,6 @@ from typing import Any, Optional, Union import torch -from torch._inductor import config from torch._inductor.codegen.subgraph import SubgraphTemplate from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox from torch._inductor.lowering import lowerings, validate_ir @@ -21,6 +20,28 @@ log = logging.getLogger(__name__) +def _detect_collective_ops(choices: list) -> bool: + """ + Detect if choices contain collective operations. + """ + from torch._inductor.utils import is_collective_op + + for choice in choices: + if not hasattr(choice, "gm") or choice.gm is None: + continue + + for node in choice.gm.graph.nodes: + if node.op == "call_function" and node.target is not None: + op_name = str(node.target) + + if is_collective_op(op_name) or is_collective_op( + f"torch.ops.{op_name}" + ): + return True + + return False + + class CustomOpConfig: """Config for custom op autotuning. @@ -180,14 +201,8 @@ def create_internal_input_gen_fn( """Create internal input generator that converts IR buffer to user's fake tensor.""" def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor: - raw_shape = ir_buffer.get_size() - concrete_shape = V.graph.sizevars.size_hints( - raw_shape, fallback=config.unbacked_symint_fallback - ) - - fake_tensor = torch.empty( - concrete_shape, dtype=ir_buffer.get_dtype(), device="meta" - ) + fake_tensor = ir_node_to_tensor(ir_buffer) + assert fake_tensor is not None, "ir_node_to_tensor returned None" return user_function(fake_tensor) return internal_input_gen_fn @@ -321,6 +336,8 @@ def autotune_custom_op( ) input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns) + is_collective = _detect_collective_ops(choices) + # Run autotuning and get both result and winning choice selected_result, winning_choice = autotune_select_algorithm( name=name, @@ -329,6 +346,7 @@ def autotune_custom_op( layout=choices[0].layout, input_gen_fns=input_gen_fns, return_choice=True, + is_collective=is_collective, ) # Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl) @@ -363,16 +381,7 @@ def _generate_dynamic_configs( param_names = list(sig.parameters.keys()) with V.fake_mode: - fake_tensors = [] - for inp in tensor_inputs: - raw_shape = inp.get_size() - concrete_shape = V.graph.sizevars.size_hints( - raw_shape, fallback=config.unbacked_symint_fallback - ) - fake_tensor = torch.empty( - concrete_shape, dtype=inp.get_dtype(), device=inp.get_device() - ) - fake_tensors.append(fake_tensor) + fake_tensors = [ir_node_to_tensor(inp) for inp in tensor_inputs] fake_tensors_dict = dict(zip(param_names, fake_tensors)) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 77448c914df80..df71bdd3db502 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2335,6 +2335,10 @@ def autoheuristic_id(self): class ExternKernelCaller(ChoiceCaller): + """ + Caller for external kernel implementations + """ + def __init__( self, choice: ExternKernelChoice, @@ -2370,6 +2374,19 @@ def benchmark(self, *args, out): return do_bench_using_profiling(lambda: algo(*args)) return benchmarker.benchmark(algo, args, {}) + def benchmark_collective(self, *args, out): + """ + Called by benchmark_collective_choice, only run once, timing handled externally with barrier sync. + """ + if out.numel() == 0: + return + + algo = self.to_callable() + if self.has_out_variant: + algo(*args, out=out) + else: + algo(*args) + def to_callable(self): fn = self.choice.to_callable() if self.kwargs: @@ -2733,6 +2750,7 @@ def __call__( return_multi_template=False, best_config_future=None, return_choice=False, # TODO: return_choice is temporary and will be refactored soon + is_collective=False, ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2843,6 +2861,7 @@ def get_timings(hint_override: Optional[int] = None): choices, precompile_fn, best_config_future=best_config_future, + is_collective=is_collective, ) # if timings is empty, we really have no choice but to return a semi-random # choice. returning the first `ExternKernelCaller` is probably the safest bet @@ -2874,6 +2893,7 @@ def get_timings(hint_override: Optional[int] = None): # if we got any timings at all, pick the best of those choice = min(timings, key=timings.__getitem__) node = choice.output_node() + log.debug("Autotuning selected choice: %s", node) if return_choice: return node, choice @@ -2886,12 +2906,18 @@ def benchmark( layout, input_gen_fns, hint_override: Optional[int] = None, + is_collective=False, ): counters["inductor"]["select_algorithm_autotune"] += 1 # TODO(nmacchioni): remove this layer of abstraction # construct `benchmark_fn` which should pick between in-process and sub-process autotuning benchmark_fn = self.make_benchmark_fn( - choices, input_nodes, layout, input_gen_fns, hint_override=hint_override + choices, + input_nodes, + layout, + input_gen_fns, + hint_override=hint_override, + is_collective=is_collective, ) # `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which # maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds @@ -2905,6 +2931,7 @@ def autotune( input_gen_fns, choices, hint_override: Optional[int] = None, + is_collective=False, ): log.debug("Starting autotuning") @@ -2915,7 +2942,12 @@ def autotune( metadata=_autotune_metadata(input_nodes), ): benchmark_results = self.benchmark( - choices, input_nodes, layout, input_gen_fns, hint_override=hint_override + choices, + input_nodes, + layout, + input_gen_fns, + hint_override=hint_override, + is_collective=is_collective, ) if config.max_autotune_report_choices_stats: _log_autotune_choices_stats( @@ -2934,6 +2966,7 @@ def do_autotuning( precompile_fn, hint_override: Optional[int] = None, best_config_future=None, + is_collective=False, ): """Execute the autotuning process for kernel algorithm selection. @@ -3071,6 +3104,7 @@ def track_has_autotuned(choices): input_gen_fns, choices, hint_override=hint_override, + is_collective=is_collective, ) timings = self.lookup( @@ -3084,6 +3118,17 @@ def track_has_autotuned(choices): autotune_elapse = time.time() - autotune_start_ts log.debug("Autotuning elapsed time: %.02fs", autotune_elapse) + # For collective: if any choice returned inf (timeout or failure), fallback to default + if is_collective and timings: + has_inf = any(not math.isfinite(timing) for timing in timings.values()) + if has_inf: + log.warning( + "At least one choice failed or timed out during collective benchmarking. " + "Falling back to default implementation." + ) + return {} + + # For regular: if all choices returned inf, raise error if timings and all(not math.isfinite(timing) for timing in timings.values()): raise NoValidChoicesError @@ -3100,6 +3145,7 @@ def track_has_autotuned(choices): precompile_elapse, prescreening_elapse, hint_override=hint_override, + is_collective=is_collective, ) def profiler_bench_function(): @@ -3461,16 +3507,162 @@ def benchmark_choice( autotune_args.verify(**VERIFY) return result + @classmethod + def _run_collective_benchmark( + cls, + choice: ChoiceCaller, + inputs: tuple, + output: torch.Tensor, + nruns: int, + process_group, + timeout, + ) -> float: + """ + Single function for benchmarking collective operations. + Used for both warmup and actual benchmarking. + + Returns total time in milliseconds, or raises TimeoutError if any collective times out. + """ + import torch.distributed as dist + + work = dist.barrier(group=process_group, async_op=True) + if not work.wait(timeout): + raise TimeoutError("Barrier timeout before benchmarking") + + torch.cuda.synchronize() + + total_time = 0.0 + + for i in range(nruns): + torch.cuda.synchronize() + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + start_evt.record() + choice.benchmark_collective(*inputs, out=output) # type: ignore[attr-defined] + end_evt.record() + end_evt.synchronize() + + total_time += start_evt.elapsed_time(end_evt) + + return total_time + + @classmethod + def benchmark_collective_choice( + cls, + choice: ChoiceCaller, + autotune_args: AutotuneArgs, + ) -> float: + """ + Benchmark a choice for collective operations with cross-rank synchronization. + This method ensures all ranks synchronize before benchmarking + to get accurate measurements for distributed collective operations. + + Timeout/Error handling: If ANY rank times out or encounters an error during + the collective operations, ALL ranks will naturally time out (since the collective + won't complete), allowing the autotuner to fall back to the default implementation. + """ + from datetime import timedelta + + import torch.distributed as dist + + timeout_seconds = config.collective_benchmark_timeout + + nruns = config.collective_benchmark_nruns + nwarmup = ir.autotune_warmup + + # Use default process group (None = all ranks) + process_group = None + rank = dist.get_rank(process_group) + + benchmark_tensors: BenchmarkTensors = autotune_args.get_benchmark_tensors( + cls._is_extern(choice) + ) + inputs, output = benchmark_tensors.unpack() + output.zero_() + + timeout = timedelta(seconds=timeout_seconds) + + try: + # Do n warmups + total_time = cls._run_collective_benchmark( + choice, inputs, output, nwarmup, process_group, timeout + ) + + # Do n actual benchmarking runs + total_time = cls._run_collective_benchmark( + choice, inputs, output, nruns, process_group, timeout + ) + + avg_time = total_time / nruns + + # All-reduce to get avg time across ranks + time_tensor = torch.tensor( + [avg_time], dtype=torch.float32, device=f"cuda:{rank}" + ) + work = dist.all_reduce( + time_tensor, + op=dist.ReduceOp.AVG, + group=process_group, + async_op=True, + ) + if not work.wait(timeout): + raise TimeoutError( + "All-reduce timeout when collecting benchmark results" + ) + + timing = time_tensor.item() + + log.info( + "Collective benchmark for %s: %.6f ms", + choice.name, + timing, + ) + + return timing + + except Exception: + log.warning( + "Collective benchmark exception for choice %s. Skipping this choice.", + getattr(choice, "name", ""), + exc_info=True, + ) + return float("inf") + @classmethod def benchmark_choices( cls, choices: Sequence[ChoiceCaller], autotune_args: AutotuneArgs, + is_collective: bool = False, ) -> dict[ChoiceCaller, float]: + """ + Benchmark a list of choices and return timing dict. + """ + if is_collective: + import torch.distributed as dist + + if not dist.is_initialized(): + log.warning( + "Collective op detected but distributed not initialized. " + "Falling back to regular benchmarking." + ) + is_collective = False + else: + rank = dist.get_rank(None) # Use default process group + log.debug( + "Using collective benchmarking for %d choices on rank %d", + len(choices), + rank, + ) timings = {} for choice in choices: try: - timing = cls.benchmark_choice(choice, autotune_args) + if is_collective: + timing = cls.benchmark_collective_choice(choice, autotune_args) + else: + timing = cls.benchmark_choice(choice, autotune_args) except CUDACompileError: from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -3524,6 +3716,16 @@ def benchmark_choices( timings[choice] = timing + # If a collective choice failed or timed out, skip the rest of the choices + if is_collective and not math.isfinite(timing): + log.warning( + "Choice %s failed or timed out during collective benchmarking. " + "Stopping further benchmarking to avoid NCCL corruption.", + getattr(choice, "name", ""), + ) + timings.update({c: float("inf") for c in choices if c not in timings}) + break + return timings @classmethod @@ -3534,11 +3736,16 @@ def benchmark_in_current_process( layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], hint_override: Optional[int] = None, + is_collective=False, ) -> dict[ChoiceCaller, float]: inputs = cls.get_inputs( choices, input_nodes, layout, input_gen_fns, hint_override=hint_override ) - return cls.benchmark_choices(choices, inputs) + return cls.benchmark_choices( + choices, + inputs, + is_collective=is_collective, + ) @classmethod def benchmark_in_sub_process( @@ -3570,21 +3777,24 @@ def make_benchmark_fn( layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], hint_override: Optional[int] = None, + is_collective=False, ): if DEBUG: print(f"{len(choices)} tuning requests:") - if config.autotune_in_subproc: + # Collective ops must use current process + if is_collective or not config.autotune_in_subproc: return functools.partial( - cls.benchmark_in_sub_process, + cls.benchmark_in_current_process, input_nodes=input_nodes, layout=layout, input_gen_fns=input_gen_fns, hint_override=hint_override, + is_collective=is_collective, ) else: return functools.partial( - cls.benchmark_in_current_process, + cls.benchmark_in_sub_process, input_nodes=input_nodes, layout=layout, input_gen_fns=input_gen_fns, @@ -3816,8 +4026,26 @@ def log_results( precompile_elapse: float, prescreening_elapse: Optional[float] = None, hint_override: Optional[int] = None, + is_collective: bool = False, ): - """Log the autotuning results, currently only handles mm and flex""" + """Log the autotuning results, currently only handles mm and flex. Log Collective op autotuning result""" + if is_collective and timings: + import torch.distributed as dist + + # Only rank 0 logs to avoid duplicate logs + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + best_choice = min(timings, key=timings.__getitem__) + log.warning("[COLLECTIVE AUTOTUNING] All timings:") + for c, t in sorted(timings.items(), key=lambda x: x[1]): + choice_name = getattr(c, "name", str(c)) + log.warning( + " - %s: %.6f ms %s", + choice_name, + t if math.isfinite(t) else float("inf"), + "← SELECTED" if c == best_choice else "", + ) + V.debug.log_autotuning_results( name, input_nodes, timings, elapse, precompile_elapse ) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 4d1ddc9ad4769..d7f3844cdf1ba 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -4126,3 +4126,24 @@ def should_fallback_by_default(node: torch.fx.Node) -> bool: return target in fallback_hops return not _needs_inductor_compile(node) + + +# Collective operation names for specialized benchmarking +COLLECTIVE_OPS = OrderedSet( + [ + "torch.ops._c10d_functional.all_reduce.default", + "torch.ops._c10d_functional.all_reduce_.default", + "torch.ops._c10d_functional.all_gather_into_tensor.default", + "torch.ops._c10d_functional.reduce_scatter_tensor.default", + "torch.ops._c10d_functional.all_to_all_single.default", + "torch.ops._c10d_functional_autograd.all_reduce.default", + "torch.ops._c10d_functional_autograd.all_gather_into_tensor.default", + "torch.ops._c10d_functional_autograd.reduce_scatter_tensor.default", + "torch.ops._c10d_functional_autograd.all_to_all_single.default", + ] +) + + +def is_collective_op(op_name: str) -> bool: + """Check if an operation is a collective operation.""" + return op_name in COLLECTIVE_OPS From d40f4950f2b7f7aa380a22fe0f6166e71680fbcf Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 3 Dec 2025 09:39:36 -0800 Subject: [PATCH 200/338] [hoo] Invoke subgraph + effects Inductor support (#167364) In order to support effect tokens with invoke subgraph in Inductor, I reimplemented how we handle with_effects, which now has a similar behavior as control_deps. Previously the behavior of with_effects is that we turn each with_effects(op, ..) call into an `ir.EffectfulKernel`, and the `op` is treated as a FallbackKernel. However in the case of `invoke_subgraph`, we want it to be an `ir.InvokeSubgraph` to trigger the invoke_subgraph inductor lowering logic. So now instead of turning with_effects into an `ir.EffectfulKernel`, I just add this dependency information to `V.graph.additional_star_deps`, similar to `V.graph.additional_buffer_deps` which is used by `control_deps`. Differential Revision: [D87668982](https://our.internmc.facebook.com/intern/diff/D87668982) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167364 Approved by: https://github.com/fxdawnn --- test/higher_order_ops/test_with_effects.py | 5 ++ test/inductor/test_torchbind.py | 20 ++--- torch/_inductor/graph.py | 1 + torch/_inductor/ir.py | 15 +++- torch/_inductor/lowering.py | 90 ++++++++++++++++++---- torch/_inductor/scheduler.py | 4 + 6 files changed, 110 insertions(+), 25 deletions(-) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index b7840c0729e27..c612c3a65ce0b 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -1012,6 +1012,11 @@ def forward(self, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): return (addmm_1,)""", # noqa: B950 ) + recorded_list.clear() + out2 = torch.compile(model)(x) + self.assertEqual(len(recorded_list), 4) + self.assertTrue(torch.allclose(model(x)[0], out2[0], atol=1e-7, rtol=1e-4)) + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index c604f8450bbbf..88a39e14583f7 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -192,7 +192,7 @@ def test_torchbind_aot_compile(self): { "nodes": [ { - "name": "buf3", + "name": "buf1", "node": { "target": "_TorchScriptTesting::takes_foo_tuple_return", "inputs": [ @@ -208,20 +208,20 @@ def test_torchbind_aot_compile(self): }, { "name": "x", - "arg": {"as_tensor": {"name": "buf2"}}, + "arg": {"as_tensor": {"name": "buf0"}}, "kind": 1, }, ], "outputs": [ - {"as_tensor": {"name": "buf4"}}, - {"as_tensor": {"name": "buf5"}}, + {"as_tensor": {"name": "buf2"}}, + {"as_tensor": {"name": "buf3"}}, ], "metadata": {}, "is_hop_single_tensor_return": None, }, }, { - "name": "buf7", + "name": "buf5", "node": { "target": "_TorchScriptTesting::takes_foo", "inputs": [ @@ -237,17 +237,17 @@ def test_torchbind_aot_compile(self): }, { "name": "x", - "arg": {"as_tensor": {"name": "buf6"}}, + "arg": {"as_tensor": {"name": "buf4"}}, "kind": 1, }, ], - "outputs": [{"as_tensor": {"name": "buf8"}}], + "outputs": [{"as_tensor": {"name": "buf6"}}], "metadata": {}, "is_hop_single_tensor_return": None, }, }, { - "name": "buf9", + "name": "buf7", "node": { "target": "call_torchbind", "inputs": [ @@ -268,11 +268,11 @@ def test_torchbind_aot_compile(self): }, { "name": "_1", - "arg": {"as_tensor": {"name": "buf2"}}, + "arg": {"as_tensor": {"name": "buf0"}}, "kind": 1, }, ], - "outputs": [{"as_tensor": {"name": "buf10"}}], + "outputs": [{"as_tensor": {"name": "buf8"}}], "metadata": {}, "is_hop_single_tensor_return": None, }, diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index a16e09f3ca5cf..b136f7ab9eddf 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -390,6 +390,7 @@ def __init__( self.additional_buffer_deps: dict[str, OrderedSet[str]] = defaultdict( OrderedSet ) + self.additional_star_deps: dict[str, OrderedSet[str]] = defaultdict(OrderedSet) # Inplace padding may require Inductor to allocate slightly larger # tensor for padding. diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index de4b4ab20a779..8ba7ab9311b6c 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -8662,11 +8662,22 @@ def create( fake_operands = None if eager_input_vals := current_node.meta.get("eager_input_vals"): # eager_input_vals is (args_values, kwargs_values). We need args for invoke_subgraph - fake_operands = eager_input_vals[0][2:] + offset = 2 + if current_node.target is torch.ops.higher_order.with_effects: + # Aruguments eagerly are (token, subgraph, identifier, *operands) + assert current_node.args[1] is torch.ops.higher_order.invoke_subgraph + offset = 3 + fake_operands = eager_input_vals[0][offset:] else: + offset = 2 + if current_node.target is torch.ops.higher_order.with_effects: + # with_effects args: (token, invoke_subgraph, subgraph, identifier, *operands) + assert current_node.args[1] is torch.ops.higher_order.invoke_subgraph + offset = 4 + # For the partitioned backward graph, we do not have # eager_input_vals. Here, we rely on the recorded example values. - fx_operands = current_node.args[2:] + fx_operands = current_node.args[offset:] fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] # Realize the inputs. Also intermediates can have different strides than diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 427997964bbb7..45f660f04674b 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -7449,26 +7449,90 @@ def _sink_tokens(tokens): return None +@register_lowering(torch.ops.prims._make_token.default) +def _make_token(): + return None + + @register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None) def with_effects(token, op, *args, **kwargs): - result = ir.EffectfulKernel.create(op, *args, **kwargs) - - from torch._higher_order_ops.effects import _get_effect + """ + We lower the operator directly, and then we add StarDep dependencies to all + the newly created nodes in the graph. + """ + from torch._higher_order_ops.effects import _get_effect, _get_schema + # Get effect type effect_type = _get_effect(op) - assert effect_type is not None - effectful_kernel = V.graph.effectful_ops[effect_type] + if effect_type is None and op is torch.ops.higher_order.invoke_subgraph: + from torch._guards import InvokeSubgraphCache, TracingContext - if result is None: - return (effectful_kernel,) + tracing_ctx = TracingContext.try_get() + if tracing_ctx: + invoke_subgraph_cache = tracing_ctx.hop_dispatch_set_cache.get_cache( + torch.ops.higher_order.invoke_subgraph + ) + if invoke_subgraph_cache: + assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache) + # args[1] is identifier + effects = invoke_subgraph_cache.get_effects(args[1]) + if effects: + assert len(effects) == 1, "Multiple effects NYI" + effect_type = next(iter(effects)) + + # Track operations before + operation_len = len(V.graph.operations) + + # Lower the op + if op in lowerings: + result = lowerings[op](*args, **kwargs) + # Realize so that we can get the ops to show up in V.graph.operations + pytree.tree_map_only(TensorBox, lambda a: a.realize(), result) + else: + + def wrap_tensors(x): + return TensorBox.create(x) if isinstance(x, ir.IRNode) else x + + result = pytree.tree_map( + wrap_tensors, ir.FallbackKernel.create(op, *args, **kwargs) + ) + + # Get all the operations created during the lowering above, and add StarDeps + # to the previous node with the same effect + assert len(V.graph.operations[operation_len:]) > 0, ( + f"No operation nodes were generated when lowering effectful operator {op}." + ) + if effect_type: + prev_effect_buffer = V.graph.effectful_ops.get(effect_type) + for new_op in V.graph.operations[operation_len:]: + # Patch has_side_effects to return True + new_op.has_side_effects = lambda: True # pyrefly: ignore[missing-attribute] + if prev_effect_buffer: + op_name = new_op.get_name() # pyrefly: ignore[missing-attribute] + V.graph.additional_star_deps[op_name].add(prev_effect_buffer.get_name()) + # Update the effectful ops chain to point to the latest operation + V.graph.effectful_ops[effect_type] = ( # pyrefly: ignore[missing-attribute] + new_op # pyrefly: ignore[unsupported-operation] + ) + + try: + args, kwargs = pytree.tree_map_only( + ir.TorchBindObject, lambda a: a.get_value(), (args, kwargs) + ) + schema = _get_schema(op, args, kwargs) + except RuntimeError as e: + error_msg = str(e) + log.warning( + "Failed to get schema for %s: %s. Assuming list output", op, error_msg + ) + return (token, *result) - result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result) - # See [NOTE: with_effects return type] - # Only return `result` if it is a tuple, not list. - if not isinstance(result, tuple): - return (effectful_kernel, result) + if len(schema.returns) == 0: + return (token, result) + elif len(schema.returns) == 1: + return (token, result) else: - return (effectful_kernel, *result) + return (token, *result) from .comm_lowering import register_comm_lowerings diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index f285a65470e78..aeaed244cab2f 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3073,6 +3073,10 @@ def add_user( add_user(add_dep, node, is_weak=True) node.add_fake_dep(WeakDep(add_dep, node.get_name())) + for add_dep in V.graph.additional_star_deps[node.get_name()]: + add_user(add_dep, node, is_weak=False) # Strong dependency + node.add_fake_dep(StarDep(add_dep)) + # add normal non-mutation dependencies for read in node.read_writes.reads: if not isinstance(read, WeakDep): From 201e2c4117eb9744594dad6a5c18213d7b4705d7 Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 3 Dec 2025 21:00:49 +0000 Subject: [PATCH 201/338] Fix update slow test workflow after #168334 (#169495) Looks like https://github.com/pytorch/pytorch/pull/168334 switched update slow test to python 3.10 however workflow itself is using python 3.9. And has been failing since then: https://github.com/pytorch/pytorch/actions/runs/19815030887/job/56764391794 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169495 Approved by: https://github.com/malfet --- .github/workflows/weekly.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/weekly.yml b/.github/workflows/weekly.yml index b95dadd5f2b1c..7bed6c785d4db 100644 --- a/.github/workflows/weekly.yml +++ b/.github/workflows/weekly.yml @@ -44,7 +44,7 @@ jobs: - name: Setup Python uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: - python-version: '3.9' + python-version: '3.10' - name: Install requirements shell: bash run: | From bc6a4863c7246a6493d16d4ea6eee71ec07c6a09 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Wed, 3 Dec 2025 21:04:02 +0000 Subject: [PATCH 202/338] [xpu][feature][2/N]Enable SDPA XPU FlashAttention backend with SYCL-TLA implementation (#167057) This is a PR to upstream [SYCL-TLA](https://github.com/intel/sycl-tla) version FlashAttention for Pytorch XPU. This is the second PR to register SYCL-TLA version FlashAttention forward/backward xpu kernels into SDPA's FlashAttention XPU backend. PR stacks: - https://github.com/pytorch/pytorch/pull/169101 - https://github.com/pytorch/pytorch/pull/167057 Currently, we support Intel Ponte Vecchio and Battlemage on Linux. In terms of other platform support, we are WIP. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167057 Approved by: https://github.com/EikanWang, https://github.com/drisspg Co-authored-by: Eikan Wang --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 14 +- aten/src/ATen/native/native_functions.yaml | 2 + .../ATen/native/transformers/attention.cpp | 14 +- test/test_transformers.py | 397 ++++++++++++++---- torch/_meta_registrations.py | 2 +- torch/csrc/Module.cpp | 8 +- .../aoti_torch/generated/c_shim_xpu.h | 2 + torch/nn/attention/bias.py | 10 +- 8 files changed, 337 insertions(+), 112 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 7be355b74c2f8..1dff18181b420 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -74,11 +75,6 @@ bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) { return sdp::check_tensor_dtype(params, supported_dtypes, debug); } -bool can_use_flash_attention(sdp::sdp_params const& params, bool debug) { - // Currently, XPU fallbacks flash attention to overridable - return can_use_overrideable_attention(params, debug); -} - bool can_use_cudnn_attention(sdp::sdp_params const& params, bool debug) { if (debug) { TORCH_WARN("XPU don't support SDPA cudnn attention backend."); @@ -142,10 +138,8 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { break; case sdp::SDPBackend::flash_attention: if (ctx.userEnabledFlashSDP() && - can_use_flash_attention(kernel_params, print_debug)) { - TORCH_WARN_ONCE( - "SDPA Flash Attention backend is not supported on XPU, falling back to OVERRIDEABLE backend."); - return sdp::SDPBackend::overrideable; + sdp::can_use_flash_attention(kernel_params, print_debug)) { + return sdp::SDPBackend::flash_attention; } break; case sdp::SDPBackend::cudnn_attention: @@ -172,7 +166,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { print_debug = true; TORCH_WARN("Flash attention kernel not used because:"); - can_use_flash_attention(kernel_params, print_debug); + sdp::can_use_flash_attention(kernel_params, print_debug); TORCH_WARN("Overrideable attention kernel not used because:"); can_use_overrideable_attention(kernel_params, print_debug); TORCH_WARN("CuDNN attention kernel not used because:"); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a4d7797273d0d..248a3e1875e55 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15132,6 +15132,7 @@ - func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) dispatch: CUDA: _scaled_dot_product_flash_attention_cuda + XPU: _scaled_dot_product_flash_attention_xpu NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda tags: nondeterministic_seeded @@ -15151,6 +15152,7 @@ variants: function dispatch: CUDA: _scaled_dot_product_flash_attention_backward_cuda + XPU: _scaled_dot_product_flash_attention_backward_xpu NestedTensorCUDA: _scaled_dot_product_flash_attention_backward_nested - func: _scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 7aad4309924d4..72326a8b5e249 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -614,8 +614,8 @@ at::Tensor preprocess_mask( // This causes the kernel to maybe alias query, key, value // So instead we pad the head_dimensions to be a multiple of 8 in the composite // region -template -at::Tensor pad_last_dim(const at::Tensor& attn_bias) { +template +at::Tensor pad_last_dim(const at::Tensor& attn_bias, int alignment_size) { auto last_dim_size = attn_bias.sym_size(-1); if (last_dim_size % alignment_size == 0) { return attn_bias; @@ -743,11 +743,13 @@ Tensor scaled_dot_product_attention( return std::get<0>(out_lse_softmax); } case SDPBackend::flash_attention: { - if(query_device_type == DeviceType::CUDA){ + if(query_device_type == DeviceType::CUDA || + query_device_type == DeviceType::XPU) { c10::SymInt og_size = query_.sym_size(-1); - Tensor query_padded = pad_last_dim<8, false>(query_); - Tensor key_padded = pad_last_dim<8, false>(key); - Tensor value_padded = pad_last_dim<8, false>(value); + int alignment_size = (query_device_type == DeviceType::XPU) ? 64 : 8; + Tensor query_padded = pad_last_dim(query_, alignment_size); + Tensor key_padded = pad_last_dim(key, alignment_size); + Tensor value_padded = pad_last_dim(value, alignment_size); // We need to calculate the scale based off the OG head dim size auto og_scale = sdp::calculate_scale(query_, scale); auto out_lse_softmax = at::_scaled_dot_product_flash_attention( diff --git a/test/test_transformers.py b/test/test_transformers.py index ad7ae56307eb1..1897548f560cf 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -4208,6 +4208,8 @@ class TestSDPAXpuOnly(NNTestCase): Mostly migrate from TestSDPACudaOnly in test/test_transformers.py """ + PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION = torch.xpu.is_available() and torch._C._is_flash_attention_available() + @parametrize("type", ["dense"]) @parametrize("dropout", [0.0, 0.7]) @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half]) @@ -4222,7 +4224,36 @@ def test_fused_sdp_choice_xpu(self, device, type: str, dropout: float, dtype: to else: assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.OVERRIDEABLE.value - def test_fused_attention_different_dk_dv(self, device): + def test_backends_set_to_math(self, device): + dtype = torch.bfloat16 + q_shape = SdpaShape(1, 1, 8, 16) + kv_shape = SdpaShape(1, 1, 12, 16) + make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) + make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) + q, k, v = make_q(), make_kv(), make_kv() + with sdpa_kernel(backends=[SDPBackend.MATH]): + self.assertTrue(torch._C._get_math_sdp_enabled()) + self.assertFalse(torch._C._get_overrideable_sdp_enabled()) + _ = F.scaled_dot_product_attention(q, k, v) + + def test_default_priority_order(self, device): + # The default priority order of xpu is overridable, math, flash, efficient, cudnn + # For xpu backend, we need to make sure that overridable > math > flash + dtype = torch.bfloat16 + shape = SdpaShape(1, 1, 1, 1) + make_tensor = partial(torch.rand, shape, device=device, dtype=dtype) + t = make_tensor() + # run sdp_choice to make sure priority_order is set by XPU default priority_order + torch._fused_sdp_choice(t, t, t) + from torch.nn.attention import _cur_sdpa_kernel_backends + default_priority = _cur_sdpa_kernel_backends(with_priority=True) + flash_index = default_priority.index(SDPBackend.FLASH_ATTENTION) + overrideable_index = default_priority.index(SDPBackend.OVERRIDEABLE) + math_index = default_priority.index(SDPBackend.MATH) + self.assertTrue(overrideable_index < math_index < flash_index, + f"Expected overrideable < math < flash, got {overrideable_index}, {math_index}, {flash_index}") + + def test_onednn_attention_different_dk_dv(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64 @@ -4231,51 +4262,16 @@ def test_fused_attention_different_dk_dv(self, device): v_shape = SdpaShape(batch, num_heads, 2, head_dim_v) query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) - actual = F.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) + with sdpa_kernel([SDPBackend.OVERRIDEABLE]): + actual = F.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) - math_ref = torch.ops.aten._scaled_dot_product_attention_math( - query.float(), key.float(), value.float(), attn_mask=None, dropout_p=0.0, is_causal=False)[0] + with sdpa_kernel([SDPBackend.MATH]): + math_ref = F.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @parametrize("dtype", [torch.half, torch.bfloat16]) - @parametrize("batch_size,n_head,n_head_kv,q_size,kv_size,head_dim", [ - (2, 64, 16, 9216, 77, 64), - (2, 32, 4, 2304, 2304, 64), - (2, 32, 2, 2304, 77, 64), - (2, 20, 2, 576, 576, 64), - (2, 20, 2, 576, 77, 64), - (2, 20, 2, 144, 144, 64), - (2, 20, 2, 144, 77, 64), - (1, 32, 2, 1, 32, 128), - (4, 32, 4, 1, 32, 128), - (1, 32, 2, 32, 32, 128), - (4, 32, 4, 32, 32, 128), - (1, 32, 2, 2016, 2016, 128), - (4, 32, 4, 2016, 2016, 128), - ]) - @parametrize("is_causal", [True, False]) - def test_fused_attention_gqa(self, device, dtype, batch_size, n_head, n_head_kv, q_size, kv_size, head_dim, is_causal): - tol = Tolerances(1e-5, 5e-6) - if dtype is torch.bfloat16: - tol = Tolerances(5e-2, 5e-2) - if dtype is torch.float16: - tol = Tolerances(1e-2, 1e-2) - make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) - q_shape = SdpaShape(batch_size, n_head, q_size, head_dim) - k_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim) - v_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim) - query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) - - actual = F.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True) - - math_ref = torch.ops.aten._scaled_dot_product_attention_math( - query.float(), key.float(), value.float(), attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True)[0] - - self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=tol.atol, rtol=tol.rtol) - def test_onednn_attention_fail_d576(self, device): # Test that onednn graph attention dispatching correctly bails out on d > 576 b, h = 1, 2 @@ -4290,7 +4286,7 @@ def test_onednn_attention_fail_d576(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel."): _ = F.scaled_dot_product_attention(q, k, v) - def test_fused_attention_broadcasted_input(self, device): + def test_onednn_attention_broadcasted_input(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) batch, num_heads, seqlen, head_dim = 32, 16, 128, 32 @@ -4304,15 +4300,17 @@ def test_fused_attention_broadcasted_input(self, device): attn_mask = attn_mask.expand(1, 1, seqlen, seqlen) # test that we do not dispatch to onednn for an unsupported case - actual = F.scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]): + actual = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) - math_ref = torch.ops.aten._scaled_dot_product_attention_math( - query.float(), key.float(), value.float(), attn_mask=attn_mask, dropout_p=0.0, is_causal=False)[0] + with sdpa_kernel(backends=[SDPBackend.MATH]): + math_ref = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - def test_attention_preserves_query_layout(self, device): + def test_onednn_attention_preserves_query_layout(self, device): def test_attention(permute_order: list[list[int]]): BHSqD = [4, 16, 256, 64] @@ -4328,7 +4326,8 @@ def test_attention(permute_order: list[list[int]]): self.assertEqual(k.shape, BHSkvD) self.assertEqual(v.shape, BHSkvD) - out = F.scaled_dot_product_attention(q, k, v) + with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]): + out = F.scaled_dot_product_attention(q, k, v) self.assertTrue(out.permute(permute_order).is_contiguous()) permutable = [0, 1, 2] @@ -4337,36 +4336,7 @@ def test_attention(permute_order: list[list[int]]): for permute_order in permute_orders: test_attention(list(permute_order) + [3]) - def test_backends_set_to_math(self, device): - dtype = torch.bfloat16 - q_shape = SdpaShape(1, 1, 8, 16) - kv_shape = SdpaShape(1, 1, 12, 16) - make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) - make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) - q, k, v = make_q(), make_kv(), make_kv() - with sdpa_kernel(backends=[SDPBackend.MATH]): - self.assertTrue(torch._C._get_math_sdp_enabled()) - self.assertFalse(torch._C._get_overrideable_sdp_enabled()) - _ = F.scaled_dot_product_attention(q, k, v) - - def test_default_priority_order(self, device): - # The default priority order of xpu is overridable, math, flash, efficient, cudnn - # For xpu backend, we need to make sure that overridable > math > flash - dtype = torch.bfloat16 - shape = SdpaShape(1, 1, 1, 1) - make_tensor = partial(torch.rand, shape, device=device, dtype=dtype) - t = make_tensor() - # run sdp_choice to make sure priority_order is set by XPU default priority_order - torch._fused_sdp_choice(t, t, t) - from torch.nn.attention import _cur_sdpa_kernel_backends - default_priority = _cur_sdpa_kernel_backends(with_priority=True) - flash_index = default_priority.index(SDPBackend.FLASH_ATTENTION) - overrideable_index = default_priority.index(SDPBackend.OVERRIDEABLE) - math_index = default_priority.index(SDPBackend.MATH) - self.assertTrue(overrideable_index < math_index < flash_index, - f"Expected overrideable < math < flash, got {overrideable_index}, {math_index}, {flash_index}") - - def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device): + def test_onednn_attention_fused_kernels_safe_softmax(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) batch, num_heads, seqlen, head_dim = 32, 16, 32, 64 @@ -4377,17 +4347,18 @@ def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device): attn_mask = torch.full((seqlen, seqlen), float('-inf'), device=device, dtype=torch.bfloat16) - actual = F.scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) - - math_ref = torch.ops.aten._scaled_dot_product_attention_math( - query.float(), key.float(), value.float(), attn_mask=attn_mask, dropout_p=0.0, is_causal=False)[0] + with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]): + actual = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + with sdpa_kernel(backends=[SDPBackend.MATH]): + math_ref = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) @parametrize("type", ["dense"]) @parametrize("is_contiguous", [True, False]) - def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): + def test_onednn_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 @@ -4409,12 +4380,53 @@ def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: s with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]): actual = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) - math_ref = torch.ops.aten._scaled_dot_product_attention_math( - query.contiguous(), key.contiguous(), value.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False)[0] + + with sdpa_kernel(backends=[SDPBackend.MATH]): + math_ref = torch.nn.functional.scaled_dot_product_attention( + query.contiguous(), key.contiguous(), value.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False) self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) - @parametrize("fused_kernel", [SDPBackend.MATH, SDPBackend.OVERRIDEABLE]) + @parametrize("dtype", [torch.half, torch.bfloat16]) + @parametrize("batch_size,n_head,n_head_kv,q_size,kv_size,head_dim", [ + (2, 64, 16, 9216, 77, 64), + (2, 32, 4, 2304, 2304, 64), + (2, 32, 2, 2304, 77, 64), + (2, 20, 2, 576, 576, 64), + (2, 20, 2, 576, 77, 64), + (2, 20, 2, 144, 144, 64), + (2, 20, 2, 144, 77, 64), + (1, 32, 2, 1, 32, 128), + (4, 32, 4, 1, 32, 128), + (1, 32, 2, 32, 32, 128), + (4, 32, 4, 32, 32, 128), + (1, 32, 2, 2016, 2016, 128), + (4, 32, 4, 2016, 2016, 128), + ]) + @parametrize("is_causal", [True, False]) + def test_onednn_attention_gqa_vs_math(self, device, dtype, batch_size, n_head, n_head_kv, q_size, kv_size, head_dim, is_causal): + tol = Tolerances(1e-5, 5e-6) + if dtype is torch.bfloat16: + tol = Tolerances(5e-2, 5e-2) + if dtype is torch.float16: + tol = Tolerances(1e-2, 1e-2) + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) + q_shape = SdpaShape(batch_size, n_head, q_size, head_dim) + k_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim) + v_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim) + query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]): + actual = F.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True) + + with sdpa_kernel(backends=[SDPBackend.MATH]): + math_ref = F.scaled_dot_product_attention( + query.float(), key.float(), value.float(), attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True) + + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=tol.atol, rtol=tol.rtol) + + @parametrize("fused_kernel", [SDPBackend.OVERRIDEABLE]) @parametrize("dtype", [torch.half, torch.bfloat16, torch.float32]) @parametrize("batch_size,n_head,q_size,kv_size,head_dim", [ (2, 5, 9216, 9216, 64), @@ -4434,7 +4446,7 @@ def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: s ]) @parametrize("mask_type", ["float", "causal"]) @parametrize("train", [False]) - def test_scaled_dot_product_fused_attention_mask_vs_math( + def test_onednn_attention_mask_vs_math( self, device, fused_kernel, @@ -4501,6 +4513,213 @@ def test_scaled_dot_product_fused_attention_mask_vs_math( self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol) + @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + @parametrize("dtype", [torch.float32, torch.float64]) + def test_flash_attention_unsupport_dtypes(self, device, dtype): + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) + batch, num_heads, seqlen, head_dim = 32, 16, 32, 64 + q_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + k_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + v_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + q, k, v = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + with self.assertRaisesRegex(RuntimeError, "No available kernel"): + F.scaled_dot_product_attention(q, k, v) + + @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + def test_flash_attention_unsupport_dropout(self, device): + dtype = torch.bfloat16 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) + batch, num_heads, seqlen, head_dim = 32, 16, 32, 64 + q_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + k_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + v_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + q, k, v = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + with self.assertRaisesRegex(RuntimeError, "No available kernel"): + F.scaled_dot_product_attention(q, k, v, dropout_p=0.1) + + @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + def test_flash_attention_unsupport_bhsd_layout(self, device): + dtype = torch.bfloat16 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) + batch, num_heads, seqlen, head_dim = 32, 16, 32, 64 + q_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + k_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + v_shape = SdpaShape(batch, seqlen, num_heads, head_dim) + q, k, v = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + # (B, S, H, D) + q = q.view(batch, seqlen, num_heads, head_dim).transpose(1, 2) + k = k.view(batch, seqlen, num_heads, head_dim).transpose(1, 2) + v = v.view(batch, seqlen, num_heads, head_dim).transpose(1, 2) + + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + F.scaled_dot_product_attention(q, k, v) + + # (B, H, S, D) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + with self.assertRaisesRegex(RuntimeError, "No available kernel"): + F.scaled_dot_product_attention(q, k, v) + + @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + def test_flash_attention_headdim_size(self, device): + dtype = torch.bfloat16 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) + batch, num_heads, seqlen = 32, 2, 32 + + max_supported_head_dim = 192 + q_shape = SdpaShape(batch, seqlen, num_heads, max_supported_head_dim) + k_shape = SdpaShape(batch, seqlen, num_heads, max_supported_head_dim) + v_shape = SdpaShape(batch, seqlen, num_heads, max_supported_head_dim) + q, k, v = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + F.scaled_dot_product_attention(q, k, v) + + q_shape = SdpaShape(batch, seqlen, num_heads, max_supported_head_dim + 1) + k_shape = SdpaShape(batch, seqlen, num_heads, max_supported_head_dim + 1) + v_shape = SdpaShape(batch, seqlen, num_heads, max_supported_head_dim + 1) + q, k, v = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + with self.assertRaisesRegex(RuntimeError, "No available kernel"): + F.scaled_dot_product_attention(q, k, v) + + @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + def test_flash_attention_fail_with_non_square_causal_attention(self, device): + dtype = torch.bfloat16 + q_shape = SdpaShape(1, 1, 8, 16) + kv_shape = SdpaShape(1, 1, 12, 16) + make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) + make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) + q, k, v = make_q(), make_kv(), make_kv() + warning_str = "Flash attention XPU does not support the is_causal flag when seqlen_q != seqlen_k." + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + with self.assertWarnsRegex(UserWarning, warning_str): + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, 0.0, is_causal=True)) + + @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) + @parametrize("dtype", [torch.half, torch.bfloat16]) + @parametrize("batch_size", [1, 2, 4]) + @parametrize("n_head", [[3, 1], [4, 2], [10, 2]]) + @parametrize("q_size", [1, 32, 77, 128, 144, 512, 576]) + @parametrize("kv_size", [1, 32, 77, 128, 144, 512, 576]) + @parametrize("head_dim", [64, 96, 128, 192]) + @parametrize("mask_type", [None, "causal"]) + @parametrize("train", [True, False]) + @parametrize("layout", ["bshd"]) + @parametrize("enable_gqa", [True, False]) + def test_flash_attention_vs_math( + self, + device, + fused_kernel, + dtype, + batch_size, + q_size, + kv_size, + n_head, + head_dim, + mask_type, + train, + layout, + enable_gqa, + ): + if mask_type == "causal" and q_size != kv_size: + self.skipTest("Flash Attention V2 does not accept is_causal when seq_len_q != seq_len_k") + + tol = Tolerances(1e-5, 5e-6) + if dtype is torch.bfloat16: + tol = Tolerances(5e-2, 5e-2) + if dtype is torch.float16: + tol = Tolerances(1e-2, 1e-2) + make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False) + + if enable_gqa: + n_head_q, n_head_kv = n_head[0], n_head[1] + else: + n_head_q = n_head_kv = n_head[0] + + q_shape = SdpaShape(batch_size, n_head_q, q_size, head_dim) + kv_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim) + q = make_tensor(q_shape) + k = make_tensor(kv_shape) + v = make_tensor(kv_shape) + + # (B, S, H, D) by default + q = q.view(batch_size, q_size, n_head_q, head_dim).transpose(1, 2) + k = k.view(batch_size, kv_size, n_head_kv, head_dim).transpose(1, 2) + v = v.view(batch_size, kv_size, n_head_kv, head_dim).transpose(1, 2) + if layout == "bhsd": + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + is_causal = False + if mask_type == "causal": + is_causal = True + + q2, k2, v2 = q.clone(), k.clone(), v.clone() + q2, k2, v2 = q2.float(), k2.float(), v2.float() + + if train: + q = q.detach().clone().requires_grad_(True) + k = k.detach().clone().requires_grad_(True) + v = v.detach().clone().requires_grad_(True) + q2 = q2.detach().clone().requires_grad_(True) + k2 = k2.detach().clone().requires_grad_(True) + v2 = v2.detach().clone().requires_grad_(True) + + with sdpa_kernel(backends=[fused_kernel]): + actual = F.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, is_causal=is_causal, enable_gqa=enable_gqa) + + with sdpa_kernel(backends=[SDPBackend.MATH]): + if is_causal: + bottom_right_mask = causal_lower_right(q_size, kv_size) + math_ref = F.scaled_dot_product_attention( + q2, k2, v2, dropout_p=0.0, attn_mask=bottom_right_mask, enable_gqa=enable_gqa) + else: + math_ref = F.scaled_dot_product_attention( + q2, k2, v2, dropout_p=0.0, is_causal=is_causal, enable_gqa=enable_gqa) + + if dtype in [torch.float16, torch.bfloat16]: + math_ref = math_ref.to(dtype) + + self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol) + + if train: + loss = torch.mean(actual) + loss_ref = torch.mean(math_ref) + loss.backward() + loss_ref.backward() + + grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad + grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad + if dtype in [torch.float16, torch.bfloat16]: + grad_q_ref = grad_q_ref.to(dtype) + grad_k_ref = grad_k_ref.to(dtype) + grad_v_ref = grad_v_ref.to(dtype) + + self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol) + self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol) + self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol) class TestAttnBias(NNTestCase): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index ce00b67373e26..75e14dbc86b96 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5691,7 +5691,7 @@ def meta__scaled_dot_product_flash_attention( # are going to use cudagraphs or not, so we return meta tensors here # it's possible we'll need to have some special handling in inductor for sdpa # See [Note] BC breaking change to flash seed/offset - if torch.version.hip and torch.cuda.is_available(): + if torch.version.hip and torch.cuda.is_available() or device_hint(query) == "xpu": # Maintain old path on AMD seed = torch.empty((), dtype=torch.long, device="meta") offset = torch.empty((), dtype=torch.long, device="meta") diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 4de6ba3976688..6a9e2ca842050 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -125,6 +125,10 @@ #endif #endif +#ifdef USE_XPU +#include +#endif + #ifdef USE_DISTRIBUTED #ifdef USE_C10D #include @@ -2477,7 +2481,7 @@ Call this whenever a new thread is created in order to propagate values from .value("OVERRIDEABLE", sdp::SDPBackend::overrideable); py_module.def("_is_flash_attention_available", []() { -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_XPU) return sdp::is_flash_attention_available(); #else return false; @@ -2486,7 +2490,7 @@ Call this whenever a new thread is created in order to propagate values from py_module.def( "_can_use_flash_attention", [](const sdp::sdp_params& params, bool debug) { -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_XPU) return sdp::can_use_flash_attention(params, debug); #else return false; diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 3f41e4e1a6b12..49adef8de4031 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -15,6 +15,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHand AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 0e491d0eb635a..a524b6ab43fd8 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -232,13 +232,15 @@ def _dispatch( query, key, value, None, dropout_p, is_causal, enable_gqa ) if can_use_flash_attention(sdpa_params): - needs_padding = query.size(-1) % 8 != 0 + alignment = 64 if query.device.type == "xpu" else 8 og_head_size = query.size(-1) og_scale = _calculate_scale(og_head_size, scale) + needs_padding = og_head_size % alignment != 0 if needs_padding: - query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8)) - key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8)) - value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8)) + pad_len = alignment - (og_head_size % alignment) + query = torch.nn.functional.pad(query, (0, pad_len)) + key = torch.nn.functional.pad(key, (0, pad_len)) + value = torch.nn.functional.pad(value, (0, pad_len)) out = torch.ops.aten._scaled_dot_product_flash_attention( query, key, From 34a98608afa0cb5b48f0d6d30432fdd0a2614ddf Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 3 Dec 2025 12:08:59 -0600 Subject: [PATCH 203/338] [MPS] Add `linalg.lu_solve` and `linalg.lu` (#167569) Fixes #167238 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167569 Approved by: https://github.com/malfet --- .../native/mps/operations/LinearAlgebra.mm | 16 ++++ aten/src/ATen/native/native_functions.yaml | 1 + torch/_refs/linalg/__init__.py | 74 +++++++++++++++++++ torch/testing/_internal/common_mps.py | 17 ++--- .../_internal/opinfo/definitions/linalg.py | 8 +- 5 files changed, 106 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 00f9c96b78af8..c6d766f92f2b0 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -28,7 +28,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -1541,6 +1543,20 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors); } +TORCH_IMPL_FUNC(linalg_lu_out_mps)(const Tensor& A, bool pivot, const Tensor& P, const Tensor& L, const Tensor& U) { + Tensor LU = at::empty({0}, A.scalar_type(), std::nullopt, kMPS, std::nullopt, MemoryFormat::Contiguous); + auto pivots = at::empty({0}, A.options().dtype(kInt)); + auto info = at::empty({0}, A.options().dtype(kInt)); + mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, /*check_errors=*/false); + at::lu_unpack_out(const_cast(P), + const_cast(L), + const_cast(U), + LU, + pivots, + /*unpack_data=*/true, + /*unpack_pivots=*/pivot); +} + TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 248a3e1875e55..e6b5bfcd18727 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14227,6 +14227,7 @@ structured: True dispatch: CPU, CUDA: linalg_lu_out + MPS: linalg_lu_out_mps # linalg.lu_solve - func: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 4d194f773f859..393e42b06d15c 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import math from functools import partial from typing import Optional, Union @@ -359,3 +360,76 @@ def svdvals(A: TensorLikeType) -> Tensor: def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: check_fp_or_complex(x.dtype, "linalg.vecdot") return (x.conj() * y).sum(dim=dim) + + +def _pivots_to_permutation(pivots, shape, *, inverse=False): + perm = torch.empty(shape, dtype=torch.int32, device=pivots.device) + perm[..., :] = torch.arange(shape[-1], dtype=torch.int32, device=pivots.device) + indices = range(shape[-1]) + if inverse: + indices = reversed(indices) + + if len(shape) > 1: + for i in indices: + j_s = pivots[..., i] + perm_i = perm[..., i].clone() + j_idx = torch.meshgrid( + *[torch.arange(s, device=perm.device) for s in j_s.shape], indexing="ij" + ) + (j_s,) + perm_j = perm[j_idx] + perm.index_put_(j_idx, perm_i) + perm[..., i].copy_(perm_j) + + else: + for i in indices: + j = pivots[i] + perm_i = perm[i].clone() + perm_j = perm[j].clone() + perm[i].copy_(perm_j) + perm[j].copy_(perm_i) + + return perm + + +def _apply_pivots(a, pivots, shape, *, inverse=False): + perm = _pivots_to_permutation(pivots - 1, shape, inverse=inverse) + + if len(shape) == 1: + return a[perm, :] + else: + idx = torch.meshgrid( + *[torch.arange(s, device=a.device) for s in perm.shape], indexing="ij" + )[:-1] + (perm, slice(None)) + return a[idx] + + +def linalg_lu_solve_out_mps(LU, pivots, B, *, left=True, adjoint=False, out): + if out.numel() == 0: + return + + if not left: + adjoint = not adjoint + B = B.mH + + if adjoint: + lu_ = LU.mH + x = torch.linalg.solve_triangular(lu_, B, left=True, upper=False) + x = torch.linalg.solve_triangular( + lu_, x, left=True, upper=True, unitriangular=True + ) + x = _apply_pivots(x, pivots, LU.shape[:-1], inverse=True) + else: + x = _apply_pivots(B, pivots, LU.shape[:-1]) + x = torch.linalg.solve_triangular( + LU, x, left=True, upper=False, unitriangular=True + ) + x = torch.linalg.solve_triangular(LU, x, left=True, upper=True) + + if not left: + x = x.mH + + out.copy_(x) + + +mps_lib = torch.library.Library("aten", "IMPL", "MPS") # noqa: TOR901 +mps_lib.impl("aten::linalg_lu_solve.out", linalg_lu_solve_out_mps) diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 9d3d65aba9a2d..2b1f1be0e02f9 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -330,15 +330,12 @@ def mps_ops_modifier( "linalg.ldl_solve": None, "linalg.lstsq": None, "linalg.lstsqgrad_oriented": None, - "linalg.lu": None, - "linalg.lu_solve": None, "linalg.matrix_norm": [torch.float32], "linalg.norm": [torch.float32], "linalg.normsubgradients_at_zero": [torch.float32], "linalg.qr": None, "linalg.svdvals": None, "linalg.vecdot": None, - "lu_solve": None, "masked.median": None, "matrix_exp": None, "mode": None, @@ -700,16 +697,18 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: torch.float16, torch.float32, ], # missing `aten::lu_solve`. + # `linalg.lu_solve`'s backward pass for the `LU` arg calls + # `lu_unpack`, and pivots are unpacked if `left == adjoint`. When + # unpacking pivots, `lu_unpack` incorrectly raises an error if + # `pivots.shape` is zero in any of the batch dims and the last dim + # is greater than 1. + "linalg.lu_solve": None, + # lu_solve only fails on MacOS 14 for some reason + "lu_solve": None if MACOS_VERSION < 15.0 else [], "linalg.tensorsolve": [ torch.float16, torch.float32, ], # missing `aten::lu_solve`. - "linalg.det": [torch.float16, torch.float32], # missing aten::lu_solve.out - "linalg.slogdet": [ - torch.float16, - torch.float32, - ], # missing aten::lu_solve.out - "logdet": [torch.float16, torch.float32], # missing aten::lu_solve.out "aminmax": [torch.float32, torch.float16], "special.i1": [torch.float16], # "i1_backward" not implemented for 'Half' "special.i1e": [torch.float16], # "i1e_backward" not implemented for 'Half' diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index da75f82815507..95cb59df0fcb4 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -256,7 +256,13 @@ def clone(X, requires_grad): for n, batch, rhs in product(ns, batches, nrhs): A = make_a(*(batch + (n, n))) - LU, pivots = torch.linalg.lu_factor(A) + if torch.device(device).type == "mps": + # TODO: Fix lu_factor for MPS, because it does not work for all of + # these cases. So we resort to the CPU impl here and move the + # outputs back to MPS. + LU, pivots = (x.to(device) for x in torch.linalg.lu_factor(A.cpu())) + else: + LU, pivots = torch.linalg.lu_factor(A) B = make_b(batch + (n, rhs)) From 89e3bbcb5b5321dc8b9520b4d5a8ee60cea1d0b4 Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Wed, 3 Dec 2025 21:37:30 +0000 Subject: [PATCH 204/338] [Accelerator] Add Accelerator Capabilities API (#165631) # Motivation There are several issues related to the data type and precision that an accelerator supports (see #165038 and #143112). Sometimes, we have to check for these capabilities in the document, and then hard-code. This PR proposes a new unified API for users to check their accelerator capabilities. # Changes This PR creates a new data structure `DeviceCapability` containing the capabilities that an accelerator commonly has: - Supporting DataType (set to be supported as default): - `fp16`, `int32`, `complex` ... etc - Other capabilities (need to be discussed) To access the structure, this PR defines a new Python API in the Accelerator module -- `get_device_capability`. It takes `device` as an input and returns a dictionary containing the capabilities (now we have `supported_dtypes` as the key). # Usage ```python >>> import torch >>> import torch_openreg >>> torch.accelerator.get_device_capability('openreg:0') {'supported_dtypes': [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64, torch.complex32, torch.complex64, torch.complex128, torch.bool, torch.qint8, torch.quint8, torch.qint32, torch.bfloat16, torch.quint4x2, torch.quint2x4, torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16, torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e5m2fnuz, torch.float8_e4m3fnuz, torch.uint16, torch.uint32, torch.uint64, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.int1, torch.int2, torch.int3, torch.int4, torch.int5, torch.int6, torch.int7, torch.float8_e8m0fnu, torch.float4_e2m1fn_x2]} ``` # TODO - So far, precision is the only capability to track, based on my knowledge. But we can find more capabilities in common, and the API should be designed for good extension. - It will support other in-tree accelerators, such as **cuda** and **mps**. - Clarify whether the capabilities are software or hardware supported. (By @guangyey ) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165631 Approved by: https://github.com/guangyey, https://github.com/albanD Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Co-authored-by: Jiawei Li --- aten/src/ATen/DeviceAccelerator.cpp | 6 ++ aten/src/ATen/DeviceAccelerator.h | 5 ++ c10/core/DeviceCapability.h | 76 +++++++++++++++++++ c10/core/impl/DeviceGuardImplInterface.h | 27 +++++++ c10/core/impl/VirtualGuardImpl.h | 4 + .../torch_openreg/csrc/runtime/OpenRegGuard.h | 9 +++ .../torch_openreg/tests/test_device.py | 9 ++- torch/_C/__init__.pyi.in | 1 + torch/accelerator/__init__.py | 27 ++++++- torch/csrc/DeviceAccelerator.cpp | 19 +++++ 10 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 c10/core/DeviceCapability.h diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index aa9d6e6b1ce9b..efab9ec9c5927 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -130,6 +130,12 @@ c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index) { impl.uncheckedSetDevice({device_type, device_index}); return impl.getDevice().index(); } + +c10::DeviceCapability getDeviceCapability(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + c10::impl::VirtualGuardImpl impl(device_type); + return impl.getDeviceCapability({device_type, device_index}); +} // NOLINTEND(bugprone-unchecked-optional-access) } // namespace at::accelerator diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index 2cc4cff7cd1f2..d24b42ca459e7 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -73,6 +74,10 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +// Get the device capability of the given device index. +TORCH_API c10::DeviceCapability getDeviceCapability( + c10::DeviceIndex device_index); + TORCH_API inline void emptyCache() { const auto device_type = getAccelerator(true).value(); at::getDeviceAllocator(device_type)->emptyCache(); diff --git a/c10/core/DeviceCapability.h b/c10/core/DeviceCapability.h new file mode 100644 index 0000000000000..cc171dfcd6ffe --- /dev/null +++ b/c10/core/DeviceCapability.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +constexpr size_t NUMBER_OF_DEVICE_CAPABILITIES = NumScalarTypes; + +// Generate bitfields for each scalar type +#define DEFINE_SCALAR_TYPE(_1, n) unsigned int has_##n : 1; + +// Generate enum indices for each scalar type +#define DEFINE_SCALAR_ENUM(_1, name) kIndex_##name, + +enum ScalarTypeIndex { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_ENUM) +}; + +/** + * @brief DeviceCapability represents the the common capabilities that all + * devices should support. + * + * This struct provides a compact way to represent the common capabilities that + * all devices should support. Includes the following capabilities: + * - Supported data types + * + * Purpose + * - Enable device-specific optimizations based on supported capabilities + * + * Contract + * + * Supported data types: + * - Each bitfield represents support for one device capability + * - Bit value 1 means the capability is supported, 0 means not supported + * - The struct is initialized with all capabilities enabled by default + * + * @note Adding New Capabilities + * + * 1. Define the new capability in the `DeviceCapability` struct + * 2. Update the support of the new capability in each accelerator + * implementation + * 3. Add the new capability to the returned PyObject Dictionary + */ +struct C10_API DeviceCapability { + union { + struct { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE) + } supported_scalar_types; + uint64_t capability_bits; // Allow direct bit manipulation + } capability_data; + + // Default constructor with all capabilities enabled. + DeviceCapability() { + capability_data.capability_bits = + ((1ULL << NUMBER_OF_DEVICE_CAPABILITIES) - 1); + } + + // Iterate supported ScalarTypes without allocating a vector + template + void forEachSupportedScalarType(F&& visitor) const { +#define VISIT_SCALAR_TYPE(_1, n) \ + if (capability_data.supported_scalar_types.has_##n) { \ + visitor(ScalarType::n); \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(VISIT_SCALAR_TYPE) + +#undef VISIT_SCALAR_TYPE + } +}; + +#undef DEFINE_SCALAR_ENUM +#undef DEFINE_SCALAR_TYPE +} // namespace c10 diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index f9f67497c6315..141fb05cb77d1 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -191,6 +192,15 @@ struct C10_API DeviceGuardImplInterface { */ virtual DeviceIndex deviceCount() const noexcept = 0; + /** + * Get the following capabilities of the current device: + * (1) Data type support + * Returns DeviceCapability object. + */ + virtual DeviceCapability getDeviceCapability(Device /*unused*/) const { + TORCH_CHECK(false, "Backend doesn't support getting device capabilities."); + } + /** * Return true if all the work previously enqueued on the stream for * asynchronous execution has completed running on the device. @@ -291,6 +301,23 @@ struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface { return 1; } + DeviceCapability getDeviceCapability(Device /*unused*/) const override { + DeviceCapability cap; + if constexpr (D == DeviceType::Meta) { + cap.capability_data.capability_bits = 0; + // Meta only supports basic types for shape inference + // Byte, Char, Short, Int, Long, Float, Double, + // Bool, ComplexFloat, ComplexDouble + cap.capability_data.capability_bits = (1ULL << kIndex_Byte) | + (1ULL << kIndex_Char) | (1ULL << kIndex_Short) | + (1ULL << kIndex_Int) | (1ULL << kIndex_Long) | + (1ULL << kIndex_Float) | (1ULL << kIndex_Double) | + (1ULL << kIndex_ComplexFloat) | (1ULL << kIndex_ComplexDouble) | + (1ULL << kIndex_Bool); + } + return cap; + } + // Event-related functions void record( void** /*event*/, diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 3d259f5e390e3..0254c69baba00 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -57,6 +57,10 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { return impl_->deviceCount(); } + DeviceCapability getDeviceCapability(Device d) const override { + return impl_->getDeviceCapability(d); + } + // Event functions void record( void** event, diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h index 59bc2d5cdbff5..3c1c1193d3cdb 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -50,6 +51,14 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { return c10::Device(static_type, device_index); } + /** + * Get the device capability for a given device. + * By default, OpenReg has 2 same devices with the same capability. + */ + c10::DeviceCapability getDeviceCapability(c10::Device /*unused*/) const override { + return c10::DeviceCapability(); + } + /** * Set the current device to c10::Device. */ diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py index f925f15600ce7..9cb4a785d36e7 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_device.py @@ -1,7 +1,7 @@ # Owner(s): ["module: PrivateUse1"] import torch -import torch_openreg # noqa: F401 +from torch.testing._internal.common_dtype import get_all_dtypes from torch.testing._internal.common_utils import run_tests, TestCase @@ -31,6 +31,13 @@ def test_invalid_device_index(self): with self.assertRaisesRegex(RuntimeError, "The device index is out of range"): torch.accelerator.set_device_index(2) + def test_device_capability(self): + capability = torch.accelerator.get_device_capability("openreg:0") + supported_dtypes = capability["supported_dtypes"] + expected_dtypes = get_all_dtypes(include_complex32=True, include_qint=True) + + self.assertTrue(all(dtype in supported_dtypes for dtype in expected_dtypes)) + if __name__ == "__main__": run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 532815d535d5e..520d07d487270 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2494,6 +2494,7 @@ def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails def _accelerator_getAccelerator() -> _device: ... def _accelerator_setDeviceIndex(device_index: _int) -> None: ... def _accelerator_getDeviceIndex() -> _int: ... +def _accelerator_getDeviceCapability(device_index: _int) -> dict[str, Any]: ... def _accelerator_setStream(Stream) -> None: ... def _accelerator_getStream(device_index: _int) -> Stream: ... def _accelerator_synchronizeDevice(device_index: _int) -> None: ... diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index e1a82aa63ce22..b0dfbe400bfbc 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -2,7 +2,8 @@ This package introduces support for the current :ref:`accelerator` in python. """ -from typing import Optional +from functools import cache +from typing import Any from typing_extensions import deprecated import torch @@ -25,6 +26,7 @@ "current_accelerator", "current_device_idx", # deprecated "current_device_index", + "get_device_capability", "current_stream", "device_count", "device_index", @@ -152,6 +154,29 @@ def current_device_index() -> int: """ +@cache +def get_device_capability(device: _device_t = None, /) -> dict[str, Any]: + r"""Return the capability of the currently selected device. + + Args: + device (:class:`torch.device`, str, int, optional): The device to query capabilities for + :ref:`accelerator` device type. If not given, + use :func:`torch.accelerator.current_device_index` by default. + + Returns: + dict[str, Any]: A dictionary containing device capability information. The dictionary includes: + - ``supported_dtypes`` (set(torch.dtype)): Set of PyTorch data types supported by the device + + Examples: + >>> # xdoctest: +SKIP("requires cuda") + >>> # Query capabilities for current device + >>> capabilities = torch.accelerator.get_device_capability("cuda:0") + >>> print("Supported dtypes:", capabilities["supported_dtypes"]) + """ + device_index = _get_device_index(device, optional=True) + return torch._C._accelerator_getDeviceCapability(device_index) + + def set_device_index(device: _device_t, /) -> None: r"""Set the current device index to a given device. diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 14e54851178f5..c6ffa893d95ae 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -33,6 +33,25 @@ void initModule(PyObject* module) { return at::accelerator::getDeviceIndex(); }); + m.def("_accelerator_getDeviceCapability", [](c10::DeviceIndex device_index) { + const auto device_type = at::accelerator::getAccelerator(true).value(); + torch::utils::maybe_initialize_device(device_type); + auto caps = at::accelerator::getDeviceCapability(device_index); + + py::dict dict; + + py::set dtype_set; + caps.forEachSupportedScalarType([&](c10::ScalarType dtype) { + THPDtype* thp_dtype = torch::getTHPDtype(dtype); + py::object dtype_obj = + py::reinterpret_borrow((PyObject*)thp_dtype); + dtype_set.add(dtype_obj); + }); + + dict["supported_dtypes"] = dtype_set; + return dict; + }); + m.def("_accelerator_setStream", [](c10::Stream stream) { const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); From e7d24d3ff93d1503ba63860b7057438ad93f918e Mon Sep 17 00:00:00 2001 From: Dino Viehland Date: Wed, 3 Dec 2025 22:00:27 +0000 Subject: [PATCH 205/338] [dynamo, 3.14] Don't use _PyObject_GC_TRACK which is internal and calls an unexported API symbol (#169490) Summary: In CPython 3.14.1 _PyObject_GC_TRACK calls an unexported function causing build failures. Switch to using the public exported PyObject_GC_Track. Test Plan: `buck build --flagfile fbcode//mode/opt fbsource//third-party/pypi/docling-ibm-models/__test__:test-cpython3.14` succceds now. Reviewed By: itamaro Differential Revision: D88288162 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169490 Approved by: https://github.com/williamwen42, https://github.com/malfet --- torch/csrc/dynamo/cpython_defs.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c index df7b40ba0da7b..e0cb3bfe29607 100644 --- a/torch/csrc/dynamo/cpython_defs.c +++ b/torch/csrc/dynamo/cpython_defs.c @@ -252,7 +252,7 @@ static void THP_take_ownership(PyFrameObject* f, _PyInterpreterFrame* frame) { PyErr_SetRaisedException(exc); } if (!_PyObject_GC_IS_TRACKED((PyObject*)f)) { - _PyObject_GC_TRACK((PyObject*)f); + PyObject_GC_Track((PyObject*)f); } Py_END_CRITICAL_SECTION(); } From 61be54a31dc09b59d99b62176fb935aee0b924ef Mon Sep 17 00:00:00 2001 From: "Chen Chen (AI Infra)" Date: Wed, 3 Dec 2025 22:13:51 +0000 Subject: [PATCH 206/338] update the MHA to enable cudnn-frontend 1.12.1 (#169086) Summary: As titled. Before: ``` fbcode/caffe2/aten/src/ATen/native/cudnn/MHA.cpp:490:12: error: no member named 'set_generate_stats' in 'cudnn_frontend::graph::SDPA_attributes' 488 | fe::graph::SDPA_attributes() | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 489 | .set_name("CUDNN_SDPA") | ~~~~~~~~~~~~~~~~~~~~~~~ 490 | .set_generate_stats(return_softmaxstats) | ^ fbcode/caffe2/aten/src/ATen/native/cudnn/MHA.cpp:708:12: error: no member named 'set_generate_stats' in 'cudnn_frontend::graph::SDPA_attributes' 706 | fe::graph::SDPA_attributes() | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 707 | .set_name("CUDNN_SDPA_NESTEDTENSOR") | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 708 | .set_generate_stats(return_softmaxstats) | ^ 2 errors generated. ``` Test Plan: sandcastle Differential Revision: D87832011 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169086 Approved by: https://github.com/eqy, https://github.com/drisspg, https://github.com/Skylion007 --- aten/src/ATen/native/cudnn/MHA.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 504688f203333..58dd0552cab5e 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -487,7 +487,11 @@ std::unique_ptr build_graph( auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA") +#if CUDNN_FRONTEND_VERSION <= 11200 + .set_is_inference(!return_softmaxstats) +#else .set_generate_stats(return_softmaxstats) +#endif .set_causal_mask(is_causal) .set_attn_scale(attn_scale); if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { @@ -705,7 +709,11 @@ std::unique_ptr build_graph_nestedtensor( auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA_NESTEDTENSOR") +#if CUDNN_FRONTEND_VERSION <= 11200 + .set_is_inference(!return_softmaxstats) +#else .set_generate_stats(return_softmaxstats) +#endif .set_causal_mask(is_causal) .set_attn_scale(attn_scale) .set_seq_len_q(SEQ_LEN_Q_) From fdf863d5e1de3b2688c9511e96876e34581dbfd7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 3 Dec 2025 22:23:29 +0000 Subject: [PATCH 207/338] Revert "Triton 3.6 pin update (#168096)" This reverts commit 93d0d6838c56af59b0dba794e6aa08f0c1c7799c. Reverted https://github.com/pytorch/pytorch/pull/168096 on behalf of https://github.com/atalman due to Causes timeouts https://github.com/pytorch/pytorch/issues/169492 ([comment](https://github.com/pytorch/pytorch/pull/168096#issuecomment-3609092057)) --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- .github/scripts/amd/package_triton_wheel.sh | 1 - .../rocm/dynamic_inductor_timm_training.csv | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 263fcf2e0bdbb..7aab8bed1c108 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -5261b27331eb1dd09df9ec1bd6acc21cbb184481 +bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 40c341bdcdbe8..d5c0c99142898 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.6.0 +3.5.1 diff --git a/.github/scripts/amd/package_triton_wheel.sh b/.github/scripts/amd/package_triton_wheel.sh index 501e50e2fe2f1..fe8d915422dac 100755 --- a/.github/scripts/amd/package_triton_wheel.sh +++ b/.github/scripts/amd/package_triton_wheel.sh @@ -87,7 +87,6 @@ done cp -r $ROCM_HOME/include/hip $TRITON_ROCM_DIR/include cp -r $ROCM_HOME/include/roctracer $TRITON_ROCM_DIR/include cp -r $ROCM_HOME/include/hsa $TRITON_ROCM_DIR/include -cp -r $ROCM_HOME/include/hipblas-common $TRITON_ROCM_DIR/include # Copy linker mkdir -p $TRITON_ROCM_DIR/llvm/bin diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv index 702da0cb57f89..2d087e6595526 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv @@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7 -convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7 +convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 From 296e67c92635443c67b11c0ae1bd045f03ebb7bc Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 3 Dec 2025 09:39:37 -0800 Subject: [PATCH 208/338] [effect] Remove special handling for profiler op (#168389) We shouldn't need this anymore as we have a registration for the op to have no effect Differential Revision: [D87680134](https://our.internmc.facebook.com/intern/diff/D87680134) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168389 Approved by: https://github.com/zou3519 ghstack dependencies: #167364 --- torch/_higher_order_ops/effects.py | 5 ----- torch/_ops.py | 34 +----------------------------- 2 files changed, 1 insertion(+), 38 deletions(-) diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index 86707a4f55ef1..96d7872048ec8 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -112,11 +112,6 @@ def has_aliasing(op: OpType): def has_effects(op) -> bool: - # Skip over the profiler's RecordFunction as they should not show up in the graph - _skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction} - if op in _skip_ops: - return False - return ( isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) and not has_aliasing(op) diff --git a/torch/_ops.py b/torch/_ops.py index 8f8a7328429fa..75905d78da5b5 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1043,28 +1043,6 @@ def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): if _may_use_fallthrough_instead_of_fallback(key) ] - @contextlib.contextmanager - def _register_as_effectful_op_temporarily(self): - from torch._higher_order_ops.effects import ( - _EffectType, - _get_effect, - _register_effectful_op, - ) - - try: - # We don't want to register the effect if there already exists a - # registration, especially if the registration is None (explicitly - # no effect) - register_tmp_effect = _get_effect(self) is None - handle = None - if register_tmp_effect: - handle = _register_effectful_op(self, _EffectType.ORDERED) - yield - finally: - if register_tmp_effect: - assert handle is not None - handle.destroy() - # Use positional-only argument to avoid naming collision with aten ops arguments # that are named "self". This way, all the aten ops can be called by kwargs. def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: @@ -1072,17 +1050,7 @@ def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: # When any inputs are FakeScriptObject, we need to # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject. - # - # Note: - # 1. We only register the torchbind op temporarily as effectful op because we only want - # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior - # of the eagerly executing the op might change after tracing. - # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might - # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction. - with self._register_as_effectful_op_temporarily(): - return self._dispatch_in_python( - self._fallthrough_keys(), *args, **kwargs - ) + return self._dispatch_in_python(self._fallthrough_keys(), *args, **kwargs) return self._op(*args, **kwargs) def _dispatch_in_python( From 6e404e9b7d6f5fb0de86aa73888c3038248c17f8 Mon Sep 17 00:00:00 2001 From: sekyonda <127536312+sekyondaMeta@users.noreply.github.com> Date: Wed, 3 Dec 2025 22:41:46 +0000 Subject: [PATCH 209/338] Compiler User Guide Landing Pages (#165635) Moving the Torch.compile pages to the User Guide (from the API section) to provide an easier user experience. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165635 Approved by: https://github.com/angelayi --- docs/source/onnx_export.md | 4 +- docs/source/pytorch-api.md | 4 +- .../torch.compiler_troubleshooting_old.md | 2 +- docs/source/user_guide/index.md | 8 +++ .../user_guide/torch_compiler/advanced.md | 13 ++++ .../torch_compiler/api_reference.md | 12 ++++ .../_static/dynamo_summary_diagram.png | Bin ...dynamic_shapes_advanced_control_options.md | 10 +-- .../compile/dynamic_shapes_backed_unbacked.md | 0 .../dynamic_shapes_beyond_the_basics.md | 0 .../compile/dynamic_shapes_core_concepts.md | 0 ...mic_shapes_debugging_tlparse_torch_logs.md | 4 +- .../compile/dynamic_shapes_troubleshooting.md | 0 ...c_shapes_troubleshooting_guardon_errors.md | 0 .../dynamic_shapes_zero_one_specialization.md | 0 .../torch_compiler}/compile/header_code.py | 0 .../programming_model.common_graph_breaks.md | 0 .../programming_model.compiler_disable.md | 0 .../compile/programming_model.custom_ops.md | 0 .../programming_model.dynamo_core_concepts.md | 0 ...rogramming_model.dynamo_nonstrict_trace.md | 0 .../programming_model.error_on_graph_break.md | 0 .../programming_model.fullgraph_false.md | 0 .../programming_model.fullgraph_true.md | 0 .../programming_model.graph_breaks_index.md | 0 .../compile/programming_model.md | 2 + .../programming_model.nested_graph_breaks.md | 0 ...gramming_model.non_strict_tracing_model.md | 0 .../programming_model.observability.md | 0 .../programming_model.recompilation.md | 0 .../programming_model.reporting_issues.md | 0 .../programming_model.skipped_functions.md | 0 ...rogramming_model.where_to_apply_compile.md | 0 .../torch_compiler/core_concepts.md | 13 ++++ .../{ => user_guide/torch_compiler}/export.md | 4 +- .../torch_compiler}/export/api_reference.md | 0 .../torch_compiler}/export/draft_export.md | 4 +- .../torch_compiler}/export/ir_spec.md | 0 .../export/joint_with_descriptors.md | 0 .../export/programming_model.md | 0 .../torch_compiler}/export/pt2_archive.md | 0 .../user_guide/torch_compiler/performance.md | 12 ++++ .../torch_compiler}/torch.compiler.config.md | 0 .../torch_compiler}/torch.compiler.md | 58 +++++++++--------- .../torch.compiler_aot_inductor.md | 2 +- ...h.compiler_aot_inductor_debugging_guide.md | 2 +- .../torch.compiler_aot_inductor_minifier.md | 0 .../torch.compiler_backward.md | 2 + .../torch.compiler_cudagraph_trees.md | 0 .../torch.compiler_custom_backends.md | 0 .../torch.compiler_dynamic_shapes.md | 8 +-- .../torch.compiler_dynamo_deepdive.md | 0 .../torch.compiler_dynamo_overview.md | 8 ++- .../torch.compiler_fake_tensor.md | 0 .../torch_compiler}/torch.compiler_faq.md | 4 +- .../torch.compiler_fine_grain_apis.md | 2 +- .../torch.compiler_get_started.md | 2 +- .../torch.compiler_inductor_profiling.md | 8 +-- .../torch.compiler_inductor_provenance.rst | 14 ++--- .../torch_compiler}/torch.compiler_ir.md | 4 +- .../torch.compiler_nn_module.md | 4 +- .../torch.compiler_performance_dashboard.md | 0 .../torch.compiler_profiling_torch_compile.md | 16 ++--- .../torch.compiler_transformations.md | 0 .../torch.compiler_troubleshooting.md | 2 +- .../torch_compiler/troubleshooting_faqs.md | 13 ++++ 66 files changed, 159 insertions(+), 82 deletions(-) create mode 100644 docs/source/user_guide/torch_compiler/advanced.md create mode 100644 docs/source/user_guide/torch_compiler/api_reference.md rename docs/source/{ => user_guide/torch_compiler}/compile/_static/dynamo_summary_diagram.png (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/dynamic_shapes_advanced_control_options.md (96%) rename docs/source/{ => user_guide/torch_compiler}/compile/dynamic_shapes_backed_unbacked.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/dynamic_shapes_beyond_the_basics.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/dynamic_shapes_core_concepts.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/dynamic_shapes_debugging_tlparse_torch_logs.md (95%) rename docs/source/{ => user_guide/torch_compiler}/compile/dynamic_shapes_troubleshooting.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/dynamic_shapes_troubleshooting_guardon_errors.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/dynamic_shapes_zero_one_specialization.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/header_code.py (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.common_graph_breaks.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.compiler_disable.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.custom_ops.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.dynamo_core_concepts.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.dynamo_nonstrict_trace.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.error_on_graph_break.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.fullgraph_false.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.fullgraph_true.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.graph_breaks_index.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.md (95%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.nested_graph_breaks.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.non_strict_tracing_model.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.observability.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.recompilation.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.reporting_issues.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.skipped_functions.md (100%) rename docs/source/{ => user_guide/torch_compiler}/compile/programming_model.where_to_apply_compile.md (100%) create mode 100644 docs/source/user_guide/torch_compiler/core_concepts.md rename docs/source/{ => user_guide/torch_compiler}/export.md (99%) rename docs/source/{ => user_guide/torch_compiler}/export/api_reference.md (100%) rename docs/source/{ => user_guide/torch_compiler}/export/draft_export.md (98%) rename docs/source/{ => user_guide/torch_compiler}/export/ir_spec.md (100%) rename docs/source/{ => user_guide/torch_compiler}/export/joint_with_descriptors.md (100%) rename docs/source/{ => user_guide/torch_compiler}/export/programming_model.md (100%) rename docs/source/{ => user_guide/torch_compiler}/export/pt2_archive.md (100%) create mode 100644 docs/source/user_guide/torch_compiler/performance.md rename docs/source/{ => user_guide/torch_compiler}/torch.compiler.config.md (100%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler.md (82%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_aot_inductor.md (99%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_aot_inductor_debugging_guide.md (98%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_aot_inductor_minifier.md (100%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_backward.md (99%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_cudagraph_trees.md (100%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_custom_backends.md (100%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_dynamic_shapes.md (95%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_dynamo_deepdive.md (100%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_dynamo_overview.md (98%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_fake_tensor.md (100%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_faq.md (99%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_fine_grain_apis.md (99%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_get_started.md (99%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_inductor_profiling.md (95%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_inductor_provenance.rst (88%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_ir.md (93%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_nn_module.md (98%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_performance_dashboard.md (100%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_profiling_torch_compile.md (94%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_transformations.md (100%) rename docs/source/{ => user_guide/torch_compiler}/torch.compiler_troubleshooting.md (99%) create mode 100644 docs/source/user_guide/torch_compiler/troubleshooting_faqs.md diff --git a/docs/source/onnx_export.md b/docs/source/onnx_export.md index 0adfec359d0b8..cf1f0ab4a9687 100644 --- a/docs/source/onnx_export.md +++ b/docs/source/onnx_export.md @@ -179,7 +179,7 @@ The overall ONNX graph has the following `metadata_props`: This property contains a string representation of the graph_signature from the original PyTorch ExportedProgram. The graph signature describes the structure of the model's inputs and outputs and how they map to the ONNX graph. The inputs are defined as `InputSpec` objects, which include the kind of input (e.g., `InputKind.PARAMETER` for parameters, `InputKind.USER_INPUT` for user-defined inputs), the argument name, the target (which can be a specific node in the model), and whether the input is persistent. The outputs are defined as `OutputSpec` objects, which specify the kind of output (e.g., `OutputKind.USER_OUTPUT`) and the argument name. - To read more about the graph signature, please see the {doc}`torch.export ` for more information. + To read more about the graph signature, please see the {doc}`torch.export ` for more information. - **pkg.torch.export.ExportedProgram.range_constraints** @@ -188,7 +188,7 @@ The overall ONNX graph has the following `metadata_props`: *Example:* `s0: VR[2, int_oo]`, which indicates that the size of the input tensor must be at least 2. - To read more about range constraints, please see the {doc}`torch.export ` for more information. + To read more about range constraints, please see the {doc}`torch.export ` for more information. Each input value in the ONNX graph may have the following metadata property: diff --git a/docs/source/pytorch-api.md b/docs/source/pytorch-api.md index c0f1302b8e8ed..b2e42f5e381d6 100644 --- a/docs/source/pytorch-api.md +++ b/docs/source/pytorch-api.md @@ -32,7 +32,7 @@ mtia.memory mtia.mtia_graph meta torch.backends -torch.export +torch.export torch.distributed torch.distributed.tensor torch.distributed.algorithms.join @@ -45,7 +45,7 @@ torch.distributed.pipelining torch.distributed._symmetric_memory torch.distributed.checkpoint torch.distributions -torch.compiler +torch.compiler torch.fft torch.func futures diff --git a/docs/source/torch.compiler_troubleshooting_old.md b/docs/source/torch.compiler_troubleshooting_old.md index ef13fc1772374..b10441161a51a 100644 --- a/docs/source/torch.compiler_troubleshooting_old.md +++ b/docs/source/torch.compiler_troubleshooting_old.md @@ -211,7 +211,7 @@ debugging. There are two tools available to enable this: If the error does not occur with the `"eager"` backend, then the backend compiler is the source of the error ([example error](https://gist.github.com/mlazos/2f13681e3cc6c43b3911f336327032de)). -There are [different choices](./torch.compiler.md) +There are [different choices](./user_guide/torch_compiler/torch.compiler.md) for backend compilers for TorchDynamo, with TorchInductor fitting the needs of most users. This section focuses on TorchInductor as the motivating example, but some tools can also be used with other diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md index 3a341893ef90b..1340e9e013abb 100644 --- a/docs/source/user_guide/index.md +++ b/docs/source/user_guide/index.md @@ -25,6 +25,14 @@ Learn the Basics pytorch_main_components ``` +```{toctree} +:maxdepth: 1 +:caption: Torch Compile + +Torch.compile +Torch.export +``` + ```{toctree} :maxdepth: 1 :caption: Beyond the Basics diff --git a/docs/source/user_guide/torch_compiler/advanced.md b/docs/source/user_guide/torch_compiler/advanced.md new file mode 100644 index 0000000000000..acfa3cd60a462 --- /dev/null +++ b/docs/source/user_guide/torch_compiler/advanced.md @@ -0,0 +1,13 @@ +# Advanced + +Deep dive into compiler internals, custom backends, transformations, and advanced features. + +```{toctree} +:maxdepth: 1 + +torch.compiler_dynamo_deepdive.md +torch.compiler_transformations.md +torch.compiler_fake_tensor.md +torch.compiler_custom_backends.md +torch.compiler_dynamic_shapes +``` diff --git a/docs/source/user_guide/torch_compiler/api_reference.md b/docs/source/user_guide/torch_compiler/api_reference.md new file mode 100644 index 0000000000000..aa3eec1b03797 --- /dev/null +++ b/docs/source/user_guide/torch_compiler/api_reference.md @@ -0,0 +1,12 @@ +# Reference/API + +Complete API documentation, configuration options, and fine-grained compiler controls. + +```{toctree} +:maxdepth: 1 + +../../torch.compiler_api.md +torch.compiler.config.md +torch.compiler_fine_grain_apis.md +torch.compiler_inductor_provenance.rst +``` diff --git a/docs/source/compile/_static/dynamo_summary_diagram.png b/docs/source/user_guide/torch_compiler/compile/_static/dynamo_summary_diagram.png similarity index 100% rename from docs/source/compile/_static/dynamo_summary_diagram.png rename to docs/source/user_guide/torch_compiler/compile/_static/dynamo_summary_diagram.png diff --git a/docs/source/compile/dynamic_shapes_advanced_control_options.md b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_advanced_control_options.md similarity index 96% rename from docs/source/compile/dynamic_shapes_advanced_control_options.md rename to docs/source/user_guide/torch_compiler/compile/dynamic_shapes_advanced_control_options.md index e822766817175..280d596afb20e 100644 --- a/docs/source/compile/dynamic_shapes_advanced_control_options.md +++ b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_advanced_control_options.md @@ -28,7 +28,7 @@ follow these steps using `tlparse`: 1. In the `tlparse` output, identify the line number of the frame of interest. Example: - ```{image} ../_static/img/dynamic_shapes/tlparse4_pgo.png + ```{image} ../../../_static/img/dynamic_shapes/tlparse4_pgo.png ``` 2. Open `local_code` using `put_local_code_state_` or `put_remote_code_state_` for the @@ -113,7 +113,7 @@ For example, in the following `tlparse` snapshot, Dynamo graphs 20/0, graph 20/0 vs. graph 20/2). In the Dynamo graph of 20/2, sizes `s0`, `s1`, and `s5` are used for `rotary_pos_emb_` and `x`. -```{image} ../_static/img/dynamic_shapes/tlparse5_dynamic_shapes.png +```{image} ../../../_static/img/dynamic_shapes/tlparse5_dynamic_shapes.png ``` ```{tip} @@ -147,12 +147,12 @@ Check the following: reason is size-related and not due to other factors. For example, while in these screenshot the recomplile reason is size-related: -```{image} ../_static/img/dynamic_shapes/tlparse6_size_related_recompilations.png +```{image} ../../../_static/img/dynamic_shapes/tlparse6_size_related_recompilations.png ``` In the one below it is not, which indicates that dynamic shapes won't resolve it: -```{image} ../_static/img/dynamic_shapes/tlparse7_not_size_related_recompilations.png +```{image} ../../../_static/img/dynamic_shapes/tlparse7_not_size_related_recompilations.png :width: 500px :align: center ``` @@ -215,7 +215,7 @@ call to a Triton kernel. To identify the reason for specialization: * **Using tlparse:** Check the `compilation_metrics` for a specialization section, which will indicate what got specialized and the user and framework stack when it happened. Example: - ```{image} ../_static/img/dynamic_shapes/tlparse8_compilation_metrics.png + ```{image} ../../../_static/img/dynamic_shapes/tlparse8_compilation_metrics.png ``` The log above indicates that `s0` is specialized to `33` due to the following code: diff --git a/docs/source/compile/dynamic_shapes_backed_unbacked.md b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_backed_unbacked.md similarity index 100% rename from docs/source/compile/dynamic_shapes_backed_unbacked.md rename to docs/source/user_guide/torch_compiler/compile/dynamic_shapes_backed_unbacked.md diff --git a/docs/source/compile/dynamic_shapes_beyond_the_basics.md b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_beyond_the_basics.md similarity index 100% rename from docs/source/compile/dynamic_shapes_beyond_the_basics.md rename to docs/source/user_guide/torch_compiler/compile/dynamic_shapes_beyond_the_basics.md diff --git a/docs/source/compile/dynamic_shapes_core_concepts.md b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_core_concepts.md similarity index 100% rename from docs/source/compile/dynamic_shapes_core_concepts.md rename to docs/source/user_guide/torch_compiler/compile/dynamic_shapes_core_concepts.md diff --git a/docs/source/compile/dynamic_shapes_debugging_tlparse_torch_logs.md b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_debugging_tlparse_torch_logs.md similarity index 95% rename from docs/source/compile/dynamic_shapes_debugging_tlparse_torch_logs.md rename to docs/source/user_guide/torch_compiler/compile/dynamic_shapes_debugging_tlparse_torch_logs.md index 46c7cb2daee4c..3fa2999823191 100644 --- a/docs/source/compile/dynamic_shapes_debugging_tlparse_torch_logs.md +++ b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_debugging_tlparse_torch_logs.md @@ -65,7 +65,7 @@ fn(x, y) To identify where dynamic shape guards originate, use `tlparse`. Here is an example tlparse output: -```{image} ../_static/img/dynamic_shapes/tlparse9_debugging_guards.png +```{image} ../../../_static/img/dynamic_shapes/tlparse9_debugging_guards.png ``` By clicking on the `dynamo_cpp_guards` link, you can view all guards from the compilation, including the symbolic shape guard `L['x'].size()[0] <= 9`. @@ -92,7 +92,7 @@ fn(x, y) Now, this compiled region can be used for inputs of size 0 and 1: -```{image} ../_static/img/dynamic_shapes/tlparse10_debugging_guards_unbacked.png +```{image} ../../../_static/img/dynamic_shapes/tlparse10_debugging_guards_unbacked.png ``` ```{seealso} diff --git a/docs/source/compile/dynamic_shapes_troubleshooting.md b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_troubleshooting.md similarity index 100% rename from docs/source/compile/dynamic_shapes_troubleshooting.md rename to docs/source/user_guide/torch_compiler/compile/dynamic_shapes_troubleshooting.md diff --git a/docs/source/compile/dynamic_shapes_troubleshooting_guardon_errors.md b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_troubleshooting_guardon_errors.md similarity index 100% rename from docs/source/compile/dynamic_shapes_troubleshooting_guardon_errors.md rename to docs/source/user_guide/torch_compiler/compile/dynamic_shapes_troubleshooting_guardon_errors.md diff --git a/docs/source/compile/dynamic_shapes_zero_one_specialization.md b/docs/source/user_guide/torch_compiler/compile/dynamic_shapes_zero_one_specialization.md similarity index 100% rename from docs/source/compile/dynamic_shapes_zero_one_specialization.md rename to docs/source/user_guide/torch_compiler/compile/dynamic_shapes_zero_one_specialization.md diff --git a/docs/source/compile/header_code.py b/docs/source/user_guide/torch_compiler/compile/header_code.py similarity index 100% rename from docs/source/compile/header_code.py rename to docs/source/user_guide/torch_compiler/compile/header_code.py diff --git a/docs/source/compile/programming_model.common_graph_breaks.md b/docs/source/user_guide/torch_compiler/compile/programming_model.common_graph_breaks.md similarity index 100% rename from docs/source/compile/programming_model.common_graph_breaks.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.common_graph_breaks.md diff --git a/docs/source/compile/programming_model.compiler_disable.md b/docs/source/user_guide/torch_compiler/compile/programming_model.compiler_disable.md similarity index 100% rename from docs/source/compile/programming_model.compiler_disable.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.compiler_disable.md diff --git a/docs/source/compile/programming_model.custom_ops.md b/docs/source/user_guide/torch_compiler/compile/programming_model.custom_ops.md similarity index 100% rename from docs/source/compile/programming_model.custom_ops.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.custom_ops.md diff --git a/docs/source/compile/programming_model.dynamo_core_concepts.md b/docs/source/user_guide/torch_compiler/compile/programming_model.dynamo_core_concepts.md similarity index 100% rename from docs/source/compile/programming_model.dynamo_core_concepts.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.dynamo_core_concepts.md diff --git a/docs/source/compile/programming_model.dynamo_nonstrict_trace.md b/docs/source/user_guide/torch_compiler/compile/programming_model.dynamo_nonstrict_trace.md similarity index 100% rename from docs/source/compile/programming_model.dynamo_nonstrict_trace.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.dynamo_nonstrict_trace.md diff --git a/docs/source/compile/programming_model.error_on_graph_break.md b/docs/source/user_guide/torch_compiler/compile/programming_model.error_on_graph_break.md similarity index 100% rename from docs/source/compile/programming_model.error_on_graph_break.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.error_on_graph_break.md diff --git a/docs/source/compile/programming_model.fullgraph_false.md b/docs/source/user_guide/torch_compiler/compile/programming_model.fullgraph_false.md similarity index 100% rename from docs/source/compile/programming_model.fullgraph_false.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.fullgraph_false.md diff --git a/docs/source/compile/programming_model.fullgraph_true.md b/docs/source/user_guide/torch_compiler/compile/programming_model.fullgraph_true.md similarity index 100% rename from docs/source/compile/programming_model.fullgraph_true.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.fullgraph_true.md diff --git a/docs/source/compile/programming_model.graph_breaks_index.md b/docs/source/user_guide/torch_compiler/compile/programming_model.graph_breaks_index.md similarity index 100% rename from docs/source/compile/programming_model.graph_breaks_index.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.graph_breaks_index.md diff --git a/docs/source/compile/programming_model.md b/docs/source/user_guide/torch_compiler/compile/programming_model.md similarity index 95% rename from docs/source/compile/programming_model.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.md index 0de06b6f62137..a5499300ad015 100644 --- a/docs/source/compile/programming_model.md +++ b/docs/source/user_guide/torch_compiler/compile/programming_model.md @@ -1,3 +1,5 @@ +(compile_programming_model)= + # torch.compile Programming Model The `torch.compile` programming model: diff --git a/docs/source/compile/programming_model.nested_graph_breaks.md b/docs/source/user_guide/torch_compiler/compile/programming_model.nested_graph_breaks.md similarity index 100% rename from docs/source/compile/programming_model.nested_graph_breaks.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.nested_graph_breaks.md diff --git a/docs/source/compile/programming_model.non_strict_tracing_model.md b/docs/source/user_guide/torch_compiler/compile/programming_model.non_strict_tracing_model.md similarity index 100% rename from docs/source/compile/programming_model.non_strict_tracing_model.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.non_strict_tracing_model.md diff --git a/docs/source/compile/programming_model.observability.md b/docs/source/user_guide/torch_compiler/compile/programming_model.observability.md similarity index 100% rename from docs/source/compile/programming_model.observability.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.observability.md diff --git a/docs/source/compile/programming_model.recompilation.md b/docs/source/user_guide/torch_compiler/compile/programming_model.recompilation.md similarity index 100% rename from docs/source/compile/programming_model.recompilation.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.recompilation.md diff --git a/docs/source/compile/programming_model.reporting_issues.md b/docs/source/user_guide/torch_compiler/compile/programming_model.reporting_issues.md similarity index 100% rename from docs/source/compile/programming_model.reporting_issues.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.reporting_issues.md diff --git a/docs/source/compile/programming_model.skipped_functions.md b/docs/source/user_guide/torch_compiler/compile/programming_model.skipped_functions.md similarity index 100% rename from docs/source/compile/programming_model.skipped_functions.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.skipped_functions.md diff --git a/docs/source/compile/programming_model.where_to_apply_compile.md b/docs/source/user_guide/torch_compiler/compile/programming_model.where_to_apply_compile.md similarity index 100% rename from docs/source/compile/programming_model.where_to_apply_compile.md rename to docs/source/user_guide/torch_compiler/compile/programming_model.where_to_apply_compile.md diff --git a/docs/source/user_guide/torch_compiler/core_concepts.md b/docs/source/user_guide/torch_compiler/core_concepts.md new file mode 100644 index 0000000000000..4355db69cb88f --- /dev/null +++ b/docs/source/user_guide/torch_compiler/core_concepts.md @@ -0,0 +1,13 @@ +# Core Concepts + +Understand how `torch.compile` works, including the programming model, graph breaks, and compilation behavior. + +```{toctree} +:maxdepth: 1 + +compile/programming_model.md +torch.compiler_dynamo_overview.md +torch.compiler_nn_module.md +torch.compiler_backward.md + +``` diff --git a/docs/source/export.md b/docs/source/user_guide/torch_compiler/export.md similarity index 99% rename from docs/source/export.md rename to docs/source/user_guide/torch_compiler/export.md index 2ab7d85303c0d..f2176da171869 100644 --- a/docs/source/export.md +++ b/docs/source/user_guide/torch_compiler/export.md @@ -660,8 +660,8 @@ export/ir_spec export/pt2_archive export/draft_export export/joint_with_descriptors -cond -generated/exportdb/index +../../cond +../../generated/exportdb/index torch.compiler_aot_inductor torch.compiler_ir ``` diff --git a/docs/source/export/api_reference.md b/docs/source/user_guide/torch_compiler/export/api_reference.md similarity index 100% rename from docs/source/export/api_reference.md rename to docs/source/user_guide/torch_compiler/export/api_reference.md diff --git a/docs/source/export/draft_export.md b/docs/source/user_guide/torch_compiler/export/draft_export.md similarity index 98% rename from docs/source/export/draft_export.md rename to docs/source/user_guide/torch_compiler/export/draft_export.md index b1ec6ca5d44e6..451747b3a91db 100644 --- a/docs/source/export/draft_export.md +++ b/docs/source/user_guide/torch_compiler/export/draft_export.md @@ -126,7 +126,7 @@ Running the `tlparse` command in the terminal will generate a [tlparse](https://github.com/pytorch/tlparse) HTML report. Here is an example of the `tlparse` report: -```{image} ../_static/img/export/draft_export_report.png +```{image} ../../../_static/img/export/draft_export_report.png ``` Clicking into the Data Dependent Error, we will see the following page which @@ -136,7 +136,7 @@ contains information to help debug this error. Specifically, it contains: - A list of local variables and their shapes - Information for how this guard was created -```{image} ../_static/img/export/draft_export_report_dde.png +```{image} ../../../_static/img/export/draft_export_report_dde.png ``` ## The returned Exported Program diff --git a/docs/source/export/ir_spec.md b/docs/source/user_guide/torch_compiler/export/ir_spec.md similarity index 100% rename from docs/source/export/ir_spec.md rename to docs/source/user_guide/torch_compiler/export/ir_spec.md diff --git a/docs/source/export/joint_with_descriptors.md b/docs/source/user_guide/torch_compiler/export/joint_with_descriptors.md similarity index 100% rename from docs/source/export/joint_with_descriptors.md rename to docs/source/user_guide/torch_compiler/export/joint_with_descriptors.md diff --git a/docs/source/export/programming_model.md b/docs/source/user_guide/torch_compiler/export/programming_model.md similarity index 100% rename from docs/source/export/programming_model.md rename to docs/source/user_guide/torch_compiler/export/programming_model.md diff --git a/docs/source/export/pt2_archive.md b/docs/source/user_guide/torch_compiler/export/pt2_archive.md similarity index 100% rename from docs/source/export/pt2_archive.md rename to docs/source/user_guide/torch_compiler/export/pt2_archive.md diff --git a/docs/source/user_guide/torch_compiler/performance.md b/docs/source/user_guide/torch_compiler/performance.md new file mode 100644 index 0000000000000..3cc4c15f7328f --- /dev/null +++ b/docs/source/user_guide/torch_compiler/performance.md @@ -0,0 +1,12 @@ +# Performance + +Learn how to profile, benchmark, and optimize your models with `torch.compile`. + +```{toctree} +:maxdepth: 1 + +torch.compiler_performance_dashboard.md +torch.compiler_inductor_profiling.md +torch.compiler_profiling_torch_compile.md +torch.compiler_cudagraph_trees.md +``` diff --git a/docs/source/torch.compiler.config.md b/docs/source/user_guide/torch_compiler/torch.compiler.config.md similarity index 100% rename from docs/source/torch.compiler.config.md rename to docs/source/user_guide/torch_compiler/torch.compiler.config.md diff --git a/docs/source/torch.compiler.md b/docs/source/user_guide/torch_compiler/torch.compiler.md similarity index 82% rename from docs/source/torch.compiler.md rename to docs/source/user_guide/torch_compiler/torch.compiler.md index 11e22aae4cf3f..d6fff109bc118 100644 --- a/docs/source/torch.compiler.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler.md @@ -80,50 +80,48 @@ Some of the most commonly used backends include: - Uses OpenVINO for inference optimizations. `Read more `__ ``` -## Read More + + + +```{toctree} +:maxdepth: 1 +:hidden: + +torch.compiler_get_started.md +``` + +```{toctree} +:maxdepth: 1 +:hidden: + +core_concepts +``` ```{toctree} -:caption: Getting Started for PyTorch Users -:maxdepth: 2 - -torch.compiler_get_started -torch.compiler_api -torch.compiler.config -torch.compiler_dynamic_shapes -torch.compiler_fine_grain_apis -torch.compiler_backward -torch.compiler_aot_inductor -torch.compiler_inductor_profiling -torch.compiler_profiling_torch_compile -torch.compiler_faq -torch.compiler_troubleshooting -torch.compiler_performance_dashboard -torch.compiler_inductor_provenance +:maxdepth: 1 +:hidden: + +performance ``` ```{toctree} -:caption: torch.compile Programming Model -:maxdepth: 2 +:maxdepth: 1 +:hidden: -compile/programming_model +advanced ``` ```{toctree} -:caption: Deep Dive for PyTorch Developers :maxdepth: 1 +:hidden: + -torch.compiler_dynamo_overview -torch.compiler_dynamo_deepdive -torch.compiler_nn_module -torch.compiler_cudagraph_trees -torch.compiler_fake_tensor +troubleshooting_faqs ``` ```{toctree} -:caption: HowTo for PyTorch Backend Vendors :maxdepth: 1 +:hidden: -torch.compiler_custom_backends -torch.compiler_transformations -torch.compiler_ir +api_reference ``` diff --git a/docs/source/torch.compiler_aot_inductor.md b/docs/source/user_guide/torch_compiler/torch.compiler_aot_inductor.md similarity index 99% rename from docs/source/torch.compiler_aot_inductor.md rename to docs/source/user_guide/torch_compiler/torch.compiler_aot_inductor.md index e1de040114915..257deb73fc57a 100644 --- a/docs/source/torch.compiler_aot_inductor.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_aot_inductor.md @@ -199,7 +199,7 @@ Below are some useful tools for debugging AOT Inductor. :caption: Debugging Tools :maxdepth: 1 -logging +../../logging torch.compiler_aot_inductor_minifier torch.compiler_aot_inductor_debugging_guide ``` diff --git a/docs/source/torch.compiler_aot_inductor_debugging_guide.md b/docs/source/user_guide/torch_compiler/torch.compiler_aot_inductor_debugging_guide.md similarity index 98% rename from docs/source/torch.compiler_aot_inductor_debugging_guide.md rename to docs/source/user_guide/torch_compiler/torch.compiler_aot_inductor_debugging_guide.md index 331e1abd886a0..cb29e7f699c5c 100644 --- a/docs/source/torch.compiler_aot_inductor_debugging_guide.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_aot_inductor_debugging_guide.md @@ -34,7 +34,7 @@ CUDA_LAUNCH_BLOCKING=1 These flags take effect at runtime: - `PYTORCH_NO_CUDA_MEMORY_CACHING=1` disables PyTorch's Caching Allocator, which allocates a bigger buffer than needed immediately to reduce the number of buffer allocations. This is usually the reason why CUDA illegal memory access errors are non-deterministic. -![How PyTorch's caching allocator can mask CUDA illegal memory access errors](./_static/img/aoti_debugging_guide/cuda_ima_cca.png) +![How PyTorch's caching allocator can mask CUDA illegal memory access errors](../../_static/img/aoti_debugging_guide/cuda_ima_cca.png) *Figure: How PyTorch's caching allocator can mask CUDA illegal memory access errors* - `CUDA_LAUNCH_BLOCKING=1` forces the kernels to launch one at a time. Without this, we would get the famous "CUDA kernel errors might be asynchronously reported at some other API call" warning since kernels are launched asynchronously. diff --git a/docs/source/torch.compiler_aot_inductor_minifier.md b/docs/source/user_guide/torch_compiler/torch.compiler_aot_inductor_minifier.md similarity index 100% rename from docs/source/torch.compiler_aot_inductor_minifier.md rename to docs/source/user_guide/torch_compiler/torch.compiler_aot_inductor_minifier.md diff --git a/docs/source/torch.compiler_backward.md b/docs/source/user_guide/torch_compiler/torch.compiler_backward.md similarity index 99% rename from docs/source/torch.compiler_backward.md rename to docs/source/user_guide/torch_compiler/torch.compiler_backward.md index 27cd66dc419c8..a596bfd6038fc 100644 --- a/docs/source/torch.compiler_backward.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_backward.md @@ -1,3 +1,5 @@ +(compiler_backward)= + ``torch.compile`` has different autograd semantics ================================================== diff --git a/docs/source/torch.compiler_cudagraph_trees.md b/docs/source/user_guide/torch_compiler/torch.compiler_cudagraph_trees.md similarity index 100% rename from docs/source/torch.compiler_cudagraph_trees.md rename to docs/source/user_guide/torch_compiler/torch.compiler_cudagraph_trees.md diff --git a/docs/source/torch.compiler_custom_backends.md b/docs/source/user_guide/torch_compiler/torch.compiler_custom_backends.md similarity index 100% rename from docs/source/torch.compiler_custom_backends.md rename to docs/source/user_guide/torch_compiler/torch.compiler_custom_backends.md diff --git a/docs/source/torch.compiler_dynamic_shapes.md b/docs/source/user_guide/torch_compiler/torch.compiler_dynamic_shapes.md similarity index 95% rename from docs/source/torch.compiler_dynamic_shapes.md rename to docs/source/user_guide/torch_compiler/torch.compiler_dynamic_shapes.md index 22cb482cd20bd..a14d7c9029040 100644 --- a/docs/source/torch.compiler_dynamic_shapes.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_dynamic_shapes.md @@ -71,7 +71,7 @@ f(torch.rand(40)) ``` In the produced output, you can see that four graphs were generated. -See the corresponding tlparse output +See the corresponding tlparse output By making the size dynamic, the function can handle various sizes without recompilation: @@ -88,7 +88,7 @@ f(torch.rand(40)) ``` With dynamic shapes enabled, only one graph is created. See the -corresponding tlparse output. +corresponding tlparse output. While compilation time differences are minimal for this small example, more complex use cases would show significant @@ -129,12 +129,12 @@ In the code above, we specialize that the graph requires an input size of 10, in case it will return `x * 10`. If the input size is less than 30, it will return `x * 200`. In the output, you can see that this creates three graphs. -See the corresponding tlparse output +See the corresponding tlparse output This is how graphs created for the above function: -```{image} _static/img/dynamic_shapes/dynamic_shapes_example_specialization.png +```{image} ../../_static/img/dynamic_shapes/dynamic_shapes_example_specialization.png ``` (enable-dynamic-behavior)= diff --git a/docs/source/torch.compiler_dynamo_deepdive.md b/docs/source/user_guide/torch_compiler/torch.compiler_dynamo_deepdive.md similarity index 100% rename from docs/source/torch.compiler_dynamo_deepdive.md rename to docs/source/user_guide/torch_compiler/torch.compiler_dynamo_deepdive.md diff --git a/docs/source/torch.compiler_dynamo_overview.md b/docs/source/user_guide/torch_compiler/torch.compiler_dynamo_overview.md similarity index 98% rename from docs/source/torch.compiler_dynamo_overview.md rename to docs/source/user_guide/torch_compiler/torch.compiler_dynamo_overview.md index 6baf75058a8e4..7ba68ad0c42f9 100644 --- a/docs/source/torch.compiler_dynamo_overview.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_dynamo_overview.md @@ -1,3 +1,5 @@ +(dynamo_overview)= + # Dynamo Overview Before you read this section, read {ref}`torch.compiler_overview`. @@ -20,7 +22,7 @@ backends to make PyTorch code faster with a single line decorator The following diagram demonstrates how PyTorch works with `torch.compile` and without it: -```{image} _static/img/dynamo/TorchDynamo.png +```{image} ../../_static/img/dynamo/TorchDynamo.png ``` `TorchInductor` is one of the backends @@ -327,7 +329,7 @@ def compiled_example(a, b): The following diagram demonstrates how `torch.compile` transforms and optimizes user-written code: it first extracts computation graphs from the user-written function, and compiles these graphs into optimized functions, then assembles them into a new function, which is functionally equivalent to the user-written code but optimized to have a good computation speed. -```{image} _static/img/dynamo/flowchart.jpg +```{image} ../../_static/img/dynamo/flowchart.jpg ``` -To learn more about how all this is implemented internally, see {ref}`torch.compiler_dynamo_deepdive`. \ No newline at end of file +To learn more about how all this is implemented internally, see {ref}`torch.compiler_dynamo_deepdive`. diff --git a/docs/source/torch.compiler_fake_tensor.md b/docs/source/user_guide/torch_compiler/torch.compiler_fake_tensor.md similarity index 100% rename from docs/source/torch.compiler_fake_tensor.md rename to docs/source/user_guide/torch_compiler/torch.compiler_fake_tensor.md diff --git a/docs/source/torch.compiler_faq.md b/docs/source/user_guide/torch_compiler/torch.compiler_faq.md similarity index 99% rename from docs/source/torch.compiler_faq.md rename to docs/source/user_guide/torch_compiler/torch.compiler_faq.md index 7a8eaaa5215fa..7aeddc0cf4b28 100644 --- a/docs/source/torch.compiler_faq.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_faq.md @@ -621,10 +621,10 @@ might need even finer control. Suppose you want to disable the tracing on just the `a_fn` function, but want to continue the tracing back in `aa_fn` and `ab_fn`. The image below demonstrates this use case: -:::{figure} _static/img/fine_grained_apis/call_stack_diagram.png +:::{figure} ../../_static/img/fine_grained_apis/call_stack_diagram.png :alt: diagram of torch.compile + disable(a_fn, recursive=False) ::: In this case, you can use `torch._dynamo.disable(recursive=False)`. In previous versions, this functionality was provided by `torch._dynamo.skip`. -This is now supported by the `recursive` flag inside `torch._dynamo.disable`. \ No newline at end of file +This is now supported by the `recursive` flag inside `torch._dynamo.disable`. diff --git a/docs/source/torch.compiler_fine_grain_apis.md b/docs/source/user_guide/torch_compiler/torch.compiler_fine_grain_apis.md similarity index 99% rename from docs/source/torch.compiler_fine_grain_apis.md rename to docs/source/user_guide/torch_compiler/torch.compiler_fine_grain_apis.md index fc4768ce2ebc0..7aa0044facabd 100644 --- a/docs/source/torch.compiler_fine_grain_apis.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_fine_grain_apis.md @@ -38,7 +38,7 @@ disable compilation are listed in the following table: TorchDynamo intercepts the execution of each Python function frame. So, suppose you have a code structure (image below) where the function `fn` calls functions `a_fn` and `b_fn`. And `a_fn` calls `aa_fn` and `ab_fn`. When you use the PyTorch eager mode rather than `torch.compile`, these function frames run as is. With `torch.compile`, TorchDynamo intercepts each of these function frames (indicated by the green color): -:::{figure} _static/img/fine_grained_apis/api_diagram.png +:::{figure} ../../_static/img/fine_grained_apis/api_diagram.png :alt: Callstack diagram of different apis. ::: diff --git a/docs/source/torch.compiler_get_started.md b/docs/source/user_guide/torch_compiler/torch.compiler_get_started.md similarity index 99% rename from docs/source/torch.compiler_get_started.md rename to docs/source/user_guide/torch_compiler/torch.compiler_get_started.md index adbc2184df250..c9182d16364ad 100644 --- a/docs/source/torch.compiler_get_started.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_get_started.md @@ -145,4 +145,4 @@ basic understanding of how torch.compile works. Here is what you check out next: - [torch.compile tutorial on training](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) - {ref}`torch.compiler_api` -- {ref}`torchdynamo_fine_grain_tracing` \ No newline at end of file +- {ref}`torchdynamo_fine_grain_tracing` diff --git a/docs/source/torch.compiler_inductor_profiling.md b/docs/source/user_guide/torch_compiler/torch.compiler_inductor_profiling.md similarity index 95% rename from docs/source/torch.compiler_inductor_profiling.md rename to docs/source/user_guide/torch_compiler/torch.compiler_inductor_profiling.md index c8e69e836b957..a0956e94dabb1 100644 --- a/docs/source/torch.compiler_inductor_profiling.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_inductor_profiling.md @@ -86,7 +86,7 @@ In the output, you can notice the following: Loading the trace into Chrome (visit chrome://tracing in the chrome browser and load the file as the UI suggested) will show UI as follows: - ```{image} _static/img/inductor_profiling/trace.png + ```{image} ../../_static/img/inductor_profiling/trace.png ``` You can zoom in and out to check the profile. @@ -124,7 +124,7 @@ In the output, you can notice the following: * We also call zoom into a certain category of kernels. For example, let’s check reduction kernels: - ```{image} _static/img/inductor_profiling/kernel_breakdown.png + ```{image} ../../_static/img/inductor_profiling/kernel_breakdown.png ``` We can see an ordered table of execution time for each individual @@ -149,7 +149,7 @@ We can lookup the kernel name in the ``fwd.py``, and find comment like: **# kernel path: /tmp/torchinductor_shunting/jk/cjk2vm3446xrk7rth7hr6pun7xxo3dnzubwcn6ydrpifal4eykrz.py** -```{image} _static/img/inductor_profiling/inductor_code.png +```{image} ../../_static/img/inductor_profiling/inductor_code.png ``` I’ll rename it k.py for convenience. Here is a paste for this [file](https://gist.github.com/shunting314/96a0afef9dce53d6357bf1633094f358). @@ -159,7 +159,7 @@ benchmark. Run ``k.py`` directly will report its execution time and bandwidth: - ```{image} _static/img/inductor_profiling/terminal_printout.png + ```{image} ../../_static/img/inductor_profiling/terminal_printout.png ``` We can check if max-autotune helps this kernel, by running: diff --git a/docs/source/torch.compiler_inductor_provenance.rst b/docs/source/user_guide/torch_compiler/torch.compiler_inductor_provenance.rst similarity index 88% rename from docs/source/torch.compiler_inductor_provenance.rst rename to docs/source/user_guide/torch_compiler/torch.compiler_inductor_provenance.rst index f20dfb40b2066..508062f38c3ad 100644 --- a/docs/source/torch.compiler_inductor_provenance.rst +++ b/docs/source/user_guide/torch_compiler/torch.compiler_inductor_provenance.rst @@ -15,10 +15,10 @@ The yellow highlighting shows the provenance of the nodes/kernels. Example screenshot of the provenance tracking tool for TorchInductor: - .. image:: _static/img/inductor_provenance/provenance_jit_inductor.png + .. image:: ../../_static/img/inductor_provenance/provenance_jit_inductor.png Example screenshot of the provenance tracking tool for AOTInductor: - .. image:: _static/img/inductor_provenance/provenance_aot_inductor.png + .. image:: ../../_static/img/inductor_provenance/provenance_aot_inductor.png Using the Provenance Tracking Highlighter @@ -53,7 +53,7 @@ Follow these steps to enable and use provenance tracking in your PyTorch project After running ``tlparse --inductor-provenance``, you should see an additional "Provenance Tracking" section in the tlparse output. Clicking into the link(s) to access the provenance tracking tool. For a demo, see: https://github.com/pytorch/tlparse/pull/93 - .. image:: _static/img/inductor_provenance/index.png + .. image:: ../../_static/img/inductor_provenance/index.png Source code corresponding to each Inductor kernel @@ -61,17 +61,17 @@ Source code corresponding to each Inductor kernel With ``INDUCTOR_PROVENANCE=1``, you can also view the source code corresponding to each Inductor kernel in tlparse. To access it, click the "readable_html" link next to "inductor_provenance_tracking_kernel_stack_traces.json" in the tlparse output. - .. image:: _static/img/inductor_provenance/index_2.png + .. image:: ../../_static/img/inductor_provenance/index_2.png Below are some example screenshots. The ``:1`` and ``:467`` suffixes at the end of the kernel names are used to distinguish different calls to the same kernel. We refer to these suffixes as debug handles. - .. image:: _static/img/inductor_provenance/kernel_source_1.png - .. image:: _static/img/inductor_provenance/kernel_source_2.png + .. image:: ../../_static/img/inductor_provenance/kernel_source_1.png + .. image:: ../../_static/img/inductor_provenance/kernel_source_2.png You can also find the debug handle in the comments within the kernel source code. - .. image:: _static/img/inductor_provenance/kernel_source_3.png + .. image:: ../../_static/img/inductor_provenance/kernel_source_3.png See Also diff --git a/docs/source/torch.compiler_ir.md b/docs/source/user_guide/torch_compiler/torch.compiler_ir.md similarity index 93% rename from docs/source/torch.compiler_ir.md rename to docs/source/user_guide/torch_compiler/torch.compiler_ir.md index ff66b8cc7efce..4aa439165d043 100644 --- a/docs/source/torch.compiler_ir.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_ir.md @@ -17,7 +17,7 @@ This opset is designed to serve as the functional IR to interface with backends. ``` ```{csv-table} - :file: ../build/ir/aten_ops.csv + :file: ../../../build/ir/aten_ops.csv :widths: auto :header-rows: 1 ``` @@ -34,7 +34,7 @@ This opset is designed to interface with compiler backends. ``` ```{csv-table} - :file: ../build/ir/prims_ops.csv + :file: ../../../build/ir/prims_ops.csv :widths: auto :header-rows: 1 ``` diff --git a/docs/source/torch.compiler_nn_module.md b/docs/source/user_guide/torch_compiler/torch.compiler_nn_module.md similarity index 98% rename from docs/source/torch.compiler_nn_module.md rename to docs/source/user_guide/torch_compiler/torch.compiler_nn_module.md index a694e2c88dbd6..4da3220860d07 100644 --- a/docs/source/torch.compiler_nn_module.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_nn_module.md @@ -1,3 +1,5 @@ +(compiler_nn_module)= + # PyTorch 2.0 NNModule Support **Author**: [Will Constable](https://github.com/wconstab) @@ -56,4 +58,4 @@ TODO: confirm if backward/pre_backward hooks are working or not and document acc State dict hooks have not yet been supported in `torch.compile`. -TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present. \ No newline at end of file +TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present. diff --git a/docs/source/torch.compiler_performance_dashboard.md b/docs/source/user_guide/torch_compiler/torch.compiler_performance_dashboard.md similarity index 100% rename from docs/source/torch.compiler_performance_dashboard.md rename to docs/source/user_guide/torch_compiler/torch.compiler_performance_dashboard.md diff --git a/docs/source/torch.compiler_profiling_torch_compile.md b/docs/source/user_guide/torch_compiler/torch.compiler_profiling_torch_compile.md similarity index 94% rename from docs/source/torch.compiler_profiling_torch_compile.md rename to docs/source/user_guide/torch_compiler/torch.compiler_profiling_torch_compile.md index 9c1a215920abf..25537ca2501e6 100644 --- a/docs/source/torch.compiler_profiling_torch_compile.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_profiling_torch_compile.md @@ -45,7 +45,7 @@ See also the [general pytorch profiler guide](https://pytorch.org/tutorials/reci **Viewing chrome traces**: In the Chrome browser, open chrome://tracing and load the json file. Use the “w” and “s” keys to zoom in and out, and use “a” and “d” to scroll left and right. “?” will show a “help” screen with a list of shortcuts. -```{figure} _static/img/profiling_torch_compile/basic_chrome_trace.png +```{figure} ../../_static/img/profiling_torch_compile/basic_chrome_trace.png :alt: Example of a basic chrome trace, visualized in the chrome://tracing viewer ``` @@ -59,7 +59,7 @@ Every kernel on the accelerator occurs after being launched by code running on t To view a flow connection, click on a GPU kernel and click “ac2g”: -```{figure} _static/img/profiling_torch_compile/ac2g.png +```{figure} ../../_static/img/profiling_torch_compile/ac2g.png :alt: Visualization in the chrome://trace viewer, showing an async flow between a kernel and its launching location. ``` @@ -121,7 +121,7 @@ See an example below: prof.export_chrome_trace("trace_compile.json") ``` -```{figure} _static/img/profiling_torch_compile/compilation_profiling.png +```{figure} ../../_static/img/profiling_torch_compile/compilation_profiling.png :alt: A visualization in the chrome://trace viewer, showing dynamo and inductor compilation steps ``` @@ -198,7 +198,7 @@ See the synthetic example below for a demonstration: prof.export_chrome_trace("trace_break.json") ``` -```{figure} _static/img/profiling_torch_compile/graph_breaks_with_torch_compiled_region.png +```{figure} ../../_static/img/profiling_torch_compile/graph_breaks_with_torch_compiled_region.png :alt: Visualization in the chrome://trace viewer, showing nested Torch-Compiled Region events and multiple CompiledFunction events - indicating graph breaks. ``` @@ -210,7 +210,7 @@ When an operator is launched, we expect to see a few events: 2. Kernel launch (if dealing with a GPU kernel) 3. GPU-side event -```{figure} _static/img/profiling_torch_compile/kernel_launch_labeled.png +```{figure} ../../_static/img/profiling_torch_compile/kernel_launch_labeled.png :alt: Visualization in the chrome://trace viewer, showing the three types of events - CPU-side event, kernel launch, and GPU-side event ``` @@ -219,7 +219,7 @@ When an operator is launched, we expect to see a few events: 2. The **kernel launch** should appear as cuLaunchKernel instead of cudaLaunchKernel (cudaLaunchKernel is typical for aten ops) 3. The **GPU-side event** should appear, and how descriptive the name will be depends on the inductor config for unique_kernel_names -```{figure} _static/img/profiling_torch_compile/triton_kernel_launch.png +```{figure} ../../_static/img/profiling_torch_compile/triton_kernel_launch.png ``` **Non-Inductor generated Triton kernels:** @@ -228,7 +228,7 @@ When an operator is launched, we expect to see a few events: 2. The **kernel launch** should appear s cuLaunchKernel instead of cudaLaunchKernel (cudaLaunchKernel is typical for aten ops) 3. The **GPU-side** event should appear, named similarly to the triton kernel that was authored. -```{figure} _static/img/profiling_torch_compile/noninductor_triton_kernel.png +```{figure} ../../_static/img/profiling_torch_compile/noninductor_triton_kernel.png ``` **Inductor-generated CPU kernels:** @@ -243,7 +243,7 @@ When an operator is launched, we expect to see a few events: One common issue is bad GPU utilization. A quick way to identify this is if there are large gaps between kernels on the GPU: -```{figure} _static/img/profiling_torch_compile/cpu_bound.png +```{figure} ../../_static/img/profiling_torch_compile/cpu_bound.png :alt: Visualization in the chrome://trace viewer, showing large gaps between GPU kernels. This indicates that the model is CPU bound, likely due to overhead during kernel launches. ``` diff --git a/docs/source/torch.compiler_transformations.md b/docs/source/user_guide/torch_compiler/torch.compiler_transformations.md similarity index 100% rename from docs/source/torch.compiler_transformations.md rename to docs/source/user_guide/torch_compiler/torch.compiler_transformations.md diff --git a/docs/source/torch.compiler_troubleshooting.md b/docs/source/user_guide/torch_compiler/torch.compiler_troubleshooting.md similarity index 99% rename from docs/source/torch.compiler_troubleshooting.md rename to docs/source/user_guide/torch_compiler/torch.compiler_troubleshooting.md index a4f7af3b9b8e9..ded51073c3d93 100644 --- a/docs/source/torch.compiler_troubleshooting.md +++ b/docs/source/user_guide/torch_compiler/torch.compiler_troubleshooting.md @@ -816,7 +816,7 @@ to debug real `torch.compile` issues. Below is a high-level overview of the stack: -![Torch Dynamo Stack](_static/img/dynamo/td_stack.png) +![Torch Dynamo Stack](../../_static/img/dynamo/td_stack.png) The stack comprises three main components: TorchDynamo, AOTAutograd, and Inductor. Our debugging strategy involves first identifying the component in which the error occurs diff --git a/docs/source/user_guide/torch_compiler/troubleshooting_faqs.md b/docs/source/user_guide/torch_compiler/troubleshooting_faqs.md new file mode 100644 index 0000000000000..263bc25cd0fac --- /dev/null +++ b/docs/source/user_guide/torch_compiler/troubleshooting_faqs.md @@ -0,0 +1,13 @@ +# Troubleshooting FAQs + +Find solutions to common issues, debugging guides, and answers to frequently asked questions. + +```{toctree} +:maxdepth: 1 + +compile/programming_model.observability +compile/programming_model.reporting_issues +torch.compiler_troubleshooting.md +torch.compiler_faq.md + +``` From eb5c63652a33da42e7018c23df5f20a3eb4c6ccf Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 3 Dec 2025 23:06:29 +0000 Subject: [PATCH 210/338] Generalize GraphView in manual bucketing (#169426) Generalize the GraphView construction so it doesn't have to rely on `nn_module_stack`. `nn_module_stack` can be inaccurate sometimes if the module calls are wrapped in closures. To mitigate this, we allow users to use fx_annotate to annotate their modules and use annotation to construct GraphView. also, we preserve annotation and stack trace on newly created nodes, in addition to nn_module_stack node meta Pull Request resolved: https://github.com/pytorch/pytorch/pull/169426 Approved by: https://github.com/eellison --- .../test_aten_comm_compute_reordering.py | 200 +++++++++++++++++- torch/_inductor/fx_passes/bucketing.py | 59 ++++++ torch/_inductor/fx_passes/graph_view.py | 48 ++++- .../fx_passes/overlap_manual_scheduling.py | 27 ++- 4 files changed, 320 insertions(+), 14 deletions(-) diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 60488496d0ffb..fb64e77f5bebf 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -1300,7 +1300,9 @@ def forward(self, x): return model -def apply_manual_reordering_and_get_graph(graph, module_bucket_plans, out_li) -> None: +def apply_manual_reordering_and_get_graph( + graph, module_bucket_plans, out_li, custom_module_stack_fn=None +) -> None: gm = graph.owning_module from torch._inductor.fx_passes.overlap_manual_scheduling import ( ManualOverlapScheduler, @@ -1323,18 +1325,24 @@ def apply_manual_reordering_and_get_graph(graph, module_bucket_plans, out_li) -> node.meta["nn_module_stack"] = {"test": ["module_2", ""]} overlapped_gm = ManualOverlapScheduler( - gm, module_bucket_plans, insert_overlap_deps=False + gm, + module_bucket_plans, + insert_overlap_deps=False, + module_stack_fn=custom_module_stack_fn, ).run() overlapped_gm.graph.lint() out_li.append(overlapped_gm.graph) -def run_and_get_manual_aten_graph(fn, module_bucket_plans, *inputs): +def run_and_get_manual_aten_graph( + fn, module_bucket_plans, *inputs, custom_module_stack_fn=None +): li = [] apply = functools.partial( apply_manual_reordering_and_get_graph, module_bucket_plans=module_bucket_plans, out_li=li, + custom_module_stack_fn=custom_module_stack_fn, ) with torch._inductor.config.patch(post_grad_custom_post_pass=apply): out = fn(*inputs) @@ -1377,6 +1385,77 @@ def test_make_graph_view_and_get_subgraph_by_path(self): ) self.assertEqual([n.name for n in mixed_nodes], ["layers_0_wq"]) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + def test_make_graph_view_and_get_subgraph_by_path_custom_module_stack_fn(self): + from torch._dynamo.functional_export import dynamo_graph_capture_for_export + from torch._inductor.fx_passes.graph_view import ( + get_subgraph_by_path, + make_graph_view, + ) + + model = get_toy_model(device_type) + + module_path_key = "module_path" + # Add annotation to node.meta["custom"] + for name, m in model.named_modules(): + m.forward = torch.fx.traceback.annotate_fn({module_path_key: name})( + m.forward + ) + + def module_stack_fn(node): + module_stack = node.meta.get("custom", {}).get(module_path_key, "") + return [(module_stack, torch.nn.Module)] + + gm = dynamo_graph_capture_for_export(model)(torch.randn(2, 4).to(device_type)) + + # delete "nn_module_stack" to make sure the graph view is only constructed from annotation + for n in gm.graph.nodes: + if "nn_module_stack" in n.meta: + del n.meta["nn_module_stack"] + + graph_view = make_graph_view(gm.graph, module_stack_fn=module_stack_fn) + # Fetch subgraph for first transformer layer + sub_nodes = get_subgraph_by_path(graph_view, "layers.0.wq") + self.assertEqual( + [n.name for n in sub_nodes], + [ + "l_func_self_modules_layers_modules_0_modules_wq_parameters_weight_", + "l_func_self_modules_layers_modules_0_modules_wq_parameters_bias_", + "linear", + ], + ) + + # Fetch multiple paths at once + multi_nodes = get_subgraph_by_path(graph_view, ["layers.0.wq", "layers.0.proj"]) + self.assertEqual( + [n.name for n in multi_nodes], + [ + "l_func_self_modules_layers_modules_0_modules_wq_parameters_weight_", + "l_func_self_modules_layers_modules_0_modules_wq_parameters_bias_", + "linear", + "l_func_self_modules_layers_modules_0_modules_proj_parameters_weight_", + "l_func_self_modules_layers_modules_0_modules_proj_parameters_bias_", + "x", + ], + ) + + # Fetch non existing paths + non_exist_nodes = get_subgraph_by_path(graph_view, "nonexistent.module.path") + self.assertEqual(non_exist_nodes, []) + + # Fetch mixed of existing and non existing paths + mixed_nodes = get_subgraph_by_path( + graph_view, ["layers.0.wq", "nonexistent.module.path"] + ) + self.assertEqual( + [n.name for n in mixed_nodes], + [ + "l_func_self_modules_layers_modules_0_modules_wq_parameters_weight_", + "l_func_self_modules_layers_modules_0_modules_wq_parameters_bias_", + "linear", + ], + ) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_manual_reordering_bucketing_pass_separate_buckets( self, @@ -1569,6 +1648,121 @@ def func(a, b, c, d, *, ranks): correct = func(a, b, c, d, ranks=ranks) self.assertTrue(same(out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + def test_bucketing_reordering_pass_single_bucket_custom_module_stack_fn( + self, + ): + module_path_key = "module_path" + + def module_stack_fn(node): + module_stack = node.meta.get("custom", {}).get(module_path_key, "") + return [(module_stack, torch.nn.Module)] + + def func(a, b, c, d, *, ranks): + # All 4 all-gathers are independent - COULD be bucketed together + with torch.fx.traceback.annotate({module_path_key: "my_module_1"}): + ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) + ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) + with torch.fx.traceback.annotate({module_path_key: "my_module_2"}): + ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks) + ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks) + + # First compute - can hide ag1 and ag2 + e = a * 5 # Use a to avoid fusion + mm1 = torch.matmul(e, e.T) + + # Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred) + # Use first 8x8 elements to match mm1's shape + intermediate = ag1[:8, :8] + ag2[:8, :8] + + # Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4 + mm2 = torch.matmul(mm1 + intermediate, c[:8]) + + # Use all results + result = ( + ag1.sum() * 1.1 + + ag2.sum() * 1.2 + + ag3.sum() * 1.3 + + ag4.sum() * 1.4 + + mm1.sum() + + mm2.sum() + ) + return result + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + a = torch.ones(8, 8, dtype=torch.float, device=device_type) + b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2 + c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3 + d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4 + ranks = list(range(self.world_size)) + + func_c = functools.partial(func, ranks=ranks) + compiled = torch.compile(func_c) + out, aten_graph = run_and_get_manual_aten_graph( + compiled, + [["my_module_1", "my_module_2"]], + a, + b, + c, + d, + custom_module_stack_fn=module_stack_fn, + ) + + ( + FileCheck() + .check("_pre_bucket_all_gather") + .check("all_gather_into_tensor_out") + .check("wait_tensor_4") + .run(str(aten_graph)) + ) + + correct = func(a, b, c, d, ranks=ranks) + self.assertTrue(same(out, correct)) + + # Add metadata to the collective nodes to test preservation + test_metadata = { + "nn_module_stack": { + "test": ("module_1", ""), + }, + "custom": { + "module_path": "my_module_1", + }, + } + + # Verify metadata preservation: new bucketed nodes should have the metadata + new_ag_nodes = aten_graph.find_nodes( + op="call_function", + target=torch.ops.bucketing._pre_bucket_all_gather.default, + ) + new_wait_nodes = aten_graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.wait_tensor.default, + ) + + all_new_nodes = list(new_ag_nodes) + list(new_wait_nodes) + self.assertGreater(len(all_new_nodes), 0, "Should have created new nodes") + + for node in all_new_nodes: + self.assertEqual( + node.meta.get("nn_module_stack"), test_metadata["nn_module_stack"] + ) + self.assertEqual(node.meta.get("custom"), test_metadata["custom"]) + self.assertTrue(node.meta.get("stack_trace", None) is not None) + self.assertTrue( + node.meta.get("bucketing_stack_trace_sources", None) is not None + ) + self.assertTrue( + node.meta.get("bucketing_custom_sources", None) is not None + ) + self.assertTrue( + node.meta.get("bucketing_nn_module_stack_sources", None) is not None + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index aba2c5182264a..e72cdccddb440 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -17,12 +17,15 @@ from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._logging import trace_structured from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.traceback import NodeSource, NodeSourceAction from torch.utils._ordered_set import OrderedSet logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") + BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"] @@ -74,6 +77,53 @@ def _schedulable_wait_node(node: torch.fx.Node) -> bool: return is_callable and is_collective +def _populate_node_meta( + bucket_nodes: list[torch.fx.Node], new_nodes: list[torch.fx.Node] +): + if bucket_nodes: + for n in new_nodes: + # For the following keys, we only store the information of the first node so + # gm.print_readable shows some information + # Full information are stored in "bucketing_{key}_sources" + for key, default in [ + ("nn_module_stack", ""), + ("fwd_nn_module_stack", ""), + ("stack_trace", ""), + ("custom", {}), + ]: + n.meta[key] = bucket_nodes[0].meta.get(key, default) + + # Collect sources from all bucket nodes for this metadata key, for debugging purposes only + bucketing_sources_key = f"bucketing_{key}_sources" + # Use set to remove duplicates + if key == "stack_trace": + sources = OrderedSet( + [ + node.meta.get(key, default) + for node in bucket_nodes + if node.meta.get(key, default) + ] + ) + else: + # type might not be hashable + sources = [ + node.meta.get(key, default) + for node in bucket_nodes + if node.meta.get(key, default) + ] + n.meta[bucketing_sources_key] = sources + + # used by inductor provenance tracking + n.meta["from_node"] = [ + NodeSource( + original_node, + "bucketing_pass", + [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], + ) + for original_node in bucket_nodes + ] + + def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None: if is_all_gather_into_tensor(node): group_key_fn = ( @@ -842,6 +892,15 @@ def process_collective_bucket( for node in nodes_to_move: wait_insertion_point.prepend(node) + # Preserve metadata from original collective nodes to new bucketed nodes + if bucket_nodes: + overlap_log.debug( + "Bucketing nodes: %s, New nodes: %s", + ",".join([n.name for n in bucket_nodes]), + ",".join([n.name for n in new_nodes]), + ) + _populate_node_meta(bucket_nodes, new_nodes) + # Erase old nodes for node, wait_n in zip(bucket_nodes, bucket_waits): g.erase_node(wait_n) diff --git a/torch/_inductor/fx_passes/graph_view.py b/torch/_inductor/fx_passes/graph_view.py index 88a78747ec607..5758551a9b8a5 100644 --- a/torch/_inductor/fx_passes/graph_view.py +++ b/torch/_inductor/fx_passes/graph_view.py @@ -2,12 +2,16 @@ import itertools import re -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch.fx as fx # noqa: TC001 from torch.utils._ordered_set import OrderedSet +if TYPE_CHECKING: + from collections.abc import Callable + + def _get_module_stack(node: fx.Node) -> list[tuple[str, type[Any]]]: nn_stack = node.meta.get("nn_module_stack", "") if nn_stack: @@ -105,7 +109,10 @@ def _is_root(stack: str) -> bool: return stack == "" -def make_graph_view(graph: fx.Graph) -> Optional[GraphView]: +def make_graph_view( + graph: fx.Graph, + module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None, +) -> Optional[GraphView]: """ Code from: https://github.com/meta-pytorch/autoparallel/pull/158 @@ -147,12 +154,45 @@ def make_graph_view(graph: fx.Graph) -> Optional[GraphView]: subgraph = get_subgraph_by_path(graph_view, "layers.0") where subgraph contains all the nodes that belong to this region + + module_stack_fn: Optional callable for extracting module hierarchy information from nodes. + + Signature: Callable[[fx.Node], list[tuple[str, type[Any]]]] + + Takes an FX node and returns a list of (module_path, module_class) tuples representing + the nested module hierarchy for that node, ordered from outermost to innermost scope. + + - module_path (str): Dot-separated path identifying the module in the hierarchy + (e.g., "layers.0.attention.wq") + - module_class (type): The Python class type of the module + + This enables custom logic for determining module membership, useful for: + - Graphs without standard nn_module_stack metadata + - Filtering or grouping nodes by custom criteria + + Example of getting the module stack from annotation: + + def module_stack_fn(node): + module_stack = node.meta.get("custom", {}).get("module_path", "") + return [(module_stack, torch.nn.Module)] + + If None, defaults to extracting from node.meta["nn_module_stack"] or + node.meta["fwd_nn_module_stack"]. """ + + def nn_module_stack_meta(node: fx.Node) -> list[tuple[str, type[Any]]]: + result = [] + for module_stack, module_class in _get_module_stack(node): + module_stack = _clean_stack_name(module_stack) + result.append((module_stack, module_class)) + return result + + if module_stack_fn is None: + module_stack_fn = nn_module_stack_meta nodes: list[fx.Node] = list(graph.nodes) nodes_by_module_stack_root: GraphView | None = None for node in nodes: - for module_stack, module_class in _get_module_stack(node): - module_stack = _clean_stack_name(module_stack) + for module_stack, module_class in module_stack_fn(node): nodes_by_module_stack: GraphView | None = nodes_by_module_stack_root for name in module_stack.split("."): if nodes_by_module_stack is None: diff --git a/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/torch/_inductor/fx_passes/overlap_manual_scheduling.py index d2c8b588d2011..540e73166ba45 100644 --- a/torch/_inductor/fx_passes/overlap_manual_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_manual_scheduling.py @@ -2,7 +2,7 @@ import heapq from collections import Counter, defaultdict -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING import torch import torch.fx as fx @@ -28,6 +28,10 @@ from .graph_view import get_subgraph_by_path, GraphView, make_graph_view +if TYPE_CHECKING: + from collections.abc import Callable + + class ManualOverlapPreservingBucketer(OverlapPreservingBucketer): """ Buckets collective operations based on user specifications. @@ -106,14 +110,13 @@ def _bucket_group(self, coll_nodes: list[fx.Node]) -> None: new_start = new_wait.args[0] assert isinstance(new_start, fx.Node) + # Set manual bucketing-specific metadata + # Note: Generic metadata (nn_module_stack, fwd_nn_module_stack, custom, stack_trace) + # is now preserved automatically by the bucketing functions in bucketing.py node_type = ( "bucketed_all_gather" if is_all_gather(first) else "bucketed_reduce_scatter" ) for n in new_nodes: - n.meta["nn_module_stack"] = coll_nodes[0].meta.get("nn_module_stack", "") - n.meta["fwd_nn_module_stack"] = coll_nodes[0].meta.get( - "fwd_nn_module_stack", "" - ) if n == new_wait: node_type = node_type + "_wait" n.meta["manual_bucket_node_type"] = node_type @@ -161,6 +164,7 @@ def __init__( gm: fx.GraphModule, module_bucket_plans: list[list[str] | str], insert_overlap_deps: bool, + module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None, ): super().__init__( gm, @@ -187,6 +191,8 @@ def __init__( ) self.insert_overlap_deps = insert_overlap_deps + self.module_stack_fn = module_stack_fn + def _identify_collectives(self) -> None: """Identify all collective operations.""" for node in self.nodes: @@ -317,7 +323,7 @@ def _obtain_nodes_in_subgraph(self) -> None: """ Obtain nodes in each subgraph from module_bucket_plans """ - graph_view: GraphView | None = make_graph_view(self.graph) + graph_view: GraphView | None = make_graph_view(self.graph, self.module_stack_fn) if graph_view is None: return @@ -340,6 +346,7 @@ def manual_overlap_bucketing( gm: torch.fx.GraphModule, module_bucket_plans: list[list[str] | str], insert_overlap_deps: bool = False, + module_stack_fn: None | Callable[[fx.Node], list[tuple[str, type[Any]]]] = None, ) -> torch.fx.GraphModule: """Schedule nodes based on user specifications in module_bucket_plans The manual overlapping consists of two steps: @@ -352,10 +359,16 @@ def manual_overlap_bucketing( Args: gm: input graph module to optimize. module_bucket_plans: user specified FQNs + module_stack_fn: Optional callable for extracting module hierarchy from nodes. + Used to construct a GraphView for identifying nodes in module_bucket_plans. + The module_class component of the returned tuples is not used by this pass. + + See the `module_stack_fn` parameter in `make_graph_view` (graph_view.py) for + detailed documentation on signature, return format, and usage examples. """ # decode abbreviated FQNs to actual FQNs overlapped_gm = ManualOverlapScheduler( - gm, module_bucket_plans, insert_overlap_deps + gm, module_bucket_plans, insert_overlap_deps, module_stack_fn ).run() overlapped_gm.recompile() return overlapped_gm From 7eb625920054b1126a7d2d99818aaa188c6ba95e Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 2 Dec 2025 20:42:12 -0800 Subject: [PATCH 211/338] [dynamo][benchmarks] add option to force amp to use bfloat16 instead of float16 (#169449) For models which hardcode bf16 like modded-nanogpt Tested on `python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --only modded_nanogpt --disable-cudagraphs` w/ https://github.com/pytorch/benchmark/pull/2660 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169449 Approved by: https://github.com/Lucaskabela --- benchmarks/dynamo/common.py | 10 +++++++++- benchmarks/dynamo/torchbench.yaml | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 3d3065ade8a5b..398ca2eab1556 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1796,7 +1796,10 @@ def setup_amp(self, current_device=None): self.autocast = functools.partial( torch.amp.autocast, device_type=devices[0] ) - if self.args.amp_dtype: + if self.args.amp_dtype is None: + if self.args.only in self.amp_dtype_bfloat16: + self.autocast_arg["dtype"] = torch.bfloat16 + else: amp_dtype = ( torch.float16 if self.args.amp_dtype == "float16" @@ -1881,6 +1884,10 @@ def force_amp_for_fp16_bf16_models(self): def force_fp16_for_bf16_models(self): return set() + @property + def amp_dtype_bfloat16(self): + return set() + @property def skip_not_suitable_for_training_models(self): return set() @@ -3877,6 +3884,7 @@ def run(runner, args, original_dir=None): # xfail: https://github.com/pytorch/pytorch/issues/145773 "llama", "cm3leon_generate", + "modded_nanogpt", } ) diff --git a/benchmarks/dynamo/torchbench.yaml b/benchmarks/dynamo/torchbench.yaml index 974c3d700a045..0566820b7ed5b 100644 --- a/benchmarks/dynamo/torchbench.yaml +++ b/benchmarks/dynamo/torchbench.yaml @@ -110,6 +110,8 @@ dtype: force_fp16_for_bf16_models: - vision_maskrcnn + amp_dtype_bfloat16: + - modded_nanogpt # models in canary_models that we should run anyway canary_models: @@ -138,6 +140,7 @@ only_training: - hf_Reformer - pytorch_struct - yolov3 + - modded_nanogpt trt_not_yet_working: @@ -198,6 +201,7 @@ skip: cpu: # model is CUDA only - cm3leon_generate + - modded_nanogpt # timeout - nanogpt # timeout From f9bd6c53624c7c0ea3772de78498326e84c2f0e7 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 2 Dec 2025 08:01:25 -0800 Subject: [PATCH 212/338] Support module.to in strict export (#167555) This diff makes it so that we can call module.to inside export forward. The main strategy is that we inline through eager module.to call and do polyfill for C++ binded functions. We also have to use @zhxchen17 's tracer because there are some graph breaks if we are in export path. Another caveat was that i had to turn off graph_break_on_nn_param_ctr. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167555 Approved by: https://github.com/zhxchen17 --- test/dynamo/test_misc.py | 66 +++++++++++++++++ test/export/test_export.py | 76 +++++++++++++++++++ torch/_dynamo/create_parameter_op.py | 3 +- torch/_dynamo/polyfills/__init__.py | 1 + torch/_dynamo/polyfills/loader.py | 1 + torch/_dynamo/polyfills/torch_c_nn.py | 102 ++++++++++++++++++++++++++ torch/_dynamo/trace_rules.py | 2 - 7 files changed, 248 insertions(+), 3 deletions(-) create mode 100644 torch/_dynamo/polyfills/torch_c_nn.py diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 78b5c7e4553da..c19ec2aa58b29 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -14163,6 +14163,72 @@ def _random_resize(image: torch.Tensor): self.assertTrue(224 <= h <= 256) self.assertTrue(224 <= w <= 256) + @unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device") + def test_module_to_with_shared_weights_compile(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) + + def forward(self, x): + token_ids = torch.randint(0, 10, (4,), device=x.device) + embedded = self.embedding(token_ids).sum() + return x.sum() + embedded.sum() + + class Container(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = Model() + + def forward(self, x): + if "cuda" in str(x.device): + mod = self.mod.to(x.device) + return mod(x) + else: + return x.sum() + + container = Container() + container_eager = copy.deepcopy(container) + with torch._dynamo.config.patch(graph_break_on_nn_param_ctor=False): + compiled = torch.compile(container, backend="eager", fullgraph=True) + + inp1 = torch.randn(4, 4, 4, device="cuda") + + # First call with CUDA input + compiled_result1 = compiled(inp1) + eager_result1 = container_eager(inp1) + same(compiled_result1, eager_result1) + + # Second call - weights are now on CUDA from first call + # This tests that .to(cuda) on already-cuda weights doesn't fail + compiled_result2 = compiled(inp1) + eager_result2 = container_eager(inp1) + same(compiled_result2, eager_result2) + + @unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device") + def test_module_to_move_compile(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc(x) + self.to("cpu") + return x + + mod = Model().cuda() + with torch._dynamo.config.patch(graph_break_on_nn_param_ctor=False): + fn = torch.compile(mod, backend="aot_eager", fullgraph=True) + x = torch.randn(10, 10, device="cuda") + ref = fn(x) + self.assertEqual(str(mod.fc.weight.device), "cpu") + mod.cuda() + ref = fn( + x + ) # second time compile runs, we should also move the module to cpu device + self.assertEqual(str(mod.fc.weight.device), "cpu") + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/export/test_export.py b/test/export/test_export.py index 92ea28c077e52..3a996faf5ed99 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -16323,6 +16323,82 @@ def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): ignore_empty_lines=True, ) + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + def test_module_to_with_shared_weights(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) + + def forward(self, x): + token_ids = torch.ones((4,), device=x.device, dtype=torch.int64) + embedded = self.embedding(token_ids).sum() + return x.sum() + embedded.sum() + + class Container(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = Model() + + def forward(self, x): + if "cuda" in str(x.device): + mod = self.mod.to(x.device) + return mod(x) + else: + return x.sum() + + with ( + torch._dynamo.config.patch(graph_break_on_nn_param_ctor=False), + torch._export.config.patch(use_legacy_dynamo_graph_capture=False), + ): + torch.manual_seed(0) + container = Container() + container_eager = copy.deepcopy(container) + gm = torch.export.export( + container, + (torch.randn(4, 4, 4, device="cuda"),), + strict=True, + ).module() + + self.assertExpectedInline( + str(gm.code).strip(), + """\ +def forward(self, x): + args_0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + mod_embedding_weight = self.mod.embedding.weight + _guards_fn = self._guards_fn(args_0); _guards_fn = None + empty = torch.ops.aten.empty.memory_format([10, 8], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False) + detach = torch.ops.aten.detach.default(empty); empty = None + submod_6 = self.submod_1 + to = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_6, mod_embedding_weight); submod_6 = mod_embedding_weight = None + getitem = to[0]; to = None + submod_7 = self.submod_3 + wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_7); submod_7 = wrap_with_set_grad_enabled = None + submod_8 = self.submod_4 + view_as = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_8, detach, getitem); submod_8 = detach = getitem = None + getitem_1 = view_as[0]; view_as = None + ones = torch.ops.aten.ones.default([4], dtype = torch.int64, device = device(type='cuda', index=0), pin_memory = False) + embedding = torch.ops.aten.embedding.default(getitem_1, ones); getitem_1 = ones = None + sum_1 = torch.ops.aten.sum.default(embedding); embedding = None + sum_2 = torch.ops.aten.sum.default(args_0); args_0 = None + sum_3 = torch.ops.aten.sum.default(sum_1); sum_1 = None + add = torch.ops.aten.add.Tensor(sum_2, sum_3); sum_2 = sum_3 = None + return pytree.tree_unflatten((add,), self._out_spec)""", + ) + + inp = torch.randn(4, 4, 4, device="cuda") + + # Call container first to move shared weights to CUDA + export_out = gm(inp) + eager_out = container_eager(inp) + self.assertEqual(export_out, eager_out) + + # This should not fail even though weights are now on CUDA + # and .to(cuda) returns the same parameter with requires_grad=True + export_out_v2 = gm(inp) + eager_out_v2 = container_eager(inp) + self.assertEqual(export_out_v2, eager_out_v2) + @testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation @testing.expectedFailureStrictV2 # test_hop doesn't have a dynamo implementation @testing.expectedFailureRetraceability # test_hop doesn't have a dynamo implementation diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index 2a716865c3f48..a0bc7325c54b4 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -22,7 +22,8 @@ class TracableCreateParameter(torch.autograd.Function): @staticmethod # pyrefly: ignore [bad-override] def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter: - assert not tensor.requires_grad + if tensor.requires_grad: + tensor = tensor.detach() return placeholder.set_(tensor) @staticmethod diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 59f6f76317e6d..56f614b265a9c 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -34,6 +34,7 @@ pytree as pytree, struct as struct, sys as sys, + torch_c_nn as torch_c_nn, ) from torch.overrides import BaseTorchFunctionMode diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 31479e9d86ce6..46e6fa6df9c3e 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -25,6 +25,7 @@ "sys", "fx", "tensor", + "torch_c_nn", ) if python_pytree._cxx_pytree_dynamo_traceable: POLYFILLED_MODULE_NAMES += ("pytree",) diff --git a/torch/_dynamo/polyfills/torch_c_nn.py b/torch/_dynamo/polyfills/torch_c_nn.py new file mode 100644 index 0000000000000..e6eb8ce78da3e --- /dev/null +++ b/torch/_dynamo/polyfills/torch_c_nn.py @@ -0,0 +1,102 @@ +""" +Polyfills for torch._C._nn functions. +""" + +import torch + +from ..decorators import substitute_in_graph + + +@substitute_in_graph(torch._C._nn._parse_to, skip_signature_check=True) +def _parse_to_polyfill(*args, **kwargs): # noqa: F821 + """ + Polyfill for torch._C._nn._parse_to that parses arguments to nn.Module.to(). + + Signature mirrors torch._C._nn._parse_to which accepts: + - to(device) - device as string or torch.device + - to(dtype) - dtype as torch.dtype + - to(tensor) - extracts device and dtype from tensor + - to(device=..., dtype=..., non_blocking=..., memory_format=...) + + Returns: + tuple: (device, dtype, non_blocking, memory_format) + """ + device = None + dtype = None + non_blocking = False + memory_format = None + + # Handle positional arguments + if len(args) == 1: + arg = args[0] + # Check if it's a tensor + if isinstance(arg, torch.Tensor): + device = arg.device + dtype = arg.dtype + # Check if it's a dtype + elif isinstance(arg, torch.dtype): + dtype = arg + # Check if it's a device (string or torch.device) + elif isinstance(arg, (str, torch.device)): + device = torch.device(arg) if isinstance(arg, str) else arg + else: + raise TypeError( + f"to() received an invalid combination of arguments. Got: {type(arg)}" + ) + elif len(args) > 1: + raise TypeError( + f"to() received too many positional arguments. Got {len(args)}, expected at most 1" + ) + + # Handle keyword arguments + if "device" in kwargs: + device_arg = kwargs["device"] + if device_arg is not None: + device = ( + torch.device(device_arg) if isinstance(device_arg, str) else device_arg + ) + + if "dtype" in kwargs: + dtype = kwargs["dtype"] + + if "non_blocking" in kwargs: + non_blocking = kwargs["non_blocking"] + + if "memory_format" in kwargs: + memory_format = kwargs["memory_format"] + + return (device, dtype, non_blocking, memory_format) + + +@substitute_in_graph(torch.__future__.get_swap_module_params_on_conversion) +def get_swap_module_params_on_conversion_polyfill() -> bool: + """ + Polyfill for torch.__future__.get_swap_module_params_on_conversion. + + Returns the default value False to allow tracing through nn.Module._apply(). + """ + return False + + +@substitute_in_graph(torch._has_compatible_shallow_copy_type) +def _has_compatible_shallow_copy_type_polyfill( + input: torch.Tensor, from_: torch.Tensor +) -> bool: + """ + Polyfill for torch._has_compatible_shallow_copy_type. + + Checks if two tensors have compatible types for shallow copying. + The C++ implementation checks if input's TensorImpl has compatible shallow copy type + with from_'s key_set. We approximate this by checking if both tensors are the same type. + """ + # Check if both tensors are the same type (handles both regular tensors and subclasses) + # This is more permissive than checking exact torch.Tensor type equality + # but properly handles subclasses by allowing same-type shallow copies + return type(input) is type(from_) + + +__all__ = [ + "_parse_to_polyfill", + "get_swap_module_params_on_conversion_polyfill", + "_has_compatible_shallow_copy_type_polyfill", +] diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 083c8b1f93807..813247f4fd3c7 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1064,7 +1064,6 @@ "torch._C._nn._conv_depthwise2d", "torch._C._nn._pad_circular", "torch._C._nn._pad_enum", - "torch._C._nn._parse_to", "torch._C._nn._test_ambiguous_defaults", "torch._C._nn._test_optional_filled_intlist", "torch._C._nn._test_optional_floatlist", @@ -1587,7 +1586,6 @@ "torch._fw_primal_copy", "torch._grid_sampler_2d_cpu_fallback", "torch._grouped_mm", - "torch._has_compatible_shallow_copy_type", "torch._histogramdd_bin_edges", "torch._histogramdd_from_bin_cts", "torch._histogramdd_from_bin_tensors", From ec2c71f5c85021b8938cdafadce24c15a36fd93e Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 4 Dec 2025 00:19:34 +0000 Subject: [PATCH 213/338] [CI][Docker] Add triton to pallas build (#169494) Fixes #169480 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169494 Approved by: https://github.com/malfet --- .ci/docker/build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 748608005e622..f0f154f0c7c1f 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -255,6 +255,7 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=11 PALLAS=yes + TRITON=yes ;; pytorch-linux-jammy-py3.12-triton-cpu) CUDA_VERSION=12.6 From 305168768a95d69c444df5cd334bb774edfe06f1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 00:22:15 +0000 Subject: [PATCH 214/338] Revert "Enable custom collective op autotuning (#167294)" This reverts commit dfbd3714d15c37a7b83b322a6b60f997fc00f50c. Reverted https://github.com/pytorch/pytorch/pull/167294 on behalf of https://github.com/huydhn due to I was too hasty, the macos failure was legit ([comment](https://github.com/pytorch/pytorch/pull/167294#issuecomment-3609375254)) --- test/inductor/test_collective_autotuning.py | 189 --------------- torch/_inductor/codegen/subgraph.py | 55 +---- torch/_inductor/config.py | 10 - torch/_inductor/kernel/custom_op.py | 47 ++-- torch/_inductor/select_algorithm.py | 244 +------------------- torch/_inductor/utils.py | 21 -- 6 files changed, 34 insertions(+), 532 deletions(-) delete mode 100644 test/inductor/test_collective_autotuning.py diff --git a/test/inductor/test_collective_autotuning.py b/test/inductor/test_collective_autotuning.py deleted file mode 100644 index a5a05d05a9028..0000000000000 --- a/test/inductor/test_collective_autotuning.py +++ /dev/null @@ -1,189 +0,0 @@ -# Owner(s): ["module: inductor"] - -import torch -import torch.distributed as dist -from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, - skip_if_lt_x_gpu, -) -from torch.testing._internal.common_utils import run_tests - - -class TestCollectiveAutotuning2Ranks(MultiProcessTestCase): - """Test collective autotuning with 2 ranks""" - - @property - def world_size(self): - return 2 - - def setUp(self): - super().setUp() - self._spawn_processes() - - @skip_if_lt_x_gpu(2) - def test_equivalent_allreduce_strategies(self): - """ - Test autotuning between mathematically equivalent all_reduce strategies. - - Strategy 1: sum all_reduce - Strategy 2: avg all_reduce * world_size - """ - dist.init_process_group( - backend="nccl", - init_method=f"file:///tmp/test_equiv_allreduce_{self.id()}", - world_size=self.world_size, - rank=self.rank, - ) - - dist.barrier() - - rank = dist.get_rank() - device = f"cuda:{rank}" - - from torch._C._distributed_c10d import _register_process_group - - _register_process_group("default", dist.group.WORLD) - - @torch.library.custom_op("test::equiv_ar", mutates_args=()) - def equiv_ar(x: torch.Tensor) -> torch.Tensor: - result = x.clone() - return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") - - @equiv_ar.register_fake - def _(x): - return torch.empty_like(x) - - def sum_allreduce(x: torch.Tensor) -> torch.Tensor: - result = x.clone() - return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") - - def avg_allreduce_scaled(x: torch.Tensor) -> torch.Tensor: - result = x.clone() - result = torch.ops._c10d_functional.all_reduce_(result, "avg", "default") - return result * self.world_size - - from torch._inductor.kernel.custom_op import ( - CustomOpConfig, - register_custom_op_autotuning, - ) - - register_custom_op_autotuning( - equiv_ar, - configs=[ - CustomOpConfig(sum_allreduce), - CustomOpConfig(avg_allreduce_scaled), - ], - ) - - class EquivAllReduceModel(torch.nn.Module): - def forward(self, x): - return equiv_ar(x) - - model = torch.compile(EquivAllReduceModel()).to(device) - - torch.manual_seed(42) - x = torch.randn(128, 128, device=device) - dist.broadcast(x, src=0) - - _ = model(x) - - dist.barrier() - dist.destroy_process_group() - - -class TestCollectiveAutotuning4Ranks(MultiProcessTestCase): - """Test collective autotuning with 4 ranks""" - - @property - def world_size(self): - return 4 - - def setUp(self): - super().setUp() - self._spawn_processes() - - @skip_if_lt_x_gpu(4) - def test_vllm_style_allreduce(self): - """ - Test vLLM-style custom allreduce with buffer copy pattern. - - vLLM uses custom allreduce optimized for small tensors (<8MB). - Two implementations simulate vLLM's registered=False mode vs standard NCCL. - """ - dist.init_process_group( - backend="nccl", - init_method=f"file:///tmp/test_vllm_allreduce_{self.id()}", - world_size=self.world_size, - rank=self.rank, - ) - - dist.barrier() - - rank = dist.get_rank() - device = f"cuda:{rank}" - - from torch._C._distributed_c10d import _register_process_group - - _register_process_group("default", dist.group.WORLD) - - @torch.library.custom_op("test::vllm_allreduce", mutates_args=()) - def vllm_allreduce(x: torch.Tensor) -> torch.Tensor: - result = x.clone() - return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") - - @vllm_allreduce.register_fake - def _(x): - return torch.empty_like(x) - - def vllm_buffer_copy_allreduce(x: torch.Tensor) -> torch.Tensor: - """ - vLLM registered=False: flatten -> copy to IPC buffer -> allreduce -> reshape - - vLLM code: - inp_size = inp.numel() * inp.element_size() - self.buffer_ptrs[self.rank][:inp_size].copy_(inp.view(-1)) - ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) - """ - original_shape = x.shape - flat_x = x.contiguous().view(-1) - buffer_copy = flat_x.clone() - result = torch.ops._c10d_functional.all_reduce_( - buffer_copy, "sum", "default" - ) - return result.view(original_shape) - - def nccl_allreduce_direct(x: torch.Tensor) -> torch.Tensor: - """Standard NCCL allreduce without buffer copy.""" - result = x.clone() - return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") - - from torch._inductor.kernel.custom_op import ( - CustomOpConfig, - register_custom_op_autotuning, - ) - - register_custom_op_autotuning( - vllm_allreduce, - configs=[ - CustomOpConfig(vllm_buffer_copy_allreduce), - CustomOpConfig(nccl_allreduce_direct), - ], - ) - - class VLLMAllReduceModel(torch.nn.Module): - def forward(self, x): - return vllm_allreduce(x) - - model = torch.compile(VLLMAllReduceModel()).to(device) - - torch.manual_seed(42 + rank) - x = torch.randn(128, 256, device=device) - - y = model(x) - self.assertEqual(y.shape, x.shape) - dist.barrier() - dist.destroy_process_group() - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 7b931fb3bf47e..1c1f0f1c9cd2c 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -71,25 +71,16 @@ def __init__( self.sym_inputs = get_symbolic_inputs(self.input_nodes) - # Cache compiled module to avoid recompiling on every benchmark call - self._compiled_module: Any = None - self._compiled_sym_inputs: list[Any] | None = None - def __str__(self) -> str: return f"SubgraphCaller({self.name})" - def _compile_for_benchmarking(self, *args: list[Any]) -> tuple[Any, list[Any]]: - """ - Compile the subgraph for benchmarking and return (module, sym_inputs). - - TODO: Add precompile() method to enable parallel compilation of all choices - before benchmarking. - """ + def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: + # Codegen Subgraph for benchmarking + # Need GraphLowering instead of SubgraphLowering to generate + # fully callable module import torch._inductor.config as inductor_config from torch._inductor.graph import GraphLowering - safe_name = self.name.replace("::", "_").replace(".", "_") - bm_graph_lowering = GraphLowering( gm=self.gm, example_inputs=self.example_inputs, @@ -99,7 +90,7 @@ def _compile_for_benchmarking(self, *args: list[Any]) -> tuple[Any, list[Any]]: extern_node_serializer=V.graph.extern_node_serializer, is_inference=V.graph.is_inference, is_backward=V.graph.is_backward, - name=f"benchmark_{safe_name}", + name=f"benchmark_{self.name}", ) for sym_inp in self.sym_inputs: @@ -132,23 +123,9 @@ def _compile_for_benchmarking(self, *args: list[Any]) -> tuple[Any, list[Any]]: ): bm_graph_lowering.run(*self.example_inputs) mod = bm_graph_lowering.compile_to_module() + bm_func = mod.call - return mod, sym_inputs - - def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: - """ - Regular benchmarking: compile and use benchmarker with warmup/rep. - """ - if self._compiled_module is None: - mod, sym_inputs = self._compile_for_benchmarking(*args) - self._compiled_module = mod - self._compiled_sym_inputs = sym_inputs - else: - mod = self._compiled_module - sym_inputs = self._compiled_sym_inputs - assert sym_inputs is not None # Type narrowing - - bm_func = mod.call + bm_func([*sym_inputs, *args]) if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args])) return benchmarker.benchmark( @@ -157,24 +134,6 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: device=benchmarker.infer_device(*sym_inputs, *args), ) - def benchmark_collective(self, *args: list[Any], out: torch.Tensor) -> None: - """ - Only run once with cached compiled module. - Called by benchmark_collective_choice which handles warmup - and timing with barrier synchronization across all ranks. - """ - if self._compiled_module is None: - mod, sym_inputs = self._compile_for_benchmarking(*args) - self._compiled_module = mod - self._compiled_sym_inputs = sym_inputs - else: - mod = self._compiled_module - sym_inputs = self._compiled_sym_inputs - assert sym_inputs is not None # Type narrowing - - bm_func = mod.call - bm_func([*sym_inputs, *args]) - def hash_key(self) -> str: return "-".join( [ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index fcfb8f51ae6e7..7ba93575ce8bf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -608,16 +608,6 @@ def prologue_fusion_enabled() -> bool: # If autotuning in subprocess, whether to use multiple devices autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" -# Number of benchmark runs for collective operations -collective_benchmark_nruns = int( - os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_NRUNS", "50") -) - -# Timeout in seconds for collective benchmarking -collective_benchmark_timeout = float( - os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_TIMEOUT", "30") -) - coordinate_descent_tuning = ( os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" ) diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index c6a641ce83b17..12cc68dcb9844 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union import torch +from torch._inductor import config from torch._inductor.codegen.subgraph import SubgraphTemplate from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox from torch._inductor.lowering import lowerings, validate_ir @@ -20,28 +21,6 @@ log = logging.getLogger(__name__) -def _detect_collective_ops(choices: list) -> bool: - """ - Detect if choices contain collective operations. - """ - from torch._inductor.utils import is_collective_op - - for choice in choices: - if not hasattr(choice, "gm") or choice.gm is None: - continue - - for node in choice.gm.graph.nodes: - if node.op == "call_function" and node.target is not None: - op_name = str(node.target) - - if is_collective_op(op_name) or is_collective_op( - f"torch.ops.{op_name}" - ): - return True - - return False - - class CustomOpConfig: """Config for custom op autotuning. @@ -201,8 +180,14 @@ def create_internal_input_gen_fn( """Create internal input generator that converts IR buffer to user's fake tensor.""" def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor: - fake_tensor = ir_node_to_tensor(ir_buffer) - assert fake_tensor is not None, "ir_node_to_tensor returned None" + raw_shape = ir_buffer.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback + ) + + fake_tensor = torch.empty( + concrete_shape, dtype=ir_buffer.get_dtype(), device="meta" + ) return user_function(fake_tensor) return internal_input_gen_fn @@ -336,8 +321,6 @@ def autotune_custom_op( ) input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns) - is_collective = _detect_collective_ops(choices) - # Run autotuning and get both result and winning choice selected_result, winning_choice = autotune_select_algorithm( name=name, @@ -346,7 +329,6 @@ def autotune_custom_op( layout=choices[0].layout, input_gen_fns=input_gen_fns, return_choice=True, - is_collective=is_collective, ) # Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl) @@ -381,7 +363,16 @@ def _generate_dynamic_configs( param_names = list(sig.parameters.keys()) with V.fake_mode: - fake_tensors = [ir_node_to_tensor(inp) for inp in tensor_inputs] + fake_tensors = [] + for inp in tensor_inputs: + raw_shape = inp.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback + ) + fake_tensor = torch.empty( + concrete_shape, dtype=inp.get_dtype(), device=inp.get_device() + ) + fake_tensors.append(fake_tensor) fake_tensors_dict = dict(zip(param_names, fake_tensors)) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index df71bdd3db502..77448c914df80 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2335,10 +2335,6 @@ def autoheuristic_id(self): class ExternKernelCaller(ChoiceCaller): - """ - Caller for external kernel implementations - """ - def __init__( self, choice: ExternKernelChoice, @@ -2374,19 +2370,6 @@ def benchmark(self, *args, out): return do_bench_using_profiling(lambda: algo(*args)) return benchmarker.benchmark(algo, args, {}) - def benchmark_collective(self, *args, out): - """ - Called by benchmark_collective_choice, only run once, timing handled externally with barrier sync. - """ - if out.numel() == 0: - return - - algo = self.to_callable() - if self.has_out_variant: - algo(*args, out=out) - else: - algo(*args) - def to_callable(self): fn = self.choice.to_callable() if self.kwargs: @@ -2750,7 +2733,6 @@ def __call__( return_multi_template=False, best_config_future=None, return_choice=False, # TODO: return_choice is temporary and will be refactored soon - is_collective=False, ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2861,7 +2843,6 @@ def get_timings(hint_override: Optional[int] = None): choices, precompile_fn, best_config_future=best_config_future, - is_collective=is_collective, ) # if timings is empty, we really have no choice but to return a semi-random # choice. returning the first `ExternKernelCaller` is probably the safest bet @@ -2893,7 +2874,6 @@ def get_timings(hint_override: Optional[int] = None): # if we got any timings at all, pick the best of those choice = min(timings, key=timings.__getitem__) node = choice.output_node() - log.debug("Autotuning selected choice: %s", node) if return_choice: return node, choice @@ -2906,18 +2886,12 @@ def benchmark( layout, input_gen_fns, hint_override: Optional[int] = None, - is_collective=False, ): counters["inductor"]["select_algorithm_autotune"] += 1 # TODO(nmacchioni): remove this layer of abstraction # construct `benchmark_fn` which should pick between in-process and sub-process autotuning benchmark_fn = self.make_benchmark_fn( - choices, - input_nodes, - layout, - input_gen_fns, - hint_override=hint_override, - is_collective=is_collective, + choices, input_nodes, layout, input_gen_fns, hint_override=hint_override ) # `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which # maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds @@ -2931,7 +2905,6 @@ def autotune( input_gen_fns, choices, hint_override: Optional[int] = None, - is_collective=False, ): log.debug("Starting autotuning") @@ -2942,12 +2915,7 @@ def autotune( metadata=_autotune_metadata(input_nodes), ): benchmark_results = self.benchmark( - choices, - input_nodes, - layout, - input_gen_fns, - hint_override=hint_override, - is_collective=is_collective, + choices, input_nodes, layout, input_gen_fns, hint_override=hint_override ) if config.max_autotune_report_choices_stats: _log_autotune_choices_stats( @@ -2966,7 +2934,6 @@ def do_autotuning( precompile_fn, hint_override: Optional[int] = None, best_config_future=None, - is_collective=False, ): """Execute the autotuning process for kernel algorithm selection. @@ -3104,7 +3071,6 @@ def track_has_autotuned(choices): input_gen_fns, choices, hint_override=hint_override, - is_collective=is_collective, ) timings = self.lookup( @@ -3118,17 +3084,6 @@ def track_has_autotuned(choices): autotune_elapse = time.time() - autotune_start_ts log.debug("Autotuning elapsed time: %.02fs", autotune_elapse) - # For collective: if any choice returned inf (timeout or failure), fallback to default - if is_collective and timings: - has_inf = any(not math.isfinite(timing) for timing in timings.values()) - if has_inf: - log.warning( - "At least one choice failed or timed out during collective benchmarking. " - "Falling back to default implementation." - ) - return {} - - # For regular: if all choices returned inf, raise error if timings and all(not math.isfinite(timing) for timing in timings.values()): raise NoValidChoicesError @@ -3145,7 +3100,6 @@ def track_has_autotuned(choices): precompile_elapse, prescreening_elapse, hint_override=hint_override, - is_collective=is_collective, ) def profiler_bench_function(): @@ -3507,162 +3461,16 @@ def benchmark_choice( autotune_args.verify(**VERIFY) return result - @classmethod - def _run_collective_benchmark( - cls, - choice: ChoiceCaller, - inputs: tuple, - output: torch.Tensor, - nruns: int, - process_group, - timeout, - ) -> float: - """ - Single function for benchmarking collective operations. - Used for both warmup and actual benchmarking. - - Returns total time in milliseconds, or raises TimeoutError if any collective times out. - """ - import torch.distributed as dist - - work = dist.barrier(group=process_group, async_op=True) - if not work.wait(timeout): - raise TimeoutError("Barrier timeout before benchmarking") - - torch.cuda.synchronize() - - total_time = 0.0 - - for i in range(nruns): - torch.cuda.synchronize() - - start_evt = torch.cuda.Event(enable_timing=True) - end_evt = torch.cuda.Event(enable_timing=True) - - start_evt.record() - choice.benchmark_collective(*inputs, out=output) # type: ignore[attr-defined] - end_evt.record() - end_evt.synchronize() - - total_time += start_evt.elapsed_time(end_evt) - - return total_time - - @classmethod - def benchmark_collective_choice( - cls, - choice: ChoiceCaller, - autotune_args: AutotuneArgs, - ) -> float: - """ - Benchmark a choice for collective operations with cross-rank synchronization. - This method ensures all ranks synchronize before benchmarking - to get accurate measurements for distributed collective operations. - - Timeout/Error handling: If ANY rank times out or encounters an error during - the collective operations, ALL ranks will naturally time out (since the collective - won't complete), allowing the autotuner to fall back to the default implementation. - """ - from datetime import timedelta - - import torch.distributed as dist - - timeout_seconds = config.collective_benchmark_timeout - - nruns = config.collective_benchmark_nruns - nwarmup = ir.autotune_warmup - - # Use default process group (None = all ranks) - process_group = None - rank = dist.get_rank(process_group) - - benchmark_tensors: BenchmarkTensors = autotune_args.get_benchmark_tensors( - cls._is_extern(choice) - ) - inputs, output = benchmark_tensors.unpack() - output.zero_() - - timeout = timedelta(seconds=timeout_seconds) - - try: - # Do n warmups - total_time = cls._run_collective_benchmark( - choice, inputs, output, nwarmup, process_group, timeout - ) - - # Do n actual benchmarking runs - total_time = cls._run_collective_benchmark( - choice, inputs, output, nruns, process_group, timeout - ) - - avg_time = total_time / nruns - - # All-reduce to get avg time across ranks - time_tensor = torch.tensor( - [avg_time], dtype=torch.float32, device=f"cuda:{rank}" - ) - work = dist.all_reduce( - time_tensor, - op=dist.ReduceOp.AVG, - group=process_group, - async_op=True, - ) - if not work.wait(timeout): - raise TimeoutError( - "All-reduce timeout when collecting benchmark results" - ) - - timing = time_tensor.item() - - log.info( - "Collective benchmark for %s: %.6f ms", - choice.name, - timing, - ) - - return timing - - except Exception: - log.warning( - "Collective benchmark exception for choice %s. Skipping this choice.", - getattr(choice, "name", ""), - exc_info=True, - ) - return float("inf") - @classmethod def benchmark_choices( cls, choices: Sequence[ChoiceCaller], autotune_args: AutotuneArgs, - is_collective: bool = False, ) -> dict[ChoiceCaller, float]: - """ - Benchmark a list of choices and return timing dict. - """ - if is_collective: - import torch.distributed as dist - - if not dist.is_initialized(): - log.warning( - "Collective op detected but distributed not initialized. " - "Falling back to regular benchmarking." - ) - is_collective = False - else: - rank = dist.get_rank(None) # Use default process group - log.debug( - "Using collective benchmarking for %d choices on rank %d", - len(choices), - rank, - ) timings = {} for choice in choices: try: - if is_collective: - timing = cls.benchmark_collective_choice(choice, autotune_args) - else: - timing = cls.benchmark_choice(choice, autotune_args) + timing = cls.benchmark_choice(choice, autotune_args) except CUDACompileError: from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -3716,16 +3524,6 @@ def benchmark_choices( timings[choice] = timing - # If a collective choice failed or timed out, skip the rest of the choices - if is_collective and not math.isfinite(timing): - log.warning( - "Choice %s failed or timed out during collective benchmarking. " - "Stopping further benchmarking to avoid NCCL corruption.", - getattr(choice, "name", ""), - ) - timings.update({c: float("inf") for c in choices if c not in timings}) - break - return timings @classmethod @@ -3736,16 +3534,11 @@ def benchmark_in_current_process( layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], hint_override: Optional[int] = None, - is_collective=False, ) -> dict[ChoiceCaller, float]: inputs = cls.get_inputs( choices, input_nodes, layout, input_gen_fns, hint_override=hint_override ) - return cls.benchmark_choices( - choices, - inputs, - is_collective=is_collective, - ) + return cls.benchmark_choices(choices, inputs) @classmethod def benchmark_in_sub_process( @@ -3777,24 +3570,21 @@ def make_benchmark_fn( layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], hint_override: Optional[int] = None, - is_collective=False, ): if DEBUG: print(f"{len(choices)} tuning requests:") - # Collective ops must use current process - if is_collective or not config.autotune_in_subproc: + if config.autotune_in_subproc: return functools.partial( - cls.benchmark_in_current_process, + cls.benchmark_in_sub_process, input_nodes=input_nodes, layout=layout, input_gen_fns=input_gen_fns, hint_override=hint_override, - is_collective=is_collective, ) else: return functools.partial( - cls.benchmark_in_sub_process, + cls.benchmark_in_current_process, input_nodes=input_nodes, layout=layout, input_gen_fns=input_gen_fns, @@ -4026,26 +3816,8 @@ def log_results( precompile_elapse: float, prescreening_elapse: Optional[float] = None, hint_override: Optional[int] = None, - is_collective: bool = False, ): - """Log the autotuning results, currently only handles mm and flex. Log Collective op autotuning result""" - if is_collective and timings: - import torch.distributed as dist - - # Only rank 0 logs to avoid duplicate logs - rank = dist.get_rank() if dist.is_initialized() else 0 - if rank == 0: - best_choice = min(timings, key=timings.__getitem__) - log.warning("[COLLECTIVE AUTOTUNING] All timings:") - for c, t in sorted(timings.items(), key=lambda x: x[1]): - choice_name = getattr(c, "name", str(c)) - log.warning( - " - %s: %.6f ms %s", - choice_name, - t if math.isfinite(t) else float("inf"), - "← SELECTED" if c == best_choice else "", - ) - + """Log the autotuning results, currently only handles mm and flex""" V.debug.log_autotuning_results( name, input_nodes, timings, elapse, precompile_elapse ) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d7f3844cdf1ba..4d1ddc9ad4769 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -4126,24 +4126,3 @@ def should_fallback_by_default(node: torch.fx.Node) -> bool: return target in fallback_hops return not _needs_inductor_compile(node) - - -# Collective operation names for specialized benchmarking -COLLECTIVE_OPS = OrderedSet( - [ - "torch.ops._c10d_functional.all_reduce.default", - "torch.ops._c10d_functional.all_reduce_.default", - "torch.ops._c10d_functional.all_gather_into_tensor.default", - "torch.ops._c10d_functional.reduce_scatter_tensor.default", - "torch.ops._c10d_functional.all_to_all_single.default", - "torch.ops._c10d_functional_autograd.all_reduce.default", - "torch.ops._c10d_functional_autograd.all_gather_into_tensor.default", - "torch.ops._c10d_functional_autograd.reduce_scatter_tensor.default", - "torch.ops._c10d_functional_autograd.all_to_all_single.default", - ] -) - - -def is_collective_op(op_name: str) -> bool: - """Check if an operation is a collective operation.""" - return op_name in COLLECTIVE_OPS From 65c4620d6bb0c6029f69762c22b91dda2294da9a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 00:25:01 +0000 Subject: [PATCH 215/338] Revert "[effect] Remove special handling for profiler op (#168389)" This reverts commit 296e67c92635443c67b11c0ae1bd045f03ebb7bc. Reverted https://github.com/pytorch/pytorch/pull/168389 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail test_python_dispatch ([comment](https://github.com/pytorch/pytorch/pull/168389#issuecomment-3609380860)) --- torch/_higher_order_ops/effects.py | 5 +++++ torch/_ops.py | 34 +++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index 96d7872048ec8..86707a4f55ef1 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -112,6 +112,11 @@ def has_aliasing(op: OpType): def has_effects(op) -> bool: + # Skip over the profiler's RecordFunction as they should not show up in the graph + _skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction} + if op in _skip_ops: + return False + return ( isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) and not has_aliasing(op) diff --git a/torch/_ops.py b/torch/_ops.py index 75905d78da5b5..8f8a7328429fa 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1043,6 +1043,28 @@ def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): if _may_use_fallthrough_instead_of_fallback(key) ] + @contextlib.contextmanager + def _register_as_effectful_op_temporarily(self): + from torch._higher_order_ops.effects import ( + _EffectType, + _get_effect, + _register_effectful_op, + ) + + try: + # We don't want to register the effect if there already exists a + # registration, especially if the registration is None (explicitly + # no effect) + register_tmp_effect = _get_effect(self) is None + handle = None + if register_tmp_effect: + handle = _register_effectful_op(self, _EffectType.ORDERED) + yield + finally: + if register_tmp_effect: + assert handle is not None + handle.destroy() + # Use positional-only argument to avoid naming collision with aten ops arguments # that are named "self". This way, all the aten ops can be called by kwargs. def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: @@ -1050,7 +1072,17 @@ def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: # When any inputs are FakeScriptObject, we need to # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject. - return self._dispatch_in_python(self._fallthrough_keys(), *args, **kwargs) + # + # Note: + # 1. We only register the torchbind op temporarily as effectful op because we only want + # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior + # of the eagerly executing the op might change after tracing. + # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might + # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction. + with self._register_as_effectful_op_temporarily(): + return self._dispatch_in_python( + self._fallthrough_keys(), *args, **kwargs + ) return self._op(*args, **kwargs) def _dispatch_in_python( From 320de0c6b0a3e7c6d2693ea5c28d5d0156ba7991 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 3 Dec 2025 13:53:24 -0800 Subject: [PATCH 216/338] Add public documentation for stable_topological_sort (#169498) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169498 Approved by: https://github.com/albanD ghstack dependencies: #167397 --- docs/source/conf.py | 1 - docs/source/fx.md | 3 +++ test/allowlist_for_publicAPI.json | 2 +- torch/fx/passes/utils/fuser_utils.py | 4 ++-- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 5c404f8c129fc..7a3663ca062df 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -950,7 +950,6 @@ "get_node_target", "is_node_output_tensor", "legalize_graph", - "stable_topological_sort", # torch.fx.passes.utils.common "compare_graphs", "lift_subgraph_as_module", diff --git a/docs/source/fx.md b/docs/source/fx.md index b8447b378d3f9..4ce1c9d01f06a 100644 --- a/docs/source/fx.md +++ b/docs/source/fx.md @@ -1096,6 +1096,9 @@ The set of leaf modules can be customized by overriding ```{eval-rst} .. autofunction:: torch.fx.traceback.annotate ``` +```{eval-rst} +.. autofunction:: torch.fx.passes.tools_common.stable_topological_sort +`` diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index b6c203aea4ab6..d01d41d37997e 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2090,7 +2090,7 @@ "SimpleQueue", "Tuple", "compatibility", - "stable_topological_sort", + "legalize_graph", "lift_subgraph_as_module" ], "torch.fx.tensor_type": [ diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index e5509187b39dd..ea264e9fb2641 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -7,7 +7,7 @@ from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule from torch.fx.node import Node -from torch.fx.passes.tools_common import NodeList, NodeSet, stable_topological_sort +from torch.fx.passes.tools_common import NodeList, NodeSet from torch.fx.passes.utils import lift_subgraph_as_module # type: ignore[attr-defined] @@ -283,7 +283,7 @@ def fuse_by_partitions( erase_nodes(gm, sorted_nodes) - stable_topological_sort(gm) + torch.fx.passes.tools_common.stable_topological_sort(gm) gm.graph.lint() return gm From 6c8b6a043f1628188b6396b3a2a6e000ca68362b Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Thu, 4 Dec 2025 00:45:04 +0000 Subject: [PATCH 217/338] localtensor pointwise ops test (#169500) add localtensor test support for pointwise ops test. uncovered some weird enable/disable bugs but should be good now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169500 Approved by: https://github.com/dzmitry-huba --- test/distributed/tensor/test_pointwise_ops.py | 17 +++++ torch/distributed/_local_tensor/__init__.py | 55 ++++++++++---- torch/random.py | 4 +- .../distributed/_tensor/common_dtensor.py | 71 ++++++++++++++++++- 4 files changed, 128 insertions(+), 19 deletions(-) diff --git a/test/distributed/tensor/test_pointwise_ops.py b/test/distributed/tensor/test_pointwise_ops.py index 9d35e10f24ba8..54f8715b25671 100644 --- a/test/distributed/tensor/test_pointwise_ops.py +++ b/test/distributed/tensor/test_pointwise_ops.py @@ -20,8 +20,11 @@ from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorOpTestBase, + LocalDTensorOpTestBase, skip_unless_torch_gpu, + with_comms, ) @@ -141,6 +144,7 @@ def _run_sharded_elementwise_ops( kwargs=kwargs, ) + @with_comms def test_partial_add(self): device_mesh = self.build_device_mesh() d_1 = DTensor.from_local(torch.rand(2, 2), device_mesh, [Partial()]) @@ -148,6 +152,7 @@ def test_partial_add(self): d_3 = d_1 + d_2 self.assertTrue(d_3._spec.placements[0].is_partial()) + @with_comms def test_partial_replicate_add(self): device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() @@ -172,6 +177,7 @@ def test_partial_replicate_add(self): self.assertEqual(d_3.placements, (Partial(reduce_op=reduce_op),)) self.assertEqual(d_3.full_tensor(), d_1.full_tensor() + d_2.full_tensor()) + @with_comms def test_activations(self): device_mesh = self.build_device_mesh() self._run_sharded_elementwise_ops( @@ -211,6 +217,7 @@ def test_activations(self): op=torch.sigmoid, ) + @with_comms @skip( "testing RNG based ops is broken: https://github.com/pytorch/PiPPy/issues/494" ) @@ -239,6 +246,7 @@ def _reset_random_seed(): training=True, ) + @with_comms @skip_unless_torch_gpu def test_dropout_backward(self): device_mesh = self.build_device_mesh() @@ -271,6 +279,7 @@ def test_dropout_backward(self): ), ) + @with_comms @skip_unless_torch_gpu def test_dropout_errors(self): device_mesh = self.build_device_mesh() @@ -282,6 +291,7 @@ def test_dropout_errors(self): op=torch.nn.functional.dropout, ) + @with_comms def test_mul_out(self): device_mesh = self.build_device_mesh() torch.manual_seed(self.rank) @@ -300,6 +310,7 @@ def test_mul_out(self): self.assertEqual(input_tensor, dtensor.to_local()) self.assertEqual(expected, dt.to_local()) + @with_comms def test_mul_partial(self): # we only test the partial behavior for mul op as other placement # behaviors should be well tested in test_dtensor_ops.py @@ -356,6 +367,7 @@ def test_mul_partial(self): self.assertEqual(z.placements, (Replicate(),)) self.assertEqual(z.to_local(), input) + @with_comms def test_inplace_op_partial_to_replicate(self): # test that in-place operations that require redistribution raise an error # to preserve aliasing semantics (issue #163374) @@ -376,5 +388,10 @@ def test_inplace_op_partial_to_replicate(self): partial_dt.clamp_(max=10) +DistElementwiseOpsTestWithLocalTensor = create_local_tensor_test_class( + DistElementwiseOpsTest, base_class=LocalDTensorOpTestBase +) + + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index 4c8f12c11687b..c780e1ef7cb8a 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -355,7 +355,8 @@ def _for_each_rank_run_func( for r in sorted(ranks): if use_per_rank_rng: assert lm is not None - _set_rng_state(*lm._per_rank_rng_states[r]) + if r in lm._per_rank_rng_states: + _set_rng_state(*lm._per_rank_rng_states[r]) else: assert global_rng_state is not None _set_rng_state(*global_rng_state) @@ -1164,7 +1165,7 @@ def __init__(self, ranks: Union[int, frozenset[int]]): else: assert isinstance(ranks, frozenset) self.ranks = ranks - self._disable = False + self._disable = True self._old_get_coordinate = None self._old_torch_manual_seed: Any = None self._old_torch_initial_seed: Any = None @@ -1172,10 +1173,10 @@ def __init__(self, ranks: Union[int, frozenset[int]]): int, tuple[torch.Tensor, dict[int, torch.Tensor]] ] = {} + self.enable_() + def __enter__(self) -> "LocalTensorMode": - self._disable = False - self._patch_device_mesh() - self._patch_random_functions() + self.enable_() get_local_tensor_mode_list().append(self) # _distribute_region will compute correct per-shard offsets @@ -1196,9 +1197,7 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - self._disable = True - self._unpatch_device_mesh() - self._unpatch_random_functions() + self.disable_() get_local_tensor_mode_list().pop() super().__exit__(exc_type, exc_val, exc_tb) @@ -1314,6 +1313,22 @@ def __torch_dispatch__( return _for_each_rank_run_func(func, self.ranks, args, kwargs, alias=True) + def disable_(self): + if self._disable: + return + + self._unpatch_device_mesh() + self._unpatch_random_functions() + self._disable = True + + def enable_(self): + if not self._disable: + return + + self._patch_device_mesh() + self._patch_random_functions() + self._disable = False + @contextlib.contextmanager def disable(self) -> Generator[None, None, None]: """ @@ -1321,14 +1336,21 @@ def disable(self) -> Generator[None, None, None]: rank specific computations and merge results back before enabling LocalTensorMode back. """ - old = self._disable - self._disable = True - self._unpatch_device_mesh() + # don't unpatch again if already disabled + if self._disable: + try: + yield + finally: + # re-disable if the yield messed + # with the state + self.disable_() + return # noqa: B012 + + self.disable_() try: yield finally: - self._disable = old - self._patch_device_mesh() + self.enable_() def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor: """ @@ -1417,12 +1439,12 @@ def torch_manual_seed(seed) -> torch._C.Generator: for rank in sorted(lm.ranks): rank_seed = seed.node._local_ints[rank] - _manual_seed_impl(rank_seed, update_local_tensor_states=False) + _manual_seed_impl(rank_seed) lm._per_rank_rng_states[rank] = _get_rng_state() return torch.random.default_generator from torch.random import _manual_seed_impl - result = _manual_seed_impl(seed, update_local_tensor_states=False) + result = _manual_seed_impl(seed) if lm is not None and len(lm._per_rank_rng_states) > 0: cpu_state, cuda_states = _get_rng_state() @@ -1452,6 +1474,9 @@ def torch_initial_seed(): return torch.random.default_generator.initial_seed() +# Save the original get_coordinate method before any patching + + class _LocalDeviceMesh: """ Holds implementations of DeviceMesh functionality that must be patched while running diff --git a/torch/random.py b/torch/random.py index f86d7349019dc..e36f635c0df13 100644 --- a/torch/random.py +++ b/torch/random.py @@ -39,10 +39,10 @@ def manual_seed(seed) -> torch._C.Generator: is raised. Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`. """ - return _manual_seed_impl(seed, update_local_tensor_states=True) + return _manual_seed_impl(seed) -def _manual_seed_impl(seed, update_local_tensor_states) -> torch._C.Generator: +def _manual_seed_impl(seed) -> torch._C.Generator: seed = int(seed) import torch.cuda diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 54bc65bc93365..2c749ca2d5416 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -524,6 +524,12 @@ def wrapper( *args: tuple[object], **kwargs: dict[str, Any], # type: ignore[misc] ) -> None: + # just passthrough if harness doesn't + # support init_pg e.g., DTensorOpTestBase + if not hasattr(self, "init_pg"): + func(self, *args, **kwargs) + return + self.init_pg(eager_init, backend) try: @@ -712,6 +718,65 @@ def to_dist_tensor( raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}") +class LocalDTensorOpTestBase(DTensorOpTestBase): + @property + def is_local_tensor_enabled(self) -> bool: + return True + + def _handle_test_skip(self, msg: str) -> None: + self.skipTest(msg) + + def _get_local_tensor_mode(self): + return LocalTensorMode(frozenset(range(self.world_size))) + + def setUp(self) -> None: + super().setUp() + torch.autograd._enable_record_function(False) + + def tearDown(self) -> None: + from torch.distributed.tensor import _random as random + + random._rng_tracker = None + super().tearDown() + torch.autograd._enable_record_function(True) + + @property + def rank(self): + return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)})) + + @rank.setter + def rank(self, rank): + pass + + def join_or_run(self, fn): + @wraps(fn) + def wrapper(self): + fn() + + return types.MethodType(wrapper, self) + + def build_device_mesh(self) -> DeviceMesh: + with maybe_disable_local_tensor_mode(): + return super().build_device_mesh() + + def init_pg(self, eager_init, backend: Optional[str] = None) -> None: + dist.init_process_group("fake", rank=0, world_size=self.world_size) + self._pg = dist.distributed_c10d._get_default_group() + + def destroy_pg(self, device_id: Optional[int] = None) -> None: + dist.destroy_process_group(self._pg) + self._pg = None + + def _spawn_processes(self) -> None: + pass + + def run_test(self, test_name: str, parent_pipe) -> None: + getattr(self, test_name)() + + def init_manual_seed_for_rank(self) -> None: + torch.manual_seed(0) + + class LocalDTensorTestBase(DTensorTestBase): @property def is_local_tensor_enabled(self) -> bool: @@ -790,7 +855,9 @@ def wrapped(self): return wrapped -def create_local_tensor_test_class(orig_cls, skipped_tests=None): +def create_local_tensor_test_class( + orig_cls, skipped_tests=None, base_class=LocalDTensorTestBase +): if skipped_tests is None: skipped_tests = [] @@ -809,7 +876,7 @@ def create_local_tensor_test_class(orig_cls, skipped_tests=None): cls = type( orig_cls.__name__ + "WithLocalTensor", - (LocalDTensorTestBase,) + orig_cls.__bases__, + (base_class,) + orig_cls.__bases__, dct, ) cls.__file__ = __file__ From 2df6058f116a65722a0e03073402feb242572d35 Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 3 Dec 2025 09:12:03 -0800 Subject: [PATCH 218/338] [opaque_obj] Remove free registration (#167739) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167739 Approved by: https://github.com/zou3519 --- test/test_opaque_obj_v2.py | 83 ++++++++++++++----- torch/_library/opaque_object.py | 38 ++++++--- .../csrc/jit/frontend/schema_type_parser.cpp | 8 ++ 3 files changed, 97 insertions(+), 32 deletions(-) diff --git a/test/test_opaque_obj_v2.py b/test/test_opaque_obj_v2.py index 7dcddfb0f3906..99ff9058eda52 100644 --- a/test/test_opaque_obj_v2.py +++ b/test/test_opaque_obj_v2.py @@ -13,10 +13,11 @@ ) from torch._library.effects import EffectType from torch._library.fake_class_registry import FakeScriptObject -from torch._library.opaque_object import register_opaque_type +from torch._library.opaque_object import get_opaque_type_name, register_opaque_type from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.fx.graph import _illegal_char_regex from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -63,9 +64,15 @@ def increment_counter(self): self.counter += 1 -register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue") -register_opaque_type(RNGState, "_TestOpaqueObject_RNGState") -register_opaque_type(Counter, "_TestOpaqueObject_Counter") +class AddModule(torch.nn.Module): + def forward(self, x, y): + return x * y + + +register_opaque_type(OpaqueQueue) +register_opaque_type(RNGState) +register_opaque_type(Counter) +register_opaque_type(AddModule) class TestOpaqueObject(TestCase): @@ -74,7 +81,7 @@ def setUp(self): torch.library.define( "_TestOpaqueObject::queue_push", - "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()", + f"({get_opaque_type_name(OpaqueQueue)} a, Tensor b) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=self.lib, ) @@ -91,7 +98,7 @@ def push_impl_fake(q: OpaqueQueue, b: torch.Tensor) -> None: pass self.lib.define( - "queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor", + f"queue_pop({get_opaque_type_name(OpaqueQueue)} a) -> Tensor", ) def pop_impl(queue: OpaqueQueue) -> torch.Tensor: @@ -126,7 +133,7 @@ def size_impl_fake(q: OpaqueQueue) -> int: torch.library.define( "_TestOpaqueObject::noisy_inject", - "(Tensor x, _TestOpaqueObject_RNGState obj) -> Tensor", + f"(Tensor x, {get_opaque_type_name(RNGState)} obj) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=self.lib, ) @@ -227,7 +234,7 @@ def forward(self, arg0_1, arg1_1): def test_bad_fake(self, make_fx_tracing_mode): torch.library.define( "_TestOpaqueObject::bad_fake", - "(Tensor x, _TestOpaqueObject_RNGState obj) -> Tensor", + f"(Tensor x, {get_opaque_type_name(RNGState)} obj) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=self.lib, ) @@ -326,7 +333,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): "_TestOpaqueObject::noisy_inject", None ) - def test_compile(self): + def test_compile1(self): def foo(rng_state, x): x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state) x = x * x @@ -342,10 +349,14 @@ def foo(rng_state, x): backend = AotEagerAndRecordGraphs() torch.compile(foo, fullgraph=True, backend=backend)(rng, x) + + # This is done in torch.fx's graph in _namespace.create_name() where it + # sanitizes the name + fx_class = _illegal_char_regex.sub("_", get_opaque_type_name(RNGState)) self.assertExpectedInline( backend.graphs[0].code.strip(), - """\ -def forward(self, L_x_ : torch.Tensor, L_rng_state_ : __main___RNGState): + f"""\ +def forward(self, L_x_ : torch.Tensor, L_rng_state_ : {fx_class}): l_x_ = L_x_ l_rng_state_ = L_rng_state_ x = torch.ops._TestOpaqueObject.noisy_inject(l_x_, l_rng_state_); l_x_ = None @@ -430,15 +441,9 @@ def bar(counter, x): torch.compile(bar)(counter, torch.ones(2, 3)) def test_export_joint(self): - class Moo(torch.nn.Module): - def forward(self, x, y): - return x * y - - register_opaque_type(Moo, "_TestOpaqueObject_Moo") - torch.library.define( "_TestOpaqueObject::module_mul", - "(_TestOpaqueObject_Moo a, Tensor b, SymInt c) -> Tensor", + f"({get_opaque_type_name(AddModule)} a, Tensor b, SymInt c) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=self.lib, ) @@ -446,12 +451,12 @@ def forward(self, x, y): @torch.library.impl( "_TestOpaqueObject::module_mul", "CompositeExplicitAutograd", lib=self.lib ) - def module_mul_impl(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor: - assert isinstance(m, Moo) + def module_mul_impl(m: AddModule, a: torch.Tensor, b: int) -> torch.Tensor: + assert isinstance(m, AddModule) return m(a, b) @torch.library.register_fake("_TestOpaqueObject::module_mul", lib=self.lib) - def module_mul_fake(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor: + def module_mul_fake(m: AddModule, a: torch.Tensor, b: int) -> torch.Tensor: return torch.empty_like(a) def module_mul_setup_context(ctx, inputs, output): @@ -471,7 +476,7 @@ def module_mul_backward(ctx, grad) -> torch.Tensor: class M(torch.nn.Module): def __init__(self): super().__init__() - self.moo = Moo() + self.moo = AddModule() def forward(self, x, y): b = y.item() @@ -496,6 +501,40 @@ def forward(self, primals, tangents): self.assertEqual(compiled_fn(*inp), M()(*inp)) + def test_invalid_schema(self): + with self.assertRaisesRegex( + RuntimeError, + "unknown type specifier", + ): + torch.library.define( + "_TestOpaqueObject::invalid_op1", + "(foo.bar.baz a) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=self.lib, + ) + + with self.assertRaisesRegex( + RuntimeError, + r"expected \) but found 'dots' here", + ): + torch.library.define( + "_TestOpaqueObject::invalid_op2", + "(......... a) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=self.lib, + ) + + with self.assertRaisesRegex( + RuntimeError, + "unknown type specifier", + ): + torch.library.define( + "_TestOpaqueObject::invalid_op5", + "(MyNamespace..MyClass a) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=self.lib, + ) + instantiate_parametrized_tests(TestOpaqueObject) diff --git a/torch/_library/opaque_object.py b/torch/_library/opaque_object.py index ce9b9cfe38a57..6ceebbf7ef1d6 100644 --- a/torch/_library/opaque_object.py +++ b/torch/_library/opaque_object.py @@ -1,4 +1,4 @@ -from typing import Any, NewType, Optional +from typing import Any, NewType import torch @@ -155,23 +155,41 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None: _OPAQUE_TYPES: dict[Any, str] = {} -def register_opaque_type(cls: Any, name: Optional[str] = None) -> None: +def get_opaque_type_name(cls: Any) -> str: + """ + Gets the registered opaque type name for a given class. + + Args: + cls (type): The class to get the type name for. + + Returns: + str: The registered type name for the class. + + Raises: + ValueError: If the class is not registered as an opaque type. + """ + if cls not in _OPAQUE_TYPES: + raise ValueError( + f"Class {cls} is not registered as an opaque type. " + f"Call register_opaque_type({cls.__name__}) first." + ) + return _OPAQUE_TYPES[cls] + + +def register_opaque_type(cls: Any) -> None: """ Registers the given type as an opaque type which allows this to be consumed by a custom operator. + The type name will be automatically generated from the class's fully + qualified name (ex. my_module.MyClass). + Args: cls (type): The class to register as an opaque type. - name (str): A unique qualified name of the type. """ - if name is None: - name = cls.__name__ + # Generate a fully qualified name by combining module and qualname + name = f"{cls.__module__}.{cls.__qualname__}" - if "." in name: - # The schema_type_parser will break up types with periods - raise ValueError( - f"Unable to accept name, {name}, for this opaque type as it contains a '.'" - ) _OPAQUE_TYPES[cls] = name torch._C._register_opaque_type(name) diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index 735856dc10a7c..ec3d74c398779 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -101,6 +101,14 @@ TypePtr SchemaTypeParser::parseBaseType() { } std::string text = tok.text(); + // Check if this might be a dotted identifier (for opaque types) + // Keep consuming '.' + IDENT sequences to build fully qualified names + while (L.cur().kind == '.' && L.lookahead().kind == TK_IDENT) { + L.next(); // consume '.' + auto ident_tok = L.expect(TK_IDENT); + text += "." + ident_tok.text(); + } + // Check if this type is registered as an opaque type first if (isRegisteredOpaqueType(text)) { return c10::PyObjectType::get(); From 7cbc2d034cecd21ab5c9707d0a9c525c17143fb8 Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 3 Dec 2025 09:12:04 -0800 Subject: [PATCH 219/338] [opaque_obj] Remove inital opaque obj (#167740) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167740 Approved by: https://github.com/zou3519 ghstack dependencies: #167739 --- test/test_custom_ops.py | 3 - test/test_opaque_obj.py | 268 -------------------------- torch/_C/__init__.pyi.in | 3 - torch/_library/fake_class_registry.py | 2 +- torch/_library/infer_schema.py | 3 +- torch/_library/opaque_object.py | 132 +------------ torch/csrc/jit/python/init.cpp | 29 --- 7 files changed, 3 insertions(+), 437 deletions(-) delete mode 100644 test/test_opaque_obj.py diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index bcc9c377e5049..5098f05744ad2 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -34,7 +34,6 @@ TensorMetadata, ) from torch._library.infer_schema import tuple_to_list -from torch._library.opaque_object import make_opaque, OpaqueType from torch._utils_internal import get_file_path_2 # @manual from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv @@ -903,8 +902,6 @@ def _generate_examples(self, typ): return [torch.tensor(3)] if typ == Optional[torch.types.Number]: return [None, 2.718] - if typ == OpaqueType: - return [make_opaque("moo")] origin = typing.get_origin(typ) if origin is Union: args = typing.get_args(typ) diff --git a/test/test_opaque_obj.py b/test/test_opaque_obj.py deleted file mode 100644 index 2c47ffc5b59b6..0000000000000 --- a/test/test_opaque_obj.py +++ /dev/null @@ -1,268 +0,0 @@ -# Owner(s): ["module: custom-operators"] -import copy - -import torch -from torch._dynamo.test_case import run_tests, TestCase -from torch._library.fake_class_registry import maybe_to_fake_obj -from torch._library.opaque_object import ( - get_payload, - make_opaque, - OpaqueType, - set_payload, -) -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import make_fx -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) - - -class OpaqueQueue: - def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None: - super().__init__() - self.queue = queue - self.init_tensor_ = init_tensor_ - - # For testing purposes - self._push_counter = 0 - self._pop_counter = 0 - self._size_counter = 0 - - def push(self, tensor: torch.Tensor) -> None: - self._push_counter += 1 - self.queue.append(tensor) - - def pop(self) -> torch.Tensor: - self._pop_counter += 1 - if len(self.queue) > 0: - return self.queue.pop(0) - return self.init_tensor_ - - def size(self) -> int: - self._size_counter += 1 - return len(self.queue) - - def __eq__(self, other): - if len(self.queue) != len(other.queue): - return False - for q1, q2 in zip(self.queue, other.queue): - if not torch.allclose(q1, q2): - return False - return torch.allclose(self.init_tensor_, other.init_tensor_) - - -class TestOpaqueObject(TestCase): - def setUp(self): - self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901 - - torch.library.define( - "_TestOpaqueObject::queue_push", - "(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=self.lib, - ) - - @torch.library.impl( - "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib - ) - def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None: - queue = get_payload(q) - assert isinstance(queue, OpaqueQueue) - queue.push(b) - - @torch.library.register_fake("_TestOpaqueObject::queue_push", lib=self.lib) - def push_impl_fake(q: torch._C.ScriptObject, b: torch.Tensor) -> None: - pass - - self.lib.define( - "queue_pop(__torch__.torch.classes.aten.OpaqueObject a) -> Tensor", - ) - - def pop_impl(q: torch._C.ScriptObject) -> torch.Tensor: - queue = get_payload(q) - assert isinstance(queue, OpaqueQueue) - return queue.pop() - - self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd") - - def pop_impl_fake(q: torch._C.ScriptObject) -> torch.Tensor: - # This is not accurate since the queue could have tensors that are - # not rank 1 - ctx = torch._custom_op.impl.get_ctx() - u0 = ctx.new_dynamic_size() - return torch.empty(u0) - - self.lib._register_fake("queue_pop", pop_impl_fake) - - @torch.library.custom_op( - "_TestOpaqueObject::queue_size", - mutates_args=[], - ) - def size_impl(q: OpaqueType) -> int: - queue = get_payload(q) - assert isinstance(queue, OpaqueQueue) - return queue.size() - - @size_impl.register_fake - def size_impl_fake(q: torch._C.ScriptObject) -> int: - ctx = torch._custom_op.impl.get_ctx() - u0 = ctx.new_dynamic_size() - return u0 - - super().setUp() - - def tearDown(self): - self.lib._destroy() - - super().tearDown() - - def test_creation(self): - queue = OpaqueQueue([], torch.zeros(3)) - obj = make_opaque(queue) - self.assertTrue(isinstance(obj, torch._C.ScriptObject)) - self.assertEqual(str(obj._type()), "__torch__.torch.classes.aten.OpaqueObject") - - # obj.payload stores a direct reference to this python queue object - payload = get_payload(obj) - self.assertEqual(payload, queue) - queue.push(torch.ones(3)) - self.assertEqual(payload.size(), 1) - - def test_ops(self): - queue = OpaqueQueue([], torch.zeros(3)) - obj = make_opaque() - set_payload(obj, queue) - - torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3) + 1) - self.assertEqual(queue.size(), 1) - size = torch.ops._TestOpaqueObject.queue_size(obj) - self.assertEqual(size, queue.size()) - popped = torch.ops._TestOpaqueObject.queue_pop(obj) - self.assertEqual(popped, torch.ones(3) + 1) - self.assertEqual(queue.size(), 0) - - def test_eq(self): - self.assertTrue(make_opaque("moo") == make_opaque("moo")) - self.assertFalse(make_opaque("moo") == make_opaque("mop")) - - q1 = OpaqueQueue([torch.ones(3)], torch.zeros(3)) - q2 = OpaqueQueue([torch.ones(3)], torch.zeros(3)) - obj1 = make_opaque(q1) - obj2 = make_opaque(q2) - self.assertTrue(obj1 == obj1) - self.assertTrue(q1 == q2) - self.assertTrue(obj1 == obj2) - - def test_deepcopy(self): - q1 = OpaqueQueue([torch.ones(3), torch.ones(3) * 2], torch.zeros(3)) - obj1 = make_opaque(q1) - - obj2 = copy.deepcopy(obj1) - q2 = get_payload(obj2) - - self.assertTrue(q1 is not q2) - self.assertTrue(q1 == q2) - - def test_bad_fake(self): - torch.library.define( - "_TestOpaqueObject::bad_fake", - "(__torch__.torch.classes.aten.OpaqueObject q, Tensor x) -> Tensor", - lib=self.lib, - ) - - def f(q, x): - torch.ops._TestOpaqueObject.bad_fake(q, x) - return x.cos() - - def bad_fake1(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor: - payload = get_payload(q) - return b * payload - - torch.library.register_fake( - "_TestOpaqueObject::bad_fake", bad_fake1, lib=self.lib - ) - - with FakeTensorMode() as fake_mode: - obj = make_opaque(1) - fake_obj = maybe_to_fake_obj(fake_mode, obj) - x = torch.ones(3) - - with self.assertRaisesRegex( - ValueError, - "get_payload: this function was called with a FakeScriptObject", - ): - torch.ops._TestOpaqueObject.bad_fake(fake_obj, x) - - def bad_fake2(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor: - set_payload(q, 2) - return torch.empty_like(b) - - torch.library.register_fake( - "_TestOpaqueObject::bad_fake", bad_fake2, lib=self.lib, allow_override=True - ) - - with FakeTensorMode() as fake_mode: - obj = make_opaque(1) - fake_obj = maybe_to_fake_obj(fake_mode, obj) - x = torch.ones(3) - - with self.assertRaisesRegex( - ValueError, - "set_payload: this function was called with a FakeScriptObject", - ): - torch.ops._TestOpaqueObject.bad_fake(fake_obj, x) - - @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) - def test_make_fx(self, make_fx_tracing_mode): - class M(torch.nn.Module): - def forward(self, queue, x): - torch.ops._TestOpaqueObject.queue_push(queue, x.tan()) - torch.ops._TestOpaqueObject.queue_push(queue, x.cos()) - torch.ops._TestOpaqueObject.queue_push(queue, x.sin()) - pop1 = torch.ops._TestOpaqueObject.queue_pop(queue) - size1 = torch.ops._TestOpaqueObject.queue_size(queue) - pop2 = torch.ops._TestOpaqueObject.queue_pop(queue) - size2 = torch.ops._TestOpaqueObject.queue_size(queue) - x_cos = pop1 + size1 - x_sin = pop2 - size2 - return x_sin + x_cos - - q1 = OpaqueQueue([], torch.empty(0).fill_(-1)) - obj1 = make_opaque(q1) - q2 = OpaqueQueue([], torch.empty(0).fill_(-1)) - obj2 = make_opaque(q2) - - x = torch.ones(2, 3) - gm = make_fx(M(), tracing_mode=make_fx_tracing_mode)(obj1, x) - self.assertTrue(torch.allclose(gm(obj1, x), M()(obj2, x))) - self.assertEqual(q1._push_counter, 3) - self.assertEqual(q1._pop_counter, 2) - self.assertEqual(q1._size_counter, 2) - self.assertEqual(q1.size(), 1) - self.assertExpectedInline( - gm.code.strip("\n"), - """\ -def forward(self, arg0_1, arg1_1): - tan = torch.ops.aten.tan.default(arg1_1) - queue_push = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, tan); tan = queue_push = None - cos = torch.ops.aten.cos.default(arg1_1) - queue_push_1 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, cos); cos = queue_push_1 = None - sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None - queue_push_2 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, sin); sin = queue_push_2 = None - queue_pop = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1) - queue_size = torch.ops._TestOpaqueObject.queue_size.default(arg0_1) - queue_pop_1 = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1) - queue_size_1 = torch.ops._TestOpaqueObject.queue_size.default(arg0_1); arg0_1 = None - add = torch.ops.aten.add.Tensor(queue_pop, queue_size); queue_pop = queue_size = None - sub = torch.ops.aten.sub.Tensor(queue_pop_1, queue_size_1); queue_pop_1 = queue_size_1 = None - add_1 = torch.ops.aten.add.Tensor(sub, add); sub = add = None - return add_1 - """, - ) - - -instantiate_parametrized_tests(TestOpaqueObject) - -if __name__ == "__main__": - run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 520d07d487270..9dc460d9522fa 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1633,9 +1633,6 @@ def _jit_pass_cse(Graph) -> _bool: ... def _jit_pass_dce(Graph) -> None: ... def _jit_pass_dce_graph(Graph) -> None: ... def _jit_pass_lint(Graph) -> None: ... -def _make_opaque_object(payload: Any) -> ScriptObject: ... -def _get_opaque_object_payload(obj: ScriptObject) -> Any: ... -def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ... def _register_opaque_type(type_name: str) -> None: ... def _is_opaque_type_registered(type_name: str) -> _bool: ... diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index 474df5116e460..c1bd96ec0589d 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -167,7 +167,7 @@ def maybe_to_fake_obj( OpaqueTypeStr, ) - if x is None or is_opaque_type(type(x)) or str(x._type()) == OpaqueTypeStr: + if x is None or is_opaque_type(type(x)): # In order to make OpaqueObjects truly opaque, the fake kernel should # not depend on the contents of the OpaqueObject at all. fake_x_wrapped = FakeScriptObject(FakeOpaqueObject(), OpaqueTypeStr, None) diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 8c10a23dab881..81189595297b1 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -9,7 +9,7 @@ from torch import device, dtype, Tensor, types from torch.utils._exposed_in import exposed_in -from .opaque_object import _OPAQUE_TYPES, is_opaque_type, OpaqueType, OpaqueTypeStr +from .opaque_object import _OPAQUE_TYPES, is_opaque_type # This is used as a negative test for @@ -263,7 +263,6 @@ def get_supported_param_types(): (types.Number, "Scalar", True, False, False), (dtype, "ScalarType", False, False, False), (device, "Device", False, False, False), - (OpaqueType, OpaqueTypeStr, False, False, False), ] result = [] for line in data: diff --git a/torch/_library/opaque_object.py b/torch/_library/opaque_object.py index 6ceebbf7ef1d6..775ee23ba0b88 100644 --- a/torch/_library/opaque_object.py +++ b/torch/_library/opaque_object.py @@ -2,7 +2,7 @@ import torch -from .fake_class_registry import FakeScriptObject, register_fake_class +from .fake_class_registry import register_fake_class @register_fake_class("aten::OpaqueObject") @@ -22,136 +22,6 @@ def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None: OpaqueType = NewType("OpaqueType", torch._C.ScriptObject) - -def make_opaque(payload: Any = None) -> torch._C.ScriptObject: - """ - Creates an opaque object which stores the given Python object. - This opaque object can be passed to any custom operator as an argument. - The Python object can then be accessed from the opaque object using the `get_payload()` API. - The opaque object has `._type()` - "__torch__.torch.classes.aten.OpaqueObject", which should be the type used - when creating custom operator schemas. - - Args: - payload (Any): The Python object to store in the opaque object. This can - be empty, and can be set with `set_payload()` later. - - Returns: - torch._C.ScriptObject: The opaque object that stores the given Python object. - - Example: - - >>> import random - >>> import torch - >>> from torch._library.opaque_object import ( - ... make_opaque, - ... get_payload, - ... set_payload, - ... ) - >>> - >>> class RNGState: - >>> def __init__(self, seed): - >>> self.rng = random.Random(seed) - >>> - >>> rng = RNGState(0) - >>> obj = make_opaque() - >>> set_payload(obj, rng) - >>> - >>> assert get_payload(obj) == rng - >>> - >>> lib = torch.library.Library("mylib", "FRAGMENT") - >>> - >>> torch.library.define( - >>> "mylib::noisy_inject", - >>> "(Tensor x, __torch__.torch.classes.aten.OpaqueObject obj) -> Tensor", - >>> tags=torch.Tag.pt2_compliant_tag, - >>> lib=lib, - >>> ) - >>> - >>> @torch.library.impl( - >>> "mylib::noisy_inject", "CompositeExplicitAutograd", lib=lib - >>> ) - >>> def noisy_inject(x: torch.Tensor, obj: torch._C.ScriptObject) -> torch.Tensor: - >>> rng_state = get_payload(obj) - >>> assert isinstance(rng_state, RNGState) - >>> out = x.clone() - >>> for i in range(out.numel()): - >>> out.view(-1)[i] += rng_state.rng.random() - >>> return out - >>> - >>> print(torch.ops.mylib.noisy_inject(torch.ones(3), obj)) - """ - return torch._C._make_opaque_object(payload) - - -def get_payload(opaque_object: torch._C.ScriptObject) -> Any: - """ - Retrieves the Python object stored in the given opaque object. - - Args: - torch._C.ScriptObject: The opaque object that stores the given Python object. - - Returns: - payload (Any): The Python object stored in the opaque object. This can - be set with `set_payload()`. - """ - if isinstance(opaque_object, FakeScriptObject): - raise ValueError( - "get_payload: this function was called with a FakeScriptObject " - "implying that you are calling get_payload inside of a fake kernel." - "The fake kernel should not depend on the contents of the " - "OpaqueObject at all, so we're erroring out. If you need this" - "functionality, consider creating a custom TorchBind Object instead" - "(but note that this is more difficult)." - ) - if not ( - isinstance(opaque_object, torch._C.ScriptObject) - and opaque_object._type().qualified_name() == OpaqueTypeStr - ): - type_ = ( - opaque_object._type().qualified_name() - if isinstance(opaque_object, torch._C.ScriptObject) - else type(opaque_object) - ) - raise ValueError( - f"Tried to get the payload from a non-OpaqueObject of type `{type_}`" - ) - return torch._C._get_opaque_object_payload(opaque_object) - - -def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None: - """ - Sets the Python object stored in the given opaque object. - - Args: - torch._C.ScriptObject: The opaque object that stores the given Python object. - payload (Any): The Python object to store in the opaque object. - """ - if isinstance(opaque_object, FakeScriptObject): - raise ValueError( - "set_payload: this function was called with a FakeScriptObject " - "implying that you are calling get_payload inside of a fake kernel." - "The fake kernel should not depend on the contents of the " - "OpaqueObject at all, so we're erroring out. If you need this" - "functionality, consider creating a custom TorchBind Object instead" - "(but note that this is more difficult)." - ) - - if not ( - isinstance(opaque_object, torch._C.ScriptObject) - and opaque_object._type().qualified_name() == OpaqueTypeStr - ): - type_ = ( - opaque_object._type().qualified_name() - if isinstance(opaque_object, torch._C.ScriptObject) - else type(opaque_object) - ) - raise ValueError( - f"Tried to get the payload from a non-OpaqueObject of type `{type_}`" - ) - torch._C._set_opaque_object_payload(opaque_object, payload) - - _OPAQUE_TYPES: dict[Any, str] = {} diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a7f16a7dc5a04..82a11af3714b4 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1862,35 +1862,6 @@ void initJITBindings(PyObject* module) { &parseSchema, py::arg("schema"), py::arg("allow_typevars") = true); - m.def( - "_make_opaque_object", - [](py::object payload) { - auto obj = c10::make_intrusive(payload); - auto typePtr = - torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject"); - return torch::jit::toPyObject(c10::IValue(std::move(obj))); - }, - R"doc(Creates an opaque object which stores the given Python object.)doc"); - m.def( - "_get_opaque_object_payload", - [](py::object obj) { - auto typePtr = - torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject"); - auto ivalue = torch::jit::toIValue(std::move(obj), typePtr); - auto customObj = ivalue.toCustomClass(); - return customObj->getPayload(); - }, - R"doc(Returns the Python object stored on the given opaque object.)doc"); - m.def( - "_set_opaque_object_payload", - [](py::object obj, py::object payload) { - auto typePtr = - torch::getCustomClass("__torch__.torch.classes.aten.OpaqueObject"); - auto ivalue = torch::jit::toIValue(std::move(obj), typePtr); - auto customObj = ivalue.toCustomClass(); - customObj->setPayload(std::move(payload)); - }, - R"doc(Sets the payload of the given opaque object with the given Python object.)doc"); m.def( "_register_opaque_type", [](const std::string& type_name) { From d16447dacaf2420ea175f0c275c75da951f57d39 Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 3 Dec 2025 09:12:04 -0800 Subject: [PATCH 220/338] [opaque obj] Set type of FakeScriptObj to be OpaqueObj type (#167741) When we create the FakeScriptObject we provide the name of the type of the FakeScriptObject. In the case of actual ScriptObjects this would be the CustomClassHolder class name, and for opaque objects this would've previously been a dummy name OpaqueTypeStr. This PR changes the dummy name to be the actual opaque object type (addresses https://github.com/pytorch/pytorch/pull/163936#discussion_r2519860197). Pull Request resolved: https://github.com/pytorch/pytorch/pull/167741 Approved by: https://github.com/zou3519 ghstack dependencies: #167739, #167740 --- torch/_dynamo/variables/script_object.py | 6 ++++-- torch/_library/fake_class_registry.py | 4 +++- torch/_library/opaque_object.py | 3 +++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index af7bd985287d7..25568760b159c 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -25,7 +25,7 @@ import torch from torch._guards import Source -from torch._library.opaque_object import is_opaque_type, OpaqueTypeStr +from torch._library.opaque_object import is_opaque_type from torch.fx.proxy import Proxy from .. import graph_break_hints @@ -81,7 +81,9 @@ def as_proxy(self) -> Proxy: "Dynamo cannot safely trace script object due to graph break." ) def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: - if getattr(self.value, "script_class_name", "") == OpaqueTypeStr: + if hasattr(self.value, "script_class_name") and is_opaque_type( + self.value.script_class_name + ): unimplemented( gb_type="Attempted to access attributes/methods on an OpaqueObject", context=f"value={self.value}, attr={name}", diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index c1bd96ec0589d..57342a752a84b 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -163,6 +163,7 @@ def maybe_to_fake_obj( from torch._library.opaque_object import ( FakeOpaqueObject, + get_opaque_type_name, is_opaque_type, OpaqueTypeStr, ) @@ -170,7 +171,8 @@ def maybe_to_fake_obj( if x is None or is_opaque_type(type(x)): # In order to make OpaqueObjects truly opaque, the fake kernel should # not depend on the contents of the OpaqueObject at all. - fake_x_wrapped = FakeScriptObject(FakeOpaqueObject(), OpaqueTypeStr, None) + type_name = OpaqueTypeStr if x is None else get_opaque_type_name(type(x)) + fake_x_wrapped = FakeScriptObject(FakeOpaqueObject(), type_name, None) return fake_x_wrapped else: # x.__obj_flatten__() could be calling some tensor operations inside but we don't diff --git a/torch/_library/opaque_object.py b/torch/_library/opaque_object.py index 775ee23ba0b88..567e2c837db7a 100644 --- a/torch/_library/opaque_object.py +++ b/torch/_library/opaque_object.py @@ -69,6 +69,9 @@ def is_opaque_type(cls: Any) -> bool: """ Checks if the given type is an opaque type. """ + if isinstance(cls, str): + return torch._C._is_opaque_type_registered(cls) + if cls not in _OPAQUE_TYPES: return False From d78f52b199c547106d4cd9d2856dd0805c118bf1 Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 3 Dec 2025 09:12:05 -0800 Subject: [PATCH 221/338] [opaque obj] Improve error msg for intermediate opaques (#167742) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167742 Approved by: https://github.com/zou3519 ghstack dependencies: #167739, #167740, #167741 --- test/test_opaque_obj_v2.py | 20 +++++++++++++++++++- torch/_dynamo/graph_break_registry.json | 11 +++++++++++ torch/_dynamo/variables/torch.py | 23 +++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/test/test_opaque_obj_v2.py b/test/test_opaque_obj_v2.py index 99ff9058eda52..3015defd88349 100644 --- a/test/test_opaque_obj_v2.py +++ b/test/test_opaque_obj_v2.py @@ -6,6 +6,7 @@ import torch from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import AotEagerAndRecordGraphs +from torch._dynamo.utils import counters as dynamo_counters from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, @@ -376,7 +377,7 @@ def forward(self, arg0_1, arg1_1): return (add,)""", # noqa: B950 ) - def test_compile_intermediate(self): + def test_compile_global(self): counter = Counter(0) def foo(x, y): @@ -417,6 +418,23 @@ def forward(self, arg0_1, arg1_1, arg2_1): return (add,)""", # noqa: B950 ) + def test_compile_create_intermediate(self): + dynamo_counters.clear() + + def foo(x, y): + counter = Counter(0) + z = torch.ops._TestOpaqueObject.increment_counter(counter, y) + x = x * z + return x + + inp = (torch.tensor(1), torch.tensor(0)) + torch.compile(foo)(*inp) + self.assertEqual(len(dynamo_counters["graph_break"]), 1) + self.assertTrue( + "Opaque object were created in the middle of the program and passed to a custom op." + in next(iter(dynamo_counters["graph_break"].keys())), + ) + def test_compile_attribute(self): counter = Counter(0) diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 7cf8e52d0197d..a425fae65a377 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3711,5 +3711,16 @@ "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." ] } + ], + "GB0367": [ + { + "Gb_type": "Opaque object were created in the middle of the program and passed to a custom op.", + "Context": "Opaque object types: {intermediate_opaques}. Function: {self.value}", + "Explanation": "Opaque objects cannot be created inside the torch.compile region. They must be created before entering the compiled function.", + "Hints": [ + "Please create the opaque object before calling torch.compile ", + "and pass it in as an argument or as a global variable." + ] + } ] } diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 78d87a09713ab..e5f21ebb72961 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -41,6 +41,7 @@ import torch.fx import torch.nn from torch._guards import TracingContext +from torch._library.opaque_object import is_opaque_type from torch._logging import warning_once from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type @@ -86,6 +87,7 @@ TensorWithTFOverrideVariable, TorchFunctionModeStackVariable, ) +from .user_defined import UserDefinedObjectVariable try: @@ -1488,6 +1490,27 @@ def call_function( ) return self.call_tensor_method(tx, args, kwargs) + intermediate_opaques = [ + type(x.value) + for x in args + if x.source is None + and isinstance(x, UserDefinedObjectVariable) + and is_opaque_type(type(x.value)) + ] + if len(intermediate_opaques) > 0: + unimplemented( + gb_type="Opaque object were created in the middle of the program and passed to a custom op.", + context=f"Opaque object types: {intermediate_opaques}. Function: {self.value}", + explanation=( + "Opaque objects cannot be created inside the torch.compile region. " + "They must be created before entering the compiled function." + ), + hints=[ + "Please create the opaque object before calling torch.compile " + "and pass it in as an argument or as a global variable." + ], + ) + special_handler = self._get_handlers().get(self.value) if special_handler: result = special_handler(self, tx, *args, **kwargs) From bea4912944defdbcb8b061800caab6cbbbd01df5 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Wed, 3 Dec 2025 14:19:40 -0800 Subject: [PATCH 222/338] [Dynamo][Guard]Add the user-friendly TYPE_MATCH for type (#169025) Fix #168160 after the opensource PR https://github.com/pytorch/pytorch/pull/168272 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169025 Approved by: https://github.com/anijain2305 --- test/dynamo/test_check_type_id.py | 123 ++++++++++++++++++++++++++++++ torch/_dynamo/guards.py | 10 ++- 2 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 test/dynamo/test_check_type_id.py diff --git a/test/dynamo/test_check_type_id.py b/test/dynamo/test_check_type_id.py new file mode 100644 index 0000000000000..4f63c140246ef --- /dev/null +++ b/test/dynamo/test_check_type_id.py @@ -0,0 +1,123 @@ +# Owner(s): ["module: dynamo"] +""" +Test for TYPE_MATCH guard and ___check_type_id function. + +This test demonstrates how the TYPE_MATCH guard works in PyTorch Dynamo. +When a function is compiled, Dynamo installs guards to ensure the compiled +code remains valid. TYPE_MATCH guards ensure that values maintain their +exact type (using type identity, not just type equality). +""" + +import re + +import torch +import torch._dynamo +import torch._dynamo.test_case +from torch._dynamo.eval_frame import _debug_get_cache_entry_list +from torch.testing._internal.common_utils import munge_exc + + +class TestCheckTypeId(torch._dynamo.test_case.TestCase): + @staticmethod + def _find_guard_lines(guard_manager_str: str, keyword: str) -> list[str]: + # Normalize and anonymize type IDs, then return lines containing the keyword + normalized = re.sub( + r"\d{7,}", "", munge_exc(guard_manager_str), flags=re.MULTILINE + ) + pattern = re.compile(rf"^.*{re.escape(keyword)}.*$", re.MULTILINE) + return pattern.findall(normalized) + + def test_type_match_with_different_values(self): + """ + Test that TYPE_MATCH guard correctly identifies type mismatches. + + This test compiles a function that uses a global variable and verifies: + 1. The compiled function works with values of the same type + 2. The function recompiles when the type changes + 3. The ___check_type_id/check_obj_id guard is present in the generated code + 4. The check_type_id should present the user-friendly code that specify the type + """ + + # Define a global variable that we'll guard on + class Config: + multiplier = 2 # int type + + def fn(x): + # This will trigger a TYPE_MATCH guard on Config.multiplier + return x * Config.multiplier + + # Compile the function + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + # First call - should compile and install guards + x = torch.randn(4) + result1 = opt_fn(x) + expected1 = x * 2 + self.assertTrue(torch.allclose(result1, expected1)) + + # Get the cache entry to inspect guards + cache_entries = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(cache_entries), 1) + + # Check that the guard string contains check_type_id + guard_str = str(cache_entries[0].guard_manager) + matches = self._find_guard_lines(guard_str, "ID_MATCH") + self.assertIn("___check_obj_id", matches[0]) + self.assertIn( + "type=.Config'>", + matches[0], + ) + self.assertEqual( + matches[0].split("#")[0], + "| | +- ID_MATCH: ___check_obj_id(L['Config'], ), type=.Config'> ", + ) + + def test_type_match_with_custom_classes(self): + """ + Test TYPE_MATCH guard with custom class instances. + + Demonstrates that the guard checks type identity, not structural equality. + """ + + class Point: + def __init__(self, x, y): + self.x = x + self.y = y + + class Point2D: + def __init__(self, x, y): + self.x = x + self.y = y + + point = Point(1, 2) + + def fn(tensor): + # Access point's attributes, triggering TYPE_MATCH guard on point + return tensor + point.x + point.y + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + # First call with Point instance + x = torch.ones(4) + result1 = opt_fn(x) + expected1 = x + 1 + 2 + self.assertTrue(torch.allclose(result1, expected1)) + + # Verify guard contains check_type_id + cache_entries = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(cache_entries), 1) + + guard_str = str(cache_entries[0].guard_manager) + matches = self._find_guard_lines(guard_str, "TYPE_MATCH") + self.assertEqual( + matches[0].split("#")[0], + "| | +- TYPE_MATCH: ___check_type_id(L['point'], ), type=.Point'> ", + ) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 1a5f235ad916b..322578dc6444f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1946,7 +1946,7 @@ def TYPE_MATCH(self, guard: Guard) -> None: obj_id = self.id_ref(t, f"type({guard.name})") type_repr = repr(t) - code = f"___check_type_id({self.arg_ref(guard)}, {obj_id}) # {type_repr}" + code = f"___check_type_id({self.arg_ref(guard)}, {obj_id}), type={type_repr}" self._set_guard_export_info(guard, [code]) self.get_guard_manager(guard).add_type_match_guard( @@ -2048,7 +2048,13 @@ def id_match_unchecked( ref = self.arg_ref(guard) val = self.get(guard.name) id_val = self.id_ref(val, guard.name) - code = f"___check_obj_id({ref}, {id_val})" + try: + type_repr = repr(val) + except Exception: + # During deepcopy reconstruction or other state transitions, + # objects may be in an incomplete state where repr() fails + type_repr = f"<{type(val).__name__}>" + code = f"___check_obj_id({ref}, {id_val}), type={type_repr}" self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH") self.get_guard_manager(guard).add_id_match_guard( id_val, get_verbose_code_parts(code, guard, recompile_hint) From b3a7edb2311367974cc7cd764cfb11a5d6758b24 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 01:36:32 +0000 Subject: [PATCH 223/338] Revert "Add public documentation for stable_topological_sort (#169498)" This reverts commit 320de0c6b0a3e7c6d2693ea5c28d5d0156ba7991. Reverted https://github.com/pytorch/pytorch/pull/169498 on behalf of https://github.com/huydhn due to The doc test failure is legit ([comment](https://github.com/pytorch/pytorch/pull/169498#issuecomment-3609569000)) --- docs/source/conf.py | 1 + docs/source/fx.md | 3 --- test/allowlist_for_publicAPI.json | 2 +- torch/fx/passes/utils/fuser_utils.py | 4 ++-- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7a3663ca062df..5c404f8c129fc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -950,6 +950,7 @@ "get_node_target", "is_node_output_tensor", "legalize_graph", + "stable_topological_sort", # torch.fx.passes.utils.common "compare_graphs", "lift_subgraph_as_module", diff --git a/docs/source/fx.md b/docs/source/fx.md index 4ce1c9d01f06a..b8447b378d3f9 100644 --- a/docs/source/fx.md +++ b/docs/source/fx.md @@ -1096,9 +1096,6 @@ The set of leaf modules can be customized by overriding ```{eval-rst} .. autofunction:: torch.fx.traceback.annotate ``` -```{eval-rst} -.. autofunction:: torch.fx.passes.tools_common.stable_topological_sort -`` diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index d01d41d37997e..b6c203aea4ab6 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2090,7 +2090,7 @@ "SimpleQueue", "Tuple", "compatibility", - "legalize_graph", + "stable_topological_sort", "lift_subgraph_as_module" ], "torch.fx.tensor_type": [ diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index ea264e9fb2641..e5509187b39dd 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -7,7 +7,7 @@ from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule from torch.fx.node import Node -from torch.fx.passes.tools_common import NodeList, NodeSet +from torch.fx.passes.tools_common import NodeList, NodeSet, stable_topological_sort from torch.fx.passes.utils import lift_subgraph_as_module # type: ignore[attr-defined] @@ -283,7 +283,7 @@ def fuse_by_partitions( erase_nodes(gm, sorted_nodes) - torch.fx.passes.tools_common.stable_topological_sort(gm) + stable_topological_sort(gm) gm.graph.lint() return gm From b2b6b034c9fd08672c40e63ef243556ad4c49bd2 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Thu, 4 Dec 2025 01:38:20 +0000 Subject: [PATCH 224/338] [export] Make RNNs exportable on GPUs (#163245) Summary: Fixes https://github.com/pytorch/pytorch/issues/155309 `torch.export` fails to export an RNN model on GPU with `cudnn` enabled. This is because RNN module's `flatten_parameters()` method calls `p.data_ptr()` for aliasing detection, causing "Cannot access data pointer of Tensor" errors on GPU. This PR fixes this issue by disabling cudnn during export time. Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_lstm_gpu buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_gru_gpu buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_rnn_flatten_parameters Rollback Plan: Differential Revision: D82687470 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163245 Approved by: https://github.com/tugsbayasgalan --- test/export/test_export.py | 83 ++++++++++++++++++++++++++++++- test/export/test_export_opinfo.py | 8 ++- torch/export/_trace.py | 2 + 3 files changed, 90 insertions(+), 3 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 3a996faf5ed99..788aa8518f94b 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -78,6 +78,7 @@ IS_WINDOWS, run_tests, skipIfCrossRef, + skipIfRocm, skipIfXpu, TEST_TRANSFORMERS, TEST_WITH_CROSSREF, @@ -8104,6 +8105,84 @@ def _patch_config(kwargs): ): _ = export(mod, inp, strict=True) + @requires_gpu + @skipIfRocm + @testing.expectedFailureSerDer + @testing.expectedFailureSerDerNonStrict + def test_export_lstm_gpu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.rnn = torch.nn.LSTM( + input_size=4, hidden_size=5, num_layers=1, batch_first=True + ) + + def forward(self, x): + out, _ = self.rnn(x) + return out + + m = M().to(GPU_TYPE) + x = torch.randn(2, 3, 4, device=GPU_TYPE) + + ep = export(m, (x,)) + self.assertTrue(callable(ep.module())) + + eager_out = m(x) + export_out = ep.module()(x) + self.assertEqual(eager_out, export_out) + + @requires_gpu + @skipIfRocm + @testing.expectedFailureSerDer + @testing.expectedFailureSerDerNonStrict + def test_export_gru_gpu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.rnn = torch.nn.GRU( + input_size=4, hidden_size=5, num_layers=1, batch_first=True + ) + + def forward(self, x): + out, _ = self.rnn(x) + return out + + m = M().to(GPU_TYPE) + x = torch.randn(2, 3, 4, device=GPU_TYPE) + + ep = export(m, (x,)) + self.assertTrue(callable(ep.module())) + + eager_out = m(x) + export_out = ep.module()(x) + self.assertEqual(eager_out, export_out) + + @requires_gpu + @skipIfRocm + @testing.expectedFailureSerDer + @testing.expectedFailureSerDerNonStrict + def test_export_rnn_flatten_parameters_gpu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=3, hidden_size=4, num_layers=2, batch_first=True + ) + + def forward(self, x): + self.lstm.flatten_parameters() + out, (h, c) = self.lstm(x) + return out + + m = M().to(GPU_TYPE) + x = torch.randn(1, 5, 3, device=GPU_TYPE) + + ep = export(m, (x,), strict=False) + + eager_out = m(x) + export_out = ep.module()(x) + self.assertEqual(eager_out, export_out) + def test_device_to_static(self): class Module(torch.nn.Module): def forward(self, x): @@ -8729,7 +8808,7 @@ def forward(self, x): bn_num_batches_tracked = self.bn.num_batches_tracked; bn_num_batches_tracked = None _guards_fn = self._guards_fn(x); _guards_fn = None conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None - batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None + batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, False); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None return pytree.tree_unflatten((batch_norm,), self._out_spec)""", ) @@ -8750,7 +8829,7 @@ def forward(self, x): _guards_fn = self._guards_fn(x); _guards_fn = None conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None add_ = torch.ops.aten.add_.Tensor(bn_num_batches_tracked, 1); bn_num_batches_tracked = add_ = None - batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None + batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, False); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None return pytree.tree_unflatten((batch_norm,), self._out_spec)""", ) diff --git a/test/export/test_export_opinfo.py b/test/export/test_export_opinfo.py index 075fd6df119b9..5eb42c461e574 100644 --- a/test/export/test_export_opinfo.py +++ b/test/export/test_export_opinfo.py @@ -20,7 +20,12 @@ skipOps, xfail, ) -from torch.testing._internal.common_utils import run_tests, skipIfRocm, TestCase +from torch.testing._internal.common_utils import ( + IS_FBCODE, + run_tests, + skipIfRocm, + TestCase, +) from torch.utils import _pytree as pytree @@ -122,6 +127,7 @@ class TestExportOpInfo(TestCase): @skipOps( "TestExportOpInfo", "test_fake_export", export_failures | fake_export_failures ) + @unittest.skipIf(IS_FBCODE, "tests broken with unexpected successes internally") def test_fake_export(self, device, dtype, op): _test_export_helper(self, dtype, op) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 856f23f68b19e..fdffacf512c20 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -178,11 +178,13 @@ class ExportArtifact: def _ignore_backend_decomps(): orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False) orig_nnpack_flag = torch.backends.nnpack.set_flags(False) + orig_cudnn_flag = torch.backends.cudnn.set_flags(False) try: yield finally: torch.backends.mkldnn.set_flags(*orig_mkldnn_flag) torch.backends.nnpack.set_flags(*orig_nnpack_flag) + torch.backends.cudnn.set_flags(*orig_cudnn_flag) @contextmanager From 2ac3ef882afb23136adc188975f0a8802fc68adf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 01:52:57 +0000 Subject: [PATCH 225/338] Revert "[generator] Close all open generators in compile_subgraph (#157149)" This reverts commit 78adb3b3df41b45d2368b67226d2f864b78939a6. Reverted https://github.com/pytorch/pytorch/pull/157149 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break some torchrec tests ([comment](https://github.com/pytorch/pytorch/pull/157149#issuecomment-3609616560)) --- test/dynamo/test_generator.py | 8 +- ...enerators-ExceptionTest.test_except_throw} | 0 ...onTest.test_except_throw_exception_context | 0 torch/_dynamo/output_graph.py | 5 - torch/_dynamo/side_effects.py | 20 ---- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/variables/functions.py | 95 ++++++++++++++++--- 7 files changed, 85 insertions(+), 45 deletions(-) rename test/dynamo_expected_failures/{CPython313-test_exceptions-ExceptionTests.test_generator_leaking3 => CPython313-test_generators-ExceptionTest.test_except_throw} (100%) create mode 100644 test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 2a0bd874f881c..c02126c7404ff 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1009,7 +1009,7 @@ def test_close_with_side_effects(self): z = 0 def whoo(t): - nonlocal z # noqa: F824 + nonlocal z try: L.append(1) yield t.sin() @@ -1050,6 +1050,7 @@ def whoo(t): @torch.compile(backend="eager", fullgraph=True) def fn(t): + nonlocal z gen = whoo(t) i = next(gen) y = gen.close() @@ -1077,6 +1078,7 @@ def whoo(t): @torch.compile(backend="eager", fullgraph=fullgraph) def fn(t): + nonlocal z gen = whoo(t) i = next(gen) gen.close() @@ -1378,10 +1380,8 @@ def fn(t): a = next(gen) try: gen.throw(ValueError) - except StopIteration as e: - assert len(e.args) == 0 + except StopIteration: return a - raise AssertionError("Expected StopIteration") t = torch.randn(2) y = self._compile_check(fn, (t,)) diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking3 b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking3 rename to test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 7374035854c4f..6ff908ff0394f 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1629,11 +1629,6 @@ def compile_subgraph( ) self.codegen_suffix(tx, stack_values_flat, pass1) - # Close all generators opened while tracing. Needs to be done after - # pass1, as PyCodegen might try to reconstruct the generator, which - # sets LocalGeneratorObjectVariable.remaining_items - self.side_effects.close_local_generators() - # Use `pass1.uses` to selectively cache multi-user variables into a # temporary local source. This (a). speeds up loading VTs with long # chained source, and (b). avoids redundantly saving single-user VT diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 594c7fd7060aa..999bd145c3e57 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -59,7 +59,6 @@ if TYPE_CHECKING: from torch._dynamo.output_graph import OutputGraph from torch._dynamo.symbolic_convert import InstructionTranslatorBase - from torch._dynamo.variables.functions import LocalGeneratorObjectVariable from torch._dynamo.variables.lists import ListVariable @@ -135,7 +134,6 @@ def __init__( self.keepalive = keepalive or [] self.save_for_backward = save_for_backward or [] self.tensor_hooks = tensor_hooks or {} - self.local_generators: list[LocalGeneratorObjectVariable] = [] # Used by MappingProxyVariable to graph break in case of any mutated # dict self._has_existing_dict_mutation = False @@ -229,24 +227,6 @@ def should_allow_side_effects_in_hop(self) -> bool: and output_graph.current_tx.output.current_tracer.allow_side_effects_in_hop ) - def track_generator(self, gen: "LocalGeneratorObjectVariable") -> None: - self.local_generators.append(gen) - - def untrack_generator(self, gen: "LocalGeneratorObjectVariable") -> None: - self.local_generators.remove(gen) - - def close_local_generators(self) -> None: - from .symbolic_convert import temporarely_allow_writes_to_output_graph - - output_graph = self.output_graph_weakref() - if output_graph: - tx = output_graph.root_tx - with temporarely_allow_writes_to_output_graph(tx): - for gen in self.local_generators: - if not gen._is_generator_exhausted(): - # pyrefly: ignore[bad-argument-type] - gen.call_method(tx, "close", [], {}) - def is_reconstructing_generator(self) -> bool: output_graph = self.output_graph_weakref() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 5410de2f4365e..f401b9d6178b9 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1016,7 +1016,7 @@ class ExceptionStack: # and "stack" sometimes refers to a C variable with the same name and the # exception stack, respectively. # - # The lifetime of an exception in Python 3.11+ is: + # The lifetime of an exception is (Python 3.11+): # + tx._raise_exception_variable(...) := sets the current_exception variable # + PUSH_EXC_INFO := pushes the current_exception to the *exception stack* # + POP_EXCEPT := pops TOS from the *exception stack* diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index df37a5d9a4cbc..f493e0e1fd961 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -45,6 +45,7 @@ from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..exc import ( format_skip_frame_message, + get_dynamo_observed_exception, handle_observed_exception, InfiniteGeneratorError, ObservedException, @@ -907,7 +908,6 @@ def __init__( self.code = code self.f_globals = f_globals self.inline_tracer = inline_tracer - inline_tracer.output.side_effects.track_generator(self) def get_code(self) -> types.CodeType: return self.code @@ -976,12 +976,9 @@ def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: # created on call_function. Any exception needs to be propagated to tx # for Dynamo to behave correctly return tracer.inline_call_() - except ObservedUserStopIteration: - tracer.output.side_effects.untrack_generator(self) - raise - except ObservedException: + except ObservedException as e: tracer.generator_exhausted = True - raise + raise e except InfiniteGeneratorError: # test/dynamo/test_misc.py::test_iterator_limit raise @@ -1023,10 +1020,9 @@ def force_apply_to_var_sequence( def should_allow_nested_graph_breaks(self): return False - def _setup_and_raise_exception( + def _setup_exception( self, tx: "InstructionTranslator", exc: VariableTracker ) -> None: - # Raise an exception at the point where the generator is paused tracer = self.inline_tracer try: tracer._raise_exception_variable(exc) @@ -1090,7 +1086,7 @@ def call_method( # Raise GeneratorExit to see if user code catches it. Any other exception # is propagated to the parent frame. try: - self._setup_and_raise_exception( + self._setup_exception( tx, variables.ExceptionVariable(GeneratorExit, ()) ) # There's an extra block on Python 3.12+ to handle StopIteration @@ -1139,7 +1135,7 @@ def call_method( # returns the next value yielded by the generator. # * If the generator exits without yielding, raise StopIteration # * If the generator function does not catch the passed-in exception, - # or raises a different exception, then that new exception propagates to the caller. + # or raises a different exception, then that exception propagates to the caller. # Setup the exception table and jump target in case of try...finally tracer = self.inline_tracer @@ -1148,15 +1144,84 @@ def call_method( # In such cases, we re-raise the exception object given to avoid # creating a new object, so that IS_OP works. # See: https://github.com/pytorch/pytorch/pull/146496 - self._setup_and_raise_exception( - tx, args[1] if len(args) == 3 else args[0] - ) + self._setup_exception(tx, args[1] if len(args) == 3 else args[0]) except ObservedException: # noqa: TRY203 # propagate the exception back to the parent caller raise - # If reaches here, it means user code captured the exception - return self.next_variable(tx) + retval = self.next_variable(tx) + + # The exception raised before is still active. We need to check the exception + # table one more time to find the next target. But why? Let's walk + # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M + # + # z = 0 + # def whoo(): + # global z + # z = 0 + # try: + # yield 1 + # except ValueError: + # yield 2 + # finally: + # z += 1 + # z += 10 + # + # gen = whoo() + # next(gen) + # gen.throw(ValueError) + # print('z', z) -> z = 1 + # + # ... + # >> 58 PUSH_EXC_INFO + # + # 8 60 LOAD_GLOBAL 2 (ValueError) + # 70 CHECK_EXC_MATCH + # 72 POP_JUMP_IF_FALSE 7 (to 88) + # 74 POP_TOP + # + # 9 76 LOAD_CONST 3 (2) + # 78 YIELD_VALUE 3 <------ ValueError is still active here + # 80 RESUME 1 + # 82 POP_TOP + # 84 POP_EXCEPT + # 86 jump_backward 34 (to 20) + # ... + # + # ExceptionTable: + # 4 to 8 -> 124 [0] lasti + # 12 to 18 -> 58 [0] + # 20 to 56 -> 124 [0] lasti + # 58 to 82 -> 90 [1] lasti <------ move to 90 + # 84 to 86 -> 96 [0] + # 88 to 88 -> 90 [1] lasti + # 90 to 94 -> 96 [0] + # 96 to 116 -> 118 [1] lasti + # 118 to 122 -> 124 [0] lasti + # + # In this scenario, a generator can yield after `throw()` is called. Even + # after the exception is raised a few lines above, it remains active + # within the `78 YIELD_VALUE` instruction. When the generator resumes + # after the second yield on instruction `80 RESUME`, we cannot simply + # return the control flow to the next instruction. Instead, one must + # check the exception table (or equivalent) to find the next target + # In this case, it says the instruction pointer must be moved to 90. + # + # Without this step, if we let the trace proceed to the next + # instruction, it would follow the control flow where the exception + # raised by `throw()` was handled and swallowed, potentially leading + # to incorrect behavior. + exc_type = type("__InternalThrowException", (Exception,), {}) + + try: + self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) + self.next_variable(tx) + except get_dynamo_observed_exception(exc_type): + # We should get back the exception raised before. + pass + else: + raise_observed_exception(RuntimeError, tracer) + return retval return super().call_method(tx, name, args, kwargs) From e115f9f4e4b039f8e9a642aaa2bd8254a920541b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 02:20:31 +0000 Subject: [PATCH 226/338] Revert "Support module.to in strict export (#167555)" This reverts commit f9bd6c53624c7c0ea3772de78498326e84c2f0e7. Reverted https://github.com/pytorch/pytorch/pull/167555 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167555#issuecomment-3609697942)) --- test/dynamo/test_misc.py | 66 ----------------- test/export/test_export.py | 76 ------------------- torch/_dynamo/create_parameter_op.py | 3 +- torch/_dynamo/polyfills/__init__.py | 1 - torch/_dynamo/polyfills/loader.py | 1 - torch/_dynamo/polyfills/torch_c_nn.py | 102 -------------------------- torch/_dynamo/trace_rules.py | 2 + 7 files changed, 3 insertions(+), 248 deletions(-) delete mode 100644 torch/_dynamo/polyfills/torch_c_nn.py diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index c19ec2aa58b29..78b5c7e4553da 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -14163,72 +14163,6 @@ def _random_resize(image: torch.Tensor): self.assertTrue(224 <= h <= 256) self.assertTrue(224 <= w <= 256) - @unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device") - def test_module_to_with_shared_weights_compile(self): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) - - def forward(self, x): - token_ids = torch.randint(0, 10, (4,), device=x.device) - embedded = self.embedding(token_ids).sum() - return x.sum() + embedded.sum() - - class Container(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = Model() - - def forward(self, x): - if "cuda" in str(x.device): - mod = self.mod.to(x.device) - return mod(x) - else: - return x.sum() - - container = Container() - container_eager = copy.deepcopy(container) - with torch._dynamo.config.patch(graph_break_on_nn_param_ctor=False): - compiled = torch.compile(container, backend="eager", fullgraph=True) - - inp1 = torch.randn(4, 4, 4, device="cuda") - - # First call with CUDA input - compiled_result1 = compiled(inp1) - eager_result1 = container_eager(inp1) - same(compiled_result1, eager_result1) - - # Second call - weights are now on CUDA from first call - # This tests that .to(cuda) on already-cuda weights doesn't fail - compiled_result2 = compiled(inp1) - eager_result2 = container_eager(inp1) - same(compiled_result2, eager_result2) - - @unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device") - def test_module_to_move_compile(self): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.fc = torch.nn.Linear(10, 10) - - def forward(self, x): - x = self.fc(x) - self.to("cpu") - return x - - mod = Model().cuda() - with torch._dynamo.config.patch(graph_break_on_nn_param_ctor=False): - fn = torch.compile(mod, backend="aot_eager", fullgraph=True) - x = torch.randn(10, 10, device="cuda") - ref = fn(x) - self.assertEqual(str(mod.fc.weight.device), "cpu") - mod.cuda() - ref = fn( - x - ) # second time compile runs, we should also move the module to cpu device - self.assertEqual(str(mod.fc.weight.device), "cpu") - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/export/test_export.py b/test/export/test_export.py index 788aa8518f94b..0bb21f47f9381 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -16402,82 +16402,6 @@ def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): ignore_empty_lines=True, ) - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") - def test_module_to_with_shared_weights(self): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.embedding = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) - - def forward(self, x): - token_ids = torch.ones((4,), device=x.device, dtype=torch.int64) - embedded = self.embedding(token_ids).sum() - return x.sum() + embedded.sum() - - class Container(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = Model() - - def forward(self, x): - if "cuda" in str(x.device): - mod = self.mod.to(x.device) - return mod(x) - else: - return x.sum() - - with ( - torch._dynamo.config.patch(graph_break_on_nn_param_ctor=False), - torch._export.config.patch(use_legacy_dynamo_graph_capture=False), - ): - torch.manual_seed(0) - container = Container() - container_eager = copy.deepcopy(container) - gm = torch.export.export( - container, - (torch.randn(4, 4, 4, device="cuda"),), - strict=True, - ).module() - - self.assertExpectedInline( - str(gm.code).strip(), - """\ -def forward(self, x): - args_0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) - mod_embedding_weight = self.mod.embedding.weight - _guards_fn = self._guards_fn(args_0); _guards_fn = None - empty = torch.ops.aten.empty.memory_format([10, 8], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False) - detach = torch.ops.aten.detach.default(empty); empty = None - submod_6 = self.submod_1 - to = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_6, mod_embedding_weight); submod_6 = mod_embedding_weight = None - getitem = to[0]; to = None - submod_7 = self.submod_3 - wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_7); submod_7 = wrap_with_set_grad_enabled = None - submod_8 = self.submod_4 - view_as = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_8, detach, getitem); submod_8 = detach = getitem = None - getitem_1 = view_as[0]; view_as = None - ones = torch.ops.aten.ones.default([4], dtype = torch.int64, device = device(type='cuda', index=0), pin_memory = False) - embedding = torch.ops.aten.embedding.default(getitem_1, ones); getitem_1 = ones = None - sum_1 = torch.ops.aten.sum.default(embedding); embedding = None - sum_2 = torch.ops.aten.sum.default(args_0); args_0 = None - sum_3 = torch.ops.aten.sum.default(sum_1); sum_1 = None - add = torch.ops.aten.add.Tensor(sum_2, sum_3); sum_2 = sum_3 = None - return pytree.tree_unflatten((add,), self._out_spec)""", - ) - - inp = torch.randn(4, 4, 4, device="cuda") - - # Call container first to move shared weights to CUDA - export_out = gm(inp) - eager_out = container_eager(inp) - self.assertEqual(export_out, eager_out) - - # This should not fail even though weights are now on CUDA - # and .to(cuda) returns the same parameter with requires_grad=True - export_out_v2 = gm(inp) - eager_out_v2 = container_eager(inp) - self.assertEqual(export_out_v2, eager_out_v2) - @testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation @testing.expectedFailureStrictV2 # test_hop doesn't have a dynamo implementation @testing.expectedFailureRetraceability # test_hop doesn't have a dynamo implementation diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index a0bc7325c54b4..2a716865c3f48 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -22,8 +22,7 @@ class TracableCreateParameter(torch.autograd.Function): @staticmethod # pyrefly: ignore [bad-override] def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter: - if tensor.requires_grad: - tensor = tensor.detach() + assert not tensor.requires_grad return placeholder.set_(tensor) @staticmethod diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 56f614b265a9c..59f6f76317e6d 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -34,7 +34,6 @@ pytree as pytree, struct as struct, sys as sys, - torch_c_nn as torch_c_nn, ) from torch.overrides import BaseTorchFunctionMode diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 46e6fa6df9c3e..31479e9d86ce6 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -25,7 +25,6 @@ "sys", "fx", "tensor", - "torch_c_nn", ) if python_pytree._cxx_pytree_dynamo_traceable: POLYFILLED_MODULE_NAMES += ("pytree",) diff --git a/torch/_dynamo/polyfills/torch_c_nn.py b/torch/_dynamo/polyfills/torch_c_nn.py deleted file mode 100644 index e6eb8ce78da3e..0000000000000 --- a/torch/_dynamo/polyfills/torch_c_nn.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Polyfills for torch._C._nn functions. -""" - -import torch - -from ..decorators import substitute_in_graph - - -@substitute_in_graph(torch._C._nn._parse_to, skip_signature_check=True) -def _parse_to_polyfill(*args, **kwargs): # noqa: F821 - """ - Polyfill for torch._C._nn._parse_to that parses arguments to nn.Module.to(). - - Signature mirrors torch._C._nn._parse_to which accepts: - - to(device) - device as string or torch.device - - to(dtype) - dtype as torch.dtype - - to(tensor) - extracts device and dtype from tensor - - to(device=..., dtype=..., non_blocking=..., memory_format=...) - - Returns: - tuple: (device, dtype, non_blocking, memory_format) - """ - device = None - dtype = None - non_blocking = False - memory_format = None - - # Handle positional arguments - if len(args) == 1: - arg = args[0] - # Check if it's a tensor - if isinstance(arg, torch.Tensor): - device = arg.device - dtype = arg.dtype - # Check if it's a dtype - elif isinstance(arg, torch.dtype): - dtype = arg - # Check if it's a device (string or torch.device) - elif isinstance(arg, (str, torch.device)): - device = torch.device(arg) if isinstance(arg, str) else arg - else: - raise TypeError( - f"to() received an invalid combination of arguments. Got: {type(arg)}" - ) - elif len(args) > 1: - raise TypeError( - f"to() received too many positional arguments. Got {len(args)}, expected at most 1" - ) - - # Handle keyword arguments - if "device" in kwargs: - device_arg = kwargs["device"] - if device_arg is not None: - device = ( - torch.device(device_arg) if isinstance(device_arg, str) else device_arg - ) - - if "dtype" in kwargs: - dtype = kwargs["dtype"] - - if "non_blocking" in kwargs: - non_blocking = kwargs["non_blocking"] - - if "memory_format" in kwargs: - memory_format = kwargs["memory_format"] - - return (device, dtype, non_blocking, memory_format) - - -@substitute_in_graph(torch.__future__.get_swap_module_params_on_conversion) -def get_swap_module_params_on_conversion_polyfill() -> bool: - """ - Polyfill for torch.__future__.get_swap_module_params_on_conversion. - - Returns the default value False to allow tracing through nn.Module._apply(). - """ - return False - - -@substitute_in_graph(torch._has_compatible_shallow_copy_type) -def _has_compatible_shallow_copy_type_polyfill( - input: torch.Tensor, from_: torch.Tensor -) -> bool: - """ - Polyfill for torch._has_compatible_shallow_copy_type. - - Checks if two tensors have compatible types for shallow copying. - The C++ implementation checks if input's TensorImpl has compatible shallow copy type - with from_'s key_set. We approximate this by checking if both tensors are the same type. - """ - # Check if both tensors are the same type (handles both regular tensors and subclasses) - # This is more permissive than checking exact torch.Tensor type equality - # but properly handles subclasses by allowing same-type shallow copies - return type(input) is type(from_) - - -__all__ = [ - "_parse_to_polyfill", - "get_swap_module_params_on_conversion_polyfill", - "_has_compatible_shallow_copy_type_polyfill", -] diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 813247f4fd3c7..083c8b1f93807 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1064,6 +1064,7 @@ "torch._C._nn._conv_depthwise2d", "torch._C._nn._pad_circular", "torch._C._nn._pad_enum", + "torch._C._nn._parse_to", "torch._C._nn._test_ambiguous_defaults", "torch._C._nn._test_optional_filled_intlist", "torch._C._nn._test_optional_floatlist", @@ -1586,6 +1587,7 @@ "torch._fw_primal_copy", "torch._grid_sampler_2d_cpu_fallback", "torch._grouped_mm", + "torch._has_compatible_shallow_copy_type", "torch._histogramdd_bin_edges", "torch._histogramdd_from_bin_cts", "torch._histogramdd_from_bin_tensors", From 85a315917efe82c24306be805c584ec044951c75 Mon Sep 17 00:00:00 2001 From: tianrengao Date: Thu, 4 Dec 2025 02:43:17 +0000 Subject: [PATCH 227/338] Enable custom collective op autotuning (#167294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add collective op autotuning with distributed benchmarking Collective operations can have multiple equivalent implementations with different performance. Existing autotuning only supports single-process benchmarking, making it impossible to autotune collective ops that require multi-rank coordination. **Summary:** Added distributed benchmarking via `benchmark_collective_choice()` and `register_custom_op_autotuning()` API. All ranks are synchronized and benchmarked at the same time to ensure a fair comparison. The benchmark across all possible ranks(world_size) has a timeout detection(default 30s, configurable in configs.py) and exception handling to prevent deadlocks. If any rank times out or encounters an exception, all ranks fallback to regular benchmark together. **Example:** ```python # Define custom op with multiple implementations @torch.library.custom_op("mylib::allreduce", mutates_args=()) def my_allreduce(x: torch.Tensor) -> torch.Tensor: return torch.ops._c10d_functional.all_reduce_(x.clone(), "sum", "default") # Register autotuning choices register_custom_op_autotuning( my_allreduce, configs=[ CustomOpConfig(lambda x: all_reduce_(x.clone(), "sum", "default")), CustomOpConfig(lambda x: all_reduce_(x.clone(), "avg", "default") * world_size), ], ) # Compile and autotune model = torch.compile(MyModel()) output = model(input) # Autotuning happens here ``` **Implementation:** Different from regular benchmark which calls benchmarker.benchmark with warmup and rep, collective op benchmarking should handle barrier synchronization and timing. So I created `collective_benchmark` for subgraphcaller choice to only run the implementation once with cached compiled modules. Then the timing and synchronization will be handled in select_algorithm.py **example log** ``` Autotune Choices Stats: {"num_choices": 3, "num_triton_choices": 0, "best_kernel": "test::vllm_allreduce_autotuned_nccl_allreduce_direct_1", "best_kernel_desc": "CustomOp nccl_allreduce_direct", "best_time": 0.004397066775709391} Autotune Choices Stats: {"num_choices": 3, "num_triton_choices": 0, "best_kernel": "test::vllm_allreduce_autotuned_nccl_allreduce_direct_1", "best_kernel_desc": "CustomOp nccl_allreduce_direct", "best_time": 0.004397066775709391} Autotune Choices Stats: {"num_choices": 3, "num_triton_choices": 0, "best_kernel": "test::vllm_allreduce_autotuned_nccl_allreduce_direct_1", "best_kernel_desc": "CustomOp nccl_allreduce_direct", "best_time": 0.004397066775709391} Autotune Choices Stats: {"num_choices": 3, "num_triton_choices": 0, "best_kernel": "test::vllm_allreduce_autotuned_nccl_allreduce_direct_1", "best_kernel_desc": "CustomOp nccl_allreduce_direct", "best_time": 0.004397066775709391} [rank0]:W1126 14:09:52.629000 394749 torch/_inductor/select_algorithm.py:3996] [0/0] [COLLECTIVE AUTOTUNING] All timings: [rank0]:W1126 14:09:52.630000 394749 torch/_inductor/select_algorithm.py:3999] [0/0] - test::vllm_allreduce_autotuned_nccl_allreduce_direct_1: 0.004397 ms ← SELECTED [rank0]:W1126 14:09:52.630000 394749 torch/_inductor/select_algorithm.py:3999] [0/0] - test::vllm_allreduce_autotuned_vllm_buffer_copy_allreduce_0: 0.004507 ms [rank0]:W1126 14:09:52.630000 394749 torch/_inductor/select_algorithm.py:3999] [0/0] - test::vllm_allreduce_autotuned_fallback_default: 0.012899 ms ``` **Test Plan:** * Added `test_equivalent_allreduce_strategies` (2 ranks) and `test_all_gather_4ranks` (4 ranks) * Verified timeout detection and exception handling prevent deadlocks Pull Request resolved: https://github.com/pytorch/pytorch/pull/167294 Approved by: https://github.com/shunting314, https://github.com/mlazos, https://github.com/eellison --- test/inductor/test_collective_autotuning.py | 189 +++++++++++++++ torch/_inductor/codegen/subgraph.py | 55 ++++- torch/_inductor/config.py | 10 + torch/_inductor/kernel/custom_op.py | 47 ++-- torch/_inductor/select_algorithm.py | 244 +++++++++++++++++++- torch/_inductor/utils.py | 21 ++ 6 files changed, 532 insertions(+), 34 deletions(-) create mode 100644 test/inductor/test_collective_autotuning.py diff --git a/test/inductor/test_collective_autotuning.py b/test/inductor/test_collective_autotuning.py new file mode 100644 index 0000000000000..a5a05d05a9028 --- /dev/null +++ b/test/inductor/test_collective_autotuning.py @@ -0,0 +1,189 @@ +# Owner(s): ["module: inductor"] + +import torch +import torch.distributed as dist +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import run_tests + + +class TestCollectiveAutotuning2Ranks(MultiProcessTestCase): + """Test collective autotuning with 2 ranks""" + + @property + def world_size(self): + return 2 + + def setUp(self): + super().setUp() + self._spawn_processes() + + @skip_if_lt_x_gpu(2) + def test_equivalent_allreduce_strategies(self): + """ + Test autotuning between mathematically equivalent all_reduce strategies. + + Strategy 1: sum all_reduce + Strategy 2: avg all_reduce * world_size + """ + dist.init_process_group( + backend="nccl", + init_method=f"file:///tmp/test_equiv_allreduce_{self.id()}", + world_size=self.world_size, + rank=self.rank, + ) + + dist.barrier() + + rank = dist.get_rank() + device = f"cuda:{rank}" + + from torch._C._distributed_c10d import _register_process_group + + _register_process_group("default", dist.group.WORLD) + + @torch.library.custom_op("test::equiv_ar", mutates_args=()) + def equiv_ar(x: torch.Tensor) -> torch.Tensor: + result = x.clone() + return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") + + @equiv_ar.register_fake + def _(x): + return torch.empty_like(x) + + def sum_allreduce(x: torch.Tensor) -> torch.Tensor: + result = x.clone() + return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") + + def avg_allreduce_scaled(x: torch.Tensor) -> torch.Tensor: + result = x.clone() + result = torch.ops._c10d_functional.all_reduce_(result, "avg", "default") + return result * self.world_size + + from torch._inductor.kernel.custom_op import ( + CustomOpConfig, + register_custom_op_autotuning, + ) + + register_custom_op_autotuning( + equiv_ar, + configs=[ + CustomOpConfig(sum_allreduce), + CustomOpConfig(avg_allreduce_scaled), + ], + ) + + class EquivAllReduceModel(torch.nn.Module): + def forward(self, x): + return equiv_ar(x) + + model = torch.compile(EquivAllReduceModel()).to(device) + + torch.manual_seed(42) + x = torch.randn(128, 128, device=device) + dist.broadcast(x, src=0) + + _ = model(x) + + dist.barrier() + dist.destroy_process_group() + + +class TestCollectiveAutotuning4Ranks(MultiProcessTestCase): + """Test collective autotuning with 4 ranks""" + + @property + def world_size(self): + return 4 + + def setUp(self): + super().setUp() + self._spawn_processes() + + @skip_if_lt_x_gpu(4) + def test_vllm_style_allreduce(self): + """ + Test vLLM-style custom allreduce with buffer copy pattern. + + vLLM uses custom allreduce optimized for small tensors (<8MB). + Two implementations simulate vLLM's registered=False mode vs standard NCCL. + """ + dist.init_process_group( + backend="nccl", + init_method=f"file:///tmp/test_vllm_allreduce_{self.id()}", + world_size=self.world_size, + rank=self.rank, + ) + + dist.barrier() + + rank = dist.get_rank() + device = f"cuda:{rank}" + + from torch._C._distributed_c10d import _register_process_group + + _register_process_group("default", dist.group.WORLD) + + @torch.library.custom_op("test::vllm_allreduce", mutates_args=()) + def vllm_allreduce(x: torch.Tensor) -> torch.Tensor: + result = x.clone() + return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") + + @vllm_allreduce.register_fake + def _(x): + return torch.empty_like(x) + + def vllm_buffer_copy_allreduce(x: torch.Tensor) -> torch.Tensor: + """ + vLLM registered=False: flatten -> copy to IPC buffer -> allreduce -> reshape + + vLLM code: + inp_size = inp.numel() * inp.element_size() + self.buffer_ptrs[self.rank][:inp_size].copy_(inp.view(-1)) + ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) + """ + original_shape = x.shape + flat_x = x.contiguous().view(-1) + buffer_copy = flat_x.clone() + result = torch.ops._c10d_functional.all_reduce_( + buffer_copy, "sum", "default" + ) + return result.view(original_shape) + + def nccl_allreduce_direct(x: torch.Tensor) -> torch.Tensor: + """Standard NCCL allreduce without buffer copy.""" + result = x.clone() + return torch.ops._c10d_functional.all_reduce_(result, "sum", "default") + + from torch._inductor.kernel.custom_op import ( + CustomOpConfig, + register_custom_op_autotuning, + ) + + register_custom_op_autotuning( + vllm_allreduce, + configs=[ + CustomOpConfig(vllm_buffer_copy_allreduce), + CustomOpConfig(nccl_allreduce_direct), + ], + ) + + class VLLMAllReduceModel(torch.nn.Module): + def forward(self, x): + return vllm_allreduce(x) + + model = torch.compile(VLLMAllReduceModel()).to(device) + + torch.manual_seed(42 + rank) + x = torch.randn(128, 256, device=device) + + y = model(x) + self.assertEqual(y.shape, x.shape) + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 1c1f0f1c9cd2c..7b931fb3bf47e 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -71,16 +71,25 @@ def __init__( self.sym_inputs = get_symbolic_inputs(self.input_nodes) + # Cache compiled module to avoid recompiling on every benchmark call + self._compiled_module: Any = None + self._compiled_sym_inputs: list[Any] | None = None + def __str__(self) -> str: return f"SubgraphCaller({self.name})" - def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: - # Codegen Subgraph for benchmarking - # Need GraphLowering instead of SubgraphLowering to generate - # fully callable module + def _compile_for_benchmarking(self, *args: list[Any]) -> tuple[Any, list[Any]]: + """ + Compile the subgraph for benchmarking and return (module, sym_inputs). + + TODO: Add precompile() method to enable parallel compilation of all choices + before benchmarking. + """ import torch._inductor.config as inductor_config from torch._inductor.graph import GraphLowering + safe_name = self.name.replace("::", "_").replace(".", "_") + bm_graph_lowering = GraphLowering( gm=self.gm, example_inputs=self.example_inputs, @@ -90,7 +99,7 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: extern_node_serializer=V.graph.extern_node_serializer, is_inference=V.graph.is_inference, is_backward=V.graph.is_backward, - name=f"benchmark_{self.name}", + name=f"benchmark_{safe_name}", ) for sym_inp in self.sym_inputs: @@ -123,9 +132,23 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: ): bm_graph_lowering.run(*self.example_inputs) mod = bm_graph_lowering.compile_to_module() - bm_func = mod.call - bm_func([*sym_inputs, *args]) + return mod, sym_inputs + + def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: + """ + Regular benchmarking: compile and use benchmarker with warmup/rep. + """ + if self._compiled_module is None: + mod, sym_inputs = self._compile_for_benchmarking(*args) + self._compiled_module = mod + self._compiled_sym_inputs = sym_inputs + else: + mod = self._compiled_module + sym_inputs = self._compiled_sym_inputs + assert sym_inputs is not None # Type narrowing + + bm_func = mod.call if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args])) return benchmarker.benchmark( @@ -134,6 +157,24 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: device=benchmarker.infer_device(*sym_inputs, *args), ) + def benchmark_collective(self, *args: list[Any], out: torch.Tensor) -> None: + """ + Only run once with cached compiled module. + Called by benchmark_collective_choice which handles warmup + and timing with barrier synchronization across all ranks. + """ + if self._compiled_module is None: + mod, sym_inputs = self._compile_for_benchmarking(*args) + self._compiled_module = mod + self._compiled_sym_inputs = sym_inputs + else: + mod = self._compiled_module + sym_inputs = self._compiled_sym_inputs + assert sym_inputs is not None # Type narrowing + + bm_func = mod.call + bm_func([*sym_inputs, *args]) + def hash_key(self) -> str: return "-".join( [ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 7ba93575ce8bf..fcfb8f51ae6e7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -608,6 +608,16 @@ def prologue_fusion_enabled() -> bool: # If autotuning in subprocess, whether to use multiple devices autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" +# Number of benchmark runs for collective operations +collective_benchmark_nruns = int( + os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_NRUNS", "50") +) + +# Timeout in seconds for collective benchmarking +collective_benchmark_timeout = float( + os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_TIMEOUT", "30") +) + coordinate_descent_tuning = ( os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" ) diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 12cc68dcb9844..c6a641ce83b17 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -6,7 +6,6 @@ from typing import Any, Optional, Union import torch -from torch._inductor import config from torch._inductor.codegen.subgraph import SubgraphTemplate from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox from torch._inductor.lowering import lowerings, validate_ir @@ -21,6 +20,28 @@ log = logging.getLogger(__name__) +def _detect_collective_ops(choices: list) -> bool: + """ + Detect if choices contain collective operations. + """ + from torch._inductor.utils import is_collective_op + + for choice in choices: + if not hasattr(choice, "gm") or choice.gm is None: + continue + + for node in choice.gm.graph.nodes: + if node.op == "call_function" and node.target is not None: + op_name = str(node.target) + + if is_collective_op(op_name) or is_collective_op( + f"torch.ops.{op_name}" + ): + return True + + return False + + class CustomOpConfig: """Config for custom op autotuning. @@ -180,14 +201,8 @@ def create_internal_input_gen_fn( """Create internal input generator that converts IR buffer to user's fake tensor.""" def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor: - raw_shape = ir_buffer.get_size() - concrete_shape = V.graph.sizevars.size_hints( - raw_shape, fallback=config.unbacked_symint_fallback - ) - - fake_tensor = torch.empty( - concrete_shape, dtype=ir_buffer.get_dtype(), device="meta" - ) + fake_tensor = ir_node_to_tensor(ir_buffer) + assert fake_tensor is not None, "ir_node_to_tensor returned None" return user_function(fake_tensor) return internal_input_gen_fn @@ -321,6 +336,8 @@ def autotune_custom_op( ) input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns) + is_collective = _detect_collective_ops(choices) + # Run autotuning and get both result and winning choice selected_result, winning_choice = autotune_select_algorithm( name=name, @@ -329,6 +346,7 @@ def autotune_custom_op( layout=choices[0].layout, input_gen_fns=input_gen_fns, return_choice=True, + is_collective=is_collective, ) # Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl) @@ -363,16 +381,7 @@ def _generate_dynamic_configs( param_names = list(sig.parameters.keys()) with V.fake_mode: - fake_tensors = [] - for inp in tensor_inputs: - raw_shape = inp.get_size() - concrete_shape = V.graph.sizevars.size_hints( - raw_shape, fallback=config.unbacked_symint_fallback - ) - fake_tensor = torch.empty( - concrete_shape, dtype=inp.get_dtype(), device=inp.get_device() - ) - fake_tensors.append(fake_tensor) + fake_tensors = [ir_node_to_tensor(inp) for inp in tensor_inputs] fake_tensors_dict = dict(zip(param_names, fake_tensors)) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 77448c914df80..df71bdd3db502 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2335,6 +2335,10 @@ def autoheuristic_id(self): class ExternKernelCaller(ChoiceCaller): + """ + Caller for external kernel implementations + """ + def __init__( self, choice: ExternKernelChoice, @@ -2370,6 +2374,19 @@ def benchmark(self, *args, out): return do_bench_using_profiling(lambda: algo(*args)) return benchmarker.benchmark(algo, args, {}) + def benchmark_collective(self, *args, out): + """ + Called by benchmark_collective_choice, only run once, timing handled externally with barrier sync. + """ + if out.numel() == 0: + return + + algo = self.to_callable() + if self.has_out_variant: + algo(*args, out=out) + else: + algo(*args) + def to_callable(self): fn = self.choice.to_callable() if self.kwargs: @@ -2733,6 +2750,7 @@ def __call__( return_multi_template=False, best_config_future=None, return_choice=False, # TODO: return_choice is temporary and will be refactored soon + is_collective=False, ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2843,6 +2861,7 @@ def get_timings(hint_override: Optional[int] = None): choices, precompile_fn, best_config_future=best_config_future, + is_collective=is_collective, ) # if timings is empty, we really have no choice but to return a semi-random # choice. returning the first `ExternKernelCaller` is probably the safest bet @@ -2874,6 +2893,7 @@ def get_timings(hint_override: Optional[int] = None): # if we got any timings at all, pick the best of those choice = min(timings, key=timings.__getitem__) node = choice.output_node() + log.debug("Autotuning selected choice: %s", node) if return_choice: return node, choice @@ -2886,12 +2906,18 @@ def benchmark( layout, input_gen_fns, hint_override: Optional[int] = None, + is_collective=False, ): counters["inductor"]["select_algorithm_autotune"] += 1 # TODO(nmacchioni): remove this layer of abstraction # construct `benchmark_fn` which should pick between in-process and sub-process autotuning benchmark_fn = self.make_benchmark_fn( - choices, input_nodes, layout, input_gen_fns, hint_override=hint_override + choices, + input_nodes, + layout, + input_gen_fns, + hint_override=hint_override, + is_collective=is_collective, ) # `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which # maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds @@ -2905,6 +2931,7 @@ def autotune( input_gen_fns, choices, hint_override: Optional[int] = None, + is_collective=False, ): log.debug("Starting autotuning") @@ -2915,7 +2942,12 @@ def autotune( metadata=_autotune_metadata(input_nodes), ): benchmark_results = self.benchmark( - choices, input_nodes, layout, input_gen_fns, hint_override=hint_override + choices, + input_nodes, + layout, + input_gen_fns, + hint_override=hint_override, + is_collective=is_collective, ) if config.max_autotune_report_choices_stats: _log_autotune_choices_stats( @@ -2934,6 +2966,7 @@ def do_autotuning( precompile_fn, hint_override: Optional[int] = None, best_config_future=None, + is_collective=False, ): """Execute the autotuning process for kernel algorithm selection. @@ -3071,6 +3104,7 @@ def track_has_autotuned(choices): input_gen_fns, choices, hint_override=hint_override, + is_collective=is_collective, ) timings = self.lookup( @@ -3084,6 +3118,17 @@ def track_has_autotuned(choices): autotune_elapse = time.time() - autotune_start_ts log.debug("Autotuning elapsed time: %.02fs", autotune_elapse) + # For collective: if any choice returned inf (timeout or failure), fallback to default + if is_collective and timings: + has_inf = any(not math.isfinite(timing) for timing in timings.values()) + if has_inf: + log.warning( + "At least one choice failed or timed out during collective benchmarking. " + "Falling back to default implementation." + ) + return {} + + # For regular: if all choices returned inf, raise error if timings and all(not math.isfinite(timing) for timing in timings.values()): raise NoValidChoicesError @@ -3100,6 +3145,7 @@ def track_has_autotuned(choices): precompile_elapse, prescreening_elapse, hint_override=hint_override, + is_collective=is_collective, ) def profiler_bench_function(): @@ -3461,16 +3507,162 @@ def benchmark_choice( autotune_args.verify(**VERIFY) return result + @classmethod + def _run_collective_benchmark( + cls, + choice: ChoiceCaller, + inputs: tuple, + output: torch.Tensor, + nruns: int, + process_group, + timeout, + ) -> float: + """ + Single function for benchmarking collective operations. + Used for both warmup and actual benchmarking. + + Returns total time in milliseconds, or raises TimeoutError if any collective times out. + """ + import torch.distributed as dist + + work = dist.barrier(group=process_group, async_op=True) + if not work.wait(timeout): + raise TimeoutError("Barrier timeout before benchmarking") + + torch.cuda.synchronize() + + total_time = 0.0 + + for i in range(nruns): + torch.cuda.synchronize() + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + start_evt.record() + choice.benchmark_collective(*inputs, out=output) # type: ignore[attr-defined] + end_evt.record() + end_evt.synchronize() + + total_time += start_evt.elapsed_time(end_evt) + + return total_time + + @classmethod + def benchmark_collective_choice( + cls, + choice: ChoiceCaller, + autotune_args: AutotuneArgs, + ) -> float: + """ + Benchmark a choice for collective operations with cross-rank synchronization. + This method ensures all ranks synchronize before benchmarking + to get accurate measurements for distributed collective operations. + + Timeout/Error handling: If ANY rank times out or encounters an error during + the collective operations, ALL ranks will naturally time out (since the collective + won't complete), allowing the autotuner to fall back to the default implementation. + """ + from datetime import timedelta + + import torch.distributed as dist + + timeout_seconds = config.collective_benchmark_timeout + + nruns = config.collective_benchmark_nruns + nwarmup = ir.autotune_warmup + + # Use default process group (None = all ranks) + process_group = None + rank = dist.get_rank(process_group) + + benchmark_tensors: BenchmarkTensors = autotune_args.get_benchmark_tensors( + cls._is_extern(choice) + ) + inputs, output = benchmark_tensors.unpack() + output.zero_() + + timeout = timedelta(seconds=timeout_seconds) + + try: + # Do n warmups + total_time = cls._run_collective_benchmark( + choice, inputs, output, nwarmup, process_group, timeout + ) + + # Do n actual benchmarking runs + total_time = cls._run_collective_benchmark( + choice, inputs, output, nruns, process_group, timeout + ) + + avg_time = total_time / nruns + + # All-reduce to get avg time across ranks + time_tensor = torch.tensor( + [avg_time], dtype=torch.float32, device=f"cuda:{rank}" + ) + work = dist.all_reduce( + time_tensor, + op=dist.ReduceOp.AVG, + group=process_group, + async_op=True, + ) + if not work.wait(timeout): + raise TimeoutError( + "All-reduce timeout when collecting benchmark results" + ) + + timing = time_tensor.item() + + log.info( + "Collective benchmark for %s: %.6f ms", + choice.name, + timing, + ) + + return timing + + except Exception: + log.warning( + "Collective benchmark exception for choice %s. Skipping this choice.", + getattr(choice, "name", ""), + exc_info=True, + ) + return float("inf") + @classmethod def benchmark_choices( cls, choices: Sequence[ChoiceCaller], autotune_args: AutotuneArgs, + is_collective: bool = False, ) -> dict[ChoiceCaller, float]: + """ + Benchmark a list of choices and return timing dict. + """ + if is_collective: + import torch.distributed as dist + + if not dist.is_initialized(): + log.warning( + "Collective op detected but distributed not initialized. " + "Falling back to regular benchmarking." + ) + is_collective = False + else: + rank = dist.get_rank(None) # Use default process group + log.debug( + "Using collective benchmarking for %d choices on rank %d", + len(choices), + rank, + ) timings = {} for choice in choices: try: - timing = cls.benchmark_choice(choice, autotune_args) + if is_collective: + timing = cls.benchmark_collective_choice(choice, autotune_args) + else: + timing = cls.benchmark_choice(choice, autotune_args) except CUDACompileError: from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -3524,6 +3716,16 @@ def benchmark_choices( timings[choice] = timing + # If a collective choice failed or timed out, skip the rest of the choices + if is_collective and not math.isfinite(timing): + log.warning( + "Choice %s failed or timed out during collective benchmarking. " + "Stopping further benchmarking to avoid NCCL corruption.", + getattr(choice, "name", ""), + ) + timings.update({c: float("inf") for c in choices if c not in timings}) + break + return timings @classmethod @@ -3534,11 +3736,16 @@ def benchmark_in_current_process( layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], hint_override: Optional[int] = None, + is_collective=False, ) -> dict[ChoiceCaller, float]: inputs = cls.get_inputs( choices, input_nodes, layout, input_gen_fns, hint_override=hint_override ) - return cls.benchmark_choices(choices, inputs) + return cls.benchmark_choices( + choices, + inputs, + is_collective=is_collective, + ) @classmethod def benchmark_in_sub_process( @@ -3570,21 +3777,24 @@ def make_benchmark_fn( layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], hint_override: Optional[int] = None, + is_collective=False, ): if DEBUG: print(f"{len(choices)} tuning requests:") - if config.autotune_in_subproc: + # Collective ops must use current process + if is_collective or not config.autotune_in_subproc: return functools.partial( - cls.benchmark_in_sub_process, + cls.benchmark_in_current_process, input_nodes=input_nodes, layout=layout, input_gen_fns=input_gen_fns, hint_override=hint_override, + is_collective=is_collective, ) else: return functools.partial( - cls.benchmark_in_current_process, + cls.benchmark_in_sub_process, input_nodes=input_nodes, layout=layout, input_gen_fns=input_gen_fns, @@ -3816,8 +4026,26 @@ def log_results( precompile_elapse: float, prescreening_elapse: Optional[float] = None, hint_override: Optional[int] = None, + is_collective: bool = False, ): - """Log the autotuning results, currently only handles mm and flex""" + """Log the autotuning results, currently only handles mm and flex. Log Collective op autotuning result""" + if is_collective and timings: + import torch.distributed as dist + + # Only rank 0 logs to avoid duplicate logs + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + best_choice = min(timings, key=timings.__getitem__) + log.warning("[COLLECTIVE AUTOTUNING] All timings:") + for c, t in sorted(timings.items(), key=lambda x: x[1]): + choice_name = getattr(c, "name", str(c)) + log.warning( + " - %s: %.6f ms %s", + choice_name, + t if math.isfinite(t) else float("inf"), + "← SELECTED" if c == best_choice else "", + ) + V.debug.log_autotuning_results( name, input_nodes, timings, elapse, precompile_elapse ) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 4d1ddc9ad4769..d7f3844cdf1ba 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -4126,3 +4126,24 @@ def should_fallback_by_default(node: torch.fx.Node) -> bool: return target in fallback_hops return not _needs_inductor_compile(node) + + +# Collective operation names for specialized benchmarking +COLLECTIVE_OPS = OrderedSet( + [ + "torch.ops._c10d_functional.all_reduce.default", + "torch.ops._c10d_functional.all_reduce_.default", + "torch.ops._c10d_functional.all_gather_into_tensor.default", + "torch.ops._c10d_functional.reduce_scatter_tensor.default", + "torch.ops._c10d_functional.all_to_all_single.default", + "torch.ops._c10d_functional_autograd.all_reduce.default", + "torch.ops._c10d_functional_autograd.all_gather_into_tensor.default", + "torch.ops._c10d_functional_autograd.reduce_scatter_tensor.default", + "torch.ops._c10d_functional_autograd.all_to_all_single.default", + ] +) + + +def is_collective_op(op_name: str) -> bool: + """Check if an operation is a collective operation.""" + return op_name in COLLECTIVE_OPS From 2e0c2e170fe658c440775c8e5c44228aafcc47ec Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 3 Dec 2025 18:59:35 +0000 Subject: [PATCH 228/338] [xpu][feature] [1/2] Introduce XPUPluggableAllocator in cpp part (#168966) # Motivation This PR aims to introduce `XPUPluggableAllocator` and we make it as simple as possible. The follow-up PR would introduce the code related to the Python frontend part. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168966 Approved by: https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/eellison --- build_variables.bzl | 1 + c10/xpu/XPUCachingAllocator.cpp | 66 +++------- c10/xpu/XPUCachingAllocator.h | 41 +++++-- torch/csrc/xpu/XPUPluggableAllocator.cpp | 147 +++++++++++++++++++++++ torch/csrc/xpu/XPUPluggableAllocator.h | 80 ++++++++++++ 5 files changed, 279 insertions(+), 56 deletions(-) create mode 100644 torch/csrc/xpu/XPUPluggableAllocator.cpp create mode 100644 torch/csrc/xpu/XPUPluggableAllocator.h diff --git a/build_variables.bzl b/build_variables.bzl index ba856c5a97ba4..25f167191ab60 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -875,6 +875,7 @@ libtorch_python_xpu_sources = [ "torch/csrc/xpu/Event.cpp", "torch/csrc/xpu/Module.cpp", "torch/csrc/xpu/Stream.cpp", + "torch/csrc/xpu/XPUPluggableAllocator.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_xpu.cpp", "torch/csrc/inductor/aoti_torch/shim_xpu.cpp", ] diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index dfcccc94c9e32..92dffc9153977 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1353,7 +1353,7 @@ class NativeCachingAllocator : public XPUAllocator { public: std::vector> device_allocators; - void init(DeviceIndex device_count) { + void init(DeviceIndex device_count) override { const auto size = static_cast(device_allocators.size()); if (size < device_count) { device_allocators.resize(device_count); @@ -1538,88 +1538,62 @@ class NativeCachingAllocator : public XPUAllocator { } }; -static NativeCachingAllocator allocator; +static NativeCachingAllocator native_allocator; void local_raw_delete(void* ptr) { - allocator.free(ptr); + native_allocator.free(ptr); } -Allocator* get() { - return &allocator; -} - -void init(DeviceIndex device_count) { - return allocator.init(device_count); -} - -void emptyCache(MempoolId_t mempool_id) { - return allocator.emptyCache(mempool_id); -} - -void resetPeakStats(DeviceIndex device) { - return allocator.resetPeakStats(device); -} - -void resetAccumulatedStats(DeviceIndex device) { - return allocator.resetAccumulatedStats(device); -} +std::atomic allocator; -DeviceStats getDeviceStats(DeviceIndex device) { - return allocator.getDeviceStats(device); -} - -void* raw_alloc(size_t size) { - return allocator.raw_alloc(size); -} - -void raw_delete(void* ptr) { - return allocator.raw_delete(ptr); -} +struct NativeAllocatorStaticInitializer { + NativeAllocatorStaticInitializer() { + allocator.store(&native_allocator); + c10::SetAllocator(c10::kXPU, &native_allocator, 0); + } +}; -void recordStream(const DataPtr& dataPtr, XPUStream stream) { - return allocator.recordStream(dataPtr, stream); -} +static NativeAllocatorStaticInitializer native_allocator_static_initializer; void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) { - return allocator.enablePeerAccess(dev, dev_to_access); + return native_allocator.enablePeerAccess(dev, dev_to_access); } double getMemoryFraction(DeviceIndex device) { - return allocator.getMemoryFraction(device); + return native_allocator.getMemoryFraction(device); } void setMemoryFraction(double fraction, DeviceIndex device) { - return allocator.setMemoryFraction(fraction, device); + return native_allocator.setMemoryFraction(fraction, device); } void createOrIncrefPool( c10::DeviceIndex device, MempoolId_t mempool_id, XPUAllocator* allocator_ptr) { - return allocator.createOrIncrefPool(device, mempool_id, allocator_ptr); + return native_allocator.createOrIncrefPool(device, mempool_id, allocator_ptr); } void beginAllocateToPool( c10::DeviceIndex device, MempoolId_t mempool_id, std::function filter) { - return allocator.beginAllocateToPool(device, mempool_id, std::move(filter)); + return native_allocator.beginAllocateToPool( + device, mempool_id, std::move(filter)); } void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { - return allocator.endAllocateToPool(device, mempool_id); + return native_allocator.endAllocateToPool(device, mempool_id); } void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { - return allocator.releasePool(device, mempool_id); + return native_allocator.releasePool(device, mempool_id); } int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { - return allocator.getPoolUseCount(device, mempool_id); + return native_allocator.getPoolUseCount(device, mempool_id); } -REGISTER_ALLOCATOR(kXPU, &allocator) - } // namespace c10::xpu::XPUCachingAllocator namespace c10::xpu { diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index 0054e359e77fe..54c7387cc3897 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -8,28 +8,49 @@ namespace c10::xpu::XPUCachingAllocator { class XPUAllocator : public DeviceAllocator { public: + virtual void init(c10::DeviceIndex device_count) = 0; virtual void* raw_alloc(size_t nbytes) = 0; virtual void raw_delete(void* ptr) = 0; }; -C10_XPU_API Allocator* get(); +C10_XPU_API extern std::atomic allocator; -C10_XPU_API void init(DeviceIndex device_count); +inline XPUAllocator* get() { + return allocator.load(); +} -C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0}); +inline void init(c10::DeviceIndex device_count) { + get()->init(device_count); +} -C10_XPU_API void resetPeakStats(DeviceIndex device); +inline void emptyCache(MempoolId_t mempool_id = {0, 0}) { + get()->emptyCache(mempool_id); +} -C10_XPU_API void resetAccumulatedStats(DeviceIndex device); +inline void resetPeakStats(DeviceIndex device) { + get()->resetPeakStats(device); +} -C10_XPU_API c10::CachingDeviceAllocator::DeviceStats getDeviceStats( - DeviceIndex device); +inline void resetAccumulatedStats(DeviceIndex device) { + get()->resetAccumulatedStats(device); +} -C10_XPU_API void* raw_alloc(size_t size); +inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + DeviceIndex device) { + return get()->getDeviceStats(device); +} -C10_XPU_API void raw_delete(void* ptr); +inline void* raw_alloc(size_t size) { + return get()->raw_alloc(size); +} -C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream); +inline void raw_delete(void* ptr) { + get()->raw_delete(ptr); +} + +inline void recordStream(const DataPtr& dataPtr, XPUStream stream) { + get()->recordStream(dataPtr, stream); +} C10_XPU_API void enablePeerAccess( c10::DeviceIndex dev, diff --git a/torch/csrc/xpu/XPUPluggableAllocator.cpp b/torch/csrc/xpu/XPUPluggableAllocator.cpp new file mode 100644 index 0000000000000..6534ac94f159d --- /dev/null +++ b/torch/csrc/xpu/XPUPluggableAllocator.cpp @@ -0,0 +1,147 @@ +#include + +namespace torch::xpu::XPUPluggableAllocator { + +void custom_raw_deleter(void* ptr); + +static c10::DeviceIndex device_count_ = 0; + +void* XPUPluggableAllocator::malloc( + size_t size, + c10::DeviceIndex device, + sycl::queue* queue) { + void* r = alloc_fn_(size, device, queue); + { + const std::lock_guard lock(allocator_mutex_); + allocation_metadata_.emplace(r, _AllocationMetadata(size, device, queue)); + } + return r; +} + +c10::DataPtr XPUPluggableAllocator::allocate(size_t size) { + auto device = c10::xpu::current_device(); + sycl::queue& queue = c10::xpu::getCurrentXPUStream(device); + void* r = this->malloc(size, device, &queue); + return {r, r, raw_deleter(), c10::Device(c10::kXPU, device)}; +} + +void* XPUPluggableAllocator::raw_alloc(size_t nbytes) { + auto device = c10::xpu::current_device(); + sycl::queue& queue = c10::xpu::getCurrentXPUStream(device); + return malloc(nbytes, device, &queue); +} + +c10::DeleterFnPtr XPUPluggableAllocator::raw_deleter() const { + return &custom_raw_deleter; +} + +void XPUPluggableAllocator::raw_delete(void* ptr) { + sycl::queue* queue = nullptr; + c10::DeviceIndex device_idx = -1; + size_t size = 0; + { + const std::lock_guard lock(allocator_mutex_); + TORCH_CHECK( + allocation_metadata_.count(ptr), + "Trying to free a pointer not allocated here"); + _AllocationMetadata& metadata = allocation_metadata_[ptr]; + size = metadata.size; + device_idx = metadata.device_idx; + queue = metadata.queue; + allocation_metadata_.erase(ptr); + } + free_fn_(ptr, size, device_idx, queue); +} + +void XPUPluggableAllocator::init(c10::DeviceIndex device_count) { + if (init_fn_) { + init_fn_(device_count); + } + device_count_ = device_count; + initialized_ = true; +} + +bool XPUPluggableAllocator::initialized() { + return initialized_; +} + +void XPUPluggableAllocator::copy_data( + void* dest, + const void* src, + std::size_t count) const { + c10::xpu::getCurrentXPUStream().queue().memcpy(dest, src, count); +} + +void XPUPluggableAllocator::recordStream( + const c10::DataPtr& ptr, + c10::Stream stream) { + if (record_stream_fn_) { + auto xpu_stream = c10::xpu::XPUStream(stream); + record_stream_fn_(ptr.get(), &xpu_stream.queue()); + } +} + +void XPUPluggableAllocator::emptyCache( + /*unused*/ c10::MempoolId_t mempool_id) { + TORCH_CHECK( + false, + "XPUPluggableAllocator does not yet support emptyCache. " + "If you need it, please file an issue describing your use case."); +} + +c10::CachingDeviceAllocator::DeviceStats XPUPluggableAllocator::getDeviceStats( + c10::DeviceIndex device) { + TORCH_CHECK( + false, + "XPUPluggableAllocator does not yet support getDeviceStats. " + "If you need it, please file an issue describing your use case."); +} + +void XPUPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) { + TORCH_CHECK( + false, + "XPUPluggableAllocator does not yet support resetAccumulatedStats. " + "If you need it, please file an issue describing your use case."); +} + +void XPUPluggableAllocator::resetPeakStats(c10::DeviceIndex device) { + TORCH_CHECK( + false, + "XPUPluggableAllocator does not yet support resetPeakStats. " + "If you need it, please file an issue describing your use case."); +} + +std::shared_ptr + current_custom_allocator; + +std::shared_ptr +getCurrentAllocator() { + return current_custom_allocator; +} + +std::shared_ptr +createCustomAllocator( + std::function alloc_fn, + std::function free_fn) { + auto allocator = std::make_shared( + std::move(alloc_fn), std::move(free_fn)); + allocator->init(device_count_); + return allocator; +} + +void changeCurrentAllocator( + const std::shared_ptr& + allocator) { + TORCH_CHECK( + !c10::xpu::XPUCachingAllocator::get()->initialized(), + "Can't swap an already initialized allocator"); + c10::xpu::XPUCachingAllocator::allocator.store(allocator.get()); + c10::SetAllocator(c10::kXPU, allocator.get()); + current_custom_allocator = allocator; +} + +void custom_raw_deleter(void* ptr) { + current_custom_allocator->raw_delete(ptr); +} + +} // namespace torch::xpu::XPUPluggableAllocator diff --git a/torch/csrc/xpu/XPUPluggableAllocator.h b/torch/csrc/xpu/XPUPluggableAllocator.h new file mode 100644 index 0000000000000..5133955c58876 --- /dev/null +++ b/torch/csrc/xpu/XPUPluggableAllocator.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include + +namespace torch::xpu::XPUPluggableAllocator { + +struct _AllocationMetadata { + _AllocationMetadata() {} + _AllocationMetadata( + size_t size, + c10::DeviceIndex device_idx, + sycl::queue* queue) + : size(size), device_idx(device_idx), queue(queue) {} + size_t size{0}; + c10::DeviceIndex device_idx{-1}; + sycl::queue* queue{}; +}; + +struct TORCH_PYTHON_API XPUPluggableAllocator + : public c10::xpu::XPUCachingAllocator::XPUAllocator { + XPUPluggableAllocator( + std::function alloc_fn, + std::function free_fn) + : alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {} + + C10_DISABLE_COPY_AND_ASSIGN(XPUPluggableAllocator); + + ~XPUPluggableAllocator() override = default; + + void* malloc(size_t size, c10::DeviceIndex device, sycl::queue* stream); + + c10::DataPtr allocate(size_t size) override; + c10::DeleterFnPtr raw_deleter() const override; + + void* raw_alloc(size_t nbytes) override; + void raw_delete(void* ptr) override; + void init(c10::DeviceIndex device_count) override; + bool initialized() override; + void copy_data(void* dest, const void* src, std::size_t count) const final; + + void recordStream(const c10::DataPtr&, c10::Stream stream) override; + void emptyCache(c10::MempoolId_t mempool_id = {0, 0}) override; + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) override; + void resetAccumulatedStats(c10::DeviceIndex device) override; + void resetPeakStats(c10::DeviceIndex device) override; + + void set_init_fn(std::function init_fn) { + init_fn_ = std::move(init_fn); + } + void set_record_stream_fn( + std::function record_stream_fn) { + record_stream_fn_ = std::move(record_stream_fn); + } + + protected: + std::function alloc_fn_; + std::function free_fn_; + std::function init_fn_; + std::function record_stream_fn_; + std::mutex allocator_mutex_; + // We do the bookkeeping here in order to simplify custom allocators + std::unordered_map allocation_metadata_; + bool initialized_ = false; +}; + +TORCH_XPU_API std::shared_ptr +getCurrentAllocator(); + +TORCH_XPU_API std::shared_ptr +createCustomAllocator( + std::function alloc_fn, + std::function free_fn); + +TORCH_XPU_API void changeCurrentAllocator( + const std::shared_ptr& + allocator); + +} // namespace torch::xpu::XPUPluggableAllocator From 7b7af390ea8541c611d1ce2018a6934188fc197b Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 3 Dec 2025 12:08:59 -0600 Subject: [PATCH 229/338] [MPS] Migrate `lu_unpack` to Metal and fix `lu_solve` backward (#168120) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168120 Approved by: https://github.com/malfet ghstack dependencies: #167569 --- .../ATen/native/mps/kernels/LinearAlgebra.h | 6 ++ .../native/mps/kernels/LinearAlgebra.metal | 34 +++++++ .../native/mps/operations/LinearAlgebra.mm | 92 ++++++------------- aten/src/ATen/native/native_functions.yaml | 6 +- torch/testing/_internal/common_mps.py | 8 -- .../_internal/opinfo/definitions/linalg.py | 2 +- 6 files changed, 71 insertions(+), 77 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.h b/aten/src/ATen/native/mps/kernels/LinearAlgebra.h index e50753122028c..ff053de15377a 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.h +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.h @@ -14,3 +14,9 @@ struct OrgqrParams { ::c10::metal::array H_strides; ::c10::metal::array H_sizes; }; + +struct UnpackPivotsParams { + uint32_t perm_batch_stride; + uint32_t pivots_batch_stride; + uint32_t dim_size; +}; diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index ecb2ddefd1fc1..e48d2c62cb02d 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -801,6 +801,27 @@ kernel void orgqr( } } +template +kernel void unpack_pivots( + device TO* perm [[buffer(0)]], + constant TI* pivots [[buffer(1)]], + constant UnpackPivotsParams& params [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + auto perm_batch_stride = params.perm_batch_stride; + auto pivots_batch_stride = params.pivots_batch_stride; + auto dim_size = params.dim_size; + + perm += perm_batch_stride * tid; + pivots += pivots_batch_stride * tid; + + for (uint32_t i = 0; i < dim_size; i++) { + auto j = pivots[i] - 1; + auto perm_j = perm[j]; + perm[j] = perm[i]; + perm[i] = perm_j; + } +} + #define INSTANTIATE_MM_OPS(DTYPE) \ template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ constant DTYPE * mat1Data [[buffer(0)]], \ @@ -860,3 +881,16 @@ REGISTER_ORGQR(half); REGISTER_ORGQR(bfloat); REGISTER_ORGQR(float2); REGISTER_ORGQR(half2); + +#define REGISTER_UNPACK_PIVOTS(TO, TI) \ + template [[host_name("unpack_pivots_" #TO "_" #TI)]] \ + kernel void unpack_pivots( \ + device TO * perm [[buffer(0)]], \ + constant TI * pivots [[buffer(1)]], \ + constant UnpackPivotsParams & params [[buffer(2)]], \ + uint tid [[thread_position_in_grid]]); + +REGISTER_UNPACK_PIVOTS(int, int); +REGISTER_UNPACK_PIVOTS(int, long); +REGISTER_UNPACK_PIVOTS(long, int); +REGISTER_UNPACK_PIVOTS(long, long); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index c6d766f92f2b0..d895382c660ef 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -1145,52 +1146,39 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const return out; } -static void lu_unpack_mps_impl(const Tensor& LU_data, - const Tensor& LU_pivots, - bool unpack_data, - bool unpack_pivots, - const Tensor& P, - const Tensor& L, - const Tensor& U) { - const auto ndim = LU_data.dim(); - TORCH_CHECK(ndim >= 2, "LU_data must have at least 2 dimensions"); - - const auto r = LU_data.size(-2); - const auto c = LU_data.size(-1); - const auto k = std::min(r, c); - - const auto batchSize = c10::multiply_integers(LU_data.sizes().begin(), LU_data.sizes().end() - 2); - - if (unpack_data) { - Tensor L_part = r < c ? slice(LU_data, -1, 0, k) : LU_data; - L.copy_(L_part.tril()); - (ndim == 2 ? L.diagonal() : L.diagonal(0, -2, -1)).fill_(1); - - Tensor U_part = r < c ? LU_data : slice(LU_data, -2, 0, k); - U.copy_(U_part.triu()); +static void unpack_pivots_stub_impl(TensorIterator& iter, const int64_t dim_size, const int64_t max_pivot) { + if (iter.numel() == 0 || dim_size == 0) { + return; } - if (unpack_pivots) { - // P as an identity matrix for pivots - P.fill_(0); - LU_pivots.dim() == 1 ? P.diagonal().fill_(1) : P.diagonal(0, -2, -1).fill_(1); + auto perm = iter.tensor(0); + auto pivots = iter.tensor(1); + + // TODO: Perhaps this should be disabled since it requires a sync? + TORCH_CHECK_TENSOR_ALL(pivots.le(max_pivot).logical_and(pivots.ge(1)), + "pivots passed to lu_unpack must be between 1 and LU.size(-2) inclusive." + "Did you properly pass the result of lu_factor?"); - auto stream = getCurrentMPSStream(); - auto device = MPSDevice::getInstance()->device(); - auto applyPivotsPSO = lib.getPipelineStateForFunc("applyPivots"); - uint32_t maxThreadsPerGroup = [applyPivotsPSO maxTotalThreadsPerThreadgroup]; + auto num_threads = iter.numel(); + MPSStream* stream = getCurrentMPSStream(); - auto pivots = (LU_pivots.dim() == 1) ? LU_pivots.sub(1) : LU_pivots.view({batchSize, -1}).sub(1); + UnpackPivotsParams params; + params.perm_batch_stride = safe_downcast((perm.dim() > 1) ? perm.stride(-2) : 0); + params.pivots_batch_stride = safe_downcast((pivots.dim() > 1) ? pivots.stride(-2) : 0); + params.dim_size = safe_downcast(dim_size); + dispatch_sync_with_rethrow(stream->queue(), ^() { @autoreleasepool { - dispatch_sync_with_rethrow(stream->queue(), ^() { - auto computeEncoder = stream->commandEncoder(); - mtl_setArgs(computeEncoder, P, pivots, r, k); - [computeEncoder setComputePipelineState:applyPivotsPSO]; - mtl_dispatch1DJob(computeEncoder, applyPivotsPSO, batchSize * maxThreadsPerGroup); - }); + id compute_encoder = stream->commandEncoder(); + auto pipeline_state = lib.getPipelineStateForFunc( + fmt::format("unpack_pivots_{}_{}", scalarToMetalTypeString(perm), scalarToMetalTypeString(pivots))); + getMPSProfiler().beginProfileKernel(pipeline_state, "unpack_pivots", {pivots}); + [compute_encoder setComputePipelineState:pipeline_state]; + mtl_setArgs(compute_encoder, perm, pivots, params); + mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads); + getMPSProfiler().endProfileKernel(pipeline_state); } - } + }); } static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper) { @@ -1527,41 +1515,17 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info); } -TORCH_IMPL_FUNC(lu_unpack_out_mps) -(const Tensor& LU_data, - const Tensor& LU_pivots, - bool unpack_data, - bool unpack_pivots, - const Tensor& P, - const Tensor& L, - const Tensor& U) { - mps::lu_unpack_mps_impl(LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U); -} - TORCH_IMPL_FUNC(linalg_lu_factor_ex_out_mps) (const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) { mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors); } -TORCH_IMPL_FUNC(linalg_lu_out_mps)(const Tensor& A, bool pivot, const Tensor& P, const Tensor& L, const Tensor& U) { - Tensor LU = at::empty({0}, A.scalar_type(), std::nullopt, kMPS, std::nullopt, MemoryFormat::Contiguous); - auto pivots = at::empty({0}, A.options().dtype(kInt)); - auto info = at::empty({0}, A.options().dtype(kInt)); - mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, /*check_errors=*/false); - at::lu_unpack_out(const_cast(P), - const_cast(L), - const_cast(U), - LU, - pivots, - /*unpack_data=*/true, - /*unpack_pivots=*/pivot); -} - TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info); } REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl) +REGISTER_DISPATCH(unpack_pivots_stub, mps::unpack_pivots_stub_impl) REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl); } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e6b5bfcd18727..1759951b68bdc 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9756,8 +9756,7 @@ variants: function structured: True dispatch: - CPU, CUDA: lu_unpack_out - MPS: lu_unpack_out_mps + CPU, CUDA, MPS: lu_unpack_out # TODO: remove dispatch section when porting TH CUDA to ATen - func: multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) @@ -14226,8 +14225,7 @@ variants: function structured: True dispatch: - CPU, CUDA: linalg_lu_out - MPS: linalg_lu_out_mps + CPU, CUDA, MPS: linalg_lu_out # linalg.lu_solve - func: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 2b1f1be0e02f9..cedd0c92b6a4d 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -697,14 +697,6 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: torch.float16, torch.float32, ], # missing `aten::lu_solve`. - # `linalg.lu_solve`'s backward pass for the `LU` arg calls - # `lu_unpack`, and pivots are unpacked if `left == adjoint`. When - # unpacking pivots, `lu_unpack` incorrectly raises an error if - # `pivots.shape` is zero in any of the batch dims and the last dim - # is greater than 1. - "linalg.lu_solve": None, - # lu_solve only fails on MacOS 14 for some reason - "lu_solve": None if MACOS_VERSION < 15.0 else [], "linalg.tensorsolve": [ torch.float16, torch.float32, diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index 95cb59df0fcb4..87071c439f8e0 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -1077,7 +1077,7 @@ def out_fn(output): else: return output - batch_shapes = ((), (3,), (3, 3)) + batch_shapes = ((), (3,), (3, 3), (0,)) # pivot=False only supported in CUDA pivots = (True, False) if torch.device(device).type == "cuda" else (True,) deltas = (-2, -1, 0, +1, +2) From 02d8bd6974cf84b721680d773dbdb1b6f40ce272 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Wed, 3 Dec 2025 16:21:36 -0800 Subject: [PATCH 230/338] [pallas backend] Add special function support to Pallas backend (#169422) With this change, we are at 362 failed and 616 passed tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/169422 Approved by: https://github.com/yarongmu-google, https://github.com/yf225 --- test/inductor/test_torchinductor.py | 18 ++ torch/_inductor/codegen/pallas.py | 315 ++++++++++++++++++++++++++++ 2 files changed, 333 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index d3585bdb1d317..c5bdab4135b0f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -929,6 +929,12 @@ def is_triton_cpu_backend(device): return getattr(device, "type", device) == "cpu" and config.cpu_backend == "triton" +def is_pallas_backend(device): + if getattr(device, "type", device) == "cpu": + return config.cpu_backend == "pallas" + return config.cuda_backend == "pallas" + + def skip_if_triton_cpu(fn): import types @@ -13259,6 +13265,18 @@ def test_pointwise(self, name, op): ]: raise unittest.SkipTest(f"Triton CPU does not support {name}") + if is_pallas_backend(self.device) and name in { + "airy_ai", + "bessel_y0", + "bessel_y1", + "modified_bessel_k0", + "modified_bessel_k1", + "ndtri", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", + }: + raise unittest.SkipTest(f"Pallas does not support {name}") + if name in {"gammainc", "gammaincc"}: args = ( torch.randn(8, 8, dtype=dtype, device=self.device), diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 854adf3f53d34..0e97ae1f8f58d 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -1,6 +1,7 @@ from __future__ import annotations import hashlib +import math from typing import Any, Optional, TYPE_CHECKING, Union import sympy # noqa: TC002 @@ -226,6 +227,12 @@ def constant(val, dtype: torch.dtype) -> str: jax_dtype = torch_dtype_to_jax(dtype) if dtype == torch.bool: return "True" if val else "False" + # Handle special float values + if isinstance(val, float): + if math.isnan(val): + return "jnp.nan" + if math.isinf(val): + return "jnp.inf" if val > 0 else "-jnp.inf" return f"jnp.array({val}, dtype={jax_dtype})" @staticmethod @@ -275,6 +282,18 @@ def le(a: str, b: str) -> str: def gt(a: str, b: str) -> str: return f"({a} > {b})" + @staticmethod + def isnan(x: str) -> str: + return f"jnp.isnan({x})" + + @staticmethod + def isinf(x: str) -> str: + return f"jnp.isinf({x})" + + @staticmethod + def isfinite(x: str) -> str: + return f"jnp.isfinite({x})" + @staticmethod def ge(a: str, b: str) -> str: return f"({a} >= {b})" @@ -351,6 +370,302 @@ def lgamma(x: str) -> str: def digamma(x: str) -> str: return f"jax.scipy.special.digamma({x})" + @staticmethod + def bessel_j0(x: str) -> str: + # bessel_jn requires float64 and has numerical issues at x=0 (returns NaN) + # bessel_jn(x, v=n) returns array of shape (n+1, ...) with J_0 to J_n + # Handle by: convert to float64, compute, handle x=0, convert back + # J0(0) = 1.0 + return ( + f"jnp.where({x}.astype(jnp.float64) == 0.0, 1.0, " + f"jax.scipy.special.bessel_jn({x}.astype(jnp.float64), v=0)[0])" + f".astype({x}.dtype)" + ) + + @staticmethod + def bessel_j1(x: str) -> str: + # bessel_jn requires float64 and has numerical issues at x=0 (returns NaN) + # bessel_jn(x, v=n) returns array of shape (n+1, ...) with J_0 to J_n + # Handle by: convert to float64, compute, handle x=0, convert back + # J1(0) = 0.0 + return ( + f"jnp.where({x}.astype(jnp.float64) == 0.0, 0.0, " + f"jax.scipy.special.bessel_jn({x}.astype(jnp.float64), v=1)[1])" + f".astype({x}.dtype)" + ) + + @staticmethod + def modified_bessel_i0(x: str) -> str: + # Modified Bessel function of the first kind I_0(x) + # I_0(x) = bessel_i0e(x) * exp(|x|) where bessel_i0e is the scaled version + return f"jax.lax.bessel_i0e({x}) * jnp.exp(jnp.abs({x}))" + + @staticmethod + def modified_bessel_i1(x: str) -> str: + # Modified Bessel function of the first kind I_1(x) + # I_1(x) = bessel_i1e(x) * exp(|x|) where bessel_i1e is the scaled version + return f"jax.lax.bessel_i1e({x}) * jnp.exp(jnp.abs({x}))" + + @staticmethod + def spherical_bessel_j0(x: str) -> str: + # Spherical Bessel function of the first kind j_0(x) = sin(x) / x + # Handle x=0: j_0(0) = 1 + return f"jnp.where({x} == 0.0, 1.0, jnp.sin({x}) / {x})" + + @staticmethod + def i0(x: str) -> str: + # Modified Bessel function I_0 (same as modified_bessel_i0) + return f"jax.lax.bessel_i0e({x}) * jnp.exp(jnp.abs({x}))" + + @staticmethod + def i0e(x: str) -> str: + # Exponentially scaled modified Bessel function I_0 + return f"jax.lax.bessel_i0e({x})" + + @staticmethod + def i1(x: str) -> str: + # Modified Bessel function I_1 (same as modified_bessel_i1) + return f"jax.lax.bessel_i1e({x}) * jnp.exp(jnp.abs({x}))" + + @staticmethod + def i1e(x: str) -> str: + # Exponentially scaled modified Bessel function I_1 + return f"jax.lax.bessel_i1e({x})" + + @staticmethod + def gammainc(x: str, y: str) -> str: + # Regularized lower incomplete gamma function P(a, x) + # Note: PyTorch uses gammainc(input, other) where input is a (shape param) + return f"jax.scipy.special.gammainc({x}, {y})" + + @staticmethod + def gammaincc(x: str, y: str) -> str: + # Regularized upper incomplete gamma function Q(a, x) + return f"jax.scipy.special.gammaincc({x}, {y})" + + @staticmethod + def igamma(x: str, y: str) -> str: + # Regularized lower incomplete gamma function (alias for gammainc) + return f"jax.scipy.special.gammainc({x}, {y})" + + @staticmethod + def igammac(x: str, y: str) -> str: + # Regularized upper incomplete gamma function (alias for gammaincc) + return f"jax.scipy.special.gammaincc({x}, {y})" + + @staticmethod + def polygamma(x: str, y: str) -> str: + # Polygamma function psi^(n)(x), x is order n, y is the value + # Note: JAX uses polygamma(n, x) where n is integer order + return f"jax.scipy.special.polygamma({x}.astype(jnp.int32), {y})" + + @staticmethod + def ndtri(x: str) -> str: + # Inverse of the standard normal CDF + return f"jax.scipy.special.ndtri({x})" + + @staticmethod + def zeta(x: str, y: str) -> str: + # Hurwitz zeta function zeta(x, q) = sum_{k=0}^inf 1/(k+q)^x + return f"jax.scipy.special.zeta({x}, {y})" + + @staticmethod + def xlogy(x: str, y: str) -> str: + # x * log(y), with proper handling of x=0 + return f"jax.scipy.special.xlogy({x}, {y})" + + @staticmethod + def xlog1py(x: str, y: str) -> str: + # x * log1p(y), with proper handling of x=0 + return f"jax.scipy.special.xlog1py({x}, {y})" + + @staticmethod + def chebyshev_polynomial_t(x: str, n: str) -> str: + # Chebyshev polynomial of the first kind T_n(x) + # For |x| <= 1: T_n(x) = cos(n * arccos(x)) + # For x > 1: T_n(x) = cosh(n * arccosh(x)) + # For x < -1: T_n(x) = (-1)^n * cosh(n * arccosh(-x)) + return ( + f"jnp.where(jnp.abs({x}) <= 1, " + f"jnp.cos({n} * jnp.arccos(jnp.clip({x}, -1, 1))), " + f"jnp.where({x} > 1, " + f"jnp.cosh({n} * jnp.arccosh(jnp.maximum({x}, 1.0))), " + f"((-1.0) ** {n}) * jnp.cosh({n} * jnp.arccosh(jnp.maximum(-{x}, 1.0)))))" + ) + + @staticmethod + def chebyshev_polynomial_u(x: str, n: str) -> str: + # Chebyshev polynomial of the second kind U_n(x) + # For |x| < 1: U_n(x) = sin((n+1) * arccos(x)) / sqrt(1 - x^2) + # For x = 1: U_n(1) = n+1 + # For x = -1: U_n(-1) = (-1)^n * (n+1) + # For x > 1: U_n(x) = sinh((n+1) * arccosh(x)) / sqrt(x^2 - 1) + # For x < -1: U_n(x) = (-1)^n * U_n(-x) (symmetry) + return ( + f"jnp.where(jnp.abs({x}) < 1, " + f"jnp.sin(({n} + 1) * jnp.arccos(jnp.clip({x}, -1, 1))) / " + f"jnp.sqrt(jnp.maximum(1 - {x}**2, 1e-10)), " + f"jnp.where({x} >= 1, " + f"jnp.where({x} == 1, {n} + 1.0, " + f"jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum({x}, 1.0))) / " + f"jnp.sqrt(jnp.maximum({x}**2 - 1, 1e-10))), " + f"jnp.where({x} == -1, ((-1.0) ** {n}) * ({n} + 1.0), " + f"((-1.0) ** {n}) * jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum(-{x}, 1.0))) / " + f"jnp.sqrt(jnp.maximum({x}**2 - 1, 1e-10)))))" + ) + + @staticmethod + def chebyshev_polynomial_v(x: str, n: str) -> str: + # Chebyshev polynomial of the third kind V_n(x) + # V_n(x) = (T_n(x) - T_{n+1}(x)) / (1 - x) for x != 1 + # V_n(1) = 1, recurrence: V_0 = 1, V_1 = 2x - 1, V_n = 2x*V_{n-1} - V_{n-2} + # Explicit: V_0 = 1, V_1 = 2x-1, V_2 = 4x^2-2x-1, V_3 = 8x^3-4x^2-4x+1 + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2*{x} - 1, " + f"jnp.where({n} == 2, 4*{x}**2 - 2*{x} - 1, " + f"jnp.where({n} == 3, 8*{x}**3 - 4*{x}**2 - 4*{x} + 1, " + f"jnp.where({n} == 4, 16*{x}**4 - 8*{x}**3 - 12*{x}**2 + 4*{x} + 1, " + f"jnp.where({n} == 5, 32*{x}**5 - 16*{x}**4 - 32*{x}**3 + 12*{x}**2 + 6*{x} - 1, " + f"jnp.zeros_like({x})))))))" + ) + + @staticmethod + def chebyshev_polynomial_w(x: str, n: str) -> str: + # Chebyshev polynomial of the fourth kind W_n(x) + # W_n(x) = (T_n(x) + T_{n+1}(x)) / (1 + x) for x != -1 + # W_n(-1) = (-1)^n, recurrence: W_0 = 1, W_1 = 2x + 1, W_n = 2x*W_{n-1} - W_{n-2} + # Explicit: W_0 = 1, W_1 = 2x+1, W_2 = 4x^2+2x-1, W_3 = 8x^3+4x^2-4x-1 + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2*{x} + 1, " + f"jnp.where({n} == 2, 4*{x}**2 + 2*{x} - 1, " + f"jnp.where({n} == 3, 8*{x}**3 + 4*{x}**2 - 4*{x} - 1, " + f"jnp.where({n} == 4, 16*{x}**4 + 8*{x}**3 - 12*{x}**2 - 4*{x} + 1, " + f"jnp.where({n} == 5, 32*{x}**5 + 16*{x}**4 - 32*{x}**3 - 12*{x}**2 + 6*{x} + 1, " + f"jnp.zeros_like({x})))))))" + ) + + @staticmethod + def shifted_chebyshev_polynomial_t(x: str, n: str) -> str: + # Shifted Chebyshev polynomial of the first kind T*_n(x) = T_n(2x - 1) + # T_n(y) where y = 2x - 1 + # Use same formula as chebyshev_polynomial_t + y = f"(2 * {x} - 1)" + return ( + f"jnp.where(jnp.abs({y}) <= 1, " + f"jnp.cos({n} * jnp.arccos(jnp.clip({y}, -1, 1))), " + f"jnp.where({y} > 1, " + f"jnp.cosh({n} * jnp.arccosh(jnp.maximum({y}, 1.0))), " + f"((-1.0) ** {n}) * jnp.cosh({n} * jnp.arccosh(jnp.maximum(-{y}, 1.0)))))" + ) + + @staticmethod + def shifted_chebyshev_polynomial_u(x: str, n: str) -> str: + # Shifted Chebyshev polynomial of the second kind U*_n(x) = U_n(2x - 1) + # Use same formula as chebyshev_polynomial_u + y = f"(2 * {x} - 1)" + return ( + f"jnp.where(jnp.abs({y}) < 1, " + f"jnp.sin(({n} + 1) * jnp.arccos(jnp.clip({y}, -1, 1))) / " + f"jnp.sqrt(jnp.maximum(1 - ({y})**2, 1e-10)), " + f"jnp.where({y} >= 1, " + f"jnp.where({y} == 1, {n} + 1.0, " + f"jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum({y}, 1.0))) / " + f"jnp.sqrt(jnp.maximum({y}**2 - 1, 1e-10))), " + f"jnp.where({y} == -1, ((-1.0) ** {n}) * ({n} + 1.0), " + f"((-1.0) ** {n}) * jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum(-{y}, 1.0))) / " + f"jnp.sqrt(jnp.maximum({y}**2 - 1, 1e-10)))))" + ) + + @staticmethod + def shifted_chebyshev_polynomial_v(x: str, n: str) -> str: + # Shifted Chebyshev polynomial of the third kind V*_n(x) = V_n(2x - 1) + y = f"(2 * {x} - 1)" # shifted variable + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2*{y} - 1, " + f"jnp.where({n} == 2, 4*{y}**2 - 2*{y} - 1, " + f"jnp.where({n} == 3, 8*{y}**3 - 4*{y}**2 - 4*{y} + 1, " + f"jnp.where({n} == 4, 16*{y}**4 - 8*{y}**3 - 12*{y}**2 + 4*{y} + 1, " + f"jnp.where({n} == 5, 32*{y}**5 - 16*{y}**4 - 32*{y}**3 + 12*{y}**2 + 6*{y} - 1, " + f"jnp.zeros_like({x})))))))" + ) + + @staticmethod + def shifted_chebyshev_polynomial_w(x: str, n: str) -> str: + # Shifted Chebyshev polynomial of the fourth kind W*_n(x) = W_n(2x - 1) + y = f"(2 * {x} - 1)" # shifted variable + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2*{y} + 1, " + f"jnp.where({n} == 2, 4*{y}**2 + 2*{y} - 1, " + f"jnp.where({n} == 3, 8*{y}**3 + 4*{y}**2 - 4*{y} - 1, " + f"jnp.where({n} == 4, 16*{y}**4 + 8*{y}**3 - 12*{y}**2 - 4*{y} + 1, " + f"jnp.where({n} == 5, 32*{y}**5 + 16*{y}**4 - 32*{y}**3 - 12*{y}**2 + 6*{y} + 1, " + f"jnp.zeros_like({x})))))))" + ) + + @staticmethod + def hermite_polynomial_h(x: str, n: str) -> str: + # Physicist's Hermite polynomial H_n(x) + # H_n(x) = 2^n * x^n - n*(n-1)/2 * 2^(n-2) * x^(n-2) + ... + # Use explicit formula: H_n(x) = n! * sum_{m=0}^{n//2} (-1)^m / (m! * (n-2m)!) * (2x)^(n-2m) + # For simplicity, use the relation: H_n(x) = 2^(n/2) * He_n(x * sqrt(2)) where He is probabilist's + # Actually simpler: use recurrence or closed form + # H_0 = 1, H_1 = 2x, H_2 = 4x^2 - 2, H_3 = 8x^3 - 12x + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 2 * {x}, " + f"jnp.where({n} == 2, 4 * {x}**2 - 2, " + f"jnp.where({n} == 3, 8 * {x}**3 - 12 * {x}, " + f"jnp.where({n} == 4, 16 * {x}**4 - 48 * {x}**2 + 12, " + f"jnp.where({n} == 5, 32 * {x}**5 - 160 * {x}**3 + 120 * {x}, " + f"jnp.zeros_like({x})))))))" # Fallback for higher n + ) + + @staticmethod + def hermite_polynomial_he(x: str, n: str) -> str: + # Probabilist's Hermite polynomial He_n(x) + # He_0 = 1, He_1 = x, He_2 = x^2 - 1, He_3 = x^3 - 3x + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, {x}, " + f"jnp.where({n} == 2, {x}**2 - 1, " + f"jnp.where({n} == 3, {x}**3 - 3 * {x}, " + f"jnp.where({n} == 4, {x}**4 - 6 * {x}**2 + 3, " + f"jnp.where({n} == 5, {x}**5 - 10 * {x}**3 + 15 * {x}, " + f"jnp.zeros_like({x})))))))" # Fallback for higher n + ) + + @staticmethod + def laguerre_polynomial_l(x: str, n: str) -> str: + # Laguerre polynomial L_n(x) + # L_0 = 1, L_1 = 1 - x, L_2 = (x^2 - 4x + 2)/2, L_3 = (-x^3 + 9x^2 - 18x + 6)/6 + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, 1 - {x}, " + f"jnp.where({n} == 2, ({x}**2 - 4*{x} + 2) / 2, " + f"jnp.where({n} == 3, (-{x}**3 + 9*{x}**2 - 18*{x} + 6) / 6, " + f"jnp.where({n} == 4, ({x}**4 - 16*{x}**3 + 72*{x}**2 - 96*{x} + 24) / 24, " + f"jnp.where({n} == 5, (-{x}**5 + 25*{x}**4 - 200*{x}**3 + 600*{x}**2 - 600*{x} + 120) / 120, " + f"jnp.zeros_like({x})))))))" # Fallback for higher n + ) + + @staticmethod + def legendre_polynomial_p(x: str, n: str) -> str: + # Legendre polynomial P_n(x) + # P_0 = 1, P_1 = x, P_2 = (3x^2 - 1)/2, P_3 = (5x^3 - 3x)/2 + return ( + f"jnp.where({n} == 0, jnp.ones_like({x}), " + f"jnp.where({n} == 1, {x}, " + f"jnp.where({n} == 2, (3 * {x}**2 - 1) / 2, " + f"jnp.where({n} == 3, (5 * {x}**3 - 3 * {x}) / 2, " + f"jnp.where({n} == 4, (35 * {x}**4 - 30 * {x}**2 + 3) / 8, " + f"jnp.where({n} == 5, (63 * {x}**5 - 70 * {x}**3 + 15 * {x}) / 8, " + f"jnp.zeros_like({x})))))))" # Fallback for higher n + ) + # Reciprocal and square @staticmethod def reciprocal(x: str) -> str: From d900f5e86745dec76713f4b0ef07005ef36b2f5a Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Thu, 4 Dec 2025 03:11:39 +0000 Subject: [PATCH 231/338] Update slow tests (#167967) This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/weekly.yml). Update the list of slow tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167967 Approved by: https://github.com/pytorchbot --- test/slow_tests.json | 510 +++++++++++++++++++++++-------------------- 1 file changed, 276 insertions(+), 234 deletions(-) diff --git a/test/slow_tests.json b/test/slow_tests.json index c027d3d1d0901..5f4a4934fd004 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,236 +1,278 @@ { - "EndToEndLSTM (__main__.RNNTest)": 190.48799641927084, - "MultiheadAttention (__main__.ModulesTest)": 141.2663370768229, - "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 82.87333234151204, - "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 70.6538565499442, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.34033711751302, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 171.25450134277344, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 119.71899922688802, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.35733322870163, - "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.64533233642578, - "test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.672952016194664, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 138.04000091552734, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 172.1344985961914, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 114.02050018310547, - "test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.25642830984933, - "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.3350003560384, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 120.95249938964844, - "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.97774887084961, - "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.90774917602539, - "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1144.3935089111328, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 222.58500061035156, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.10033162434894, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 517.1875050862631, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 113.88125228881836, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 235.77350616455078, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.6155014038086, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.63325119018555, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 216.2968317667643, - "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 153.0915012359619, - "test_cat_2k_args (__main__.TestTEFuserDynamic)": 108.80471753561869, - "test_cat_2k_args (__main__.TestTEFuserStatic)": 102.20949847949669, - "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 311.7026621500651, - "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 395.0001729329427, - "test_collect_callgrind (__main__.TestBenchmarkUtils)": 348.6218566894531, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.71574974060059, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.68499946594238, - "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.0557508468628, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 65.86899948120117, - "test_comprehensive_gradient_cuda_complex64 (__main__.TestDecompCUDA)": 97.15880012512207, - "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 103.20700073242188, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 102.74033610026042, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 460.4286702473958, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 435.62066650390625, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 287.3090057373047, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 265.1860008239746, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1235.7365112304688, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.20825004577637, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1281.2615051269531, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.90750026702881, - "test_comprehensive_linalg_householder_product_cuda_complex64 (__main__.TestDecompCUDA)": 79.04633331298828, - "test_comprehensive_linalg_lu_factor_ex_cuda_complex128 (__main__.TestDecompCUDA)": 68.10879821777344, - "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.43025207519531, - "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.94575023651123, - "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.93649864196777, - "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.46275043487549, - "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 64.10650062561035, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.03124904632568, - "test_comprehensive_linalg_svd_cuda_float64 (__main__.TestDecompCUDA)": 64.32800025939942, - "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.41353665865384, - "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 100.17661388103778, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 110.95025062561035, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 108.06550025939941, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 104.24150085449219, - "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.453749656677246, - "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 61.739999771118164, - "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 69.96549987792969, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 113.65749931335449, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 106.57500076293945, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.54049682617188, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 116.19766489664714, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 272.48475646972656, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 248.12175369262695, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 79.66900062561035, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 81.52649879455566, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 79.29400062561035, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.40349960327148, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 128.42924880981445, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 125.03675079345703, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1264.9732360839844, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1250.7332458496094, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.0684814453125, - "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 574.4627532958984, - "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 581.7282485961914, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 65.052001953125, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 61.19200134277344, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 63.16874885559082, - "test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 62.39250183105469, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.32574844360352, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 113.91499900817871, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 74.42549800872803, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 76.1560001373291, - "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.76750087738037, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 70.69724941253662, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.87625026702881, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 80.2542495727539, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 69.0419979095459, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 117.03342655726841, - "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 289.50213841029574, - "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 67.38800048828125, - "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 145.27399444580078, - "test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 66.9245999654134, - "test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 151.91099548339844, - "test_conv_bn_fuse_cpu (__main__.CpuTests)": 92.79549789428711, - "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.60149955749512, - "test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 69.27724676392972, - "test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 76.24971498761859, - "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 81.93449974060059, - "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.87700080871582, - "test_count_nonzero_all (__main__.TestBool)": 631.2585144042969, - "test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.042999267578125, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.49850082397461, - "test_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 93.03299713134766, - "test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 228.46711820714614, - "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 286.29998779296875, - "test_fail_arithmetic_ops.py (__main__.TestTyping)": 68.43842806134906, - "test_fail_random.py (__main__.TestTyping)": 74.83523060725285, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 72.84900093078613, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 75.86675071716309, - "test_fuse_large_params_cpu (__main__.CpuTests)": 151.4199981689453, - "test_fuse_large_params_cuda (__main__.GPUTests)": 60.351999282836914, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 158.3622828892299, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 149.6796646118164, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.97800064086914, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 114.8385009765625, - "test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.69736822027909, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.62700080871582, - "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.197998046875, - "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 96.46900177001953, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 187.83824920654297, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.49449920654297, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 124.90424919128418, - "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 518.4157485961914, - "test_indirect_device_assert (__main__.TritonCodeGenTests)": 304.6440022786458, - "test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 143.82052836698645, - "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 77.4985705784389, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 76.06225109100342, - "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 138.9222858973912, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.62233225504558, - "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 148.1219940185547, - "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 109.34200286865234, - "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 119.36233266194661, - "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 127.95700073242188, - "test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.64850175380707, - "test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 105.3174296787807, - "test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.9210001627604, - "test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 504.3250020345052, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 86.21566645304362, - "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 129.277715410505, - "test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 64.24800109863281, - "test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 77.23899841308594, - "test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_True (__main__.TestCKBackend)": 65.15649795532227, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.579833984375, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.6555004119873, - "test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 142.21566772460938, - "test_proper_exit (__main__.TestDataLoader)": 267.74214717320035, - "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 266.6539971487863, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 101.97100067138672, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.3346659342448, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.50300216674805, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.61333465576172, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.41133371988933, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.37100219726562, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.30900065104167, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.61750030517578, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 79.33600234985352, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 101.2393315633138, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 103.18400192260742, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.4114990234375, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.52833302815755, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.72700119018555, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.61966705322266, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.2750015258789, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.17449951171875, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.96749877929688, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 106.44049835205078, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7173334757487, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 531.5236612955729, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1077.4210205078125, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 812.0880126953125, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1347.9365234375, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 88.93533070882161, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 269.01949310302734, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.99799601236978, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 232.36275100708008, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 69.80400085449219, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 134.3415012359619, - "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.51749992370605, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 91.21066792805989, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.97775268554688, - "test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.608266321818036, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.62575149536133, - "test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 63.59499969482422, - "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 88.68299865722656, - "test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 91.50320053100586, - "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.10774898529053, - "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 66.20533180236816, - "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 243.1092529296875, - "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 105.01200103759766, - "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.93685695103237, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 142.38899993896484, - "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 119.90166600545247, - "test_sort_bool_cpu (__main__.CpuTritonTests)": 346.2856750488281, - "test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 423.09974098205566, - "test_sort_stable_cuda (__main__.GPUTests)": 117.61659927368164, - "test_sort_transpose_cpu (__main__.CpuTritonTests)": 378.31200154622394, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 222.822007894516, - "test_terminate_handler_on_crash (__main__.TestTorch)": 143.31728431156702, - "test_terminate_signal (__main__.ForkTest)": 168.20485967184817, - "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 168.19242484867573, - "test_terminate_signal (__main__.SpawnTest)": 172.16428443363733, - "test_thnn_conv_strided_padded_dilated (__main__.TestConvolutionNN)": 93.30639710426331, - "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 163.89743041992188, - "test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 60.47671399797712, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.39550018310547, - "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 173.53924942016602, - "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 175.3212537765503, - "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.20649909973145, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 99.9885025024414, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 71.64024829864502, - "test_view_ops (__main__.TestViewOpsWithLocalTensor)": 73.45887422561646, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.75249862670898, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 61.858001708984375, - "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 65.11023766653878, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 66.35274982452393, - "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.196499824523926, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.75380906604585, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 73.64649868011475, - "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 75.09799966358003, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.51450157165527, - "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21433276221866, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.20024871826172, - "test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 88.1349983215332, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.89924907684326, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.32975196838379, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 120.09600067138672 + "EndToEndLSTM (__main__.RNNTest)": 195.11499938964843, + "MultiheadAttention (__main__.ModulesTest)": 142.00380249023436, + "test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 214.6786651611328, + "test_RNN_cpu_vs_cudnn_no_dropout (__main__.TestNN)": 72.39199912548065, + "test_RNN_cpu_vs_cudnn_with_dropout (__main__.TestNN)": 73.05633429686229, + "test_StridedShard_to_shard_order (__main__.Test_StridedShard_with_shard_order)": 253.58512496948242, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 106.16550159454346, + "test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.58166758219402, + "test_addmm_relu_tunableop_rocm_cuda_float32 (__main__.TestLinalgCUDA)": 62.60266876220703, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 62.53962421417236, + "test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 177.9409942626953, + "test_aot_autograd_disable_functionalization_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.34580052693685, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.86476732889811, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 142.54474639892578, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 195.94950103759766, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 120.17424774169922, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.93349933624268, + "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 66.56851626980689, + "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 70.99724960327148, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 149.67525100708008, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 180.85475158691406, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 104.83274841308594, + "test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.97112907901887, + "test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.97749996185303, + "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 73.70349884033203, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 119.76774978637695, + "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 93.69075012207031, + "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 109.89175033569336, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 801.439599609375, + "test_avg_pool3d_backward2_cpu (__main__.CpuTritonTests)": 270.46433512369794, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 211.92539825439454, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 526.4229965209961, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 540.007625579834, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 73.37349891662598, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 146.07825088500977, + "test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 61.05500030517578, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 77.6555004119873, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 63.8514986038208, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 264.5168743133545, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 165.7322540283203, + "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 330.8664970397949, + "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 423.7527503967285, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 313.5642509460449, + "test_comprehensive_cholesky_inverse_cuda_float32 (__main__.TestDecompCUDA)": 70.66033256053925, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.57474899291992, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 103.05299949645996, + "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 67.24449920654297, + "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 68.60375022888184, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 105.27174758911133, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 97.67850112915039, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 458.8267517089844, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 451.6082458496094, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 298.8152503967285, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 255.6614990234375, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1176.4095153808594, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.6922492980957, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1098.8550109863281, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 77.52225112915039, + "test_comprehensive_linalg_lu_factor_cuda_complex128 (__main__.TestDecompCUDA)": 64.52633285522461, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.95650100708008, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.89800071716309, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.7504997253418, + "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 74.69425201416016, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 62.47725009918213, + "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 66.51850032806396, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 115.14674758911133, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 111.31599998474121, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 108.47875022888184, + "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.36350059509277, + "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 64.12074947357178, + "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 63.71774959564209, + "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 66.63899975731259, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 114.73800086975098, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 110.1662483215332, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 115.3847484588623, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 109.4905014038086, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 306.85575103759766, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 228.0407485961914, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 78.3700008392334, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 84.47775268554688, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 78.47249984741211, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 86.97974967956543, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 124.5634994506836, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 122.19799995422363, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1262.0645141601562, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1255.4177551269531, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1257.4462585449219, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 605.8682556152344, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 615.4145050048828, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 66.37674903869629, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 65.44024848937988, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 66.0570011138916, + "test_comprehensive_nn_functional_pad_reflect_cuda_complex64 (__main__.TestDecompCUDA)": 62.90416653951009, + "test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 61.29275035858154, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.26900100708008, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 112.6924991607666, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 73.96350288391113, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.25400161743164, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 70.42575073242188, + "test_comprehensive_pca_lowrank_cuda_complex64 (__main__.TestDecompCUDA)": 99.28966617584229, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 69.16975021362305, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 75.30550003051758, + "test_comprehensive_svd_lowrank_cuda_complex128 (__main__.TestDecompCUDA)": 120.88183275858562, + "test_comprehensive_svd_lowrank_cuda_complex64 (__main__.TestDecompCUDA)": 119.77483590443929, + "test_comprehensive_svd_lowrank_cuda_float32 (__main__.TestDecompCUDA)": 120.83816623687744, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 86.3487491607666, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 78.20924949645996, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 89.26825046539307, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 220.15350151062012, + "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 77.47299766540527, + "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 156.85225296020508, + "test_conv_bn_fuse_cpu (__main__.CpuTests)": 68.80920028686523, + "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 68.10125064849854, + "test_conv_large_batch_1_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 121.31333414713542, + "test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 80.68750095367432, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 82.94275093078613, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 80.35500144958496, + "test_count_nonzero_all (__main__.TestBool)": 650.8682556152344, + "test_cross_entropy_large_tensor_reduction_sum_cuda (__main__.TestNNDeviceTypeCUDA)": 323.86448669433594, + "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 450.4883321126302, + "test_diff_hyperparams_sharding_strategy_str_no_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 60.20799891153971, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 87.0319995880127, + "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 1517.4078125, + "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 90.65559997558594, + "test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 67.08999888102214, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 72.2988748550415, + "test_fail_creation_ops.py (__main__.TestTyping)": 102.47843830759932, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 80.67500114440918, + "test_fn_grad_add_cpu_complex128 (__main__.TestComplexBwdGradientsCPU)": 75.49025793998472, + "test_fn_grad_constant_pad_nd_cpu_complex128 (__main__.TestComplexBwdGradientsCPU)": 177.47612947033298, + "test_fn_grad_constant_pad_nd_cuda_complex128 (__main__.TestComplexBwdGradientsCUDA)": 124.01433499654134, + "test_fn_grad_diagonal_scatter_cpu_complex128 (__main__.TestComplexBwdGradientsCPU)": 387.570063929404, + "test_fn_grad_diagonal_scatter_cuda_complex128 (__main__.TestComplexBwdGradientsCUDA)": 155.9375, + "test_fn_grad_flip_cpu_complex128 (__main__.TestComplexBwdGradientsCPU)": 61.3171936158211, + "test_fn_grad_rsub_cpu_complex128 (__main__.TestComplexBwdGradientsCPU)": 83.22996791716545, + "test_fn_grad_rsub_cuda_complex128 (__main__.TestComplexBwdGradientsCUDA)": 62.90033372243246, + "test_fn_grad_sub_cpu_complex128 (__main__.TestComplexBwdGradientsCPU)": 77.49893539182601, + "test_fn_grad_sub_cuda_complex128 (__main__.TestComplexBwdGradientsCUDA)": 61.0314998626709, + "test_fn_grad_where_cpu_complex128 (__main__.TestComplexBwdGradientsCPU)": 99.6264519230012, + "test_fn_grad_where_cuda_complex128 (__main__.TestComplexBwdGradientsCUDA)": 72.60183270772298, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 84.32150268554688, + "test_fuse_large_params_cpu (__main__.CpuTests)": 97.63633219401042, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 167.9266242980957, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 167.08250045776367, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 148.94650268554688, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 118.18500137329102, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 81.35249900817871, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 196.03149795532227, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 111.10725021362305, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 134.25675010681152, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 614.353271484375, + "test_graph_make_graphed_callables_same_pool (__main__.TestCuda)": 102.73666604359944, + "test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 385.05516481399536, + "test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 395.0171728134155, + "test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.79066467285156, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 312.15050506591797, + "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 71.82537364959717, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 85.09174919128418, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 129.14387321472168, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.64374923706055, + "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 130.71199989318848, + "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 126.44325256347656, + "test_list_clearing_cuda (__main__.GPUTests)": 61.48289999961853, + "test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 104.84637451171875, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 102.5270004272461, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 136.7854986190796, + "test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 426.58765665690106, + "test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 133.9463348388672, + "test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 359.5349934895833, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 68.19662570953369, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.87825012207031, + "test_nll_loss_large_tensor_reduction_sum_cuda (__main__.TestNNDeviceTypeCUDA)": 340.27033456166583, + "test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 135.83149814605713, + "test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 67.58062505722046, + "test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 198.98699951171875, + "test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 500.39749908447266, + "test_pool3d_large_size_int64_cuda (__main__.TestPoolingNNDeviceTypeCUDA)": 65.12433274586995, + "test_proper_exit (__main__.TestDataLoader)": 203.98437309265137, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 196.37637424468994, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 63.505500078201294, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 111.68949890136719, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.62675094604492, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 93.89300155639648, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 117.77149963378906, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.85300254821777, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 88.89249992370605, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 100.24625015258789, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 111.7132511138916, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 84.21674919128418, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 105.77849960327148, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 109.34375, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 92.73649978637695, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.92499923706055, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 108.25849914550781, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 64.16908399264018, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.11400032043457, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 114.92299842834473, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 61.62425025304159, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.86524963378906, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 105.85474967956543, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 66.22370831171672, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 113.55375099182129, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.45649909973145, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 573.1502685546875, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1091.4237670898438, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 781.7357482910156, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1477.8807678222656, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 91.73400115966797, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 274.52249908447266, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 142.28099822998047, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 227.64300155639648, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 78.95800018310547, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 139.07250213623047, + "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 70.76949882507324, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 100.25174903869629, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.1675033569336, + "test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 79.20649909973145, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 147.0157470703125, + "test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 85.56925010681152, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_inference_precision_amp (__main__.DeterministicTest)": 62.117165883382164, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_inference_precision_bfloat16 (__main__.DeterministicTest)": 80.71633275349934, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 162.40999857584634, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 112.27533340454102, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 147.1988321940104, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 104.0053342183431, + "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 61.973000844319664, + "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 61.754499435424805, + "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 60.08883412679037, + "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 86.1146666208903, + "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 104.33766746520996, + "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 114.78433227539062, + "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 86.49966684977214, + "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 71.44516626993816, + "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.8162488937378, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 72.04562425613403, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 184.2334976196289, + "test_scaled_gemm_offline_tunableop_cuda_float8_e4m3fnuz (__main__.TestLinalgCUDA)": 84.8563323020935, + "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.34962558746338, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 119.15850162506104, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 140.9133758544922, + "test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 106.76350021362305, + "test_sort_stable_cpu (__main__.CpuTritonTests)": 1319.0793050130208, + "test_sort_stable_cuda (__main__.GPUTests)": 96.01039962768554, + "test_split_cumsum_cpu (__main__.CpuTritonTests)": 90.8499984741211, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 304.4350051879883, + "test_tensor_split (__main__.TestVmapOperators)": 105.89132479213826, + "test_terminate_handler_on_crash (__main__.TestTorch)": 167.24449968338013, + "test_terminate_signal (__main__.ForkTest)": 199.22387313842773, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 199.12587642669678, + "test_terminate_signal (__main__.SpawnTest)": 200.71112155914307, + "test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 65.3956667582194, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 88.69500064849854, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 212.44074630737305, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 209.64949798583984, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.97124862670898, + "test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.45366668701172, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 93.24074745178223, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 76.62825012207031, + "test_vec_compare_op_cpu_only (__main__.CPUReproTests)": 60.935458501180015, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 96.73649978637695, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 73.72424983978271, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 67.43249893188477, + "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 62.6795015335083, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 75.2802505493164, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 77.50925064086914, + "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 66.33838690480879, + "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 67.38049983978271, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 75.26774978637695, + "test_vmapjvpvjp_svd_cpu_float32 (__main__.TestOperatorsCPU)": 61.21835474814138, + "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 65.22375106811523, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 79.81699752807617, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 85.13375091552734, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 83.11999893188477, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 111.10899925231934, + "test_warp_softmax_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 154.79667123158774, + "test_warp_softmax_64bit_indexing_cuda_float32 (__main__.TestNNDeviceTypeCUDA)": 137.61766529083252 } \ No newline at end of file From 49a04d26088acc17d948ddd66920f3e16371e873 Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 3 Dec 2025 09:39:37 -0800 Subject: [PATCH 232/338] [effect] Remove special handling for profiler op (#168389) We shouldn't need this anymore as we have a registration for the op to have no effect Differential Revision: [D87680134](https://our.internmc.facebook.com/intern/diff/D87680134) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168389 Approved by: https://github.com/zou3519 ghstack dependencies: #167364 --- torch/_higher_order_ops/effects.py | 5 ----- torch/_ops.py | 34 +----------------------------- 2 files changed, 1 insertion(+), 38 deletions(-) diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index 86707a4f55ef1..96d7872048ec8 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -112,11 +112,6 @@ def has_aliasing(op: OpType): def has_effects(op) -> bool: - # Skip over the profiler's RecordFunction as they should not show up in the graph - _skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction} - if op in _skip_ops: - return False - return ( isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) and not has_aliasing(op) diff --git a/torch/_ops.py b/torch/_ops.py index 8f8a7328429fa..75905d78da5b5 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1043,28 +1043,6 @@ def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): if _may_use_fallthrough_instead_of_fallback(key) ] - @contextlib.contextmanager - def _register_as_effectful_op_temporarily(self): - from torch._higher_order_ops.effects import ( - _EffectType, - _get_effect, - _register_effectful_op, - ) - - try: - # We don't want to register the effect if there already exists a - # registration, especially if the registration is None (explicitly - # no effect) - register_tmp_effect = _get_effect(self) is None - handle = None - if register_tmp_effect: - handle = _register_effectful_op(self, _EffectType.ORDERED) - yield - finally: - if register_tmp_effect: - assert handle is not None - handle.destroy() - # Use positional-only argument to avoid naming collision with aten ops arguments # that are named "self". This way, all the aten ops can be called by kwargs. def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: @@ -1072,17 +1050,7 @@ def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: # When any inputs are FakeScriptObject, we need to # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject. - # - # Note: - # 1. We only register the torchbind op temporarily as effectful op because we only want - # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior - # of the eagerly executing the op might change after tracing. - # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might - # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction. - with self._register_as_effectful_op_temporarily(): - return self._dispatch_in_python( - self._fallthrough_keys(), *args, **kwargs - ) + return self._dispatch_in_python(self._fallthrough_keys(), *args, **kwargs) return self._op(*args, **kwargs) def _dispatch_in_python( From bc39b2b3bc7a6e19a42e62bd576974035086fe55 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 2 Dec 2025 02:32:54 +0000 Subject: [PATCH 233/338] Replace `msg` by `args` in `raise_observed_exception` (#169343) PR removes the `msg` argument in `raise_observed_exception` and replace all users of it by the existing `args` one Pull Request resolved: https://github.com/pytorch/pytorch/pull/169343 Approved by: https://github.com/Lucaskabela --- torch/_dynamo/exc.py | 4 ---- torch/_dynamo/variables/nn_module.py | 14 +++++++++----- torch/_dynamo/variables/user_defined.py | 10 +++++++--- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 5b0e8a402dd96..a7bdf1caff241 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -387,15 +387,11 @@ def raise_observed_exception( *, args: Optional[list[Any]] = None, kwargs: Optional[dict[str, Any]] = None, - msg: Optional[str] = None, ) -> NoReturn: from .variables import BuiltinVariable # CPython here raises an exception. Since there is no python code, we have to manually setup the exception # stack and raise the exception. - # If a message is provided but no args, use the message as the first argument - if msg is not None and (args is None or len(args) == 0): - args = [msg] exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type] tx.exn_vt_stack.set_current_exception(exception_vt) # type: ignore[arg-type] raised_exc = get_dynamo_observed_exception(exc_type) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 525c42a009c1d..0c813cb2e0305 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -112,9 +112,11 @@ def convert_to_fake(x: Any) -> Any: raise_observed_exception( AttributeError, tx, - msg=str(e) - if str(e) - else "AttributeError during lazy module initialization", + args=[ + str(e) + if str(e) + else "AttributeError during lazy module initialization" + ], ) @@ -397,7 +399,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker raise_observed_exception( AttributeError, tx, - msg=f"'{type(base).__name__}' object has no attribute '{name}'", + args=[f"'{type(base).__name__}' object has no attribute '{name}'"], ) if name == "forward": @@ -1330,7 +1332,9 @@ def manually_trace_nn_module_getattr( raise_observed_exception( AttributeError, tx, - msg=f"'{type(self.value).__name__}' object has no attribute '{name}'", + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'" + ], ) assert out is not None return out diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0863d8592abd2..ce5c0a2d31294 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -285,7 +285,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke raise_observed_exception( AttributeError, tx, - msg=f"type object '{self.value.__name__}' has no attribute '{name}'", + args=[ + f"type object '{self.value.__name__}' has no attribute '{name}'" + ], ) else: # Cannot reason about classes with a custom metaclass @@ -1460,7 +1462,9 @@ def var_getattr(self, tx: "InstructionTranslator", name): raise_observed_exception( AttributeError, tx, - msg=f"'{type(self.value).__name__}' object has no attribute '{name}'", + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'" + ], ) return result @@ -1736,7 +1740,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): raise_observed_exception( AttributeError, tx, - msg=f"'{type(self.value).__name__}' object has no attribute '{name}'", + args=[f"'{type(self.value).__name__}' object has no attribute '{name}'"], ) def call_obj_hasattr( From 6f53fefeb90ad3281119b5cfc4aa9ffd8a066e3d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 03:32:12 +0000 Subject: [PATCH 234/338] Revert "Avoid std::tie and returning value constructions in qconv_unpack.cpp (#169207)" This reverts commit bb3034198b459401fabeab254e1b99f0115046e2. Reverted https://github.com/pytorch/pytorch/pull/169207 on behalf of https://github.com/huydhn due to Sorry to keep reverting this, but the issue is still there ([comment](https://github.com/pytorch/pytorch/pull/169207#issuecomment-3609906136)) --- .../ATen/native/quantized/qconv_unpack.cpp | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/quantized/qconv_unpack.cpp b/aten/src/ATen/native/quantized/qconv_unpack.cpp index df66a6087f738..4c2352a396177 100644 --- a/aten/src/ATen/native/quantized/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/qconv_unpack.cpp @@ -82,31 +82,32 @@ class QConv1dUnpackWeightsInt8 final { static std::tuple> run( const c10::intrusive_ptr>& packed_weight) { auto& ctx = at::globalContext(); + at::Tensor weight; + std::optional bias; #ifdef USE_FBGEMM if (ctx.qEngine() == at::QEngine::FBGEMM || ctx.qEngine() == at::QEngine::X86) { - auto result = packed_weight->unpack(); - auto& weight = std::get<0>(result); + std::tie(weight, bias) = packed_weight->unpack(); weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return result; + return std::tuple>(weight, bias); } #endif #ifdef USE_PYTORCH_QNNPACK if (ctx.qEngine() == at::QEngine::QNNPACK) { - auto result = packed_weight->unpack(); - auto& weight = std::get<0>(result); - weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return result; + std::tie(weight, bias) = packed_weight->unpack(); + at::Tensor new_weight = weight.clone(); + new_weight = new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return std::tuple>(new_weight, bias); } #endif #if AT_MKLDNN_ENABLED() if (ctx.qEngine() == at::QEngine::ONEDNN) { - auto result = packed_weight->unpack(); - auto& weight = std::get<0>(result); - weight = weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return result; + std::tie(weight, bias) = packed_weight->unpack(); + at::Tensor new_weight = weight.clone(); + new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return std::tuple>(new_weight, bias); } #endif From a15066c28b3145e6edbfc88359d0411d14cfc70c Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 4 Dec 2025 03:51:34 +0000 Subject: [PATCH 235/338] Fix torch.fx for the newer "|" union syntax (#169453) This PR fixes torch.fx handling of the newer `|` type. Otherwise, they could be errors like ``` "/torch/package/importer.py", line 95, in get_name name = obj.__name__ AttributeError: 'types.UnionType' object has no attribute '__name__'. Did you mean: '__ne__'? ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/169453 Approved by: https://github.com/albanD, https://github.com/malfet --- test/test_fx.py | 10 ++++++++++ torch/fx/graph.py | 5 +++++ torch/fx/node.py | 2 ++ 3 files changed, 17 insertions(+) diff --git a/test/test_fx.py b/test/test_fx.py index 7fdd6552edc7b..e2584156bf730 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2381,6 +2381,16 @@ def test_typename_print_pre_pep585(self): self.assertTrue("typing.List[float]" in str(graph)) + def test_typename_print_union(self): + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,), type_expr=float|torch.Tensor|None + ) + output: torch.fx.Node = graph.output(b) + + self.assertTrue('float | torch.Tensor | None' in str(graph)) + def test_layout(self): class M(torch.nn.Module): def forward(self, x): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 36ef68a9a2e35..d4b0a1b1500d3 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -10,6 +10,7 @@ import os import pprint import re +import types import typing import warnings from collections import defaultdict @@ -499,6 +500,10 @@ def type_repr(o: Any): return "()" typename = _type_repr(o) + if isinstance(o, types.UnionType) and "|" in typename: + # str | int + args = [type_repr(arg) for arg in o.__args__] + return "|".join(args) if origin_type := getattr(o, "__origin__", None): # list[...], typing.List[...], TensorType[...] diff --git a/torch/fx/node.py b/torch/fx/node.py index 5afabe40ec341..85e6f3a82e969 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -174,6 +174,8 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: # Fixup segment_reduce mismatch if module == "torch" and name == "segment_reduce": name = "_" + name + if module == "torch.nn.functional" and name in ("_ScalingType", "_SwizzleType"): + name = name.removeprefix("_") return f"{module}.{name}" From cc0853af42122f8185321f542616f4474e717f09 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 03:51:59 +0000 Subject: [PATCH 236/338] Revert "Replace `msg` by `args` in `raise_observed_exception` (#169343)" This reverts commit bc39b2b3bc7a6e19a42e62bd576974035086fe55. Reverted https://github.com/pytorch/pytorch/pull/169343 on behalf of https://github.com/huydhn due to It looks like there is a land race here failing lint ([comment](https://github.com/pytorch/pytorch/pull/169343#issuecomment-3609956743)) --- torch/_dynamo/exc.py | 4 ++++ torch/_dynamo/variables/nn_module.py | 14 +++++--------- torch/_dynamo/variables/user_defined.py | 10 +++------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index a7bdf1caff241..5b0e8a402dd96 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -387,11 +387,15 @@ def raise_observed_exception( *, args: Optional[list[Any]] = None, kwargs: Optional[dict[str, Any]] = None, + msg: Optional[str] = None, ) -> NoReturn: from .variables import BuiltinVariable # CPython here raises an exception. Since there is no python code, we have to manually setup the exception # stack and raise the exception. + # If a message is provided but no args, use the message as the first argument + if msg is not None and (args is None or len(args) == 0): + args = [msg] exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type] tx.exn_vt_stack.set_current_exception(exception_vt) # type: ignore[arg-type] raised_exc = get_dynamo_observed_exception(exc_type) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 0c813cb2e0305..525c42a009c1d 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -112,11 +112,9 @@ def convert_to_fake(x: Any) -> Any: raise_observed_exception( AttributeError, tx, - args=[ - str(e) - if str(e) - else "AttributeError during lazy module initialization" - ], + msg=str(e) + if str(e) + else "AttributeError during lazy module initialization", ) @@ -399,7 +397,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker raise_observed_exception( AttributeError, tx, - args=[f"'{type(base).__name__}' object has no attribute '{name}'"], + msg=f"'{type(base).__name__}' object has no attribute '{name}'", ) if name == "forward": @@ -1332,9 +1330,7 @@ def manually_trace_nn_module_getattr( raise_observed_exception( AttributeError, tx, - args=[ - f"'{type(self.value).__name__}' object has no attribute '{name}'" - ], + msg=f"'{type(self.value).__name__}' object has no attribute '{name}'", ) assert out is not None return out diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ce5c0a2d31294..0863d8592abd2 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -285,9 +285,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke raise_observed_exception( AttributeError, tx, - args=[ - f"type object '{self.value.__name__}' has no attribute '{name}'" - ], + msg=f"type object '{self.value.__name__}' has no attribute '{name}'", ) else: # Cannot reason about classes with a custom metaclass @@ -1462,9 +1460,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): raise_observed_exception( AttributeError, tx, - args=[ - f"'{type(self.value).__name__}' object has no attribute '{name}'" - ], + msg=f"'{type(self.value).__name__}' object has no attribute '{name}'", ) return result @@ -1740,7 +1736,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): raise_observed_exception( AttributeError, tx, - args=[f"'{type(self.value).__name__}' object has no attribute '{name}'"], + msg=f"'{type(self.value).__name__}' object has no attribute '{name}'", ) def call_obj_hasattr( From dd18a75336a4fbd7497955cc5665904724fce889 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 3 Dec 2025 18:59:38 +0000 Subject: [PATCH 237/338] [xpu][feature] [2/2] Introduce XPUPluggableAllocator in frontend part (#169043) # Motivation This PR aims to introduce `torch.xpu.memory.XPUPluggableAllocator` and focus on the frontend part. Meanwhile, we introduce an API `torch.xpu.memory.change_current_allocator` that used to change the current `XPUAllocator`. Both APIs has the CUDA counterpart. # Additional Context This API would be used in popular repos such as https://github.com/search?q=repo%3Avllm-project%2Fvllm%20cudapluggableallocator&type=code Pull Request resolved: https://github.com/pytorch/pytorch/pull/169043 Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/eellison ghstack dependencies: #168966 --- docs/source/xpu.md | 2 + test/test_xpu.py | 137 +++++++++++++++++++++++++++++++++++ torch/_C/__init__.pyi.in | 6 ++ torch/csrc/xpu/Module.cpp | 31 ++++++++ torch/utils/cpp_extension.py | 17 ++++- torch/xpu/__init__.py | 4 + torch/xpu/memory.py | 86 +++++++++++++++++++++- 7 files changed, 278 insertions(+), 5 deletions(-) diff --git a/docs/source/xpu.md b/docs/source/xpu.md index 6cd82aa984159..d187efbfc77a2 100644 --- a/docs/source/xpu.md +++ b/docs/source/xpu.md @@ -75,6 +75,8 @@ :toctree: generated :nosignatures: + XPUPluggableAllocator + change_current_allocator empty_cache get_per_process_memory_fraction max_memory_allocated diff --git a/test/test_xpu.py b/test/test_xpu.py index 6b92dc4c96b38..307fa10fe0527 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,5 +1,6 @@ # Owner(s): ["module: intel"] +import ctypes import gc import re import subprocess @@ -580,6 +581,142 @@ def test_can_device_access_peer(self): torch.xpu.can_device_access_peer(peer, device), ) + def get_dummy_allocator(self, check_vars): + dummy_allocator_source_vars = """ + #include + #include + + extern "C" { + C10_EXPORT int called_dummy_alloc = 0; + C10_EXPORT int called_dummy_free = 0; + + C10_EXPORT void* dummy_alloc(size_t size, int device, sycl::queue* queue) { + called_dummy_alloc = 123; + auto& sycl_device = c10::xpu::get_raw_device(device); + auto& sycl_context = c10::xpu::get_device_context(); + void* ptr = sycl::malloc_shared(size, sycl_device, sycl_context); + return ptr; + } + + C10_EXPORT void dummy_free(void* ptr, size_t size, int device, sycl::queue* queue) { + called_dummy_free = 321; + sycl::free(ptr, c10::xpu::get_device_context()); + } + } + """ + dummy_allocator_source_no_vars = """ + #include + #include + + extern "C" { + C10_EXPORT void* dummy_alloc(size_t size, int device, sycl::queue* queue) { + auto& sycl_device = c10::xpu::get_raw_device(device); + auto& sycl_context = c10::xpu::get_device_context(); + void* ptr = sycl::malloc_shared(size, sycl_device, sycl_context); + return ptr; + } + + C10_EXPORT void dummy_free(void* ptr, size_t size, int device, sycl::queue* queue) { + sycl::free(ptr, c10::xpu::get_device_context()); + } + } + """ + + from torch.utils.cpp_extension import load_inline + + dummy_allocator_libname = "dummy_allocator" + dummy_allocator = load_inline( + name=dummy_allocator_libname, + cpp_sources=dummy_allocator_source_vars + if check_vars + else dummy_allocator_source_no_vars, + is_python_module=False, + keep_intermediates=False, + verbose=True, + with_sycl=True, + ) + allocator = torch.xpu.memory.XPUPluggableAllocator( + dummy_allocator, + "dummy_alloc", + "dummy_free", + ) + return allocator, dummy_allocator + + def test_xpu_pluggable_allocator(self): + torch.xpu.init() + allocator, dummy_allocator = self.get_dummy_allocator(True) + alloc_lib = ctypes.CDLL(dummy_allocator) + called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc") + called_dummy_free = ctypes.c_int.in_dll(alloc_lib, "called_dummy_free") + self.assertEqual(called_dummy_alloc.value, 0) + self.assertEqual(called_dummy_free.value, 0) + + with self.assertRaises(RuntimeError): + torch.xpu.memory.change_current_allocator(allocator) + + def check_output(script: str) -> str: + return ( + subprocess.check_output([sys.executable, "-c", script]) + .decode("ascii") + .strip() + ) + + test_script = """\ +import ctypes +import torch +from torch.utils.cpp_extension import load_inline + +dummy_allocator_source_vars = \"\"\"\ +#include +#include + +extern "C" { + C10_EXPORT int called_dummy_alloc = 0; + C10_EXPORT int called_dummy_free = 0; + + C10_EXPORT void* dummy_alloc(size_t size, int device, sycl::queue* queue) { + called_dummy_alloc = 123; + auto& sycl_device = c10::xpu::get_raw_device(device); + auto& sycl_context = c10::xpu::get_device_context(); + void* ptr = sycl::malloc_shared(size, sycl_device, sycl_context); + return ptr; + } + + C10_EXPORT void dummy_free(void* ptr, size_t size, int device, sycl::queue* queue) { + called_dummy_free = 321; + sycl::free(ptr, c10::xpu::get_device_context()); + } +} +\"\"\" + +if __name__ == "__main__": + dummy_allocator = load_inline( + name='dummy_allocator', + cpp_sources=dummy_allocator_source_vars, + is_python_module=False, + keep_intermediates=False, + verbose=True, + with_sycl=True, + ) + + allocator = torch.xpu.memory.XPUPluggableAllocator( + dummy_allocator, + "dummy_alloc", + "dummy_free", + ) + torch.xpu.memory.change_current_allocator(allocator) + tensor = torch.randn(100, device='xpu') + del tensor + allocator_lib = ctypes.CDLL(dummy_allocator) + called_dummy_alloc = ctypes.c_int.in_dll(allocator_lib, "called_dummy_alloc") + called_dummy_free = ctypes.c_int.in_dll(allocator_lib, "called_dummy_free") + print(called_dummy_alloc.value, called_dummy_free.value) +""" + rc = check_output(test_script).splitlines()[-1] + called_dummy_alloc_value, called_dummy_free_value = rc.split() + self.assertEqual(called_dummy_alloc_value, "123") + self.assertEqual(called_dummy_free_value, "321") + def test_torch_version_xpu(self): self.assertEqual(len(torch.version.xpu), 8) compiler_version = int(torch.version.xpu) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9dc460d9522fa..9ad00753fe25c 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2442,6 +2442,12 @@ class _XpuDeviceProperties: type: str uuid: Any +class _xpu_XPUAllocator: ... + +def _xpu_customAllocator(alloc_fn: _int, free_fn: _int) -> _xpu_XPUAllocator: ... +def _xpu_changeCurrentAllocator(allocator: _xpu_XPUAllocator) -> None: ... +def _xpu_getAllocator() -> _xpu_XPUAllocator: ... + # Defined in torch/csrc/xpu/Stream.cpp class _XpuStreamBase(Stream): stream_id: _int diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index ba5998ba3d3ce..08cfc9185a298 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -10,6 +10,7 @@ #include #include #include +#include using namespace torch; @@ -372,6 +373,35 @@ static void registerXpuDeviceProperties(PyObject* module) { }); } +static void registerXpuPluggableAllocator(PyObject* module) { + auto m = py::handle(module).cast(); + + py::class_< + c10::xpu::XPUCachingAllocator::XPUAllocator, + std::shared_ptr>( + m, "_xpu_XPUAllocator"); + + m.def("_xpu_getAllocator", []() { + return py::cast(torch::xpu::XPUPluggableAllocator::getCurrentAllocator()); + }); + m.def( + "_xpu_changeCurrentAllocator", + [](std::shared_ptr + allocator) { + torch::xpu::XPUPluggableAllocator::changeCurrentAllocator(allocator); + }); + m.def("_xpu_customAllocator", [](uint64_t malloc_ptr, uint64_t free_ptr) { + using MallocFuncType = void*(size_t, int, sycl::queue*); + using FreeFuncType = void(void*, size_t, int, sycl::queue*); + std::function malloc_fn = + reinterpret_cast(malloc_ptr); + std::function free_fn = + reinterpret_cast(free_ptr); + return torch::xpu::XPUPluggableAllocator::createCustomAllocator( + malloc_fn, free_fn); + }); +} + static void bindGetDeviceProperties(PyObject* module) { // Add method to torch.xpu auto m = py::handle(module).cast(); @@ -495,6 +525,7 @@ namespace torch::xpu { void initModule(PyObject* module) { registerXpuDeviceProperties(module); + registerXpuPluggableAllocator(module); initXpuMethodBindings(module); } diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 14ddcbf732b91..f29c382f0e3f3 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1551,7 +1551,7 @@ def SyclExtension(name, sources, *args, **kwargs): kwargs["libraries"] = libraries include_dirs = kwargs.get("include_dirs", []) - include_dirs += include_paths() + include_dirs += include_paths(device_type="xpu") kwargs["include_dirs"] = include_dirs kwargs["language"] = "c++" @@ -2388,12 +2388,16 @@ def _prepare_ldflags(extra_ldflags, with_cuda, with_sycl, verbose, is_standalone extra_ldflags.append('c10.lib') if with_cuda: extra_ldflags.append('c10_hip.lib' if IS_HIP_EXTENSION else 'c10_cuda.lib') + if with_sycl: + extra_ldflags.append('c10_xpu.lib') extra_ldflags.append('torch_cpu.lib') if with_cuda: extra_ldflags.append('torch_hip.lib' if IS_HIP_EXTENSION else 'torch_cuda.lib') # /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it. # Related issue: https://github.com/pytorch/pytorch/issues/31611 extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ') + if with_sycl: + extra_ldflags.append('torch_xpu.lib') extra_ldflags.append('torch.lib') extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}') if not is_standalone: @@ -2405,9 +2409,13 @@ def _prepare_ldflags(extra_ldflags, with_cuda, with_sycl, verbose, is_standalone extra_ldflags.append('-lc10') if with_cuda: extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') + if with_sycl: + extra_ldflags.append('-lc10_xpu') extra_ldflags.append('-ltorch_cpu') if with_cuda: extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda') + if with_sycl: + extra_ldflags.append('-ltorch_xpu') extra_ldflags.append('-ltorch') if not is_standalone: extra_ldflags.append('-ltorch_python') @@ -2443,10 +2451,11 @@ def _prepare_ldflags(extra_ldflags, with_cuda, with_sycl, verbose, is_standalone extra_ldflags.append('-lamdhip64') if with_sycl: if IS_WINDOWS: - extra_ldflags.append('c10_xpu.lib') - extra_ldflags.append('torch_xpu.lib') extra_ldflags.append(f'/LIBPATH:{_join_sycl_home("lib")}') extra_ldflags.append('sycl.lib') + else: + extra_ldflags.append(f'-L{_join_sycl_home("lib")}') + extra_ldflags.append('-lsycl') return extra_ldflags @@ -2754,6 +2763,8 @@ def _write_ninja_file_to_build_library(path, # TODO generalize with_cuda as specific device type. if with_cuda: system_includes = include_paths("cuda") + elif with_sycl: + system_includes = include_paths("xpu") else: system_includes = include_paths("cpu") # sysconfig.get_path('include') gives us the location of Python.h diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 194684e3388e4..93481a622494b 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -520,6 +520,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: # import here to avoid circular import from .memory import ( + change_current_allocator, empty_cache, get_per_process_memory_fraction, max_memory_allocated, @@ -532,6 +533,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: reset_accumulated_memory_stats, reset_peak_memory_stats, set_per_process_memory_fraction, + XPUPluggableAllocator, ) from .random import ( get_rng_state, @@ -550,7 +552,9 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: "Event", "Stream", "StreamContext", + "XPUPluggableAllocator", "can_device_access_peer", + "change_current_allocator", "current_device", "current_stream", "default_generators", diff --git a/torch/xpu/memory.py b/torch/xpu/memory.py index 3a9c7d7c83ee4..e9a95c7cde37c 100644 --- a/torch/xpu/memory.py +++ b/torch/xpu/memory.py @@ -1,12 +1,18 @@ import collections +import ctypes from typing import Any, Union import torch +from torch._utils import _dummy_type from torch.types import Device -from . import _get_device_index, _lazy_init, is_initialized +from . import _get_device_index, _is_compiled, _lazy_init, is_initialized +if not _is_compiled(): + # Define dummy base classes + torch._C.__dict__["_xpu_XPUAllocator"] = _dummy_type("_xpu_XPUAllocator") + _device_t = Union[Device, str, int, None] @@ -227,7 +233,7 @@ def set_per_process_memory_fraction(fraction: float, device: _device_t = None) - an out-of-memory error will be raised by the allocator. Arguments: - fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction. + fraction (float): Range: 0~1. Allowed memory equals total_memory * fraction. device (torch.device or int or str, optional): selected device. It uses the current device, given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default). @@ -241,7 +247,83 @@ def set_per_process_memory_fraction(fraction: float, device: _device_t = None) - torch._C._xpu_setMemoryFraction(fraction, device) +class _XPUAllocator: + r"""Wrapper over internal XPU memory allocators.""" + + def __init__(self, allocator: torch._C._xpu_XPUAllocator): + self._allocator = allocator + + def allocator(self): + return self._allocator + + +class XPUPluggableAllocator(_XPUAllocator): + r"""XPU memory allocator loaded from a shared library.""" + + def __init__(self, path_to_lib_file: str, alloc_fn_name: str, free_fn_name: str): + r"""XPU memory allocator loaded dynamically from a shared library. + + This lets users provide custom allocation and free functions implemented + in a separate shared library. The allocator is registered through + ``torch._C._xpu_customAllocator`` and becomes available for use via + ``torch.memory.xpu.change_current_allocator``. + + Arguments: + path_to_lib_file (str): + Filesystem path to the shared library file containing the allocation + and free functions. + alloc_fn_name (str): + Name of the allocation function exported from the shared library. + The function must have the signature: + + ``void* alloc_fn(size_t size, int device, sycl::queue* queue);`` + + free_fn_name (str): + Name of the free function exported from the shared library. + The function must have the signature: + + ``void free_fn(void* ptr, size_t size, sycl::queue* queue);`` + """ + allocator_lib = ctypes.CDLL(path_to_lib_file) + + alloc_fn_ptr = getattr(allocator_lib, alloc_fn_name) + free_fn_ptr = getattr(allocator_lib, free_fn_name) + + alloc_fn_addr = ctypes.cast(alloc_fn_ptr, ctypes.c_void_p).value + free_fn_addr = ctypes.cast(free_fn_ptr, ctypes.c_void_p).value + + if alloc_fn_addr is None or free_fn_addr is None: + raise RuntimeError( + "Failed to load allocator symbols from the shared library." + ) + + self._allocator = torch._C._xpu_customAllocator(alloc_fn_addr, free_fn_addr) + + +def change_current_allocator(allocator: _XPUAllocator) -> None: + r"""Change the currently used memory allocator to be the one provided. + + .. note:: + If the current allocator has already been used/initialized, this function will error. + + Arguments: + allocator (torch.xpu.memory._XPUAllocator): allocator to be set as the active one. + """ + torch._C._xpu_changeCurrentAllocator(allocator.allocator()) + + +def _get_current_allocator() -> _XPUAllocator: + r"""Return the allocator being currently used. + + Returns: + _XPUAllocator: the allocator being currently used. + """ + return _XPUAllocator(torch._C._xpu_getAllocator()) + + __all__ = [ + "XPUPluggableAllocator", + "change_current_allocator", "empty_cache", "get_per_process_memory_fraction", "max_memory_allocated", From 597df3a4e2a67b9fdbe1a89b2f4d74f822274db6 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 3 Dec 2025 13:33:16 -0800 Subject: [PATCH 238/338] [CI] Use py3.10 in upload-utilization-stats (#169501) Summary: To fix failures like this on CI, ``` 2025-12-03T19:04:16.4264214Z ##[endgroup] 2025-12-03T19:04:16.8364369Z Traceback (most recent call last): 2025-12-03T19:04:16.8365336Z File "/usr/lib64/python3.9/runpy.py", line 197, in _run_module_as_main 2025-12-03T19:04:16.8366153Z return _run_code(code, main_globals, None, 2025-12-03T19:04:16.8366898Z File "/usr/lib64/python3.9/runpy.py", line 87, in _run_code 2025-12-03T19:04:16.8368113Z exec(code, run_globals) 2025-12-03T19:04:16.8368891Z File "/home/ec2-user/actions-runner/_work/pytorch/pytorch/tools/stats/upload_utilization_stats/upload_utilization_stats.py", line 17, in 2025-12-03T19:04:16.8369619Z from tools.stats.utilization_stats_lib import ( 2025-12-03T19:04:16.8370231Z File "/home/ec2-user/actions-runner/_work/pytorch/pytorch/tools/stats/utilization_stats_lib.py", line 13, in 2025-12-03T19:04:16.8370807Z class UtilizationStats: 2025-12-03T19:04:16.8371390Z File "/home/ec2-user/actions-runner/_work/pytorch/pytorch/tools/stats/utilization_stats_lib.py", line 14, in UtilizationStats 2025-12-03T19:04:16.8371999Z avg: float | None = None 2025-12-03T19:04:16.8372353Z TypeError: unsupported operand type(s) for |: 'type' and 'NoneType' ``` e.g. https://productionresultssa17.blob.core.windows.net/actions-results/e3792057-a4cd-4eaa-b759-317aa47a149b/workflow-job-run-9ddd4c2f-8892-5cca-89c2-e6d976902e9d/logs/job/job-logs.txt?rsct=text%2Fplain&se=2025-12-03T21%3A15%3A42Z&sig=gvBdBvL2Z6KRploHfvLZF07GZvwLvR7SvN%2FxM1Vo2Fw%3D&ske=2025-12-04T06%3A43%3A09Z&skoid=ca7593d4-ee42-46cd-af88-8b886a2f84eb&sks=b&skt=2025-12-03T18%3A43%3A09Z&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skv=2025-11-05&sp=r&spr=https&sr=b&st=2025-12-03T21%3A05%3A37Z&sv=2025-11-05 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169501 Approved by: https://github.com/malfet --- .github/actions/upload-utilization-stats/action.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/actions/upload-utilization-stats/action.yml b/.github/actions/upload-utilization-stats/action.yml index 3eb68e0aa5544..6dfdc9404b703 100644 --- a/.github/actions/upload-utilization-stats/action.yml +++ b/.github/actions/upload-utilization-stats/action.yml @@ -38,6 +38,10 @@ inputs: runs: using: composite steps: + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: '3.10' - name: Print Inputs shell: bash run: | From b1decff555cd50e2123c8c6e25cc0d447c411f62 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 04:46:04 +0000 Subject: [PATCH 239/338] Revert "Fix torch.fx for the newer "|" union syntax (#169453)" This reverts commit a15066c28b3145e6edbfc88359d0411d14cfc70c. Reverted https://github.com/pytorch/pytorch/pull/169453 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/169453#issuecomment-3610098687)) --- test/test_fx.py | 10 ---------- torch/fx/graph.py | 5 ----- torch/fx/node.py | 2 -- 3 files changed, 17 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index e2584156bf730..7fdd6552edc7b 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2381,16 +2381,6 @@ def test_typename_print_pre_pep585(self): self.assertTrue("typing.List[float]" in str(graph)) - def test_typename_print_union(self): - graph: torch.fx.Graph = torch.fx.Graph() - x: torch.fx.Node = graph.create_node("placeholder", "x") - b: torch.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,), type_expr=float|torch.Tensor|None - ) - output: torch.fx.Node = graph.output(b) - - self.assertTrue('float | torch.Tensor | None' in str(graph)) - def test_layout(self): class M(torch.nn.Module): def forward(self, x): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index d4b0a1b1500d3..36ef68a9a2e35 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -10,7 +10,6 @@ import os import pprint import re -import types import typing import warnings from collections import defaultdict @@ -500,10 +499,6 @@ def type_repr(o: Any): return "()" typename = _type_repr(o) - if isinstance(o, types.UnionType) and "|" in typename: - # str | int - args = [type_repr(arg) for arg in o.__args__] - return "|".join(args) if origin_type := getattr(o, "__origin__", None): # list[...], typing.List[...], TensorType[...] diff --git a/torch/fx/node.py b/torch/fx/node.py index 85e6f3a82e969..5afabe40ec341 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -174,8 +174,6 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: # Fixup segment_reduce mismatch if module == "torch" and name == "segment_reduce": name = "_" + name - if module == "torch.nn.functional" and name in ("_ScalingType", "_SwizzleType"): - name = name.removeprefix("_") return f"{module}.{name}" From 5191b2fa68ba19960912bfd7fd721c79d76bb1f3 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 3 Dec 2025 13:27:25 -0800 Subject: [PATCH 240/338] [AOTI] Fix a small buffer mutation issue (#169347) Summary: Fix https://github.com/pytorch/pytorch/issues/169118. A small named buffer can be optimized as inlined constant, but if that named buffer is mutated later on, AOTI can generate wrong code if not respecting the mutation information. This PR fixes the issue by recording mutated_named_buffers when unlifting the input graph and pass to Inductor graph lowering. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169347 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 29 ++++++++++++++++++++- test/inductor/test_aot_inductor_arrayref.py | 3 ++- torch/_inductor/compile_fx.py | 3 +++ torch/_inductor/graph.py | 4 +++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 6c0c932023638..8e4102a57d682 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6619,7 +6619,7 @@ def runner_call(*args, **kwargs): with self.assertRaises(AssertionError): torch.testing.assert_close(new_expected, new_output, atol=1e-3, rtol=1e-3) - def test_cond_share_predicte(self): + def test_cond_share_predicate(self): class Model(torch.nn.Module): def forward(self, predicate, x): y = torch.cond( @@ -6641,6 +6641,33 @@ def forward(self, predicate, x): ) self.check_model(Model(), example_inputs) + def test_cond_predicate_on_cpu(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "is_cache_initialized", + torch.tensor([False], dtype=torch.bool, device="cpu"), + persistent=False, + ) + + def forward(self, x): + def true_fn(x): + return x + 1.0 + + def false_fn(x): + return x + 0.0 + + out = torch.cond( + self.is_cache_initialized, true_fn, false_fn, operands=(x,) + ) + self.is_cache_initialized.fill_(True) + return out + + model = Model() + example_inputs = (torch.tensor([1.0], device=self.device),) + self.check_model(model, example_inputs) + @unittest.skipIf( IS_FBCODE, "To enable after the C shim FC window ends", diff --git a/test/inductor/test_aot_inductor_arrayref.py b/test/inductor/test_aot_inductor_arrayref.py index 492ad9c23c5c7..2b1214c863409 100644 --- a/test/inductor/test_aot_inductor_arrayref.py +++ b/test/inductor/test_aot_inductor_arrayref.py @@ -71,7 +71,8 @@ def fail_minimal_arrayref_interface(is_skip=False): "test_cond_with_parameters": fail_minimal_arrayref_interface(), "test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(), "test_custom_op_in_subgraph": fail_minimal_arrayref_interface(), - "test_cond_share_predicte": fail_stack_allocation(is_skip=True), + "test_cond_share_predicate": fail_stack_allocation(is_skip=True), + "test_cond_predicate_on_cpu": fail_stack_allocation(is_skip=True), "test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), "test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), "test_while_loop_with_unbacked_symint_closure_dynamic_False": fail_minimal_arrayref_interface(), diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 46ca60483828d..98a4445f9cc30 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -473,6 +473,9 @@ def _unlift_graph( pytree.treespec_leaf(), None, ) + # After unlifting, the buffer mutation information is lost. Pass the information + # so that Inductor can do optimizations correctly. + unlifted_gm.meta["mutated_named_buffers"] = OrderedSet(buffer_mutations.values()) return unlifted_gm diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index b136f7ab9eddf..c5ae0c205ef58 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -412,6 +412,9 @@ def __init__( self.named_buffers: dict[str, torch.Tensor] = ( const_module.named_buffers if const_module else {} ) + self.mutated_named_buffers: OrderedSet[torch.Tensor] = gm.meta.get( + "mutated_named_buffers", OrderedSet() + ) self.named_parameters: dict[str, torch.Tensor] = ( const_module.named_parameters if const_module else {} ) @@ -1409,6 +1412,7 @@ def get_attr( config.aot_inductor.use_runtime_constant_folding or config.always_keep_tensor_constants or unsupported_output_tensor(value) + or target in self.mutated_named_buffers ): return self.add_tensor_constant(value, target) From 1c87554d74140eaee964ca8b1832cede67f5f520 Mon Sep 17 00:00:00 2001 From: Ting Lu Date: Thu, 4 Dec 2025 05:20:45 +0000 Subject: [PATCH 241/338] [CI] Add CUDA 13.0 inductor CI benchmarks (#165029) Adding CUDA 13.0 to the inductor bencharks as it is the latest support CUDA version Pull Request resolved: https://github.com/pytorch/pytorch/pull/165029 Approved by: https://github.com/atalman --- .ci/docker/build.sh | 11 +++ .../common/install_inductor_benchmark_deps.sh | 13 ++- .ci/docker/requirements-ci.txt | 4 + .ci/pytorch/build.sh | 5 ++ .ci/pytorch/common_utils.sh | 5 +- .ci/pytorch/test.sh | 25 ++++-- .github/workflows/docker-builds.yml | 1 + .../workflows/inductor-micro-benchmark.yml | 27 ++++++ .github/workflows/inductor-perf-compare.yml | 34 ++++++++ .../workflows/inductor-perf-test-nightly.yml | 86 +++++++++++++++++++ .github/workflows/inductor-periodic.yml | 51 +++++++++++ .github/workflows/inductor.yml | 30 +++++++ .github/workflows/pull.yml | 29 ++++++- .github/workflows/trunk.yml | 10 +++ cmake/Dependencies.cmake | 3 + 15 files changed, 321 insertions(+), 13 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index f0f154f0c7c1f..e175be2a6df4d 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -136,6 +136,17 @@ case "$tag" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; + pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks) + CUDA_VERSION=13.0.2 + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=11 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.12 diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index 81467d87f5140..8b2a3f3ac96c6 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -35,8 +35,19 @@ function install_torchbench() { # Pango is needed for weasyprint which is needed for doctr conda_install pango +# Detect CUDA version and use appropriate wheel index +# DESIRED_CUDA is set as ENV in the Dockerfile (e.g., "13.0.2", "12.8.1") +if [[ "${DESIRED_CUDA}" == 13.* ]]; then + CUDA_INDEX_URL="https://download.pytorch.org/whl/cu130" + echo "DESIRED_CUDA=${DESIRED_CUDA}, using cu130 wheels" +else + # Default to cu128 for CUDA 12.x + CUDA_INDEX_URL="https://download.pytorch.org/whl/cu128" + echo "DESIRED_CUDA=${DESIRED_CUDA}, using cu128 wheels" +fi + # Stable packages are ok here, just to satisfy TorchBench check -pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 +pip_install torch torchvision torchaudio --index-url "${CUDA_INDEX_URL}" install_torchbench install_huggingface diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index f00516ccf1293..a32161cae6a34 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -403,6 +403,10 @@ pyre-extensions==0.0.32 tabulate==0.9.0 #Description: These package are needed to build FBGEMM and torchrec on PyTorch CI +tqdm>=4.66.0 +#Description: progress bar library required for dynamo benchmarks +#test that import: benchmarks/dynamo/* + Jinja2==3.1.6 aiohttp==3.13.2 #Description: required for torch.distributed.debug diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 071f14700def4..6a8956e6fc4be 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -36,6 +36,11 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then nvcc --version fi +if [[ "$BUILD_ENVIRONMENT" == *cuda13* ]]; then + # Disable FBGEMM for CUDA 13 builds + export USE_FBGEMM=0 +fi + if [[ "$BUILD_ENVIRONMENT" == *cuda11* ]]; then if [[ "$BUILD_ENVIRONMENT" != *clang* ]]; then # TODO: there is a linking issue when building with UCC using clang, diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 68402766cbe79..1fd78664cc122 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -285,7 +285,10 @@ EOF rm -rf fbgemm else pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec - pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu + # Skip fbgemm for CUDA 13 as it's not compatible yet + if [[ "$BUILD_ENVIRONMENT" != *cuda13* ]]; then + pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu + fi fi } diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 44dff52974320..9118d6031a2a7 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -886,8 +886,14 @@ test_dynamo_benchmark() { local shard_id="$1" shift + # Exclude torchrec_dlrm for CUDA 13 as FBGEMM is not compatible + local extra_args=() + if [[ "$BUILD_ENVIRONMENT" == *cuda13* ]]; then + extra_args=(--exclude-exact torchrec_dlrm) + fi + if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then - test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@" + test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "${extra_args[@]}" "$@" elif [[ "${TEST_CONFIG}" == *perf* ]]; then # TODO (huydhn): Just smoke test some sample models if [[ "${TEST_CONFIG}" == *b200* ]]; then @@ -899,7 +905,7 @@ test_dynamo_benchmark() { export TORCHBENCH_ONLY_MODELS="BERT_pytorch" fi fi - test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@" + test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "${extra_args[@]}" "$@" else if [[ "${TEST_CONFIG}" == *cpu* ]]; then local dt="float32" @@ -907,17 +913,17 @@ test_dynamo_benchmark() { dt="amp" fi if [[ "${TEST_CONFIG}" == *freezing* ]]; then - test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --"$dt" --freezing "$@" + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --"$dt" --freezing "${extra_args[@]}" "$@" else - test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --"$dt" "$@" + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --"$dt" "${extra_args[@]}" "$@" fi elif [[ "${TEST_CONFIG}" == *aot_inductor* ]]; then - test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@" + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "${extra_args[@]}" "$@" elif [[ "${TEST_CONFIG}" == *max_autotune_inductor* ]]; then - test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@" + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "${extra_args[@]}" "$@" else - test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@" - test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@" + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "${extra_args[@]}" "$@" + test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "${extra_args[@]}" "$@" fi fi } @@ -1928,7 +1934,8 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then else # Do this after checkout_install_torchbench to ensure we clobber any # nightlies that torchbench may pull in - if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* ]]; then + # Skip torchrec/fbgemm for cuda13 as they're not compatible yet + if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* && "${BUILD_ENVIRONMENT}" != *cuda13* ]]; then install_torchrec_and_fbgemm fi PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index fa1f083800fe0..31b189142172b 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -53,6 +53,7 @@ jobs: pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks, + pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks, pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.10-clang12, pytorch-linux-jammy-py3.11-clang12, diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 3421e2b9af77d..c10327c8f548c 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -55,3 +55,30 @@ jobs: test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 720 secrets: inherit + + build-cuda13: + name: cuda13.0-py3.10-gcc11-sm80 + uses: ./.github/workflows/_linux-build.yml + needs: + - get-default-label-prefix + with: + runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks + cuda-arch-list: '8.0' + test-matrix: | + { include: [ + { config: "inductor-micro-benchmark", shard: 1, num_shards: 1, runner: "linux.aws.a100", owners: ["oncall:pt2"] }, + ]} + secrets: inherit + + test-cuda13: + name: cuda13.0-py3.10-gcc11-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: build-cuda13 + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm80 + docker-image: ${{ needs.build-cuda13.outputs.docker-image }} + test-matrix: ${{ needs.build-cuda13.outputs.test-matrix }} + timeout-minutes: 720 + secrets: inherit diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index 764e631819ccc..d38818eef4000 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -59,3 +59,37 @@ jobs: monitor-log-interval: 15 monitor-data-collect-interval: 4 secrets: inherit + + build-cuda13: + name: cuda13.0-py3.10-gcc11-sm80 + uses: ./.github/workflows/_linux-build.yml + needs: + - get-default-label-prefix + with: + runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks + cuda-arch-list: '8.0' + test-matrix: | + { include: [ + { config: "inductor_huggingface_perf_compare", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, + { config: "inductor_timm_perf_compare", shard: 1, num_shards: 2, runner: "linux.aws.a100" }, + { config: "inductor_timm_perf_compare", shard: 2, num_shards: 2, runner: "linux.aws.a100" }, + { config: "inductor_torchbench_perf_compare", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, + ]} + build-additional-packages: "vision audio torchao" + secrets: inherit + + test-cuda13: + name: cuda13.0-py3.10-gcc11-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: build-cuda13 + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm80 + docker-image: ${{ needs.build-cuda13.outputs.docker-image }} + test-matrix: ${{ needs.build-cuda13.outputs.test-matrix }} + # disable monitor in perf tests for more investigation + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 88a528ba1b075..2617fc990b933 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -164,3 +164,89 @@ jobs: monitor-log-interval: 15 monitor-data-collect-interval: 4 secrets: inherit + + build-cuda13: + name: cuda13.0-py3.10-gcc11-sm80 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + # Every bit to make perf run faster helps + runner: linux.12xlarge.memory + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks + cuda-arch-list: '8.0' + test-matrix: | + { include: [ + { config: "inductor_huggingface_perf", shard: 1, num_shards: 5, runner: "linux.aws.a100" }, + { config: "inductor_huggingface_perf", shard: 2, num_shards: 5, runner: "linux.aws.a100" }, + { config: "inductor_huggingface_perf", shard: 3, num_shards: 5, runner: "linux.aws.a100" }, + { config: "inductor_huggingface_perf", shard: 4, num_shards: 5, runner: "linux.aws.a100" }, + { config: "inductor_huggingface_perf", shard: 5, num_shards: 5, runner: "linux.aws.a100" }, + { config: "inductor_timm_perf", shard: 1, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_timm_perf", shard: 2, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_timm_perf", shard: 3, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_timm_perf", shard: 4, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_timm_perf", shard: 5, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_timm_perf", shard: 6, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_torchbench_perf", shard: 1, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_torchbench_perf", shard: 2, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_torchbench_perf", shard: 3, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_torchbench_perf", shard: 4, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_torchbench_perf", shard: 5, num_shards: 6, runner: "linux.aws.a100" }, + { config: "inductor_torchbench_perf", shard: 6, num_shards: 6, runner: "linux.aws.a100" }, + { config: "cachebench", shard: 1, num_shards: 2, runner: "linux.aws.a100" }, + { config: "cachebench", shard: 2, num_shards: 2, runner: "linux.aws.a100" }, + ]} + selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio torchao" + secrets: inherit + + test-nightly-cuda13: + name: cuda13.0-py3.10-gcc11-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: build-cuda13 + if: github.event.schedule == '0 7 * * 1-6' + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm80 + dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true + docker-image: ${{ needs.build-cuda13.outputs.docker-image }} + test-matrix: ${{ needs.build-cuda13.outputs.test-matrix }} + timeout-minutes: 720 + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit + + test-weekly-cuda13: + name: cuda13.0-py3.10-gcc11-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: build-cuda13 + if: github.event.schedule == '0 7 * * 0' + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm80 + dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true + docker-image: ${{ needs.build-cuda13.outputs.docker-image }} + test-matrix: ${{ needs.build-cuda13.outputs.test-matrix }} + timeout-minutes: 1440 + # disable monitor in perf tests, next step is to enable it + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit + + test-cuda13: + name: cuda13.0-py3.10-gcc11-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: build-cuda13 + if: github.event_name == 'workflow_dispatch' + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm80 + dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} + docker-image: ${{ needs.build-cuda13.outputs.docker-image }} + test-matrix: ${{ needs.build-cuda13.outputs.test-matrix }} + timeout-minutes: 720 + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index f3e34d6ecb52f..2a2f9049da99b 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -81,6 +81,56 @@ jobs: test-matrix: ${{ needs.periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit + periodic-dynamo-benchmarks-build-cuda13: + name: periodic-dynamo-benchmarks-build-cuda13 + uses: ./.github/workflows/_linux-build.yml + needs: get-default-label-prefix + with: + runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks + cuda-arch-list: '8.0;8.6' + test-matrix: | + { include: [ + { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, + { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + build-additional-packages: "vision audio torchao" + secrets: inherit + + periodic-dynamo-benchmarks-test-cuda13: + name: periodic-dynamo-benchmarks-test-cuda13 + uses: ./.github/workflows/_linux-test.yml + needs: periodic-dynamo-benchmarks-build-cuda13 + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm86 + docker-image: ${{ needs.periodic-dynamo-benchmarks-build-cuda13.outputs.docker-image }} + test-matrix: ${{ needs.periodic-dynamo-benchmarks-build-cuda13.outputs.test-matrix }} + secrets: inherit + rocm-periodic-dynamo-benchmarks-build: if: github.repository_owner == 'pytorch' name: rocm-periodic-dynamo-benchmarks-build @@ -158,6 +208,7 @@ jobs: test-matrix: ${{ needs.inductor-smoke-build.outputs.test-matrix }} secrets: inherit + periodic-dynamo-benchmarks-cpu-build: name: periodic-dynamo-benchmarks-cpu-build uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index e524ed548b741..b54910164fe62 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -74,6 +74,36 @@ jobs: test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit + inductor-build-cuda13: + name: inductor-build-cuda13 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks + cuda-arch-list: '8.6' + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + test-matrix: | + { include: [ + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + ]} + build-additional-packages: "vision audio torchao" + secrets: inherit + + inductor-test-cuda13: + name: inductor-test-cuda13 + uses: ./.github/workflows/_linux-test.yml + needs: inductor-build-cuda13 + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm86 + docker-image: ${{ needs.inductor-build-cuda13.outputs.docker-image }} + test-matrix: ${{ needs.inductor-build-cuda13.outputs.test-matrix }} + secrets: inherit + inductor-cpu-build: name: inductor-cpu-build uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index e1d46de9110b4..c85e2813a7f37 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -343,8 +343,33 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.test-matrix }} secrets: inherit - linux-noble-xpu-n-py3_10-build: - name: linux-noble-xpu-n-py3.10 + linux-jammy-cuda13_0-py3_10-gcc11-inductor-build: + name: cuda13.0-py3.10-gcc11-sm75 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm75 + docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks + cuda-arch-list: '7.5' + test-matrix: | + { include: [ + { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda13_0-py3_10-gcc11-inductor-test: + name: cuda13.0-py3.10-gcc11-sm75 + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cuda13_0-py3_10-gcc11-inductor-build + with: + build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm75 + docker-image: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-inductor-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-xpu-n-py3_10-build: + name: linux-jammy-xpu-n-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index d458bde5f9d30..8c873f5d15162 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -241,6 +241,16 @@ jobs: cuda-arch-list: '8.0' secrets: inherit + inductor-build-cuda13: + name: inductor-build-cuda13 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-cuda13.0-py3.12-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks + cuda-arch-list: '8.0' + secrets: inherit + # Test cross-compiled models with Windows libs extracted from wheel cross-compile-linux-test: name: cross-compile-linux-test diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index dfff1f2ad833a..9c0c1b6fd32af 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1400,6 +1400,9 @@ if(NOT INTERN_BUILD_MOBILE) # https://github.com/pytorch/pytorch/pull/55292 string(APPEND CMAKE_CUDA_FLAGS " -DCUB_WRAPPED_NAMESPACE=at_cuda_detail") + # Suppress cusparse warnings + string(APPEND CMAKE_CUDA_FLAGS " -DDISABLE_CUSPARSE_DEPRECATED") + message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor") string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1" " -D__CUDA_NO_HALF_OPERATORS__" From d54ff42903c2ae0533931ff11d23b35f875bdb3d Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 3 Dec 2025 13:22:47 -0800 Subject: [PATCH 242/338] [dynamo, guards] apply functools.cached_property to Source.name (#168131) Partial fix for https://github.com/pytorch/pytorch/issues/168118. Decreases guard build time from 25s -> 16s on a local tlparse. On the guard build benchmark, time went from 135.81s -> 84.66s However, there are a lot of changes to `source.name()` callsites. This could technically be avoided by writing our own `cached_property` decorator that requires the function call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168131 Approved by: https://github.com/anijain2305, https://github.com/aorenste --- test/dynamo/test_guard_manager.py | 4 +- test/dynamo/test_subclasses.py | 2 +- torch/_dynamo/eval_frame.py | 4 +- torch/_dynamo/graph_break_registry.json | 6 +- torch/_dynamo/guards.py | 42 +++--- torch/_dynamo/output_graph.py | 14 +- torch/_dynamo/source.py | 133 ++++++++++++------ torch/_dynamo/variables/builder.py | 32 ++--- torch/_dynamo/variables/higher_order_ops.py | 4 +- torch/_dynamo/variables/misc.py | 2 +- torch/_dynamo/variables/nn_module.py | 2 +- torch/_dynamo/variables/optimizer.py | 6 +- torch/_dynamo/variables/tensor.py | 4 +- torch/_export/non_strict_utils.py | 2 +- .../_aot_autograd/frontend_utils.py | 2 +- torch/_guards.py | 5 +- torch/_subclasses/meta_utils.py | 4 +- torch/export/exported_program.py | 2 +- torch/fx/experimental/symbolic_shapes.py | 91 ++++++------ 19 files changed, 208 insertions(+), 153 deletions(-) diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index f11c04c8071d8..5515500d7cda7 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -928,8 +928,8 @@ def hook(guard_wrapper, f_locals, builder): foo_source = LocalSource("foo") foo_x_source = AttrSource(foo_source, "x") - self.assertTrue(builder.get(foo_source.name()) is foo) - self.assertTrue(builder.get(foo_x_source.name()) is foo.x) + self.assertTrue(builder.get(foo_source.name) is foo) + self.assertTrue(builder.get(foo_x_source.name) is foo.x) # Check types of foo.x foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 25c0da48f602f..3ee7119e8e02b 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1624,7 +1624,7 @@ def backend(gm, args): str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items() } curr_var_to_sources = { - str(k): v[0].name() + str(k): v[0].name for k, v in context.fake_mode.shape_env.var_to_sources.items() } return gm diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index a9091767f70fd..2249bc5aa762b 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1713,13 +1713,13 @@ def check_signature_rewritable(graph: torch.fx.GraphModule) -> None: stack = s break if stack is None: - msg = f"{source.name()}, a closed over free variable" + msg = f"{source.name}, a closed over free variable" else: tb = "".join(traceback.format_list(stack)) extra = "" if len(user_stacks) > 1: extra = f"(elided {len(user_stacks) - 1} more accesses)" - msg = f"{source.name()}, accessed at:\n{tb}{extra}" + msg = f"{source.name}, accessed at:\n{tb}{extra}" # TODO: option to print ALL of the stack traces at once input_errors.append(msg) diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index a425fae65a377..5e706e77ba73d 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -389,7 +389,7 @@ { "Gb_type": "Encountered aliasing during higher order op tracing", "Context": "context", - "Explanation": "Higher order ops do not support aliasing. Found in {source_target.name()}", + "Explanation": "Higher order ops do not support aliasing. Found in {source_target.name}", "Hints": [ "Replace `return input` with `return input.clone()` to avoid aliasing.", "Consider using the debug context to change user code to avoid aliasing.", @@ -401,7 +401,7 @@ { "Gb_type": "Encountered input mutation during higher order op tracing", "Context": "context", - "Explanation": "Higher order ops do not support input mutation. Found in {source_target.name()}", + "Explanation": "Higher order ops do not support input mutation. Found in {source_target.name}", "Hints": [ "Consider using the debug context to change user code to avoid mutation.", "Please open an issue." @@ -1469,7 +1469,7 @@ { "Gb_type": "Unsupported function call (delayed)", "Context": "source: {self.source}", - "Explanation": "Dynamo determined that a graph break should occur when calling `{self.source.name()}`. Reason: {self.msg}", + "Explanation": "Dynamo determined that a graph break should occur when calling `{self.source.name}`. Reason: {self.msg}", "Hints": [] } ], diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 322578dc6444f..0f44cabf66f4a 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -914,7 +914,7 @@ def getitem_on_dict_manager( example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: - base_source_name = source.base.name() + base_source_name = source.base.name if isinstance(source.index, ConstDictKeySource): index = source.index.index else: @@ -1043,9 +1043,9 @@ def __init__( self.key_order_guarded_dict_ids = set() assert self.check_fn_manager.output_graph is not None for source in self.check_fn_manager.output_graph.guard_on_key_order: - dict_obj = self.get(source.name()) + dict_obj = self.get(source.name) if self.save_guards: - self.source_get_cache[source.name()] = dict_obj + self.source_get_cache[source.name] = dict_obj self.key_order_guarded_dict_ids.add(id(dict_obj)) # Keep track of weak references of objects with ID_MATCH guard. This @@ -1073,7 +1073,7 @@ def guard_on_dict_keys_and_ignore_order( ) # Iterate over the dicts and install a dict_getitem_manager. - dict_source = guard.originating_source.name() + dict_source = guard.originating_source.name # Ensure that we call dict.keys and not value.keys (which can call # overridden keys method). In the C++ guards, we relied on PyDict_Next @@ -1256,7 +1256,7 @@ def getitem_on_dict_mgr( l1_guard_manager_enum = l2_guard_manager_enum = None if l2_key: l1_source = AttrSource(source.base, l1_key) - l1_source_name = l1_source.name() + l1_source_name = l1_source.name l1_value = mod_dict[l1_key] # do not guard on key order for _parameters etc unless the user code # actually needs the key order (e.g. calling named_parameters) @@ -1304,7 +1304,7 @@ def getitem_on_dict_mgr( return l1_mgr def requires_key_order_guarding(self, source: Source) -> bool: - source_name = source.name() + source_name = source.name if source_name == "": return False obj_id = id(self.get(source_name)) @@ -1347,7 +1347,7 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: root_guard_manager = self.guard_manager.root example_value = None - source_name = source.name() + source_name = source.name if source_name != "" and source_name in self._cached_guard_managers: return self._cached_guard_managers[source_name] @@ -1364,7 +1364,7 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: base_guard_manager = None base_guard_manager_enum = GuardManagerType.GUARD_MANAGER if isinstance(source, ChainedSource): - base_source_name = source.base.name() + base_source_name = source.base.name base_example_value = self.get(base_source_name) base_guard_manager = self.get_guard_manager_from_source(source.base) base_guard_manager_enum = self.get_guard_manager_type( @@ -1755,10 +1755,10 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: ) else: raise AssertionError( - f"missing guard manager builder {source} - {source.name()}" + f"missing guard manager builder {source} - {source.name}" ) - self._cached_guard_managers[source.name()] = out + self._cached_guard_managers[source.name] = out return out def get_guard_manager(self, guard: Guard) -> GuardManager: @@ -1857,7 +1857,7 @@ def HASATTR(self, guard: Guard) -> None: return assert isinstance(source, AttrSource), f"invalid source {guard.name}" base_source = source.base - base = base_source.name() + base = base_source.name attr = source.member ref = self.arg_ref(base) @@ -1879,7 +1879,7 @@ def HASATTR(self, guard: Guard) -> None: if val: # Just install a getattr manager. GetAttrGuardAccessor itself # acts as hasattr guard. - example_value = self.get(source.name()) + example_value = self.get(source.name) base_example_value = self.get(base) guard_manager_enum = self.get_guard_manager_type(source, example_value) @@ -1892,7 +1892,7 @@ def HASATTR(self, guard: Guard) -> None: base_example_value, example_value, base, - source.name(), + source.name, guard_manager_enum, ) else: @@ -2434,7 +2434,7 @@ def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None: self.check_fn_manager.additional_used_global_vars.add(name) ref_a = self.arg_ref(guard) - ref_b = self.arg_ref(source_b.name()) + ref_b = self.arg_ref(source_b.name) if is_from_optimizer_source( guard.originating_source @@ -2709,7 +2709,7 @@ def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: python_fallback = True else: example_value = self.get( - source.name(), + source.name, closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, ) if isinstance(example_value, int): @@ -3919,11 +3919,11 @@ def source_ref(source: Source) -> str: guard_source = source.guard_source() if guard_source is GuardSource.CONSTANT: # No need to track constants - return source.name() + return source.name assert w_builder r_builder = w_builder() assert r_builder is not None - return r_builder.arg_ref(source.name()) + return r_builder.arg_ref(source.name) builder = GuardBuilder( f_code, @@ -4087,7 +4087,7 @@ def add_code_part( if isinstance(guard, DuplicateInputs): source_a = guard.input_source_a source_b = guard.input_source_b - code_part = f"{source_a.name()} is {source_b.name()}" + code_part = f"{source_a.name} is {source_b.name}" install_object_aliasing_guard( builder.get_guard_manager_from_source(source_a), builder.get_guard_manager_from_source(source_b), @@ -4105,8 +4105,8 @@ def add_code_part( ] code_part = ( """check_overlapping(""" - f"""overlapping=[{", ".join(s.name() for s in guard.overlapping_sources)}], """ - f"""non_overlapping=[{", ".join(s.name() for s in guard.non_overlapping_sources)}])""" + f"""overlapping=[{", ".join(s.name for s in guard.overlapping_sources)}], """ + f"""non_overlapping=[{", ".join(s.name for s in guard.non_overlapping_sources)}])""" ) install_storage_overlapping_guard( overlapping_guard_managers, @@ -4585,7 +4585,7 @@ def make_dupe_guard( dupe_source ) or is_from_flatten_script_object_source(obj_source): raise exc.UnsafeScriptObjectError( - f"{obj_source.name()} is aliasing {dupe_source.name()}. This is not supported." + f"{obj_source.name} is aliasing {dupe_source.name}. This is not supported." f" Please do a clone for corresponding input." ) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 6ff908ff0394f..0d409869ccec5 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1235,7 +1235,7 @@ def register_leaf_name(leaf_name: str) -> None: self.param_name_to_source[new_name] = new_source if isinstance(source, LocalSource): self.dynamo_flat_name_to_original_fqn[ - OutputGraph.module_key_name(new_source.name()) + OutputGraph.module_key_name(new_source.name) ] = leaf_name # annoying, but there are cases when we do not have parameters @@ -2566,7 +2566,7 @@ def placeholder_binds_symbol(node: fx.Node) -> Optional[sympy.Symbol]: return None def remove_unused(node: fx.Node) -> None: - log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) + log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name) # I'm not really sure why you need to delete these from the # node since the node is going to get removed del node.meta["grapharg"] @@ -2748,7 +2748,7 @@ def example_value_from_input_node(self, node: torch.fx.Node) -> Any: def add_fqn_info_for_inlined_modules( self, inlined_module: torch.nn.Module, source: Source ) -> None: - name = OutputGraph.module_key_name(source.name()) + name = OutputGraph.module_key_name(source.name) name = get_unique_name_wrt( name, self.used_inlined_inbuilt_modules_names, self.global_scope ) @@ -2761,7 +2761,7 @@ def register_leaf_name(leaf_name: str) -> None: self.param_name_to_source[new_name] = new_source if isinstance(source, LocalSource): self.dynamo_flat_name_to_original_fqn[ - OutputGraph.module_key_name(new_source.name()) + OutputGraph.module_key_name(new_source.name) ] = leaf_name # annoying, but there are cases when we do not have parameters @@ -3312,7 +3312,7 @@ def create_graph_input( log.debug( "create_graph_input %s %s %s at debug_level %s before=%s", name, - source.name() if source is not None else "(none)", + source.name if source is not None else "(none)", example_value, self.debug_level, before, @@ -3658,7 +3658,7 @@ def _lift_symbols_in_symint( log.debug( "_lift_symbols_in_symint %s from %s at debug_level %s", s0, - source.name() if source is not None else "subgraph inputs", + source.name if source is not None else "subgraph inputs", self.debug_level, ) self.lifted_freevars[parent_proxy] = ph # type: ignore[index] @@ -3684,7 +3684,7 @@ def _lift_symbols_in_symint( log.debug( "_lift_symbols_in_symint %s from %s at debug_level %s", s, - source.name() if source is not None else "subgraph inputs", + source.name if source is not None else "subgraph inputs", self.debug_level, ) ph.node.meta["grapharg"] = GraphArg( diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index a5a69cd177c27..8b42472465984 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -117,7 +117,7 @@ def _get_source_debug_name(source: Optional[Source]) -> str: return "" else: try: - return source.name() + return source.name except NotImplementedError: return "" @@ -147,6 +147,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return GuardSource.LOCAL + @functools.cached_property def name(self) -> str: return f"L[{repr(self.local_name)}]" @@ -162,6 +163,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return GuardSource.TEMP_LOCAL + @property def name(self) -> str: raise NotImplementedError( "Cannot create guard on TempLocalSource - this is an internal Dynamo bug. Please file an issue on GitHub." @@ -178,6 +180,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return GuardSource.SYNTHETIC_LOCAL + @functools.cached_property def name(self) -> str: return f"SYNTHETIC_LOCAL[{self.local_name!r}]" @@ -194,6 +197,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.random_call_index)) codegen.append_output(create_binary_subscr()) + @functools.cached_property def name(self) -> str: return f"random_value_{self.random_call_index}" @@ -208,6 +212,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return GuardSource.GLOBAL + @functools.cached_property def name(self) -> str: return f"G[{repr(self.global_name)}]" @@ -227,6 +232,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return GuardSource.GLOBAL + @functools.cached_property def name(self) -> str: return f"G[{repr(self.global_name)}]()" @@ -240,8 +246,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}()" + return f"{self.base.name}()" @dataclasses.dataclass(frozen=True) @@ -269,10 +276,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: if not self.member.isidentifier(): - return f"getattr({self.base.name()}, {self.member!r})" - return f"{self.base.name()}.{self.member}" + return f"getattr({self.base.name}, {self.member!r})" + return f"{self.base.name}.{self.member}" @dataclasses.dataclass(frozen=True) @@ -295,8 +303,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"object.__getattribute__({self.base.name()}, {self.member!r})" + return f"object.__getattribute__({self.base.name}, {self.member!r})" # Represents obj.__dict__ where obj is a type object @@ -309,12 +318,13 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: # type(ob).__dict__ can return a proxy of the dict. But in the C++ # guard accessor, we are use type->tp_dict which is a dict. So, # forcefully pass a dict object to ensure that the GuardManager # registers that its working on a dict object. - return f"dict({self.base.name()}.__dict__)" + return f"dict({self.base.name}.__dict__)" # Represents obj.__mro__ where object is type object @@ -327,8 +337,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}.__mro__" + return f"{self.base.name}.__mro__" @dataclasses.dataclass(frozen=True) @@ -360,8 +371,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}.__code__" + return f"{self.base.name}.__code__" # Represents obj.__closure__ where object is type object @@ -374,8 +386,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}.__closure__" + return f"{self.base.name}.__closure__" # Represents tensor.grad source. It could be represented by AttrSource as well. @@ -393,8 +406,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}.{self.member}" + return f"{self.base.name}.{self.member}" @dataclasses.dataclass(frozen=True) @@ -425,6 +439,7 @@ class EphemeralSource(Source): def guard_source(self) -> GuardSource: return GuardSource.EPHEMERAL + @functools.cached_property def name(self) -> str: return f"" @@ -443,8 +458,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @property def name(self) -> str: - return self.base.name() + return self.base.name class TensorProperty(enum.Enum): @@ -492,14 +508,15 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: if self.prop is TensorProperty.SIZE: - return f"{self.base.name()}.size()[{self.idx}]" + return f"{self.base.name}.size()[{self.idx}]" elif self.prop is TensorProperty.STRIDE: - return f"{self.base.name()}.stride()[{self.idx}]" + return f"{self.base.name}.stride()[{self.idx}]" elif self.prop is TensorProperty.STORAGE_OFFSET: assert self.idx is None - return f"{self.base.name()}.storage_offset()" + return f"{self.base.name}.storage_offset()" else: raise AssertionError(f"unhandled {self.prop}") @@ -517,8 +534,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"({self.idx}, {self.base.name()})" + return f"({self.idx}, {self.base.name})" @dataclasses.dataclass(frozen=True) @@ -532,9 +550,10 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: # NB: use method call so that function stripping regexes work - return f"{self.base.name()}.__neg__()" + return f"{self.base.name}.__neg__()" @dataclasses.dataclass(frozen=True) @@ -548,8 +567,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"cast_symbool_to_symint_guardless({self.base.name()})" + return f"cast_symbool_to_symint_guardless({self.base.name})" @dataclasses.dataclass(frozen=True) @@ -571,8 +591,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"int({self.base.name()})" + return f"int({self.base.name})" @dataclasses.dataclass(frozen=True) @@ -586,8 +607,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}.__obj_flatten__()" + return f"{self.base.name}.__obj_flatten__()" @dataclasses.dataclass(frozen=True) @@ -601,8 +623,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}._type().qualified_name()" + return f"{self.base.name}._type().qualified_name()" class AttrProxySource(ChainedSource): @@ -612,8 +635,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}.get_base()" + return f"{self.base.name}.get_base()" @dataclasses.dataclass(frozen=True) @@ -631,13 +655,13 @@ def __post_init__(self) -> None: assert isinstance(self.idx_key, str) object.__setattr__(self, "field", "__kwdefaults__") object.__setattr__( - self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']" + self, "_name", f"{self.base.name}.{self.field}['{self.idx_key}']" ) else: assert isinstance(self.idx_key, int) object.__setattr__(self, "field", "__defaults__") object.__setattr__( - self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" + self, "_name", f"{self.base.name}.{self.field}[{self.idx_key}]" ) def reconstruct(self, codegen: "PyCodegen") -> None: @@ -649,6 +673,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: return self._name @@ -681,15 +706,16 @@ def unpack_slice(self) -> slice: slice_class, slice_args = self.index return slice_class(*slice_args) + @functools.cached_property def name(self) -> str: # Index can be of following types # 1) index is a slice - example 1:4 # 2) index is a constant - example string, integer assert not isinstance(self.index, Source) if self.index_is_slice: - return f"{self.base.name()}[{self.unpack_slice()!r}]" + return f"{self.base.name}[{self.unpack_slice()!r}]" else: - return f"{self.base.name()}[{self.index!r}]" + return f"{self.base.name}[{self.index!r}]" @dataclasses.dataclass(frozen=True) @@ -707,9 +733,10 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.extend_output(create_call_function(2, False)) + @functools.cached_property def name(self) -> str: # The list creation will be CSE'd by PyExprCSEPass - return f"list(dict.keys({self.base.name()}))[{self.index!r}]" + return f"list(dict.keys({self.base.name}))[{self.index!r}]" def is_dict_key(self) -> bool: return True @@ -735,9 +762,10 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.extend_output(create_call_function(2, False)) + @functools.cached_property def name(self) -> str: # set ordering might not be stable - return f"list({self.base.name()})[{self.index!r}]" + return f"list({self.base.name})[{self.index!r}]" def is_dict_key(self) -> bool: return False @@ -772,11 +800,12 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(create_binary_subscr()) + @functools.cached_property def name(self) -> str: if isinstance(self.index, ConstDictKeySource): - return f"{self.base.name()}[{self.index.name()}]" + return f"{self.base.name}[{self.index.name}]" else: - return f"{self.base.name()}[{self.index!r}]" + return f"{self.base.name}[{self.index!r}]" # Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that @@ -817,11 +846,12 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) + @functools.cached_property def name(self) -> str: if isinstance(self.index, ConstDictKeySource): - return f"dict.__getitem__({self.base.name()}, {self.index.name()})" + return f"dict.__getitem__({self.base.name}, {self.index.name})" else: - return f"{self.base.name()}[{self.index!r}]" + return f"{self.base.name}[{self.index!r}]" @dataclasses.dataclass(frozen=True) @@ -852,6 +882,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) + @functools.cached_property def name(self) -> str: # Index can be of following types # 1) index is a slice - example 1:4 @@ -862,7 +893,7 @@ def name(self) -> str: "List[slice] is a temporary object and should not have a source" ) else: - return f"list.__getitem__({self.base.name()}, {self.index!r})" + return f"list.__getitem__({self.base.name}, {self.index!r})" @dataclasses.dataclass(frozen=True) @@ -875,8 +906,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.extend_output(create_call_function(2, False)) + @functools.cached_property def name(self) -> str: - return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" + return f"___tuple_iterator_getitem({self.base.name}, {self.index!r})" @dataclasses.dataclass(frozen=True) @@ -888,8 +920,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"___namedtuple_fields({self.base.name()})" + return f"___namedtuple_fields({self.base.name})" @dataclasses.dataclass(frozen=True) @@ -904,8 +937,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"___dataclass_fields({self.base.name()})" + return f"___dataclass_fields({self.base.name})" @dataclasses.dataclass(frozen=True) @@ -921,8 +955,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return f"type({self.base.name()})" + return f"type({self.base.name})" @dataclasses.dataclass(frozen=True) @@ -933,8 +968,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() + @functools.cached_property def name(self) -> str: - return self.base.name() + return self.base.name @dataclasses.dataclass(frozen=True) @@ -945,8 +981,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] + @functools.cached_property def name(self) -> str: - return self.base.name() + return self.base.name @dataclasses.dataclass(frozen=True) @@ -969,6 +1006,7 @@ def guard_source(self) -> GuardSource: @dataclasses.dataclass(frozen=True) class GlobalStateSource(Source): + @property def name(self) -> str: return "" @@ -987,6 +1025,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: install_guard(self.make_guard(GuardBuilder.ID_MATCH)) + @property def name(self) -> str: return "__import__('torch')" @@ -1014,6 +1053,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: install_guard(self.make_guard(GuardBuilder.ID_MATCH)) + @property def name(self) -> str: return "__import__('collections')" @@ -1034,6 +1074,7 @@ def guard_source(self) -> GuardSource: class TorchFunctionModeStackSource(Source): ind: int + @functools.cached_property def name(self) -> str: return f"___get_torch_function_mode_stack_at({self._get_index()})" @@ -1065,6 +1106,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return GuardSource.CONSTANT + @functools.cached_property def name(self) -> str: return self.source_name @@ -1074,8 +1116,9 @@ def make_guard(self, fn: Any) -> Any: @dataclasses.dataclass(frozen=True) class NumpyTensorSource(ChainedSource): + @functools.cached_property def name(self) -> str: - return f"___from_numpy({self.base.name()})" + return f"___from_numpy({self.base.name})" def guard_source(self) -> GuardSource: return self.base.guard_source() @@ -1088,8 +1131,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: @dataclasses.dataclass(frozen=True) class SubclassAttrListSource(ChainedSource): + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}.__tensor_flatten__()[0]" + return f"{self.base.name}.__tensor_flatten__()[0]" def guard_source(self) -> GuardSource: return self.base.guard_source() @@ -1099,8 +1143,9 @@ def guard_source(self) -> GuardSource: # source, it is ephemeral @dataclasses.dataclass(frozen=True) class FloatTensorSource(ChainedSource): + @functools.cached_property def name(self) -> str: - return f"___as_tensor({self.base.name()})" + return f"___as_tensor({self.base.name})" def guard_source(self) -> GuardSource: return self.base.guard_source() @@ -1108,8 +1153,9 @@ def guard_source(self) -> GuardSource: @dataclasses.dataclass(frozen=True) class CallMethodItemSource(ChainedSource): + @functools.cached_property def name(self) -> str: - return f"{self.base.name()}.item()" + return f"{self.base.name}.item()" def guard_source(self) -> GuardSource: return self.base.guard_source() @@ -1120,6 +1166,7 @@ def guard_source(self) -> GuardSource: # guard contents from the ambient ShapeEnv @dataclasses.dataclass(frozen=True) class ShapeEnvSource(Source): + @property def name(self) -> str: return "" @@ -1131,6 +1178,7 @@ def guard_source(self) -> GuardSource: class CurrentStreamSource(Source): device: device_type + @functools.cached_property def name(self) -> str: return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))" @@ -1153,6 +1201,7 @@ def guard_source(self) -> GuardSource: @dataclasses.dataclass(frozen=True) class BackwardStateSource(Source): + @property def name(self) -> str: return "" diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index b41da586c799c..0be956a4cac67 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -389,7 +389,7 @@ def erase(self): self.example_strong_ref = None def __eq__(self, other): - return self.source.name() == other.source.name() + return self.source.name == other.source.name class BackwardStateGraphArg(GraphArg): @@ -444,7 +444,7 @@ def __init__( super().__init__() self.tx = tx self.source = source - self.name = source.name() + self.name = source.name def __call__(self, value): if value in self.tx.output.side_effects: @@ -1645,7 +1645,7 @@ def build_key_value(i, k, v): elif value.dynamism.type == _DimHintType.DYNAMIC: log.debug( "%s marked %s via IntWrapper", - self.source.name(), + self.source.name, DimDynamic.DYNAMIC, ) return self.wrap_symint( @@ -1658,7 +1658,7 @@ def build_key_value(i, k, v): elif value.dynamism.type == _DimHintType.AUTO: log.debug( "%s marked %s via IntWrapper", - self.source.name(), + self.source.name, DimDynamic.DYNAMIC, ) return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC) @@ -1831,7 +1831,7 @@ def mark_static_input(self, value: torch.Tensor, guard: bool): from ..decorators import mark_static_address static_inputs_log.debug( - "Marking static input %s, id: %s)", self.source.name(), id(value) + "Marking static input %s, id: %s)", self.source.name, id(value) ) mark_static_address(value, guard=guard) @@ -2003,12 +2003,12 @@ def wrap_module(self, value: torch.nn.Module): def wrap_literal(self, value): if type(value) is int: # allowlist has higher precedence over specialization control. - if is_dynamic_source(self.source.name()): - log.debug("%s marked dynamic via source whitelist", self.source.name()) + if is_dynamic_source(self.source.name): + log.debug("%s marked dynamic via source whitelist", self.source.name) return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC) - if is_unbacked_source(self.source.name()): - log.debug("%s marked unbacked via source whitelist", self.source.name()) + if is_unbacked_source(self.source.name): + log.debug("%s marked unbacked via source whitelist", self.source.name) return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED) if not config.specialize_int: @@ -2034,7 +2034,7 @@ def wrap_literal(self, value): process_automatic_dynamic( self.tx, - self.source.name(), + self.source.name, FrameStateSizeEntry.make_scalar(value), is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), ) @@ -2440,7 +2440,7 @@ def wrap_symint( self.install_guards(GuardBuilder.CONSTANT_MATCH) return ConstantVariable.create(value=value, source=self.source) - name = self.source.name() + name = self.source.name frame_state_entry = process_automatic_dynamic( self.tx, @@ -2453,7 +2453,7 @@ def wrap_symint( # know if bare integers are actually going to be sizevars # and it is inappropriate to eagerly duck size them with # real sizevars - normalized_source_name = normalize_source_name(self.source.name()) + normalized_source_name = normalize_source_name(self.source.name) base_source = self.source if isinstance(base_source, ChainedSource): base_source = base_source.get_base() @@ -2539,7 +2539,7 @@ def wrap_symfloat(self, value): frame_state_entry = process_automatic_dynamic( self.tx, - self.source.name(), + self.source.name, FrameStateSizeEntry.make_scalar(value), is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), ) @@ -3386,7 +3386,7 @@ def _automatic_dynamic( hints=[], ) - name = source.name() + name = source.name prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) shape_env_to_source_to_symbol_cache = ( prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None @@ -3509,7 +3509,7 @@ def update_dim2constraint(dim, constraint_range, name): # Reflect the user directive in the frame_state # For dynamic, apply None always - normalized_source_name = normalize_source_name(source.name()) + normalized_source_name = normalize_source_name(source.name) base_source = source if isinstance(base_source, ChainedSource): base_source = base_source.get_base() @@ -3670,7 +3670,7 @@ def wrap_to_fake_tensor_and_record( log.debug( "wrap_to_fake %s %s %s %s", - source.name(), + source.name, tuple(e.shape), symbolic_context, type(e), diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 4fe11f4dd03d1..a4543821b19b1 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1105,7 +1105,7 @@ def check_aliasing_and_input_mutation( unimplemented( gb_type="Encountered input mutation during higher order op tracing", context=context, - explanation=f"Higher order ops do not support input mutation. Found in {source_target.name()}", + explanation=f"Higher order ops do not support input mutation. Found in {source_target.name}", hints=[ "Consider using the debug context to change user code to avoid mutation.", "Please open an issue.", @@ -1119,7 +1119,7 @@ def check_aliasing_and_input_mutation( unimplemented( gb_type="Encountered aliasing during higher order op tracing", context=context, - explanation=f"Higher order ops do not support aliasing. Found in {source_target.name()}", + explanation=f"Higher order ops do not support aliasing. Found in {source_target.name}", hints=[ "Replace `return input` with `return input.clone()` to avoid aliasing.", "Consider using the debug context to change user code to avoid aliasing.", diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 5bd8ad5d075e6..748d4a0985b49 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -572,7 +572,7 @@ def call_function( gb_type="Unsupported function call (delayed)", context=f"source: {self.source}", explanation="Dynamo determined that a graph break should occur " - f"when calling `{self.source.name()}`. Reason: {self.msg}", + f"when calling `{self.source.name}`. Reason: {self.msg}", hints=[], ) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 525c42a009c1d..bb6952abf0b56 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -122,7 +122,7 @@ def convert_to_fake(x: Any) -> Any: def record_nn_module_stack( module_key: str, source: Source, tx: "InstructionTranslator", mod: torch.nn.Module ) -> Any: - fully_qualified_name = source.name() + fully_qualified_name = source.name # Remove redundant namings fully_qualified_name = re.sub( r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]", diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index fd7ccf9cc6e68..69ca37db4ef37 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -323,7 +323,7 @@ def mark_static(x: Any) -> None: # Note: to avoid spam logs only warn if perf hint artifact is enabled # (NB: artifacts are only enabled at the debug or warning level) if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): - non_static_grad_names = [src.name() for src in non_static_grads] + non_static_grad_names = [src.name for src in non_static_grads] perf_hint_log.warning( ( "Grad tensors %s will be copied during cudagraphs execution." @@ -365,7 +365,7 @@ def wrap_tensor( # mark these tensors as static for cudagraphs mark_static_address(tensor_value, guard=True) source = self.tensor_to_source[tensor_value] - self.static_tensor_names.add(tx.output.module_key_name(source.name())) + self.static_tensor_names.add(tx.output.module_key_name(source.name)) elif tensor_value in self.grad_to_source: source = self.grad_to_source[tensor_value] else: @@ -374,7 +374,7 @@ def wrap_tensor( global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) source = GlobalWeakRefSource(global_name) - self.static_tensor_names.add(tx.output.module_key_name(source.name())) + self.static_tensor_names.add(tx.output.module_key_name(source.name)) return VariableTracker.build(tx, tensor_value, source) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 548e69ef0262d..d47c520046d38 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -316,7 +316,7 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name): # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__, # L['mod'].model.model.encoder.embed_positions)", scope) # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope. - _input_associated_real_value = eval(self.source.name(), scope) + _input_associated_real_value = eval(self.source.name, scope) except Exception as exc: raise NotImplementedError from exc @@ -553,7 +553,7 @@ def call_id(self, tx): # For local source, we associate the real value. We use this real value scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} try: - _input_associated_real_value = eval(self.source.name(), scope) + _input_associated_real_value = eval(self.source.name, scope) except Exception as exc: unimplemented( gb_type="Error getting associated real value", diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 1c064845fe160..e80b96d1c68ce 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -280,7 +280,7 @@ def _create_symbolic_context_for_tensor(t, source, t_constraints, sources, mode) if isinstance(constraint, _RelaxedConstraint): continue symbolic_context.constraint_sizes[i] = constraint.constraint_range - mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment] + mode.shape_env.source_name_to_debug_name[src.name] = constraint.name # type: ignore[assignment] return symbolic_context diff --git a/torch/_functorch/_aot_autograd/frontend_utils.py b/torch/_functorch/_aot_autograd/frontend_utils.py index 4780fd2b8ebcc..041d321fec56d 100644 --- a/torch/_functorch/_aot_autograd/frontend_utils.py +++ b/torch/_functorch/_aot_autograd/frontend_utils.py @@ -173,7 +173,7 @@ def _try_get_metadata_from_dynamo( assert source is None or source not in seen_sources, source seen_sources.add(source) aot_autograd_arg_pos_to_source.append(source) - source_name = source.name() if source else str(source) + source_name = source.name if source else str(source) # input[i] in dynamo is now: # input[i + len(extra_params)] in AOT, diff --git a/torch/_guards.py b/torch/_guards.py index 1bd32fc7f08ec..8da885bbb683c 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -245,7 +245,7 @@ class Guard: # globals (and locals, if you create a LOCAL guard) to extract the Python # object that we want to perform guard tests on. This evaluation # typically happens in GuardBuilder.eval. In these cases, name is - # typically produced by originating_source.name() (not to be confused with + # typically produced by originating_source.name (not to be confused with # GuardSource - the property source). # # Occasionally, name is not a valid Python expression; sometimes @@ -297,7 +297,7 @@ def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]: @property def name(self) -> str: - return self.originating_source.name() + return self.originating_source.name @property def source(self) -> GuardSource: @@ -1092,6 +1092,7 @@ def reconstruct(self, codegen: PyCodegen) -> None: def guard_source(self) -> GuardSource: raise NotImplementedError + @functools.cached_property def name(self) -> str: raise NotImplementedError diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 4ede1d7234066..1db028fdbe2ee 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -870,7 +870,7 @@ def _backward_error(cls, t: _TensorT) -> _TensorT: # This function assumes that it's possible to do the conversion # NB: name here is used in a conventional way by Dynamo; it corresponds - # precisely to the Source.name() of the tensor we're fakeifying and + # precisely to the Source.name of the tensor we're fakeifying and # corresponds to a valid Python expression. When we construct sub-names # as part of this process, we will maintain this invariant! (Even though # other users of this may not need it this property to be upheld.) @@ -1937,7 +1937,7 @@ def __call__( metadata_fn=lambda: { "describer_id": self.describer.id, "id": t_desc.id, - "source": source.name(), + "source": source.name, }, ) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index afd73ce13d00b..ffcc7dff4941b 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -1707,7 +1707,7 @@ def _convert_guards_to_code(graph_module): ) } py_printer = torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter( - shape_env.var_to_sources, lambda s: s.name(), shape_env.var_to_sources + shape_env.var_to_sources, lambda s: s.name, shape_env.var_to_sources ) ret = [ py_printer.doprint(guard.expr) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index bacc95d4c9154..77b2681055c44 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1918,7 +1918,7 @@ class StrictMinMaxConstraint(Constraint): def render(self, source: Source) -> str: """Format the constrain equation""" # TODO: better printing for -oo and oo - return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}" + return f"{self.vr.lower} <= {source.name} <= {self.vr.upper}" @dataclass(frozen=True) @@ -1943,7 +1943,7 @@ class RelaxedUnspecConstraint(Constraint): """ def render(self, source: Source) -> str: - return f"RelaxedUnspecConstraint({source.name()})" + return f"RelaxedUnspecConstraint({source.name})" # NB: None here indicates the client constraint is whatever is implicitly @@ -2039,7 +2039,7 @@ def _rewrite(self, src: Source) -> sympy.Expr: return self._defs[src] else: # otherwise, create a symbol representing the source - return sympy.Symbol(src.name()) + return sympy.Symbol(src.name) def is_equal(self, source1: Source, source2: Source) -> bool: return ( @@ -2252,11 +2252,11 @@ class TrackedFake: symbolic_context: Optional[SymbolicContext] def __hash__(self) -> int: - return hash((self.fake, self.source.name())) + return hash((self.fake, self.source.name)) def __eq__(self, other: object) -> bool: if isinstance(other, TrackedFake): - return self.fake is other.fake and self.source.name() == other.source.name() + return self.fake is other.fake and self.source.name == other.source.name return False @@ -2712,7 +2712,7 @@ def _print_Symbol(self, expr: sympy.Symbol) -> str: def repr_sources(src: Mapping[sympy.Symbol, list[Source]]) -> str: return repr( { - symbol: [s.name() for s in sources] + symbol: [s.name for s in sources] for symbol, sources in src.items() } ) @@ -2820,7 +2820,7 @@ def print_source(self, source: Source) -> str: if source in self.source_to_symbol: return self.source_to_symbol[source].name - source_name = source.name() + source_name = source.name mangled_name = re.sub("[^0-9a-zA-Z_]+", "_", source_name) old_mangled_name = mangled_name count = 0 @@ -2849,7 +2849,7 @@ class _CppShapeGuardsHelper(_ShapeGuardsHelper): class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter): def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]): - super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) + super().__init__(var_to_sources, lambda n: n.name, var_to_sources) class DynamicDimConstraintPrinter(PythonPrinter): @@ -2875,7 +2875,7 @@ def _print_Symbol(self, expr: sympy.Symbol) -> str: assert self.symbol_to_source.get(expr), ( f"Unknown symbol {expr} created by constraints solver" ) - return self.symbol_to_source[expr][0].name() + return self.symbol_to_source[expr][0].name class DimConstraints: @@ -3095,7 +3095,7 @@ def add_equality(self, source: Source, expr: sympy.Expr) -> None: """Add an equality constraint""" if expr.is_number: # specialization, right here - self._static_results.add(f"{source.name()} == {expr}") + self._static_results.add(f"{source.name} == {expr}") else: # these will resolve to either specializations or dynamic equality constraints self._symbolic_equivalences.append((source, expr)) @@ -3175,7 +3175,7 @@ def solve(self) -> None: assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" # because this is univariate, the solution is a specialization self._static_results.add( - f"{self._dcp.symbol_to_source[s][0].name()} == {val}" + f"{self._dcp.symbol_to_source[s][0].name} == {val}" ) # add this as a substitution to simplify other constraints self._substitutions[s] = val # type: ignore[assignment] @@ -3200,8 +3200,8 @@ def solve(self) -> None: base, divisor = congruence.args tmp_name = "_" + str( self._dcp.source_name_to_debug_name.get( - self._dcp.symbol_to_source[s][0].name(), - self._dcp.symbol_to_source[s][0].name(), + self._dcp.symbol_to_source[s][0].name, + self._dcp.symbol_to_source[s][0].name, ) ) tmp = sympy.Symbol(tmp_name, integer=True) @@ -3243,7 +3243,7 @@ def solve(self) -> None: # remaining symbolic equivalences become dynamic equality constraints for source, expr3 in self._symbolic_equivalences: - self._dynamic_results.add(f"{source.name()} == {self._dcp.doprint(expr3)}") + self._dynamic_results.add(f"{source.name} == {self._dcp.doprint(expr3)}") @classmethod def _is_supported_congruence(cls, congruence: sympy.Expr) -> bool: @@ -3266,7 +3266,7 @@ def forced_specializations(self) -> dict[str, sympy.Expr]: """Returns a dictionary of the names of symbols to their specialized value""" def debug_name(src: Source) -> str: - name = src.name() + name = src.name if self._dcp.source_name_to_debug_name: return f"{self._dcp.source_name_to_debug_name[name]} = {name}" else: @@ -4011,7 +4011,7 @@ def patch_source_specialization( check_fn: A function that takes a sympy Symbol and returns a sympy expression representing a constraint/specialization to be applied """ - name = source.name() + name = source.name sym = self.source_to_var[name] expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr new_axioms = dict(self.get_implications(self.simplify(expr))) @@ -4284,7 +4284,7 @@ def freeze_runtime_asserts(self) -> None: def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]: if not self._translation_validation_enabled: return None - srcname = source.name() + srcname = source.name if source not in self.source_to_symbol: self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) return self.source_to_symbol[srcname] @@ -4874,7 +4874,7 @@ def _log_create_unbacked_symbol( if source is None: sloc, maybe_extra_debug = self._get_stack_summary(is_debug) else: - sloc, maybe_extra_debug = source.name(), "" + sloc, maybe_extra_debug = source.name, "" log.info( "%s %s [%s, %s] %s%s", prefix, @@ -5028,7 +5028,7 @@ def create_symbol( if constraint_dim.vr.lower != val: raise ConstraintViolationError( f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, " - f"for {source.name()}" + f"for {source.name}" ) if symbolic_context: from torch._dynamo.source import TensorPropertySource @@ -5041,7 +5041,7 @@ def create_symbol( constraint_dim = None # see note [Tensor Fakification and Symbol Caching] - source_name = source.name() + source_name = source.name if ( isinstance(symbolic_context, StatefulSymbolicContext) and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache @@ -5115,7 +5115,7 @@ def create_symbol( # If we're not duck shaping, we always create a new symbol # Even if we're duck shaping, if we haven't seen this particular # value before, we also create a new symbol - symbol_id = self._generate_unique_id(source.name()) + symbol_id = self._generate_unique_id(source.name) if type(val) is int or is_nested_int(val): sympy_expr = make_symbol( SymT.SIZE, symbol_id, positive=positive, integer=True @@ -5219,7 +5219,7 @@ def create_symbol( "create_symbol %s = %s for %s %s %s%s%s", sympy_expr, val, - source.name(), + source.name, range_str, sloc, maybe_more_info, @@ -5232,7 +5232,7 @@ def create_symbol( "symbol": str(sympy_expr), "val": repr(val), "vr": range_str, - "source": source.name(), + "source": source.name, "user_stack": structured.from_traceback( TracingContext.extract_stack() ), @@ -5248,7 +5248,7 @@ def create_symbol( # the same symint r = self.val_to_var[val] self.source_to_var[source_name] = r - self.log.debug("create_symbol %s duck sized %s", r, source.name()) + self.log.debug("create_symbol %s duck sized %s", r, source.name) if isinstance(r, sympy.Symbol): r_sources = self.var_to_sources[r] @@ -5275,7 +5275,7 @@ def add_var_to_val(self, expr: sympy.Symbol, val: int) -> None: self.var_to_val[expr] = sympy.Integer(val) def _debug_name(self, source: Source) -> str: - src_name = source.name() + src_name = source.name return self.source_name_to_debug_name.get(src_name, src_name) def _render_range_for_constraint_violation( @@ -5289,7 +5289,7 @@ def _render_range_for_constraint_violation( if upper >= default.upper: upper = None c_render = ( - f"{self._debug_name(source)} = {source.name()} in the specified range" + f"{self._debug_name(source)} = {source.name} in the specified range" ) if lower is not None and upper is not None: c_render += f" {lower} <= {self._debug_name(source)} <= {upper}" @@ -5311,7 +5311,7 @@ def produce_guards_verbose( self, placeholders: Sequence[FakeTensor], sources: Sequence[Source], - source_ref: Callable[[Source], str] = lambda n: n.name(), + source_ref: Callable[[Source], str] = lambda n: n.name, *, guards: Optional[list[ShapeGuard]] = None, input_contexts: Optional[DimList[SymbolicContext]] = None, @@ -5501,10 +5501,10 @@ def is_dim(src: object) -> TypeGuard[TensorPropertySource]: if equalities_inputs: source_index = {} for i, src in enumerate(sources): - source_index[src.name()] = i + source_index[src.name] = i def get_expression(tensor_dim_src: Source) -> sympy.Expr: - fake = placeholders[source_index[tensor_dim_src.base.name()]] # type: ignore[attr-defined] + fake = placeholders[source_index[tensor_dim_src.base.name]] # type: ignore[attr-defined] assert tensor_dim_src.idx is not None # type: ignore[attr-defined] symint = fake.shape[tensor_dim_src.idx] # type: ignore[attr-defined] if isinstance(symint, torch.SymInt): @@ -5521,16 +5521,16 @@ def get_expression(tensor_dim_src: Source) -> sympy.Expr: concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2)) if not concrete_val: raise ConstraintViolationError( - f"{src1.name()} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}" + f"{src1.name} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}" " is not equal to " - f"{src2.name()} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}" + f"{src2.name} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}" ) for srcEq, root, fn in equalities_inputs.derived_equalities: expr1 = get_expression(srcEq) # recall that root is either a phantom symbol or an input source if isinstance(root, sympy.Symbol): - expr2, debug_name = root, self.var_to_sources[root][0].name() + expr2, debug_name = root, self.var_to_sources[root][0].name elif isinstance(root, sympy.Integer): expr2, debug_name = root, str(root) else: @@ -5542,7 +5542,7 @@ def get_expression(tensor_dim_src: Source) -> sympy.Expr: concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_)) if not concrete_val: raise ConstraintViolationError( - f"Expected input {srcEq.name()} to be equal to " + f"Expected input {srcEq.name} to be equal to " f"{fn(sympy.Symbol(debug_name))}, " f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, " f"but got {expr1.xreplace(self.var_to_val)}" @@ -5569,7 +5569,12 @@ def get_expression(tensor_dim_src: Source) -> sympy.Expr: def track_symint( source: Source, val: IntLikeType, constraint: DimConstraint = None ) -> None: - log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint) + log.debug( + "track_symint %s %s %s", + LazyString(lambda: source.name), + val, + constraint, + ) assert not isinstance(val, SymInt) or is_symbolic(val) if isinstance(val, SymInt) and val.node.maybe_as_int() is not None: @@ -5658,7 +5663,7 @@ def hint(s: sympy.Expr) -> str: ) def track_symfloat(source: Source, val: FloatLikeType) -> None: - log.debug("track_symfloat %s %s", LazyString(source.name), val) + log.debug("track_symfloat %s %s", LazyString(lambda: source.name), val) assert not isinstance(val, SymFloat) or is_symbolic(val) if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None: @@ -5764,7 +5769,7 @@ def track_symfloat(source: Source, val: FloatLikeType) -> None: if not _simplified: for source, expr in input_guards: - srcname = source.name() + srcname = source.name if self._translation_validation_enabled: # Ignore sources that were not turned into SymInts. if srcname in self.source_to_symbol: @@ -5827,8 +5832,8 @@ def track_symfloat(source: Source, val: FloatLikeType) -> None: ) ): msg = ( - f"The values of {self._debug_name(source)} = {source.name()} and " - f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} " + f"The values of {self._debug_name(source)} = {source.name} and " + f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name} " "must always be equal." ) record_constraint_violation( @@ -5846,8 +5851,8 @@ def track_symfloat(source: Source, val: FloatLikeType) -> None: ): src = symbol_to_source[symbol][0] msg = ( - f"The values of {self._debug_name(source)} = {source.name()} must always be related to " - f"the values of {self._debug_name(src)} = {src.name()} by " + f"The values of {self._debug_name(source)} = {source.name} must always be related to " + f"the values of {self._debug_name(src)} = {src.name} by " f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}." ) record_constraint_violation( @@ -6868,7 +6873,7 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: "symbolic_shape_specialization", metadata_fn=lambda: { "symbol": repr(a), - "sources": [s.name() for s in self.var_to_sources.get(a, [])], + "sources": [s.name for s in self.var_to_sources.get(a, [])], "value": repr(tgt), "reason": msg, "stack": structured.from_traceback( @@ -6886,7 +6891,7 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: if config.print_specializations: self.log.warning( - "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt + "Specializing %s to %s", self.var_to_sources[a][0].name, tgt ) self.log.debug("SPECIALIZATION", stack_info=True) log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) @@ -7211,7 +7216,7 @@ def go(x: Any) -> Optional[str]: if str(s) in frame_symbols: # type: ignore[operator] continue if s in self.var_to_sources: - frame_symbols[str(s)] = self.var_to_sources[s][0].name() # type: ignore[assignment] + frame_symbols[str(s)] = self.var_to_sources[s][0].name # type: ignore[assignment] return str(x) return None From bc43d5b297f207a11d83d77ddf0152bdaabe15a8 Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 3 Dec 2025 13:22:47 -0800 Subject: [PATCH 243/338] [dynamo, guards] cache GuardBuilder.get() on sources (#168203) Partial fix for https://github.com/pytorch/pytorch/issues/168118. Decreases guard build time from 16s -> 9s on a local tlparse. On the guard build benchmark, time went from 84.66s -> 57.25s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168203 Approved by: https://github.com/anijain2305 ghstack dependencies: #168131 --- test/dynamo/test_guard_manager.py | 4 +- test/dynamo/test_misc.py | 10 ++ torch/_dynamo/guards.py | 98 ++++++----- torch/_dynamo/source.py | 264 ++++++++++++++++-------------- torch/_guards.py | 51 +++++- 5 files changed, 262 insertions(+), 165 deletions(-) diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 5515500d7cda7..a563f66dc2aac 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -928,8 +928,8 @@ def hook(guard_wrapper, f_locals, builder): foo_source = LocalSource("foo") foo_x_source = AttrSource(foo_source, "x") - self.assertTrue(builder.get(foo_source.name) is foo) - self.assertTrue(builder.get(foo_x_source.name) is foo.x) + self.assertTrue(builder.get(foo_source) is foo) + self.assertTrue(builder.get(foo_x_source) is foo.x) # Check types of foo.x foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 78b5c7e4553da..b3e2e9d4fee4d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -13304,6 +13304,16 @@ def f(*args, **kwargs): self.assertRaises(Unsupported, f, []) self.assertRaises(Unsupported, f, "1 + j") + def test_guard_string_escaped(self): + d = {frozenset({0}): {frozenset({0}): 1}} + + @torch.compile(backend="eager") + def f(x): + return x + d[frozenset({0})][frozenset({0})] + + x = torch.ones(3) + self.assertEqual(x + 1, f(x)) + def test_compiled_class_graph_break(self): counter = CompileCounter() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 0f44cabf66f4a..be327f7778723 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1003,6 +1003,9 @@ def __init__( self.source_ref = source_ref self.lookup_weakrefs = lookup_weakrefs self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope} + self.src_get_value_cache: weakref.WeakKeyDictionary[Source, object] = ( + weakref.WeakKeyDictionary() + ) self.runtime_global_scope = runtime_global_scope or global_scope self.source_get_cache = source_get_cache or {} self.scope["__builtins__"] = builtins.__dict__.copy() @@ -1043,7 +1046,7 @@ def __init__( self.key_order_guarded_dict_ids = set() assert self.check_fn_manager.output_graph is not None for source in self.check_fn_manager.output_graph.guard_on_key_order: - dict_obj = self.get(source.name) + dict_obj = self.get(source) if self.save_guards: self.source_get_cache[source.name] = dict_obj self.key_order_guarded_dict_ids.add(id(dict_obj)) @@ -1307,7 +1310,7 @@ def requires_key_order_guarding(self, source: Source) -> bool: source_name = source.name if source_name == "": return False - obj_id = id(self.get(source_name)) + obj_id = id(self.get(source)) return obj_id in self.key_order_guarded_dict_ids def get_guard_manager_type( @@ -1353,7 +1356,7 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: return self._cached_guard_managers[source_name] if source_name != "": - example_value = self.get(source_name) + example_value = self.get(source) self.guard_tree_values[id(example_value)] = example_value guard_manager_enum = self.get_guard_manager_type(source, example_value) @@ -1365,7 +1368,7 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: base_guard_manager_enum = GuardManagerType.GUARD_MANAGER if isinstance(source, ChainedSource): base_source_name = source.base.name - base_example_value = self.get(base_source_name) + base_example_value = self.get(source.base) base_guard_manager = self.get_guard_manager_from_source(source.base) base_guard_manager_enum = self.get_guard_manager_type( source.base, base_example_value @@ -1799,13 +1802,22 @@ def add_python_lambda_leaf_guard_to_root( # to this frame!) Instead, you should be reading out some property # (like its type) which is what you permanently install into the # guard code. - def get(self, name: str, closure_vars: Optional[dict[str, Any]] = None) -> Any: + def get( + self, + guard_or_source: Guard | Source, + closure_vars: Optional[dict[str, Any]] = None, + ) -> Any: + name = guard_or_source.name + if isinstance(guard_or_source, Source): + src = guard_or_source + else: + src = guard_or_source.originating_source if self.source_get_cache: if name in self.source_get_cache: return self.source_get_cache[name] if closure_vars is None: closure_vars = _get_closure_vars() - ret = eval(name, self.scope, closure_vars) + ret = src.get_value(self.scope, closure_vars, self.src_get_value_cache) if self.save_guards and ".__closure__" in name: self.source_get_cache[name] = ret return ret @@ -1861,7 +1873,7 @@ def HASATTR(self, guard: Guard) -> None: attr = source.member ref = self.arg_ref(base) - val = hasattr(self.get(base), attr) + val = hasattr(self.get(base_source), attr) code = None if val: code = f"hasattr({ref}, {attr!r})" @@ -1872,15 +1884,15 @@ def HASATTR(self, guard: Guard) -> None: return self._set_guard_export_info( - guard, [code], provided_guarded_object=self.get(base) + guard, [code], provided_guarded_object=self.get(base_source) ) base_manager = self.get_guard_manager_from_source(base_source) if val: # Just install a getattr manager. GetAttrGuardAccessor itself # acts as hasattr guard. - example_value = self.get(source.name) - base_example_value = self.get(base) + example_value = self.get(source) + base_example_value = self.get(base_source) guard_manager_enum = self.get_guard_manager_type(source, example_value) # if the base value is nn.Module, check if we can speedup the @@ -1911,7 +1923,7 @@ def NOT_PRESENT_IN_GENERIC_DICT( ) -> None: assert attr is not None ref = self.arg_ref(guard) - val = self.get(guard.name) + val = self.get(guard) base_manager = self.get_guard_manager(guard) @@ -1933,7 +1945,7 @@ def NOT_PRESENT_IN_GENERIC_DICT( def TYPE_MATCH(self, guard: Guard) -> None: # ___check_type_id is same as `id(type(x)) == y` - value = self.get(guard.name) + value = self.get(guard) if isinstance(value, torch._subclasses.FakeTensor) and value.pytype: t = value.pytype else: @@ -1956,8 +1968,8 @@ def TYPE_MATCH(self, guard: Guard) -> None: def DICT_VERSION(self, guard: Guard) -> None: # ___check_dict_version is same as `dict_version(x) == y` ref = self.arg_ref(guard) - val = self.get(guard.name) - version = dict_version(self.get(guard.name)) + val = self.get(guard) + version = dict_version(self.get(guard)) code = f"___dict_version({ref}) == {version}" self._set_guard_export_info(guard, [code]) @@ -2000,7 +2012,7 @@ def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None: def BOOL_MATCH(self, guard: Guard) -> None: # checks val == True or val == False ref = self.arg_ref(guard) - val = self.get(guard.name) + val = self.get(guard) assert istype(val, bool) code = [f"{ref} == {val!r}"] self._set_guard_export_info(guard, code) @@ -2017,7 +2029,7 @@ def BOOL_MATCH(self, guard: Guard) -> None: def NONE_MATCH(self, guard: Guard) -> None: # checks `val is None` ref = self.arg_ref(guard) - val = self.get(guard.name) + val = self.get(guard) assert val is None code = [f"{ref} is None"] self._set_guard_export_info(guard, code) @@ -2028,7 +2040,7 @@ def NONE_MATCH(self, guard: Guard) -> None: def ID_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: # TODO - Run a CI with the following uncommented to find the remaining places - # val = self.get(guard.name) + # val = self.get(guard) # if inspect.isclass(val): # raise AssertionError(f"{guard.name} is a class, use CLASS_MATCH guard") # if inspect.ismodule(val): @@ -2046,7 +2058,7 @@ def id_match_unchecked( ) ref = self.arg_ref(guard) - val = self.get(guard.name) + val = self.get(guard) id_val = self.id_ref(val, guard.name) try: type_repr = repr(val) @@ -2074,7 +2086,7 @@ def id_match_unchecked( def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: ref = self.arg_ref(guard) - val = self.get(guard.name) + val = self.get(guard) assert isinstance(val, torch.Tensor) code = f"{ref} is not None" self._set_guard_export_info(guard, [code]) @@ -2085,7 +2097,7 @@ def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: def DISPATCH_KEY_SET_MATCH(self, guard: Guard) -> None: ref = self.arg_ref(guard) - val = self.get(guard.name) + val = self.get(guard) assert isinstance(val, torch._C.DispatchKeySet) code_parts = f"{ref}.raw_repr() == {val!r}.raw_repr()" @@ -2152,8 +2164,8 @@ def fn(x: Any) -> bool: ) def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard) -> None: - value = self.get(guard.name) - original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) + value = self.get(guard) + original_metadata = deepcopy(self.get(guard).__tensor_flatten__()[1]) if hasattr(value, "__metadata_guard__"): verify_guard_fn_signature(value) cls = type(value) @@ -2176,7 +2188,7 @@ def metadata_checker(x: Any) -> bool: def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None: # Copied from DTensor __metadata_guard__ # TODO - Consider moving this to C++ if stable - value = deepcopy(self.get(guard.name)) + value = deepcopy(self.get(guard)) def guard_fn(x: Any) -> bool: return x._check_equals(value, skip_shapes=True) @@ -2188,7 +2200,7 @@ def guard_fn(x: Any) -> bool: def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: ref = self.arg_ref(guard) - val = self.get(guard.name) + val = self.get(guard) if np: np_types: tuple[type[Any], ...] = ( np.int8, @@ -2292,7 +2304,7 @@ def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> No return def CONSTANT_MATCH(self, guard: Guard) -> None: - val = self.get(guard.name) + val = self.get(guard) if istype(val, bool): self.BOOL_MATCH(guard) elif val is None: @@ -2305,7 +2317,7 @@ def CONSTANT_MATCH(self, guard: Guard) -> None: def NN_MODULE(self, guard: Guard) -> None: # don't support this in serialization because it uses unsupported ID_MATCH self.ID_MATCH(guard, "[inline-inbuilt-nn-modules-candidate]") - val = self.get(guard.name) + val = self.get(guard) if hasattr(val, "training"): assert istype(val.training, bool) if not self.guard_nn_modules: @@ -2329,7 +2341,7 @@ def FUNCTION_MATCH(self, guard: Guard) -> None: def CLASS_MATCH(self, guard: Guard) -> None: """Equals ID_MATCH on classes - better readability than directly calling ID_MATCH""" - val = self.get(guard.name) + val = self.get(guard) if not inspect.isclass(val): raise AssertionError( f"{guard.name} is not a class, but CLASS_MATCH is used" @@ -2338,7 +2350,7 @@ def CLASS_MATCH(self, guard: Guard) -> None: def MODULE_MATCH(self, guard: Guard) -> None: """Equals ID_MATCH on modules - better readability than directly calling ID_MATCH""" - val = self.get(guard.name) + val = self.get(guard) if not inspect.ismodule(val): raise AssertionError( f"{guard.name} is not a module, but MODULE_MATCH is used" @@ -2348,7 +2360,7 @@ def MODULE_MATCH(self, guard: Guard) -> None: def CLOSURE_MATCH(self, guard: Guard) -> None: """matches a closure by __code__ id.""" # don't support this in serialization because it uses unsupported FUNCTION_MATCH - val = self.get(guard.name) + val = self.get(guard) # Strictly only want user-defined functions if type(val) is types.FunctionType and hasattr(val, "__code__"): self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type] @@ -2369,7 +2381,7 @@ def SEQUENCE_LENGTH(self, guard: Guard) -> None: # This guard is used to check length of PySequence objects like list, # tuple, collections.deque etc ref = self.arg_ref(guard) - value = self.get(guard.name) + value = self.get(guard) if not isinstance(value, dict): # C++ DICT_LENGTH checks for type @@ -2393,7 +2405,7 @@ def SEQUENCE_LENGTH(self, guard: Guard) -> None: def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None: ref = self.arg_ref(guard) - value = self.get(guard.name) + value = self.get(guard) t = type(value) code = [] @@ -2409,7 +2421,7 @@ def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None: def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None: ref = self.arg_ref(guard) - value = self.get(guard.name) + value = self.get(guard) t = type(value) code = [] @@ -2477,7 +2489,7 @@ def WEAKREF_ALIVE(self, guard: Guard) -> None: def MAPPING_KEYS_CHECK(self, guard: Guard) -> None: """Guard on the key order of types.MappingProxyType object""" ref = self.arg_ref(guard) - value = self.get(guard.name) + value = self.get(guard) code = [] code.append(f"list({ref}.keys()) == {list(value.keys())}") @@ -2487,7 +2499,7 @@ def MAPPING_KEYS_CHECK(self, guard: Guard) -> None: def DICT_KEYS_MATCH(self, guard: Guard) -> None: """Insert guard to check that the keys of a dict are same""" ref = self.arg_ref(guard) - value = self.get(guard.name) + value = self.get(guard) if value is torch.utils._pytree.SUPPORTED_NODES: # For SUPPORTED_NODES, we can guard on the dictionary version (PEP509). @@ -2709,7 +2721,7 @@ def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: python_fallback = True else: example_value = self.get( - source.name, + source, closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, ) if isinstance(example_value, int): @@ -2819,7 +2831,7 @@ def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: if isinstance(value, TensorWeakRef): value = value() - value = value if value is not None else self.get(guard.name) + value = value if value is not None else self.get(guard) pytype = type(value) dispatch_keys = torch._C._dispatch_keys(value) @@ -2868,11 +2880,15 @@ def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: "dtype", "device", "requires_grad", - "ndimension()", + "ndimension", ] for term in terms: - real_value = self.get(tensor_name + "." + term) + term_src = AttrSource(guard.originating_source, term) + if term == "ndimension": + term = "ndimension()" + term_src = CallFunctionNoArgsSource(term_src) + real_value = self.get(term_src) if istype(real_value, (torch.device, torch.dtype)): # copy pasted from EQUALS_MATCH code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}") @@ -3018,7 +3034,7 @@ def _set_guard_export_info( # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD) if provided_guarded_object is None: name = guard.name - guarded_object = None if not name else self.get(name) + guarded_object = None if not name else self.get(guard) else: guarded_object = provided_guarded_object @@ -3635,7 +3651,7 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: # things like "not hasattr(x, 'foo')". In cases like this, # we don't have a well defined value because such thing # doesn't exist. - value = builder.get(guard.name) + value = builder.get(guard) has_value = True except: # noqa: B001,E722 value = MISSING @@ -3792,7 +3808,7 @@ def serialize_guards( if guard_type in ("TYPE_MATCH", "BUILTIN_MATCH"): if guard._unserializable: # Only call builder.get again if we know we're going to throw - obj = builder.get(guard.name) + obj = builder.get(guard) raise_local_type_error(obj) elif ( guard_type in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 8b42472465984..d2b91530f8e4f 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -122,6 +122,21 @@ def _get_source_debug_name(source: Optional[Source]) -> str: return "" +def _esc_str(s: Any, apply_repr: bool = False) -> str: + """ + Escapes curly brackets for format strings. + e.g. "frozenset({0})" becomes "frozenset({{0}})". + This is used by _name_template for example, because it's + expected to return a format string, but we may wish to include + strings that should not be accidentally formatted. + """ + if apply_repr: + s = repr(s) + else: + s = str(s) + return s.replace("{", "{{").replace("}", "}}") + + @dataclasses.dataclass(frozen=True) class LocalSource(Source): local_name: str @@ -148,8 +163,8 @@ def guard_source(self) -> GuardSource: return GuardSource.LOCAL @functools.cached_property - def name(self) -> str: - return f"L[{repr(self.local_name)}]" + def _name_template(self) -> str: + return f"L[{_esc_str(self.local_name, apply_repr=True)}]" @dataclasses.dataclass(frozen=True) @@ -164,7 +179,7 @@ def guard_source(self) -> GuardSource: return GuardSource.TEMP_LOCAL @property - def name(self) -> str: + def _name_template(self) -> str: raise NotImplementedError( "Cannot create guard on TempLocalSource - this is an internal Dynamo bug. Please file an issue on GitHub." ) @@ -181,8 +196,8 @@ def guard_source(self) -> GuardSource: return GuardSource.SYNTHETIC_LOCAL @functools.cached_property - def name(self) -> str: - return f"SYNTHETIC_LOCAL[{self.local_name!r}]" + def _name_template(self) -> str: + return f"SYNTHETIC_LOCAL[{_esc_str(self.local_name, apply_repr=True)}]" @dataclasses.dataclass(frozen=True) @@ -198,8 +213,8 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(create_binary_subscr()) @functools.cached_property - def name(self) -> str: - return f"random_value_{self.random_call_index}" + def _name_template(self) -> str: + return f"random_value_{_esc_str(self.random_call_index)}" @dataclasses.dataclass(frozen=True) @@ -213,8 +228,8 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @functools.cached_property - def name(self) -> str: - return f"G[{repr(self.global_name)}]" + def _name_template(self) -> str: + return f"G[{_esc_str(self.global_name, apply_repr=True)}]" @dataclasses.dataclass(frozen=True) @@ -233,8 +248,8 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @functools.cached_property - def name(self) -> str: - return f"G[{repr(self.global_name)}]()" + def _name_template(self) -> str: + return f"G[{_esc_str(self.global_name, apply_repr=True)}]()" @dataclasses.dataclass(frozen=True) @@ -246,9 +261,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}()" + @property + def _name_template(self) -> str: + return "{0}()" @dataclasses.dataclass(frozen=True) @@ -277,10 +292,10 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: if not self.member.isidentifier(): - return f"getattr({self.base.name}, {self.member!r})" - return f"{self.base.name}.{self.member}" + return f"getattr({{0}}, {_esc_str(self.member, apply_repr=True)})" + return f"{{0}}.{_esc_str(self.member)}" @dataclasses.dataclass(frozen=True) @@ -304,8 +319,10 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() @functools.cached_property - def name(self) -> str: - return f"object.__getattribute__({self.base.name}, {self.member!r})" + def _name_template(self) -> str: + return ( + f"object.__getattribute__({{0}}, {_esc_str(self.member, apply_repr=True)})" + ) # Represents obj.__dict__ where obj is a type object @@ -318,13 +335,13 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: + @property + def _name_template(self) -> str: # type(ob).__dict__ can return a proxy of the dict. But in the C++ # guard accessor, we are use type->tp_dict which is a dict. So, # forcefully pass a dict object to ensure that the GuardManager # registers that its working on a dict object. - return f"dict({self.base.name}.__dict__)" + return "dict({0}.__dict__)" # Represents obj.__mro__ where object is type object @@ -337,9 +354,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}.__mro__" + @property + def _name_template(self) -> str: + return "{0}.__mro__" @dataclasses.dataclass(frozen=True) @@ -371,9 +388,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}.__code__" + @property + def _name_template(self) -> str: + return "{0}.__code__" # Represents obj.__closure__ where object is type object @@ -386,9 +403,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}.__closure__" + @property + def _name_template(self) -> str: + return "{0}.__closure__" # Represents tensor.grad source. It could be represented by AttrSource as well. @@ -407,8 +424,8 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() @functools.cached_property - def name(self) -> str: - return f"{self.base.name}.{self.member}" + def _name_template(self) -> str: + return f"{{0}}.{_esc_str(self.member)}" @dataclasses.dataclass(frozen=True) @@ -440,8 +457,9 @@ def guard_source(self) -> GuardSource: return GuardSource.EPHEMERAL @functools.cached_property - def name(self) -> str: - return f"" + def _name_template(self) -> str: + desc = ": " + self.desc if self.desc is not None else "" + return f"" def make_guard(self, fn: Callable[..., Any]) -> Guard: raise NotImplementedError @@ -459,8 +477,8 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() @property - def name(self) -> str: - return self.base.name + def _name_template(self) -> str: + return "{0}" class TensorProperty(enum.Enum): @@ -476,7 +494,7 @@ def method_name(self) -> str: elif self is TensorProperty.STORAGE_OFFSET: return "storage_offset" else: - raise AssertionError(f"unhandled {self}") + raise AssertionError(f"unhandled {_esc_str(self)}") @dataclasses.dataclass(frozen=True) @@ -494,7 +512,7 @@ def __post_init__(self) -> None: def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from( - utils.__name__, f"call_{self.prop.method_name()}" + utils.__name__, f"call_{_esc_str(self.prop.method_name())}" ) ) codegen(self.base) @@ -509,16 +527,16 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: if self.prop is TensorProperty.SIZE: - return f"{self.base.name}.size()[{self.idx}]" + return f"{{0}}.size()[{_esc_str(self.idx)}]" elif self.prop is TensorProperty.STRIDE: - return f"{self.base.name}.stride()[{self.idx}]" + return f"{{0}}.stride()[{_esc_str(self.idx)}]" elif self.prop is TensorProperty.STORAGE_OFFSET: assert self.idx is None - return f"{self.base.name}.storage_offset()" + return "{0}.storage_offset()" else: - raise AssertionError(f"unhandled {self.prop}") + raise AssertionError(f"unhandled {_esc_str(self.prop)}") @dataclasses.dataclass(frozen=True) @@ -535,8 +553,8 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() @functools.cached_property - def name(self) -> str: - return f"({self.idx}, {self.base.name})" + def _name_template(self) -> str: + return f"({_esc_str(self.idx)}, {{0}})" @dataclasses.dataclass(frozen=True) @@ -550,10 +568,10 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: + @property + def _name_template(self) -> str: # NB: use method call so that function stripping regexes work - return f"{self.base.name}.__neg__()" + return "{0}.__neg__()" @dataclasses.dataclass(frozen=True) @@ -567,9 +585,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"cast_symbool_to_symint_guardless({self.base.name})" + @property + def _name_template(self) -> str: + return "cast_symbool_to_symint_guardless({0})" @dataclasses.dataclass(frozen=True) @@ -591,9 +609,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"int({self.base.name})" + @property + def _name_template(self) -> str: + return "int({0})" @dataclasses.dataclass(frozen=True) @@ -607,9 +625,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}.__obj_flatten__()" + @property + def _name_template(self) -> str: + return "{0}.__obj_flatten__()" @dataclasses.dataclass(frozen=True) @@ -623,9 +641,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}._type().qualified_name()" + @property + def _name_template(self) -> str: + return "{0}._type().qualified_name()" class AttrProxySource(ChainedSource): @@ -635,9 +653,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}.get_base()" + @property + def _name_template(self) -> str: + return "{0}.get_base()" @dataclasses.dataclass(frozen=True) @@ -655,13 +673,15 @@ def __post_init__(self) -> None: assert isinstance(self.idx_key, str) object.__setattr__(self, "field", "__kwdefaults__") object.__setattr__( - self, "_name", f"{self.base.name}.{self.field}['{self.idx_key}']" + self, + "_name", + f"{{0}}.{_esc_str(self.field)}['{_esc_str(self.idx_key)}']", ) else: assert isinstance(self.idx_key, int) object.__setattr__(self, "field", "__defaults__") object.__setattr__( - self, "_name", f"{self.base.name}.{self.field}[{self.idx_key}]" + self, "_name", f"{{0}}.{_esc_str(self.field)}[{_esc_str(self.idx_key)}]" ) def reconstruct(self, codegen: "PyCodegen") -> None: @@ -674,7 +694,7 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: return self._name @@ -707,15 +727,15 @@ def unpack_slice(self) -> slice: return slice_class(*slice_args) @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: # Index can be of following types # 1) index is a slice - example 1:4 # 2) index is a constant - example string, integer assert not isinstance(self.index, Source) if self.index_is_slice: - return f"{self.base.name}[{self.unpack_slice()!r}]" + return f"{{0}}[{_esc_str(self.unpack_slice(), apply_repr=True)}]" else: - return f"{self.base.name}[{self.index!r}]" + return f"{{0}}[{_esc_str(self.index, apply_repr=True)}]" @dataclasses.dataclass(frozen=True) @@ -734,9 +754,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: # The list creation will be CSE'd by PyExprCSEPass - return f"list(dict.keys({self.base.name}))[{self.index!r}]" + return f"list(dict.keys({{0}}))[{_esc_str(self.index, apply_repr=True)}]" def is_dict_key(self) -> bool: return True @@ -763,9 +783,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: # set ordering might not be stable - return f"list({self.base.name})[{self.index!r}]" + return f"list({{0}})[{_esc_str(self.index, apply_repr=True)}]" def is_dict_key(self) -> bool: return False @@ -801,11 +821,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(create_binary_subscr()) @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: if isinstance(self.index, ConstDictKeySource): - return f"{self.base.name}[{self.index.name}]" + return f"{{0}}[{_esc_str(self.index.name)}]" else: - return f"{self.base.name}[{self.index!r}]" + return f"{{0}}[{_esc_str(self.index, apply_repr=True)}]" # Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that @@ -847,11 +867,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: if isinstance(self.index, ConstDictKeySource): - return f"dict.__getitem__({self.base.name}, {self.index.name})" + return f"dict.__getitem__({{0}}, {_esc_str(self.index.name)})" else: - return f"{self.base.name}[{self.index!r}]" + return f"{{0}}[{_esc_str(self.index, apply_repr=True)}]" @dataclasses.dataclass(frozen=True) @@ -883,7 +903,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: # Index can be of following types # 1) index is a slice - example 1:4 # 2) index is a constant - example string, integer @@ -893,7 +913,7 @@ def name(self) -> str: "List[slice] is a temporary object and should not have a source" ) else: - return f"list.__getitem__({self.base.name}, {self.index!r})" + return f"list.__getitem__({{0}}, {_esc_str(self.index, apply_repr=True)})" @dataclasses.dataclass(frozen=True) @@ -907,8 +927,10 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) @functools.cached_property - def name(self) -> str: - return f"___tuple_iterator_getitem({self.base.name}, {self.index!r})" + def _name_template(self) -> str: + return ( + f"___tuple_iterator_getitem({{0}}, {_esc_str(self.index, apply_repr=True)})" + ) @dataclasses.dataclass(frozen=True) @@ -920,9 +942,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"___namedtuple_fields({self.base.name})" + @property + def _name_template(self) -> str: + return "___namedtuple_fields({0})" @dataclasses.dataclass(frozen=True) @@ -937,9 +959,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"___dataclass_fields({self.base.name})" + @property + def _name_template(self) -> str: + return "___dataclass_fields({0})" @dataclasses.dataclass(frozen=True) @@ -955,9 +977,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return f"type({self.base.name})" + @property + def _name_template(self) -> str: + return "type({0})" @dataclasses.dataclass(frozen=True) @@ -968,9 +990,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return self.base.guard_source() - @functools.cached_property - def name(self) -> str: - return self.base.name + @property + def _name_template(self) -> str: + return "{0}" @dataclasses.dataclass(frozen=True) @@ -981,9 +1003,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def guard_source(self) -> GuardSource: return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] - @functools.cached_property - def name(self) -> str: - return self.base.name + @property + def _name_template(self) -> str: + return "{0}" @dataclasses.dataclass(frozen=True) @@ -1007,7 +1029,7 @@ def guard_source(self) -> GuardSource: @dataclasses.dataclass(frozen=True) class GlobalStateSource(Source): @property - def name(self) -> str: + def _name_template(self) -> str: return "" def guard_source(self) -> GuardSource: @@ -1026,7 +1048,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: install_guard(self.make_guard(GuardBuilder.ID_MATCH)) @property - def name(self) -> str: + def _name_template(self) -> str: return "__import__('torch')" def reconstruct(self, codegen: "PyCodegen") -> None: @@ -1054,7 +1076,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: install_guard(self.make_guard(GuardBuilder.ID_MATCH)) @property - def name(self) -> str: + def _name_template(self) -> str: return "__import__('collections')" def reconstruct(self, codegen: "PyCodegen") -> None: @@ -1075,8 +1097,8 @@ class TorchFunctionModeStackSource(Source): ind: int @functools.cached_property - def name(self) -> str: - return f"___get_torch_function_mode_stack_at({self._get_index()})" + def _name_template(self) -> str: + return f"___get_torch_function_mode_stack_at({_esc_str(self._get_index())})" def _get_index(self) -> int: from .variables.torch_function import TorchFunctionModeStackVariable @@ -1107,7 +1129,7 @@ def guard_source(self) -> GuardSource: return GuardSource.CONSTANT @functools.cached_property - def name(self) -> str: + def _name_template(self) -> str: return self.source_name def make_guard(self, fn: Any) -> Any: @@ -1116,9 +1138,9 @@ def make_guard(self, fn: Any) -> Any: @dataclasses.dataclass(frozen=True) class NumpyTensorSource(ChainedSource): - @functools.cached_property - def name(self) -> str: - return f"___from_numpy({self.base.name})" + @property + def _name_template(self) -> str: + return "___from_numpy({0})" def guard_source(self) -> GuardSource: return self.base.guard_source() @@ -1131,9 +1153,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: @dataclasses.dataclass(frozen=True) class SubclassAttrListSource(ChainedSource): - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}.__tensor_flatten__()[0]" + @property + def _name_template(self) -> str: + return "{0}.__tensor_flatten__()[0]" def guard_source(self) -> GuardSource: return self.base.guard_source() @@ -1143,9 +1165,9 @@ def guard_source(self) -> GuardSource: # source, it is ephemeral @dataclasses.dataclass(frozen=True) class FloatTensorSource(ChainedSource): - @functools.cached_property - def name(self) -> str: - return f"___as_tensor({self.base.name})" + @property + def _name_template(self) -> str: + return "___as_tensor({0})" def guard_source(self) -> GuardSource: return self.base.guard_source() @@ -1153,9 +1175,9 @@ def guard_source(self) -> GuardSource: @dataclasses.dataclass(frozen=True) class CallMethodItemSource(ChainedSource): - @functools.cached_property - def name(self) -> str: - return f"{self.base.name}.item()" + @property + def _name_template(self) -> str: + return "{0}.item()" def guard_source(self) -> GuardSource: return self.base.guard_source() @@ -1167,7 +1189,7 @@ def guard_source(self) -> GuardSource: @dataclasses.dataclass(frozen=True) class ShapeEnvSource(Source): @property - def name(self) -> str: + def _name_template(self) -> str: return "" def guard_source(self) -> GuardSource: @@ -1179,8 +1201,8 @@ class CurrentStreamSource(Source): device: device_type @functools.cached_property - def name(self) -> str: - return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))" + def _name_template(self) -> str: + return f"___get_current_stream(torch.device('{_esc_str(self.device.type)}', {_esc_str(self.device.index)}))" def reconstruct(self, codegen: "PyCodegen") -> None: num_args = 1 @@ -1202,7 +1224,7 @@ def guard_source(self) -> GuardSource: @dataclasses.dataclass(frozen=True) class BackwardStateSource(Source): @property - def name(self) -> str: + def _name_template(self) -> str: return "" def guard_source(self) -> GuardSource: diff --git a/torch/_guards.py b/torch/_guards.py index 8da885bbb683c..8e3dfcc482e02 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -1092,9 +1092,35 @@ def reconstruct(self, codegen: PyCodegen) -> None: def guard_source(self) -> GuardSource: raise NotImplementedError + @property + def _name_template(self) -> str: + """ + A template for the name of the source. Used to prevent code duplication between + `name` and `get_value`. + + For non-ChainedSources, `name` and `get_value` use the returned string directly. + + For ChainedSources, `name` and `get_value` expect the return to be a format string + with `{0}` present - `name` and `get_value` will apply different values to this function's + returned format string. + """ + raise NotImplementedError + @functools.cached_property def name(self) -> str: - raise NotImplementedError + return self._name_template + + def get_value( + self, + globals: dict[str, Any], + locals: dict[str, Any], + cache: weakref.WeakKeyDictionary[Source, Any], + ) -> Any: + if self in cache: + return cache[self] + value = eval(self._name_template, globals, locals) + cache[self] = value + return value def make_guard(self, fn: Callable[..., Any]) -> Guard: if self.guard_source() is GuardSource.CONSTANT: @@ -1127,6 +1153,29 @@ def get_base(self) -> Source: current = current.base return current + @functools.cached_property + def name(self) -> str: + return self._name_template.format(self.base.name) + + def get_value( + self, + globals: dict[str, Any], + locals: dict[str, Any], + cache: weakref.WeakKeyDictionary[Source, Any], + ) -> Any: + if self in cache: + return cache[self] + tmpvar = "tmp" + counter = 0 + while tmpvar in locals: + tmpvar = f"tmp{counter}" + counter += 1 + locals[tmpvar] = self.base.get_value(globals, locals, cache) + value = eval(self._name_template.format(tmpvar), globals, locals) + del locals[tmpvar] + cache[self] = value + return value + def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: """ From ea7035f462a0d2830865ee86c832bd101e1427fc Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 3 Dec 2025 13:22:48 -0800 Subject: [PATCH 244/338] [dynamo, guards] cache Source.guard_source (#168386) Reduces guard build time for a bit. Noticeable when guards get really large (e.g. DEPTH = 2000 in the repro). On the guard build benchmark, time went from 57.25s -> 50.18s Pull Request resolved: https://github.com/pytorch/pytorch/pull/168386 Approved by: https://github.com/anijain2305 ghstack dependencies: #168131, #168203 --- torch/_dynamo/guards.py | 2 +- torch/_dynamo/source.py | 125 ++++++----------------------- torch/_dynamo/utils.py | 12 +-- torch/_dynamo/variables/builder.py | 18 ++--- torch/_guards.py | 13 ++- 5 files changed, 51 insertions(+), 119 deletions(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index be327f7778723..ea720d5c49f5f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -3932,7 +3932,7 @@ def build_guards( w_builder = None def source_ref(source: Source) -> str: - guard_source = source.guard_source() + guard_source = source.guard_source if guard_source is GuardSource.CONSTANT: # No need to track constants return source.name diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index d2b91530f8e4f..776aef718e935 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -104,7 +104,7 @@ def is_constant_source(source: Source) -> bool: if isinstance(source, ConstantSource): return True try: - if source.guard_source() == GuardSource.CONSTANT: + if source.guard_source == GuardSource.CONSTANT: return True except NotImplementedError: pass @@ -159,6 +159,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: else: codegen.append_output(codegen.create_load(self.local_name)) + @property def guard_source(self) -> GuardSource: return GuardSource.LOCAL @@ -175,6 +176,7 @@ class TempLocalSource(Source): def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load(self.local_name)) + @property def guard_source(self) -> GuardSource: return GuardSource.TEMP_LOCAL @@ -192,6 +194,7 @@ class SyntheticLocalSource(Source): def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load(self.local_name)) + @property def guard_source(self) -> GuardSource: return GuardSource.SYNTHETIC_LOCAL @@ -204,6 +207,7 @@ def _name_template(self) -> str: class RandomValueSource(Source): random_call_index: int + @property def guard_source(self) -> GuardSource: return GuardSource.RANDOM_VALUE @@ -224,6 +228,7 @@ class GlobalSource(Source): def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_global(self.global_name, add=True)) + @property def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @@ -244,6 +249,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: ) codegen.extend_output(create_call_function(0, False)) + @property def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @@ -258,9 +264,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen(self.base)) codegen.extend_output(create_call_function(0, False)) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}()" @@ -288,9 +291,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @functools.cached_property def _name_template(self) -> str: if not self.member.isidentifier(): @@ -315,9 +315,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @functools.cached_property def _name_template(self) -> str: return ( @@ -332,9 +329,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs("__dict__")) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: # type(ob).__dict__ can return a proxy of the dict. But in the C++ @@ -351,9 +345,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs("__mro__")) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}.__mro__" @@ -385,9 +376,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs("__code__")) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}.__code__" @@ -400,9 +388,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs("__closure__")) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}.__closure__" @@ -420,9 +405,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @functools.cached_property def _name_template(self) -> str: return f"{{0}}.{_esc_str(self.member)}" @@ -430,8 +412,9 @@ def _name_template(self) -> str: @dataclasses.dataclass(frozen=True) class ParamBufferSource(AttrSource): + @functools.cached_property def guard_source(self) -> GuardSource: - return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] + return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source] # Special AttrSource to differentiate module._buffers or module._parameters @@ -453,6 +436,7 @@ class UnspecializedParamBufferSource(AttrSource): class EphemeralSource(Source): desc: Optional[str] = None + @property def guard_source(self) -> GuardSource: return GuardSource.EPHEMERAL @@ -473,9 +457,6 @@ class SkipGuardSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: self.base.reconstruct(codegen) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}" @@ -523,9 +504,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: create_call_function(2 if self.idx is not None else 1, False) ) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @functools.cached_property def _name_template(self) -> str: if self.prop is TensorProperty.SIZE: @@ -549,9 +527,6 @@ def __post_init__(self) -> None: def reconstruct(self, codegen: "PyCodegen") -> None: raise NotImplementedError - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @functools.cached_property def _name_template(self) -> str: return f"({_esc_str(self.idx)}, {{0}})" @@ -565,9 +540,6 @@ def __post_init__(self) -> None: def reconstruct(self, codegen: "PyCodegen") -> None: raise NotImplementedError - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: # NB: use method call so that function stripping regexes work @@ -582,9 +554,6 @@ def __post_init__(self) -> None: def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "cast_symbool_to_symint_guardless({0})" @@ -606,9 +575,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(create_call_function(1, False)) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "int({0})" @@ -622,9 +588,6 @@ def __post_init__(self) -> None: def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}.__obj_flatten__()" @@ -638,9 +601,6 @@ def __post_init__(self) -> None: def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}._type().qualified_name()" @@ -650,9 +610,6 @@ class AttrProxySource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}.get_base()" @@ -690,9 +647,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.idx_key)) codegen.append_output(create_binary_subscr()) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @functools.cached_property def _name_template(self) -> str: return self._name @@ -718,9 +672,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(create_binary_subscr()) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - def unpack_slice(self) -> slice: assert self.index_is_slice slice_class, slice_args = self.index @@ -742,9 +693,6 @@ def _name_template(self) -> str: class ConstDictKeySource(ChainedSource): index: Any - def guard_source(self) -> GuardSource: - return self.base.guard_source() - def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") @@ -771,9 +719,6 @@ def __post_init__(self) -> None: assert ConstantVariable.is_literal(self.index) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "set_getitem") @@ -806,9 +751,6 @@ def __post_init__(self) -> None: self.index, ConstDictKeySource ) or ConstantVariable.is_literal(self.index) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - def reconstruct(self, codegen: "PyCodegen") -> None: # Load dict codegen(self.base) @@ -844,9 +786,6 @@ def __post_init__(self) -> None: self.index, ConstDictKeySource ) or ConstantVariable.is_literal(self.index) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - def reconstruct(self, codegen: "PyCodegen") -> None: # reconstruct dict.__getitem__(dct, key) @@ -939,9 +878,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs("_fields")) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "___namedtuple_fields({0})" @@ -956,9 +892,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(create_call_function(1, False)) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "___dataclass_fields({0})" @@ -974,9 +907,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(create_call_function(1, False)) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "type({0})" @@ -987,9 +917,6 @@ class OptimizerSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @property def _name_template(self) -> str: return "{0}" @@ -1000,8 +927,9 @@ class NNModuleSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) + @functools.cached_property def guard_source(self) -> GuardSource: - return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] + return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source] @property def _name_template(self) -> str: @@ -1010,20 +938,23 @@ def _name_template(self) -> str: @dataclasses.dataclass(frozen=True) class UnspecializedNNModuleSource(NNModuleSource): + @functools.cached_property def guard_source(self) -> GuardSource: - return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()] + return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source] @dataclasses.dataclass(frozen=True) class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource): + @functools.cached_property def guard_source(self) -> GuardSource: - return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()] + return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source] @dataclasses.dataclass(frozen=True) class FSDPNNModuleSource(NNModuleSource): + @functools.cached_property def guard_source(self) -> GuardSource: - return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] + return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source] @dataclasses.dataclass(frozen=True) @@ -1032,6 +963,7 @@ class GlobalStateSource(Source): def _name_template(self) -> str: return "" + @property def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @@ -1060,6 +992,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: ] ) + @property def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @@ -1088,6 +1021,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: ] ) + @property def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @@ -1114,6 +1048,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output([codegen.create_load_const(self._get_index())]) codegen.extend_output(create_call_function(1, False)) + @property def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @@ -1125,6 +1060,7 @@ class ConstantSource(Source): def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_global(self.source_name, add=False)) + @property def guard_source(self) -> GuardSource: return GuardSource.CONSTANT @@ -1142,9 +1078,6 @@ class NumpyTensorSource(ChainedSource): def _name_template(self) -> str: return "___from_numpy({0})" - def guard_source(self) -> GuardSource: - return self.base.guard_source() - def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) codegen(self.base) @@ -1157,9 +1090,6 @@ class SubclassAttrListSource(ChainedSource): def _name_template(self) -> str: return "{0}.__tensor_flatten__()[0]" - def guard_source(self) -> GuardSource: - return self.base.guard_source() - # NB: We don't expect you to actually ever generate guards against this # source, it is ephemeral @@ -1169,9 +1099,6 @@ class FloatTensorSource(ChainedSource): def _name_template(self) -> str: return "___as_tensor({0})" - def guard_source(self) -> GuardSource: - return self.base.guard_source() - @dataclasses.dataclass(frozen=True) class CallMethodItemSource(ChainedSource): @@ -1179,9 +1106,6 @@ class CallMethodItemSource(ChainedSource): def _name_template(self) -> str: return "{0}.item()" - def guard_source(self) -> GuardSource: - return self.base.guard_source() - # This is a synthetic source that is associated with the singleton # shape env guard we always register for all frames. We get the actual @@ -1192,6 +1116,7 @@ class ShapeEnvSource(Source): def _name_template(self) -> str: return "" + @property def guard_source(self) -> GuardSource: return GuardSource.SHAPE_ENV @@ -1217,6 +1142,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(num_args, False)) codegen.extend_output(create_call_function(1, False)) + @property def guard_source(self) -> GuardSource: return GuardSource.GLOBAL @@ -1227,6 +1153,7 @@ class BackwardStateSource(Source): def _name_template(self) -> str: return "" + @property def guard_source(self) -> GuardSource: return GuardSource.BACKWARD_STATE diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index d3c351e0de01a..afdd0c7aefa4d 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2548,20 +2548,20 @@ def is_int_specialization_case(value: Any, source: Any) -> bool: return not TracingContext.get().force_unspec_int_unbacked_size_like and ( # Assume integers from global variables want to be specialized - not source.guard_source().is_local() + not source.guard_source.is_local() # Assume that integers that came from NN modules want to be # specialized (as we don't expect users to be changing the # NN modules on the fly), unless explicitly disabled or ( - source.guard_source().is_specialized_nn_module() + source.guard_source.is_specialized_nn_module() and not config.allow_unspec_int_on_nn_module ) or ( - source.guard_source().is_unspecialized_builtin_nn_module() + source.guard_source.is_unspecialized_builtin_nn_module() and not config.allow_unspec_int_on_nn_module ) or ( - source.guard_source().is_unspecialized_nn_module() + source.guard_source.is_unspecialized_nn_module() and not config.allow_unspec_int_on_nn_module ) or is_from_defaults(source) @@ -3856,8 +3856,8 @@ def tensor_always_has_static_shape( from .source import is_from_unspecialized_param_buffer_source if ( - tensor_source.guard_source().is_specialized_nn_module() - or tensor_source.guard_source().is_unspecialized_builtin_nn_module() + tensor_source.guard_source.is_specialized_nn_module() + or tensor_source.guard_source.is_unspecialized_builtin_nn_module() ) and config.force_nn_module_property_static_shapes: return True, TensorStaticReason.NN_MODULE_PROPERTY diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 0be956a4cac67..248ab9d5f4bab 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1699,7 +1699,7 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): if ( istype(value, tuple) and all(ConstantVariable.is_literal(item) for item in value) - and self.source.guard_source().is_unspecialized_nn_module() + and self.source.guard_source.is_unspecialized_nn_module() ): self.install_guards(GuardBuilder.CONSTANT_MATCH) return TupleVariable([ConstantVariable.create(item) for item in value]) @@ -2017,8 +2017,8 @@ def wrap_literal(self, value): if is_int_specialization_case(value, self.source): recompile_hint = None if ( - self.source.guard_source().is_unspecialized_builtin_nn_module() - or self.source.guard_source().is_unspecialized_nn_module() + self.source.guard_source.is_unspecialized_builtin_nn_module() + or self.source.guard_source.is_unspecialized_nn_module() ): # This means that it is an integer from a NN module. # Dynamo considers nn module int attributes to be static @@ -2036,7 +2036,7 @@ def wrap_literal(self, value): self.tx, self.source.name, FrameStateSizeEntry.make_scalar(value), - is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), ) self.install_guards( functools.partial( @@ -2078,7 +2078,7 @@ def wrap_tensor(self, value: torch.Tensor): isinstance(value, torch.nn.Parameter) # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior # compatible with previous behavior. - or (source and source.guard_source().is_unspecialized_nn_module()) + or (source and source.guard_source.is_unspecialized_nn_module()) ) ): self.mark_static_input(value, guard=is_parameter_freezing()) @@ -2101,8 +2101,8 @@ def wrap_tensor(self, value: torch.Tensor): ) if should_install_free_tensor or ( - (source.guard_source().is_specialized_nn_module() or make_graph_attribute) - and not source.guard_source().is_fsdp_module() + (source.guard_source.is_specialized_nn_module() or make_graph_attribute) + and not source.guard_source.is_fsdp_module() ): self.assert_not_wrapped_by_this_graph(value) return self.tx.output.register_attr_or_module( @@ -2446,7 +2446,7 @@ def wrap_symint( self.tx, name, FrameStateSizeEntry.make_scalar(value), - is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), ) # TODO: This should be dynamic, as we in general do not @@ -2541,7 +2541,7 @@ def wrap_symfloat(self, value): self.tx, self.source.name, FrameStateSizeEntry.make_scalar(value), - is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), ) # NB: we specialize on nan input, because our guard modeling in diff --git a/torch/_guards.py b/torch/_guards.py index 8e3dfcc482e02..af13099773f8b 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -301,7 +301,7 @@ def name(self) -> str: @property def source(self) -> GuardSource: - return self.originating_source.guard_source() + return self.originating_source.guard_source @staticmethod def weakref_to_str(obj_weakref: object) -> str: @@ -1089,6 +1089,7 @@ def is_ephemeral(self) -> bool: def reconstruct(self, codegen: PyCodegen) -> None: raise NotImplementedError + @functools.cached_property def guard_source(self) -> GuardSource: raise NotImplementedError @@ -1123,16 +1124,16 @@ def get_value( return value def make_guard(self, fn: Callable[..., Any]) -> Guard: - if self.guard_source() is GuardSource.CONSTANT: + if self.guard_source is GuardSource.CONSTANT: raise NotImplementedError return Guard(self, fn) def is_specialized_nn_module(self) -> bool: - return self.guard_source().is_specialized_nn_module() + return self.guard_source.is_specialized_nn_module() def subguards_allowed(self) -> bool: """True if you can guard on attributes of this""" - return self.guard_source() != GuardSource.SYNTHETIC_LOCAL + return self.guard_source != GuardSource.SYNTHETIC_LOCAL # Subclasses can be found in torch/_dynamo/source.py @@ -1147,6 +1148,10 @@ def is_dict_key(self) -> bool: def is_ephemeral(self) -> bool: return self.base.is_ephemeral() + @functools.cached_property + def guard_source(self) -> GuardSource: + return self.base.guard_source + def get_base(self) -> Source: current: Source = self while isinstance(current, ChainedSource): From c55b1e8f61d041ee436d697449eb028931d574fb Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 3 Dec 2025 13:22:48 -0800 Subject: [PATCH 245/338] [dynamo, guards] cache Source hashing (#168886) Final fix for https://github.com/pytorch/pytorch/issues/168118. Decreases guard build time from 9s -> 0.5s on a local tlparse. On the guard build benchmark, time went from 50.18s -> 8.15s Pull Request resolved: https://github.com/pytorch/pytorch/pull/168886 Approved by: https://github.com/anijain2305 ghstack dependencies: #168131, #168203, #168386 --- .../pr_time_benchmarks/expected_results.csv | 2 +- torch/_dynamo/source.py | 118 +++++++++--------- torch/_guards.py | 60 ++++++++- 3 files changed, 120 insertions(+), 60 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 58dc3f82c0a4c..5c7a29bea8e37 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -82,7 +82,7 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1 -basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1 +basic_NestedModule_eager,compile_time_instruction_count,6140000000,0.1 diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 776aef718e935..dd3386f765cfe 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -24,7 +24,13 @@ from typing import Any, Optional, TYPE_CHECKING, Union from torch import device as device_type -from torch._guards import ChainedSource, Guard, GuardSource, Source +from torch._guards import ( + ChainedSource, + dataclass_with_cached_hash, + Guard, + GuardSource, + Source, +) from . import utils from .bytecode_transformation import ( @@ -137,7 +143,7 @@ def _esc_str(s: Any, apply_repr: bool = False) -> str: return s.replace("{", "{{").replace("}", "}}") -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class LocalSource(Source): local_name: str @@ -168,7 +174,7 @@ def _name_template(self) -> str: return f"L[{_esc_str(self.local_name, apply_repr=True)}]" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class TempLocalSource(Source): # like LocalSource, but cannot be guarded on local_name: str @@ -187,7 +193,7 @@ def _name_template(self) -> str: ) -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class SyntheticLocalSource(Source): local_name: str @@ -203,7 +209,7 @@ def _name_template(self) -> str: return f"SYNTHETIC_LOCAL[{_esc_str(self.local_name, apply_repr=True)}]" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class RandomValueSource(Source): random_call_index: int @@ -221,7 +227,7 @@ def _name_template(self) -> str: return f"random_value_{_esc_str(self.random_call_index)}" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class GlobalSource(Source): global_name: str @@ -237,7 +243,7 @@ def _name_template(self) -> str: return f"G[{_esc_str(self.global_name, apply_repr=True)}]" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class GlobalWeakRefSource(Source): global_name: str @@ -258,7 +264,7 @@ def _name_template(self) -> str: return f"G[{_esc_str(self.global_name, apply_repr=True)}]()" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class WeakRefCallSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen(self.base)) @@ -269,12 +275,12 @@ def _name_template(self) -> str: return "{0}()" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class CallFunctionNoArgsSource(WeakRefCallSource): pass -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class AttrSource(ChainedSource): member: str @@ -298,7 +304,7 @@ def _name_template(self) -> str: return f"{{0}}.{_esc_str(self.member)}" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class GenericAttrSource(ChainedSource): member: str @@ -323,7 +329,7 @@ def _name_template(self) -> str: # Represents obj.__dict__ where obj is a type object -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class TypeDictSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -339,7 +345,7 @@ def _name_template(self) -> str: # Represents obj.__mro__ where object is type object -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class TypeMROSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -350,7 +356,7 @@ def _name_template(self) -> str: return "{0}.__mro__" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class LocalCellSource(Source): """ Conceptually, this class is `LocalSource` for cell objects implicitly @@ -370,7 +376,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # Represents obj.__code__ where object is type object -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class CodeSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -382,7 +388,7 @@ def _name_template(self) -> str: # Represents obj.__closure__ where object is type object -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ClosureSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -397,7 +403,7 @@ def _name_template(self) -> str: # But, we could access grad field on tensor directly in C++ without going # through the Python bytecodes. Therefore, we use a separate source for grad # field. -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class GradSource(ChainedSource): member: str = "grad" @@ -410,7 +416,7 @@ def _name_template(self) -> str: return f"{{0}}.{_esc_str(self.member)}" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ParamBufferSource(AttrSource): @functools.cached_property def guard_source(self) -> GuardSource: @@ -418,7 +424,7 @@ def guard_source(self) -> GuardSource: # Special AttrSource to differentiate module._buffers or module._parameters -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class UnspecializedParamBufferSource(AttrSource): pass @@ -432,7 +438,7 @@ class UnspecializedParamBufferSource(AttrSource): # symbolicized / fake-ified to avoid invalid specialization during view replay. This source # is useful for symbols utilized in the middle of the view chain that are not expected to be # present within the final view shape metadata. -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class EphemeralSource(Source): desc: Optional[str] = None @@ -452,7 +458,7 @@ def is_ephemeral(self) -> bool: return True -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class SkipGuardSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: self.base.reconstruct(codegen) @@ -478,7 +484,7 @@ def method_name(self) -> str: raise AssertionError(f"unhandled {_esc_str(self)}") -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class TensorPropertySource(ChainedSource): prop: TensorProperty idx: Optional[int] = None # None for STORAGE_OFFSET @@ -517,7 +523,7 @@ def _name_template(self) -> str: raise AssertionError(f"unhandled {_esc_str(self.prop)}") -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class IndexedSource(ChainedSource): idx: int @@ -532,7 +538,7 @@ def _name_template(self) -> str: return f"({_esc_str(self.idx)}, {{0}})" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class NegateSource(ChainedSource): def __post_init__(self) -> None: assert self.base is not None @@ -546,7 +552,7 @@ def _name_template(self) -> str: return "{0}.__neg__()" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ConvertIntSource(ChainedSource): def __post_init__(self) -> None: assert self.base is not None @@ -559,7 +565,7 @@ def _name_template(self) -> str: return "cast_symbool_to_symint_guardless({0})" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class DynamicScalarSource(ChainedSource): is_int: bool @@ -580,7 +586,7 @@ def _name_template(self) -> str: return "int({0})" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class FlattenScriptObjectSource(ChainedSource): def __post_init__(self) -> None: assert self.base is not None @@ -593,7 +599,7 @@ def _name_template(self) -> str: return "{0}.__obj_flatten__()" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ScriptObjectQualifiedNameSource(ChainedSource): def __post_init__(self) -> None: assert self.base is not None @@ -615,7 +621,7 @@ def _name_template(self) -> str: return "{0}.get_base()" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class DefaultsSource(ChainedSource): idx_key: Union[int, str] is_kw: bool = False @@ -652,7 +658,7 @@ def _name_template(self) -> str: return self._name -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class GetItemSource(ChainedSource): index: Any index_is_slice: bool = False @@ -689,7 +695,7 @@ def _name_template(self) -> str: return f"{{0}}[{_esc_str(self.index, apply_repr=True)}]" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ConstDictKeySource(ChainedSource): index: Any @@ -710,7 +716,7 @@ def is_dict_key(self) -> bool: return True -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class NonSerializableSetGetItemSource(ChainedSource): index: int @@ -737,7 +743,7 @@ def is_dict_key(self) -> bool: # Used to access an item from the dictionary -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class DictGetItemSource(ChainedSource): # Key to access in the dictionary. It can be one of the following types # 1) ConstDictKeySource @@ -772,7 +778,7 @@ def _name_template(self) -> str: # Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that # torch.compile does not run the overridden __getitem__ method -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class DictSubclassGetItemSource(ChainedSource): # Key to access in the dictionary. It can be one of the following types # 1) ConstDictKeySource @@ -813,7 +819,7 @@ def _name_template(self) -> str: return f"{{0}}[{_esc_str(self.index, apply_repr=True)}]" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ListGetItemSource(GetItemSource): """ Same as GetItemSource with reconstruct and name overridden to be list specific. @@ -855,7 +861,7 @@ def _name_template(self) -> str: return f"list.__getitem__({{0}}, {_esc_str(self.index, apply_repr=True)})" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class TupleIteratorGetItemSource(GetItemSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( @@ -872,7 +878,7 @@ def _name_template(self) -> str: ) -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class NamedTupleFieldsSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -883,7 +889,7 @@ def _name_template(self) -> str: return "___namedtuple_fields({0})" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class DataclassFieldsSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( @@ -897,7 +903,7 @@ def _name_template(self) -> str: return "___dataclass_fields({0})" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class TypeSource(ChainedSource): def __post_init__(self) -> None: assert self.base is not None @@ -912,7 +918,7 @@ def _name_template(self) -> str: return "type({0})" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class OptimizerSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -922,7 +928,7 @@ def _name_template(self) -> str: return "{0}" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class NNModuleSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -936,28 +942,28 @@ def _name_template(self) -> str: return "{0}" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class UnspecializedNNModuleSource(NNModuleSource): @functools.cached_property def guard_source(self) -> GuardSource: return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source] -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource): @functools.cached_property def guard_source(self) -> GuardSource: return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source] -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class FSDPNNModuleSource(NNModuleSource): @functools.cached_property def guard_source(self) -> GuardSource: return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source] -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class GlobalStateSource(Source): @property def _name_template(self) -> str: @@ -968,7 +974,7 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class TorchSource(Source): """Points to the actual `torch` module - used instead of GlobalSource in case the user has overridden `torch` in their local namespace""" @@ -997,7 +1003,7 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class CollectionsSource(Source): """Points to the actual `collections` module - used instead of GlobalSource in case the user has overridden `collections` in their local namespace""" @@ -1026,7 +1032,7 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class TorchFunctionModeStackSource(Source): ind: int @@ -1053,7 +1059,7 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ConstantSource(Source): source_name: str @@ -1072,7 +1078,7 @@ def make_guard(self, fn: Any) -> Any: raise NotImplementedError -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class NumpyTensorSource(ChainedSource): @property def _name_template(self) -> str: @@ -1084,7 +1090,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(1, False)) -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class SubclassAttrListSource(ChainedSource): @property def _name_template(self) -> str: @@ -1093,14 +1099,14 @@ def _name_template(self) -> str: # NB: We don't expect you to actually ever generate guards against this # source, it is ephemeral -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class FloatTensorSource(ChainedSource): @property def _name_template(self) -> str: return "___as_tensor({0})" -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class CallMethodItemSource(ChainedSource): @property def _name_template(self) -> str: @@ -1110,7 +1116,7 @@ def _name_template(self) -> str: # This is a synthetic source that is associated with the singleton # shape env guard we always register for all frames. We get the actual # guard contents from the ambient ShapeEnv -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ShapeEnvSource(Source): @property def _name_template(self) -> str: @@ -1121,7 +1127,7 @@ def guard_source(self) -> GuardSource: return GuardSource.SHAPE_ENV -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class CurrentStreamSource(Source): device: device_type @@ -1147,7 +1153,7 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class BackwardStateSource(Source): @property def _name_template(self) -> str: diff --git a/torch/_guards.py b/torch/_guards.py index af13099773f8b..03b619f65ae48 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -6,6 +6,7 @@ import functools import logging import re +import sys import threading import traceback import unittest.mock @@ -14,7 +15,28 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union +from typing import ( + Any, + Generic, + NamedTuple, + Optional, + overload, + TYPE_CHECKING, + TypeVar, + Union, +) + + +if sys.version_info >= (3, 11): + from typing import dataclass_transform +else: + + def dataclass_transform(): + def decorator(fn): + return fn + + return decorator + import torch from torch.utils import _pytree as pytree @@ -1076,9 +1098,41 @@ def tracing( _TLS.tracing_context = old_context +@overload +def dataclass_with_cached_hash(cls: type[T], **kwargs: Any) -> type[T]: ... + + +@overload +def dataclass_with_cached_hash( + cls: None = None, **kwargs: Any +) -> Callable[[type[T]], type[T]]: ... + + +@dataclass_transform() +def dataclass_with_cached_hash( + cls: type[T] | None = None, **kwargs: Any +) -> type[T] | Callable[[type[T]], type[T]]: + def wrap(cls_inner: type[T]) -> type[T]: + new_cls = dataclasses.dataclass(cls_inner, **kwargs) + old_hash = cls_inner.__hash__ + + def __hash__(self) -> int: + if not hasattr(self, "_hash"): + object.__setattr__(self, "_hash", old_hash(self)) + return self._hash + + new_cls.__hash__ = __hash__ + return new_cls # type: ignore[return-value] + + if cls is None: + return wrap + + return wrap(cls) + + # Subclasses can be found in torch/_dynamo/source.py # TODO(voz): Consider a toplevel torch/_source.py -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class Source: def is_dict_key(self) -> bool: return False @@ -1137,7 +1191,7 @@ def subguards_allowed(self) -> bool: # Subclasses can be found in torch/_dynamo/source.py -@dataclasses.dataclass(frozen=True) +@dataclass_with_cached_hash(frozen=True) class ChainedSource(Source): base: Source From 1afe2832f58e24e54a5bfda5a5afa9b96fdea40e Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 3 Dec 2025 17:16:55 -0800 Subject: [PATCH 246/338] [dynamo] Support Sequence-like placement user defined objects from_local (#169531) Similar to https://github.com/pytorch/pytorch/pull/168149 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169531 Approved by: https://github.com/bdhirsh --- .../tensor/test_dtensor_compile.py | 42 +++++++++++++++++++ torch/_dynamo/variables/torch.py | 21 +++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 9b1734b9b8682..05fd187bb7576 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -842,6 +842,48 @@ def fn(x): out_test = fn_opt(dt) self.assertEqual(out_ref, out_test) + def test_dynamo_from_local_grad_placements_sequence_intermediate(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + placements = PytreeTuple(Shard(0)) + + def fn(x): + dt = DTensor.from_local( + x, + mesh, + placements=placements, + run_check=False, + ) + return dt.to_local() + 2 + + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + x = torch.ones(4) + + out_ref = fn(x) + out_test = fn_opt(x) + self.assertEqual(out_ref, out_test) + + def test_dynamo_from_local_grad_placements_sequence_intermediate_as_args(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + placements = PytreeTuple(Shard(0)) + + def fn(x): + dt = DTensor.from_local( + x, + mesh, + placements, + run_check=False, + ) + return dt.to_local() + 2 + + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + x = torch.ones(4) + + out_ref = fn(x) + out_test = fn_opt(x) + self.assertEqual(out_ref, out_test) + def test_dynamo_to_local_kwargs(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index e5f21ebb72961..f2323c8b6b67c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -957,12 +957,31 @@ def handle_constant_processgroup_functions( def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function - args_as_value = [x.as_python_constant() for x in args[1:]] + placements_vt = kwargs.get("placements") + + if placements_vt is None and len(args) >= 3: + placements_vt = args[2] + + if placements_vt is None: + placements_vt = ConstantVariable.create(None) + elif isinstance(placements_vt, variables.UserDefinedObjectVariable): + placements_vt = variables.BuiltinVariable(tuple).call_function( + tx, [placements_vt], {} + ) + + new_args = list(args) + if len(new_args) >= 3: + new_args[2] = placements_vt + elif kwargs.get("placements") is not None: + kwargs["placements"] = placements_vt + + args_as_value = [x.as_python_constant() for x in new_args[1:]] kwargs_as_value = { k: v.as_python_constant() for k, v in kwargs.items() if k not in ["shape", "stride"] } + kwargs_to_be_proxied = { k: kwargs[k] for k in ["shape", "stride"] if k in kwargs } From b6b6c80379388b7f9932c3e6a0f9907bf430e417 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 05:37:11 +0000 Subject: [PATCH 247/338] Revert "[opaque obj] Improve error msg for intermediate opaques (#167742)" This reverts commit d78f52b199c547106d4cd9d2856dd0805c118bf1. Reverted https://github.com/pytorch/pytorch/pull/167742 on behalf of https://github.com/huydhn due to Sorry, I need to revert this to cleanly revert https://github.com/pytorch/pytorch/pull/169204. Please rebase and reland this ([comment](https://github.com/pytorch/pytorch/pull/167742#issuecomment-3610340589)) --- test/test_opaque_obj_v2.py | 20 +------------------- torch/_dynamo/graph_break_registry.json | 11 ----------- torch/_dynamo/variables/torch.py | 23 ----------------------- 3 files changed, 1 insertion(+), 53 deletions(-) diff --git a/test/test_opaque_obj_v2.py b/test/test_opaque_obj_v2.py index 3015defd88349..99ff9058eda52 100644 --- a/test/test_opaque_obj_v2.py +++ b/test/test_opaque_obj_v2.py @@ -6,7 +6,6 @@ import torch from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import AotEagerAndRecordGraphs -from torch._dynamo.utils import counters as dynamo_counters from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, @@ -377,7 +376,7 @@ def forward(self, arg0_1, arg1_1): return (add,)""", # noqa: B950 ) - def test_compile_global(self): + def test_compile_intermediate(self): counter = Counter(0) def foo(x, y): @@ -418,23 +417,6 @@ def forward(self, arg0_1, arg1_1, arg2_1): return (add,)""", # noqa: B950 ) - def test_compile_create_intermediate(self): - dynamo_counters.clear() - - def foo(x, y): - counter = Counter(0) - z = torch.ops._TestOpaqueObject.increment_counter(counter, y) - x = x * z - return x - - inp = (torch.tensor(1), torch.tensor(0)) - torch.compile(foo)(*inp) - self.assertEqual(len(dynamo_counters["graph_break"]), 1) - self.assertTrue( - "Opaque object were created in the middle of the program and passed to a custom op." - in next(iter(dynamo_counters["graph_break"].keys())), - ) - def test_compile_attribute(self): counter = Counter(0) diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 5e706e77ba73d..29fd67b6c92de 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3711,16 +3711,5 @@ "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." ] } - ], - "GB0367": [ - { - "Gb_type": "Opaque object were created in the middle of the program and passed to a custom op.", - "Context": "Opaque object types: {intermediate_opaques}. Function: {self.value}", - "Explanation": "Opaque objects cannot be created inside the torch.compile region. They must be created before entering the compiled function.", - "Hints": [ - "Please create the opaque object before calling torch.compile ", - "and pass it in as an argument or as a global variable." - ] - } ] } diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index f2323c8b6b67c..89a7fee6df37f 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -41,7 +41,6 @@ import torch.fx import torch.nn from torch._guards import TracingContext -from torch._library.opaque_object import is_opaque_type from torch._logging import warning_once from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type @@ -87,7 +86,6 @@ TensorWithTFOverrideVariable, TorchFunctionModeStackVariable, ) -from .user_defined import UserDefinedObjectVariable try: @@ -1509,27 +1507,6 @@ def call_function( ) return self.call_tensor_method(tx, args, kwargs) - intermediate_opaques = [ - type(x.value) - for x in args - if x.source is None - and isinstance(x, UserDefinedObjectVariable) - and is_opaque_type(type(x.value)) - ] - if len(intermediate_opaques) > 0: - unimplemented( - gb_type="Opaque object were created in the middle of the program and passed to a custom op.", - context=f"Opaque object types: {intermediate_opaques}. Function: {self.value}", - explanation=( - "Opaque objects cannot be created inside the torch.compile region. " - "They must be created before entering the compiled function." - ), - hints=[ - "Please create the opaque object before calling torch.compile " - "and pass it in as an argument or as a global variable." - ], - ) - special_handler = self._get_handlers().get(self.value) if special_handler: result = special_handler(self, tx, *args, **kwargs) From ba1412546f3082c0958c077acc2025e4dbc33f1f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 05:40:24 +0000 Subject: [PATCH 248/338] Revert "[dynamo][dicts] Decentralize and Improve key hash implementation for Dict variable tracker (#169204)" This reverts commit c04e2c656f48d82d1521b867bbbf03967b9b7564. Reverted https://github.com/pytorch/pytorch/pull/169204 on behalf of https://github.com/huydhn due to This has been reverted internally ([comment](https://github.com/pytorch/pytorch/pull/169204#issuecomment-3610355613)) --- test/dynamo/test_dicts.py | 210 +----------------- .../TestCustomOp.test_impl_device_cpu | 0 torch/_dynamo/graph_break_registry.json | 44 ---- torch/_dynamo/utils.py | 18 -- torch/_dynamo/variables/base.py | 56 ----- torch/_dynamo/variables/builtin.py | 9 - torch/_dynamo/variables/constant.py | 25 --- torch/_dynamo/variables/dicts.py | 206 +++++++++++------ torch/_dynamo/variables/functions.py | 46 ---- torch/_dynamo/variables/higher_order_ops.py | 9 - torch/_dynamo/variables/lists.py | 34 --- torch/_dynamo/variables/misc.py | 37 --- torch/_dynamo/variables/tensor.py | 28 --- torch/_dynamo/variables/torch.py | 9 - torch/_dynamo/variables/user_defined.py | 87 ++------ 15 files changed, 154 insertions(+), 664 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 4c233ea9458f3..cdaeb2d91fbfb 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -19,7 +19,6 @@ import torch._functorch.config import torch.nn import torch.utils.checkpoint -from torch._dynamo.exc import Unsupported from torch._dynamo.testing import same from torch._dynamo.utils import dict_items from torch.testing._internal.common_utils import ( @@ -90,7 +89,7 @@ def forward(self, x): inp = torch.randn(4, 4) mod = Foo() - opt_f = torch.compile(mod, backend="eager", fullgraph=True) + opt_f = torch.compile(mod) self.assertEqual(mod(inp), opt_f(inp)) def test_dict_subclass_local_with_non_dict_method(self): @@ -519,7 +518,7 @@ def fn(d): args1 = {namedtuple: None, 3: torch.randn(3)} cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + opt_fn = torch.compile(fn, backend=cnts) self.assertEqual(fn(args1), opt_fn(args1)) self.assertEqual(cnts.frame_count, 1) # Test a failing namedtuple guard @@ -539,7 +538,7 @@ def fn(d, x): args1[3] = z cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + opt_fn = torch.compile(fn, backend=cnts) self.assertEqual(fn(args1, x), opt_fn(args1, x)) self.assertEqual(cnts.frame_count, 1) @@ -1063,6 +1062,8 @@ def fn(b: Any): a = {"one": torch.ones(1)} return a | b + from torch._dynamo.exc import Unsupported + for arg in args: with self.assertRaisesRegex(Unsupported, "Observed exception"): _ = fn(arg) @@ -1203,156 +1204,6 @@ def f(): opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) - def test_range_as_dict_key(self): - def fn(x): - d = {range(5): x * 2, range(10, 15): x * 3} - return d[range(0, 5, 1)] + d[range(10, 15)] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_tuple_as_dict_key(self): - def fn(x): - d = {(1, 2): x * 2, (3, 4, 5): x * 3} - return d[(1, 2)] + d[(3, 4, 5)] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_enum_as_dict_key(self): - class Color(enum.Enum): - RED = 1 - GREEN = 2 - BLUE = 3 - - def fn(x): - d = {Color.RED: x * 2, Color.GREEN: x * 3, Color.BLUE: x * 4} - return d[Color.RED] + d[Color.GREEN] + d[Color.BLUE] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_intenum_as_dict_key(self): - class Priority(enum.IntEnum): - LOW = 1 - MEDIUM = 2 - HIGH = 3 - - def fn(x): - d = {Priority.LOW: x * 2, Priority.MEDIUM: x * 3, Priority.HIGH: x * 4} - return d[Priority.LOW] + d[Priority.MEDIUM] + d[Priority.HIGH] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_frozenset_as_dict_key(self): - def fn(x): - d = {frozenset([1, 2]): x * 2, frozenset([3, 4, 5]): x * 3} - return d[frozenset([1, 2])] + d[frozenset([3, 4, 5])] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_typing_union_as_dict_key(self): - from typing import Union - - def fn(x): - d = {Union[int, str]: x * 2, Union[float, bool]: x * 3} - return d[Union[int, str]] + d[Union[float, bool]] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_numpy_dtype_as_dict_key(self): - import numpy as np - - def fn(x): - d = {np.float32: x * 2, np.int64: x * 3, np.bool_: x * 4} - return d[np.float32] + d[np.int64] + d[np.bool_] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_method_wrapper_as_dict_key(self): - add_method = list.__add__ - mul_method = list.__mul__ - - def fn(x): - # Method wrappers are the type of bound methods on built-in types - d = {add_method: x * 2, mul_method: x * 3} - return d[add_method] + d[mul_method] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_torch_builtin_function_as_dict_key(self): - def fn(x, y): - # Using torch built-in functions as dictionary keys - d = {torch.add: x * 2, torch.mul: y * 3, torch.sub: x + y} - return d[torch.add] + d[torch.mul] + d[torch.sub] - - x = torch.randn(4) - y = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x, y), opt_fn(x, y)) - - def test_frozen_dataclass_as_dict_key(self): - from dataclasses import dataclass - - @dataclass(frozen=True) - class Point: - x: int - y: int - - def fn(tensor): - p1 = Point(1, 2) - p2 = Point(3, 4) - d = {p1: tensor * 2, p2: tensor * 3} - return d[Point(1, 2)] + d[Point(3, 4)] - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - - def test_list_as_dict_key_raises_typeerror(self): - def fn(x): - d = {[1, 2, 3]: x * 2} - return d[[1, 2, 3]] - - x = torch.randn(4) - - # First check that eager execution raises TypeError - with self.assertRaises(TypeError): - fn(x) - - # Also check that compiled version raises TypeError - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - with self.assertRaisesRegex(Unsupported, "Observed exception"): - opt_fn(x) - - def test_get_default_nowrap_functions_as_dict_key(self): - def fn(x): - # Get the set of default nowrap functions - nowrap_funcs = torch.overrides.get_default_nowrap_functions() - # Use the set as a dict key and search for Tensor.grad.__get__ in it - d = {frozenset(nowrap_funcs): x * 2} - # Check if Tensor.grad.__get__ is in the set - if torch.Tensor.grad.__get__ in nowrap_funcs: - return d[frozenset(nowrap_funcs)] + x - return x - - x = torch.randn(4) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(x), opt_fn(x)) - instantiate_parametrized_tests(DictTests) @@ -1887,9 +1738,7 @@ def fn(x): new_gn = partial(gn, x=1) key = Container(new_gn, 4) new_dict[key] = 5 - # Make another key that should hash to the same value - key1 = Container(new_gn, 4) - return x * new_dict[key1] + return x * new_dict[key] x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -1898,53 +1747,6 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) - def test_custom_object_as_dict_key(self): - """Test that custom objects with __hash__ as dict keys are properly handled. - - This test verifies that when using custom objects with overridden __hash__ - and __eq__ as dictionary keys, two instances with the same hash and equality - should be recognized as the same key. - """ - - class CustomKey: - def __init__(self, value, name): - self.value = value - self.name = name - - def fn(x): - d = {} - # Create first instance - key1 = CustomKey(42, "test") - d[key1] = x * 2 - - # Create second instance with same values - should hash to same value - key2 = CustomKey(42, "test") - d[key2] = x * 3 # This should overwrite the first value - - return d[key1] * d[key2] - - x = torch.randn(4) - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertTrue(same(opt_fn(x), fn(x))) - - def test_user_defined_object(self): - class A: - def __init__(self): - self.x = {} - REF[self] = {} - - REF = {} - - def f(a, x): - REF[a]["foo"] = x - return x + 1 - - opt_f = torch.compile(f, backend="eager", fullgraph=True) - - x = torch.randn(4) - self.assertTrue(same(f(A(), x), opt_f(A(), x))) - class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict diff --git a/test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu b/test/dynamo_expected_failures/TestCustomOp.test_impl_device_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 29fd67b6c92de..a5c1d22eea1fd 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3667,49 +3667,5 @@ "Use custom operators instead of direct attribute/method access." ] } - ], - "GB0363": [ - { - "Gb_type": "User-defined object with overridden __hash__", - "Context": "hashing object of type={type(obj)} and variable tracker {vt}", - "Explanation": "Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", - "Hints": [ - "Dynamo does not support hashing user-defined objects with overridden __hash__", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } - ], - "GB0364": [ - { - "Gb_type": "Dynamo cannot determine whether the underlying object is hashable", - "Context": "is_python_hashable {self}", - "Explanation": "Dynamo does not know whether the underlying python object for {self} is hashable", - "Hints": [ - "Consider using a different type of object as the dictionary key instead of {type_self}.", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } - ], - "GB0365": [ - { - "Gb_type": "Dynamo cannot determine the hash of an object", - "Context": "get_python_hash {self}", - "Explanation": "Dynamo does not know the hash of the underlying python object for {self}", - "Hints": [ - "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } - ], - "GB0366": [ - { - "Gb_type": "Dynamo cannot determine the equality comparison of an object", - "Context": "is_python_equal {self}", - "Explanation": "Dynamo does not know the equality comparison of the underlying python object for {self}", - "Hints": [ - "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } ] } diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index afdd0c7aefa4d..d08b92de3441e 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4962,21 +4962,3 @@ def get_traced_code() -> Optional[list[CodeType]]: from torch._guards import TracingContext return TracingContext.get_traced_code() - - -def raise_on_overridden_hash(obj: Any, vt: VariableTracker) -> None: - from . import graph_break_hints - from .exc import unimplemented - - is_overridden = type(obj).__dict__.get("__hash__", False) - - if is_overridden: - unimplemented( - gb_type="User-defined object with overridden __hash__", - context=f"hashing object of type={type(obj)} and variable tracker {vt}", - explanation=f"Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", - hints=[ - "Dynamo does not support hashing user-defined objects with overridden __hash__", - *graph_break_hints.SUPPORTABLE, - ], - ) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index a794010f4083f..617f787e43d8a 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -683,62 +683,6 @@ def build( else: return variables.LazyVariableTracker.create(value, source) - def is_python_hashable(self): - """ - Unlike the variable tracker's own __hash__, this method checks whether - the underlying Python object referenced by this variable tracker is hashable. - """ - try: - type_self = self.python_type() - except NotImplementedError: - type_self = type(self) - - unimplemented( - gb_type="Dynamo cannot determine whether the underlying object is hashable", - context=f"is_python_hashable {self}", - explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable", - hints=[ - ( - f"Consider using a different type of object as the dictionary key instead of {type_self}." - ), - *graph_break_hints.SUPPORTABLE, - ], - ) - - def get_python_hash(self): - """ - Unlike the variable tracker’s own __hash__, this method is used by - ConstDictVariableTracker to compute the hash of the underlying key object. - """ - unimplemented( - gb_type="Dynamo cannot determine the hash of an object", - context=f"get_python_hash {self}", - explanation=f"Dynamo does not know the hash of the underlying python object for {self}", - hints=[ - ( - f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." - ), - *graph_break_hints.SUPPORTABLE, - ], - ) - - def is_python_equal(self, other): - """ - NB - Deliberately not overriding the __eq__ method because that can - disable the __hash__ for the vt itself. - """ - unimplemented( - gb_type="Dynamo cannot determine the equality comparison of an object", - context=f"is_python_equal {self}", - explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}", - hints=[ - ( - f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." - ), - *graph_break_hints.SUPPORTABLE, - ], - ) - def __init__( self, *, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 9bd1bae080508..40b2be0437373 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -3268,15 +3268,6 @@ def call_contains( ) -> VariableTracker: return a.call_method(tx, "__contains__", [b], {}) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.fn) - - def is_python_equal(self, other): - return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn - @contextlib.contextmanager def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]: diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 0b2eaaea80826..672fa1d804383 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -23,7 +23,6 @@ istype, np, raise_args_mismatch, - raise_on_overridden_hash, ) from .base import ValueMutationNew, VariableTracker @@ -341,20 +340,6 @@ def call_obj_hasattr( result = hasattr(self.value, name) return variables.ConstantVariable.create(result) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.value) - - def is_python_equal(self, other): - # Could be an EnumVariable as well - from .tensor import SymNodeVariable - - if isinstance(other, SymNodeVariable): - return self.as_python_constant() == other.evaluate_expr() - return self.as_python_constant() == other.as_python_constant() - class EnumVariable(VariableTracker): """VariableTracker for enum.Enum and enum.IntEnum instances @@ -403,13 +388,3 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker member = getattr(self.value, name) source = self.source and AttrSource(self.source, name) return VariableTracker.build(tx, member, source=source) - - def is_python_hashable(self): - raise_on_overridden_hash(self.value, self) - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 9b98c91723063..422cae7c4d3f1 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -20,11 +20,14 @@ import collections import functools +import inspect import operator import types -from collections.abc import Sequence +from collections.abc import Hashable as py_Hashable, Sequence from typing import Any, Optional, TYPE_CHECKING, Union +from torch._subclasses.fake_tensor import is_fake + from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import raise_observed_exception, unimplemented @@ -52,8 +55,8 @@ # [Adding a new supported class within the keys of ConstDictVariable] -# - Implement is_python_hashable() method in the VariableTracker subclass -# - Implement get_python_hash() and is_python_equal() methods for hashable types +# - Add its tracker type to is_hashable +# - (perhaps) Define how it is compared in _HashableTracker._eq_impl def was_instancecheck_override(obj: Any) -> bool: @@ -70,7 +73,7 @@ def raise_unhashable( raise_observed_exception( TypeError, tx, - msg=f"Unhashable type: {arg.python_type()!r} and variable tracker = {type(arg.realize())}", + args=[ConstantVariable(f"unhashable type: {type(arg.realize())}")], ) @@ -85,7 +88,52 @@ def is_hashable(x: VariableTracker) -> bool: and x.is_hashable() ): return True - return x.is_python_hashable() + + if isinstance(x, variables.TensorVariable): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return x.as_proxy().node.meta.get("example_value") is not None + elif isinstance(x, variables.TupleVariable): + return all(is_hashable(e) for e in x.items) + elif isinstance(x, variables.FrozenDataClassVariable): + return all(is_hashable(e) for e in x.fields.values()) + elif ( + isinstance(x, variables.UserDefinedObjectVariable) + and not was_instancecheck_override(x.value) + and inspect.getattr_static(x.value, "__hash__") is int.__hash__ + and isinstance(x.value, int) + ): + return isinstance(x.value, py_Hashable) + elif isinstance(x, variables.FunctoolsPartialVariable): + return ( + is_hashable(x.func) + and all(is_hashable(arg) for arg in x.args) + and all(is_hashable(value) for value in x.keywords.values()) + ) + else: + return isinstance( + x, + ( + variables.BuiltinVariable, + variables.SymNodeVariable, + variables.ConstantVariable, + variables.EnumVariable, + variables.FrozensetVariable, + variables.UserDefinedClassVariable, + variables.UserFunctionVariable, + variables.SkipFunctionVariable, + variables.misc.NumpyVariable, + variables.NNModuleVariable, + variables.UnspecializedNNModuleVariable, + variables.MethodWrapperVariable, + variables.TorchInGraphFunctionVariable, + variables.TypingVariable, + variables.FunctoolsPartialVariable, + variables.WeakRefVariable, + variables.TorchHigherOrderOperatorVariable, + ), + ) class ConstDictVariable(VariableTracker): @@ -106,47 +154,88 @@ class _HashableTracker: def __init__(self, vt: VariableTracker) -> None: # We specialize SymNodes vt = specialize_symnode(vt) - - # If Dynamo does not know the hashability of the vt, it will raise unsupported here + # TODO Temporarily remove to figure out what keys are we breaking on + # and add proper support for them if not is_hashable(vt): raise_unhashable(vt) self.vt = vt - def __hash__(self) -> int: - """ - Computes the hash value for the wrapped VariableTracker. - - For unrealized LazyVariableTrackers, uses the hash of the original value - to avoid realizing the tracker and inserting unnecessary guards. - For all other cases, delegates to the VariableTracker's get_python_hash method. - - Returns: - The hash value of the underlying variable tracker - """ + @property + def underlying_value(self) -> Any: if ( isinstance(self.vt, variables.LazyVariableTracker) and not self.vt.is_realized() and self.vt.is_hashable() ): - return hash(self.vt.original_value()) - return self.vt.get_python_hash() - - def __eq__(self, other) -> bool: - """ - Checks equality between two _HashableTracker instances. + return self.vt.original_value() + if isinstance(self.vt, variables.TensorVariable): + x = self.vt.as_proxy().node.meta["example_value"] + elif isinstance(self.vt, variables.TupleVariable): + Hashable = ConstDictVariable._HashableTracker + x = tuple(Hashable(e).underlying_value for e in self.vt.items) + elif isinstance(self.vt, variables.NNModuleVariable): + return self.vt.value + elif isinstance(self.vt, variables.UnspecializedNNModuleVariable): + return self.vt.value + elif isinstance(self.vt, variables.UserFunctionVariable): + return self.vt.get_function() + elif isinstance(self.vt, variables.WeakRefVariable): + # Access the underlying value inside the referent_vt for the key representation + Hashable = ConstDictVariable._HashableTracker + return Hashable(self.vt.referent_vt).underlying_value + elif isinstance(self.vt, variables.FrozenDataClassVariable): + Hashable = ConstDictVariable._HashableTracker + fields_values = { + k: Hashable(v).underlying_value + for k, v in self.vt.fields.items() # type: ignore[attr-defined] + } + return variables.FrozenDataClassVariable.HashWrapper( + self.vt.python_type(), fields_values + ) + elif isinstance(self.vt, variables.UserDefinedObjectVariable): + # The re module in Python 3.13+ has a dictionary (_cache2) with + # an object as key (`class _ZeroSentinel(int): ...`): + # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual + return self.vt.value # type: ignore[attr-defined,union-attr] + elif isinstance(self.vt, variables.FunctoolsPartialVariable): + Hashable = ConstDictVariable._HashableTracker + items = (self.vt.func, *self.vt.args, *self.vt.keywords.values()) + x = tuple(Hashable(e).underlying_value for e in items) + return x + else: + x = self.vt.as_python_constant() + return x - Delegates to the VariableTracker's is_python_equal method to compare - the underlying variable trackers for Python-level equality. + def __hash__(self) -> int: + return hash(self.underlying_value) + + @staticmethod + def _eq_impl(a: Any, b: Any) -> bool: + # TODO: Put this in utils and share it between variables/builtin.py and here + type_a, type_b = type(a), type(b) + if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)): + return False + + if isinstance(a, tuple): + Hashable = ConstDictVariable._HashableTracker + return len(a) == len(b) and all( + Hashable._eq_impl(u, v) for u, v in zip(a, b) + ) + elif is_fake(a): + return a is b + else: + return a == b - Args: - other: Another _HashableTracker instance to compare with + def __eq__(self, other: object) -> bool: + Hashable = ConstDictVariable._HashableTracker + assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), ( + type(other) + ) + if isinstance(other, Hashable): + return Hashable._eq_impl(self.underlying_value, other.underlying_value) - Returns: - True if the underlying variable trackers are Python-equal, False otherwise - """ - if self.vt is other.vt: - return True - return self.vt.is_python_equal(other.vt) + # constant + return Hashable._eq_impl(self.underlying_value, other) def __init__( self, @@ -235,7 +324,7 @@ def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return ( - vt.is_python_hashable() + is_hashable(vt) and Hashable(vt) in self.items and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) ) @@ -447,6 +536,8 @@ def call_method( Hashable = ConstDictVariable._HashableTracker + arg_hashable = args and is_hashable(args[0]) + if name == "__init__": temp_dict_vt = variables.BuiltinVariable(dict).call_dict( tx, *args, **kwargs @@ -515,7 +606,6 @@ def call_method( self.install_dict_keys_match_guard() return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -530,21 +620,16 @@ def call_method( tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = args[1] return ConstantVariable.create(None) - elif name == "__delitem__" and self.is_mutable(): - arg_hashable = args and is_hashable(args[0]) - if arg_hashable: - self.install_dict_keys_match_guard() - self.should_reconstruct_all = True - tx.output.side_effects.mutation(self) - self.items.__delitem__(Hashable(args[0])) - return ConstantVariable.create(None) - else: - return super().call_method(tx, name, args, kwargs) + elif name == "__delitem__" and arg_hashable and self.is_mutable(): + self.install_dict_keys_match_guard() + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.__delitem__(Hashable(args[0])) + return ConstantVariable.create(None) elif name == "get": if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -560,7 +645,6 @@ def call_method( if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -652,7 +736,6 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -668,7 +751,6 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) - arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) @@ -821,12 +903,6 @@ def clone(self, **kwargs: Any) -> VariableTracker: self.install_dict_keys_match_guard() return super().clone(**kwargs) - def is_python_hashable(self): - """ - Dictionaries are mutable and therefore not hashable in Python. - """ - return False - class MappingProxyVariable(VariableTracker): # proxies to the original dict_vt @@ -1340,18 +1416,6 @@ def call_method( return FrozensetVariable(r.items) # type: ignore[attr-defined] return super().call_method(tx, name, args, kwargs) - def is_python_hashable(self): - """ - Frozensets are immutable and hashable in Python. - """ - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class DictKeySetVariable(SetVariable): def debug_repr(self) -> str: @@ -1541,9 +1605,3 @@ def call_method( return self.dv_dict.call_method(tx, "__eq__", [args[0].dv_dict], {}) return ConstantVariable.create(False) return super().call_method(tx, name, args, kwargs) - - def is_python_hashable(self): - """ - Dictionary item views are not hashable in Python. - """ - return False diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index f493e0e1fd961..fdc2f53f82383 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -810,15 +810,6 @@ def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: return collected return None - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.fn) - - def is_python_equal(self, other): - return isinstance(other, variables.UserFunctionVariable) and self.fn is other.fn - class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): _nonvar_fields = { @@ -1966,15 +1957,6 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return fn_var_getattr(tx, self.value, self.source, name) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.value) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class WrappedSkipFunctionVariable(SkipFunctionVariable): def __init__( @@ -2361,34 +2343,6 @@ def guard_as_python_constant(self) -> Any: **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, ) - def is_python_hashable(self) -> bool: - return ( - self.func.is_python_hashable() - and all(arg.is_python_hashable() for arg in self.args) - and all(value.is_python_hashable() for value in self.keywords.values()) - ) - - def get_python_hash(self): - func_hash = self.func.get_python_hash() - args_hash = (arg.get_python_hash() for arg in self.args) - values_hash = (value.get_python_hash() for value in self.keywords.values()) - return hash((func_hash, *args_hash, *values_hash)) - - def is_python_equal(self, other): - return ( - self.func.is_python_equal(other.func) - and all( - arg_a.is_python_equal(arg_b) - for (arg_a, arg_b) in zip(self.args, other.args) - ) - and all( - value_a.is_python_equal(value_b) - for (value_a, value_b) in zip( - self.keywords.values(), other.keywords.values() - ) - ) - ) - class PolyfilledFunctionVariable(VariableTracker): _nonvar_fields = { diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index a4543821b19b1..0f7491911d35b 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1813,15 +1813,6 @@ def _call_function( def as_python_constant(self): return self.value - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index a97c284f9516c..4f21e35479fb8 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -620,25 +620,6 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return self.items[fields.index(name)] return super().var_getattr(tx, name) - def is_python_hashable(self): - return True - - def get_python_hash(self): - l = self.range_length() - start = self.start() - step = self.step() - return hash((l, start, step)) - - def is_python_equal(self, other): - if not isinstance(other, variables.RangeVariable): - return False - - return ( - self.start() == other.start() - and self.step() == other.step() - and self.stop() == other.stop() - ) - class CommonListMethodsVariable(BaseListVariable): """ @@ -1000,9 +981,6 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) - def is_python_hashable(self): - return False - class DequeVariable(CommonListMethodsVariable): def __init__( @@ -1192,18 +1170,6 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) - def is_python_hashable(self): - return all(item.is_python_hashable() for item in self.items) - - def get_python_hash(self): - items = tuple(x.get_python_hash() for x in self.items) - return hash(items) - - def is_python_equal(self, other): - return isinstance(other, variables.TupleVariable) and all( - a.is_python_equal(b) for (a, b) in zip(self.items, other.items) - ) - class SizeVariable(TupleVariable): """torch.Size(...)""" diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 748d4a0985b49..c7d6e58ba4531 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1306,15 +1306,6 @@ def is_python_constant(self): def as_python_constant(self): return self.method_wrapper - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class GetSetDescriptorVariable(VariableTracker): def __init__(self, desc, **kwargs) -> None: @@ -1449,15 +1440,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # codegen.append_output(codegen.create_load_const(self.value)) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): @@ -1636,15 +1618,6 @@ def as_proxy(self): return super().as_proxy() - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.as_python_constant()) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls class NullVariable(VariableTracker): @@ -2124,13 +2097,3 @@ def reconstruct(self, codegen: "PyCodegen"): codegen(self.referent_vt) codegen(self.callback_vt) codegen.extend_output(create_call_function(2, False)) - - def is_python_hashable(self): - return self.referent_vt.is_python_hashable() - - def get_python_hash(self): - # weakref relies on the referent's hash - return self.referent_vt.get_python_hash() - - def is_python_equal(self, other): - return self.referent_vt.is_python_equal(other.referent_vt) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index d47c520046d38..47439387e0fca 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1428,20 +1428,6 @@ def set_name_hint(self, name: str): self.proxy.node._rename(name) self._is_name_set = True - def is_python_hashable(self): - # Tensors are hashable if they have an example_value (a fake tensor) - # Most VT's should have one. - # It'd be nice if at some point we could assert that they all have one - return self.as_proxy().node.meta["example_value"] is not None - - def get_python_hash(self): - return hash(self.as_proxy().node.meta["example_value"]) - - def is_python_equal(self, other): - a = self.as_proxy().node.meta["example_value"] - b = other.as_proxy().node.meta["example_value"] - return a is b - class SymNodeVariable(VariableTracker): """ @@ -1530,20 +1516,6 @@ def call_method( ), ) - def is_python_hashable(self): - return True - - def get_python_hash(self): - # Essentially convert the SymNode to a constant variable whenever its - # searched for a dict key. - return hash(self.evaluate_expr()) - - def is_python_equal(self, other): - if isinstance(other, SymNodeVariable): - return self.evaluate_expr() == other.evaluate_expr() - # could be constant variable as well - return self.evaluate_expr() == other.as_python_constant() - class NumpyNdarrayVariable(TensorVariable): """ diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 89a7fee6df37f..3d0541dacfd6f 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -2094,15 +2094,6 @@ def torch_function_override_enabled(self, tx, args, kwargs): ) ) and can_dispatch_torch_function(tx, args, kwargs) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.value) - - def is_python_equal(self, other): - return self.as_python_constant() == other.as_python_constant() - class DispatchKeySetVariable(BaseTorchVariable): """represents torch.DispatchKeySet""" diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0863d8592abd2..cc377a09ab746 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -89,7 +89,6 @@ object_has_getattribute, proxy_args_kwargs, raise_args_mismatch, - raise_on_overridden_hash, set_methods, tensortype_to_dtype, tuple_methods, @@ -928,18 +927,6 @@ def const_getattr(self, tx: "InstructionTranslator", name): return self.value.__name__ return super().const_getattr(tx, name) - def is_python_hashable(self): - return True - - def get_python_hash(self): - return hash(self.value) - - def is_python_equal(self, other): - return ( - isinstance(other, variables.UserDefinedClassVariable) - and self.value is other.value - ) - class UserDefinedExceptionClassVariable(UserDefinedClassVariable): @property @@ -1756,20 +1743,26 @@ def call_obj_hasattr( handle_observed_exception(tx) return variables.ConstantVariable.create(False) - def is_python_hashable(self): - raise_on_overridden_hash(self.value, self) - return True - def get_python_hash(self): - # default hash - return hash(self.value) +class FrozenDataClassVariable(UserDefinedObjectVariable): + class HashWrapper: + """This class is hashed if a dataclass is used as a key in a dict. + It's necessary to avoid side effects from calling the __init__ of the dataclass class when hashing""" - def is_python_equal(self, other): - # id check - return self.value is other.value + def __init__(self, c, fields): + self.cls = c + self.fields = tuple(fields.items()) + def __eq__(self, other): + return ( + type(self) is type(other) + and self.cls == other.cls + and self.fields == other.fields + ) + + def __hash__(self): + return hash((self.cls, self.fields)) -class FrozenDataClassVariable(UserDefinedObjectVariable): @staticmethod def create(tx, value, source): from dataclasses import fields @@ -1871,22 +1864,6 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value_type.__name__})" - def is_python_hashable(self): - # TODO - Check corner cases like eq=False, hash=False etc - return True - - def get_python_hash(self): - return hash(tuple(arg.get_python_hash() for arg in self.fields.values())) - - def is_python_equal(self, other): - is_class_same = self.python_type() is other.python_type() - is_field_name_same = self.fields.keys() == other.fields.keys() - is_field_value_same = all( - value_a.is_python_equal(value_b) - for value_a, value_b in zip(self.fields.values(), other.fields.values()) - ) - return is_class_same and is_field_name_same and is_field_value_same - class SourcelessGraphModuleVariable(UserDefinedObjectVariable): def __init__( @@ -2107,10 +2084,6 @@ def install_dict_keys_match_guard(self): def install_dict_contains_guard(self): return self._dict_vt.install_dict_contains_guard() - def is_python_hashable(self): - raise_on_overridden_hash(self.value, self) - return False - class UserDefinedSetVariable(UserDefinedObjectVariable): """ @@ -2184,18 +2157,6 @@ def install_dict_keys_match_guard(self): def install_dict_contains_guard(self): return self._set_vt.install_dict_contains_guard() - def is_python_hashable(self): - raise_on_overridden_hash(self.value, self) - return self._set_vt.is_python_hashable() - - def get_python_hash(self): - return self._set_vt.get_python_hash() - - def is_python_equal(self, other): - return isinstance( - other, UserDefinedSetVariable - ) and self._set_vt.is_python_equal(other._set_vt) - class UserDefinedListVariable(UserDefinedObjectVariable): """ @@ -2237,10 +2198,6 @@ def unpack_var_sequence(self, tx): def is_underlying_vt_modified(self, side_effects): return side_effects.is_modified(self._list_vt) - def is_python_hashable(self): - raise_on_overridden_hash(self.value, self) - return False - class UserDefinedTupleVariable(UserDefinedObjectVariable): """ @@ -2289,18 +2246,6 @@ def unpack_var_sequence(self, tx): return self._tuple_vt.unpack_var_sequence(tx) raise NotImplementedError - def is_python_hashable(self): - raise_on_overridden_hash(self.value, self) - return self._tuple_vt.is_python_hashable() - - def get_python_hash(self): - return self._tuple_vt.get_python_hash() - - def is_python_equal(self, other): - return isinstance( - other, UserDefinedTupleVariable - ) and self._tuple_vt.is_python_equal(other._tuple_vt) - class MutableMappingVariable(UserDefinedObjectVariable): def __init__(self, value, **kwargs): From 2353a0f60eb4b4cb6675907a7fa9fbedc1c02e7f Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 3 Dec 2025 16:56:41 -0800 Subject: [PATCH 249/338] [dynamo] support tracing DeviceMesh._flatten (#169530) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169530 Approved by: https://github.com/anijain2305, https://github.com/xmfan --- test/dynamo/test_fake_distributed.py | 20 ++++++++++++++++++++ torch/_dynamo/variables/distributed.py | 8 ++++++++ 2 files changed, 28 insertions(+) diff --git a/test/dynamo/test_fake_distributed.py b/test/dynamo/test_fake_distributed.py index 41e373a50d76b..fca48c54f198d 100644 --- a/test/dynamo/test_fake_distributed.py +++ b/test/dynamo/test_fake_distributed.py @@ -135,6 +135,26 @@ def fn(x): res = fn(x) self.assertEqual(res, x) + def test_device_mesh_flatten(self): + device_mesh = init_device_mesh( + device_type="cpu", + mesh_shape=( + 1, + self.world_size, + ), + mesh_dim_names=("dp", "tp"), + ) + self.assertEqual(device_mesh.get_coordinate(), [0, 0]) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + dm = device_mesh._flatten() + return x + 1, dm.get_coordinate() + + x = torch.ones(10) + res = fn(x) + self.assertEqual(res, (x + 1, [0])) + instantiate_parametrized_tests(TestFakeDistributed) diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index f6faf4414d1da..cabb1786bed1f 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -318,6 +318,14 @@ def call_method( ) if name == "_get_or_create_default_group": return ProcessGroupVariable(self.value._get_or_create_default_group()) + if name == "_flatten": + from .builder import SourcelessBuilder + + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return SourcelessBuilder.create( + tx, self.value._flatten(*const_args, **const_kwargs) + ) return super().call_method(tx, name, args, kwargs) From b6b6d912df0b6f4082f8e50b18bd1de1dd7325f4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 06:03:36 +0000 Subject: [PATCH 250/338] Revert "[Dynamo][Guard]Add the user-friendly TYPE_MATCH for type (#169025)" This reverts commit bea4912944defdbcb8b061800caab6cbbbd01df5. Reverted https://github.com/pytorch/pytorch/pull/169025 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the new tests are failing internally D88329912 ([comment](https://github.com/pytorch/pytorch/pull/169025#issuecomment-3610438983)) --- test/dynamo/test_check_type_id.py | 123 ------------------------------ torch/_dynamo/guards.py | 10 +-- 2 files changed, 2 insertions(+), 131 deletions(-) delete mode 100644 test/dynamo/test_check_type_id.py diff --git a/test/dynamo/test_check_type_id.py b/test/dynamo/test_check_type_id.py deleted file mode 100644 index 4f63c140246ef..0000000000000 --- a/test/dynamo/test_check_type_id.py +++ /dev/null @@ -1,123 +0,0 @@ -# Owner(s): ["module: dynamo"] -""" -Test for TYPE_MATCH guard and ___check_type_id function. - -This test demonstrates how the TYPE_MATCH guard works in PyTorch Dynamo. -When a function is compiled, Dynamo installs guards to ensure the compiled -code remains valid. TYPE_MATCH guards ensure that values maintain their -exact type (using type identity, not just type equality). -""" - -import re - -import torch -import torch._dynamo -import torch._dynamo.test_case -from torch._dynamo.eval_frame import _debug_get_cache_entry_list -from torch.testing._internal.common_utils import munge_exc - - -class TestCheckTypeId(torch._dynamo.test_case.TestCase): - @staticmethod - def _find_guard_lines(guard_manager_str: str, keyword: str) -> list[str]: - # Normalize and anonymize type IDs, then return lines containing the keyword - normalized = re.sub( - r"\d{7,}", "", munge_exc(guard_manager_str), flags=re.MULTILINE - ) - pattern = re.compile(rf"^.*{re.escape(keyword)}.*$", re.MULTILINE) - return pattern.findall(normalized) - - def test_type_match_with_different_values(self): - """ - Test that TYPE_MATCH guard correctly identifies type mismatches. - - This test compiles a function that uses a global variable and verifies: - 1. The compiled function works with values of the same type - 2. The function recompiles when the type changes - 3. The ___check_type_id/check_obj_id guard is present in the generated code - 4. The check_type_id should present the user-friendly code that specify the type - """ - - # Define a global variable that we'll guard on - class Config: - multiplier = 2 # int type - - def fn(x): - # This will trigger a TYPE_MATCH guard on Config.multiplier - return x * Config.multiplier - - # Compile the function - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - - # First call - should compile and install guards - x = torch.randn(4) - result1 = opt_fn(x) - expected1 = x * 2 - self.assertTrue(torch.allclose(result1, expected1)) - - # Get the cache entry to inspect guards - cache_entries = _debug_get_cache_entry_list(fn.__code__) - self.assertEqual(len(cache_entries), 1) - - # Check that the guard string contains check_type_id - guard_str = str(cache_entries[0].guard_manager) - matches = self._find_guard_lines(guard_str, "ID_MATCH") - self.assertIn("___check_obj_id", matches[0]) - self.assertIn( - "type=.Config'>", - matches[0], - ) - self.assertEqual( - matches[0].split("#")[0], - "| | +- ID_MATCH: ___check_obj_id(L['Config'], ), type=.Config'> ", - ) - - def test_type_match_with_custom_classes(self): - """ - Test TYPE_MATCH guard with custom class instances. - - Demonstrates that the guard checks type identity, not structural equality. - """ - - class Point: - def __init__(self, x, y): - self.x = x - self.y = y - - class Point2D: - def __init__(self, x, y): - self.x = x - self.y = y - - point = Point(1, 2) - - def fn(tensor): - # Access point's attributes, triggering TYPE_MATCH guard on point - return tensor + point.x + point.y - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - - # First call with Point instance - x = torch.ones(4) - result1 = opt_fn(x) - expected1 = x + 1 + 2 - self.assertTrue(torch.allclose(result1, expected1)) - - # Verify guard contains check_type_id - cache_entries = _debug_get_cache_entry_list(fn.__code__) - self.assertEqual(len(cache_entries), 1) - - guard_str = str(cache_entries[0].guard_manager) - matches = self._find_guard_lines(guard_str, "TYPE_MATCH") - self.assertEqual( - matches[0].split("#")[0], - "| | +- TYPE_MATCH: ___check_type_id(L['point'], ), type=.Point'> ", - ) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index ea720d5c49f5f..e9097c592af9f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1958,7 +1958,7 @@ def TYPE_MATCH(self, guard: Guard) -> None: obj_id = self.id_ref(t, f"type({guard.name})") type_repr = repr(t) - code = f"___check_type_id({self.arg_ref(guard)}, {obj_id}), type={type_repr}" + code = f"___check_type_id({self.arg_ref(guard)}, {obj_id}) # {type_repr}" self._set_guard_export_info(guard, [code]) self.get_guard_manager(guard).add_type_match_guard( @@ -2060,13 +2060,7 @@ def id_match_unchecked( ref = self.arg_ref(guard) val = self.get(guard) id_val = self.id_ref(val, guard.name) - try: - type_repr = repr(val) - except Exception: - # During deepcopy reconstruction or other state transitions, - # objects may be in an incomplete state where repr() fails - type_repr = f"<{type(val).__name__}>" - code = f"___check_obj_id({ref}, {id_val}), type={type_repr}" + code = f"___check_obj_id({ref}, {id_val})" self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH") self.get_guard_manager(guard).add_id_match_guard( id_val, get_verbose_code_parts(code, guard, recompile_hint) From 43b94713bbf340d3c124fde02d0f73add4021247 Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 4 Dec 2025 06:32:13 +0000 Subject: [PATCH 251/338] Fix torch.fx for the newer "|" union syntax (#169453) This PR fixes torch.fx handling of the newer `|` type. Otherwise, they could be errors like ``` "/torch/package/importer.py", line 95, in get_name name = obj.__name__ AttributeError: 'types.UnionType' object has no attribute '__name__'. Did you mean: '__ne__'? ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/169453 Approved by: https://github.com/albanD, https://github.com/malfet --- test/test_fx.py | 10 ++++++++++ torch/fx/graph.py | 5 +++++ torch/fx/node.py | 2 ++ 3 files changed, 17 insertions(+) diff --git a/test/test_fx.py b/test/test_fx.py index 7fdd6552edc7b..e2584156bf730 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2381,6 +2381,16 @@ def test_typename_print_pre_pep585(self): self.assertTrue("typing.List[float]" in str(graph)) + def test_typename_print_union(self): + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,), type_expr=float|torch.Tensor|None + ) + output: torch.fx.Node = graph.output(b) + + self.assertTrue('float | torch.Tensor | None' in str(graph)) + def test_layout(self): class M(torch.nn.Module): def forward(self, x): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 36ef68a9a2e35..d4b0a1b1500d3 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -10,6 +10,7 @@ import os import pprint import re +import types import typing import warnings from collections import defaultdict @@ -499,6 +500,10 @@ def type_repr(o: Any): return "()" typename = _type_repr(o) + if isinstance(o, types.UnionType) and "|" in typename: + # str | int + args = [type_repr(arg) for arg in o.__args__] + return "|".join(args) if origin_type := getattr(o, "__origin__", None): # list[...], typing.List[...], TensorType[...] diff --git a/torch/fx/node.py b/torch/fx/node.py index 5afabe40ec341..85e6f3a82e969 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -174,6 +174,8 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: # Fixup segment_reduce mismatch if module == "torch" and name == "segment_reduce": name = "_" + name + if module == "torch.nn.functional" and name in ("_ScalingType", "_SwizzleType"): + name = name.removeprefix("_") return f"{module}.{name}" From 76aeb8c7e0f795b3fddca134cbea9a69da3ee696 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 4 Dec 2025 06:41:19 +0000 Subject: [PATCH 252/338] [DebugMode] default values for outputs, stack trace (#169499) Changes some default flag values for DebugMode: - `record_output=True` - `debug_string(show_stack_trace=...)` is set to the value of `record_stack_trace`, unless overridden. For existing tests, overrides values to the old ones to avoid wobbling expected test outputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169499 Approved by: https://github.com/yushangdi --- .../tensor/debug/test_debug_mode.py | 24 ++++++++++--------- torch/utils/_debug_mode.py | 11 ++++++--- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index dcc50bd268faa..37147c3ca9fe0 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -68,9 +68,7 @@ def test_debug_mode_mm(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with DebugMode( - record_torchfunction=True, record_ids=True, record_output=True - ) as debug_mode: + with DebugMode(record_torchfunction=True, record_ids=True) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( @@ -121,7 +119,8 @@ def mm(x, y): ) self.assertTrue(torch.equal(sum_op.record["output"], eager_out.to_local())) self.assertTrue( - "aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string() + "aten::sum(t: f32[1, 32]) -> t: f32[] # {'hash': " + in debug_mode.debug_string() ) # check tuple hash functions @@ -169,13 +168,13 @@ def test_debug_mode_backward(self): y_dtensor = DTensor.from_local(y, mesh, [Shard(1)], run_check=False) with DebugMode( - record_torchfunction=True, record_stack_trace=True + record_torchfunction=True, record_stack_trace=True, record_output=False ) as debug_mode: z = x_dtensor + y_dtensor z.sum().backward() self.assertExpectedInline( - debug_mode.debug_string(), + debug_mode.debug_string(show_stack_trace=False), """\ (dt: f32[8, 8]| S(0), dt: f32[8, 8]| S(1)) aten::add.Tensor(dt: f32[8, 8]| S(0), dt: f32[8, 8]| S(1)) @@ -215,7 +214,7 @@ def test_debug_mode_densor_redistribution_trace(self): y_dtensor = DTensor.from_local(y, mesh, [Shard(1), Shard(1)], run_check=False) x_dtensor._spec.shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(0, 1)),) y_dtensor._spec.shard_order = (ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 1)),) - with DebugMode(record_torchfunction=False) as debug_mode: + with DebugMode(record_torchfunction=False, record_output=False) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( @@ -254,7 +253,7 @@ def test_debug_mode_einsum(self): b_dt = DTensor.from_local(b, mesh, [Replicate(), Partial()], run_check=False) # Capture the operator decomposition - with DebugMode(record_torchfunction=True) as debug_mode: + with DebugMode(record_torchfunction=True, record_output=False) as debug_mode: torch.einsum("bld,dnh->blnh", a_dt, b_dt) self.assertExpectedInline( @@ -311,7 +310,7 @@ def test_real_tensor(self): x = torch.randn(8, 8, 8) linear = torch.nn.Linear(8, 8) - with DebugMode(record_torchfunction=True) as debug_mode: + with DebugMode(record_torchfunction=True, record_output=False) as debug_mode: linear(x).sum() self.assertExpectedInline( @@ -331,7 +330,9 @@ def test_fake_tensor(self): x = torch.randn(8, 8) y = torch.randn(8, 8, 8) - with DebugMode(record_torchfunction=True, record_faketensor=True) as debug_mode: + with DebugMode( + record_torchfunction=True, record_faketensor=True, record_output=False + ) as debug_mode: torch.matmul(y, x) self.assertExpectedInline( @@ -355,6 +356,7 @@ def test_tensor_attributes(self): record_faketensor=True, record_tensor_attributes=["a1", "a2"], store_original_args=True, + record_output=False, ) as debug_mode: torch.matmul(y, x) @@ -454,7 +456,7 @@ def forward(self, x): mod = Bar() inp = torch.randn(4, 4) - with DebugMode(record_nn_module=True) as debug_mode: + with DebugMode(record_nn_module=True, record_output=False) as debug_mode: _ = mod(inp) self.assertExpectedInline( diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 3303f2470e4da..abe9f6aa59ae1 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -599,7 +599,7 @@ def __init__( record_nn_module=False, store_original_args=False, record_stack_trace=False, - record_output=False, + record_output=True, record_ids=False, ) -> None: super().__init__() @@ -824,12 +824,17 @@ def record_triton_kernel( self.operators.append(call) return call - def debug_string(self, show_stack_trace: bool = False) -> str: + def debug_string(self, show_stack_trace: bool | None = None) -> str: """ - show_stack_trace: If True, display one-line stack trace summaries above groups + show_stack_trace: option to display one-line stack trace summaries above groups of operations (similar to gm.print_readable() style). Requires record_stack_trace=True. + if None, uses self.record_stack_trace, otherwise overrides it. """ + show_stack_trace = ( + self.record_stack_trace if show_stack_trace is None else show_stack_trace + ) + with torch._C.DisableTorchFunction(): if not show_stack_trace: result = "\n".join( From 4fefb8e7e942386ffac764a41b232241f82bea3a Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Thu, 4 Dec 2025 07:22:35 +0000 Subject: [PATCH 253/338] [OpenReg][CI][Refactor] Refactor OpenReg CI for better independency and functionality (#167958) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR isolates OpenReg test from other test cases in CI. ## Background The PrivateUse1 integration mechanism provides new vendors with a stable and convenient way to integrate their accelerators into the PyTorch ecosystem. Typically, new accelerators only require a CPU backend as fallback purpose in case certain operations or functions are not yet ready. Furthermore, some PrivateUse1-related features are limited to allowing only one accelerator. Therefore, to align with real-world use cases, we are refactoring the current OpenReg integration to ensure that OpenReg-related tests are executed using **only the CPU** (currently, OpenReg test cases are executed in environments including GPUs). ## Design - **Isolation**: Removing OpenReg tests from the default global test case set prevents these tests from running in environments containing accelerators other than CPUs and provides a dedicated entry point for supporting the execution of OpenReg tests. - **Reuse**: The integration mechanism based on PrivateUse1 needs to support multiple platforms, such as Linux, Windows and OSX. Therefore, we need to enable openreg test cases for these platforms by adding openreg configuration items to the workflow specifically for each platform. ## Validation Run ```shell python test/run_test.py --dry-run &> ./selected_tests && grep "test_openreg" ./selected_tests python test/run_test.py --openreg --dry-run &> ./selected_tests && grep "test_openreg" ./selected_tests ``` to check if `test_openreg` is included by the default test case range. image ## New CI Jobs The new OpenReg CI jobs require minimal compute resources (not include GPU resource) and typically finish in less than 10 minutes. | Name | OS | Accelerator | Python | Compiler | Arch | Status | | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----- | ----------- | ------ | -------- | ----- | --------- | | [linux-aarch64 / linux-jammy-aarch64-py3.10 / test (openreg, 1, 1, lf.linux.arm64.m7g.4xlarge) (push)](https://github.com/pytorch/pytorch/actions/runs/19557552621/job/56003633536?pr=167958)
[linux-aarch64 / linux-jammy-aarch64-py3.10 / test (openreg, 1, 1, lf.linux.arm64.m8g.4xlarge) (push)](https://github.com/pytorch/pytorch/actions/runs/19557552621/job/56003633541?pr=167958) | Linux | CPU | 3.10 | GCC 13 | arm64 | Passed ✔️ | | [pull / linux-jammy-py3.10-gcc11 / test (openreg, 1, 1, lf.linux.2xlarge) (pull_request)](https://github.com/pytorch/pytorch/actions/runs/19557475371/job/56003745878?pr=167958) | Linux | CPU | 3.10 | GCC 11 | x86 | Passed ✔️ | | [pull / linux-jammy-py3.10-clang18-asan / test (openreg, 1, 1, lf.linux.4xlarge) (pull_request)](https://github.com/pytorch/pytorch/actions/runs/19557475371/job/56004070351?pr=167958) | Linux | CPU | 3.10 | Clang 18 | x86 | Passed ✔️ | | [pull / linux-jammy-py3.10-clang12 / test (openreg, 1, 1, lf.linux.2xlarge) (pull_request)](https://github.com/pytorch/pytorch/actions/runs/19557475371/job/56003667828?pr=167958) | Linux | CPU | 3.10 | Clang 12 | x86 | Passed ✔️ | | [pull / linux-jammy-py3.13-clang12 / test (openreg, 1, 1, lf.linux.2xlarge) (pull_request)](https://github.com/pytorch/pytorch/actions/runs/19557475371/job/56003530148?pr=167958) | Linux | CPU | 3.13 | Clang 12 | x86 | Passed ✔️ | | [trunk / macos-py3-arm64 / test (openreg, 1, 1, macos-m1-stable) (push)](https://github.com/pytorch/pytorch/actions/runs/19557552684/job/56004207701?pr=167958) | OSX | CPU | 3.12 | Clang | arm64 | Passed ✔️ | | [trunk / win-vs2022-cpu-py3 / test (openreg, 1, 1, lf.windows.4xlarge.nonephemeral) (push)](https://github.com/pytorch/pytorch/actions/runs/19557552684/job/56006519084?pr=167958) | Win | CPU | 3.10 | VS 2022 | x86 | Passed ✔️ | Pull Request resolved: https://github.com/pytorch/pytorch/pull/167958 Approved by: https://github.com/albanD --- .ci/pytorch/macos-test.sh | 10 +++++++++ .ci/pytorch/test.sh | 7 +++++++ .ci/pytorch/win-test-helpers/test_openreg.bat | 21 +++++++++++++++++++ .ci/pytorch/win-test.sh | 9 ++++++-- .github/workflows/linux-aarch64.yml | 2 ++ .github/workflows/pull.yml | 8 +++++-- .github/workflows/trunk.yml | 2 ++ test/run_test.py | 15 +++++++++---- tools/testing/discover_tests.py | 1 - 9 files changed, 66 insertions(+), 9 deletions(-) create mode 100644 .ci/pytorch/win-test-helpers/test_openreg.bat diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 2687852a2c4f3..677f8318e2fa7 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -46,6 +46,14 @@ test_python_mps() { assert_git_not_dirty } +test_python_openreg() { + setup_test_python + + time python test/run_test.py --openreg --verbose + + assert_git_not_dirty +} + test_python_shard() { if [[ -z "$NUM_TEST_SHARDS" ]]; then @@ -393,6 +401,8 @@ elif [[ $TEST_CONFIG == *"perf_smoketest"* ]]; then test_torchbench_smoketest "${SHARD_NUMBER}" elif [[ $TEST_CONFIG == *"aot_inductor_perf_smoketest"* ]]; then test_aoti_torchbench_smoketest "${SHARD_NUMBER}" +elif [[ $TEST_CONFIG == *"openreg"* ]]; then + test_python_openreg elif [[ $TEST_CONFIG == *"mps"* ]]; then test_python_mps elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9118d6031a2a7..c15c72ca4fb08 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1828,6 +1828,11 @@ test_attention_microbenchmark() { --output-json-for-dashboard "${TEST_REPORTS_DIR}/attention_microbenchmark.json" } +test_openreg() { + python test/run_test.py --openreg --verbose + assert_git_not_dirty +} + if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then (cd test && python -c "import torch; print(torch.__config__.show())") (cd test && python -c "import torch; print(torch.__config__.parallel_info())") @@ -2012,6 +2017,8 @@ elif [[ "${TEST_CONFIG}" == "b200-symm-mem" ]]; then test_h100_symm_mem elif [[ "${TEST_CONFIG}" == h100_cutlass_backend ]]; then test_h100_cutlass_backend +elif [[ "${TEST_CONFIG}" == openreg ]]; then + test_openreg else install_torchvision install_monkeytype diff --git a/.ci/pytorch/win-test-helpers/test_openreg.bat b/.ci/pytorch/win-test-helpers/test_openreg.bat new file mode 100644 index 0000000000000..0470057daf641 --- /dev/null +++ b/.ci/pytorch/win-test-helpers/test_openreg.bat @@ -0,0 +1,21 @@ +call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat +:: exit the batch once there's an error +if not errorlevel 0 ( + echo "setup pytorch env failed" + echo %errorlevel% + exit /b +) + +pushd test + +echo Run openreg tests +python run_test.py --openreg --verbose +if ERRORLEVEL 1 goto fail + +popd + +:eof +exit /b 0 + +:fail +exit /b 1 diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index a01aa0b6431cd..69b248bdac533 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -25,8 +25,8 @@ mkdir -p "$TMP_DIR"/build/torch export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers -if [[ "$TEST_CONFIG" = "force_on_cpu" ]]; then - # run the full test suite for force_on_cpu test +if [[ "$TEST_CONFIG" = "force_on_cpu" || "$TEST_CONFIG" = "openreg" ]]; then + # run the full test suite for force_on_cpu test and openreg test export USE_CUDA=0 fi @@ -49,6 +49,11 @@ run_tests() { fi done + if [[ "$TEST_CONFIG" == "openreg" ]]; then + "$SCRIPT_HELPERS_DIR"/test_openreg.bat + return + fi + if [[ $NUM_TEST_SHARDS -eq 1 ]]; then "$SCRIPT_HELPERS_DIR"/test_python_shard.bat "$SCRIPT_HELPERS_DIR"/test_custom_script_ops.bat diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml index e6690b1043006..bb1a9a4f6a8b5 100644 --- a/.github/workflows/linux-aarch64.yml +++ b/.github/workflows/linux-aarch64.yml @@ -40,9 +40,11 @@ jobs: { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, ]} secrets: inherit diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index c85e2813a7f37..eb676389f86ac 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -71,6 +71,7 @@ jobs: { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -141,6 +142,7 @@ jobs: { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} sync-tag: asan-build secrets: inherit @@ -205,7 +207,8 @@ jobs: { config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" } + { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -241,7 +244,8 @@ jobs: { config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" } + { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 8c873f5d15162..d1fd936280e94 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -134,6 +134,7 @@ jobs: { config: "default", shard: 3, num_shards: 3, runner: "macos-m1-stable" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-15" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "macos-m1-stable" }, ]} secrets: inherit @@ -165,6 +166,7 @@ jobs: { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, ]} secrets: inherit diff --git a/test/run_test.py b/test/run_test.py index 349d0755360ec..c6a8473b5667b 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -196,7 +196,6 @@ def __contains__(self, item): "test_jit_legacy", "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", - "test_openreg", ] S390X_BLOCKLIST = [ @@ -262,13 +261,11 @@ def __contains__(self, item): # depend on z3-solver "fx/test_z3_gradual_types", "test_proxy_tensor", - "test_openreg", ] XPU_BLOCKLIST = [ "test_autograd", "profiler/test_memory_profiler", - "test_openreg", ] XPU_TEST = [ @@ -286,7 +283,6 @@ def __contains__(self, item): "test_multiprocessing", "test_multiprocessing_spawn", "test_namedtuple_return_api", - "test_openreg", "test_overrides", "test_show_pickle", "test_tensorexpr", @@ -1403,6 +1399,12 @@ def parse_args(): action="store_true", help=("If this flag is present, we will run xpu tests except XPU_BLOCK_LIST"), ) + parser.add_argument( + "--openreg", + "--openreg", + action="store_true", + help=("If this flag is present, we will only run test_openreg"), + ) parser.add_argument( "--cpp", "--cpp", @@ -1698,6 +1700,11 @@ def get_selected_tests(options) -> list[str]: # Exclude all xpu specific tests otherwise options.exclude.extend(XPU_TEST) + if options.openreg: + selected_tests = ["test_openreg"] + else: + options.exclude.append("test_openreg") + # Filter to only run onnx tests when --onnx option is specified onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS] if options.onnx: diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py index 1210326a02dbf..20d29693fd97a 100644 --- a/tools/testing/discover_tests.py +++ b/tools/testing/discover_tests.py @@ -139,7 +139,6 @@ def skip_test_p(name: str) -> bool: "doctests", "test_autoload_enable", "test_autoload_disable", - "test_openreg", ], ) From ae3a2395bf66151078e2d201716f7d63ce1c6f3e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 07:26:06 +0000 Subject: [PATCH 254/338] Revert "[export] Make RNNs exportable on GPUs (#163245)" This reverts commit b2b6b034c9fd08672c40e63ef243556ad4c49bd2. Reverted https://github.com/pytorch/pytorch/pull/163245 on behalf of https://github.com/huydhn due to Reverted internally ([comment](https://github.com/pytorch/pytorch/pull/163245#issuecomment-3610661311)) --- test/export/test_export.py | 83 +------------------------------ test/export/test_export_opinfo.py | 8 +-- torch/export/_trace.py | 2 - 3 files changed, 3 insertions(+), 90 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 0bb21f47f9381..92ea28c077e52 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -78,7 +78,6 @@ IS_WINDOWS, run_tests, skipIfCrossRef, - skipIfRocm, skipIfXpu, TEST_TRANSFORMERS, TEST_WITH_CROSSREF, @@ -8105,84 +8104,6 @@ def _patch_config(kwargs): ): _ = export(mod, inp, strict=True) - @requires_gpu - @skipIfRocm - @testing.expectedFailureSerDer - @testing.expectedFailureSerDerNonStrict - def test_export_lstm_gpu(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.rnn = torch.nn.LSTM( - input_size=4, hidden_size=5, num_layers=1, batch_first=True - ) - - def forward(self, x): - out, _ = self.rnn(x) - return out - - m = M().to(GPU_TYPE) - x = torch.randn(2, 3, 4, device=GPU_TYPE) - - ep = export(m, (x,)) - self.assertTrue(callable(ep.module())) - - eager_out = m(x) - export_out = ep.module()(x) - self.assertEqual(eager_out, export_out) - - @requires_gpu - @skipIfRocm - @testing.expectedFailureSerDer - @testing.expectedFailureSerDerNonStrict - def test_export_gru_gpu(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.rnn = torch.nn.GRU( - input_size=4, hidden_size=5, num_layers=1, batch_first=True - ) - - def forward(self, x): - out, _ = self.rnn(x) - return out - - m = M().to(GPU_TYPE) - x = torch.randn(2, 3, 4, device=GPU_TYPE) - - ep = export(m, (x,)) - self.assertTrue(callable(ep.module())) - - eager_out = m(x) - export_out = ep.module()(x) - self.assertEqual(eager_out, export_out) - - @requires_gpu - @skipIfRocm - @testing.expectedFailureSerDer - @testing.expectedFailureSerDerNonStrict - def test_export_rnn_flatten_parameters_gpu(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.lstm = torch.nn.LSTM( - input_size=3, hidden_size=4, num_layers=2, batch_first=True - ) - - def forward(self, x): - self.lstm.flatten_parameters() - out, (h, c) = self.lstm(x) - return out - - m = M().to(GPU_TYPE) - x = torch.randn(1, 5, 3, device=GPU_TYPE) - - ep = export(m, (x,), strict=False) - - eager_out = m(x) - export_out = ep.module()(x) - self.assertEqual(eager_out, export_out) - def test_device_to_static(self): class Module(torch.nn.Module): def forward(self, x): @@ -8808,7 +8729,7 @@ def forward(self, x): bn_num_batches_tracked = self.bn.num_batches_tracked; bn_num_batches_tracked = None _guards_fn = self._guards_fn(x); _guards_fn = None conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None - batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, False); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None + batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None return pytree.tree_unflatten((batch_norm,), self._out_spec)""", ) @@ -8829,7 +8750,7 @@ def forward(self, x): _guards_fn = self._guards_fn(x); _guards_fn = None conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None add_ = torch.ops.aten.add_.Tensor(bn_num_batches_tracked, 1); bn_num_batches_tracked = add_ = None - batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, False); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None + batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None return pytree.tree_unflatten((batch_norm,), self._out_spec)""", ) diff --git a/test/export/test_export_opinfo.py b/test/export/test_export_opinfo.py index 5eb42c461e574..075fd6df119b9 100644 --- a/test/export/test_export_opinfo.py +++ b/test/export/test_export_opinfo.py @@ -20,12 +20,7 @@ skipOps, xfail, ) -from torch.testing._internal.common_utils import ( - IS_FBCODE, - run_tests, - skipIfRocm, - TestCase, -) +from torch.testing._internal.common_utils import run_tests, skipIfRocm, TestCase from torch.utils import _pytree as pytree @@ -127,7 +122,6 @@ class TestExportOpInfo(TestCase): @skipOps( "TestExportOpInfo", "test_fake_export", export_failures | fake_export_failures ) - @unittest.skipIf(IS_FBCODE, "tests broken with unexpected successes internally") def test_fake_export(self, device, dtype, op): _test_export_helper(self, dtype, op) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index fdffacf512c20..856f23f68b19e 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -178,13 +178,11 @@ class ExportArtifact: def _ignore_backend_decomps(): orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False) orig_nnpack_flag = torch.backends.nnpack.set_flags(False) - orig_cudnn_flag = torch.backends.cudnn.set_flags(False) try: yield finally: torch.backends.mkldnn.set_flags(*orig_mkldnn_flag) torch.backends.nnpack.set_flags(*orig_nnpack_flag) - torch.backends.cudnn.set_flags(*orig_cudnn_flag) @contextmanager From 5634469fda9e5d98869c82c7d03bb08914245f96 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Wed, 3 Dec 2025 17:04:15 -0600 Subject: [PATCH 255/338] ci: Minimize dependency on build_environment for setup-nvidia action (#169428) I recently merged a change that makes setup-nvidia auto-detect if the runner has a NVIDIA GPU which makes most of these conditionals moot. See change here: * https://github.com/pytorch/test-infra/pull/7539 Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/169428 Approved by: https://github.com/yangw-dev --- .github/workflows/_bazel-build-test.yml | 1 - .github/workflows/_binary-test-linux.yml | 3 ++- .github/workflows/_linux-test.yml | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/_bazel-build-test.yml b/.github/workflows/_bazel-build-test.yml index 72241a772be61..fd66ccd8ea418 100644 --- a/.github/workflows/_bazel-build-test.yml +++ b/.github/workflows/_bazel-build-test.yml @@ -98,7 +98,6 @@ jobs: - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG uses: pytorch/test-infra/.github/actions/setup-nvidia@main - if: ${{ inputs.cuda-version != 'cpu' && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} - name: Output disk space left run: | diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index 476dd182db0f8..c4d4fca302e81 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -186,8 +186,9 @@ jobs: path: "${{ runner.temp }}/artifacts/" - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG + id: install-nvidia-driver uses: pytorch/test-infra/.github/actions/setup-nvidia@main - if: ${{ inputs.GPU_ARCH_TYPE == 'cuda' && steps.filter.outputs.is-test-matrix-empty == 'False' }} + if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' }} - name: configure aws credentials id: aws_creds diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 2434a595f5420..11e8af797d3c1 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -170,12 +170,12 @@ jobs: uses: pytorch/test-infra/.github/actions/setup-nvidia@main with: driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '580.82.07' }} - if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }} + if: ${{ !contains(matrix.config, 'nogpu') && !contains(matrix.runner, 'b200') }} - name: Setup GPU_FLAG for docker run id: setup-gpu-flag run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}" - if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }} + if: ${{ steps.install-nvidia-driver.outputs.has-nvidia == 'true' && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }} - name: Setup SCCACHE_SERVER_PORT environment for docker run when on container id: setup-sscache-port-flag @@ -325,7 +325,7 @@ jobs: # Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }} SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }} - SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }} + SHM_SIZE: ${{ steps.install-nvidia-driver.outputs.has-nvidia == 'true' && '2g' || '1g' }} DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} DOCKER_IMAGE_S390X: ${{ inputs.docker-image }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} From 35b7a9a26c5923d98aebaa41a031dae21788a9ee Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Wed, 3 Dec 2025 17:04:16 -0600 Subject: [PATCH 256/338] ci: Remove errant nogpu conditional (#169431) Very unsure why this was added in the first place but this is no longer necessary since setup-nvidia now supports early exit if there are no nvidia gpus attached. Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/169431 Approved by: https://github.com/huydhn, https://github.com/atalman, https://github.com/malfet ghstack dependencies: #169428 --- .github/workflows/_linux-test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 11e8af797d3c1..ee2837a7456f5 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -170,12 +170,12 @@ jobs: uses: pytorch/test-infra/.github/actions/setup-nvidia@main with: driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '580.82.07' }} - if: ${{ !contains(matrix.config, 'nogpu') && !contains(matrix.runner, 'b200') }} + if: ${{ !contains(matrix.runner, 'b200') }} - name: Setup GPU_FLAG for docker run id: setup-gpu-flag run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}" - if: ${{ steps.install-nvidia-driver.outputs.has-nvidia == 'true' && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }} + if: ${{ steps.install-nvidia-driver.outputs.has-nvidia == 'true' && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }} - name: Setup SCCACHE_SERVER_PORT environment for docker run when on container id: setup-sscache-port-flag From ffd9b0fb4355e97af82fc42cf185c3ffa0fc0a32 Mon Sep 17 00:00:00 2001 From: tianrengao Date: Thu, 4 Dec 2025 07:56:35 +0000 Subject: [PATCH 257/338] Resolve collective autotuning test failure on arm (#168919) Fixes https://github.com/pytorch/pytorch/pull/167294 collective autotuning test failure on ARM64 This PR resolves the earlier test failure on arm64 in main: ```ModuleNotFoundError: No module named 'torch._C._distributed_c10d'; 'torch._C' is not a package``` The failure occurred on ARM due to unsupported c10 imports and nccl for collective autotuning tests. This PR updates the test logic to skip on unsupported platforms, ensuring CI passes on ARM. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168919 Approved by: https://github.com/shunting314 --- test/inductor/test_collective_autotuning.py | 8 ++++++++ torch/_inductor/select_algorithm.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_collective_autotuning.py b/test/inductor/test_collective_autotuning.py index a5a05d05a9028..c8c993c5a3016 100644 --- a/test/inductor/test_collective_autotuning.py +++ b/test/inductor/test_collective_autotuning.py @@ -1,7 +1,15 @@ # Owner(s): ["module: inductor"] +import sys + import torch import torch.distributed as dist + + +if not dist.is_available() or not dist.is_nccl_available(): + print("c10d NCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index df71bdd3db502..7fb1a5539e005 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -3586,7 +3586,7 @@ def benchmark_collective_choice( try: # Do n warmups - total_time = cls._run_collective_benchmark( + cls._run_collective_benchmark( choice, inputs, output, nwarmup, process_group, timeout ) From eabb7ad2128580ef674446027b95bcf4e21e8df3 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Tue, 2 Dec 2025 00:31:53 +0000 Subject: [PATCH 258/338] [Inductor XPU GEMM] Step 1/N: Refactor cutlass configuration. (#160174) This PR is the first step toward implementing RFC #160175. Currently, all Cutlass-related Torch Inductor configs are located in `torch._inductor.config.cuda`. This PR refactors the device-agnostic Cutlass configurations into `torch._inductor.config.cutlass`, so they can be shared and reused by XPU as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160174 Approved by: https://github.com/EikanWang, https://github.com/mlazos, https://github.com/jansel --- benchmarks/inductor_backends/cutlass.py | 2 +- test/inductor/test_cutlass_backend.py | 110 +++++++++--------- torch/_inductor/codecache.py | 30 ++--- .../codegen/cuda/cuda_cpp_scheduling.py | 2 +- torch/_inductor/codegen/cuda/cuda_template.py | 2 +- torch/_inductor/codegen/cuda/cutlass_cache.py | 2 +- torch/_inductor/codegen/cuda/cutlass_utils.py | 4 +- torch/_inductor/codegen/cuda/gemm_template.py | 24 ++-- torch/_inductor/config.py | 81 ++++++++----- torch/_inductor/fuzzer.py | 2 +- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/utils.py | 8 +- torch/utils/_config_module.py | 11 ++ 13 files changed, 160 insertions(+), 122 deletions(-) diff --git a/benchmarks/inductor_backends/cutlass.py b/benchmarks/inductor_backends/cutlass.py index b2ed506302aec..af06333038947 100644 --- a/benchmarks/inductor_backends/cutlass.py +++ b/benchmarks/inductor_backends/cutlass.py @@ -125,7 +125,7 @@ def name(self) -> str: def to_options(self) -> dict[str, Any]: return { **super().to_options(), - "cuda.cutlass_instantiation_level": self.cutlass_instantiation_level, + "cutlass.cutlass_instantiation_level": self.cutlass_instantiation_level, } diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index b4c4f6f18f1eb..5d9b421e562e0 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -133,10 +133,10 @@ def gen_args(op, shape, dtype=torch.float16): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cuda.cutlass_tma_only": True, - "cuda.cutlass_epilogue_fusion_enabled": True, + "cutlass.cutlass_tma_only": True, + "cutlass.cutlass_epilogue_fusion_enabled": True, } ) @@ -144,9 +144,9 @@ def gen_args(op, shape, dtype=torch.float16): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cuda.cutlass_tma_only": True, + "cutlass.cutlass_tma_only": True, } ) @@ -234,8 +234,8 @@ def mm(a, b): "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cuda.cutlass_backend_min_gemm_size": 100000, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_backend_min_gemm_size": 100000, + "cutlass.cutlass_max_profiling_configs": 2, } ): with mock.patch( @@ -287,7 +287,7 @@ def test_cutlass_backend_subproc_mm(self): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cuda.cutlass_max_profiling_configs": 4, + "cutlass.cutlass_max_profiling_configs": 4, } ): Y_compiled = torch.compile(torch.mm)(a, b) @@ -324,7 +324,7 @@ def test_cutlass_backend_subproc_addmm(self, dtype): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cuda.cutlass_max_profiling_configs": 4, + "cutlass.cutlass_max_profiling_configs": 4, } ): for x_shape in x_shapes: @@ -354,7 +354,7 @@ def test_cutlass_backend_subproc_bmm(self): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cuda.cutlass_max_profiling_configs": 4, + "cutlass.cutlass_max_profiling_configs": 4, } ): Y_compiled = torch.compile(torch.bmm)(a, b) @@ -386,7 +386,7 @@ def forward(self, a, b, c): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): from torch._inductor.utils import run_and_get_code @@ -428,8 +428,8 @@ def forward(self, a, b, c): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 1, - "cuda.cutlass_max_profiling_swizzle_options": [ + "cutlass.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_swizzle_options": [ 1, 2, 4, @@ -505,7 +505,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -595,9 +595,9 @@ def forward(self, x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cuda.cutlass_tma_only": True, + "cutlass.cutlass_tma_only": True, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -677,7 +677,7 @@ def forward(self, x, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -746,7 +746,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): expected = [model(*input) for input in inputs] @@ -775,8 +775,8 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk( "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, - "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cutlass.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): for M, K, N in ( @@ -819,7 +819,7 @@ def test_streamk_with_dynamic( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"): @@ -849,8 +849,8 @@ def test_streamk_with_static( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, - "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cutlass.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): _ = compiled_model(a, b) @@ -884,7 +884,7 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion( "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 4, + "cutlass.cutlass_max_profiling_configs": 4, "cuda.version": "12.2", # required to enable the Kernels we need } ): @@ -983,7 +983,7 @@ def mm(a, b): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) @@ -1002,7 +1002,7 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): model = MyModel() @@ -1040,7 +1040,7 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): model = MyModel() @@ -1073,8 +1073,8 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem", + "cutlass.cutlass_max_profiling_configs": 1, } ): model = MyModel() @@ -1117,7 +1117,7 @@ def mm(a, b): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, "autotune_local_cache": True, } ): @@ -1157,9 +1157,9 @@ def my_addmm(x, a, b, alpha, beta): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, - "cuda.cutlass_op_allowlist_regex": "", - "cuda.cutlass_op_denylist_regex": "pingpong", + "cutlass.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_op_allowlist_regex": "", + "cutlass.cutlass_op_denylist_regex": "pingpong", } ): with mock.patch( @@ -1202,9 +1202,9 @@ def addmm(x, a, b, alpha, beta): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, - "cuda.cutlass_op_allowlist_regex": "pingpong", - "cuda.cutlass_op_denylist_regex": None, + "cutlass.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_op_allowlist_regex": "pingpong", + "cutlass.cutlass_op_denylist_regex": None, } ): with mock.patch( @@ -1273,7 +1273,7 @@ def run_test(use_fast_accum): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): with mock.patch( @@ -1350,7 +1350,7 @@ def test_cutlass_backend_shape_coverage_mm( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ), mock.patch( @@ -1461,8 +1461,8 @@ def test_standalone_runner(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, - "cuda.generate_test_runner": True, # put standalone runner in the generated code + "cutlass.cutlass_max_profiling_configs": 2, + "cutlass.generate_test_runner": True, # put standalone runner in the generated code } ): from tempfile import NamedTemporaryFile @@ -1544,7 +1544,7 @@ def mm(a, b): { "max_autotune": True, "max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, # needed for log searching "fx_graph_cache": False, "fx_graph_remote_cache": False, @@ -1608,8 +1608,8 @@ def counting_render(self, *args, **kwargs): "max_autotune_gemm_backends": "CUTLASS", "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cuda.enable_caching_codegen": True, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.enable_caching_codegen": True, + "cutlass.cutlass_max_profiling_configs": 2, } ): compiled_model = torch.compile(model, fullgraph=True) @@ -1660,10 +1660,10 @@ def counting_render(self, *args, **kwargs): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cuda.enable_caching_codegen": True, + "cutlass.enable_caching_codegen": True, } ): # Get expected results @@ -1721,10 +1721,10 @@ def counting_render(self, *args, **kwargs): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cuda.enable_caching_codegen": True, + "cutlass.enable_caching_codegen": True, } ): # Get expected results @@ -1752,7 +1752,7 @@ def test_cutlass_backend_matmul_same_tensor(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): compiled = torch.compile(torch.mm) @@ -1771,7 +1771,7 @@ def test_cutlass_backend_matmul_nonzero_offset(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): compiled = torch.compile(torch.mm) @@ -1795,7 +1795,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1817,7 +1817,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1845,7 +1845,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1871,7 +1871,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): if use_aoti: @@ -1968,7 +1968,7 @@ def forward(self, a, b, extra_args): # baseline is cutlass kernel + triton # matches expected casting behavior - with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}): + with config.patch({"cutlass.cutlass_epilogue_fusion_enabled": False}): ref_result = torch.compile(model)(a, b, extra_args) self.assertEqual( @@ -2377,7 +2377,7 @@ def test_config_number_post_filtering(self) -> None: "max_autotune_gemm_backends": "CUTLASS", # needed for log searching "force_disable_caches": True, - "cuda.cutlass_max_profiling_swizzle_options": [2], + "cutlass.cutlass_max_profiling_swizzle_options": [2], } ): with mock.patch( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index a30644312332b..2542d5ecefd3f 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -34,7 +34,7 @@ from tempfile import _TemporaryFileWrapper from time import time, time_ns from types import ModuleType -from typing import Any, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union +from typing import Any, cast, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import override, Self import torch @@ -3741,7 +3741,7 @@ def _load_triton_kernel_from_source( return getattr(PyCodeCache.load(source_code), kernel_name) -def _cuda_compiler() -> str | None: +def _cuda_compiler() -> Optional[str]: if cuda_env.nvcc_exist(config.cuda.cuda_cxx): return config.cuda.cuda_cxx if config.is_fbcode(): @@ -3759,7 +3759,7 @@ def _cutlass_path() -> str: return parutil.get_dir_path("cutlass-4-headers") else: - return config.cuda.cutlass_dir + return config.cutlass.cutlass_dir def _cutlass_paths() -> list[str]: @@ -3807,7 +3807,7 @@ def cutlass_key() -> bytes: return resource_file.read().encode() combined_hash = hashlib.sha256() - build_code_hash([config.cuda.cutlass_dir], "", combined_hash) + build_code_hash([config.cutlass.cutlass_dir], "", combined_hash) return combined_hash.digest() @@ -3877,14 +3877,14 @@ def _nvcc_compiler_options() -> list[str]: "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", "-w", f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", - config.cuda.compile_opt_level, + config.cutlass.compile_opt_level, "-std=c++17", "--expt-relaxed-constexpr", "-DNDEBUG", ] if config.is_fbcode(): options.extend(["-ccbin", os.path.dirname(build_paths.gcc)]) - if config.cuda.enable_debug_info: + if config.cutlass.enable_debug_info: options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) if config.cuda.enable_ptxas_info: options.extend( @@ -3896,7 +3896,7 @@ def _nvcc_compiler_options() -> list[str]: "--source-in-ptx", ] ) # Annotate the ptx file with source information - if config.cuda.use_fast_math: + if config.cutlass.use_fast_math: options.extend( [ "--use_fast_math", @@ -4100,7 +4100,7 @@ def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: Returns the hash key of source code, and the path to the file. """ - if config.cuda.cutlass_hash_with_compile_cmd: + if config.cutlass.cutlass_hash_with_compile_cmd: cuda_command = repr( cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) ) @@ -4151,7 +4151,7 @@ def compile( output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext error_path = binary_error_path(output_path) binary_remote_cache = cls.get_kernel_binary_remote_cache( - caching_enabled=config.cuda.use_binary_remote_cache + caching_enabled=config.cutlass.use_binary_remote_cache and not config.force_disable_caches, caching_available=config.is_fbcode(), ) @@ -4166,13 +4166,13 @@ def compile( cmd_parts, error_output = json.loads(error_json) if ( binary_remote_cache is not None - and config.cuda.upload_to_binary_remote_cache + and config.cutlass.upload_to_binary_remote_cache ): # This ensures that a local error is uploaded to the remote cache, # as we make no assumptions about the remote cache having the same # information as the local cache binary_remote_cache.put( - error_path, config.cuda.binary_remote_cache_force_write + error_path, config.cutlass.binary_remote_cache_force_write ) cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, error_json @@ -4236,11 +4236,11 @@ def compile( # Upload to remote cache if enabled if ( binary_remote_cache is not None - and config.cuda.upload_to_binary_remote_cache + and config.cutlass.upload_to_binary_remote_cache ): # will log on errors, but not fail out binary_remote_cache.put( - output_path, config.cuda.binary_remote_cache_force_write + output_path, config.cutlass.binary_remote_cache_force_write ) cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, None @@ -4293,10 +4293,10 @@ def _record_cuda_compile_error( # Upload to remote cache directly from memory if enabled if ( binary_remote_cache is not None - and config.cuda.upload_to_binary_remote_cache + and config.cutlass.upload_to_binary_remote_cache ): binary_remote_cache.put( - error_path, config.cuda.binary_remote_cache_force_write + error_path, config.cutlass.binary_remote_cache_force_write ) diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 2496860ca1f7c..16b09d4ba80eb 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -257,7 +257,7 @@ def _can_fuse_epilogue_impl( ) return False elif ( - not config.cuda.cutlass_epilogue_fusion_enabled + not config.cutlass.cutlass_epilogue_fusion_enabled or not config.epilogue_fusion ): why("cutlass epilogue fusion is not enabled") diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 79dfa9c6c391f..92c86120570d6 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -110,7 +110,7 @@ def generate_code_and_args( args are different. """ key: Optional[str] = None - if config.cuda.enable_caching_codegen: + if config.cutlass.enable_caching_codegen: key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) if key is not None and key in self.code_cache: diff --git a/torch/_inductor/codegen/cuda/cutlass_cache.py b/torch/_inductor/codegen/cuda/cutlass_cache.py index 66db98867b413..cad4a37902304 100644 --- a/torch/_inductor/codegen/cuda/cutlass_cache.py +++ b/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -75,7 +75,7 @@ def maybe_fetch_ops() -> Optional[list[Any]]: # get_cuda_version might return "12.4.0" or "12.4" # but we want to use "12.4" version: str = ".".join(get_cuda_version().split(".")[:2]) - instantiation_level: str = config.cuda.cutlass_instantiation_level + instantiation_level: str = config.cutlass.cutlass_instantiation_level # filename and filepath request_key: str = get_config_request_key(arch, version, instantiation_level) diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index fa46e8766cd58..3ce3a49bb94e9 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -98,7 +98,7 @@ def path_join(path0, path1): # contains both cutlass and cutlass_library # we need cutlass for eVT - cutlass_python_path = path_join(config.cuda.cutlass_dir, "python") + cutlass_python_path = path_join(config.cutlass.cutlass_dir, "python") torch_root = os.path.abspath(os.path.dirname(torch.__file__)) mock_src_path = os.path.join( torch_root, @@ -252,7 +252,7 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]: ) return {} arch = _normalize_cuda_arch(arch) - instantiation_level: str = config.cuda.cutlass_instantiation_level + instantiation_level: str = config.cutlass.cutlass_instantiation_level args = CUTLASSArgs( architectures=arch, cuda_version=version, diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index c4b7188bd9e62..9148ee7877d03 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -19,7 +19,7 @@ from torch._inductor.utils import clear_on_fresh_cache from ... import ir -from ...config import cuda as inductor_cuda_config +from ...config import cutlass as inductor_cutlass_config from ...ir import ( Buffer, ChoiceCaller, @@ -578,7 +578,7 @@ def _add_cutlass_gemm_choices( for name, op in ops: for ( swizzle - ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options: + ) in inductor_cutlass_config.cutlass_max_profiling_swizzle_options: description = f"{name} swizzle={swizzle}" self.maybe_append_choice( choices, @@ -635,7 +635,7 @@ def header(self) -> IndentedBuffer: #include "cutlass/util/tensor_view_io.h" """ ) - if inductor_cuda_config.generate_test_runner and not is_dynamic( + if inductor_cutlass_config.generate_test_runner and not is_dynamic( *self.input_nodes, self.output_node ): res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES) @@ -953,7 +953,7 @@ def filter_op( ) return None - if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op): + if inductor_cutlass_config.cutlass_tma_only and not self._has_tma_epilogue(op): return None # Set epilogue. @@ -975,14 +975,16 @@ def filter_op( return None # Apply regex filters at the end when configuration name doesn't change anymore - if inductor_cuda_config.cutlass_op_allowlist_regex: + if inductor_cutlass_config.cutlass_op_allowlist_regex: if not re.search( - inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name() + inductor_cutlass_config.cutlass_op_allowlist_regex, + op.configuration_name(), ): return None - if inductor_cuda_config.cutlass_op_denylist_regex is not None: + if inductor_cutlass_config.cutlass_op_denylist_regex is not None: if re.search( - inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name() + inductor_cutlass_config.cutlass_op_denylist_regex, + op.configuration_name(), ): return None @@ -1035,7 +1037,7 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: time.time() - start_time, ) sorted_res = sorted(res.items()) - ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs] + ret_res = sorted_res[: inductor_cutlass_config.cutlass_max_profiling_configs] if len(self.filtered_ops_cache) < 50: self.filtered_ops_cache[self.cache_key] = ret_res else: @@ -1277,7 +1279,9 @@ def render( # type: ignore[override] } options.update(dict(zip(extra_names, extra_inputs))) res = self._template_from_string(self._get_template()).render(**options) - if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): + if inductor_cutlass_config.generate_test_runner and not is_dynamic( + X, W, Y, Bias + ): test_runner_code = self._template_from_string( GEMM_STANDALONE_RUNNER_TEMPLATE ).render(**options) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index fcfb8f51ae6e7..47f3fd77908c0 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -6,7 +6,12 @@ import torch import torch._inductor.custom_graph_pass from torch._environment import is_fbcode -from torch.utils._config_module import Config, get_tristate_env, install_config_module +from torch.utils._config_module import ( + Config, + get_tristate_env, + inherit_fields_from, + install_config_module, +) if TYPE_CHECKING: @@ -1844,28 +1849,13 @@ class aot_inductor_mode: compile_standalone: bool = False -class cuda: - """Settings for cuda backend, today this consists of cutlass""" - - # CUDA arch to use for CUDA template kernel compilation. - # e.g. "70", "75", "80", "90", etc. - # When arch is None, Inductor uses torch.cuda.get_device_capability(0). - arch: Optional[str] = None - - # CUDA version to use for CUDA template kernel compilation. - # e.g. "11.4", "12.1", etc. - # When version is None, Inductor uses torch.version.cuda. - version: Optional[str] = None +class cutlass: + """ + Config specific to cutlass backend. + """ - # Optimization level for the host compiler. compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1" - # Whether to enable device LTO (link-time-optimization). - enable_cuda_lto = False - - # Whether to keep intermediate files dring compilation. - enable_ptxas_info = False - # Whether to enable debug info, e.g. line number, cutlass debug info. enable_debug_info = False @@ -1877,7 +1867,10 @@ class cuda: cutlass_dir = os.path.realpath( os.environ.get( "TORCHINDUCTOR_CUTLASS_DIR", - os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"), + os.path.join( + os.path.dirname(torch.__file__), + "../third_party/cutlass/", + ), ) ) @@ -1897,14 +1890,6 @@ class cuda: # Whether to only use TMA-compatible kernels in CUTLASS cutlass_tma_only = False - # Path to CUDA NVCC. - # NVCC search order: - # 1) cuda_cxx set in this config - # 2) CUDACXX environment variable - # 3) CUDA_HOME environment variable - # 4) default system search PATH. - cuda_cxx: Optional[str] = None - # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops. cutlass_backend_min_gemm_size: int = 1 @@ -1974,6 +1959,43 @@ class cuda: enable_caching_codegen: bool = True +@inherit_fields_from(cutlass) +class cuda(cutlass): + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + + +@inherit_fields_from(cutlass) +class xpu(cutlass): + # Xe arch to use for SYCL template kernel compilation. + # eg. 12, 20, which corresponding to Xe12(PVC) and Xe20 (BMG) + arch: Optional[str] = None + # oneAPI version to use for SYCL template kernel compilation. + # e.g. "20250201". + version: Optional[str] = None + + class rocm: # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. # If empty, the `native` arch is used @@ -2182,6 +2204,7 @@ class trace: # trace functions are not relevant to config caching "trace", # uses absolute path + "cutlass.cutlass_dir", "cuda.cutlass_dir", # not relevant "worker_start_method", diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 152dce2026766..2d288e683be5a 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -480,7 +480,7 @@ def keys(self) -> KeysView[ComboType]: "aot_inductor.presets": DEFAULT, # Typing "cuda.arch": DEFAULT, # Out of Scope "cuda.version": DEFAULT, # Out of Scope - "cuda.cutlass_dir": DEFAULT, # Out of Scope + "cutlass.cutlass_dir": DEFAULT, # Out of Scope "cuda.cuda_cxx": DEFAULT, # Out of Scope "rocm.arch": DEFAULT, # Out of Scope "rocm.ck_supported_arch": DEFAULT, # Out of Scope diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 7fb1a5539e005..a50277e4bcb23 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -3842,8 +3842,8 @@ def prescreen_choices( candidates = [] if ( - config.cuda.cutlass_prescreening - and len(config.cuda.cutlass_max_profiling_swizzle_options) > 1 + config.cutlass.cutlass_prescreening + and len(config.cutlass.cutlass_max_profiling_swizzle_options) > 1 ): candidates.extend( [ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d7f3844cdf1ba..0cafbed3a00c3 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2035,7 +2035,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) - if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: + if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size: return False from .codegen.cuda.cutlass_utils import try_import_cutlass @@ -2056,9 +2056,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: if not try_import_cutlass(): log.warning( "Failed to import CUTLASS lib. Please check whether " - "_inductor.config.cuda.cutlass_dir %s is set correctly. " + "_inductor.config.cutlass.cutlass_dir %s is set correctly. " "Skipping CUTLASS backend for now.", - config.cuda.cutlass_dir, + config.cutlass.cutlass_dir, ) return False return res @@ -2066,7 +2066,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: def _use_cutlass_for_op(op_name: str) -> bool: """Check if CUTLASS should be used for the given operation.""" - enabled_ops = config.cuda.cutlass_enabled_ops.upper() + enabled_ops = config.cutlass.cutlass_enabled_ops.upper() if enabled_ops == "ALL": return True return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 16fbad73a3097..0b3189e9dfed9 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -823,3 +823,14 @@ def get_tristate_env(name: str, default: Any = None) -> bool | None: if value == "0": return False return default + + +def inherit_fields_from(parent_cls): + def wrapper(child_cls): + for k, v in parent_cls.__dict__.items(): + if not k.startswith("_") and k not in ("__module__", "__doc__"): + if k not in child_cls.__dict__: + setattr(child_cls, k, v) + return child_cls + + return wrapper From 7716da9fb23f27a65b41f9f016a2afadf281c18f Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 4 Dec 2025 09:06:23 +0000 Subject: [PATCH 259/338] [10/N] Use Python 3.10 typing (#169229) This PR applies Python 3.10 typing syntax to some files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169229 Approved by: https://github.com/Lucaskabela --- torch/__init__.py | 60 ++++----- torch/_compile.py | 6 +- torch/_guards.py | 97 +++++++------- torch/_jit_internal.py | 6 +- torch/_linalg_utils.py | 10 +- torch/_lobpcg.py | 80 ++++++------ torch/_lowrank.py | 19 ++- torch/_meta_registrations.py | 204 +++++++++++++++--------------- torch/_ops.py | 35 +---- torch/_sources.py | 8 +- torch/_tensor.py | 52 ++++---- torch/_tensor_str.py | 4 +- torch/_utils.py | 6 +- torch/_utils_internal.py | 10 +- torch/_vmap_internals.py | 12 +- torch/_weights_only_unpickler.py | 12 +- torch/functional.py | 42 +++--- torch/hub.py | 12 +- torch/library.py | 66 +++++----- torch/masked/_ops.py | 174 +++++++++++++------------ torch/nn/_reduction.py | 9 +- torch/nn/common_types.py | 8 +- torch/nn/init.py | 28 ++-- torch/nn/modules/activation.py | 41 +++--- torch/nn/modules/batchnorm.py | 18 +-- torch/nn/modules/container.py | 14 +- torch/nn/modules/conv.py | 20 +-- torch/nn/modules/lazy.py | 4 +- torch/nn/modules/loss.py | 25 ++-- torch/nn/modules/module.py | 40 +++--- torch/nn/modules/normalization.py | 6 +- torch/nn/modules/pooling.py | 34 +++-- torch/nn/modules/rnn.py | 36 +++--- torch/nn/modules/sparse.py | 29 ++--- torch/nn/modules/transformer.py | 82 ++++++------ torch/nn/modules/upsampling.py | 25 ++-- torch/overrides.py | 4 +- torch/quasirandom.py | 11 +- torch/serialization.py | 48 +++---- torch/storage.py | 52 ++++---- torch/types.py | 22 ++-- torch/xpu/__init__.py | 18 ++- torch/xpu/random.py | 7 +- 43 files changed, 717 insertions(+), 779 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index ad32f8a054dc7..e6f9cfcb54472 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -320,7 +320,7 @@ def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> ctypes.CDLL(lib_path) -def _preload_cuda_deps(err: _Optional[OSError] = None) -> None: +def _preload_cuda_deps(err: OSError | None = None) -> None: cuda_libs: list[tuple[str, str]] = [ ("cublas", "libcublas.so.*[0-9]"), ("cudnn", "libcudnn.so.*[0-9]"), @@ -1276,7 +1276,7 @@ def set_default_device(device: "Device") -> None: _GLOBAL_DEVICE_CONTEXT.device_context = device_context -def set_default_tensor_type(t: _Union[type["torch.Tensor"], str], /) -> None: +def set_default_tensor_type(t: type["torch.Tensor"] | str, /) -> None: r""" .. warning:: @@ -1524,7 +1524,7 @@ def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: return _C._get_deterministic_algorithms_warn_only() -def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None: +def set_deterministic_debug_mode(debug_mode: builtins.int | str) -> None: r"""Sets the debug mode for deterministic operations. .. note:: This is an alternative interface for @@ -1686,7 +1686,7 @@ def is_warn_always_enabled() -> builtins.bool: def _check_with( error_type, - cond: _Union[builtins.bool, SymBool], + cond: builtins.bool | SymBool, message: _Callable[[], str], ): # noqa: F811 if not isinstance(cond, (builtins.bool, SymBool)): @@ -2092,7 +2092,7 @@ def _dtype(self): return torch.quint2x4 -_storage_classes: set[type[_Union[TypedStorage, UntypedStorage]]] = { +_storage_classes: set[type[TypedStorage | UntypedStorage]] = { UntypedStorage, DoubleStorage, FloatStorage, @@ -2398,13 +2398,13 @@ def __eq__(self, other): and self.dynamic == other.dynamic ) - def apply_mode(self, mode: _Optional[str]): + def apply_mode(self, mode: str | None): if mode and mode != "default": from torch._inductor import list_mode_options self.apply_options(list_mode_options(mode, self.dynamic)) - def apply_options(self, options: _Optional[dict[str, _Any]]): + def apply_options(self, options: dict[str, _Any] | None): if not options: return @@ -2524,12 +2524,10 @@ def compile( model: _Callable[_InputT, _RetT], *, fullgraph: builtins.bool = False, - dynamic: _Optional[builtins.bool] = None, - backend: _Union[str, _Callable] = "inductor", - mode: _Union[str, None] = None, - options: _Optional[ - dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] - ] = None, + dynamic: builtins.bool | None = None, + backend: str | _Callable = "inductor", + mode: str | None = None, + options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, disable: builtins.bool = False, ) -> _Callable[_InputT, _RetT]: ... @@ -2539,31 +2537,27 @@ def compile( model: None = None, *, fullgraph: builtins.bool = False, - dynamic: _Optional[builtins.bool] = None, - backend: _Union[str, _Callable] = "inductor", - mode: _Union[str, None] = None, - options: _Optional[ - dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] - ] = None, + dynamic: builtins.bool | None = None, + backend: str | _Callable = "inductor", + mode: str | None = None, + options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, disable: builtins.bool = False, ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ... def compile( - model: _Optional[_Callable[_InputT, _RetT]] = None, + model: _Callable[_InputT, _RetT] | None = None, *, fullgraph: builtins.bool = False, - dynamic: _Optional[builtins.bool] = None, - backend: _Union[str, _Callable] = "inductor", - mode: _Union[str, None] = None, - options: _Optional[ - dict[str, _Union[str, builtins.int, builtins.bool, _Callable]] - ] = None, + dynamic: builtins.bool | None = None, + backend: str | _Callable = "inductor", + mode: str | None = None, + options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, disable: builtins.bool = False, -) -> _Union[ - _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]], - _Callable[_InputT, _RetT], -]: +) -> ( + _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]] + | _Callable[_InputT, _RetT] +): """ Optimizes given model/function using TorchDynamo and specified backend. If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` @@ -2871,7 +2865,7 @@ def __getattr__(name): @functools.cache -def get_device_module(device: _Optional[_Union[torch.device, str]] = None): +def get_device_module(device: torch.device | str | None = None): """ Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). If no device is given, return the module for the current accelerator or CPU if none is present. @@ -2897,8 +2891,8 @@ def get_device_module(device: _Optional[_Union[torch.device, str]] = None): def _constrain_as_size( symbol, - min: _Optional[builtins.int] = None, - max: _Optional[builtins.int] = None, + min: builtins.int | None = None, + max: builtins.int | None = None, ): """ This indicates that a given int is size-like, and can be used in any context where a size is expected. diff --git a/torch/_compile.py b/torch/_compile.py index 76ddd3ccb05b4..bf7d715883d58 100644 --- a/torch/_compile.py +++ b/torch/_compile.py @@ -5,7 +5,7 @@ import functools from collections.abc import Callable -from typing import Optional, overload, TypeVar, Union +from typing import overload, TypeVar from typing_extensions import ParamSpec @@ -26,8 +26,8 @@ def _disable_dynamo( def _disable_dynamo( - fn: Optional[Callable[_P, _T]] = None, recursive: bool = True -) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: + fn: Callable[_P, _T] | None = None, recursive: bool = True +) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]: """ This API should be only used inside torch, external users should still use torch._dynamo.disable. The main goal of this API is to avoid circular diff --git a/torch/_guards.py b/torch/_guards.py index 03b619f65ae48..2f5b41527478b 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -15,16 +15,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import ( - Any, - Generic, - NamedTuple, - Optional, - overload, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Generic, NamedTuple, overload, TYPE_CHECKING, TypeVar if sys.version_info >= (3, 11): @@ -114,7 +105,7 @@ def __str__(self) -> str: return f"{self.frame_id}/{self.frame_compile_id}" @classmethod - def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]: + def from_string(cls, compile_id: str | None) -> CompileId | None: """ Factory method that creates a CompileId from its string representation. Keep this in sync with the __str__ method. @@ -277,14 +268,14 @@ class Guard: create_fn: Callable[[GuardBuilderBase, Guard], None] # Export only. These values are written to at time of guard check_fn creation. - guard_types: Optional[list[str]] = None - code_list: Optional[list[str]] = None - obj_weakref: Optional[object] = None - guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None - - stack: Optional[CapturedTraceback] = None - user_stack: Optional[traceback.StackSummary] = None - _hash: Optional[int] = None + guard_types: list[str] | None = None + code_list: list[str] | None = None + obj_weakref: object | None = None + guarded_class_weakref: weakref.ReferenceType[Any] | None = None + + stack: CapturedTraceback | None = None + user_stack: traceback.StackSummary | None = None + _hash: int | None = None _unserializable: bool = False def __hash__(self) -> int: @@ -401,7 +392,7 @@ def create_fn_name(self) -> str: def set_export_info( self, guard_type: str, - guarded_class: Optional[weakref.ReferenceType[Any]], + guarded_class: weakref.ReferenceType[Any] | None, code_list: list[str], obj_weakref: object, ) -> None: @@ -514,7 +505,7 @@ class GuardsCheckpointState: def __init__(self, dynamo_guards: set[Guard]) -> None: self.dynamo_guards = dynamo_guards - def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]: + def diff(self, other: GuardsCheckpointState) -> set[Guard] | None: """ Produces a delta against another GuardsCheckpointState. @@ -538,7 +529,7 @@ class ModuleContextCheckpointState: def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None: self.nn_modules = nn_modules - def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]: + def diff(self, other: ModuleContextCheckpointState) -> set[str] | None: """ Produces a delta against another ModuleContextCheckpointState. @@ -574,7 +565,7 @@ class GlobalContextCheckpointState: def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None: self.global_state = global_states - def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]: + def diff(self, other: GlobalContextCheckpointState) -> set[str] | None: """ Produces a delta against another GlobalContextCheckpointState. @@ -627,7 +618,7 @@ def restore_graphstate(self, state: GlobalContextCheckpointState) -> None: # Like a Set[Guard] but will record the user stack on all guards at the # time they were installed at their destination class GuardsSet: - def __init__(self, inner: Optional[set[Guard]] = None) -> None: + def __init__(self, inner: set[Guard] | None = None) -> None: if inner is None: inner = set() self.inner = inner @@ -705,13 +696,13 @@ def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ... def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ... @abstractmethod - def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: ... + def get_autograd_key_entry(self, identifier: str) -> Callable | None: ... @abstractmethod def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ... @abstractmethod - def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ... + def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: ... @abstractmethod def add_lazy_bwd_entry( @@ -724,7 +715,7 @@ def add_lazy_bwd_entry( @abstractmethod def get_lazy_bwd_entry( self, identifier: str, tangent_metadata: tuple[object] - ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ... + ) -> tuple[torch.fx.GraphModule | None, int | None]: ... class InvokeSubgraphCache(HopSubgraphCache): @@ -748,13 +739,13 @@ def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: self.autograd_cache[identifier] = key - def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: + def get_autograd_key_entry(self, identifier: str) -> Callable | None: return self.autograd_cache.get(identifier, None) def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: self.proxy_dispatch_cache[identifier] = key - def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: + def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: return self.proxy_dispatch_cache.get(identifier, None) def add_lazy_bwd_entry( @@ -770,7 +761,7 @@ def add_lazy_bwd_entry( def get_lazy_bwd_entry( self, identifier: str, tangent_metadata: tuple[object] - ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: + ) -> tuple[torch.fx.GraphModule | None, int | None]: if identifier not in self.lazy_bwd_cache: return (None, None) @@ -787,7 +778,7 @@ def add_effects(self, identifier: str, effects: set) -> None: ) self.effects_cache[identifier] = effects - def get_effects(self, identifier: str) -> Optional[set]: + def get_effects(self, identifier: str) -> set | None: """Retrieve the effect types for a given invoke_subgraph identifier.""" return self.effects_cache.get(identifier, None) @@ -836,7 +827,7 @@ def get() -> CompileContext: def try_get() -> CompileContext | None: return getattr(_TLS, "compile_context", None) - def __init__(self, compile_id: Optional[CompileId]) -> None: + def __init__(self, compile_id: CompileId | None) -> None: assert compile_id is None or isinstance(compile_id, CompileId) self.compile_id: CompileId | None = compile_id self.attempt = 0 @@ -844,14 +835,14 @@ def __init__(self, compile_id: Optional[CompileId]) -> None: self.shape_env_guards: list[str] = [] @staticmethod - def current_compile_id() -> Optional[CompileId]: + def current_compile_id() -> CompileId | None: self = CompileContext.try_get() if self is None: return None return self.compile_id @staticmethod - def current_trace_id() -> Optional[TraceId]: + def current_trace_id() -> TraceId | None: self = CompileContext.try_get() if self is None: return None @@ -880,13 +871,13 @@ def get() -> TracingContext: "TracingContext.get() must be called within an ongoing trace." ) - def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: + def __init__(self, fake_mode: FakeTensorMode | None) -> None: self.guards_context = GuardsContext() self.module_context = ModuleContext() self.global_context = GlobalContext() self.previously_inlined_functions: dict[Any, Any] = dict() self.previously_cleaned_instructions: dict[Any, Any] = dict() - self.fake_mode: Optional[FakeTensorMode] = fake_mode + self.fake_mode: FakeTensorMode | None = fake_mode self.frame_summary_stack: list[traceback.FrameSummary] = [] # This is morally part of frame_summary_stack, but it is kept separate # for clarity. As we process a frame, this variable gets updated @@ -894,16 +885,16 @@ def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: # function call, this gets cleared and the frame location is pushed # to frame_summary_stack (prepping this variable for the inner frame's # progress) - self.loc_in_frame: Optional[tuple[str, int, str]] = None + self.loc_in_frame: tuple[str, int, str] | None = None # this is only set after aot_autograd - self.fw_metadata: Optional[ViewAndMutationMeta] = None + self.fw_metadata: ViewAndMutationMeta | None = None # this is only set when the DDPOptimizer is used - self.ddp_optimizer_ctx: Optional[DDPOptimizerContext] = None + self.ddp_optimizer_ctx: DDPOptimizerContext | None = None # this is only set after aot_autograd - self.aot_graph_name: Optional[list[str]] = None - self.params_flat: Optional[list[Any]] = None - self.params_flat_unwrap_subclasses: Optional[list[Any]] = None - self.params_unwrapped_to_flat_index: Optional[list[Any]] = None + self.aot_graph_name: list[str] | None = None + self.params_flat: list[Any] | None = None + self.params_flat_unwrap_subclasses: list[Any] | None = None + self.params_unwrapped_to_flat_index: list[Any] | None = None # this is for extended return calling convention from backend # compiler to aot_autograd # Per output, what the compiler specified stride of the output is, @@ -1007,7 +998,7 @@ def clear_frame() -> Generator[None, None, None]: @staticmethod @contextlib.contextmanager def current_frame( - frame_summary: Optional[traceback.FrameSummary], + frame_summary: traceback.FrameSummary | None, ) -> Generator[None, None, None]: # frame_summary can be None to solely take advantage of real_stack # attachment to thrown exceptions @@ -1030,7 +1021,7 @@ def current_frame( @staticmethod @contextlib.contextmanager def report_output_strides() -> Generator[ - Optional[list[Optional[tuple[int, ...]]]], None, None + list[tuple[int, ...] | None] | None, None, None ]: tc = TracingContext.try_get() if tc is None: @@ -1050,7 +1041,7 @@ def set_current_loc(filename: str, lineno: int, frame_name: str) -> None: TracingContext.get().loc_in_frame = (filename, lineno, frame_name) @staticmethod - def get_traced_code() -> Optional[list[CodeType]]: + def get_traced_code() -> list[CodeType] | None: tc = TracingContext.try_get() if tc is None: return None @@ -1059,8 +1050,8 @@ def get_traced_code() -> Optional[list[CodeType]]: @contextmanager def compile_context( - context: Optional[CompileContext], -) -> Generator[Optional[CompileContext], None, None]: + context: CompileContext | None, +) -> Generator[CompileContext | None, None, None]: old_context = getattr(_TLS, "compile_context", None) _TLS.compile_context = context try: @@ -1071,8 +1062,8 @@ def compile_context( @contextmanager def tracing( - context: Optional[TracingContext], -) -> Generator[Optional[TracingContext], None, None]: + context: TracingContext | None, +) -> Generator[TracingContext | None, None, None]: """ This function installs the passed in tracing context as a dynamic scoped global variable. @@ -1236,7 +1227,7 @@ def get_value( return value -def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: +def detect_fake_mode(inputs: Any = None) -> FakeTensorMode | None: """ Attempts to "detect" what the current fake mode is. If there is one ambiently available from TracingContext, we preferentially use that. Otherwise, we @@ -1273,7 +1264,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: # pyrefly: ignore [bad-argument-type] fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) if is_traceable_wrapper_subclass(flat_input): - out: list[Union[torch.Tensor, int, torch.SymInt]] = [] + out: list[torch.Tensor | int | torch.SymInt] = [] get_plain_tensors(flat_input, out=out) # type: ignore[arg-type] fake_tensors: list[FakeTensor] = [ x for x in out if isinstance(x, FakeTensor) @@ -1302,7 +1293,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: return None -def active_fake_mode() -> Optional[FakeTensorMode]: +def active_fake_mode() -> FakeTensorMode | None: """ Inspects the dispatch mode stack for an active fake mode and returns it. Returns None if no fake mode is active. diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 9efa0583cdea7..27c5768477dab 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -52,7 +52,7 @@ _P = ParamSpec("_P") _R = TypeVar("_R") -BuiltinUnionType: Union[type, tuple[type, ...]] = types.UnionType +BuiltinUnionType: type | tuple[type, ...] = types.UnionType LockType: type try: @@ -1236,7 +1236,7 @@ def _try_get_dispatched_fn(fn): def _get_named_tuple_properties( obj, - loc: Optional[torch._C._jit_tree_views.SourceRange] = None, + loc: torch._C._jit_tree_views.SourceRange | None = None, rcb=None, ): if loc is None: @@ -1531,7 +1531,7 @@ def _extract_tensors(obj): return tensors -def _get_model_id(obj) -> Optional[str]: +def _get_model_id(obj) -> str | None: if isinstance(obj, torch.jit.ScriptModule): return str(obj._c._type()) elif isinstance(obj, torch.jit.ScriptFunction): diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index 43c8b65767e00..213393da9aa99 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs """Various linear algebra utility methods for internal use.""" -from typing import Optional - import torch from torch import Tensor @@ -29,7 +27,7 @@ def get_floating_dtype(A): return torch.float32 -def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: +def matmul(A: Tensor | None, B: Tensor) -> Tensor: """Multiply two matrices. If A is None, return B. A can be sparse or dense. B is always @@ -42,12 +40,12 @@ def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: return torch.matmul(A, B) -def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: +def bform(X: Tensor, A: Tensor | None, Y: Tensor) -> Tensor: """Return bilinear form of matrices: :math:`X^T A Y`.""" return matmul(X.mT, matmul(A, Y)) -def qform(A: Optional[Tensor], S: Tensor): +def qform(A: Tensor | None, S: Tensor): """Return quadratic form :math:`S^T A S`.""" return bform(S, A, S) @@ -57,7 +55,7 @@ def basis(A): return torch.linalg.qr(A).Q -def symeig(A: Tensor, largest: Optional[bool] = False) -> tuple[Tensor, Tensor]: +def symeig(A: Tensor, largest: bool | None = False) -> tuple[Tensor, Tensor]: """Return eigenpairs of A with specified ordering.""" if largest is None: largest = False diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 1137efdc5f63a..cdc426047c33f 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -3,8 +3,6 @@ # Author: Pearu Peterson # Created: February 2020 -from typing import Optional - import torch from torch import _linalg_utils as _utils, Tensor from torch.overrides import handle_torch_function, has_torch_function @@ -258,19 +256,19 @@ class LOBPCGAutogradFunction(torch.autograd.Function): def forward( # type: ignore[override] ctx, A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, + k: int | None = None, + B: Tensor | None = None, + X: Tensor | None = None, + n: int | None = None, + iK: Tensor | None = None, + niter: int | None = None, + tol: float | None = None, + largest: bool | None = None, + method: str | None = None, tracker: None = None, - ortho_iparams: Optional[dict[str, int]] = None, - ortho_fparams: Optional[dict[str, float]] = None, - ortho_bparams: Optional[dict[str, bool]] = None, + ortho_iparams: dict[str, int] | None = None, + ortho_fparams: dict[str, float] | None = None, + ortho_bparams: dict[str, bool] | None = None, ) -> tuple[Tensor, Tensor]: # makes sure that input is contiguous for efficiency. # Note: autograd does not support dense gradients for sparse input yet. @@ -344,19 +342,19 @@ def backward(ctx, D_grad, U_grad): # pyrefly: ignore # bad-override def lobpcg( A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, + k: int | None = None, + B: Tensor | None = None, + X: Tensor | None = None, + n: int | None = None, + iK: Tensor | None = None, + niter: int | None = None, + tol: float | None = None, + largest: bool | None = None, + method: str | None = None, tracker: None = None, - ortho_iparams: Optional[dict[str, int]] = None, - ortho_fparams: Optional[dict[str, float]] = None, - ortho_bparams: Optional[dict[str, bool]] = None, + ortho_iparams: dict[str, int] | None = None, + ortho_fparams: dict[str, float] | None = None, + ortho_bparams: dict[str, bool] | None = None, ) -> tuple[Tensor, Tensor]: """Find the k largest (or smallest) eigenvalues and the corresponding eigenvectors of a symmetric positive definite generalized @@ -584,19 +582,19 @@ def lobpcg( def _lobpcg( A: Tensor, - k: Optional[int] = None, - B: Optional[Tensor] = None, - X: Optional[Tensor] = None, - n: Optional[int] = None, - iK: Optional[Tensor] = None, - niter: Optional[int] = None, - tol: Optional[float] = None, - largest: Optional[bool] = None, - method: Optional[str] = None, + k: int | None = None, + B: Tensor | None = None, + X: Tensor | None = None, + n: int | None = None, + iK: Tensor | None = None, + niter: int | None = None, + tol: float | None = None, + largest: bool | None = None, + method: str | None = None, tracker: None = None, - ortho_iparams: Optional[dict[str, int]] = None, - ortho_fparams: Optional[dict[str, float]] = None, - ortho_bparams: Optional[dict[str, bool]] = None, + ortho_iparams: dict[str, int] | None = None, + ortho_fparams: dict[str, float] | None = None, + ortho_bparams: dict[str, bool] | None = None, ) -> tuple[Tensor, Tensor]: # A must be square: assert A.shape[-2] == A.shape[-1], A.shape @@ -696,10 +694,10 @@ class LOBPCG: def __init__( self, - A: Optional[Tensor], - B: Optional[Tensor], + A: Tensor | None, + B: Tensor | None, X: Tensor, - iK: Optional[Tensor], + iK: Tensor | None, iparams: dict[str, int], fparams: dict[str, float], bparams: dict[str, bool], diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 182883cfc5e59..25089d66d35ea 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -2,7 +2,6 @@ __all__ = ["svd_lowrank", "pca_lowrank"] -from typing import Optional import torch from torch import _linalg_utils as _utils, Tensor @@ -12,8 +11,8 @@ def get_approximate_basis( A: Tensor, q: int, - niter: Optional[int] = 2, - M: Optional[Tensor] = None, + niter: int | None = 2, + M: Tensor | None = None, ) -> Tensor: """Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is @@ -85,9 +84,9 @@ def get_approximate_basis( def svd_lowrank( A: Tensor, - q: Optional[int] = 6, - niter: Optional[int] = 2, - M: Optional[Tensor] = None, + q: int | None = 6, + niter: int | None = 2, + M: Tensor | None = None, ) -> tuple[Tensor, Tensor, Tensor]: r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, batches of matrices, or a sparse matrix :math:`A` such that @@ -149,9 +148,9 @@ def svd_lowrank( def _svd_lowrank( A: Tensor, - q: Optional[int] = 6, - niter: Optional[int] = 2, - M: Optional[Tensor] = None, + q: int | None = 6, + niter: int | None = 2, + M: Tensor | None = None, ) -> tuple[Tensor, Tensor, Tensor]: # Algorithm 5.1 in Halko et al., 2009 @@ -183,7 +182,7 @@ def _svd_lowrank( def pca_lowrank( A: Tensor, - q: Optional[int] = None, + q: int | None = None, center: bool = True, niter: int = 2, ) -> tuple[Tensor, Tensor, Tensor]: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 75e14dbc86b96..61533797f2dbe 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3,7 +3,7 @@ from collections.abc import Callable, Sequence from enum import Enum from functools import wraps -from typing import Optional, TypeVar, Union +from typing import TypeVar from typing_extensions import ParamSpec import torch @@ -547,9 +547,9 @@ def meta_sparse_structured_linear( input: Tensor, weight: Tensor, _meta: Tensor, - bias: Optional[Tensor] = None, - _activation_opt: Optional[str] = None, - out_dtype: Optional[torch.dtype] = None, + bias: Tensor | None = None, + _activation_opt: str | None = None, + out_dtype: torch.dtype | None = None, ): output_sizes = list(input.shape) if bias is not None: @@ -581,7 +581,7 @@ def meta_sparse_structured_mm( mat1: Tensor, mat1_meta: Tensor, mat2: Tensor, - out_dtype: Optional[torch.dtype] = None, + out_dtype: torch.dtype | None = None, ): assert len(mat1.shape) == 2 assert len(mat1_meta.shape) == 2 @@ -610,7 +610,7 @@ def meta_sparse_structured_addmm( *, alpha=1, beta=1, - out_dtype: Optional[torch.dtype] = None, + out_dtype: torch.dtype | None = None, ): assert len(input.shape) == 1, ( "only input broadcasted to columns of mat1 * mat2 product is supported" @@ -640,9 +640,9 @@ def meta_sparse_structured_addmm( def meta__cslt_sparse_mm( compressed_A: torch.Tensor, dense_B: torch.Tensor, - bias: Optional[Tensor] = None, - alpha: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + bias: Tensor | None = None, + alpha: Tensor | None = None, + out_dtype: torch.dtype | None = None, transpose_result: bool = False, alg_id: int = 0, split_k: int = 1, @@ -724,9 +724,9 @@ def meta_segment_reduce( data: Tensor, reduce: str, *, - lengths: Optional[Tensor] = None, - indices: Optional[Tensor] = None, - offsets: Optional[Tensor] = None, + lengths: Tensor | None = None, + indices: Tensor | None = None, + offsets: Tensor | None = None, axis: int = 0, unsafe: bool = False, initial=None, @@ -1468,7 +1468,7 @@ def _linalg_svd_meta( A: Tensor, full_matrices: bool = False, compute_uv: bool = True, - driver: Optional[str] = None, + driver: str | None = None, ): checkIsMatrix(A, "linalg.svd") checkFloatingOrComplex(A, "linalg.svd") @@ -1521,7 +1521,7 @@ def _linalg_broadcast_batch_dims( def _linalg_broadcast_batch_dims_name( arg1: Tensor, arg2: Tensor, - name: Optional[str], + name: str | None, ) -> tuple[Tensor, Tensor]: # If there's no name we assume we don't want to check the errors if name: @@ -1553,10 +1553,10 @@ def _linalg_solve_ex( *, left: bool = True, check_errors: bool = False, - result: Optional[Tensor] = None, - LU: Optional[Tensor] = None, - pivots: Optional[Tensor] = None, - info: Optional[Tensor] = None, + result: Tensor | None = None, + LU: Tensor | None = None, + pivots: Tensor | None = None, + info: Tensor | None = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: checkFloatingOrComplex(A, "linalg.solve") torch._check( @@ -1613,7 +1613,7 @@ def linalg_solve_triangular_meta( upper: bool, left: bool = True, unitriangular: bool = False, - out: Optional[Tensor] = None, + out: Tensor | None = None, ) -> Tensor: if out is None: out = A.new_empty([0]) @@ -2264,7 +2264,7 @@ def meta__fused_moving_avg_obs_fq_helper( @register_meta(aten.mm) @out_wrapper(exact_dtype=True) -def meta_mm(a, b, out_dtype: Optional[torch.dtype] = None): +def meta_mm(a, b, out_dtype: torch.dtype | None = None): torch._check(a.dim() == 2, lambda: "a must be 2D") torch._check(b.dim() == 2, lambda: "b must be 2D") N, M1 = a.shape @@ -2313,12 +2313,12 @@ def device_hint(tensor) -> "str": def calc_conv_nd_return_shape( input_tensor: torch.Tensor, weight: torch.Tensor, - stride: Union[list[int], int], - padding: Union[list[int], int], - dilation: Union[list[int], int], + stride: list[int] | int, + padding: list[int] | int, + dilation: list[int] | int, is_transposed: bool, groups: int, - output_padding: Optional[Union[list[int], int]] = None, + output_padding: list[int] | int | None = None, ): def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ @@ -2384,7 +2384,7 @@ def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int elif len(dilation) == 1: dilation = [dilation[0]] * len(dims) - output_padding_list: Optional[list[int]] = None + output_padding_list: list[int] | None = None if output_padding: if isinstance(output_padding, IntLike): # pyrefly: ignore [bad-assignment] @@ -2435,9 +2435,9 @@ def is_channels_last(ten): def meta_miopen_batch_norm( input_tensor: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor], - running_mean: Optional[torch.Tensor], - running_var: Optional[torch.Tensor], + bias: torch.Tensor | None, + running_mean: torch.Tensor | None, + running_var: torch.Tensor | None, training: bool, exponential_average_factor: float, epsilon: float, @@ -3383,7 +3383,7 @@ def meta_index_Tensor(self, indices): torch._check(bool(indices), lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors - result: list[Optional[Tensor]] = [] + result: list[Tensor | None] = [] for i, index in enumerate(indices): if index is not None: torch._check( @@ -3853,7 +3853,7 @@ def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs): @register_meta([aten._dyn_quant_pack_4bit_weight]) def meta__dyn_quant_pack_4bit_weight( - weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features + weights, scales_zeros, bias: Tensor | None, block_size, in_features, out_features ): torch._check( weights.dtype is torch.uint8, @@ -5655,7 +5655,7 @@ def meta__scaled_dot_product_flash_attention( dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -5737,12 +5737,12 @@ def meta__scaled_dot_product_cudnn_attention( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Optional[Tensor], + attn_bias: Tensor | None, compute_log_sumexp: bool, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): B = query.size(0) H = query.size(1) @@ -5781,11 +5781,11 @@ def meta__scaled_dot_product_fused_attention_overrideable( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Optional[Tensor] = None, + attn_bias: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): B = query.size(0) H_Q = query.size(1) @@ -5839,7 +5839,7 @@ def meta__scaled_dot_product_flash_backward( is_causal: bool, philox_seed: Tensor, philox_offset: Tensor, - scale: Optional[float] = None, + scale: float | None = None, ): grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2) grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2) @@ -5858,8 +5858,8 @@ def meta__scaled_dot_product_flash_attention_for_cpu( value: Tensor, dropout_p: float = 0.0, is_causal: bool = False, - attn_mask: Optional[Tensor] = None, - scale: Optional[float] = None, + attn_mask: Tensor | None = None, + scale: float | None = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -5895,8 +5895,8 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( logsumexp: Tensor, dropout_p: float, is_causal: bool, - attn_mask: Optional[Tensor] = None, - scale: Optional[float] = None, + attn_mask: Tensor | None = None, + scale: float | None = None, ): # cpus's grad layout is different from cuda's, # i.e. (batch_size, seq_len, num_heads, head_dim) @@ -5927,11 +5927,11 @@ def meta__scaled_dot_product_attention_math_for_mps( query: Tensor, key: Tensor, value: Tensor, - attn_mask: Optional[Tensor] = None, + attn_mask: Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, - dropout_mask: Optional[Tensor] = None, - scale: Optional[float] = None, + dropout_mask: Tensor | None = None, + scale: float | None = None, ) -> tuple[Tensor, Tensor]: def ensure_4d(x): if x.dim() == 3: @@ -5982,11 +5982,11 @@ def meta__scaled_dot_product_efficient_attention( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Optional[Tensor], + attn_bias: Tensor | None, compute_log_sumexp: bool, dropout_p=0.0, is_causal: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -6032,7 +6032,7 @@ def meta__scaled_dot_product_efficient_backward( query: Tensor, key: Tensor, value: Tensor, - attn_bias: Optional[Tensor], + attn_bias: Tensor | None, out: Tensor, logsumexp: Tensor, philox_seed: Tensor, @@ -6040,7 +6040,7 @@ def meta__scaled_dot_product_efficient_backward( dropout_p: float, grad_input_mask: list[bool], is_causal: bool = False, - scale: Optional[float] = None, + scale: float | None = None, ): batch_size = query.size(0) num_heads = query.size(1) @@ -6103,7 +6103,7 @@ def meta__scaled_dot_product_cudnn_backward( max_k: int, dropout_p: float, is_causal: bool, - scale: Optional[float] = None, + scale: float | None = None, ): grad_q = torch.empty_like(query) grad_k = torch.empty_like(key) @@ -6120,18 +6120,18 @@ def meta__flash_attention_forward( query: Tensor, key: Tensor, value: Tensor, - cum_seq_q: Optional[Tensor], - cum_seq_k: Optional[Tensor], + cum_seq_q: Tensor | None, + cum_seq_k: Tensor | None, max_q: int, max_k: int, dropout_p: float, is_causal: bool, return_debug_mask: bool, - scale: Optional[float] = None, - window_size_left: Optional[int] = None, - window_size_right: Optional[int] = None, - seqused_k: Optional[Tensor] = None, - alibi_slopes: Optional[Tensor] = None, + scale: float | None = None, + window_size_left: int | None = None, + window_size_right: int | None = None, + seqused_k: Tensor | None = None, + alibi_slopes: Tensor | None = None, ): # NB: there are two underlying paths: # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) @@ -6211,9 +6211,9 @@ def meta__flash_attention_backward( is_causal: bool, philox_seed: Tensor, philox_offset: Tensor, - scale: Optional[float] = None, - window_size_left: Optional[int] = None, - window_size_right: Optional[int] = None, + scale: float | None = None, + window_size_left: int | None = None, + window_size_right: int | None = None, ): grad_query = torch.empty_like(query) grad_key = torch.empty_like(key) @@ -6231,18 +6231,18 @@ def meta__efficient_attention_forward( query: Tensor, key: Tensor, value: Tensor, - bias: Optional[Tensor], - cu_seqlens_q: Optional[Tensor], - cu_seqlens_k: Optional[Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], + bias: Tensor | None, + cu_seqlens_q: Tensor | None, + cu_seqlens_k: Tensor | None, + max_seqlen_q: int | None, + max_seqlen_k: int | None, dropout_p: float, custom_mask_type: int, compute_log_sumexp: bool = False, - scale: Optional[float] = None, - causal_diagonal: Optional[Tensor] = None, - seqlen_k: Optional[Tensor] = None, - window_size: Optional[int] = None, + scale: float | None = None, + causal_diagonal: Tensor | None = None, + seqlen_k: Tensor | None = None, + window_size: int | None = None, ): B = query.size(0) M = query.size(1) @@ -6284,9 +6284,9 @@ def meta__efficient_attention_backward( query: Tensor, key: Tensor, value: Tensor, - bias: Optional[Tensor], - cu_seqlens_q: Optional[Tensor], - cu_seqlens_k: Optional[Tensor], + bias: Tensor | None, + cu_seqlens_q: Tensor | None, + cu_seqlens_k: Tensor | None, max_seqlen_q: torch.SymInt, max_seqlen_k: torch.SymInt, logsumexp: Tensor, @@ -6295,8 +6295,8 @@ def meta__efficient_attention_backward( philox_offset: Tensor, custom_mask_type: int, bias_requires_grad: bool, - scale: Optional[float] = None, - num_splits_key: Optional[int] = None, + scale: float | None = None, + num_splits_key: int | None = None, shared_storage_dqdkdv: bool = False, ): if shared_storage_dqdkdv: @@ -6339,9 +6339,9 @@ def _check_scaled_mm_sizes( mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - bias: Optional[torch.Tensor] = None, - scale_result: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + bias: torch.Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, use_fast_accum: bool = False, ): def is_fp8_or_fp4_type(dtype): @@ -6520,9 +6520,9 @@ def meta_scaled_mm( mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - bias: Optional[torch.Tensor] = None, - scale_result: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + bias: torch.Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, use_fast_accum: bool = False, ): return _check_scaled_mm_sizes( @@ -6537,10 +6537,10 @@ def _check_scaled_mm_sizes_v2( scale_recipe_a: list[ScalingType], scale_b: list[torch.Tensor], scale_recipe_b: list[ScalingType], - bias: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - swizzle_a: Optional[list[SwizzleType]] = None, - swizzle_b: Optional[list[SwizzleType]] = None, + bias: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + swizzle_a: list[SwizzleType] | None = None, + swizzle_b: list[SwizzleType] | None = None, use_fast_accum: bool = False, ): def is_fp8_or_fp4_type(dtype): @@ -6872,9 +6872,9 @@ def meta_scaled_mm_v2( scale_b: list[torch.Tensor], scale_recipe_b: list[ScalingType], swizzle_b: list[SwizzleType], - bias: Optional[torch.Tensor] = None, - output_dtype: Optional[torch.dtype] = None, - contraction_dims: Optional[list[int]] = None, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, + contraction_dims: list[int] | None = None, use_fast_accum: bool = False, ): return _check_scaled_mm_sizes_v2( @@ -6997,10 +6997,10 @@ def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None): ) def upsample_nearest2d_backward( grad_output: Tensor, - output_size: Sequence[Union[int, torch.SymInt]], - input_size: Sequence[Union[int, torch.SymInt]], - scales_h: Optional[float] = None, - scales_w: Optional[float] = None, + output_size: Sequence[int | torch.SymInt], + input_size: Sequence[int | torch.SymInt], + scales_h: float | None = None, + scales_w: float | None = None, ): full_output_size = upsample_common_check( input_size, output_size, num_spatial_dims=2 @@ -7842,12 +7842,12 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): def _meta_grouped_mm_common( mat_a: Tensor, mat_b: Tensor, - scale_a: Optional[torch.Tensor], - scale_b: Optional[torch.Tensor], - offs: Optional[Tensor] = None, - bias: Optional[Tensor] = None, - scale_result: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + scale_a: torch.Tensor | None, + scale_b: torch.Tensor | None, + offs: Tensor | None = None, + bias: Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, use_fast_accum: bool = False, ): torch._check( @@ -8055,9 +8055,9 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): def meta_grouped_mm( mat_a: Tensor, mat_b: Tensor, - offs: Optional[Tensor] = None, - bias: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + offs: Tensor | None = None, + bias: Tensor | None = None, + out_dtype: torch.dtype | None = None, ) -> Tensor: return _meta_grouped_mm_common( mat_a, @@ -8077,10 +8077,10 @@ def meta_scaled_grouped_mm( mat_b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - offs: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - scale_result: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + offs: torch.Tensor | None = None, + bias: torch.Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, use_fast_accum: bool = False, ): # matching _scaled_grouped_mm_cuda Blas.cpp implementation diff --git a/torch/_ops.py b/torch/_ops.py index 75905d78da5b5..8d02767daf466 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -8,16 +8,7 @@ import types from collections.abc import Callable, Iterator from functools import cached_property -from typing import ( - Any, - ClassVar, - Concatenate, - final, - Generic, - Optional, - TYPE_CHECKING, - Union, -) +from typing import Any, ClassVar, Concatenate, final, Generic, TYPE_CHECKING from typing_extensions import ParamSpec, TypeVar import torch @@ -79,9 +70,7 @@ def __init__(self): # for use with OpOverload; cache lookup is done entirely from C++ # for speed. # TODO: The cache is NOT currently used by HigherOrderOperator, but it should! - self._dispatch_cache: dict[ - DispatchKey, Union[DispatchKey, Callable[..., Any]] - ] = {} + self._dispatch_cache: dict[DispatchKey, DispatchKey | Callable[..., Any]] = {} # This table allows you to override the behavior of a particular # dispatch key to call a custom Python function, rather than the @@ -99,7 +88,7 @@ def __init__(self): # makes sense that you should be able to register them, the same # way you can register dispatch keys. self.python_key_table: dict[ - type[Union[TorchDispatchMode, torch.Tensor]], Callable[..., Any] + type[TorchDispatchMode | torch.Tensor], Callable[..., Any] ] = {} # This table allows you to override the behavior of functorch @@ -121,12 +110,7 @@ def has_kernel_for_any_dispatch_key(self, ks): def py_impl( self, - k: Union[ - type[TorchDispatchMode], - type[torch.Tensor], - TransformType, - DispatchKey, - ], + k: type[TorchDispatchMode] | type[torch.Tensor] | TransformType | DispatchKey, ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]: if inspect.isclass(k) and ( @@ -185,7 +169,7 @@ def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: return fn(CppFunctionalizeAPI(), *args, **kwargs) def functionalize_dispatch_mode_fn( - mode: Optional[FunctionalTensorMode], *args: _P.args, **kwargs: _P.kwargs + mode: FunctionalTensorMode | None, *args: _P.args, **kwargs: _P.kwargs ) -> _T: return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) @@ -307,12 +291,7 @@ def __init__(self, name, *, cacheable=False): def py_impl( self, - k: Union[ - type[TorchDispatchMode], - type[torch.Tensor], - TransformType, - DispatchKey, - ], + k: type[TorchDispatchMode] | type[torch.Tensor] | TransformType | DispatchKey, ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) @@ -894,7 +873,7 @@ def _uncache_dispatch(self, key: DispatchKey) -> None: self._dispatch_cache.pop(key, None) # This implements the pre-computation logic for the Python dispatcher. - def _get_dispatch(self, key: DispatchKey) -> Union[DispatchKey, Callable[_P, _T]]: + def _get_dispatch(self, key: DispatchKey) -> DispatchKey | Callable[_P, _T]: # This is only called upon a cache miss assert key not in self._dispatch_cache, f"{self} {key}" diff --git a/torch/_sources.py b/torch/_sources.py index 1327729a717b1..e0ab883a8b46c 100644 --- a/torch/_sources.py +++ b/torch/_sources.py @@ -3,7 +3,7 @@ import functools import inspect from textwrap import dedent -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple from torch._C import ErrorReport from torch._C._jit_tree_views import SourceRangeFactory @@ -11,8 +11,8 @@ def get_source_lines_and_file( obj: Any, - error_msg: Optional[str] = None, -) -> tuple[list[str], int, Optional[str]]: + error_msg: str | None = None, +) -> tuple[list[str], int, str | None]: """ Wrapper around inspect.getsourcelines and inspect.getsourcefile. @@ -113,7 +113,7 @@ class ParsedDef(NamedTuple): ast: ast.Module ctx: SourceContext source: str - filename: Optional[str] + filename: str | None file_lineno: int diff --git a/torch/_tensor.py b/torch/_tensor.py index c6351ed75ffcb..c1093f35aa984 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -8,7 +8,7 @@ from collections.abc import Callable from copy import deepcopy from numbers import Number -from typing import Any, cast, Concatenate, Optional, TypeVar, Union +from typing import Any, cast, Concatenate, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -180,10 +180,10 @@ def __deepcopy__(self, memo): new_storage = self._typed_storage()._deepcopy(memo) if self.is_quantized: # quantizer_params can be different type based on torch attribute - quantizer_params: Union[ - tuple[torch.qscheme, float, int], - tuple[torch.qscheme, Tensor, Tensor, int], - ] + quantizer_params: ( + tuple[torch.qscheme, float, int] + | tuple[torch.qscheme, Tensor, Tensor, int] + ) if self.qscheme() == torch.per_tensor_affine: quantizer_params = ( self.qscheme(), @@ -366,9 +366,9 @@ def _reduce_ex_internal(self, proto): "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" ) # quantizer_params can be different type based on torch attribute - quantizer_params: Union[ - tuple[torch.qscheme, float, int], tuple[Any, Tensor, Tensor, int] - ] + quantizer_params: ( + tuple[torch.qscheme, float, int] | tuple[Any, Tensor, Tensor, int] + ) if self.qscheme() == torch.per_tensor_affine: quantizer_params = ( torch.per_tensor_affine, @@ -893,7 +893,7 @@ def __reversed__(self): def norm( self, - p: Optional[Union[float, str]] = "fro", + p: float | str | None = "fro", dim=None, keepdim=False, dtype=None, @@ -944,15 +944,15 @@ def lu(self, pivot=True, get_infos=False): def stft( self, n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: "Optional[Tensor]" = None, + hop_length: int | None = None, + win_length: int | None = None, + window: "Tensor | None" = None, center: bool = True, pad_mode: str = "reflect", normalized: bool = False, - onesided: Optional[bool] = None, - return_complex: Optional[bool] = None, - align_to_window: Optional[bool] = None, + onesided: bool | None = None, + return_complex: bool | None = None, + align_to_window: bool | None = None, ): r"""See :func:`torch.stft` @@ -993,13 +993,13 @@ def stft( def istft( self, n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: "Optional[Tensor]" = None, + hop_length: int | None = None, + win_length: int | None = None, + window: "Tensor | None" = None, center: bool = True, normalized: bool = False, - onesided: Optional[bool] = None, - length: Optional[int] = None, + onesided: bool | None = None, + length: int | None = None, return_complex: bool = False, ): r"""See :func:`torch.istft`""" @@ -1528,9 +1528,7 @@ def to_sparse_coo(self): """ return self.to_sparse() - def dim_order( - self, *, ambiguity_check: Union[bool, list[torch.memory_format]] = False - ): + def dim_order(self, *, ambiguity_check: bool | list[torch.memory_format] = False): """ dim_order(ambiguity_check=False) -> tuple @@ -1712,10 +1710,10 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def __dlpack__( self, *, - stream: Optional[Any] = -1, - max_version: Optional[tuple[int, int]] = None, - dl_device: Optional[tuple[enum.IntEnum, int]] = None, - copy: Optional[bool] = None, + stream: Any | None = -1, + max_version: tuple[int, int] | None = None, + dl_device: tuple[enum.IntEnum, int] | None = None, + copy: bool | None = None, ): """ Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 613fa9ad6ff95..46af738829312 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -3,7 +3,7 @@ import dataclasses import math import textwrap -from typing import Any, Optional +from typing import Any import torch from torch import inf @@ -15,7 +15,7 @@ class __PrinterOptions: threshold: float = 1000 edgeitems: int = 3 linewidth: int = 80 - sci_mode: Optional[bool] = None + sci_mode: bool | None = None PRINT_OPTS = __PrinterOptions() diff --git a/torch/_utils.py b/torch/_utils.py index 01cf9d393188b..70641a7c534d7 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Callable from types import ModuleType -from typing import Any, Generic, Optional, TYPE_CHECKING +from typing import Any, Generic, TYPE_CHECKING from typing_extensions import deprecated, ParamSpec import torch @@ -856,7 +856,7 @@ def _get_device_index( """ if isinstance(device, str): device = torch.device(device) - device_idx: Optional[int] = None + device_idx: int | None = None if isinstance(device, torch.device): if not allow_cpu and device.type == "cpu": raise ValueError(f"Expected a non cpu device, but got: {device}") @@ -1054,7 +1054,7 @@ def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None: ) -def try_import(module_name: str) -> Optional[ModuleType]: +def try_import(module_name: str) -> ModuleType | None: # Implementation based on # https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported if (module := sys.modules.get(module_name, None)) is not None: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 3a172a814e2e5..6f95511b5ce80 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -6,7 +6,7 @@ import tempfile import typing_extensions from collections.abc import Callable -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch @@ -255,7 +255,7 @@ def max_clock_rate(): return 1100 -def get_mast_job_name_version() -> Optional[tuple[str, int]]: +def get_mast_job_name_version() -> tuple[str, int] | None: return None @@ -274,7 +274,7 @@ def get_mast_job_name_version() -> Optional[tuple[str, int]]: REQUIRES_SET_PYTHON_MODULE = False -def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]: +def maybe_upload_prof_stats_to_manifold(profile_path: str) -> str | None: print("Uploading profile stats (fb-only otherwise no-op)") return None @@ -367,11 +367,11 @@ def get_default_numa_options(): return None -def log_triton_builds(fail: Optional[str]): +def log_triton_builds(fail: str | None): pass -def find_compile_subproc_binary() -> Optional[str]: +def find_compile_subproc_binary() -> str | None: """ Allows overriding the binary used for subprocesses """ diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 3f303f78a4713..861d4fd4b4153 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools from collections.abc import Callable -from typing import Any, Optional, Union +from typing import Any from typing_extensions import deprecated import torch @@ -9,13 +9,13 @@ from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten -in_dims_t = Union[int, tuple] -out_dims_t = Union[int, tuple[int, ...]] +in_dims_t = int | tuple +out_dims_t = int | tuple[int, ...] # Checks that all args-to-be-batched have the same batch dim size def _validate_and_get_batch_size( - flat_in_dims: list[Optional[int]], + flat_in_dims: list[int | None], flat_args: list, ) -> int: batch_sizes = [ @@ -31,7 +31,7 @@ def _validate_and_get_batch_size( return batch_sizes[0] -def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int: +def _num_outputs(batched_outputs: Tensor | tuple[Tensor, ...]) -> int: if isinstance(batched_outputs, tuple): return len(batched_outputs) return 1 @@ -115,7 +115,7 @@ def _create_batched_inputs( # Undos the batching (and any batch dimensions) associated with the `vmap_level`. def _unwrap_batched( - batched_outputs: Union[Tensor, tuple[Tensor, ...]], + batched_outputs: Tensor | tuple[Tensor, ...], out_dims: out_dims_t, vmap_level: int, batch_size: int, diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 5aaa77b25697a..a4c8aaafa351b 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -69,7 +69,7 @@ ) from struct import unpack from sys import maxsize -from typing import Any, Union +from typing import Any import torch from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING @@ -84,15 +84,15 @@ "nt", ] -_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set() +_marked_safe_globals_set: set[Callable | tuple[Callable, str]] = set() -def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]): +def _add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]): global _marked_safe_globals_set _marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals)) -def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: +def _get_safe_globals() -> list[Callable | tuple[Callable, str]]: global _marked_safe_globals_set return list(_marked_safe_globals_set) @@ -103,14 +103,14 @@ def _clear_safe_globals(): def _remove_safe_globals( - globals_to_remove: list[Union[Callable, tuple[Callable, str]]], + globals_to_remove: list[Callable | tuple[Callable, str]], ): global _marked_safe_globals_set _marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove) class _safe_globals: - def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]): + def __init__(self, safe_globals: list[Callable | tuple[Callable, str]]): self.safe_globals = safe_globals def __enter__(self): diff --git a/torch/functional.py b/torch/functional.py index 013832d59cfb3..33b0ada75324c 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -2,7 +2,7 @@ import itertools import operator from collections.abc import Sequence -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING import torch import torch.nn.functional as F @@ -120,7 +120,7 @@ def broadcast_shapes(*shapes): def split( tensor: Tensor, - split_size_or_sections: Union[int, list[int]], + split_size_or_sections: int | list[int], dim: int = 0, ) -> tuple[Tensor, ...]: r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. @@ -387,13 +387,13 @@ def parse_subscript(n: int) -> str: if TYPE_CHECKING: # The JIT doesn't understand Union, so only add type annotation for mypy def meshgrid( - *tensors: Union[Tensor, list[Tensor]], indexing: Optional[str] = None + *tensors: Tensor | list[Tensor], indexing: str | None = None ) -> tuple[Tensor, ...]: return _meshgrid(*tensors, indexing=indexing) else: - def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]: + def meshgrid(*tensors, indexing: str | None = None) -> tuple[Tensor, ...]: r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors. This is helpful when you want to visualize data over some @@ -490,7 +490,7 @@ def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]: return _meshgrid(*tensors, indexing=indexing) -def _meshgrid(*tensors, indexing: Optional[str]): +def _meshgrid(*tensors, indexing: str | None): if has_torch_function(tensors): return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing) if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): @@ -508,15 +508,15 @@ def _meshgrid(*tensors, indexing: Optional[str]): def stft( input: Tensor, n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: Optional[Tensor] = None, + hop_length: int | None = None, + win_length: int | None = None, + window: Tensor | None = None, center: bool = True, pad_mode: str = "reflect", normalized: bool = False, - onesided: Optional[bool] = None, - return_complex: Optional[bool] = None, - align_to_window: Optional[bool] = None, + onesided: bool | None = None, + return_complex: bool | None = None, + align_to_window: bool | None = None, ) -> Tensor: r"""Short-time Fourier transform (STFT). @@ -788,7 +788,7 @@ def _unique_impl( sorted: bool = True, return_inverse: bool = False, return_counts: bool = False, - dim: Optional[int] = None, + dim: int | None = None, ) -> _unique_impl_out: r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> tuple[Tensor, Tensor, Tensor] @@ -956,7 +956,7 @@ def _unique_consecutive_impl( input: Tensor, return_inverse: bool = False, return_counts: bool = False, - dim: Optional[int] = None, + dim: int | None = None, ) -> _unique_impl_out: r"""Eliminates all but the first element from every consecutive group of equivalent elements. @@ -1201,7 +1201,7 @@ def tensordot( a, b, dims: int = 2, - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): pass @@ -1210,7 +1210,7 @@ def tensordot( # noqa: F811 a, b, dims: tuple[list[int], list[int]], - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): pass @@ -1219,7 +1219,7 @@ def tensordot( # noqa: F811 a, b, dims: list[list[int]], - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): pass @@ -1228,7 +1228,7 @@ def tensordot( # noqa: F811 a, b, dims: torch.Tensor, - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): pass @@ -1237,7 +1237,7 @@ def tensordot( # noqa: F811 a, b, dims=2, - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ): r"""Returns a contraction of a and b over multiple dimensions. @@ -1659,7 +1659,7 @@ def norm( # noqa: F811 def norm( # noqa: F811 input, - p: Optional[Union[float, str]] = "fro", + p: float | str | None = "fro", dim=None, keepdim=False, out=None, @@ -1882,7 +1882,7 @@ def norm( # noqa: F811 def unravel_index( indices: Tensor, - shape: Union[int, Sequence[int], torch.Size], + shape: int | Sequence[int] | torch.Size, ) -> tuple[Tensor, ...]: r"""Converts a tensor of flat indices into a tuple of coordinate tensors that index into an arbitrary tensor of the specified shape. @@ -1938,7 +1938,7 @@ def unravel_index( return res_tensor.unbind(-1) -def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: +def _unravel_index(indices: Tensor, shape: int | Sequence[int]) -> Tensor: torch._check_type( not indices.is_complex() and not indices.is_floating_point() diff --git a/torch/hub.py b/torch/hub.py index bf138f7784347..3ec285fcb3a9e 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -12,7 +12,7 @@ import warnings import zipfile from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from typing_extensions import deprecated from urllib.error import HTTPError, URLError from urllib.parse import urlparse # noqa: F401 @@ -91,7 +91,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): VAR_DEPENDENCY = "dependencies" MODULE_HUBCONF = "hubconf.py" READ_DATA_CHUNK = 128 * 1024 -_hub_dir: Optional[str] = None +_hub_dir: str | None = None @contextlib.contextmanager @@ -417,7 +417,7 @@ def get_dir() -> str: return os.path.join(_get_torch_home(), "hub") -def set_dir(d: Union[str, os.PathLike]) -> None: +def set_dir(d: str | os.PathLike) -> None: r""" Optionally set the Torch Hub directory used to save downloaded models & weights. @@ -694,7 +694,7 @@ def _load_local(hubconf_dir, model, *args, **kwargs): def download_url_to_file( url: str, dst: str, - hash_prefix: Optional[str] = None, + hash_prefix: str | None = None, progress: bool = True, ) -> None: r"""Download object at the given URL to a local path. @@ -816,11 +816,11 @@ def _legacy_zip_load( def load_state_dict_from_url( url: str, - model_dir: Optional[str] = None, + model_dir: str | None = None, map_location: MAP_LOCATION = None, progress: bool = True, check_hash: bool = False, - file_name: Optional[str] = None, + file_name: str | None = None, weights_only: bool = False, ) -> dict[str, Any]: r"""Loads the Torch serialized object at the given URL. diff --git a/torch/library.py b/torch/library.py index 76e5d27aae434..5305d647bc613 100644 --- a/torch/library.py +++ b/torch/library.py @@ -7,7 +7,7 @@ import traceback import weakref from collections.abc import Callable, Sequence -from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union +from typing import Any, overload, TYPE_CHECKING, TypeVar, Union from typing_extensions import deprecated, ParamSpec import torch @@ -98,7 +98,7 @@ def __init__(self, ns, kind, dispatch_key=""): frame = traceback.extract_stack(limit=2)[0] filename, lineno = frame.filename, frame.lineno - self.m: Optional[Any] = torch._C._dispatch_library( + self.m: Any | None = torch._C._dispatch_library( kind, ns, dispatch_key, filename, lineno ) self.ns = ns @@ -399,7 +399,7 @@ def fallback(self, fn, dispatch_key="", *, with_keyset=False): self.m.fallback(dispatch_key, fn, with_keyset) - def _register_effectful_op(self, op_name: str, effect: Optional[EffectType]): + def _register_effectful_op(self, op_name: str, effect: EffectType | None): """ Registers an effect to an operator. This is used to register an op that has side effects that is not capturable by the schema. @@ -570,20 +570,20 @@ def wrap(f): @overload def impl( qualname: str, - types: Union[str, Sequence[str]], + types: str | Sequence[str], func: None = None, *, - lib: Optional[Library] = None, + lib: Library | None = None, ) -> Callable[[Callable[..., object]], None]: ... @overload def impl( qualname: str, - types: Union[str, Sequence[str]], + types: str | Sequence[str], func: Callable[..., object], *, - lib: Optional[Library] = None, + lib: Library | None = None, ) -> None: ... @@ -599,10 +599,10 @@ def impl( @functools.singledispatch def impl( qualname: str, - types: Union[str, Sequence[str]], - func: Optional[Callable[_P, _T]] = None, + types: str | Sequence[str], + func: Callable[_P, _T] | None = None, *, - lib: Optional[Library] = None, + lib: Library | None = None, ) -> object: """Register an implementation for a device type for this operator. @@ -683,10 +683,10 @@ def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]: @overload def _impl( qualname: str, - types: Union[str, Sequence[str]], + types: str | Sequence[str], func: None = None, *, - lib: Optional[Library] = None, + lib: Library | None = None, disable_dynamo: bool = False, ) -> Callable[[Callable[..., object]], None]: ... @@ -694,22 +694,22 @@ def _impl( @overload def _impl( qualname: str, - types: Union[str, Sequence[str]], + types: str | Sequence[str], func: Callable[..., object], *, - lib: Optional[Library] = None, + lib: Library | None = None, disable_dynamo: bool = False, ) -> None: ... def _impl( qualname: str, - types: Union[str, Sequence[str]], - func: Optional[Callable[..., object]] = None, + types: str | Sequence[str], + func: Callable[..., object] | None = None, *, - lib: Optional[Library] = None, + lib: Library | None = None, disable_dynamo: bool = False, -) -> Optional[Callable[[Callable[..., object]], None]]: +) -> Callable[[Callable[..., object]], None] | None: # See impl() if isinstance(types, str): types = (types,) @@ -786,10 +786,10 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): def register_kernel( op: _op_identifier, device_types: device_types_t, - func: Optional[Callable] = None, + func: Callable | None = None, /, *, - lib: Optional[Library] = None, + lib: Library | None = None, ): """Register an implementation for a device type for this operator. @@ -857,7 +857,7 @@ def register_autocast( cast_inputs: _dtype, /, *, - lib: Optional[Library] = None, + lib: Library | None = None, ): r"""Register an autocast dispatch rule for this custom op. @@ -948,10 +948,10 @@ def kernel(_, *args, **kwargs): def register_fake( op: _op_identifier, - func: Optional[Callable] = None, + func: Callable | None = None, /, *, - lib: Optional[Library] = None, + lib: Library | None = None, _stacklevel: int = 1, allow_override: bool = False, ): @@ -1084,9 +1084,9 @@ def register(func): def _register_effectful_op( op: _op_identifier, - effect: Optional[EffectType], + effect: EffectType | None, *, - lib: Optional[Library] = None, + lib: Library | None = None, ) -> None: r""" To specify that an operator has side-effects, we must register an effect @@ -1125,7 +1125,7 @@ def register_autograd( backward: Callable, /, *, - setup_context: Optional[Callable] = None, + setup_context: Callable | None = None, lib=None, ) -> None: r"""Register a backward formula for this custom op. @@ -1253,10 +1253,10 @@ def register_autograd( def register_torch_dispatch( op: _op_identifier, torch_dispatch_class: Any, - func: Optional[Callable] = None, + func: Callable | None = None, /, *, - lib: Optional[Library] = None, + lib: Library | None = None, ): r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. @@ -1333,7 +1333,7 @@ def register(func): def register_vmap( op: _op_identifier, - func: Optional[Callable] = None, + func: Callable | None = None, /, *, lib=None, @@ -1525,7 +1525,7 @@ def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": def get_kernel( - op: _op_identifier, dispatch_key: Union[str, torch.DispatchKey] + op: _op_identifier, dispatch_key: str | torch.DispatchKey ) -> torch._C._SafeKernelFunction: """Returns the computed kernel for a given operator and dispatch key. @@ -1607,11 +1607,11 @@ def get_kernel( def opcheck( - op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], + op: torch._ops.OpOverload | torch._ops.OpOverloadPacket | CustomOpDef, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, *, - test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS, + test_utils: str | Sequence[str] = _OPCHECK_DEFAULT_UTILS, raise_exception: bool = True, atol=None, rtol=None, diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 4bae914f0292b..dd3ff69fd6af8 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import warnings from collections.abc import Callable -from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar, Union +from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar from typing_extensions import ParamSpec import torch @@ -16,7 +16,7 @@ from torch._prims_common import DimsType from torch.types import _dtype as DType - DimOrDims: TypeAlias = Optional[DimsType] + DimOrDims: TypeAlias = DimsType | None else: # The JIT doesn't understand Union, nor torch.dtype here DType = int @@ -624,7 +624,7 @@ def _sparse_coo_scatter_reduction_helper( mask_input: Tensor, dims: tuple[int, ...], keepdim: bool, - dtype: Optional[DType] = None, + dtype: DType | None = None, ) -> Tensor: reduce = op.__name__ valid_reductions = ["sum", "prod", "amax", "amin"] @@ -744,7 +744,7 @@ def _sparse_csr_segment_reduction_helper( mask_input: Tensor, dims: tuple[int, ...], keepdim: bool, - dtype: Optional[DType] = None, + dtype: DType | None = None, ) -> Tensor: # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True # FIXME: when dense dimensions are implemented for CSR tensors @@ -869,7 +869,7 @@ def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: ) -def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor: +def _input_mask(input: Tensor | MaskedTensor, *args, **kwargs) -> Tensor: """Return canonical input mask. A canonical input mask is defined as a boolean mask tensor that @@ -1000,9 +1000,7 @@ def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: ) -def _combine_input_and_mask( - op, input: Union[MaskedTensor, Tensor], mask, *args -) -> Tensor: +def _combine_input_and_mask(op, input: MaskedTensor | Tensor, mask, *args) -> Tensor: def helper(input, mask): if mask is None: return input @@ -1046,12 +1044,12 @@ def backward(ctx, grad_output): @_apply_docstring_templates def sum( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: # __doc__ is generated by _apply_docstring_templates decorator if dtype is None: @@ -1099,12 +1097,12 @@ def sum( @_apply_docstring_templates def prod( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: # __doc__ is generated by _apply_docstring_templates decorator if dtype is None: @@ -1179,8 +1177,8 @@ def cumsum( input: Tensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1199,8 +1197,8 @@ def cumprod( input: Tensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1216,12 +1214,12 @@ def cumprod( @_apply_docstring_templates def amax( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1266,12 +1264,12 @@ def amax( @_apply_docstring_templates def amin( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1316,12 +1314,12 @@ def amin( @_apply_docstring_templates def argmax( - input: Union[Tensor, MaskedTensor], - dim: Optional[int] = None, + input: Tensor | MaskedTensor, + dim: int | None = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1342,12 +1340,12 @@ def argmax( @_apply_docstring_templates def argmin( - input: Union[Tensor, MaskedTensor], - dim: Optional[int] = None, + input: Tensor | MaskedTensor, + dim: int | None = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1368,12 +1366,12 @@ def argmin( @_apply_docstring_templates def mean( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1435,12 +1433,12 @@ def mean( @_apply_docstring_templates def median( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: int = -1, *, keepdim: bool = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1482,8 +1480,8 @@ def logsumexp( dim: DimOrDims = None, *, keepdim: bool = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1499,12 +1497,12 @@ def logsumexp( # Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations def logaddexp( - input: Union[Tensor, MaskedTensor], - other: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, + other: Tensor | MaskedTensor, *, - dtype: Optional[DType] = None, - input_mask: Optional[Tensor] = None, - other_mask: Optional[Tensor] = None, + dtype: DType | None = None, + input_mask: Tensor | None = None, + other_mask: Tensor | None = None, ) -> Tensor: """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor @@ -1561,13 +1559,13 @@ def logaddexp( @_apply_docstring_templates def norm( - input: Union[Tensor, MaskedTensor], - ord: Optional[float] = 2.0, + input: Tensor | MaskedTensor, + ord: float | None = 2.0, dim: DimOrDims = None, *, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1596,15 +1594,15 @@ def norm( def _std_var( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims, - unbiased: Optional[bool], + unbiased: bool | None, *, - correction_opt: Optional[Union[int, float]], - keepdim: Optional[bool], - dtype: Optional[DType], - mask: Optional[Tensor], - take_sqrt: Optional[bool], + correction_opt: int | float | None, + keepdim: bool | None, + dtype: DType | None, + mask: Tensor | None, + take_sqrt: bool | None, ) -> Tensor: assert unbiased is None or correction_opt is None, ( "Only one of unbiased and correction may be given" @@ -1677,14 +1675,14 @@ def _std_var( @_apply_docstring_templates def var( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, - unbiased: Optional[bool] = None, + unbiased: bool | None = None, *, - correction: Optional[Union[int, float]] = None, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + correction: int | float | None = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1708,14 +1706,14 @@ def var( @_apply_docstring_templates def std( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: DimOrDims = None, - unbiased: Optional[bool] = None, + unbiased: bool | None = None, *, - correction: Optional[int] = None, - keepdim: Optional[bool] = False, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + correction: int | None = None, + keepdim: bool | None = False, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: """\ {reduction_signature} @@ -1739,11 +1737,11 @@ def std( @_apply_docstring_templates def softmax( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1759,11 +1757,11 @@ def softmax( @_apply_docstring_templates def log_softmax( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1779,11 +1777,11 @@ def log_softmax( @_apply_docstring_templates def softmin( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, dim: int, *, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype @@ -1799,13 +1797,13 @@ def softmin( @_apply_docstring_templates def normalize( - input: Union[Tensor, MaskedTensor], + input: Tensor | MaskedTensor, ord: float, dim: int, *, eps: float = 1e-12, - dtype: Optional[DType] = None, - mask: Optional[Tensor] = None, + dtype: DType | None = None, + mask: Tensor | None = None, ) -> Tensor: if dtype is None: dtype = input.dtype diff --git a/torch/nn/_reduction.py b/torch/nn/_reduction.py index 9764f935b7c3d..a3ca62929a3b5 100644 --- a/torch/nn/_reduction.py +++ b/torch/nn/_reduction.py @@ -1,5 +1,4 @@ import warnings -from typing import Optional # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h @@ -31,8 +30,8 @@ def get_enum(reduction: str) -> int: # We use these functions in torch/legacy as well, in which case we'll silence the warning def legacy_get_string( - size_average: Optional[bool], - reduce: Optional[bool], + size_average: bool | None, + reduce: bool | None, emit_warning: bool = True, ) -> str: warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." @@ -54,8 +53,8 @@ def legacy_get_string( def legacy_get_enum( - size_average: Optional[bool], - reduce: Optional[bool], + size_average: bool | None, + reduce: bool | None, emit_warning: bool = True, ) -> int: return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/torch/nn/common_types.py b/torch/nn/common_types.py index 9262c45472271..e1928414a396e 100644 --- a/torch/nn/common_types.py +++ b/torch/nn/common_types.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeAlias as _TypeAlias, TypeVar +from typing import TypeAlias as _TypeAlias, TypeVar from torch import Tensor @@ -29,9 +29,9 @@ _size_6_t: _TypeAlias = _scalar_or_tuple_6_t[int] # For arguments which represent optional size parameters (eg, adaptive pool parameters) -_size_any_opt_t: _TypeAlias = _scalar_or_tuple_any_t[Optional[int]] -_size_2_opt_t: _TypeAlias = _scalar_or_tuple_2_t[Optional[int]] -_size_3_opt_t: _TypeAlias = _scalar_or_tuple_3_t[Optional[int]] +_size_any_opt_t: _TypeAlias = _scalar_or_tuple_any_t[int | None] +_size_2_opt_t: _TypeAlias = _scalar_or_tuple_2_t[int | None] +_size_3_opt_t: _TypeAlias = _scalar_or_tuple_3_t[int | None] # For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) _ratio_2_t: _TypeAlias = _scalar_or_tuple_2_t[float] diff --git a/torch/nn/init.py b/torch/nn/init.py index 3956d9399876e..900b2d34bc08f 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -3,7 +3,7 @@ import math import warnings from collections.abc import Callable -from typing import Literal, Optional as _Optional, TypeVar +from typing import Literal, TypeVar from typing_extensions import ParamSpec import torch @@ -67,7 +67,7 @@ # managers, so these need to be implemented as builtins. Using these wrappers # lets us keep those builtins small and reusable. def _no_grad_uniform_( - tensor: Tensor, a: float, b: float, generator: _Optional[torch.Generator] = None + tensor: Tensor, a: float, b: float, generator: torch.Generator | None = None ) -> Tensor: with torch.no_grad(): return tensor.uniform_(a, b, generator=generator) @@ -77,7 +77,7 @@ def _no_grad_normal_( tensor: Tensor, mean: float, std: float, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: with torch.no_grad(): return tensor.normal_(mean, std, generator=generator) @@ -89,7 +89,7 @@ def _no_grad_trunc_normal_( std: float, a: float, b: float, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x: float) -> float: @@ -138,7 +138,7 @@ def _no_grad_zero_(tensor: Tensor) -> Tensor: def calculate_gain( - nonlinearity: _NonlinearityType, param: _Optional[int | float] = None + nonlinearity: _NonlinearityType, param: int | float | None = None ) -> float: r"""Return the recommended gain value for the given nonlinearity function. @@ -215,7 +215,7 @@ def uniform_( tensor: Tensor, a: float = 0.0, b: float = 1.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from the uniform distribution. @@ -242,7 +242,7 @@ def normal_( tensor: Tensor, mean: float = 0.0, std: float = 1.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from the normal distribution. @@ -271,7 +271,7 @@ def trunc_normal_( std: float = 1.0, a: float = -2.0, b: float = 2.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from a truncated normal distribution. @@ -438,7 +438,7 @@ def _calculate_fan_in_and_fan_out(tensor: Tensor) -> tuple[int, int]: def xavier_uniform_( tensor: Tensor, gain: float = 1.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Xavier uniform distribution. @@ -471,7 +471,7 @@ def xavier_uniform_( def xavier_normal_( tensor: Tensor, gain: float = 1.0, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Xavier normal distribution. @@ -515,7 +515,7 @@ def kaiming_uniform_( a: float = 0, mode: _FanMode = "fan_in", nonlinearity: _NonlinearityType = "leaky_relu", - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. @@ -580,7 +580,7 @@ def kaiming_normal_( a: float = 0, mode: _FanMode = "fan_in", nonlinearity: _NonlinearityType = "leaky_relu", - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Kaiming normal distribution. @@ -631,7 +631,7 @@ def kaiming_normal_( def orthogonal_( tensor: Tensor, gain: float = 1, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the input `Tensor` with a (semi) orthogonal matrix. @@ -683,7 +683,7 @@ def sparse_( tensor: Tensor, sparsity: float, std: float = 0.01, - generator: _Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ) -> Tensor: r"""Fill the 2D input `Tensor` as a sparse matrix. diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index edd65601db985..dac27cdb0d246 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import warnings -from typing import Optional import torch import torch.nn.functional as F @@ -261,8 +260,8 @@ def __init__( min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False, - min_value: Optional[float] = None, - max_value: Optional[float] = None, + min_value: float | None = None, + max_value: float | None = None, ) -> None: super().__init__() if min_value is not None: @@ -1053,7 +1052,7 @@ def extra_repr(self) -> str: return str(self.lambd) -def _check_arg_device(x: Optional[torch.Tensor]) -> bool: +def _check_arg_device(x: torch.Tensor | None) -> bool: if x is not None: return x.device.type in [ "cpu", @@ -1063,7 +1062,7 @@ def _check_arg_device(x: Optional[torch.Tensor]) -> bool: return True -def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool: +def _arg_requires_grad(x: torch.Tensor | None) -> bool: if x is not None: return x.requires_grad return False @@ -1156,8 +1155,8 @@ class MultiheadAttention(Module): """ __constants__ = ["batch_first"] - bias_k: Optional[torch.Tensor] - bias_v: Optional[torch.Tensor] + bias_k: torch.Tensor | None + bias_v: torch.Tensor | None def __init__( self, @@ -1258,12 +1257,12 @@ def forward( query: Tensor, key: Tensor, value: Tensor, - key_padding_mask: Optional[Tensor] = None, + key_padding_mask: Tensor | None = None, need_weights: bool = True, - attn_mask: Optional[Tensor] = None, + attn_mask: Tensor | None = None, average_attn_weights: bool = True, is_causal: bool = False, - ) -> tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: r"""Compute attention outputs using query, key, and value embeddings. Supports optional parameters for padding, masks and attention weights. @@ -1517,10 +1516,10 @@ def forward( def merge_masks( self, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, query: Tensor, - ) -> tuple[Optional[Tensor], Optional[int]]: + ) -> tuple[Tensor | None, int | None]: r"""Determine mask type and combine masks if necessary. If only one mask is provided, that mask @@ -1535,8 +1534,8 @@ def merge_masks( merged_mask: merged mask mask_type: merged mask type (0, 1, or 2) """ - mask_type: Optional[int] = None - merged_mask: Optional[Tensor] = None + mask_type: int | None = None + merged_mask: Tensor | None = None if key_padding_mask is not None: mask_type = 1 @@ -1732,9 +1731,9 @@ class Softmin(Module): """ __constants__ = ["dim"] - dim: Optional[int] + dim: int | None - def __init__(self, dim: Optional[int] = None) -> None: + def __init__(self, dim: int | None = None) -> None: super().__init__() self.dim = dim @@ -1797,9 +1796,9 @@ class Softmax(Module): """ __constants__ = ["dim"] - dim: Optional[int] + dim: int | None - def __init__(self, dim: Optional[int] = None) -> None: + def __init__(self, dim: int | None = None) -> None: super().__init__() self.dim = dim @@ -1882,9 +1881,9 @@ class LogSoftmax(Module): """ __constants__ = ["dim"] - dim: Optional[int] + dim: int | None - def __init__(self, dim: Optional[int] = None) -> None: + def __init__(self, dim: int | None = None) -> None: super().__init__() self.dim = dim diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 2ac05f2e8f933..40a912b4f0568 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Optional +from typing import Any import torch from torch import Tensor @@ -29,7 +29,7 @@ class _NormBase(Module): __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] num_features: int eps: float - momentum: Optional[float] + momentum: float | None affine: bool track_running_stats: bool # WARNING: weight and bias purposely not defined here. @@ -39,7 +39,7 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: Optional[float] = 0.1, + momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, @@ -65,8 +65,8 @@ def __init__( self.register_buffer( "running_var", torch.ones(num_features, **factory_kwargs) ) - self.running_mean: Optional[Tensor] - self.running_var: Optional[Tensor] + self.running_mean: Tensor | None + self.running_var: Tensor | None self.register_buffer( "num_batches_tracked", torch.tensor( @@ -76,7 +76,7 @@ def __init__( **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, ), ) - self.num_batches_tracked: Optional[Tensor] + self.num_batches_tracked: Tensor | None else: self.register_buffer("running_mean", None) self.register_buffer("running_var", None) @@ -146,7 +146,7 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: Optional[float] = 0.1, + momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, @@ -718,10 +718,10 @@ def __init__( self, num_features: int, eps: float = 1e-5, - momentum: Optional[float] = 0.1, + momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True, - process_group: Optional[Any] = None, + process_group: Any | None = None, device=None, dtype=None, ) -> None: diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index f062c4bcbd12b..d99151369e18e 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -4,7 +4,7 @@ import operator from collections import abc as container_abcs, OrderedDict from itertools import chain, islice -from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar +from typing import Any, overload, TYPE_CHECKING, TypeVar from typing_extensions import deprecated, Self import torch @@ -358,7 +358,7 @@ def forward(self, x): _modules: dict[str, Module] # type: ignore[assignment] - def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + def __init__(self, modules: Iterable[Module] | None = None) -> None: super().__init__() if modules is not None: self += modules @@ -545,7 +545,7 @@ def forward(self, x, choice, act): _modules: dict[str, Module] # type: ignore[assignment] - def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + def __init__(self, modules: Mapping[str, Module] | None = None) -> None: super().__init__() if modules is not None: self.update(modules) @@ -673,7 +673,7 @@ def forward(self, x): return x """ - def __init__(self, values: Optional[Iterable[Any]] = None) -> None: + def __init__(self, values: Iterable[Any] | None = None) -> None: super().__init__() self._size = 0 if values is not None: @@ -888,7 +888,7 @@ def copy(self) -> ParameterDict: def __contains__(self, key: str) -> bool: return key in self._keys - def setdefault(self, key: str, default: Optional[Any] = None) -> Any: + def setdefault(self, key: str, default: Any | None = None) -> Any: """Set the default for a key in the Parameterdict. If key is in the ParameterDict, return its value. @@ -927,7 +927,7 @@ def popitem(self) -> tuple[str, Any]: del self[k] return k, val - def get(self, key: str, default: Optional[Any] = None) -> Any: + def get(self, key: str, default: Any | None = None) -> Any: r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. Args: @@ -937,7 +937,7 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: return self[key] if key in self else default # noqa: SIM401 def fromkeys( - self, keys: Iterable[str], default: Optional[Any] = None + self, keys: Iterable[str], default: Any | None = None ) -> ParameterDict: r"""Return a new ParameterDict with the keys provided. diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index b539203f6fedd..8b74b6a5a39e8 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -67,7 +67,7 @@ class _ConvNd(Module): __annotations__ = {"bias": Optional[torch.Tensor]} def _conv_forward( # type: ignore[empty-body] - self, input: Tensor, weight: Tensor, bias: Optional[Tensor] + self, input: Tensor, weight: Tensor, bias: Tensor | None ) -> Tensor: ... in_channels: int @@ -82,7 +82,7 @@ def _conv_forward( # type: ignore[empty-body] groups: int padding_mode: Literal["zeros", "reflect", "replicate", "circular"] weight: Tensor - bias: Optional[Tensor] + bias: Tensor | None def __init__( self, @@ -353,7 +353,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): if self.padding_mode != "zeros": return F.conv1d( F.pad( @@ -531,7 +531,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): if self.padding_mode != "zeros": return F.conv2d( F.pad( @@ -701,7 +701,7 @@ def __init__( **factory_kwargs, ) - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None): if self.padding_mode != "zeros": return F.conv3d( F.pad( @@ -766,12 +766,12 @@ def __init__( def _output_padding( self, input: Tensor, - output_size: Optional[list[int]], + output_size: list[int] | None, stride: list[int], padding: list[int], kernel_size: list[int], num_spatial_dims: int, - dilation: Optional[list[int]] = None, + dilation: list[int] | None = None, ) -> list[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already @@ -965,7 +965,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: if self.padding_mode != "zeros": raise ValueError( "Only `zeros` padding mode is supported for ConvTranspose1d" @@ -1153,7 +1153,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: """ Performs the forward pass. @@ -1344,7 +1344,7 @@ def __init__( **factory_kwargs, ) - def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: if self.padding_mode != "zeros": raise ValueError( "Only `zeros` padding mode is supported for ConvTranspose3d" diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index d4c192ee8ce4a..72d90d1c10364 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import itertools -from typing import Any, Optional, Protocol +from typing import Any, Protocol import torch from torch.nn.parameter import is_lazy @@ -167,7 +167,7 @@ class LazyModuleMixin: # modules inheriting from this will change their __class__ to the specified # one after they are fully initialized - cls_to_become: Optional[type[Any]] = None + cls_to_become: type[Any] | None = None def __init__(self: _LazyProtocol, *args, **kwargs): # Mypy doesn't like this super call in a mixin diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 05b39ba762f47..00ada62febded 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs from collections.abc import Callable -from typing import Optional from typing_extensions import deprecated from torch import Tensor @@ -50,14 +49,14 @@ def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> N class _WeightedLoss(_Loss): def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, reduce=None, reduction: str = "mean", ) -> None: super().__init__(size_average, reduce, reduction) self.register_buffer("weight", weight) - self.weight: Optional[Tensor] + self.weight: Tensor | None class L1Loss(_Loss): @@ -241,7 +240,7 @@ class NLLLoss(_WeightedLoss): def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -272,7 +271,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: class NLLLoss2d(NLLLoss): def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -817,17 +816,17 @@ class BCEWithLogitsLoss(_Loss): def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, reduce=None, reduction: str = "mean", - pos_weight: Optional[Tensor] = None, + pos_weight: Tensor | None = None, ) -> None: super().__init__(size_average, reduce, reduction) self.register_buffer("weight", weight) self.register_buffer("pos_weight", pos_weight) - self.weight: Optional[Tensor] - self.pos_weight: Optional[Tensor] + self.weight: Tensor | None + self.pos_weight: Tensor | None def forward(self, input: Tensor, target: Tensor) -> Tensor: """Runs the forward pass.""" @@ -1347,7 +1346,7 @@ class probabilities only when a single class label per minibatch item is too res def __init__( self, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, ignore_index: int = -100, reduce=None, @@ -1626,7 +1625,7 @@ def __init__( self, p: int = 1, margin: float = 1.0, - weight: Optional[Tensor] = None, + weight: Tensor | None = None, size_average=None, reduce=None, reduction: str = "mean", @@ -1869,7 +1868,7 @@ class TripletMarginWithDistanceLoss(_Loss): def __init__( self, *, - distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + distance_function: Callable[[Tensor, Tensor], Tensor] | None = None, margin: float = 1.0, swap: bool = False, reduction: str = "mean", @@ -1879,7 +1878,7 @@ def __init__( raise ValueError( f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead" ) - self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ( + self.distance_function: Callable[[Tensor, Tensor], Tensor] | None = ( distance_function if distance_function is not None else PairwiseDistance() ) self.margin = margin diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 6557f60389964..f9795cc1c74aa 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -115,7 +115,7 @@ def __setstate__(self, state: dict): purposes""" _global_backward_pre_hooks: dict[int, Callable] = OrderedDict() _global_backward_hooks: dict[int, Callable] = OrderedDict() -_global_is_full_backward_hook: Optional[bool] = None +_global_is_full_backward_hook: bool | None = None _global_forward_pre_hooks: dict[int, Callable] = OrderedDict() _global_forward_hooks: dict[int, Callable] = OrderedDict() _global_forward_hooks_always_called: dict[int, bool] = OrderedDict() @@ -453,12 +453,12 @@ def forward(self, x): the change.""" training: bool - _parameters: dict[str, Optional[Parameter]] - _buffers: dict[str, Optional[Tensor]] + _parameters: dict[str, Parameter | None] + _buffers: dict[str, Tensor | None] _non_persistent_buffers_set: set[str] _backward_pre_hooks: dict[int, Callable] _backward_hooks: dict[int, Callable] - _is_full_backward_hook: Optional[bool] + _is_full_backward_hook: bool | None _forward_hooks: dict[int, Callable] # Marks whether the corresponding _forward_hooks accept kwargs or not. # As JIT does not support set[int], this dict is used as a set, where all @@ -477,7 +477,7 @@ def forward(self, x): _load_state_dict_post_hooks: dict[int, Callable] _modules: dict[str, Optional["Module"]] call_super_init: bool = False - _compiled_call_impl: Optional[Callable] = None + _compiled_call_impl: Callable | None = None def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" @@ -526,7 +526,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: forward: Callable[..., Any] = _forward_unimplemented def register_buffer( - self, name: str, tensor: Optional[Tensor], persistent: bool = True + self, name: str, tensor: Tensor | None, persistent: bool = True ) -> None: r"""Add a buffer to the module. @@ -589,7 +589,7 @@ def register_buffer( else: self._non_persistent_buffers_set.add(name) - def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + def register_parameter(self, name: str, param: Parameter | None) -> None: r"""Add a parameter to the module. The parameter can be accessed as an attribute using given name. @@ -1073,7 +1073,7 @@ def apply(self, fn: Callable[["Module"], None]) -> Self: fn(self) return self - def cuda(self, device: Optional[int | device] = None) -> Self: + def cuda(self, device: int | device | None = None) -> Self: r"""Move all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So @@ -1092,7 +1092,7 @@ def cuda(self, device: Optional[int | device] = None) -> Self: """ return self._apply(lambda t: t.cuda(device)) - def ipu(self, device: Optional[int | device] = None) -> Self: + def ipu(self, device: int | device | None = None) -> Self: r"""Move all model parameters and buffers to the IPU. This also makes associated parameters and buffers different objects. So @@ -1111,7 +1111,7 @@ def ipu(self, device: Optional[int | device] = None) -> Self: """ return self._apply(lambda t: t.ipu(device)) - def xpu(self, device: Optional[int | device] = None) -> Self: + def xpu(self, device: int | device | None = None) -> Self: r"""Move all model parameters and buffers to the XPU. This also makes associated parameters and buffers different objects. So @@ -1130,7 +1130,7 @@ def xpu(self, device: Optional[int | device] = None) -> Self: """ return self._apply(lambda t: t.xpu(device)) - def mtia(self, device: Optional[int | device] = None) -> Self: + def mtia(self, device: int | device | None = None) -> Self: r"""Move all model parameters and buffers to the MTIA. This also makes associated parameters and buffers different objects. So @@ -1218,9 +1218,7 @@ def bfloat16(self) -> Self: """ return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) - def to_empty( - self, *, device: Optional[DeviceLikeType], recurse: bool = True - ) -> Self: + def to_empty(self, *, device: DeviceLikeType | None, recurse: bool = True) -> Self: r"""Move the parameters and buffers to the specified device without copying storage. Args: @@ -1239,8 +1237,8 @@ def to_empty( @overload def to( self, - device: Optional[DeviceLikeType] = ..., - dtype: Optional[dtype] = ..., + device: DeviceLikeType | None = ..., + dtype: dtype | None = ..., non_blocking: bool = ..., ) -> Self: ... @@ -1623,9 +1621,9 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn) -> None: def register_forward_pre_hook( self, - hook: Callable[[T, tuple[Any, ...]], Optional[Any]] + hook: Callable[[T, tuple[Any, ...]], Any | None] | Callable[ - [T, tuple[Any, ...], dict[str, Any]], Optional[tuple[Any, dict[str, Any]]] + [T, tuple[Any, ...], dict[str, Any]], tuple[Any, dict[str, Any]] | None ], *, prepend: bool = False, @@ -1686,8 +1684,8 @@ def register_forward_pre_hook( def register_forward_hook( self, - hook: Callable[[T, tuple[Any, ...], Any], Optional[Any]] - | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]], + hook: Callable[[T, tuple[Any, ...], Any], Any | None] + | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Any | None], *, prepend: bool = False, with_kwargs: bool = False, @@ -2830,7 +2828,7 @@ def modules(self) -> Iterator["Module"]: def named_modules( self, - memo: Optional[set["Module"]] = None, + memo: set["Module"] | None = None, prefix: str = "", remove_duplicate: bool = True, ): diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 4a7302d5cae33..d492cdb3cf5a0 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import numbers -from typing import Optional, Union +from typing import Union import torch from torch import Size, Tensor @@ -375,13 +375,13 @@ class RMSNorm(Module): __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: tuple[int, ...] - eps: Optional[float] + eps: float | None elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, - eps: Optional[float] = None, + eps: float | None = None, elementwise_affine: bool = True, device=None, dtype=None, diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 777e6b0abd8c4..1dc57c25b1683 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch.nn.functional as F from torch import Tensor from torch.nn.common_types import ( @@ -57,7 +55,7 @@ class _MaxPoolNd(Module): def __init__( self, kernel_size: _size_any_t, - stride: Optional[_size_any_t] = None, + stride: _size_any_t | None = None, padding: _size_any_t = 0, dilation: _size_any_t = 1, return_indices: bool = False, @@ -389,7 +387,7 @@ class MaxUnpool1d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_1_t, - stride: Optional[_size_1_t] = None, + stride: _size_1_t | None = None, padding: _size_1_t = 0, ) -> None: super().__init__() @@ -398,7 +396,7 @@ def __init__( self.padding = _single(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool1d( @@ -485,7 +483,7 @@ class MaxUnpool2d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_2_t, - stride: Optional[_size_2_t] = None, + stride: _size_2_t | None = None, padding: _size_2_t = 0, ) -> None: super().__init__() @@ -494,7 +492,7 @@ def __init__( self.padding = _pair(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool2d( @@ -564,7 +562,7 @@ class MaxUnpool3d(_MaxUnpoolNd): def __init__( self, kernel_size: _size_3_t, - stride: Optional[_size_3_t] = None, + stride: _size_3_t | None = None, padding: _size_3_t = 0, ) -> None: super().__init__() @@ -573,7 +571,7 @@ def __init__( self.padding = _triple(padding) def forward( - self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None + self, input: Tensor, indices: Tensor, output_size: list[int] | None = None ) -> Tensor: """Runs the forward pass.""" return F.max_unpool3d( @@ -762,11 +760,11 @@ class AvgPool2d(_AvgPoolNd): def __init__( self, kernel_size: _size_2_t, - stride: Optional[_size_2_t] = None, + stride: _size_2_t | None = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, + divisor_override: int | None = None, ) -> None: super().__init__() self.kernel_size = kernel_size @@ -879,11 +877,11 @@ class AvgPool3d(_AvgPoolNd): def __init__( self, kernel_size: _size_3_t, - stride: Optional[_size_3_t] = None, + stride: _size_3_t | None = None, padding: _size_3_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, + divisor_override: int | None = None, ) -> None: super().__init__() self.kernel_size = kernel_size @@ -964,8 +962,8 @@ class FractionalMaxPool2d(Module): def __init__( self, kernel_size: _size_2_t, - output_size: Optional[_size_2_t] = None, - output_ratio: Optional[_ratio_2_t] = None, + output_size: _size_2_t | None = None, + output_ratio: _ratio_2_t | None = None, return_indices: bool = False, _random_samples=None, ) -> None: @@ -1050,8 +1048,8 @@ class FractionalMaxPool3d(Module): def __init__( self, kernel_size: _size_3_t, - output_size: Optional[_size_3_t] = None, - output_ratio: Optional[_ratio_3_t] = None, + output_size: _size_3_t | None = None, + output_ratio: _ratio_3_t | None = None, return_indices: bool = False, _random_samples=None, ) -> None: @@ -1106,7 +1104,7 @@ def __init__( self, norm_type: float, kernel_size: _size_any_t, - stride: Optional[_size_any_t] = None, + stride: _size_any_t | None = None, ceil_mode: bool = False, ) -> None: super().__init__() diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 13cd9ec08cb55..68e8292870fc8 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -4,7 +4,7 @@ import numbers import warnings import weakref -from typing import Optional, overload +from typing import overload from typing_extensions import deprecated import torch @@ -106,7 +106,7 @@ def __init__( self.dropout = float(dropout) self.bidirectional = bidirectional self.proj_size = proj_size - self._flat_weight_refs: list[Optional[weakref.ReferenceType[Parameter]]] = [] + self._flat_weight_refs: list[weakref.ReferenceType[Parameter] | None] = [] num_directions = 2 if bidirectional else 1 if ( @@ -298,7 +298,7 @@ def reset_parameters(self) -> None: for weight in self.parameters(): init.uniform_(weight, -stdv, stdv) - def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + def check_input(self, input: Tensor, batch_sizes: Tensor | None) -> None: if not torch.jit.is_scripting(): if ( input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr] @@ -318,7 +318,7 @@ def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: ) def get_expected_hidden_size( - self, input: Tensor, batch_sizes: Optional[Tensor] + self, input: Tensor, batch_sizes: Tensor | None ) -> tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) @@ -362,14 +362,14 @@ def _weights_have_changed(self): return weights_changed def check_forward_args( - self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + self, input: Tensor, hidden: Tensor, batch_sizes: Tensor | None ) -> None: self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) self.check_hidden_size(hidden, expected_hidden_size) - def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): + def permute_hidden(self, hx: Tensor, permutation: Tensor | None): if permutation is None: return hx return _apply_permutation(hx, permutation) @@ -645,7 +645,7 @@ def __init__(self, *args, **kwargs): def forward( self, input: Tensor, - hx: Optional[Tensor] = None, + hx: Tensor | None = None, ) -> tuple[Tensor, Tensor]: pass @@ -654,7 +654,7 @@ def forward( def forward( self, input: PackedSequence, - hx: Optional[Tensor] = None, + hx: Tensor | None = None, ) -> tuple[PackedSequence, Tensor]: pass @@ -990,7 +990,7 @@ def __init__(self, *args, **kwargs): super().__init__("LSTM", *args, **kwargs) def get_expected_cell_size( - self, input: Tensor, batch_sizes: Optional[Tensor] + self, input: Tensor, batch_sizes: Tensor | None ) -> tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) @@ -1010,7 +1010,7 @@ def check_forward_args( self, input: Tensor, hidden: tuple[Tensor, Tensor], # type: ignore[override] - batch_sizes: Optional[Tensor], + batch_sizes: Tensor | None, ) -> None: self.check_input(input, batch_sizes) self.check_hidden_size( @@ -1028,7 +1028,7 @@ def check_forward_args( def permute_hidden( # type: ignore[override] self, hx: tuple[Tensor, Tensor], - permutation: Optional[Tensor], + permutation: Tensor | None, ) -> tuple[Tensor, Tensor]: if permutation is None: return hx @@ -1042,7 +1042,7 @@ def permute_hidden( # type: ignore[override] def forward( self, input: Tensor, - hx: Optional[tuple[Tensor, Tensor]] = None, + hx: tuple[Tensor, Tensor] | None = None, ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811 pass @@ -1052,7 +1052,7 @@ def forward( def forward( self, input: PackedSequence, - hx: Optional[tuple[Tensor, Tensor]] = None, + hx: tuple[Tensor, Tensor] | None = None, ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811 pass @@ -1338,7 +1338,7 @@ def __init__(self, *args, **kwargs): def forward( self, input: Tensor, - hx: Optional[Tensor] = None, + hx: Tensor | None = None, ) -> tuple[Tensor, Tensor]: # noqa: F811 pass @@ -1347,7 +1347,7 @@ def forward( def forward( self, input: PackedSequence, - hx: Optional[Tensor] = None, + hx: Tensor | None = None, ) -> tuple[PackedSequence, Tensor]: # noqa: F811 pass @@ -1584,7 +1584,7 @@ def __init__( super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) self.nonlinearity = nonlinearity - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: if input.dim() not in (1, 2): raise ValueError( f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" @@ -1704,7 +1704,7 @@ def __init__( super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) def forward( - self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None ) -> tuple[Tensor, Tensor]: if input.dim() not in (1, 2): raise ValueError( @@ -1815,7 +1815,7 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: if input.dim() not in (1, 2): raise ValueError( f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index 83a8d6ef334bb..8ec531abce695 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch from torch import Tensor @@ -124,8 +123,8 @@ class Embedding(Module): num_embeddings: int embedding_dim: int - padding_idx: Optional[int] - max_norm: Optional[float] + padding_idx: int | None + max_norm: float | None norm_type: float scale_grad_by_freq: bool weight: Tensor @@ -136,12 +135,12 @@ def __init__( self, num_embeddings: int, embedding_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, + padding_idx: int | None = None, + max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, - _weight: Optional[Tensor] = None, + _weight: Tensor | None = None, _freeze: bool = False, device=None, dtype=None, @@ -362,27 +361,27 @@ class EmbeddingBag(Module): num_embeddings: int embedding_dim: int - max_norm: Optional[float] + max_norm: float | None norm_type: float scale_grad_by_freq: bool weight: Tensor mode: str sparse: bool include_last_offset: bool - padding_idx: Optional[int] + padding_idx: int | None def __init__( self, num_embeddings: int, embedding_dim: int, - max_norm: Optional[float] = None, + max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = "mean", sparse: bool = False, - _weight: Optional[Tensor] = None, + _weight: Tensor | None = None, include_last_offset: bool = False, - padding_idx: Optional[int] = None, + padding_idx: int | None = None, device=None, dtype=None, ) -> None: @@ -431,8 +430,8 @@ def _fill_padding_idx_with_zero(self) -> None: def forward( self, input: Tensor, - offsets: Optional[Tensor] = None, - per_sample_weights: Optional[Tensor] = None, + offsets: Tensor | None = None, + per_sample_weights: Tensor | None = None, ) -> Tensor: """Forward pass of EmbeddingBag. @@ -496,13 +495,13 @@ def from_pretrained( cls, embeddings: Tensor, freeze: bool = True, - max_norm: Optional[float] = None, + max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, mode: str = "mean", sparse: bool = False, include_last_offset: bool = False, - padding_idx: Optional[int] = None, + padding_idx: int | None = None, ) -> "EmbeddingBag": r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor. diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index ed35224423aa6..6841e85ed6d2e 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -2,7 +2,7 @@ import copy import warnings from collections.abc import Callable -from typing import Any, Optional +from typing import Any import torch import torch.nn.functional as F @@ -28,8 +28,8 @@ def _generate_square_subsequent_mask( sz: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. @@ -41,7 +41,7 @@ def _generate_square_subsequent_mask( ) -def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: +def _get_seq_len(src: Tensor, batch_first: bool) -> int | None: if src.is_nested: return None else: @@ -106,8 +106,8 @@ def __init__( dim_feedforward: int = 2048, dropout: float = 0.1, activation: str | Callable[[Tensor], Tensor] = F.relu, - custom_encoder: Optional[Any] = None, - custom_decoder: Optional[Any] = None, + custom_encoder: Any | None = None, + custom_decoder: Any | None = None, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, @@ -182,14 +182,14 @@ def forward( self, src: Tensor, tgt: Tensor, - src_mask: Optional[Tensor] = None, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - src_is_causal: Optional[bool] = None, - tgt_is_causal: Optional[bool] = None, + src_mask: Tensor | None = None, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, + src_is_causal: bool | None = None, + tgt_is_causal: bool | None = None, memory_is_causal: bool = False, ) -> Tensor: r"""Take in and process masked source/target sequences. @@ -301,8 +301,8 @@ def forward( @staticmethod def generate_square_subsequent_mask( sz: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. @@ -354,7 +354,7 @@ def __init__( self, encoder_layer: "TransformerEncoderLayer", num_layers: int, - norm: Optional[Module] = None, + norm: Module | None = None, enable_nested_tensor: bool = True, mask_check: bool = True, ) -> None: @@ -407,9 +407,9 @@ def __init__( def forward( self, src: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - is_causal: Optional[bool] = None, + mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, + is_causal: bool | None = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -588,7 +588,7 @@ def __init__( self, decoder_layer: "TransformerDecoderLayer", num_layers: int, - norm: Optional[Module] = None, + norm: Module | None = None, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") @@ -600,11 +600,11 @@ def forward( self, tgt: Tensor, memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - tgt_is_causal: Optional[bool] = None, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, + tgt_is_causal: bool | None = None, memory_is_causal: bool = False, ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer in turn. @@ -799,8 +799,8 @@ def __setstate__(self, state): def forward( self, src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + src_mask: Tensor | None = None, + src_key_padding_mask: Tensor | None = None, is_causal: bool = False, ) -> Tensor: r"""Pass the input through the encoder layer. @@ -961,8 +961,8 @@ def forward( def _sa_block( self, x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, is_causal: bool = False, ) -> Tensor: x = self.self_attn( @@ -1090,10 +1090,10 @@ def forward( self, tgt: Tensor, memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, + tgt_mask: Tensor | None = None, + memory_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + memory_key_padding_mask: Tensor | None = None, tgt_is_causal: bool = False, memory_is_causal: bool = False, ) -> Tensor: @@ -1158,8 +1158,8 @@ def forward( def _sa_block( self, x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, is_causal: bool = False, ) -> Tensor: x = self.self_attn( @@ -1178,8 +1178,8 @@ def _mha_block( self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], + attn_mask: Tensor | None, + key_padding_mask: Tensor | None, is_causal: bool = False, ) -> Tensor: x = self.multihead_attn( @@ -1214,9 +1214,9 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: def _detect_is_causal_mask( - mask: Optional[Tensor], - is_causal: Optional[bool] = None, - size: Optional[int] = None, + mask: Tensor | None, + is_causal: bool | None = None, + size: int | None = None, ) -> bool: """Return whether the given attention mask is causal. diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index 7fd102a768225..29e58bc6a9f37 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch.nn.functional as F from torch import Tensor @@ -143,19 +142,19 @@ class Upsample(Module): "recompute_scale_factor", ] name: str - size: Optional[_size_any_t] - scale_factor: Optional[_ratio_any_t] + size: _size_any_t | None + scale_factor: _ratio_any_t | None mode: str - align_corners: Optional[bool] - recompute_scale_factor: Optional[bool] + align_corners: bool | None + recompute_scale_factor: bool | None def __init__( self, - size: Optional[_size_any_t] = None, - scale_factor: Optional[_ratio_any_t] = None, + size: _size_any_t | None = None, + scale_factor: _ratio_any_t | None = None, mode: str = "nearest", - align_corners: Optional[bool] = None, - recompute_scale_factor: Optional[bool] = None, + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, ) -> None: super().__init__() self.name = type(self).__name__ @@ -242,8 +241,8 @@ class UpsamplingNearest2d(Upsample): def __init__( self, - size: Optional[_size_2_t] = None, - scale_factor: Optional[_ratio_2_t] = None, + size: _size_2_t | None = None, + scale_factor: _ratio_2_t | None = None, ) -> None: super().__init__(size, scale_factor, mode="nearest") @@ -293,7 +292,7 @@ class UpsamplingBilinear2d(Upsample): def __init__( self, - size: Optional[_size_2_t] = None, - scale_factor: Optional[_ratio_2_t] = None, + size: _size_2_t | None = None, + scale_factor: _ratio_2_t | None = None, ) -> None: super().__init__(size, scale_factor, mode="bilinear", align_corners=True) diff --git a/torch/overrides.py b/torch/overrides.py index e0597eafd8107..b1193bab3d6dc 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -30,7 +30,7 @@ import warnings from collections.abc import Callable, Iterable from functools import wraps -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch @@ -1609,7 +1609,7 @@ def wrapped(*args, **kwargs): def _get_overloaded_args( relevant_args: Iterable[Any], - get_type_fn: Optional[Callable[[Any], type]] = None, + get_type_fn: Callable[[Any], type] | None = None, ) -> list[Any]: """Returns a list of arguments on which to call __torch_function__. diff --git a/torch/quasirandom.py b/torch/quasirandom.py index b5d4540e592f1..f9e6619cab180 100644 --- a/torch/quasirandom.py +++ b/torch/quasirandom.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional import torch @@ -78,8 +77,8 @@ def __init__(self, dimension, scramble=False, seed=None): def draw( self, n: int = 1, - out: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, + out: torch.Tensor | None = None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: r""" Function to draw a sequence of :attr:`n` points from a Sobol sequence. @@ -131,8 +130,8 @@ def draw( def draw_base2( self, m: int, - out: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, + out: torch.Tensor | None = None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: r""" Function to draw a sequence of :attr:`2**m` points from a Sobol sequence. @@ -187,7 +186,7 @@ def fast_forward(self, n): return self def _scramble(self): - g: Optional[torch.Generator] = None + g: torch.Generator | None = None if self.seed is not None: g = torch.Generator() g.manual_seed(self.seed) diff --git a/torch/serialization.py b/torch/serialization.py index 398d011f324b5..1a6acc8010634 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -16,7 +16,7 @@ from collections.abc import Callable from contextlib import closing, contextmanager from enum import Enum -from typing import Any, cast, Generic, IO, Optional, TypeAlias, TypeVar, Union +from typing import Any, cast, Generic, IO, TypeAlias, TypeVar from typing_extensions import TypeIs import torch @@ -66,10 +66,10 @@ PROTOCOL_VERSION = 1001 STORAGE_KEY_SEPARATOR = "," -MAP_LOCATION: TypeAlias = Optional[ - Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]] -] -STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] +MAP_LOCATION: TypeAlias = ( + Callable[[Storage, str], Storage] | torch.device | str | dict[str, str] | None +) +STORAGE: TypeAlias = Storage | torch.storage.TypedStorage | torch.UntypedStorage IS_WINDOWS = sys.platform == "win32" @@ -99,7 +99,7 @@ def _default_to_weights_only(pickle_module): class _SerializationLocal(threading.local): def __init__(self): super().__init__() - self.map_location: Optional[MAP_LOCATION] = None + self.map_location: MAP_LOCATION | None = None self.skip_data: bool = False self.materialize_fake_tensors: bool = False @@ -123,8 +123,8 @@ def mkdtemp(): _package_registry: list[ tuple[ int, - Callable[[STORAGE], Optional[str]], - Callable[[STORAGE, str], Optional[STORAGE]], + Callable[[STORAGE], str | None], + Callable[[STORAGE, str], STORAGE | None], ] ] = [] @@ -135,7 +135,7 @@ class LoadEndianness(Enum): BIG = 3 -def get_default_load_endianness() -> Optional[LoadEndianness]: +def get_default_load_endianness() -> LoadEndianness | None: """ Get fallback byte order for loading files @@ -197,7 +197,7 @@ def set_crc32_options(compute_crc32: bool): config.save.compute_crc32 = compute_crc32 -def get_default_mmap_options() -> Optional[int]: +def get_default_mmap_options() -> int | None: """ Get default mmap options for :func:`torch.load` with ``mmap=True``. @@ -272,14 +272,14 @@ def clear_safe_globals() -> None: _weights_only_unpickler._clear_safe_globals() -def get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]: +def get_safe_globals() -> list[Callable | tuple[Callable, str]]: """ Returns the list of user-added globals that are safe for ``weights_only`` load. """ return _weights_only_unpickler._get_safe_globals() -def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]) -> None: +def add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]) -> None: """ Marks the given globals as safe for ``weights_only`` load. For example, functions added to this list can be called during unpickling, classes could be instantiated @@ -443,8 +443,8 @@ def _is_zipfile(f) -> bool: def register_package( priority: int, - tagger: Callable[[STORAGE], Optional[str]], - deserializer: Callable[[STORAGE, str], Optional[STORAGE]], + tagger: Callable[[STORAGE], str | None], + deserializer: Callable[[STORAGE, str], STORAGE | None], ): """ Registers callables for tagging and deserializing storage objects with an associated priority. @@ -672,7 +672,7 @@ def _deserialize(backend_name, obj, location): def location_tag( - storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage], + storage: Storage | torch.storage.TypedStorage | torch.UntypedStorage, ): for _, tagger, _ in _package_registry: location = tagger(storage) @@ -726,7 +726,7 @@ def storage_to_tensor_type(storage): return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) -def _is_path(name_or_buffer: object) -> TypeIs[Union[str, os.PathLike]]: +def _is_path(name_or_buffer: object) -> TypeIs[str | os.PathLike]: return isinstance(name_or_buffer, (str, os.PathLike)) @@ -745,7 +745,7 @@ def __exit__(self, *args): class _open_file(_opener[IO[bytes]]): - def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None: + def __init__(self, name: str | os.PathLike[str], mode: str) -> None: super().__init__(open(name, mode)) # noqa: SIM115 def __exit__(self, *args): @@ -776,7 +776,7 @@ def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]: class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]): - def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None: + def __init__(self, name_or_buffer: str | IO[bytes]) -> None: super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) @@ -829,7 +829,7 @@ def __exit__(self, *args) -> None: self.buffer.flush() -def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener: +def _open_zipfile_writer(name_or_buffer: str | IO[bytes]) -> _opener: container: type[_opener] if _is_path(name_or_buffer): container = _open_zipfile_writer_file @@ -1004,7 +1004,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: # TODO: This feature could be added in the future storage_dtypes: dict[int, torch.dtype] = {} - def persistent_id(obj: Any) -> Optional[tuple]: + def persistent_id(obj: Any) -> tuple | None: # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see @@ -1064,7 +1064,7 @@ def persistent_id(obj: Any) -> Optional[tuple]: else: storage_dtypes[storage.data_ptr()] = storage_dtype - view_metadata: Optional[tuple[str, int, int]] + view_metadata: tuple[str, int, int] | None # Offset is always 0, but we keep it for backwards compatibility # with the old serialization format (which supported storage views) @@ -1291,8 +1291,8 @@ def load( map_location: MAP_LOCATION = None, pickle_module: Any = None, *, - weights_only: Optional[bool] = None, - mmap: Optional[bool] = None, + weights_only: bool | None = None, + mmap: bool | None = None, **pickle_load_args: Any, ) -> Any: # Reference: https://github.com/pytorch/pytorch/issues/54354 @@ -1852,7 +1852,7 @@ def persistent_load(saved_id): return result -def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: +def _maybe_decode_ascii(bytes_str: bytes | str) -> str: # When using encoding='bytes' in Py3, some **internal** keys stored as # strings in Py2 are loaded as bytes. This function decodes them with # ascii encoding, one that Py3 uses by default. diff --git a/torch/storage.py b/torch/storage.py index 1b9023121ddfb..29847d958523d 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -8,7 +8,7 @@ import io import threading import warnings -from typing import Any, cast, Optional as _Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, cast, TYPE_CHECKING, TypeVar from typing_extensions import Self import torch @@ -35,7 +35,7 @@ _share_memory_lock = threading.Lock() _share_memory_map: dict[int, threading.RLock] = {} -T = TypeVar("T", bound="Union[_StorageBase, TypedStorage]") +T = TypeVar("T", bound="_StorageBase | TypedStorage") class _StorageBase: @@ -46,9 +46,9 @@ class _StorageBase: # Used when # (1) stashing FakeTensor device onto storage in torch.serialization.skip_data # (2) stashing device onto storage to propagate to FakeTensor when torch.load under FakeTensorMode - _fake_device: _Optional[torch.device] = None + _fake_device: torch.device | None = None # Used when loading with FakeTensorMode to give information about offset of storage in torch.saved-file - _checkpoint_offset: _Optional[int] = None + _checkpoint_offset: int | None = None def __init__(self, *args, **kwargs): pass @@ -62,10 +62,10 @@ def __getitem__(self, idx): def __setitem__(self, *args, **kwargs): raise NotImplementedError - def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T: + def copy_(self, source: T, non_blocking: _bool | None = None) -> T: raise NotImplementedError - def new(self) -> Union[_StorageBase, TypedStorage]: + def new(self) -> _StorageBase | TypedStorage: raise NotImplementedError def nbytes(self) -> _int: @@ -75,13 +75,11 @@ def size(self) -> _int: return self.nbytes() def type( - self, dtype: _Optional[str] = None, non_blocking: _bool = False - ) -> Union[_StorageBase, TypedStorage]: + self, dtype: str | None = None, non_blocking: _bool = False + ) -> _StorageBase | TypedStorage: return _type(self, dtype, non_blocking) - def cuda( - self, device=None, non_blocking=False - ) -> Union[_StorageBase, TypedStorage]: + def cuda(self, device=None, non_blocking=False) -> _StorageBase | TypedStorage: """Returns a copy of this object in CUDA memory. If this object is already in CUDA memory and on the correct device, then @@ -96,7 +94,7 @@ def cuda( device2 = torch.device("cuda", device) if device else torch.device("cuda") return self.to(device=device2, non_blocking=non_blocking) - def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]: + def hpu(self, device=None, non_blocking=False) -> _StorageBase | TypedStorage: """Returns a copy of this object in HPU memory. If this object is already in HPU memory and on the correct device, then @@ -166,7 +164,7 @@ def _release_ipc_counter_cuda(cls, *args, **kwargs) -> Self: def _new_with_weak_ptr(cls, *args, **kwargs) -> Self: raise NotImplementedError - def _shared_decref(self) -> Union[_StorageBase, TypedStorage]: + def _shared_decref(self) -> _StorageBase | TypedStorage: raise NotImplementedError def _write_file(self, *args, **kwargs): @@ -175,7 +173,7 @@ def _write_file(self, *args, **kwargs): def resize_(self, size: _int): raise NotImplementedError - def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: + def _weak_ref(self, *args, **kwargs) -> _StorageBase | TypedStorage: raise NotImplementedError def _set_from_file(self, *args, **kwargs): @@ -210,17 +208,17 @@ def is_hpu(self): raise NotImplementedError @classmethod - def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]: + def from_file(cls, filename, shared, nbytes) -> _StorageBase | TypedStorage: raise NotImplementedError @classmethod - def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: + def _expired(cls, *args, **kwargs) -> _StorageBase | TypedStorage: raise NotImplementedError def _byteswap(self, *args, **kwargs): raise NotImplementedError - def _get_filename(self, *args, **kwargs) -> _Optional[str]: + def _get_filename(self, *args, **kwargs) -> str | None: raise NotImplementedError def __repr__(self): @@ -354,7 +352,7 @@ def float8_e4m3fnuz(self): """Casts this storage to float8_e4m3fnuz type""" return self._to(torch.float8_e4m3fnuz) - def is_pinned(self, device: Union[str, torch.device] = "cuda"): + def is_pinned(self, device: str | torch.device = "cuda"): r"""Determine whether the CPU storage is already pinned on device. Args: @@ -370,7 +368,7 @@ def is_pinned(self, device: Union[str, torch.device] = "cuda"): .is_pinned(device) ) - def pin_memory(self, device: Union[str, torch.device] = "cuda"): + def pin_memory(self, device: str | torch.device = "cuda"): r"""Copy the CPU storage to pinned memory, if it's not already pinned. Args: @@ -478,7 +476,7 @@ def is_hpu(self): return self.device.type == "hpu" @property - def filename(self) -> _Optional[str]: + def filename(self) -> str | None: """Returns the file name associated with this storage. The file name will be a string if the storage is on CPU and was created via @@ -671,7 +669,7 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: _Optional[torch.device] = None + _fake_device: torch.device | None = None dtype: torch.dtype @@ -680,7 +678,7 @@ def _dtype(self): return self.dtype @property - def filename(self) -> _Optional[str]: + def filename(self) -> str | None: """Returns the file name associated with this storage if the storage was memory mapped from a file. or ``None`` if the storage was not created by memory mapping a file.""" return self._untyped_storage.filename @@ -1018,7 +1016,7 @@ def _getitem(self, idx): ).set_(self) return tmp_tensor[idx_wrapped].item() - def copy_(self, source: T, non_blocking: _Optional[bool] = None): + def copy_(self, source: T, non_blocking: bool | None = None): _warn_typed_storage_removal() if isinstance(source, TypedStorage): self._untyped_storage.copy_(source._untyped_storage, non_blocking) @@ -1036,9 +1034,9 @@ def _nbytes(self): def type( self, - dtype: _Optional[str] = None, + dtype: str | None = None, non_blocking: bool = False, - ) -> Union[_StorageBase, TypedStorage, str]: + ) -> _StorageBase | TypedStorage | str: _warn_typed_storage_removal() if dtype is None: legacy_class = self._get_legacy_storage_class() @@ -1157,7 +1155,7 @@ def cpu(self): _warn_typed_storage_removal() return self._new_wrapped_storage(self._untyped_storage.cpu()) - def is_pinned(self, device: Union[str, torch.device] = "cuda"): + def is_pinned(self, device: str | torch.device = "cuda"): r"""Determine whether the CPU TypedStorage is already pinned on device. Args: @@ -1170,7 +1168,7 @@ def is_pinned(self, device: Union[str, torch.device] = "cuda"): _warn_typed_storage_removal() return self._untyped_storage.is_pinned(device) - def pin_memory(self, device: Union[str, torch.device] = "cuda"): + def pin_memory(self, device: str | torch.device = "cuda"): r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned. Args: diff --git a/torch/types.py b/torch/types.py index 0388c9c66aefe..9ed69a859b1ee 100644 --- a/torch/types.py +++ b/torch/types.py @@ -38,7 +38,7 @@ # Convenience aliases for common composite types that we need # to talk about in PyTorch -_TensorOrTensors: TypeAlias = Union[Tensor, Sequence[Tensor]] # noqa: PYI047 +_TensorOrTensors: TypeAlias = Tensor | Sequence[Tensor] # noqa: PYI047 _TensorOrTensorsOrGradEdge: TypeAlias = Union[ # noqa: PYI047 Tensor, Sequence[Tensor], @@ -46,32 +46,32 @@ Sequence["GradientEdge"], ] -_size: TypeAlias = Union[Size, list[int], tuple[int, ...]] # noqa: PYI042,PYI047 -_symsize: TypeAlias = Union[Size, Sequence[Union[int, SymInt]]] # noqa: PYI042,PYI047 -_dispatchkey: TypeAlias = Union[str, DispatchKey] # noqa: PYI042,PYI047 +_size: TypeAlias = Size | list[int] | tuple[int, ...] # noqa: PYI042,PYI047 +_symsize: TypeAlias = Size | Sequence[int | SymInt] # noqa: PYI042,PYI047 +_dispatchkey: TypeAlias = str | DispatchKey # noqa: PYI042,PYI047 # int or SymInt -IntLikeType: TypeAlias = Union[int, SymInt] +IntLikeType: TypeAlias = int | SymInt # float or SymFloat -FloatLikeType: TypeAlias = Union[float, SymFloat] +FloatLikeType: TypeAlias = float | SymFloat # bool or SymBool -BoolLikeType: TypeAlias = Union[bool, SymBool] +BoolLikeType: TypeAlias = bool | SymBool py_sym_types = (SymInt, SymFloat, SymBool) # left un-annotated intentionally -PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool] +PySymType: TypeAlias = SymInt | SymFloat | SymBool # Meta-type for "numeric" things; matches our docs -Number: TypeAlias = Union[int, float, bool] +Number: TypeAlias = int | float | bool # tuple for isinstance(x, Number) checks. # FIXME: refactor once python 3.9 support is dropped. _Number = (int, float, bool) -FileLike: TypeAlias = Union[str, os.PathLike[str], IO[bytes]] +FileLike: TypeAlias = str | os.PathLike[str] | IO[bytes] # Meta-type for "device-like" things. Not to be confused with 'device' (a # literal device object). This nomenclature is consistent with PythonArgParser. # None means use the default device (typically CPU) -Device: TypeAlias = Union[_device, str, int, None] +Device: TypeAlias = _device | str | int | None # Storage protocol implemented by ${Type}StorageBase classes diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 93481a622494b..1dd8d6684f0e2 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -218,7 +218,7 @@ def set_device(device: _device_t) -> None: torch._C._xpu_setDevice(device) -def get_device_name(device: Optional[_device_t] = None) -> str: +def get_device_name(device: _device_t | None = None) -> str: r"""Get the name of a device. Args: @@ -234,7 +234,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str: @lru_cache(None) -def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: +def get_device_capability(device: _device_t | None = None) -> dict[str, Any]: r"""Get the xpu capability of a device. Args: @@ -259,7 +259,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: def get_device_properties( - device: Optional[_device_t] = None, + device: _device_t | None = None, ) -> _XpuDeviceProperties: # pyrefly: ignore # not-a-type r"""Get the properties of a device. @@ -281,7 +281,7 @@ def current_device() -> int: return torch._C._xpu_getDevice() -def _get_device(device: Union[int, str, torch.device]) -> torch.device: +def _get_device(device: int | str | torch.device) -> torch.device: r"""Return the torch.device type object from the passed in device. Args: @@ -395,7 +395,7 @@ def set_stream(stream: Stream) -> None: ) -def current_stream(device: Optional[_device_t] = None) -> Stream: +def current_stream(device: _device_t | None = None) -> Stream: r"""Return the currently selected :class:`Stream` for a given device. Args: @@ -413,9 +413,7 @@ def current_stream(device: Optional[_device_t] = None) -> Stream: ) -def get_stream_from_external( - data_ptr: int, device: Optional[_device_t] = None -) -> Stream: +def get_stream_from_external(data_ptr: int, device: _device_t | None = None) -> Stream: r"""Return a :class:`Stream` from an external SYCL queue. This function is used to wrap SYCL queue created in other libraries in order @@ -484,7 +482,7 @@ def _get_generator(device: torch.device) -> torch._C.Generator: def _set_rng_state_offset( - offset: int, device: Union[int, str, torch.device] = "xpu" + offset: int, device: int | str | torch.device = "xpu" ) -> None: r"""Set the random number generator state offset of the specified GPU. @@ -502,7 +500,7 @@ def cb() -> None: _lazy_call(cb) -def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: +def _get_rng_state_offset(device: int | str | torch.device = "xpu") -> int: r"""Return the random number generator state offset of the specified GPU. Args: diff --git a/torch/xpu/random.py b/torch/xpu/random.py index ec770225aef39..8b489e871f7c5 100644 --- a/torch/xpu/random.py +++ b/torch/xpu/random.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs from collections.abc import Iterable -from typing import Union import torch from torch import Tensor @@ -8,7 +7,7 @@ from . import _lazy_call, _lazy_init, current_device, device_count -def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor: +def get_rng_state(device: int | str | torch.device = "xpu") -> Tensor: r"""Return the random number generator state of the specified GPU as a ByteTensor. Args: @@ -36,9 +35,7 @@ def get_rng_state_all() -> list[Tensor]: return results -def set_rng_state( - new_state: Tensor, device: Union[int, str, torch.device] = "xpu" -) -> None: +def set_rng_state(new_state: Tensor, device: int | str | torch.device = "xpu") -> None: r"""Set the random number generator state of the specified GPU. Args: From c0cb6e78404416d418350632bfc554710a5f7281 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 3 Dec 2025 17:44:33 -0800 Subject: [PATCH 260/338] [DTensor] ExplicitRedistributionContext warning mode (#169452) Allow issuing warnings instead of raising exceptions. - easier to capture all issues in one pass Pull Request resolved: https://github.com/pytorch/pytorch/pull/169452 Approved by: https://github.com/malfet --- test/distributed/tensor/test_utils.py | 16 +++++++++++++ torch/distributed/tensor/_dispatch.py | 9 ++++---- torch/distributed/tensor/_utils.py | 33 +++++++++++++++++++++++---- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 5f3225d174cb2..871b8e19f4c41 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -1102,6 +1102,22 @@ def test_explicit_matmul(self): with ExplicitRedistributionContext(): with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"): torch.matmul(dx, dA) + with ExplicitRedistributionContext(mode="warn"): + with self.assertLogs( + torch.distributed.tensor._utils.logger, level="WARN" + ) as captured: + torch.matmul(dx, dA) + self.assertEqual(len(captured.output), 1) + self.assertRegex( + captured.output[0], + r"WARNING:.*Implicit redistribution occurred", + ) + # TODO enable this once fixing the issue that op_info.schema is None in some calls to + # redistribute_local_tensor + # self.assertRegex( + # captured.output[0], + # r".*aten\.mm\.default.*", + # ) # explicit redistribute allows manual redistribute with ExplicitRedistributionContext(): diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 56c9cb1a94783..4c05b52428198 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -459,14 +459,13 @@ def redistribute_local_args( if debug_mode is not None else contextlib.nullcontext() ) - if not ExplicitRedistributionContext.is_redistribute_allowed( + ExplicitRedistributionContext.observe_redistribution( arg_spec, # pyrefly: ignore [bad-argument-type] reshard_arg_spec, - ): - raise RuntimeError( - f"Implicit redistribution occurred for {op_info.schema} while ExplicitRedistributionContext was active" - ) + message=f"Implicit redistribution occurred for {op_info.schema} " + "while ExplicitRedistributionContext was active", + ) with redistribute_context: resharded_local_tensor = redistribute_local_tensor( local_tensor, diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index aa65dbc08529f..9dc9d188faf61 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,3 +1,4 @@ +import logging import threading from collections.abc import Sequence from typing import Any, cast, Optional @@ -19,6 +20,9 @@ ) +logger = logging.getLogger(__name__) + + class ExplicitRedistributionContext: """ Within this context manager, DTensor will refuse to perform implicit redistribution, @@ -29,22 +33,41 @@ class ExplicitRedistributionContext: may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op during forward and perform a manual redistribution during backwards. + + enable (bool) if False, disables the context manager. Can be used nested inside an enabled region. + + strict (bool) if True, triggers on any redistribution. If False, only triggers on redistributions that perform + communication. + + mode (str) Determines what happens when ExplicitRedistributionContext triggers: + "raise": raises an exceptoin, "warn" issues a warning """ _local = threading.local() - def __init__(self, enable: bool = True, strict: bool = False): + def __init__(self, enable: bool = True, strict: bool = False, mode="raise"): self._enable = enable self._strict = strict + if mode not in ("raise", "warn"): + raise RuntimeError(f"Invalid mode {mode}") + self._raise_on_redistribution = mode == "raise" @classmethod - def is_redistribute_allowed(cls, src_spec: DTensorSpec, dst_spec: DTensorSpec): + def observe_redistribution( + cls, src_spec: DTensorSpec, dst_spec: DTensorSpec, message: str + ): if instance := getattr(cls._local, "_active", None): + allowed = True if instance._enable: if instance._strict: - return False - return redistribute_cost(src_spec, dst_spec) <= 0 - return True + allowed = False + else: + allowed = redistribute_cost(src_spec, dst_spec) <= 0 + if not allowed: + if instance._raise_on_redistribution: + raise RuntimeError(message) + else: + logger.warning(message) def __enter__(self): self._prev = getattr(ExplicitRedistributionContext._local, "_active", None) From 685ba6bc01170c0cb793f872e56164df89b63573 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 3 Dec 2025 22:07:11 -0800 Subject: [PATCH 261/338] add back legalize_graph for BC reason (#169541) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169541 Approved by: https://github.com/huydhn --- test/allowlist_for_publicAPI.json | 1 + torch/fx/passes/utils/fuser_utils.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index b6c203aea4ab6..bd6f29d37fbb3 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2090,6 +2090,7 @@ "SimpleQueue", "Tuple", "compatibility", + "legalize_graph", "stable_topological_sort", "lift_subgraph_as_module" ], diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index e5509187b39dd..0571c92f61b76 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -7,7 +7,12 @@ from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule from torch.fx.node import Node -from torch.fx.passes.tools_common import NodeList, NodeSet, stable_topological_sort +from torch.fx.passes.tools_common import ( # noqa: F401 + legalize_graph, + NodeList, + NodeSet, + stable_topological_sort, +) from torch.fx.passes.utils import lift_subgraph_as_module # type: ignore[attr-defined] From a2b5dfb956aed182f6aefce1ff2eda70c35049e1 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 3 Dec 2025 10:27:39 -0800 Subject: [PATCH 262/338] Restore ability to use global stableivalue from without template (#169475) https://github.com/pytorch/pytorch/pull/168155 was needed to fix Windows CI in torchaudio that looked like such
click for example of torchaudio windows CI error
``` 2025-11-15T21:11:03.9005985Z C:/actions-runner/_work/audio/audio/pytorch/audio/env/Lib/site-packages/torch/include\torch/csrc/stable/stableivalue_conversions.h(244): error: more than one instance of overloaded function "torch::stable::detail::from" matches the argument list: 2025-11-15T21:11:03.9007831Z function template "StableIValue from(T)" (declared at line 593) 2025-11-15T21:11:03.9008639Z function template "StableIValue torch::stable::detail::from(T)" (declared at line 528) 2025-11-15T21:11:03.9009336Z argument types are: (StableListHandle) 2025-11-15T21:11:03.9009839Z return from(new_list_handle); 2025-11-15T21:11:03.9010244Z ^ 2025-11-15T21:11:03.9011886Z C:/actions-runner/_work/audio/audio/pytorch/audio/env/Lib/site-packages/torch/include\torch/csrc/stable/stableivalue_conversions.h(541): note #3326-D: function "torch::stable::detail::from(const torch::stable::Tensor &)" does not match because argument #1 does not match parameter 2025-11-15T21:11:03.9013826Z [[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) { 2025-11-15T21:11:03.9014403Z ^ 2025-11-15T21:11:03.9016129Z C:/actions-runner/_work/audio/audio/pytorch/audio/env/Lib/site-packages/torch/include\torch/csrc/stable/stableivalue_conversions.h(534): note #3327-D: candidate function template "torch::stable::detail::from(const std::optional &)" failed deduction 2025-11-15T21:11:03.9017869Z inline StableIValue from(const std::optional& val) { 2025-11-15T21:11:03.9018335Z ^ 2025-11-15T21:11:03.9019885Z C:/actions-runner/_work/audio/audio/pytorch/audio/env/Lib/site-packages/torch/include\torch/csrc/stable/stableivalue_conversions.h(609): note #3326-D: function "from(const torch::stable::Tensor &)" does not match because argument #1 does not match parameter 2025-11-15T21:11:03.9021652Z from(const torch::stable::Tensor& val) { 2025-11-15T21:11:03.9022058Z ^ 2025-11-15T21:11:03.9023430Z C:/actions-runner/_work/audio/audio/pytorch/audio/env/Lib/site-packages/torch/include\torch/csrc/stable/stableivalue_conversions.h(601): note #3327-D: candidate function template "from(const std::optional &)" failed deduction 2025-11-15T21:11:03.9025327Z inline StableIValue from(const std::optional& val) { 2025-11-15T21:11:03.9025793Z ^ 2025-11-15T21:11:03.9026102Z detected during: 2025-11-15T21:11:03.9027321Z instantiation of "StableIValue torch::stable::detail::FromImpl>::call(const c10::HeaderOnlyArrayRef &, uint64_t, __nv_bool) [with T=int64_t]" at line 529 2025-11-15T21:11:03.9029527Z instantiation of "StableIValue torch::stable::detail::from(T) [with T=torch::headeronly::IntHeaderOnlyArrayRef]" at line 319 of C:/actions-runner/_work/audio/audio/pytorch/audio/env/Lib/site-packages/torch/include\torch/csrc/stable/ops.h 2025-11-15T21:11:03.9030992Z 2025-11-15T21:11:03.9031753Z 1 error detected in the compilation of "C:/actions-runner/_work/audio/audio/pytorch/audio/src/libtorchaudio/forced_align/gpu/compute.cu" ```
But this broke BC in that after that PR `from(...)` is no longer usable without template arguments, which makes the code in fa3 https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api_stable.cpp#L1797-L1800 no longer compilable in 2.10 We could update the code in FA3, but that might require ifdefs for 2.9 vs 2.10 -- as a general principle for stable extensions, I'm not sure whether updating the extension code or not breaking BC of the headers is what we should go with here. But I'm leaning towards the latter. This PR takes the alternative approach of restoring torchaudio Windows CI sanity by replacing all `{from/to}` in torch/csrc/stable/stableivalue_conversions.h with `torch::stable::detail::{from/to}` rather than making the `from`/`to` in the global namespace a function pointer Confirmed that audio CI passes https://github.com/pytorch/audio/pull/4133 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169475 Approved by: https://github.com/albanD --- .../libtorch_agnostic_2_9/csrc/kernel.cpp | 13 +- torch/csrc/stable/stableivalue_conversions.h | 155 +++++++++++------- 2 files changed, 104 insertions(+), 64 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp index 0304dfd8f0f4c..cf50a4d70e6d7 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp @@ -67,12 +67,23 @@ Tensor sgd_out_of_place( return out; } +void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor res = sgd_out_of_place( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1]), + float(torch::stable::detail::to(stack[2])), + torch::stable::detail::to(stack[3]), + torch::stable::detail::to(stack[4])); + + stack[0] = from(res); +} + STABLE_TORCH_LIBRARY(libtorch_agnostic_2_9, m) { m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CPU, m) { - m.impl("sgd_out_of_place", TORCH_BOX(&sgd_out_of_place)); + m.impl("sgd_out_of_place", &boxed_sgd_out_of_place); } Tensor identity(Tensor t) { diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index c44e656d88e11..708139836411a 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -111,45 +111,45 @@ struct FromImpl { [[maybe_unused]] bool is_internal) { switch (val) { case ScalarType::Byte: - return from(aoti_torch_dtype_uint8()); + return torch::stable::detail::from(aoti_torch_dtype_uint8()); case ScalarType::Char: - return from(aoti_torch_dtype_int8()); + return torch::stable::detail::from(aoti_torch_dtype_int8()); case ScalarType::Short: - return from(aoti_torch_dtype_int16()); + return torch::stable::detail::from(aoti_torch_dtype_int16()); case ScalarType::Int: - return from(aoti_torch_dtype_int32()); + return torch::stable::detail::from(aoti_torch_dtype_int32()); case ScalarType::Long: - return from(aoti_torch_dtype_int64()); + return torch::stable::detail::from(aoti_torch_dtype_int64()); case ScalarType::Half: - return from(aoti_torch_dtype_float16()); + return torch::stable::detail::from(aoti_torch_dtype_float16()); case ScalarType::Float: - return from(aoti_torch_dtype_float32()); + return torch::stable::detail::from(aoti_torch_dtype_float32()); case ScalarType::Double: - return from(aoti_torch_dtype_float64()); + return torch::stable::detail::from(aoti_torch_dtype_float64()); case ScalarType::ComplexHalf: - return from(aoti_torch_dtype_complex32()); + return torch::stable::detail::from(aoti_torch_dtype_complex32()); case ScalarType::ComplexFloat: - return from(aoti_torch_dtype_complex64()); + return torch::stable::detail::from(aoti_torch_dtype_complex64()); case ScalarType::ComplexDouble: - return from(aoti_torch_dtype_complex128()); + return torch::stable::detail::from(aoti_torch_dtype_complex128()); case ScalarType::Bool: - return from(aoti_torch_dtype_bool()); + return torch::stable::detail::from(aoti_torch_dtype_bool()); case ScalarType::BFloat16: - return from(aoti_torch_dtype_bfloat16()); + return torch::stable::detail::from(aoti_torch_dtype_bfloat16()); case ScalarType::Float8_e5m2: - return from(aoti_torch_dtype_float8_e5m2()); + return torch::stable::detail::from(aoti_torch_dtype_float8_e5m2()); case ScalarType::Float8_e4m3fn: - return from(aoti_torch_dtype_float8_e4m3fn()); + return torch::stable::detail::from(aoti_torch_dtype_float8_e4m3fn()); case ScalarType::Float8_e5m2fnuz: - return from(aoti_torch_dtype_float8_e5m2fnuz()); + return torch::stable::detail::from(aoti_torch_dtype_float8_e5m2fnuz()); case ScalarType::Float8_e4m3fnuz: - return from(aoti_torch_dtype_float8_e4m3fnuz()); + return torch::stable::detail::from(aoti_torch_dtype_float8_e4m3fnuz()); case ScalarType::UInt16: - return from(aoti_torch_dtype_uint16()); + return torch::stable::detail::from(aoti_torch_dtype_uint16()); case ScalarType::UInt32: - return from(aoti_torch_dtype_uint32()); + return torch::stable::detail::from(aoti_torch_dtype_uint32()); case ScalarType::UInt64: - return from(aoti_torch_dtype_uint64()); + return torch::stable::detail::from(aoti_torch_dtype_uint64()); default: STD_TORCH_CHECK( false, @@ -182,17 +182,18 @@ struct FromImpl { [[maybe_unused]] bool is_internal) { switch (val) { case DeviceType::CPU: - return from(aoti_torch_device_type_cpu()); + return torch::stable::detail::from(aoti_torch_device_type_cpu()); case DeviceType::CUDA: - return from(aoti_torch_device_type_cuda()); + return torch::stable::detail::from(aoti_torch_device_type_cuda()); case DeviceType::Meta: - return from(aoti_torch_device_type_meta()); + return torch::stable::detail::from(aoti_torch_device_type_meta()); case DeviceType::XPU: - return from(aoti_torch_device_type_xpu()); + return torch::stable::detail::from(aoti_torch_device_type_xpu()); case DeviceType::MPS: - return from(aoti_torch_device_type_mps()); + return torch::stable::detail::from(aoti_torch_device_type_mps()); case DeviceType::PrivateUse1: - return from(aoti_torch_device_type_privateuse1()); + return torch::stable::detail::from( + aoti_torch_device_type_privateuse1()); default: STD_TORCH_CHECK( false, @@ -208,7 +209,7 @@ struct FromImpl { std::nullopt_t val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - return from(nullptr); + return torch::stable::detail::from(nullptr); } }; @@ -248,10 +249,11 @@ struct FromImpl> { uint64_t extension_build_version, bool is_internal) { if (!val.has_value()) { - return from(std::nullopt); + return torch::stable::detail::from(std::nullopt); } - return from(new StableIValue(detail::FromImpl::call( - val.value(), extension_build_version, is_internal))); + return torch::stable::detail::from( + new StableIValue(detail::FromImpl::call( + val.value(), extension_build_version, is_internal))); } }; @@ -265,7 +267,7 @@ struct FromImpl { [[maybe_unused]] bool is_internal) { AtenTensorHandle new_ath; TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath)); - return from(new_ath); + return torch::stable::detail::from(new_ath); } }; @@ -286,21 +288,21 @@ struct FromImpl { [[maybe_unused]] bool is_internal) { switch (val) { case Layout::Strided: - return from(aoti_torch_layout_strided()); + return torch::stable::detail::from(aoti_torch_layout_strided()); case Layout::Sparse: - return from(aoti_torch_layout_sparse_coo()); + return torch::stable::detail::from(aoti_torch_layout_sparse_coo()); case Layout::SparseCsr: - return from(aoti_torch_layout_sparse_csr()); + return torch::stable::detail::from(aoti_torch_layout_sparse_csr()); case Layout::SparseCsc: - return from(aoti_torch_layout_sparse_csc()); + return torch::stable::detail::from(aoti_torch_layout_sparse_csc()); case Layout::SparseBsr: - return from(aoti_torch_layout_sparse_bsr()); + return torch::stable::detail::from(aoti_torch_layout_sparse_bsr()); case Layout::SparseBsc: - return from(aoti_torch_layout_sparse_bsc()); + return torch::stable::detail::from(aoti_torch_layout_sparse_bsc()); case Layout::Mkldnn: - return from(aoti_torch_layout__mkldnn()); + return torch::stable::detail::from(aoti_torch_layout__mkldnn()); case Layout::Jagged: - return from(aoti_torch_layout_jagged()); + return torch::stable::detail::from(aoti_torch_layout_jagged()); default: STD_TORCH_CHECK( false, @@ -321,13 +323,17 @@ struct FromImpl { [[maybe_unused]] bool is_internal) { switch (val) { case MemoryFormat::Contiguous: - return from(aoti_torch_memory_format_contiguous_format()); + return torch::stable::detail::from( + aoti_torch_memory_format_contiguous_format()); case MemoryFormat::Preserve: - return from(aoti_torch_memory_format_preserve_format()); + return torch::stable::detail::from( + aoti_torch_memory_format_preserve_format()); case MemoryFormat::ChannelsLast: - return from(aoti_torch_memory_format_channels_last()); + return torch::stable::detail::from( + aoti_torch_memory_format_channels_last()); case MemoryFormat::ChannelsLast3d: - return from(aoti_torch_memory_format_channels_last_3d()); + return torch::stable::detail::from( + aoti_torch_memory_format_channels_last_3d()); default: STD_TORCH_CHECK( false, @@ -349,10 +355,10 @@ struct FromImpl> { TORCH_ERROR_CODE_CHECK( torch_new_list_reserve_size(val.size(), &new_list_handle)); for (const auto& elem : val) { - TORCH_ERROR_CODE_CHECK( - torch_list_push_back(new_list_handle, from(elem))); + TORCH_ERROR_CODE_CHECK(torch_list_push_back( + new_list_handle, torch::stable::detail::from(elem))); } - return from(new_list_handle); + return torch::stable::detail::from(new_list_handle); } catch (const std::runtime_error&) { if (new_list_handle != nullptr) { // clean up memory if an error was thrown @@ -372,7 +378,8 @@ struct FromImpl> { const std::vector& val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - return from>(val); + return torch::stable::detail::from< + torch::headeronly::HeaderOnlyArrayRef>(val); } }; @@ -388,7 +395,7 @@ struct FromImpl { [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { // Convert DeviceType to shim representation (int32_t) - StableIValue device_type_shim = from(val.type()); + StableIValue device_type_shim = torch::stable::detail::from(val.type()); // Pack: lower 32 bits = device index, upper 32 bits = device type (shim) uint64_t device_index_bits = static_cast(static_cast(val.index())); @@ -409,7 +416,7 @@ struct FromImpl { StringHandle handle; TORCH_ERROR_CODE_CHECK( torch_new_string_handle(val.c_str(), val.length(), &handle)) - return from(handle); + return torch::stable::detail::from(handle); } }; @@ -478,7 +485,7 @@ struct ToImpl { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - int32_t shim_scalartype = to(val); + int32_t shim_scalartype = torch::stable::detail::to(val); if (shim_scalartype == aoti_torch_dtype_uint8()) { return ScalarType::Byte; } else if (shim_scalartype == aoti_torch_dtype_int8()) { @@ -537,7 +544,7 @@ struct ToImpl { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - int32_t shim_devicetype = to(val); + int32_t shim_devicetype = torch::stable::detail::to(val); if (shim_devicetype == aoti_torch_device_type_cpu()) { return DeviceType::CPU; } else if (shim_devicetype == aoti_torch_device_type_cuda()) { @@ -581,7 +588,7 @@ struct ToImpl> { StableIValue val, uint64_t extension_build_version, bool is_internal) { - auto sivp = to(val); + auto sivp = torch::stable::detail::to(val); // sivp is either nullptr or a pointer to a StableIValue if (sivp == nullptr) { @@ -606,7 +613,8 @@ struct ToImpl { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - return torch::stable::Tensor(to(val)); + return torch::stable::Tensor( + torch::stable::detail::to(val)); } }; @@ -622,7 +630,7 @@ struct ToImpl { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - int32_t shim_layout = to(val); + int32_t shim_layout = torch::stable::detail::to(val); if (shim_layout == aoti_torch_layout_strided()) { return Layout::Strided; } else if (shim_layout == aoti_torch_layout_sparse_coo()) { @@ -656,7 +664,7 @@ struct ToImpl { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - int32_t shim_memory_format = to(val); + int32_t shim_memory_format = torch::stable::detail::to(val); if (shim_memory_format == aoti_torch_memory_format_contiguous_format()) { return MemoryFormat::Contiguous; } else if ( @@ -688,7 +696,7 @@ struct ToImpl> { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - auto list_handle = to(val); + auto list_handle = torch::stable::detail::to(val); size_t size; try { TORCH_ERROR_CODE_CHECK(torch_list_size(list_handle, &size)); @@ -697,7 +705,7 @@ struct ToImpl> { for (size_t i = 0; i < size; i++) { StableIValue element; TORCH_ERROR_CODE_CHECK(torch_list_get_item(list_handle, i, &element)); - result.push_back(to(element)); + result.push_back(torch::stable::detail::to(element)); } TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); return result; @@ -722,7 +730,8 @@ struct ToImpl { // Unpack: lower 32 bits = device index, upper 32 bits = device type (shim) int32_t device_index = static_cast(val & 0xFFFFFFFF); StableIValue device_type_shim = (val >> 32) & 0xFFFFFFFF; - DeviceType device_type = to(device_type_shim); + DeviceType device_type = + torch::stable::detail::to(device_type_shim); return torch::stable::Device(device_type, device_index); } }; @@ -735,7 +744,7 @@ struct ToImpl { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - StringHandle handle = to(val); + StringHandle handle = torch::stable::detail::to(val); size_t length; TORCH_ERROR_CODE_CHECK(torch_string_length(handle, &length)); const char* data; @@ -822,11 +831,31 @@ HIDDEN_NAMESPACE_END(torch, stable, detail) // WARNING! Will be removed. Only exists for BC. See [global from/to deprecation // note] template -C10_DEPRECATED_MESSAGE("Use torch::stable::detail::from instead.") -auto from = &torch::stable::detail::from; +[[deprecated("Use torch::stable::detail::from instead.")]] +inline StableIValue from(T val) { + return torch::stable::detail::from(val); +} + +// WARNING! Will be removed. Only exists for BC. See [global from/to deprecation +// note] +template +[[deprecated("Use torch::stable::detail::from instead.")]] +inline StableIValue from(const std::optional& val) { + return torch::stable::detail::from(val); +} + +// WARNING! Will be removed. Only exists for BC. See [global from/to deprecation +// note] +[[deprecated( + "Use torch::stable::detail::from instead.")]] [[maybe_unused]] inline StableIValue +from(const torch::stable::Tensor& val) { + return torch::stable::detail::from(val); +} // WARNING! Will be removed. Only exists for BC. See [global from/to deprecation // note] template -C10_DEPRECATED_MESSAGE("Use torch::stable::detail::to instead.") -auto to = &torch::stable::detail::to; +[[deprecated("Use torch::stable::detail::to instead.")]] +inline T to(StableIValue val) { + return torch::stable::detail::to(val); +} From f53e14c73469f20725b1288eae370a2e3f3a18aa Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Thu, 4 Dec 2025 09:03:21 +0000 Subject: [PATCH 263/338] [ci] Update typing-extensions for python 3.14 (#169515) This fixes `test/inductor/test_caching.py::ContextTest::test_select_runtime_context_matches_forms_of_context` in python 3.14. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169515 Approved by: https://github.com/williamwen42 --- .ci/docker/requirements-ci.txt | 3 ++- pyproject.toml | 2 +- requirements.txt | 2 +- test/inductor/test_caching.py | 3 --- torch/testing/_internal/common_utils.py | 4 ---- 5 files changed, 4 insertions(+), 10 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index a32161cae6a34..15f8bde53f9bc 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -270,7 +270,8 @@ scipy==1.16.2 ; python_version >= "3.14" #test that import: # needed by torchgen utils -typing-extensions==4.12.2 +typing-extensions==4.12.2 ; python_version < "3.14" +typing-extensions==4.15.0 ; python_version >= "3.14" #Description: type hints for python #Pinned versions: #test that import: diff --git a/pyproject.toml b/pyproject.toml index dfc622650f5e7..6474ddd8f5027 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dev = [ "optree>=0.13.0", "psutil", "sympy>=1.13.3", - "typing-extensions>=4.13.2", + "typing-extensions>=4.15.0", "wheel", ] diff --git a/requirements.txt b/requirements.txt index e9b5d4482bc5c..8cc2f17fac395 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,5 +16,5 @@ optree>=0.13.0 psutil spin sympy>=1.13.3 -typing-extensions>=4.13.2 +typing-extensions>=4.15.0 wheel diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index 17527ffb79c1d..aa4c3a1f229f1 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -33,7 +33,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - xfailIfPy314Plus, ) @@ -375,7 +374,6 @@ def test_isolation_key_is_repeatable(self) -> None: """ self.assertEqual(context._isolation_key(), context._isolation_key()) - @xfailIfPy314Plus def test_select_runtime_context_matches_forms_of_context(self) -> None: """ Tests that the selected runtime context matches the forms of context. @@ -389,7 +387,6 @@ def test_select_runtime_context_matches_forms_of_context(self) -> None: set(context._RuntimeContext.forms_of_context()), ) - @xfailIfPy314Plus def test_select_compile_context_matches_forms_of_context(self) -> None: """ Tests that the selected compile context matches the forms of context. diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index b6904fd760982..df3ca03b76242 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1709,10 +1709,6 @@ def xfailIfPy312Plus(func): return unittest.expectedFailure(func) if sys.version_info >= (3, 12) else func -def xfailIfPy314Plus(func): - return unittest.expectedFailure(func) if sys.version_info >= (3, 14) else func - - def xfailIfLinux(func): return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM and not IS_FBCODE else func From 2da3bafb30b91de44fba1b9aecce1147cb64e679 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Thu, 4 Dec 2025 13:36:34 +0000 Subject: [PATCH 264/338] [Inductor] ReLU/GELU(Addmm) fusions (#168157) Similar to https://github.com/pytorch/pytorch/pull/158137 (thank you, @AaronWang04, for the instructional tips and answering my questions!), but performs `Activation(Addmm) -> _addmm_activation` replacement instead of `Activation(add(mm)) -> _addmm_activation`. The reasons as to why this mapping over the one in https://github.com/pytorch/pytorch/pull/158137 are: - Prior work done to extend cuBLASLt coverage in `addmm` beyond just 1D bias and `beta=1, alpha=1`. As long as there is an activation after `addmm`, we can call Lt. This makes the check for pattern replacement leaner and agnostic to the inputs' meta-data (`addmm`'s checks for free). - Inductor intercepts `addmm` and replaces it with `alpha * [alpha != 1] * m1 @ m2 + beta * [beta != 1] * input` when followed by point-wise consumers (including activation functions). So it is way easier and cleaner to intercept just `addmm` (and not combinatorial set of patterns) before such replacements. Re-run of the benchmark script in https://github.com/pytorch/pytorch/pull/158137 on H100 yields: `float16`: ``` ============================================================ Testing with M=1024, N=1024, K=1024, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.0096 ms Average Time per Iteration (torch compile): 0.0407 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.0270 ms Average Time per Iteration (torch compile): 0.0409 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.1828 ms Average Time per Iteration (torch compile): 0.2415 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=float16 ============================================================ Average Time per Iteration (cublas): 1.5971 ms Average Time per Iteration (torch compile): 1.9723 ms ``` `bfloat16`: ``` ============================================================ Testing with M=1024, N=1024, K=1024, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0093 ms Average Time per Iteration (torch compile): 0.0416 m ============================================================ Testing with M=2048, N=2048, K=2048, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0264 ms Average Time per Iteration (torch compile): 0.0411 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.1768 ms Average Time per Iteration (torch compile): 0.2430 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 1.5564 ms Average Time per Iteration (torch compile): 1.8916 ms ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168157 Approved by: https://github.com/eellison, https://github.com/eqy --- test/inductor/test_pattern_matcher.py | 68 ++++++++++++++++++- torch/_inductor/fx_passes/post_grad.py | 63 +++++++++++++++++ .../serialized_patterns/addmm_gelu_pattern.py | 43 ++++++++++++ .../serialized_patterns/addmm_relu_pattern.py | 35 ++++++++++ torchgen/fuse/gen_patterns.py | 3 +- 5 files changed, 209 insertions(+), 3 deletions(-) create mode 100644 torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py create mode 100644 torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 9928b89b81e64..f7e795f53f90d 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -1216,6 +1216,70 @@ def fn2(inp, a, b): _, (code) = run_and_get_code(fn2, args[0], args[1], args[2]) FileCheck().check_not("extern_kernels.addmm(").run(code[0]) + @skipIfRocm + def test_addmm_activation_fusion(self): + """ + Test whether Activation(Addmm) implies _addmm_activation + """ + + b = torch.rand(4, device=GPU_TYPE) + m1 = torch.rand(3, 2, device=GPU_TYPE) + m2 = torch.rand(2, 4, device=GPU_TYPE) + alphas = ({"alpha": 0.8}, {}) # **{} -> alpha=1 + betas = ({"beta": 1}, {}) # **{} -> beta=1 + + # Cases Activation(Addmm) -> _addmm_activation + fusable_activations = ( + torch.nn.functional.relu, + # NOTE: only approximate="tanh" is fusable + lambda *args, **kwargs: torch.nn.functional.gelu( + *args, approximate="tanh", **kwargs + ), + ) + for activation in fusable_activations: + + def f(b, m1, m2, beta, alpha): + return activation(torch.addmm(b, m1, m2, **beta, **alpha)) + + fc = torch.compile(f) + + for beta, alpha in itertools.product(betas, alphas): + expected = f(b, m1, m2, beta, alpha) + actual = fc(b, m1, m2, beta, alpha) + torch.testing.assert_close(expected, actual) + + _, (code) = run_and_get_code(fc, b, m1, m2, beta, alpha) + self.assertIn("_addmm_activation", code[0]) + + # Check no disruptions in the gemm autotune process + _, (code) = run_and_get_code( + torch.compile(f, options={"max_autotune_gemm": True}), + b, + m1, + m2, + beta, + alpha, + ) + self.assertNotIn("_addmm_activation", code[0]) + + # Cases Activation(Addmm) -> Activation(Addmm) + non_fusable_activations = ( + torch.nn.functional.gelu, # implies approximate="none" + lambda *args, **kwargs: torch.nn.functional.gelu( + *args, approximate="none", **kwargs + ), + ) + for activation in non_fusable_activations: + + def f(b, m1, m2, beta, alpha): + return activation(torch.addmm(b, m1, m2, **beta, **alpha)) + + fc = torch.compile(f) + + for beta, alpha in itertools.product(betas, alphas): + _, (code) = run_and_get_code(fc, b, m1, m2, beta, alpha) + self.assertNotIn("_addmm_activation", code[0]) + def test_addmm_alpha_beta_with_pointwise(self): # Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops # See https://github.com/pytorch/pytorch/issues/167313 @@ -1224,7 +1288,7 @@ def test_addmm_alpha_beta_with_pointwise(self): b = torch.rand(3, 2, device=GPU_TYPE) def f(x, a, b): - return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2)) + return torch.abs(torch.addmm(x, a, b, alpha=0.8, beta=0.2)) fc = torch.compile(f) @@ -1241,7 +1305,7 @@ def f(x, a, b): # Test with alpha=1, beta=1 (default) - should also unfuse def f_default(x, a, b): - return torch.nn.functional.relu(torch.addmm(x, a, b)) + return torch.abs(torch.addmm(x, a, b)) fc_default = torch.compile(f_default) expected_default = f_default(x, a, b) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index a21e78821e52b..2f73bf6ae86c4 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -34,6 +34,7 @@ CallFunctionVarArgs, filter_nodes, fwd_only, + gen_register_replacement, get_arg_value, get_mutation_region_id, Ignored, @@ -688,6 +689,66 @@ def body_fn(*flat_args): raise AssertionError("scan is not lowered to while_loop") +@functools.cache +def register_addmm_activation_fusions(): + def is_valid_addmm_activation_fusion(match: Match) -> bool: + # Exclude ROCm + if torch.version.hip: + return False + + if config.max_autotune_gemm: + return False + + inp = match.kwargs["inp"].meta["val"] + + if not inp.is_cuda: + return False + + output = match.output_node() + return not all( + is_pointwise_use(use, lambda target: torch.Tag.reduction in target.tags) + for use in output.users + ) + + args = [torch.empty(3), torch.empty(4, 2), torch.empty(2, 3)] + beta_alpha_workaround = {"beta": 1.3, "alpha": 1.2} + + def addmm_relu_pattern(inp, m1, m2, beta, alpha): + return aten.relu(aten.addmm(inp, m1, m2, beta=beta, alpha=alpha)) + + def addmm_gelu_pattern(inp, m1, m2, beta, alpha): + return aten.gelu( + aten.addmm(inp, m1, m2, beta=beta, alpha=alpha), approximate="tanh" + ) + + def addmm_relu_replacement(inp, m1, m2, beta, alpha): + return aten._addmm_activation(inp, m1, m2, beta=beta, alpha=alpha) + + def addmm_gelu_replacement(inp, m1, m2, beta, alpha): + return aten._addmm_activation( + inp, m1, m2, beta=beta, alpha=alpha, use_gelu=True + ) + + patterns = (addmm_relu_pattern, addmm_gelu_pattern) + replacements = (addmm_relu_replacement, addmm_gelu_replacement) + for pattern, replacement in zip(patterns, replacements): + key = f"{pattern.__name__}" + gen_register_replacement( + key, + # pyrefly: ignore [bad-argument-type] + pattern, + # pyrefly: ignore [bad-argument-type] + replacement, + args, + # pyrefly: ignore [bad-argument-type] + trace_fn=fwd_only, + # pyrefly: ignore [bad-argument-type] + pass_dicts=pass_patterns[1], + extra_check=is_valid_addmm_activation_fusion, + scalar_workaround=beta_alpha_workaround, + ) + + @init_once_fakemode def lazy_init(): if torch._C._has_mkldnn: @@ -713,6 +774,8 @@ def lazy_init(): extra_check=prepare_softmax_extra_check, ) + register_addmm_activation_fusions() + def reorder_for_locality(graph: torch.fx.Graph): if torch.distributed.is_available(): diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py new file mode 100644 index 0000000000000..f991015b4de69 --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py @@ -0,0 +1,43 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +addmm_default = CallFunction(aten.addmm.default, KeywordArg('inp'), KeywordArg('m1'), KeywordArg('m2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=4) +mul_Tensor = CallFunction(aten.mul.Tensor, addmm_default, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, addmm_default, addmm_default) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, addmm_default) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, addmm_default, mul_Tensor_3) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor, Ignored()) +tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) +add_Tensor_1 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) +addmm_gelu_pattern = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_1, _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py new file mode 100644 index 0000000000000..e9729a7787131 --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py @@ -0,0 +1,35 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +addmm_default = CallFunction(aten.addmm.default, KeywordArg('inp'), KeywordArg('m1'), KeywordArg('m2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha')) +addmm_relu_pattern = CallFunction(aten.relu.default, addmm_default, _users=0) diff --git a/torchgen/fuse/gen_patterns.py b/torchgen/fuse/gen_patterns.py index 0861c882e3fff..b4bdf022202ba 100644 --- a/torchgen/fuse/gen_patterns.py +++ b/torchgen/fuse/gen_patterns.py @@ -2,7 +2,7 @@ import os from torch._inductor import pattern_matcher -from torch._inductor.fx_passes import joint_graph +from torch._inductor.fx_passes import joint_graph, post_grad if __name__ == "__main__": @@ -17,3 +17,4 @@ # to serialize the patterns as it goes. os.environ["PYTORCH_GEN_PATTERNS"] = "1" joint_graph.lazy_init() + post_grad.lazy_init() From da2e3c472b60451e098e99c564af7cdab1f40add Mon Sep 17 00:00:00 2001 From: zhudada Date: Thu, 4 Dec 2025 14:22:20 +0000 Subject: [PATCH 265/338] Strengthen the implementation of OpenReg Stream (#166115) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR redesigns the implementation of OpenRegStream, enabling it to support default streams and normal stream pools with multiple priorities. It also adds robustness checks such as device ID validation and supplements relevant test cases for stream context management. **Changes and Reasons** 1. OpenReg runtime (OpenRegStream.cpp) 1)Redesigned the ID of OpenReg Stream. In the new implementation: - StreamIdType=6 indicates the default stream, and StreamIdType=7 represents the external stream. Values 0-5 are reserved as priority codes for normal stream pools. OpenReg currently supports stream pools with priorities 0-1. This design makes the priority of stream pools more intuitive and easier to understand. - Modified StreamIdType from an enum to a regular class to support multi-priority scenarios. - Updated the implementation of OpenRegStream::stream() for retrieving orStream_t. - Adjusted the code in the original implementation that restricted stream pools to single-priority, enabling multi-priority support. 2)Enhanced robustness: added device validation, DeviceGuard initialization. 2. Tests (test_streams.py) Added assertions for "original stream restoration" in context exception paths and stream switching checks, covering key code paths. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166115 Approved by: https://github.com/fffrog --- .../csrc/runtime/OpenRegStream.cpp | 172 ++++++++++-------- .../csrc/runtime/OpenRegStream.h | 3 +- .../torch_openreg/tests/test_streams.py | 16 +- .../third_party/openreg/csrc/stream.cpp | 5 +- .../openreg/tests/stream_tests.cpp | 2 +- 5 files changed, 119 insertions(+), 79 deletions(-) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegStream.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegStream.cpp index 4821f416ce749..7dca21eada6ae 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegStream.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegStream.cpp @@ -14,12 +14,13 @@ namespace c10::openreg { namespace { // Global stream state and constants -static c10::once_flag init_flag; +c10::once_flag init_flag; +DeviceIndex num_devices = -1; +constexpr int kStreamsPerPoolBits = 5; +constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; +constexpr int kStreamTypeBits = 3; -static DeviceIndex num_devices = -1; -static constexpr int kStreamsPerPoolBits = 5; -static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; -static constexpr int kStreamTypeBits = 2; +int max_stream_priorities; /* * The stream pools are lazily initialized when the first queue is requested @@ -27,30 +28,33 @@ static constexpr int kStreamTypeBits = 2; * a queue is requested, the next queue in the pool to be returned in a * round-robin fashion, see Note [Stream Management]. */ -static std::deque device_flags; -static std::vector device_flags; +std::vector, c10::openreg::max_compile_time_stream_priorities>> streams; -static std::deque< +std::deque< std::array, max_compile_time_stream_priorities>> priority_counters; -static thread_local std::unique_ptr current_streams = nullptr; +thread_local std::unique_ptr current_streams = nullptr; /* * Note [StreamId assignment] * ~~~~~~~~~~~~~~~~~~~~~~~~~~ * How do we assign stream IDs? * - * -- 56 bits -- -- 5 bits -- -- 2 bits -- -- 1 bit -- - * zeros StreamIdIndex StreamIdType Ext/native stream + * -- 55 bits -- -- 5 bits -- -- 3 bits -- -- 1 bit -- + * zeros StreamIdIndex StreamIdType Ext/Native stream * ignored for ext ignored for ext * - * Where StreamIdType: - * 00 = default stream - * 01 = normal stream - * 11 = external stream + * StreamIdType: + * 000 = normal stream + * 001 = high stream + * 110 = default stream + * 111 = external stream + + * The range 000 to 101 is reserved for stream pools of different priorities and can be expanded as needed. (OpenReg currently supports two priorities: 0 and 1) * * For external stream, StreamID is a orStream_t pointer. This means that last * bit will always be 0. So when constructing StreamId for a native stream we @@ -60,95 +64,104 @@ static thread_local std::unique_ptr current_streams = nullptr; * We rely on StreamIdIndex and StreamIdType being non-negative; */ using StreamIdIndex = uint8_t; -enum class StreamIdType : uint8_t { - DEFAULT = 0x0, - NORMAL = 0x1, - EXT = 0x3, +class StreamIdType { + private: + uint8_t stream_type; + + public: + static const uint8_t DEFAULT = 0x6; + static const uint8_t EXT = 0x7; + + public: + StreamIdType(const uint8_t _stream_type) : stream_type(_stream_type) {} + + bool isExt() const { + return EXT == stream_type; + } + + bool isDefault() const { + return DEFAULT == stream_type; + } + + uint8_t getStreamType() const { + return stream_type; + } }; inline std::ostream& operator<<(std::ostream& stream, StreamIdType s) { - switch (s) { + switch (s.getStreamType()) { case StreamIdType::DEFAULT: return stream << "DEFAULT"; - case StreamIdType::NORMAL: - return stream << "NORMAL"; case StreamIdType::EXT: return stream << "EXT"; default: - break; + return stream << "PRIORITY" << static_cast(s.getStreamType()); } - - return stream << static_cast(s); } -static inline StreamIdType streamIdType(StreamId s) { - // Externally allocated streams have their id being the orStream_ptr - // so the last bit will be 0 +inline StreamIdType streamIdType(StreamId s) { if (!(s & 1)) { return StreamIdType(StreamIdType::EXT); } - int mask_for_type = (1 << kStreamTypeBits) - 1; - auto st = static_cast((s >> 1) & mask_for_type); + auto st = (s >> 1) & mask_for_type; TORCH_CHECK( - st == StreamIdType::DEFAULT || st == StreamIdType::NORMAL, - "invalid StreamId: ", - s); + st == StreamIdType::DEFAULT || (st >= 0 && st < max_stream_priorities), + "invalid StreamIdType: ", + st); return st; } -static inline size_t streamIdIndex(StreamId s) { +inline size_t streamIdIndex(StreamId s) { return static_cast( (s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1)); } StreamId makeStreamId(StreamIdType st, size_t si) { - if (st == StreamIdType::EXT) { - return static_cast(0); - } - return (static_cast(si) << (kStreamTypeBits + 1)) | - (static_cast(st) << 1) | 1; + (static_cast(st.getStreamType()) << 1) | 1; } -static void initGlobalStreamState() { +void initGlobalStreamState() { num_devices = device_count(); device_flags.resize(num_devices); streams.resize(num_devices); priority_counters.resize(num_devices); + int leastPriority = -1, greatestPriority = -1; + OPENREG_CHECK( + orDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority)); + auto range = greatestPriority - leastPriority + 1; + max_stream_priorities = range >= c10::openreg::max_compile_time_stream_priorities + ? c10::openreg::max_compile_time_stream_priorities + : range; } -static void initSingleDeviceStream( - int priority, - DeviceIndex device_index, - int i) { +void initSingleDeviceStream(int priority, DeviceIndex device_index, int i) { auto& stream = streams[device_index][priority][i]; - OPENREG_CHECK(orStreamCreateWithPriority(&stream, 0, priority)); priority_counters[device_index][priority] = 0; } + // Creates stream pools for the specified device. It should be call only once. -static void initDeviceStreamState(DeviceIndex device_index) { +void initDeviceStreamState(DeviceIndex device_index) { + DeviceGuard device_guard{Device(DeviceType::PrivateUse1, device_index)}; for (const auto i : c10::irange(kStreamsPerPool)) { - for (const auto p : c10::irange(max_compile_time_stream_priorities)) { + for (const auto p : c10::irange(max_stream_priorities)) { initSingleDeviceStream(p, device_index, i); } } } -static void initOpenRegStreamsOnce() { +void initOpenRegStreamsOnce() { c10::call_once(init_flag, initGlobalStreamState); - for (const auto i : c10::irange(num_devices)) { c10::call_once( device_flags[i], initDeviceStreamState, static_cast(i)); } - if (current_streams) { return; } - // Inits current streams (thread local) to the last queue in the "normal // priority" queue pool. Note: the queue pool have not been initialized yet. // It will be initialized in initDeviceStreamState for the specified device. @@ -158,9 +171,19 @@ static void initOpenRegStreamsOnce() { } } -static uint32_t get_idx(std::atomic& counter) { - auto raw_idx = counter++; - return raw_idx % kStreamsPerPool; +inline void check_device(DeviceIndex device_index) { + TORCH_CHECK( + device_index >= 0 && device_index < num_devices, + "Device index value ", + static_cast(device_index), + " is out of index range [0, ", + static_cast(num_devices), + ")"); +} + +uint32_t get_idx(std::atomic& counter) { + auto raw = counter++; + return raw % kStreamsPerPool; } OpenRegStream OpenRegStreamForId(DeviceIndex device_index, StreamId stream_id) { @@ -180,22 +203,24 @@ orStream_t OpenRegStream::stream() const { StreamId stream_id = stream_.id(); StreamIdType st = streamIdType(stream_id); size_t si = streamIdIndex(stream_id); - switch (st) { - // The index 0 stream is default as well. - case StreamIdType::DEFAULT: - case StreamIdType::NORMAL: - return streams[device_index][static_cast(st)][si]; - case StreamIdType::EXT: - return reinterpret_cast(stream_id); - default: - TORCH_CHECK( - false, - "Unrecognized stream ", - stream_, - " (I didn't recognize the stream type, ", - st, - ").", - " Did you manufacture the StreamId yourself? Don't do that;"); + // OpenReg does not support a default stream natively. + // Here, we designate stream 0 from the priority 0 stream pool to serve as the default stream. + if(st.isDefault()){ + return streams[device_index][0][0]; + }else if(st.isExt()){ + return reinterpret_cast(stream_id); + }else{ + auto streamType = st.getStreamType(); + TORCH_CHECK( + streamType >= 0 && streamType <= max_stream_priorities, + "Unrecognized stream ", + stream_, + " (I didn't recognize the stream type, ", + st, + " with the value ", + streamType, + ")"); + return streams[device_index][streamType][si]; } } @@ -207,8 +232,7 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) { if (device_index == -1) { device_index = current_device(); } - auto pri_idx = - std::clamp(priority, 0, max_compile_time_stream_priorities - 1); + auto pri_idx = std::clamp(priority, 0, max_stream_priorities - 1); const auto idx = get_idx(priority_counters[device_index][pri_idx]); auto id_type = static_cast(pri_idx); return OpenRegStreamForId(device_index, makeStreamId(id_type, idx)); @@ -216,7 +240,7 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) { OpenRegStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) { initOpenRegStreamsOnce(); - int priority = 0; + int priority = isHighPriority ? max_stream_priorities - 1 : 0; return getStreamFromPool(priority, device); } @@ -232,6 +256,7 @@ OpenRegStream getDefaultOpenRegStream(DeviceIndex device_index) { if (device_index == -1) { device_index = current_device(); } + check_device(device_index); return OpenRegStreamForId( device_index, makeStreamId(StreamIdType::DEFAULT, 0)); } @@ -241,6 +266,7 @@ OpenRegStream getCurrentOpenRegStream(DeviceIndex device_index) { if (device_index == -1) { device_index = current_device(); } + check_device(device_index); return OpenRegStreamForId(device_index, current_streams[device_index]); } diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegStream.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegStream.h index e1fd0c719f5a1..bca5f697a4ab0 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegStream.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegStream.h @@ -11,7 +11,8 @@ namespace c10::openreg { -static constexpr int max_compile_time_stream_priorities = 1; +// Derive compile-time priority count from shared openreg backend constant. +static constexpr int max_compile_time_stream_priorities = 2; class OpenRegStream { public: diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_streams.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_streams.py index 20bb3df09d9fa..e0b0b749ba23c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_streams.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_streams.py @@ -9,7 +9,6 @@ class TestStream(TestCase): def test_stream_create(self): stream = torch.Stream(device="openreg") self.assertEqual(stream.device_index, torch.openreg.current_device()) - stream = torch.Stream(device="openreg:1") self.assertEqual(stream.device.type, "openreg") self.assertEqual(stream.device_index, 1) @@ -30,6 +29,19 @@ def test_stream_context(self): with torch.Stream(device="openreg:1") as stream: self.assertEqual(torch.accelerator.current_stream(), stream) + def test_stream_context_exception_restore(self): + prev = torch.accelerator.current_stream() + inner_stream = torch.Stream(device="openreg:1") + try: + with inner_stream: + # inside the context we should be on the inner stream + self.assertEqual(torch.accelerator.current_stream(), inner_stream) + raise RuntimeError("forced") + except RuntimeError: + pass + # After the exception, the current stream should be restored. + self.assertEqual(torch.accelerator.current_stream(), prev) + @skipIfTorchDynamo() def test_stream_switch(self): stream1 = torch.Stream(device="openreg:0") @@ -38,6 +50,8 @@ def test_stream_switch(self): self.assertEqual(current_stream, stream1) stream2 = torch.Stream(device="openreg:1") + current_stream = torch.accelerator.current_stream() + self.assertEqual(current_stream, stream1) torch.accelerator.set_stream(stream2) current_stream = torch.accelerator.current_stream() self.assertEqual(current_stream, stream2) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/stream.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/stream.cpp index 30f50b1aa2895..1a9fb83c407c1 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/stream.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/stream.cpp @@ -1,5 +1,4 @@ #include - #include #include #include @@ -283,9 +282,9 @@ orError_t orDeviceGetStreamPriorityRange( return orErrorUnknown; } - // OpenReg have only one priority now. + // OpenReg priority levels are 0 and 1 *leastPriority = 0; - *greatestPriority = 0; + *greatestPriority = 1; return orSuccess; } diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp index fbf5cb900a811..65b3fe9b0c60e 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp @@ -127,7 +127,7 @@ TEST_F(StreamTest, StreamPriorityRange) { // OpenReg currently exposes only one priority level; verify the fixed range. EXPECT_EQ(orDeviceGetStreamPriorityRange(&min_p, &max_p), orSuccess); EXPECT_EQ(min_p, 0); - EXPECT_EQ(max_p, 0); + EXPECT_EQ(max_p, 1); } } // namespace From a36e1d39ebbf60976fec5a0d8a96763e6adfbea3 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 4 Dec 2025 15:09:20 +0000 Subject: [PATCH 266/338] Triton 3.6 pin update (#168096) Required for release 2.10 Rocm wheel build fix provided by: https://github.com/pytorch/pytorch/pull/169369 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168096 Approved by: https://github.com/njriasan, https://github.com/malfet, https://github.com/huydhn --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- .github/scripts/amd/package_triton_wheel.sh | 1 + .../rocm/dynamic_inductor_timm_training.csv | 2 +- test/inductor/test_cooperative_reductions.py | 2 ++ test/test_sparse.py | 4 +++- test/test_sparse_csr.py | 3 ++- 7 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 7aab8bed1c108..263fcf2e0bdbb 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 +5261b27331eb1dd09df9ec1bd6acc21cbb184481 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index d5c0c99142898..40c341bdcdbe8 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.1 +3.6.0 diff --git a/.github/scripts/amd/package_triton_wheel.sh b/.github/scripts/amd/package_triton_wheel.sh index fe8d915422dac..501e50e2fe2f1 100755 --- a/.github/scripts/amd/package_triton_wheel.sh +++ b/.github/scripts/amd/package_triton_wheel.sh @@ -87,6 +87,7 @@ done cp -r $ROCM_HOME/include/hip $TRITON_ROCM_DIR/include cp -r $ROCM_HOME/include/roctracer $TRITON_ROCM_DIR/include cp -r $ROCM_HOME/include/hsa $TRITON_ROCM_DIR/include +cp -r $ROCM_HOME/include/hipblas-common $TRITON_ROCM_DIR/include # Copy linker mkdir -p $TRITON_ROCM_DIR/llvm/bin diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv index 2d087e6595526..702da0cb57f89 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_timm_training.csv @@ -10,7 +10,7 @@ beit_base_patch16_224,pass,7 -convnextv2_nano.fcmae_ft_in22k_in1k,pass,7 +convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7 diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index 4548a819b07aa..45a79bbbc73f0 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -17,6 +17,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + slowTest, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -198,6 +199,7 @@ def fn(x, y): self.assertEqual(before.count("if rsplit_id == ("), 0) self.assertEqual(after.count("if rsplit_id == ("), 6) + @slowTest @parametrize("bs", [1, 2, 5, 15]) @parametrize("count", [1024**2 + 1, 1024**2 - 1, 1024]) def test_non_power_of_2(self, bs, count): diff --git a/test/test_sparse.py b/test/test_sparse.py index 25d46892de258..91b6d82fe34f1 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -12,7 +12,7 @@ load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \ parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \ - skipIfCrossRef + skipIfCrossRef, slowTest from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_mps import mps_ops_modifier from numbers import Number @@ -4934,6 +4934,7 @@ def test_generate_simple_inputs(self): f' contiguous_indices{contiguous_indices}, contiguous_values={contiguous_values}') assert not untested_combinations, untested_combinations + @slowTest @all_sparse_layouts('layout', include_strided=False) def test_constructor_autograd(self, device, layout): @@ -5490,6 +5491,7 @@ def test_sparse_mask(self, mask_layout, device, dtype): result = mask.to_dense().sparse_mask(mask) self.assertEqual(result, mask) + @slowTest @all_sparse_layouts('layout', include_strided=False) @parametrize("masked", [subtest(False, name='nonmasked'), subtest(True, name='masked')]) @parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')]) diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 9315154614cca..4061978e35157 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import \ (TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, - skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings) + skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings, slowTest) from torch.testing._internal.common_device_type import \ (ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric, precisionOverride, skipMeta, skipCUDAIf, skipCUDAIfRocm, skipCPUIfNoMklSparse, largeTensorTest) @@ -3848,6 +3848,7 @@ def test_triton_scatter_mm(self, device, dtype): @parametrize("blocksize", [2, '2x3', 16, '16x32', 32, 64]) @onlyCUDA + @slowTest @dtypes(torch.half, torch.bfloat16, torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") From c19ecc7b073b128aaec212c05c608b1949c5bcd0 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 4 Dec 2025 15:33:13 +0000 Subject: [PATCH 267/338] fix typo: ommunication -> communication (#169557) As per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/169557 Approved by: https://github.com/zou3519 --- torch/_inductor/comm_lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/comm_lowering.py b/torch/_inductor/comm_lowering.py index 1f6cc5ee3e726..f20ca66c2de34 100644 --- a/torch/_inductor/comm_lowering.py +++ b/torch/_inductor/comm_lowering.py @@ -47,7 +47,7 @@ # # For eligible collective ops, we identify communication buffers at lowering # time and optionally choose to lower the op to a different kernel -# (ommunication libraries like NCCL handle both registered and non-registered +# (communication libraries like NCCL handle both registered and non-registered # buffers transparently within the same op, though some may require different # ops for different cases). Later, the codegen will perform "persistent # allocation" to satisfy the aforementioned constraints, and optionally, From 005f9fb6b9a283fe4811f157ccd73bf7ca3873b0 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 3 Dec 2025 23:22:43 +0000 Subject: [PATCH 268/338] Switch off of deprecaated API (#169517) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169517 Approved by: https://github.com/mlazos --- torch/_inductor/codegen/cutedsl/_cutedsl_utils.py | 4 ++-- torch/_inductor/codegen/cutedsl/cutedsl_kernel.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py b/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py index 173d122781016..17f850c8078c8 100644 --- a/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py +++ b/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py @@ -11,7 +11,7 @@ def ssa_to_indexable(ssa_value: cute.TensorSSA, dtype: str) -> cute.Numeric: Workaround for lack of gather support: SSA values cannot be used directly as indices in tensor loads. This converts SSA → fragment → scalar for indexing. """ - frag = cute.make_fragment(1, dtype) + frag = cute.make_rmem_tensor(1, dtype) frag.store(ssa_value) return frag[0] @@ -24,6 +24,6 @@ def result_to_ssa(value: cute.Numeric, dtype: str) -> cute.TensorSSA: After performing operations with non-SSA values (like indexed loads), convert the result back to SSA form for further computation. """ - frag = cute.make_fragment(1, dtype) + frag = cute.make_rmem_tensor(1, dtype) frag[0] = value return frag.load() diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py index 883517e2d3cdb..8d7f6bb337cc7 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py @@ -433,7 +433,7 @@ def load(self, name: str, index: sympy.Expr): val_frag = self.kernel.cse.newvar(dtype=var_dtype) self.kernel.body.writeline( - f"{val_frag} = cute.make_fragment(1, {cute_dtype})" + f"{val_frag} = cute.make_rmem_tensor(1, {cute_dtype})" ) self.kernel.body.writeline(f"{val_frag}[0] = ({var}[{idx_var}])") From 8b4f89ee0dce31b2162b0732e970784b10b3fa6b Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Thu, 4 Dec 2025 16:35:10 +0000 Subject: [PATCH 269/338] [CI] Simplify CPython installation (#169510) 3.14 workarounds are no longer necessary as it reached GA Partially rollsback changes from https://github.com/pytorch/pytorch/pull/164791 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169510 Approved by: https://github.com/seemethere, https://github.com/atalman --- .ci/docker/common/install_conda.sh | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 41335a0dc370f..481de54a50f2c 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -49,20 +49,12 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then export SYSROOT_DEP="sysroot_linux-64=2.17" fi -# Install correct Python version -# Also ensure sysroot is using a modern GLIBC to match system compilers -if [ "$ANACONDA_PYTHON_VERSION" = "3.14" ]; then - as_jenkins conda create -n py_$ANACONDA_PYTHON_VERSION -y\ - python="3.14.0" \ - ${SYSROOT_DEP} \ - -c conda-forge -else # Install correct Python version # Also ensure sysroot is using a modern GLIBC to match system compilers as_jenkins conda create -n py_$ANACONDA_PYTHON_VERSION -y\ python="$ANACONDA_PYTHON_VERSION" \ ${SYSROOT_DEP} -fi + # libstdcxx from conda default channels are too old, we need GLIBCXX_3.4.30 # which is provided in libstdcxx 12 and up. conda_install libstdcxx-ng=12.3.0 --update-deps -c conda-forge From 324a8280712ee6ba0ebddc569964334d36137b98 Mon Sep 17 00:00:00 2001 From: Yuhang Yang Date: Thu, 4 Dec 2025 17:53:59 +0000 Subject: [PATCH 270/338] [c10d][Sym mem] Add set_signal_pad_size API for SymmetricMemory (#169156) Summary: The signal pad size for symmetric memory was previously hardcoded as a constexpr, which may be too small for workloads that launch a large number of blocks. This change exposes `set_signal_pad_size` and `get_signal_pad_size` APIs to allow users to configure the signal pad size before making symmetric memory allocations. ### Changes: #### 1. Core API (C++) - Renamed `signal_pad_size` constexpr to `default_signal_pad_size` in CUDASymmetricMemoryTypes.hpp - Added `get_signal_pad_size()` and `set_signal_pad_size(size_t)` function declarations in CUDASymmetricMemoryTypes.hpp - Implemented the functions in SymmetricMemory.cpp using `std::optional` to distinguish between default and user-configured values - Added TORCH_API exports in SymmetricMemory.hpp for public API access #### 2. Backend Updates - Updated CUDASymmetricMemory.cu to call `get_signal_pad_size()` instead of using the hardcoded constant - Updated NCCLSymmetricMemory.cu to use configurable signal pad size with local variable `signal_pad_size` - Updated NVSHMEMSymmetricMemory.cu to use configurable signal pad size with local variable `signal_pad_size` #### 3. Python Bindings - Added Python bindings in init.cpp with comprehensive docstrings explaining usage - Added Python wrapper functions in torch/distributed/_symmetric_memory/__init__.py - Updated `__all__` to export the new API functions #### 4. Tests - Added `test_get_signal_pad_size()` to verify the API returns a positive integer and Python/C++ consistency - Added `test_set_signal_pad_size()` to verify setting, getting, and restoring signal pad size values Test Plan: `PYTHONPATH=. python3 test/distributed/test_symmetric_memory.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/169156 Approved by: https://github.com/ngimel --- test/distributed/test_nvshmem_triton.py | 10 +++ test/distributed/test_symmetric_memory.py | 69 +++++++++++++++++ torch/_C/_distributed_c10d.pyi | 3 +- torch/csrc/distributed/c10d/init.cpp | 10 ++- .../c10d/symm_mem/CUDASymmetricMemory.cu | 9 +-- .../c10d/symm_mem/CUDASymmetricMemory.hpp | 1 - .../symm_mem/CUDASymmetricMemoryTypes.hpp | 8 +- .../c10d/symm_mem/NCCLSymmetricMemory.cu | 10 +-- .../c10d/symm_mem/NVSHMEMSymmetricMemory.cu | 76 ++++++++++--------- .../c10d/symm_mem/SymmetricMemory.cpp | 21 +++++ .../c10d/symm_mem/SymmetricMemory.hpp | 12 ++- .../distributed/_symmetric_memory/__init__.py | 54 ++++++++++++- 12 files changed, 226 insertions(+), 57 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index ad30a7df5d43a..a33d3bdb9e866 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -4,6 +4,16 @@ import sys +# Import TEST_WITH_ROCM first to check for ROCm before importing NVSHMEM modules +from torch.testing._internal.common_utils import TEST_WITH_ROCM + + +# Skip entire module on ROCm before importing NVSHMEM-specific modules +# NVSHMEM is NVIDIA-specific and can cause crashes during import on ROCm +if TEST_WITH_ROCM: + print("NVSHMEM not available on ROCm, skipping tests") + sys.exit(0) + import triton.language as tl import torch diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index f589339f1944a..8c0d780cbbb82 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -98,6 +98,75 @@ def test_cuda_nvlink_connectivity_detection(self) -> None: for row in connectivity.matrix: self.assertEqual(len(row), torch.cuda.device_count()) + @skipIf( + not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch" + ) + @skip_if_lt_x_gpu(2) + def test_get_signal_pad_size(self) -> None: + # Test that get_signal_pad_size returns a positive integer + signal_pad_size = symm_mem.get_signal_pad_size() + self.assertIsInstance(signal_pad_size, int) + self.assertGreater(signal_pad_size, 0) + + # Test that the C++ API returns the same value + cpp_signal_pad_size = _SymmetricMemory.signal_pad_size + self.assertEqual(signal_pad_size, cpp_signal_pad_size) + + @skipIf( + not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch" + ) + @skip_if_lt_x_gpu(2) + def test_set_signal_pad_size(self) -> None: + # Save the original signal pad size + original_size = symm_mem.get_signal_pad_size() + + # Test setting a new signal pad size + new_size = 1024 * 1024 # 1MB + symm_mem.set_signal_pad_size(new_size) + self.assertEqual(symm_mem.get_signal_pad_size(), new_size) + + # Test that the C++ API reflects the change + self.assertEqual(_SymmetricMemory.signal_pad_size, new_size) + + # Restore original size for other tests + symm_mem.set_signal_pad_size(original_size) + self.assertEqual(symm_mem.get_signal_pad_size(), original_size) + + @skipIf( + not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch" + ) + @skip_if_lt_x_gpu(2) + def test_set_signal_pad_size_with_allocation(self) -> None: + """Test that custom signal pad size is actually used in allocations.""" + self._init_process() + + # Save the original signal pad size + original_size = symm_mem.get_signal_pad_size() + + # Test with a custom signal pad size (2x the default) + custom_size = original_size * 2 + symm_mem.set_signal_pad_size(custom_size) + + # Allocate symmetric memory and verify the signal pad size + t = symm_mem.empty(64, device="cuda") + symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) + + # Verify the allocated symmetric memory uses the custom signal pad size + self.assertEqual(symm_mem_hdl.signal_pad_size, custom_size) + + # Test that signal pad operations work with the custom size + signal_pad = symm_mem_hdl.get_signal_pad(self.rank) + expected_numel = custom_size // 4 # uint32_t + self.assertEqual(signal_pad.numel(), expected_numel) + + # Verify we can use the full custom signal pad + signal_pad.fill_(0) + signal_pad[0] = 42 + self.assertEqual(signal_pad[0].item(), 42) + + # Restore original settings + symm_mem.set_signal_pad_size(original_size) + @skipIf( not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch" ) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 477b35b1811e4..1f50ee578a80a 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -793,6 +793,7 @@ class _SymmetricMemory: def get_backend(device: torch.device) -> Optional[str]: ... @staticmethod def get_mempool_allocator(device: torch.device) -> Any: ... + signal_pad_size: int @property def rank(self) -> int: ... @property @@ -854,8 +855,6 @@ class _SymmetricMemory: def multicast_ptr(self) -> int: ... @property def buffer_size(self) -> int: ... - @property - def signal_pad_size(self) -> int: ... class ProcessGroupXCCL(Backend): class Options(Backend.Options): diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 255e793eaa4df..4b18e8f6552db 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1137,6 +1137,14 @@ This class does not support ``__members__`` property.)"); &::c10d::symmetric_memory::has_multicast_support) .def_static("set_backend", &::c10d::symmetric_memory::set_backend) .def_static("get_backend", &::c10d::symmetric_memory::get_backend) + .def_property_static( + "signal_pad_size", + [](py::object /* self */) { + return ::c10d::symmetric_memory::get_signal_pad_size(); + }, + [](py::object /* self */, size_t size) { + ::c10d::symmetric_memory::set_signal_pad_size(size); + }) .def_static( "get_mempool_allocator", &::c10d::symmetric_memory::get_mempool_allocator) @@ -1177,8 +1185,6 @@ This class does not support ``__members__`` property.)"); return reinterpret_cast(symm_mem->get_multicast_ptr()); }) .def_property_readonly("buffer_size", &SymmetricMemory::get_buffer_size) - .def_property_readonly( - "signal_pad_size", &SymmetricMemory::get_signal_pad_size) .def_property_readonly("offset", &SymmetricMemory::get_offset) .def( "get_buffer", diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index 6352330c3872c..67eb13d24539a 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -134,10 +134,6 @@ size_t CUDASymmetricMemory::get_buffer_size() { return buffer_size_; } -size_t CUDASymmetricMemory::get_signal_pad_size() { - return signal_pad_size; -} - bool CUDASymmetricMemory::has_multicast_support() { return mc_addr_ != nullptr; } @@ -153,7 +149,8 @@ void check_channel(int channel, int world_size) { "must be greater than 0 (got ", channel, ")"); - const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; + const size_t num_channels = c10d::symmetric_memory::get_signal_pad_size() / + sizeof(uint32_t) * world_size; TORCH_CHECK( static_cast(channel) < num_channels, "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", @@ -348,7 +345,7 @@ void* CUDASymmetricMemoryAllocator::alloc( int device_idx, const std::optional& group_name) { size_t signal_pad_offset = at::round_up(size, 16UL); - size_t block_size = signal_pad_offset + signal_pad_size; + size_t block_size = signal_pad_offset + get_signal_pad_size(); c10::cuda::CUDAGuard guard(device_idx); device_idx = static_cast(guard.current_device().index()); #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp index 39a6122bcdb27..e0e343da3a981 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp @@ -47,7 +47,6 @@ class CUDASymmetricMemory : public SymmetricMemory { void** get_buffer_ptrs_dev() override; void** get_signal_pad_ptrs_dev() override; size_t get_buffer_size() override; - size_t get_signal_pad_size() override; bool has_multicast_support() override; void* get_multicast_ptr() override; diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp index daf273446ef3a..7c255fa283ec9 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp @@ -1,7 +1,12 @@ #pragma once +#include #include +#if defined(USE_ROCM) +#include +#endif + namespace c10d::symmetric_memory { // Covers NVL72 @@ -11,7 +16,8 @@ constexpr int symm_max_nblocks = 32; // Maximally, a rank will need to sync with all other ranks, over all // channels. Each signal is 32 bits, which is the minimum unit for atomic cas. -constexpr size_t signal_pad_size = +// Default signal pad size, can be overridden via set_signal_pad_size(). +constexpr size_t default_signal_pad_size = symm_max_nblocks * max_cuda_p2p_domain_size * sizeof(uint32_t); #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) diff --git a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu index c099e2d72ecfd..a1d83ea702226 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu @@ -7,8 +7,8 @@ #endif #ifdef NCCL_HAS_SYMMEM_SUPPORT -#include #include +#include #include #include #include @@ -79,10 +79,6 @@ class NCCLSymmetricMemory : public SymmetricMemory { return buffer_size_; } - size_t get_signal_pad_size() override { - return signal_pad_size; - }; - bool has_multicast_support() override { // TODO return false; @@ -229,7 +225,9 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { comm)); void* signal_pad_ptr; - C10D_NCCL_CHECK(ncclMemAlloc(&signal_pad_ptr, signal_pad_size), "ncclMemAlloc failed"); + const size_t signal_pad_size = get_signal_pad_size(); + C10D_NCCL_CHECK( + ncclMemAlloc(&signal_pad_ptr, signal_pad_size), "ncclMemAlloc failed"); C10D_NCCL_CHECK( ncclCommWindowRegister(comm, signal_pad_ptr, signal_pad_size, (ncclWindow_t*)&signal_handle, NCCL_WIN_COLL_SYMMETRIC), c10::str( diff --git a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu index 510f5c4dd1b32..62adf88d4384e 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu @@ -1,8 +1,8 @@ #include -#include #include #include #include +#include #include #include @@ -39,7 +39,7 @@ struct NVSHMEMAllocation { return; } c10::cuda::CUDAGuard guard(device_idx); - nvshmem_free(ptr); // nvshmem_free has no return value + nvshmem_free(ptr); // nvshmem_free has no return value } }; @@ -53,8 +53,7 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target { NVSHMEMPeerAllocInfo( NVSHMEMAllocation* allocation, const std::string& group_name) - : base_ptr_(allocation->ptr), - buffer_size_(allocation->buffer_size) { + : base_ptr_(allocation->ptr), buffer_size_(allocation->buffer_size) { // For logging only static int exchanged_n_times = 0; c10::cuda::CUDAGuard guard(allocation->device_idx); @@ -82,8 +81,7 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target { world_within_cuda_p2p_ = true; for (int r = 0; r < world_size_; ++r) { - auto peer_ptr = nvshmem_ptr( - base_ptr_, rank_to_global_rank_[r]); + auto peer_ptr = nvshmem_ptr(base_ptr_, rank_to_global_rank_[r]); buffers_.push_back(peer_ptr); // If a peer is over network, `nvshmem_ptr` returns null if (peer_ptr == nullptr) { @@ -92,13 +90,14 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target { } // TODO: use the same allocation for signal pad + const size_t signal_pad_size = get_signal_pad_size(); void* signal_pad_ptr = nvshmem_malloc(signal_pad_size); TORCH_CHECK(signal_pad_ptr != nullptr, "nvshmem_malloc failed"); AT_CUDA_CHECK(cudaMemset(signal_pad_ptr, 0, signal_pad_size)); for (int r = 0; r < world_size_; ++r) { - signal_pads_.push_back(nvshmem_ptr( - signal_pad_ptr, rank_to_global_rank_[r])); + signal_pads_.push_back( + nvshmem_ptr(signal_pad_ptr, rank_to_global_rank_[r])); } const size_t arr_size = sizeof(void*) * world_size_; @@ -146,8 +145,7 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { NVSHMEMSymmetricMemory( NVSHMEMAllocation* allocation, const std::string& group_name) - : device_idx_(allocation->device_idx), - group_name_(group_name) { + : device_idx_(allocation->device_idx), group_name_(group_name) { // A handle stores two types of info: // (i) allocation's base ptrs and base signal pads, ours and peers' pai_ = c10::make_intrusive(allocation, group_name); @@ -159,14 +157,17 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { NVSHMEMSymmetricMemory(const NVSHMEMSymmetricMemory& other) = delete; // Copy with offset is allowed - // This is mostly a shallow copy that shares the pointer to `NVSHMEMPeerAllocInfo` which has been created by `other` + // This is mostly a shallow copy that shares the pointer to + // `NVSHMEMPeerAllocInfo` which has been created by `other` NVSHMEMSymmetricMemory(const NVSHMEMSymmetricMemory& other, size_t offset) - : device_idx_(other.device_idx_), group_name_(other.group_name_), pai_(other.pai_) { + : device_idx_(other.device_idx_), + group_name_(other.group_name_), + pai_(other.pai_) { offset_ = offset; } - ~NVSHMEMSymmetricMemory() override{ - // TODO + ~NVSHMEMSymmetricMemory() override { + // TODO }; std::vector get_buffer_ptrs() override { @@ -189,10 +190,6 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { return pai_->buffer_size_; } - size_t get_signal_pad_size() override { - return signal_pad_size; - }; - bool has_multicast_support() override { // TODO return false; @@ -247,7 +244,7 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { int device_idx_; std::string group_name_; c10::intrusive_ptr pai_; - size_t offset_{0}; // in byte + size_t offset_{0}; // in byte }; // Bootstrap based on user's setting for NCCL @@ -295,7 +292,8 @@ static void initialize_nvshmem_with_store( // Using an existing store_all_gather due to laziness. // TODO(yifu): should use broadcast - auto unique_ids = storeExchange.all_gather(store, rank, world_size, unique_id); + auto unique_ids = + storeExchange.all_gather(store, rank, world_size, unique_id); nvshmemx_init_attr_t attr; nvshmemx_set_attr_uniqueid_args(rank, world_size, &unique_ids[0], &attr); @@ -335,8 +333,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { TORCH_CHECK(ptr != nullptr || size == 0, "nvshmem_malloc failed"); // TODO: thread safety allocations_.try_emplace( - ptr, - std::make_unique(ptr, size, device_idx)); + ptr, std::make_unique(ptr, size, device_idx)); return ptr; } @@ -367,19 +364,23 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { // In case of MemPool, tensor.storage().data_ptr() may not match // exactly an allocation's base address. Thus we perform the search by // testing if the former is within an allocation's range. - auto alloc_it = std::find_if(allocations_.begin(), allocations_.end(), - [&](const auto& pair){ - auto& allocation = pair.second; - auto ptr_int = reinterpret_cast(ptr); - auto base_ptr = reinterpret_cast(allocation->ptr); - return ptr_int >= base_ptr && ptr_int < base_ptr + allocation->buffer_size; }); - TORCH_CHECK(alloc_it != allocations_.end(), + auto alloc_it = std::find_if( + allocations_.begin(), allocations_.end(), [&](const auto& pair) { + auto& allocation = pair.second; + auto ptr_int = reinterpret_cast(ptr); + auto base_ptr = reinterpret_cast(allocation->ptr); + return ptr_int >= base_ptr && + ptr_int < base_ptr + allocation->buffer_size; + }); + TORCH_CHECK( + alloc_it != allocations_.end(), "Pointer not within any SymmetricMemory allocation, " "is the tensor allocated from SymmetricMemory?"); auto& allocation = alloc_it->second; - // Search again using allocation base ptr (which is the key we use for caching, see below) + // Search again using allocation base ptr (which is the key we use for + // caching, see below) auto it = symm_mems_.find(std::make_tuple(allocation->ptr, *group_name)); c10::intrusive_ptr symm_mem; if (it != symm_mems_.end()) { @@ -387,8 +388,8 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { symm_mem = it->second; } else { // Create a new rendezvous - symm_mem = - c10::make_intrusive(allocation.get(), *group_name); + symm_mem = c10::make_intrusive( + allocation.get(), *group_name); } // Cache rendezvous using allocation's base address as key @@ -404,7 +405,8 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { } else { // Return a copy of the SymmetricMemory with an offset. This is a // "shallow" copy adjusting the offset field in the handle. - return c10::make_intrusive(*symm_mem, (uintptr_t)ptr - (uintptr_t)allocation->ptr); + return c10::make_intrusive( + *symm_mem, (uintptr_t)ptr - (uintptr_t)allocation->ptr); } }; @@ -423,7 +425,9 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { private: std::unordered_map> allocations_; - std::map, c10::intrusive_ptr> + std::map< + std::tuple, + c10::intrusive_ptr> symm_mems_; }; @@ -433,9 +437,7 @@ struct RegisterNVSHMEMSymmetricMemoryAllocator { // Query backend used for CUDA tensor if (getSymmMemBackendCUDA() == "NVSHMEM") { // Direct set (static registration) - register_allocator( - c10::DeviceType::CUDA, - allocator); + register_allocator(c10::DeviceType::CUDA, allocator); } else { // Register availability in case `set_backend` is called dynamically register_availability("NVSHMEM", allocator); diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp index ac9b1e1a69ca2..09925546aa368 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp @@ -1,11 +1,19 @@ +#include #include +#include + namespace { using namespace c10d::symmetric_memory; static bool is_finalizing_ = false; +// Signal pad size configuration - uses default if not explicitly set. +// A value of 0 indicates "not set" (use default). +// Using std::atomic for thread safety when accessed from C++ without GIL. +static std::atomic configured_signal_pad_size_{0}; + // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) class AllocatorMap { public: @@ -186,6 +194,15 @@ std::optional get_backend(c10::Device device) { return AllocatorMap::get().get_backend(device.type()); } +size_t get_signal_pad_size() { + size_t val = configured_signal_pad_size_.load(std::memory_order_acquire); + return val == 0 ? default_signal_pad_size : val; +} + +void set_signal_pad_size(size_t size) { + configured_signal_pad_size_.store(size, std::memory_order_release); +} + bool has_allocator(c10::DeviceType device_type) { return AllocatorMap::get().has_allocator(device_type); } @@ -385,6 +402,10 @@ at::Tensor SymmetricMemory::get_remote_tensor( return get_buffer_at_byte_offset(this, peer, sizes, dtype, get_offset()); } +size_t SymmetricMemory::get_signal_pad_size() { + return c10d::symmetric_memory::get_signal_pad_size(); +} + at::Tensor SymmetricMemory::get_signal_pad( int rank, c10::IntArrayRef sizes, diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp index d2cb70e1b1ae9..f2b07d21a5ef5 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp @@ -48,7 +48,7 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { virtual void** get_buffer_ptrs_dev() = 0; virtual void** get_signal_pad_ptrs_dev() = 0; virtual size_t get_buffer_size() = 0; - virtual size_t get_signal_pad_size() = 0; + size_t get_signal_pad_size(); virtual size_t get_offset() { TORCH_CHECK(false, "NYI"); @@ -200,6 +200,16 @@ TORCH_API void set_backend(const std::string& name); TORCH_API std::optional get_backend(c10::Device device); +// Get the current signal pad size for symmetric memory allocations. +// Returns the user-configured size if set, otherwise returns the default size. +TORCH_API size_t get_signal_pad_size(); + +// Set the signal pad size for future symmetric memory allocations. +// This must be called before any symmetric memory allocations are made. +// The size should be proportional to the number of blocks the user launches +// and the world size. +TORCH_API void set_signal_pad_size(size_t size); + C10_EXPORT void register_mempool_allocator( c10::DeviceType device_type, std::shared_ptr allocator); diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 48f22902ff98b..5e153a6a29db7 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -2012,4 +2012,56 @@ def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def] return _SymmetricMemory.get_mempool_allocator(torch.device(device)) -__all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"] +def set_signal_pad_size(size: int) -> None: + r""" + Set the signal pad size for future symmetric memory allocations. + + Signal pads are P2P-accessible memory regions used for synchronization in + symmetric memory. This function allows users to configure + the signal pad size to be proportional to their workload requirements. + + .. warning:: + This must be called before any symmetric memory allocations are made. + The size cannot be changed after allocations have been performed. + + Args: + size (int): the signal pad size in bytes. The size should be + proportional to the number of blocks launched and the world size. + + Example:: + + >>> # doctest: +SKIP + >>> # Set a larger signal pad size before any allocations + >>> torch.distributed._symmetric_memory.set_signal_pad_size(1024 * 1024) # 1MB + """ + _SymmetricMemory.signal_pad_size = size + + +def get_signal_pad_size() -> int: + r""" + Get the current signal pad size for symmetric memory allocations. + + Returns the user-configured size if set via :func:`set_signal_pad_size`, + otherwise returns the default size. + + Returns: + int: the signal pad size in bytes. + + Example:: + + >>> # doctest: +SKIP + >>> size = torch.distributed._symmetric_memory.get_signal_pad_size() + >>> print(f"Signal pad size: {size} bytes") + """ + return _SymmetricMemory.signal_pad_size + + +__all__ = [ + "empty", + "rendezvous", + "is_nvshmem_available", + "set_backend", + "get_backend", + "set_signal_pad_size", + "get_signal_pad_size", +] From e770c95f620c201b6ba915baa35adad305454152 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Thu, 4 Dec 2025 03:47:16 +0000 Subject: [PATCH 271/338] [inductor] require shape in TritonCSEVariable (#162275) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162275 Approved by: https://github.com/mlazos --- torch/_inductor/codegen/triton.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index cba36a25aad8d..782948b0f4021 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1005,8 +1005,7 @@ def __init__( # We'll use this to track which masks the variable needs when used for indirect indexing self.mask_vars: OrderedSet[str] = OrderedSet() assert dtype is not None, "TritonCSEVariable must have dtype" - # TODO: uncomment this and fix the few failures left - # assert shape is not None, "TritonCSEVariable must have shape" + assert shape is not None, "TritonCSEVariable must have shape" def update_on_args(self, name, args, kwargs): for arg in args: @@ -4773,7 +4772,9 @@ def codegen_body(self): self.body.writeline( f"{name} = tl.full([R0_BLOCK], {default}, tl.float32)[None, :]" ) - accumname2var[name] = self.cse.namedvar(name, dtype=torch.float) + accumname2var[name] = self.cse.namedvar( + name, dtype=torch.float, shape=("1", "R0_BLOCK") + ) self.body.writeline("split_size = min(RSPLIT_SIZE, xnumel - xoffset)") self.body.writeline( "for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):" @@ -4810,6 +4811,7 @@ def codegen_body(self): self.body, f"{triton_reduction_function}({var}, 0)", dtype=var.dtype, + shape=("R0_BLOCK",), ) import unittest From d4ad7a13725a34a5542c58a17f76e3d807e9d2cc Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 4 Dec 2025 18:36:18 +0000 Subject: [PATCH 272/338] [BE] Don't make typos in build-environment between build and test (#169423) Claude Coded, and inspired by fantastic mismatch between https://github.com/pytorch/pytorch/blob/45d14e2497292be06ad36eaa1aaaf7c630a2586a/.github/workflows/inductor-unittest.yml#L92-L93 And https://github.com/pytorch/pytorch/blob/45d14e2497292be06ad36eaa1aaaf7c630a2586a/.github/workflows/inductor-unittest.yml#L108-L109 Which resulted in using halide build artifacts while attempting to run pallas unit tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/169423 Approved by: https://github.com/seemethere --- .github/workflows/_linux-build.yml | 4 ++++ .github/workflows/_mac-build.yml | 4 ++++ .github/workflows/_win-build.yml | 4 ++++ .github/workflows/attention_op_microbenchmark.yml | 4 ++-- .github/workflows/b200-distributed.yml | 2 +- .github/workflows/b200-symm-mem.yml | 2 +- .github/workflows/h100-cutlass-backend.yml | 2 +- .github/workflows/h100-distributed.yml | 2 +- .github/workflows/h100-symm-mem.yml | 2 +- .github/workflows/inductor-micro-benchmark-x86.yml | 2 +- .github/workflows/inductor-micro-benchmark.yml | 2 +- .github/workflows/inductor-nightly.yml | 2 +- .github/workflows/inductor-perf-compare.yml | 2 +- .github/workflows/inductor-perf-test-b200.yml | 6 +++--- .../inductor-perf-test-nightly-aarch64.yml | 4 ++-- .../workflows/inductor-perf-test-nightly-h100.yml | 6 +++--- .../workflows/inductor-perf-test-nightly-macos.yml | 2 +- .../inductor-perf-test-nightly-rocm-mi300.yml | 2 +- .../inductor-perf-test-nightly-rocm-mi355.yml | 2 +- .../inductor-perf-test-nightly-x86-zen.yml | 4 ++-- .../workflows/inductor-perf-test-nightly-x86.yml | 4 ++-- .../workflows/inductor-perf-test-nightly-xpu.yml | 4 ++-- .github/workflows/inductor-perf-test-nightly.yml | 6 +++--- .github/workflows/inductor-periodic.yml | 8 ++++---- .github/workflows/inductor-rocm-mi200.yml | 2 +- .github/workflows/inductor-rocm-mi300.yml | 2 +- .github/workflows/inductor-unittest.yml | 10 +++++----- .github/workflows/inductor.yml | 4 ++-- .github/workflows/linux-aarch64.yml | 2 +- .github/workflows/mac-mps.yml | 2 +- .github/workflows/operator_benchmark.yml | 4 ++-- .github/workflows/operator_microbenchmark.yml | 6 +++--- .github/workflows/periodic-rocm-mi200.yml | 2 +- .github/workflows/periodic-rocm-mi300.yml | 2 +- .github/workflows/periodic.yml | 10 +++++----- .github/workflows/pull.yml | 14 +++++++------- .github/workflows/quantization-periodic.yml | 2 +- .github/workflows/rocm-mi200.yml | 2 +- .github/workflows/rocm-mi300.yml | 2 +- .github/workflows/rocm-mi355.yml | 2 +- .github/workflows/rocm-navi31.yml | 2 +- .github/workflows/s390x-periodic.yml | 2 +- .github/workflows/slow-rocm-mi200.yml | 2 +- .github/workflows/slow.yml | 6 +++--- .github/workflows/test-b200.yml | 2 +- .github/workflows/test-h100.yml | 2 +- .github/workflows/torchbench.yml | 2 +- .github/workflows/trunk-rocm-mi300.yml | 2 +- .github/workflows/trunk.yml | 14 +++++++------- .github/workflows/xpu.yml | 2 +- 50 files changed, 99 insertions(+), 87 deletions(-) diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index cc0064391fdef..7a375a0f81f25 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -121,6 +121,9 @@ on: test-matrix: value: ${{ jobs.build.outputs.test-matrix }} description: An optional JSON description of what test configs to run later on. + build-environment: + value: ${{ jobs.build.outputs.build-environment }} + description: Top-level label for what's being built/tested. jobs: build: @@ -132,6 +135,7 @@ jobs: outputs: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} test-matrix: ${{ steps.filter.outputs.test-matrix }} + build-environment: ${{ inputs.build-environment }} steps: - name: Setup SSH (Click me for login details) uses: pytorch/test-infra/.github/actions/setup-ssh@main diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 24fe510f0fb59..4fd7874ee0c4d 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -53,6 +53,9 @@ on: build-outcome: value: ${{ jobs.build.outputs.build-outcome }} description: The outcome of the build step. This is used to influence test filtering logic later on. + build-environment: + value: ${{ jobs.build.outputs.build-environment }} + description: Top-level label for what's being built/tested. jobs: build: @@ -65,6 +68,7 @@ jobs: outputs: build-outcome: ${{ steps.build.outcome }} test-matrix: ${{ steps.filter.outputs.test-matrix }} + build-environment: ${{ inputs.build-environment }} steps: - name: Clean up disk space before running MacOS workflow uses: pytorch/test-infra/.github/actions/check-disk-space@main diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index 0fd3cf7f3972e..005d68ece857d 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -55,6 +55,9 @@ on: test-matrix: value: ${{ jobs.build.outputs.test-matrix }} description: An optional JSON description of what test configs to run later on. + build-environment: + value: ${{ jobs.build.outputs.build-environment }} + description: Top-level label for what's being built/tested. env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} @@ -67,6 +70,7 @@ jobs: timeout-minutes: 240 outputs: test-matrix: ${{ steps.filter.outputs.test-matrix }} + build-environment: ${{ inputs.build-environment }} defaults: run: shell: bash diff --git a/.github/workflows/attention_op_microbenchmark.yml b/.github/workflows/attention_op_microbenchmark.yml index eec4d21fe2616..cd04a48223ce1 100644 --- a/.github/workflows/attention_op_microbenchmark.yml +++ b/.github/workflows/attention_op_microbenchmark.yml @@ -39,7 +39,7 @@ jobs: needs: attn-microbenchmark-build with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.attn-microbenchmark-build.outputs.build-environment }} docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }} secrets: inherit @@ -66,7 +66,7 @@ jobs: needs: opmicrobenchmark-build-b200 with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + build-environment: ${{ needs.opmicrobenchmark-build-b200.outputs.build-environment }} docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/b200-distributed.yml b/.github/workflows/b200-distributed.yml index bb85a4ddfc85e..e52c7a4b5f5c5 100644 --- a/.github/workflows/b200-distributed.yml +++ b/.github/workflows/b200-distributed.yml @@ -55,7 +55,7 @@ jobs: - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200 with: timeout-minutes: 1200 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200 + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/b200-symm-mem.yml b/.github/workflows/b200-symm-mem.yml index ba28066dd5602..62367b61b07b9 100644 --- a/.github/workflows/b200-symm-mem.yml +++ b/.github/workflows/b200-symm-mem.yml @@ -53,7 +53,7 @@ jobs: needs: - linux-jammy-cuda12_8-py3_10-gcc11-sm100-build-symm with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build-symm.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build-symm.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build-symm.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/h100-cutlass-backend.yml b/.github/workflows/h100-cutlass-backend.yml index edf4c2e0e807c..e5406f7600133 100644 --- a/.github/workflows/h100-cutlass-backend.yml +++ b/.github/workflows/h100-cutlass-backend.yml @@ -55,7 +55,7 @@ jobs: needs: - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/h100-distributed.yml b/.github/workflows/h100-distributed.yml index c05b61e30a635..0e5370a51c160 100644 --- a/.github/workflows/h100-distributed.yml +++ b/.github/workflows/h100-distributed.yml @@ -52,7 +52,7 @@ jobs: needs: - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-dist with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-dist.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-dist.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-dist.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/h100-symm-mem.yml b/.github/workflows/h100-symm-mem.yml index c75ca569fc7df..09c362a546024 100644 --- a/.github/workflows/h100-symm-mem.yml +++ b/.github/workflows/h100-symm-mem.yml @@ -52,7 +52,7 @@ jobs: needs: - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-symm with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-symm + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-symm.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-symm.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-symm.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-micro-benchmark-x86.yml b/.github/workflows/inductor-micro-benchmark-x86.yml index c6cc075e6b270..6936a9a9aa44f 100644 --- a/.github/workflows/inductor-micro-benchmark-x86.yml +++ b/.github/workflows/inductor-micro-benchmark-x86.yml @@ -37,7 +37,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-py3.9-gcc11 + build-environment: ${{ needs.inductor-build.outputs.build-environment }} docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} timeout-minutes: 720 diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index c10327c8f548c..5813aa28365e7 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -50,7 +50,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.build.outputs.build-environment }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 720 diff --git a/.github/workflows/inductor-nightly.yml b/.github/workflows/inductor-nightly.yml index 78602e05586b7..4258e8fdb0c84 100644 --- a/.github/workflows/inductor-nightly.yml +++ b/.github/workflows/inductor-nightly.yml @@ -56,7 +56,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: nightly-dynamo-benchmarks-build with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.nightly-dynamo-benchmarks-build.outputs.build-environment }} docker-image: ${{ needs.nightly-dynamo-benchmarks-build.outputs.docker-image }} test-matrix: ${{ needs.nightly-dynamo-benchmarks-build.outputs.test-matrix }} timeout-minutes: 720 diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index d38818eef4000..5e721e2f6ee1f 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -51,7 +51,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.build.outputs.build-environment }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} # disable monitor in perf tests for more investigation diff --git a/.github/workflows/inductor-perf-test-b200.yml b/.github/workflows/inductor-perf-test-b200.yml index 11f5f10a55ad8..fb297377f78b8 100644 --- a/.github/workflows/inductor-perf-test-b200.yml +++ b/.github/workflows/inductor-perf-test-b200.yml @@ -109,7 +109,7 @@ jobs: needs: build if: github.event.schedule == '0 7 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -126,7 +126,7 @@ jobs: needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -142,7 +142,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index 46a1966570c63..f7b3517dccc06 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -126,7 +126,7 @@ jobs: needs: linux-jammy-aarch64-py3_10-inductor-build if: github.event.schedule == '0 7 * * *' with: - build-environment: linux-jammy-aarch64-py3.10 + build-environment: ${{ needs.linux-jammy-aarch64-py3_10-inductor-build.outputs.build-environment }} dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true docker-image: ${{ needs.linux-jammy-aarch64-py3_10-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-inductor-build.outputs.test-matrix }} @@ -144,7 +144,7 @@ jobs: needs: linux-jammy-aarch64-py3_10-inductor-build if: github.event_name == 'workflow_dispatch' with: - build-environment: linux-jammy-aarch64-py3.10 + build-environment: ${{ needs.linux-jammy-aarch64-py3_10-inductor-build.outputs.build-environment }} dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }} docker-image: ${{ needs.linux-jammy-aarch64-py3_10-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-inductor-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index 1c35fc6794537..8d9b342daf3ef 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -132,7 +132,7 @@ jobs: needs: build if: github.event.schedule == '15 0 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -149,7 +149,7 @@ jobs: needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -168,7 +168,7 @@ jobs: # needs one round of benchmark if: ${{ github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' }} with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-${{ inputs.training || 'true' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cudagraphs-${{ inputs.cudagraphs || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'false' }}-aotinductor-${{ inputs.aotinductor || 'false' }}-maxautotune-${{ inputs.maxautotune || 'false' }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs || 'false' }}-cudagraphs_low_precision-${{ inputs.cudagraphs || 'false' }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-macos.yml b/.github/workflows/inductor-perf-test-nightly-macos.yml index 81c1c27b76439..56d2976f179c0 100644 --- a/.github/workflows/inductor-perf-test-nightly-macos.yml +++ b/.github/workflows/inductor-perf-test-nightly-macos.yml @@ -59,7 +59,7 @@ jobs: uses: ./.github/workflows/_mac-test.yml needs: macos-perf-py3-arm64-build with: - build-environment: macos-py3-arm64-distributed + build-environment: ${{ needs.macos-perf-py3-arm64-build.outputs.build-environment }} # Same as the build job python-version: 3.12.7 test-matrix: ${{ needs.macos-perf-py3-arm64-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml b/.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml index 8d6da18503001..484219b3b019b 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml @@ -120,7 +120,7 @@ jobs: uses: ./.github/workflows/_rocm-test.yml needs: linux-jammy-rocm-py3_10-inductor-benchmark-build with: - build-environment: linux-jammy-rocm-py3_10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-rocm-mi355.yml b/.github/workflows/inductor-perf-test-nightly-rocm-mi355.yml index 24872d2b1f110..ed253e9fdda68 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm-mi355.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm-mi355.yml @@ -120,7 +120,7 @@ jobs: uses: ./.github/workflows/_rocm-test.yml needs: linux-jammy-rocm-py3_10-inductor-benchmark-build with: - build-environment: linux-jammy-rocm-py3_10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-x86-zen.yml b/.github/workflows/inductor-perf-test-nightly-x86-zen.yml index a7110b0fd9328..eee51b7ff8889 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86-zen.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86-zen.yml @@ -106,7 +106,7 @@ jobs: needs: inductor-build if: github.event.schedule == '0 7 * * *' with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.inductor-build.outputs.build-environment }} dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true-freezing-true docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} @@ -122,7 +122,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.inductor-build.outputs.build-environment }} dashboard-tag: training-${{ inputs.training || 'false' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'true' }}-aotinductor-${{ inputs.aotinductor || 'true' }}-freezing-${{ inputs.freezing || 'true' }} docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index 0533184df2e0e..87875831e2a0b 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -107,7 +107,7 @@ jobs: needs: inductor-build if: github.event.schedule == '0 7 * * *' with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.inductor-build.outputs.build-environment }} dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true-freezing-true docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} @@ -124,7 +124,7 @@ jobs: needs: inductor-build if: github.event_name == 'workflow_dispatch' with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.inductor-build.outputs.build-environment }} dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-freezing-${{ inputs.freezing }} docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-xpu.yml b/.github/workflows/inductor-perf-test-nightly-xpu.yml index 28b10996bf38a..30eaa3b942af5 100644 --- a/.github/workflows/inductor-perf-test-nightly-xpu.yml +++ b/.github/workflows/inductor-perf-test-nightly-xpu.yml @@ -117,7 +117,7 @@ jobs: uses: ./.github/workflows/_xpu-test.yml needs: xpu-n-py3_10-inductor-benchmark-build with: - build-environment: linux-noble-xpu-n-py3.10 + build-environment: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }} test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }} @@ -137,7 +137,7 @@ jobs: uses: ./.github/workflows/_xpu-test.yml needs: xpu-n-py3_10-inductor-benchmark-build with: - build-environment: linux-noble-xpu-n-py3.10 + build-environment: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.build-environment }} dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }} test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 2617fc990b933..10df5cf523456 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -122,7 +122,7 @@ jobs: needs: build if: github.event.schedule == '0 7 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -138,7 +138,7 @@ jobs: needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -155,7 +155,7 @@ jobs: needs: build if: github.event_name == 'workflow_dispatch' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.build.outputs.build-environment }} dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 2a2f9049da99b..d3152cf8dcdb5 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -76,7 +76,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: periodic-dynamo-benchmarks-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + build-environment: ${{ needs.periodic-dynamo-benchmarks-build.outputs.build-environment }} docker-image: ${{ needs.periodic-dynamo-benchmarks-build.outputs.docker-image }} test-matrix: ${{ needs.periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit @@ -176,7 +176,7 @@ jobs: uses: ./.github/workflows/_rocm-test.yml needs: rocm-periodic-dynamo-benchmarks-build with: - build-environment: linux-jammy-rocm-py3_10 + build-environment: ${{ needs.rocm-periodic-dynamo-benchmarks-build.outputs.build-environment }} docker-image: ${{ needs.rocm-periodic-dynamo-benchmarks-build.outputs.docker-image }} test-matrix: ${{ needs.rocm-periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit @@ -203,7 +203,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-smoke-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.inductor-smoke-build.outputs.build-environment }} docker-image: ${{ needs.inductor-smoke-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-smoke-build.outputs.test-matrix }} secrets: inherit @@ -260,7 +260,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: periodic-dynamo-benchmarks-cpu-build with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.periodic-dynamo-benchmarks-cpu-build.outputs.build-environment }} docker-image: ${{ needs.periodic-dynamo-benchmarks-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.periodic-dynamo-benchmarks-cpu-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-rocm-mi200.yml b/.github/workflows/inductor-rocm-mi200.yml index 55de9a2121cf6..ed4df2868cb12 100644 --- a/.github/workflows/inductor-rocm-mi200.yml +++ b/.github/workflows/inductor-rocm-mi200.yml @@ -53,7 +53,7 @@ jobs: uses: ./.github/workflows/_rocm-test.yml needs: linux-jammy-rocm-py3_10-inductor-build with: - build-environment: linux-jammy-rocm-py3.10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index 57e5cb856729a..f3c73af51670e 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -61,7 +61,7 @@ jobs: uses: ./.github/workflows/_rocm-test.yml needs: linux-noble-rocm-py3_12-inductor-build with: - build-environment: linux-noble-rocm-py3.12-mi300 + build-environment: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.build-environment }} docker-image: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 308e3bedf2ea0..3f4b5173d0689 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -55,7 +55,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + build-environment: ${{ needs.inductor-build.outputs.build-environment }} docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit @@ -79,7 +79,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-halide-build with: - build-environment: linux-jammy-py3.12-gcc11 + build-environment: ${{ needs.inductor-halide-build.outputs.build-environment }} docker-image: ${{ needs.inductor-halide-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }} secrets: inherit @@ -105,7 +105,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-pallas-build with: - build-environment: linux-jammy-cuda12.8-py3.12-gcc11 + build-environment: ${{ needs.inductor-pallas-build.outputs.build-environment }} docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }} secrets: inherit @@ -129,7 +129,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-triton-cpu-build with: - build-environment: linux-jammy-py3.12-gcc11 + build-environment: ${{ needs.inductor-triton-cpu-build.outputs.build-environment }} docker-image: ${{ needs.inductor-triton-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-triton-cpu-build.outputs.test-matrix }} secrets: inherit @@ -156,7 +156,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-cpu-build with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.inductor-cpu-build.outputs.build-environment }} docker-image: ${{ needs.inductor-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-cpu-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index b54910164fe62..77e27ffcf669f 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -69,7 +69,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + build-environment: ${{ needs.inductor-build.outputs.build-environment }} docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit @@ -131,7 +131,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-cpu-build with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.inductor-cpu-build.outputs.build-environment }} docker-image: ${{ needs.inductor-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-cpu-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml index bb1a9a4f6a8b5..0cca30b7be009 100644 --- a/.github/workflows/linux-aarch64.yml +++ b/.github/workflows/linux-aarch64.yml @@ -56,7 +56,7 @@ jobs: id-token: write contents: read with: - build-environment: linux-jammy-aarch64-py3.10 + build-environment: ${{ needs.linux-jammy-aarch64-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-aarch64-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/mac-mps.yml b/.github/workflows/mac-mps.yml index c80599fe89988..d0caf9aba3965 100644 --- a/.github/workflows/mac-mps.yml +++ b/.github/workflows/mac-mps.yml @@ -39,7 +39,7 @@ jobs: needs: macos-py3-arm64-build with: sync-tag: macos-py3-arm64-mps-test - build-environment: macos-py3-arm64 + build-environment: ${{ needs.macos-py3-arm64-build.outputs.build-environment }} # Same as the build job python-version: 3.12.7 test-matrix: ${{ needs.macos-py3-arm64-build.outputs.test-matrix }} diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index 758147f5fe18e..e682e1eb06c24 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -48,7 +48,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: x86-opbenchmark-build with: - build-environment: linux-jammy-py3.10-gcc11-build + build-environment: ${{ needs.x86-opbenchmark-build.outputs.build-environment }} docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }} secrets: inherit @@ -72,7 +72,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: aarch64-opbenchmark-build with: - build-environment: linux-jammy-aarch64-py3.10 + build-environment: ${{ needs.aarch64-opbenchmark-build.outputs.build-environment }} docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/operator_microbenchmark.yml b/.github/workflows/operator_microbenchmark.yml index cd27b3a8a97db..19c8b0865437a 100644 --- a/.github/workflows/operator_microbenchmark.yml +++ b/.github/workflows/operator_microbenchmark.yml @@ -52,7 +52,7 @@ jobs: needs: opmicrobenchmark-build with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.opmicrobenchmark-build.outputs.build-environment }} docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }} secrets: inherit @@ -81,7 +81,7 @@ jobs: needs: opmicrobenchmark-build-b200 with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + build-environment: ${{ needs.opmicrobenchmark-build-b200.outputs.build-environment }} docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only @@ -107,7 +107,7 @@ jobs: needs: opmicrobenchmark-build-rocm with: timeout-minutes: 500 - build-environment: linux-jammy-rocm-py3_10 + build-environment: ${{ needs.opmicrobenchmark-build-rocm.outputs.build-environment }} docker-image: ${{ needs.opmicrobenchmark-build-rocm.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-rocm.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/periodic-rocm-mi200.yml b/.github/workflows/periodic-rocm-mi200.yml index 18e7b60570bf8..c0c75d9b7d68c 100644 --- a/.github/workflows/periodic-rocm-mi200.yml +++ b/.github/workflows/periodic-rocm-mi200.yml @@ -77,7 +77,7 @@ jobs: - linux-jammy-rocm-py3_10-build - target-determination with: - build-environment: linux-jammy-rocm-py3.10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/periodic-rocm-mi300.yml b/.github/workflows/periodic-rocm-mi300.yml index f3356cfa4fc77..04a1cbceeac28 100644 --- a/.github/workflows/periodic-rocm-mi300.yml +++ b/.github/workflows/periodic-rocm-mi300.yml @@ -76,7 +76,7 @@ jobs: - linux-noble-rocm-py3_12-build - target-determination with: - build-environment: linux-noble-rocm-py3.12-mi300 + build-environment: ${{ needs.linux-noble-rocm-py3_12-build.outputs.build-environment }} docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 325050392a393..783b9656f508f 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -77,7 +77,7 @@ jobs: - linux-jammy-cuda12_4-py3_10-gcc11-build - target-determination with: - build-environment: linux-jammy-cuda12.4-py3.10-gcc11 + build-environment: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit @@ -111,7 +111,7 @@ jobs: - linux-jammy-cuda12_8-py3_10-gcc11-build - target-determination with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit @@ -144,7 +144,7 @@ jobs: - linux-jammy-cuda12_8-py3_10-gcc11-debug-build - target-determination with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.test-matrix }} secrets: inherit @@ -176,7 +176,7 @@ jobs: - linux-jammy-cuda13_0-py3_10-gcc11-build - target-determination with: - build-environment: linux-jammy-cuda13.0-py3.10-gcc11 + build-environment: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit @@ -210,7 +210,7 @@ jobs: - linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build - target-determination with: - build-environment: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck + build-environment: ${{ needs.linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build.outputs.test-matrix }} timeout-minutes: 300 diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index eb676389f86ac..be98711f4858a 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -82,7 +82,7 @@ jobs: - linux-jammy-py3_10-gcc11-build - target-determination with: - build-environment: linux-jammy-py3.10-gcc11 + build-environment: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit @@ -92,7 +92,7 @@ jobs: uses: ./.github/workflows/_docs.yml needs: linux-jammy-py3_10-gcc11-build with: - build-environment: linux-jammy-py3.10-gcc11 + build-environment: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.docker-image }} secrets: inherit @@ -154,7 +154,7 @@ jobs: - linux-jammy-py3_10-clang18-asan-build - target-determination with: - build-environment: linux-jammy-py3.10-clang18-asan + build-environment: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.test-matrix }} sync-tag: asan-test @@ -182,7 +182,7 @@ jobs: - linux-jammy-py3_10-clang12-onnx-build - target-determination with: - build-environment: linux-jammy-py3.10-clang12-onnx + build-environment: ${{ needs.linux-jammy-py3_10-clang12-onnx-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-clang12-onnx-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-clang12-onnx-build.outputs.test-matrix }} secrets: inherit @@ -219,7 +219,7 @@ jobs: - linux-jammy-py3_10-clang12-build - target-determination with: - build-environment: linux-jammy-py3.10-clang12 + build-environment: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }} secrets: inherit @@ -254,7 +254,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: linux-jammy-py3_14-clang12-build with: - build-environment: linux-jammy-py3.14-clang12 + build-environment: ${{ needs.linux-jammy-py3_14-clang12-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_14-clang12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_14-clang12-build.outputs.test-matrix }} secrets: inherit @@ -342,7 +342,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: linux-jammy-cuda12_8-py3_10-gcc11-inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75 + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/quantization-periodic.yml b/.github/workflows/quantization-periodic.yml index 688f557eaf0e4..8dd97ff9308db 100644 --- a/.github/workflows/quantization-periodic.yml +++ b/.github/workflows/quantization-periodic.yml @@ -48,7 +48,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: periodic-quantization-build with: - build-environment: linux-jammy-cuda12.8-cudnn9-py3-gcc11 + build-environment: ${{ needs.periodic-quantization-build.outputs.build-environment }} docker-image: ${{ needs.periodic-quantization-build.outputs.docker-image }} test-matrix: ${{ needs.periodic-quantization-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/rocm-mi200.yml b/.github/workflows/rocm-mi200.yml index c947e361bfcb5..78c88b85fb1fe 100644 --- a/.github/workflows/rocm-mi200.yml +++ b/.github/workflows/rocm-mi200.yml @@ -68,7 +68,7 @@ jobs: - linux-jammy-rocm-py3_10-build - target-determination with: - build-environment: linux-jammy-rocm-py3.10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index 99059a1ff857c..3718bf6fadfec 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -67,7 +67,7 @@ jobs: - linux-noble-rocm-py3_12-build - target-determination with: - build-environment: linux-noble-rocm-py3.12-mi300 + build-environment: ${{ needs.linux-noble-rocm-py3_12-build.outputs.build-environment }} docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/rocm-mi355.yml b/.github/workflows/rocm-mi355.yml index be46dfeeadb1f..0a229b233875e 100644 --- a/.github/workflows/rocm-mi355.yml +++ b/.github/workflows/rocm-mi355.yml @@ -63,7 +63,7 @@ jobs: - linux-noble-rocm-py3_12-build - target-determination with: - build-environment: linux-noble-rocm-py3.12-mi355 + build-environment: ${{ needs.linux-noble-rocm-py3_12-build.outputs.build-environment }} docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }} tests-to-include: >- diff --git a/.github/workflows/rocm-navi31.yml b/.github/workflows/rocm-navi31.yml index 4596f44d252d2..bf1661b35e210 100644 --- a/.github/workflows/rocm-navi31.yml +++ b/.github/workflows/rocm-navi31.yml @@ -63,7 +63,7 @@ jobs: - linux-jammy-rocm-py3_10-build - target-determination with: - build-environment: linux-jammy-rocm-py3.10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} tests-to-include: >- diff --git a/.github/workflows/s390x-periodic.yml b/.github/workflows/s390x-periodic.yml index 405e3e1a581cc..17c656f7f9742 100644 --- a/.github/workflows/s390x-periodic.yml +++ b/.github/workflows/s390x-periodic.yml @@ -69,7 +69,7 @@ jobs: - linux-manylinux-2_28-py3-cpu-s390x-build - target-determination with: - build-environment: linux-s390x-binary-manywheel + build-environment: ${{ needs.linux-manylinux-2_28-py3-cpu-s390x-build.outputs.build-environment }} docker-image: pytorch/manylinuxs390x-builder:cpu-s390x test-matrix: ${{ needs.linux-manylinux-2_28-py3-cpu-s390x-build.outputs.test-matrix }} timeout-minutes: 600 diff --git a/.github/workflows/slow-rocm-mi200.yml b/.github/workflows/slow-rocm-mi200.yml index c564857dca9ce..937f04980522e 100644 --- a/.github/workflows/slow-rocm-mi200.yml +++ b/.github/workflows/slow-rocm-mi200.yml @@ -75,7 +75,7 @@ jobs: - linux-jammy-rocm-py3_10-build - target-determination with: - build-environment: linux-jammy-rocm-py3.10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index c14caee9a336c..0edb2ce3093b7 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -73,7 +73,7 @@ jobs: - linux-jammy-cuda12_8-py3_10-gcc11-sm86-build - target-determination with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm86-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm86-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm86-build.outputs.test-matrix }} secrets: inherit @@ -100,7 +100,7 @@ jobs: - linux-jammy-py3_10-clang12-build - target-determination with: - build-environment: linux-jammy-py3.10-clang12 + build-environment: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }} secrets: inherit @@ -130,7 +130,7 @@ jobs: - linux-jammy-py3_10-clang18-asan-build - target-determination with: - build-environment: linux-jammy-py3.10-clang18-asan + build-environment: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.test-matrix }} sync-tag: asan-test diff --git a/.github/workflows/test-b200.yml b/.github/workflows/test-b200.yml index 54acc686d1ae4..19dcb07c29844 100644 --- a/.github/workflows/test-b200.yml +++ b/.github/workflows/test-b200.yml @@ -71,7 +71,7 @@ jobs: needs: - linux-jammy-cuda12_8-py3_10-gcc11-sm100-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/test-h100.yml b/.github/workflows/test-h100.yml index 510473d5306ad..4351b427b0b8a 100644 --- a/.github/workflows/test-h100.yml +++ b/.github/workflows/test-h100.yml @@ -56,7 +56,7 @@ jobs: needs: - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 5a0273f0b745e..508c39a653600 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -46,7 +46,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + build-environment: ${{ needs.build.outputs.build-environment }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/trunk-rocm-mi300.yml b/.github/workflows/trunk-rocm-mi300.yml index 23ab5e9260a3e..373cc91c440c3 100644 --- a/.github/workflows/trunk-rocm-mi300.yml +++ b/.github/workflows/trunk-rocm-mi300.yml @@ -77,7 +77,7 @@ jobs: - linux-jammy-rocm-py3_10-build - target-determination with: - build-environment: linux-jammy-rocm-py3.10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index d1fd936280e94..dc66e362a4e6e 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -95,7 +95,7 @@ jobs: - target-determination with: timeout-minutes: 360 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit @@ -145,7 +145,7 @@ jobs: - macos-py3-arm64-build - target-determination with: - build-environment: macos-py3-arm64 + build-environment: ${{ needs.macos-py3-arm64-build.outputs.build-environment }} # Same as the build job python-version: 3.12.7 test-matrix: ${{ needs.macos-py3-arm64-build.outputs.test-matrix }} @@ -177,7 +177,7 @@ jobs: - win-vs2022-cpu-py3-build - target-determination with: - build-environment: win-vs2022-cpu-py3 + build-environment: ${{ needs.win-vs2022-cpu-py3-build.outputs.build-environment }} cuda-version: cpu test-matrix: ${{ needs.win-vs2022-cpu-py3-build.outputs.test-matrix }} disable-monitor: false @@ -228,7 +228,7 @@ jobs: - linux-jammy-rocm-py3_10-build - target-determination with: - build-environment: linux-jammy-rocm-py3.10 + build-environment: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit @@ -262,7 +262,7 @@ jobs: - get-label-type - win-vs2022-cuda12_8-py3-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} test-matrix: | { include: [ @@ -291,7 +291,7 @@ jobs: - verify-cachebench-cpu-build - target-determination with: - build-environment: linux-jammy-py3.10-gcc11 + build-environment: ${{ needs.verify-cachebench-cpu-build.outputs.build-environment }} docker-image: ${{ needs.verify-cachebench-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.verify-cachebench-cpu-build.outputs.test-matrix }} secrets: inherit @@ -316,7 +316,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: linux-jammy-py3-clang12-executorch-build with: - build-environment: linux-jammy-py3-clang12-executorch + build-environment: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index d9a1ba13d2b59..8799743809a77 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -82,7 +82,7 @@ jobs: id-token: write contents: read with: - build-environment: linux-noble-xpu-n-py3.10 + build-environment: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }} secrets: inherit From 31987d0eda56179bfbed565b8cbb937844cd300c Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Tue, 25 Nov 2025 11:47:55 -0800 Subject: [PATCH 273/338] [inductor] skip the r2r determ test in fbcode (#169074) Claude generate the change and I review/test/publish it. The tests will fail in fbcode (thanks Ed for flagging). Disabling it in fbcode for now. I think for now oss signal is enough. If we really want signals in fbcode, we probably need replace the 'python' command running the benchmark script with 'buck' command. One instance of the failure in fbcode: https://www.internalfb.com/tasks/?t=246383644 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169074 Approved by: https://github.com/eellison, https://github.com/v0i0 --- test/inductor/test_deterministic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/inductor/test_deterministic.py b/test/inductor/test_deterministic.py index d7e4313f5fe3b..b03e5b1ffad78 100644 --- a/test/inductor/test_deterministic.py +++ b/test/inductor/test_deterministic.py @@ -21,6 +21,7 @@ GPU_TYPE, HAS_GPU_AND_TRITON, IS_BIG_GPU, + IS_FBCODE, ) @@ -114,6 +115,7 @@ def foo(x): else: self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0) + @unittest.skipIf(IS_FBCODE, "Skipping run2run determinism test in fbcode") @parametrize("model_name", ["GoogleFnet", "BertForMaskedLM", "DistillGPT2"]) @parametrize("training_or_inference", ["training", "inference"]) @parametrize("precision", ["float32", "bfloat16", "float16", "amp"]) From 87b97449565d1b3cd158d1df99b4339a9a8ee8b9 Mon Sep 17 00:00:00 2001 From: Jesse Rusak Date: Thu, 4 Dec 2025 19:34:43 +0000 Subject: [PATCH 274/338] Clarify checkpointing docs (#169007) Clarifies the checkpointing docs by noting the semantics that the tensors produced inside `function` are not kept alive, that the tensors in `args` *are* kept alive, and be explicit about the use of "checkpointed" as referring to the region of code in which the tensors are not saved. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169007 Approved by: https://github.com/soulitzer --- torch/utils/checkpoint.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 71a67ed751fd8..da74334025111 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -359,11 +359,14 @@ def checkpoint( r"""Checkpoint a model or part of the model. Activation checkpointing is a technique that trades compute for memory. - Instead of keeping tensors needed for backward alive until they are used in - gradient computation during backward, forward computation in checkpointed - regions omits saving tensors for backward and recomputes them during the - backward pass. Activation checkpointing can be applied to any part of a - model. + By default, tensors computed during the forward pass are kept alive until + they are used in gradient computations in the backward pass. To reduce this + memory usage, tensors produced in the passed :attr:`function` are not kept + alive until the backward pass. Instead, any passed tensors in :attr:`args` + are kept alive, and the unsaved tensors are recomputed by re-invoking + :attr:`function` in the backward pass as needed for gradient computation. + Activation checkpointing can be applied to any part of a model -- this is + sometimes described as "checkpointing" that part of the model. There are currently two checkpointing implementations available, determined by the :attr:`use_reentrant` parameter. It is recommended that you use From ada0665ea39934328e8c31a3f6c9f991b67beacd Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 4 Dec 2025 04:20:21 -0800 Subject: [PATCH 275/338] Add C++ wrapper for shim from_blob in stable/csrc/ops.h (#168380) One caveat is that I only test the normal from_blob args (so no Layout and no storage_offset) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168380 Approved by: https://github.com/albanD Co-authored-by: Mikayla Gawarecki --- .../csrc/my_from_blob.cpp | 32 ++++ .../libtorch_agnostic_2_10/ops.py | 20 +++ test/cpp_extensions/test_libtorch_agnostic.py | 56 +++++++ test/test_cpp_extensions_jit.py | 138 ++++++++++++++++++ torch/csrc/stable/ops.h | 33 ++++- torch/csrc/stable/stableivalue_conversions.h | 12 +- 6 files changed, 284 insertions(+), 7 deletions(-) create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_from_blob.cpp diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_from_blob.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_from_blob.cpp new file mode 100644 index 0000000000000..124b6cb7f2263 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_from_blob.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include + +using torch::stable::Tensor; + +// Wrapper for torch::stable::from_blob with all parameters +// Note: We pass data_ptr as int64_t since we can't pass void* through the +// dispatcher +Tensor my_from_blob( + int64_t data_ptr, + torch::headeronly::HeaderOnlyArrayRef sizes, + torch::headeronly::HeaderOnlyArrayRef strides, + torch::stable::Device device, + torch::headeronly::ScalarType dtype) { + void* data = reinterpret_cast(data_ptr); + return torch::stable::from_blob( + data, sizes, strides, device, dtype); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def( + "my_from_blob(int data_ptr, int[] sizes, int[] strides, Device device, ScalarType dtype) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl("my_from_blob", TORCH_BOX(&my_from_blob)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index b1fca47322e1b..2815256d6a63b 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -316,3 +316,23 @@ def my_cuda_stream_synchronize(stream: int, device_index: int): return torch.ops.libtorch_agnostic_2_10.my_cuda_stream_synchronize( stream, device_index ) + + +def my_from_blob(data_ptr, sizes, strides, device, dtype) -> Tensor: + """ + Creates a Tensor from existing memory using torch::stable::from_blob. + + Args: + data_ptr: int - pointer to the data buffer + sizes: tuple[int] - size of the tensor + strides: tuple[int] - strides of the tensor + device: Device - device on which the tensor resides + dtype: ScalarType - data type of the tensor + storage_offset: int - offset in the storage + layout: Layout - layout of the tensor + + Returns: Tensor - tensor wrapping the existing memory + """ + return torch.ops.libtorch_agnostic_2_10.my_from_blob.default( + data_ptr, sizes, strides, device, dtype + ) diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index 10f1bba1e3179..6a7bf6e7594c7 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -922,6 +922,62 @@ def test_my_cuda_stream_synchronize(self, device): # sanity check for torch_cuda_stream_synchronize: libtorch_agnostic.ops.my_cuda_stream_synchronize(stream, device_index) + @skipIfTorchVersionLessThan(2, 10) + @skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor") + def test_my_from_blob(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + # Create reference implementation using unstable torch::from_blob via load_inline + source = """ + #include + + at::Tensor reference_from_blob(at::Tensor t) { + void* data_ptr = t.storage().data_ptr().get(); + auto options = torch::TensorOptions() + .dtype(t.dtype()) + .device(t.device()); + + return torch::from_blob( + data_ptr, + t.sizes(), + t.strides(), + options); + } + """ + + module = torch.utils.cpp_extension.load_inline( + name="test_from_blob_reference", + cpp_sources=[source], + functions=["reference_from_blob"], + ) + + # Test basic from_blob with contiguous tensor + original = torch.rand(2, 3, device=device, dtype=torch.float32) + stable_result = libtorch_agnostic.ops.my_from_blob( + original.data_ptr(), + original.size(), + original.stride(), + device, + torch.float32, + ) + reference_result = module.reference_from_blob(original) + self.assertEqual(stable_result, reference_result) + self.assertEqual(stable_result.data_ptr(), original.data_ptr()) + + # Test with non-contiguous strides + transposed = torch.rand(4, 6, device=device, dtype=torch.float32).t() + + stable_transposed = libtorch_agnostic.ops.my_from_blob( + transposed.data_ptr(), + transposed.size(), + transposed.stride(), + device, + transposed.dtype, + ) + + reference_transposed = module.reference_from_blob(transposed) + self.assertEqual(stable_transposed, reference_transposed) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 541aef8499b6b..9d03cbda766a2 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -1271,6 +1271,144 @@ def test_aoti_torch_call_dispatcher(self): self.assertEqual(abs_t, torch.abs(t)) self.assertEqual(floor_t, torch.floor(t)) + def test_from_blob_stable_api(self): + source = """ + #include + #include + #include + + // Test using the stable API torch::stable::from_blob + at::Tensor test_stable_from_blob() { + // Allocate data buffer with known values + static std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Create tensor using stable API + torch::stable::Tensor stable_tensor = torch::stable::from_blob( + data.data(), + {2, 3}, + {3, 1}, + torch::stable::Device(torch::headeronly::DeviceType::CPU, 0), + torch::headeronly::ScalarType::Float + ); + + // Convert stable::Tensor to at::Tensor for return + // The stable::Tensor wraps an AtenTensorHandle, we need to extract the underlying tensor + AtenTensorHandle handle = stable_tensor.get(); + return *reinterpret_cast(handle); + } + + // Test using the standard torch::from_blob as reference + at::Tensor test_reference_from_blob() { + // Use the same data buffer + static std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Create tensor using standard API + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU); + at::Tensor ref_tensor = torch::from_blob( + data.data(), + {2, 3}, + {3, 1}, + options + ); + + return ref_tensor; + } + + // Test with non-contiguous strides + at::Tensor test_stable_from_blob_strided() { + static std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Create a non-contiguous view: shape [2, 2] with stride [3, 1] + // This will select elements at indices [0,1] and [3,4] + torch::stable::Tensor stable_tensor = torch::stable::from_blob( + data.data(), + {2, 2}, + {3, 1}, + torch::stable::Device(torch::headeronly::DeviceType::CPU, 0), + torch::headeronly::ScalarType::Float + ); + + AtenTensorHandle handle = stable_tensor.get(); + return *reinterpret_cast(handle); + } + + at::Tensor test_reference_from_blob_strided() { + static std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU); + at::Tensor ref_tensor = torch::from_blob( + data.data(), + {2, 2}, + {3, 1}, + options + ); + + return ref_tensor; + } + + // Test with storage offset + at::Tensor test_stable_from_blob_offset() { + static std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Create tensor starting from offset 2 (third element) + torch::stable::Tensor stable_tensor = torch::stable::from_blob( + data.data(), + {2, 2}, + {2, 1}, + torch::stable::Device(torch::headeronly::DeviceType::CPU, 0), + torch::headeronly::ScalarType::Float, + 2 // storage_offset - start from data[2] + ); + + AtenTensorHandle handle = stable_tensor.get(); + return *reinterpret_cast(handle); + } + + at::Tensor test_reference_from_blob_offset() { + static std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU); + // Note: torch::from_blob doesn't support storage_offset directly, + // so we create from blob and then apply offset + at::Tensor ref_tensor = torch::from_blob( + data.data() + 2, // pointer offset instead + {2, 2}, + {2, 1}, + options + ); + + return ref_tensor; + } + """ + + module = torch.utils.cpp_extension.load_inline( + name="test_from_blob_stable", + cpp_sources=[source], + functions=[ + "test_stable_from_blob", + "test_reference_from_blob", + "test_stable_from_blob_strided", + "test_reference_from_blob_strided", + "test_stable_from_blob_offset", + "test_reference_from_blob_offset", + ], + ) + + # Test basic from_blob + stable_result = module.test_stable_from_blob() + reference_result = module.test_reference_from_blob() + self.assertEqual(stable_result, reference_result) + + # Test with non-contiguous strides + stable_strided = module.test_stable_from_blob_strided() + reference_strided = module.test_reference_from_blob_strided() + self.assertEqual(stable_strided, reference_strided) + + # Test with storage offset + stable_offset = module.test_stable_from_blob_offset() + reference_offset = module.test_reference_from_blob_offset() + self.assertEqual(stable_offset, reference_offset) + @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") def test_cuda_pluggable_allocator_include(self): """ diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 923cbf398a104..1199dc03135fc 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -379,6 +379,37 @@ inline torch::stable::Tensor view( return torch::stable::detail::to(stack[0]); } -#endif +inline torch::stable::Tensor from_blob( + void* data, + torch::headeronly::IntHeaderOnlyArrayRef sizes, + torch::headeronly::IntHeaderOnlyArrayRef strides, + torch::stable::Device device, + torch::headeronly::ScalarType dtype, + int64_t storage_offset = 0, + torch::headeronly::Layout layout = torch::headeronly::Layout::Strided) { + auto shim_dtype = + torch::stable::detail::to(torch::stable::detail::from(dtype)); + auto shim_device_type = torch::stable::detail::to( + torch::stable::detail::from(device.type())); + auto shim_layout = + torch::stable::detail::to(torch::stable::detail::from(layout)); + AtenTensorHandle ath; + TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + storage_offset, + shim_dtype, + shim_device_type, + device.index(), + &ath, + shim_layout, + nullptr, + 0)); + return torch::stable::Tensor(ath); +} + +#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 HIDDEN_NAMESPACE_END(torch, stable) diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 708139836411a..c4f10486ec779 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -276,10 +276,10 @@ struct FromImpl { // ============================================================================= #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 -// Specialization for c10::Layout => StableIValue +// Specialization for torch::headeronly::Layout => StableIValue // Note that we call into the shim to translate between the user's // Layout and libtorch's Layout, which can be different! -using c10::Layout; +using torch::headeronly::Layout; template <> struct FromImpl { static StableIValue call( @@ -311,10 +311,10 @@ struct FromImpl { } }; -// Specialization for c10::MemoryFormat => StableIValue +// Specialization for torch::headeronly::MemoryFormat => StableIValue // Note that we call into the shim to translate between the user's // MemoryFormat and libtorch's MemoryFormat, which can be different! -using c10::MemoryFormat; +using torch::headeronly::MemoryFormat; template <> struct FromImpl { static StableIValue call( @@ -623,7 +623,7 @@ struct ToImpl { // ============================================================================= #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 -// Specialization for StableIValue => c10::Layout +// Specialization for StableIValue => torch::headeronly::Layout template <> struct ToImpl { static Layout call( @@ -657,7 +657,7 @@ struct ToImpl { } }; -// Specialization for StableIValue => c10::MemoryFormat +// Specialization for StableIValue => torch::headeronly::MemoryFormat template <> struct ToImpl { static MemoryFormat call( From d3b4475277441128e201af810a4f1be3844c0a13 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 4 Dec 2025 07:40:47 -0800 Subject: [PATCH 276/338] Add STD_CUDA_{KERNEL_LAUNCH_}CHECK (#169385) This PR adds a shim function `torch_c10_cuda_check_msg` that calls [`c10::cuda::c10_cuda_check_implementation`](https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAException.cpp?brid=c7GuulbbejHwCKQd9WHwYg&fbclid=IwY2xjawOdEuBleHRuA2FlbQIxMQBicmlkETFRQVN6d2Q0b1QxeUtTVU9Bc3J0YwZhcHBfaWQBMAABHugjOOQgUIFG8UXP9yHu7Hb6o2obrmdoLygi9Ei-dTKoHqxpx3YziphkBy1L_aem_0SEdyiUJ90nUKEpNkDJKyA#L10) and returns the formatted error message if an error was thrown. `STD_CUDA_CHECK` calls this shim and (following the approach of STD_TORCH_CHECK) throws a std::runtime_error with the message. | `C10_CUDA_CHECK` | `STD_CUDA_CHECK` | | ----------------------|-------------------------| | throws `c10::AcceleratorError` (propagated to python as `torch.AcceleratorError`)| throws `std::runtime_error` (propagated to python as `RuntimeError`)| **Note that `torch.AcceleratorError` has an `error_code` attribute, but the RuntimeError propagated by STD_CUDA_CHECK will not have this attribute. Otherwise the error messages propagated are identical.** https://github.com/pytorch/pytorch/blob/9b3e34d8589b29f7b4e7fab6f78711b7ca6e4639/torch/csrc/Exceptions.cpp#L339-L341 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169385 Approved by: https://github.com/albanD, https://github.com/eqy ghstack dependencies: #168380 --- .../csrc/test_std_cuda_check.cu | 61 +++++++++ .../libtorch_agnostic_2_10/ops.py | 32 +++++ test/cpp_extensions/test_libtorch_agnostic.py | 127 ++++++++++++++++++ torch/csrc/cuda/shim_common.cpp | 56 ++++++++ torch/csrc/stable/c/shim.h | 14 ++ torch/csrc/stable/macros.h | 26 ++++ 6 files changed, 316 insertions(+) create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_std_cuda_check.cu create mode 100644 torch/csrc/stable/macros.h diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_std_cuda_check.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_std_cuda_check.cu new file mode 100644 index 0000000000000..0ad02aa1666e0 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_std_cuda_check.cu @@ -0,0 +1,61 @@ +#include +#include +#include + +__global__ void dummy_kernel(int /*unused*/) { + // Intentionally empty +} + +__global__ void invalid_kernel(int /*unused*/) { + // This kernel itself is fine, but we'll launch it with invalid config +} + +int test_std_cuda_check_success() { + // cudaGetDevice should succeed if CUDA is available + int device; + STD_CUDA_CHECK(cudaGetDevice(&device)); + return device; +} + +void test_std_cuda_check_error() { + // cudaSetDevice with an invalid device ID should fail + // Using 99999 as an invalid device ID to trigger an error + STD_CUDA_CHECK(cudaSetDevice(99999)); +} + +void test_std_cuda_kernel_launch_check_success() { + // Launch a simple kernel with valid configuration + dummy_kernel<<<1, 1>>>(0); + + STD_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void test_std_cuda_kernel_launch_check_error() { + // Launch a kernel with invalid configuration + // Using more blocks than allowed (2^31) will trigger a launch error + invalid_kernel<<<2147483648, 1>>>(0); + + STD_CUDA_KERNEL_LAUNCH_CHECK(); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("test_std_cuda_check_success() -> int"); + m.def("test_std_cuda_check_error() -> ()"); + m.def("test_std_cuda_kernel_launch_check_success() -> ()"); + m.def("test_std_cuda_kernel_launch_check_error() -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl( + "test_std_cuda_check_success", TORCH_BOX(&test_std_cuda_check_success)); + m.impl("test_std_cuda_check_error", TORCH_BOX(&test_std_cuda_check_error)); + m.impl( + "test_std_cuda_kernel_launch_check_success", + TORCH_BOX(&test_std_cuda_kernel_launch_check_success)); + m.impl( + "test_std_cuda_kernel_launch_check_error", + TORCH_BOX(&test_std_cuda_kernel_launch_check_error)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index 2815256d6a63b..b063961575cb7 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -336,3 +336,35 @@ def my_from_blob(data_ptr, sizes, strides, device, dtype) -> Tensor: return torch.ops.libtorch_agnostic_2_10.my_from_blob.default( data_ptr, sizes, strides, device, dtype ) + + +def test_std_cuda_check_success() -> int: + """ + Test STD_CUDA_CHECK macro with a successful CUDA operation. + Returns the current CUDA device index. + """ + return torch.ops.libtorch_agnostic_2_10.test_std_cuda_check_success.default() + + +def test_std_cuda_check_error() -> None: + """ + Test STD_CUDA_CHECK macro with a failing CUDA operation. + This should raise a RuntimeError with the CUDA error message. + """ + torch.ops.libtorch_agnostic_2_10.test_std_cuda_check_error.default() + + +def test_std_cuda_kernel_launch_check_success() -> None: + """ + Test STD_CUDA_KERNEL_LAUNCH_CHECK macro with a successful kernel launch. + Launches a simple kernel and checks for errors. + """ + torch.ops.libtorch_agnostic_2_10.test_std_cuda_kernel_launch_check_success.default() + + +def test_std_cuda_kernel_launch_check_error() -> None: + """ + Test STD_CUDA_KERNEL_LAUNCH_CHECK macro with an invalid kernel launch. + This should raise a RuntimeError with the CUDA kernel launch error message. + """ + torch.ops.libtorch_agnostic_2_10.test_std_cuda_kernel_launch_check_error.default() diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index 6a7bf6e7594c7..d06099a6b7cf7 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -16,6 +16,7 @@ IS_WINDOWS, parametrize, run_tests, + skipIfRocm, skipIfTorchDynamo, TestCase, xfailIfTorchDynamo, @@ -978,6 +979,132 @@ def test_my_from_blob(self, device): reference_transposed = module.reference_from_blob(transposed) self.assertEqual(stable_transposed, reference_transposed) + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + def test_std_cuda_check_success(self, device): + """Test that STD_CUDA_CHECK works correctly for successful CUDA calls.""" + import libtorch_agnostic_2_10 as libtorch_agnostic + + result = libtorch_agnostic.ops.test_std_cuda_check_success() + expected_device = torch.cuda.current_device() + self.assertEqual(result, expected_device) + + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + @skipIfRocm(msg="TODO: @mikaylagawarecki fix after branch cut") + @parametrize("show_cpp_stacktraces", [False, True]) + def test_std_cuda_check_error(self, device, show_cpp_stacktraces): + """Test that STD_CUDA_CHECK throws std::runtime_error with CUDA error message. + + When TORCH_SHOW_CPP_STACKTRACES=1, the error should include a C++ stack trace. + Since this env var is cached on first use, we use subprocess to test both cases. + """ + import os + import subprocess + import sys + + test_script = """ +import torch +import libtorch_agnostic_2_10 as libtorch_agnostic + +try: + libtorch_agnostic.ops.test_std_cuda_check_error() +except RuntimeError as e: + print(str(e)) +""" + env = os.environ.copy() + env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if show_cpp_stacktraces else "0" + # Pass the current sys.path to subprocess so it can find the locally installed extension + env["PYTHONPATH"] = os.pathsep.join(sys.path) + + result = subprocess.run( + [sys.executable, "-c", test_script], + capture_output=True, + text=True, + env=env, + ) + + error_message = result.stdout + result.stderr + + self.assertTrue( + "CUDA error: invalid device ordinal" in error_message + or "HIP error: invalid device ordinal" in error_message, + f"Expected 'CUDA/HIP error: invalid device ordinal' in error message, got: {error_message}", + ) + self.assertIn( + "GPU device may be out of range, do you have enough GPUs?", + error_message, + ) + + if show_cpp_stacktraces: + self.assertIn("C++ CapturedTraceback:", error_message) + self.assertRegex( + error_message, + r"Exception raised from test_std_.*_check_error at .*test_std_.*check\..*:\d+", + ) + else: + self.assertNotIn("C++ CapturedTraceback:", error_message) + + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + def test_std_cuda_kernel_launch_check_success(self, device): + """Test that STD_CUDA_KERNEL_LAUNCH_CHECK works correctly for successful kernel launches.""" + import libtorch_agnostic_2_10 as libtorch_agnostic + + libtorch_agnostic.ops.test_std_cuda_kernel_launch_check_success() + + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + @parametrize("show_cpp_stacktraces", [False, True]) + @skipIfRocm(msg="TODO: @mikaylagawarecki fix after branch cut") + def test_std_cuda_kernel_launch_check_error(self, device, show_cpp_stacktraces): + """Test that STD_CUDA_KERNEL_LAUNCH_CHECK throws std::runtime_error for invalid kernel launches. + + When TORCH_SHOW_CPP_STACKTRACES=1, the error should include a C++ stack trace. + Since this env var is cached on first use, we use subprocess to test both cases. + """ + import os + import subprocess + import sys + + test_script = """ +import torch +import libtorch_agnostic_2_10 as libtorch_agnostic + +try: + libtorch_agnostic.ops.test_std_cuda_kernel_launch_check_error() +except RuntimeError as e: + print(str(e)) +""" + env = os.environ.copy() + env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if show_cpp_stacktraces else "0" + # Pass the current sys.path to subprocess so it can find the locally installed extension + env["PYTHONPATH"] = os.pathsep.join(sys.path) + + result = subprocess.run( + [sys.executable, "-c", test_script], + capture_output=True, + text=True, + env=env, + ) + + error_message = result.stdout + result.stderr + + self.assertTrue( + "CUDA error: invalid configuration argument" in error_message + or "HIP error: invalid configuration argument" in error_message, + f"Expected 'CUDA|HIP error: invalid configuration argument' in error message, got: {error_message}", + ) + + if show_cpp_stacktraces: + self.assertIn("C++ CapturedTraceback:", error_message) + self.assertRegex( + error_message, + r"Exception raised from test_std_.*_kernel_launch_check_error at .*test_std_.*_check\..*:\d+", + ) + else: + self.assertNotIn("C++ CapturedTraceback:", error_message) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/cuda/shim_common.cpp b/torch/csrc/cuda/shim_common.cpp index 24cee443bb1aa..c58230958d68d 100644 --- a/torch/csrc/cuda/shim_common.cpp +++ b/torch/csrc/cuda/shim_common.cpp @@ -1,7 +1,32 @@ #include +#include #include +#include #include #include +#include +#include + +namespace { +// Helper to call the appropriate check implementation for CUDA vs ROCm. +// This is done in a separate function to avoid preprocessor directives inside +// macro (AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE) arguments, which +// is undefined behavior and fails on MSVC. +inline void call_c10_accelerator_check_implementation( + int32_t err, + const char* filename, + const char* function_name, + uint32_t line_number, + bool include_device_assertions) { +#ifdef USE_ROCM + c10::hip::c10_hip_check_implementation( + err, filename, function_name, line_number, include_device_assertions); +#else + c10::cuda::c10_cuda_check_implementation( + err, filename, function_name, line_number, include_device_assertions); +#endif +} +} // namespace AOTITorchError torch_get_current_cuda_blas_handle(void** ret_handle) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ @@ -37,3 +62,34 @@ AOTITorchError torch_cuda_stream_synchronize( .synchronize(); }); } + +AOTITorchError torch_c10_cuda_check_msg( + int32_t err, + const char* filename, + const char* function_name, + uint32_t line_number, + bool include_device_assertions, + char** error_msg) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + *error_msg = nullptr; + + try { + call_c10_accelerator_check_implementation( + err, filename, function_name, line_number, include_device_assertions); + } catch (const c10::AcceleratorError& e) { + // Match the behavior of Python exception translation: + // use what() if C++ stacktraces are enabled, otherwise + // what_without_backtrace() + const char* what_str = torch::get_cpp_stacktraces_enabled() + ? e.what() + : e.what_without_backtrace(); + size_t msg_len = std::strlen(what_str); + *error_msg = new char[msg_len + 1]; + std::memcpy(*error_msg, what_str, msg_len + 1); + } + }); +} + +void torch_c10_cuda_free_error_msg(char* error_msg) { + delete[] error_msg; +} diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index 545cb3eeb2c56..384d9369b7bc4 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -133,6 +133,20 @@ AOTI_TORCH_EXPORT AOTITorchError torch_get_cuda_stream_from_pool( AOTI_TORCH_EXPORT AOTITorchError torch_cuda_stream_synchronize(void* stream, int32_t device_index); +// Wrapper around c10_cuda_check_implementation that captures the error message +// without propagating the exception. The caller must free error_msg using +// torch_c10_cuda_free_error_msg if it is non-null. +AOTI_TORCH_EXPORT AOTITorchError torch_c10_cuda_check_msg( + int32_t err, + const char* filename, + const char* function_name, + uint32_t line_number, + bool include_device_assertions, + char** error_msg); + +// Free error message allocated by torch_c10_cuda_check_msg +AOTI_TORCH_EXPORT void torch_c10_cuda_free_error_msg(char* error_msg); + #endif // USE_CUDA #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 diff --git a/torch/csrc/stable/macros.h b/torch/csrc/stable/macros.h new file mode 100644 index 0000000000000..c06e9f0f541c8 --- /dev/null +++ b/torch/csrc/stable/macros.h @@ -0,0 +1,26 @@ +#include + +#include +#include + +// Users of this macro are expected to include cuda_runtime.h +#define STD_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + char* __error_msg = nullptr; \ + torch_c10_cuda_check_msg( \ + static_cast(__err), \ + __FILE__, \ + __func__, \ + static_cast(__LINE__), \ + true, \ + &__error_msg); \ + if (__error_msg != nullptr) { \ + std::string __msg(__error_msg); \ + torch_c10_cuda_free_error_msg(__error_msg); \ + throw std::runtime_error(__msg); \ + } \ + } while (0) + +// Users of this macro are expected to include cuda_runtime.h +#define STD_CUDA_KERNEL_LAUNCH_CHECK() STD_CUDA_CHECK(cudaGetLastError()) From 6f4baa5abe393cba7792edca4d3b418a630e1ba2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 20:17:50 +0000 Subject: [PATCH 277/338] Revert "[ROCm] Enable StaticCudaLauncher for ROCm (#166492)" This reverts commit 18f3ca08f13b8de61307f5e8cd7d4cccb67e9d11. Reverted https://github.com/pytorch/pytorch/pull/166492 on behalf of https://github.com/huydhn due to Sorry for reverthing this but the change has been reverted by an internal meta team ([comment](https://github.com/pytorch/pytorch/pull/166492#issuecomment-3614162658)) --- test/inductor/test_ck_backend.py | 1 - test/inductor/test_codecache.py | 9 +- test/inductor/test_static_cuda_launcher.py | 21 +++- .../_inductor/runtime/static_cuda_launcher.py | 56 ++-------- torch/_inductor/runtime/triton_heuristics.py | 11 +- torch/_inductor/utils.py | 5 - torch/csrc/Module.cpp | 2 +- torch/csrc/inductor/static_cuda_launcher.cpp | 102 ++---------------- torch/csrc/inductor/static_cuda_launcher.h | 2 +- 9 files changed, 41 insertions(+), 168 deletions(-) diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index 405e46d8ded52..079be79fcc9d8 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -235,7 +235,6 @@ def mm(a, b): Y_eager = a @ b torch.testing.assert_close(Y_compiled, Y_eager, equal_nan=True) - @unittest.skip("Autotune Mismatch being investigated") @unittest.skipIf(not torch.version.hip, "ROCM only") @unittest.mock.patch.dict(os.environ, _test_env) @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index e86a673ad813f..1ab261051f4c6 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -479,17 +479,14 @@ def test_remote_cache_load_function( if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if ( - device == "cuda" - and torch.version.hip is None - and dtype == torch.bfloat16 - and not SM80OrLater - ): + if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: raise unittest.SkipTest("requires SM80 or later") if use_static_cuda_launcher and not (device == "cuda" and bundle_triton): raise unittest.SkipTest( "Static cuda launcher requires cuda and triton bundling" ) + if use_static_cuda_launcher and TEST_WITH_ROCM: + raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM") def fn(x, y): return (x * 2, y @ y) diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index ec9586197d085..654bfd269f761 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -12,6 +12,7 @@ from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase +from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -38,9 +39,8 @@ def write_cubin_to_tmp(self, kernel: CompiledKernel) -> str: # Just used by tests for now. # TODO: derive cubin_path from wherever triton stores the cubin file on disk. tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False) - binary_key = "hsaco" if torch.version.hip else "cubin" with tmp_file: - tmp_file.write(kernel.asm[binary_key]) + tmp_file.write(kernel.asm["cubin"]) self.tmp_files.append(tmp_file) return tmp_file.name @@ -64,6 +64,7 @@ def _make_launcher( result.load_kernel(device_interface.current_device()) return result + @skipIfRocm def test_basic(self): @triton.jit def simple_kernel(arg0, arg1): @@ -90,6 +91,7 @@ def simple_kernel(arg0, arg1): # 2. triton relies on inspect.get_source to get the type annotations # so I can't even use exec() to generate the test cases. # So we'll just make a few kernels by hand + @skipIfRocm def test_unsigned_integers(self): @triton.jit def unsigned_integers( @@ -113,6 +115,7 @@ def unsigned_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_signed_integers(self): @triton.jit def signed_integers( @@ -136,6 +139,7 @@ def signed_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_basic_1arg(self): @triton.jit def simple_kernel_1_arg(arg0): @@ -160,6 +164,7 @@ def simple_kernel_1_arg(arg0): ) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_constexpr(self): # Constexprs are compiled directly into the cubin file, # so we never need to pass it to StaticCudaLauncher. @@ -188,6 +193,7 @@ def kernel_constexpr(arg0, CONSTANT: tl.constexpr): ) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_implied_constant(self): """xnumel is unused in this kernel, but isn't explicitly marked as a constexpr""" @@ -240,6 +246,7 @@ def triton_red_fused_any_isinf_0( launcher.run(1, 1, 1, stream, arg0, arg2, 128) self.assertEqual(arg1, arg2) + @skipIfRocm def test_kernel_no_args(self): # Just an easy way to test incompatible number of arguments @triton.jit @@ -252,6 +259,7 @@ def kernel_no_op(): stream = device_interface.get_raw_stream(device_interface.current_device()) launcher.run(1, 1, 1, stream) + @skipIfRocm def test_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -275,6 +283,7 @@ def simple_kernel(arg0, arg1): launcher.run(1, 1, 1, stream, new_arg0, arg1) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_too_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -294,6 +303,7 @@ def simple_kernel(arg0, arg1): lambda: self._make_launcher(compiled_kernel), ) + @skipIfRocm def test_kernel_empty_tensor(self): # Triton kernel generated by torch.compile of the following: # @torch.compile() @@ -354,6 +364,7 @@ def triton_poi_fused_cat_0( launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel) self.assertEqual(buf0, buf1) + @skipIfRocm def test_kernel_many_args(self): N = 200 # Make 200 arguments @@ -394,6 +405,7 @@ class TestStaticTritonCompileResult(TestCase): Tests static cuda launcher with torch.compile() """ + @skipIfRocm def test_basic_compile(self): @torch.compile def foo(x, y): @@ -403,6 +415,7 @@ def foo(x, y): y = torch.randn(10, device="cuda") self.assertEqual(foo(x, y), x + y) + @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch("compile_threads", 1) def test_incompatible_code(self): @@ -425,6 +438,7 @@ def foo(x): lambda: foo(x), ) + @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch( {"compile_threads": 1, "static_launch_user_defined_triton_kernels": True} @@ -446,6 +460,7 @@ def foo(x): x2 = x.clone().detach_() self.assertEqual(foo(x), x2 + 5) + @skipIfRocm def test_empty_tensor(self): @torch.compile() def foo(x, y): @@ -457,6 +472,7 @@ def foo(x, y): result = foo(x, y) self.assertEqual(result, torch.cat(((x * 4), y + 10))) + @skipIfRocm def test_any(self): def fn(x): return ( @@ -476,6 +492,7 @@ def fn(x): compiled_result = compiled_fn(arg) self.assertEqual(eager_result, compiled_result) + @skipIfRocm def test_disable_static_cuda_launcher(self): @torch.compile def fn(x, y): diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index a53ef35f4cf83..f48f351ce823a 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -3,7 +3,6 @@ from typing import Any from typing_extensions import Unpack -from ..utils import is_rocm from .triton_compat import ASTSource, CompiledKernel, knobs as triton_knobs from .triton_helpers import get_constexprs @@ -39,20 +38,7 @@ def __init__(self, kernel: CompiledKernel) -> None: # pyrefly: ignore [missing-attribute] self.name = kernel.src.fn.__name__ # pyrefly: ignore [missing-attribute] - if "hsaco" in kernel.asm: - # pyrefly: ignore [missing-attribute] - self.cubin_raw = kernel.asm["hsaco"] - - # pyrefly: ignore [missing-attribute] - elif "cubin" in kernel.asm: - # pyrefly: ignore [missing-attribute] - self.cubin_raw = kernel.asm["cubin"] - - else: - raise RuntimeError( - "Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm" - ) - + self.cubin_raw = kernel.asm.get("cubin", None) # pyrefly: ignore [missing-attribute] self.cubin_path = kernel._cubin_path @@ -259,42 +245,12 @@ def run( # thing, it should always match. # Get rid of constants before passing to cubin launcher + # Add a None if triton wants extra parameters for scratch spaces arg_tys = self.arg_tys - - if is_rocm(): - # ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both - # global_scratch and profile_scratch parameters in the kernel signature, - # even when the kernel doesn't use them (i.e., when has_*_scratch is False). - # - # This differs fundamentally from CUDA, where these parameters are only - # present in the signature if the corresponding has_*_scratch flag is True. - # - # The flags indicate whether memory will be allocated/used: - # - has_global_scratch: Whether global scratch workspace is needed - # - has_profile_scratch: Whether profiling instrumentation is enabled - # - # However, regardless of flag values, we MUST always pass both parameters - # to match the HIP kernel ABI. Passing None is safe: - # - # - If scratch is not needed (has_*_scratch=False or scratch_size=0): - # The None becomes nullptr, which the kernel never dereferences - # - # - If scratch is needed (has_*_scratch=True and scratch_size>0): - # The None becomes nullptr initially, but the HIP runtime intercepts - # the kernel launch, allocates the required scratch memory based on - # kernel metadata, and replaces the nullptr with a valid pointer before - # the kernel actually executes - # - # Not passing both parameters causes segmentation faults because the kernel - # expects them at specific positions in the argument array. - arg_tys = arg_tys + "OO" - args = (*args, None, None) - - else: - for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: - if has_scratch: - arg_tys = arg_tys + "O" - args = (*args, None) + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) # pyrefly: ignore [bad-argument-type] assert len(args) == len(arg_tys) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2e2cd8a8db780..5a37a0afccb34 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1613,8 +1613,9 @@ def can_statically_launch( return None def check_can_launch() -> StaticallyLaunchedCudaKernel: - if triton_meta.get("device_type") not in ("cuda", "hip"): - raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device") + if triton_meta.get("device_type") != "cuda": + # Only cuda kernels + raise CannotStaticallyLaunchKernel("Non-cuda device") if torch._inductor.config.cpp_wrapper: # If we're running with cpp wrapper, it doesn't @@ -1640,11 +1641,10 @@ def check_can_launch() -> StaticallyLaunchedCudaKernel: "static launch does not support launch attributes" ) - binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin" cubin_location = os.path.join( triton_cache_dir(triton_meta.get("device", 0)), triton_hash_to_path_key(kernel.hash), - f"{kernel.src.fn.__name__}.{binary_ext}", + f"{kernel.src.fn.__name__}.cubin", ) if not os.path.exists(cubin_location): @@ -1676,11 +1676,10 @@ def reload_cubin_path(self): When loading from cache on disk, we want to reload cubin files from their appropriate location on disc. """ - binary_ext = "hsaco" if torch.version.hip else "cubin" cubin_location = os.path.join( triton_cache_dir(self.compile_meta.get("device", 0)), triton_hash_to_path_key(self.kernel.hash), - f"{self.kernel.name}.{binary_ext}", + f"{self.kernel.name}.cubin", ) if not os.path.exists(cubin_location): if self.kernel.cubin_raw is not None: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0cafbed3a00c3..a91c350a522c8 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3053,11 +3053,6 @@ def is_gpu(device: Optional[str]) -> bool: return device in GPU_TYPES -def is_rocm() -> bool: - """Check if we're running on ROCm/HIP platform.""" - return torch.version.hip is not None - - def device_need_guard(device: str) -> bool: return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 6a9e2ca842050..4c304c27bfa19 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2154,7 +2154,7 @@ PyObject* initModule() { #ifdef USE_CUDA torch::cuda::initModule(module); #endif -#if defined(USE_CUDA) +#if defined(USE_CUDA) && !defined(USE_ROCM) ASSERT_TRUE(StaticCudaLauncher_init(module)); #endif #ifdef USE_MPS diff --git a/torch/csrc/inductor/static_cuda_launcher.cpp b/torch/csrc/inductor/static_cuda_launcher.cpp index da61cd28c1b6f..59916b6763bfa 100644 --- a/torch/csrc/inductor/static_cuda_launcher.cpp +++ b/torch/csrc/inductor/static_cuda_launcher.cpp @@ -1,4 +1,7 @@ -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) && !defined(USE_ROCM) +// We disable this file from being hipified because there are CUDA drivers hip +// has not implemented yet. Also, we're passing in a cubin file directly, so it +// would take more work to support ROCM anyway. #include #include @@ -13,11 +16,6 @@ #include #include #include - -#if defined(USE_ROCM) -#include -#endif - /** Implements a static launcher for triton compiled CUDA kernels. Given a path to a cubin file, a function name, and some metadata, @@ -58,14 +56,8 @@ const at::cuda::NVRTC& nvrtc() { CUdeviceptr getPointer(PyObject* obj) { CUdeviceptr data_ptr = 0; - if (THPUtils_checkLong(obj)) { -#if defined(USE_ROCM) - data_ptr = reinterpret_cast(THPUtils_unpackUInt64(obj)); -#else data_ptr = THPUtils_unpackUInt64(obj); -#endif - return data_ptr; } if (obj == Py_None) { @@ -81,25 +73,13 @@ CUdeviceptr getPointer(PyObject* obj) { TORCH_CHECK( THPUtils_checkLong(ret), "data_ptr method of Pointer object must return 64-bit int"); - -#if defined(USE_ROCM) - data_ptr = reinterpret_cast(THPUtils_unpackUInt64(ret)); -#else data_ptr = THPUtils_unpackUInt64(ret); -#endif - if (!data_ptr) return data_ptr; CUdeviceptr dev_ptr = 0; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute( - &dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute( &dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); -#endif - return dev_ptr; } @@ -118,15 +98,6 @@ CUfunction loadKernel( } CUmodule mod = nullptr; CUfunction func = nullptr; - -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str())); - AT_CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str())); - int shared_optin = 0; - AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( - &shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device)); - -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str())); AT_CUDA_DRIVER_CHECK( nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str())); @@ -135,9 +106,6 @@ CUfunction loadKernel( &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); - -#endif - // Shared memory logic from triton/third-party/nvidia/backend/driver.c // If we're using more than 48 KB of shared memory, and we have // access to more than 48 KB of shared memory on the device, @@ -156,21 +124,6 @@ CUfunction loadKernel( " Reducing block sizes or `num_stages` may help."); if (sharedMemBytes > SHARED_MEM_STATIC_MAX && shared_optin > SHARED_MEM_STATIC_MAX) { -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipFuncSetCacheConfig(func, hipFuncCachePreferShared)); - int shared_total = 0, shared_static = 0; - AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( - &shared_total, - hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, - device)); - AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( - &shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); - AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute( - func, - hipFuncAttributeMaxDynamicSharedMemorySize, - shared_optin - shared_static)); - -#else AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED)); int shared_total = 0, shared_static = 0; @@ -184,7 +137,6 @@ CUfunction loadKernel( func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); -#endif } return func; } @@ -200,27 +152,6 @@ inline void launchKernel( cudaStream_t stream) { // cta_args is always 1 for inductor generated triton kernels, // so we don't need to figure out grid dimension here -#if defined(USE_ROCM) - int device = 0; - AT_CUDA_DRIVER_CHECK(hipGetDevice(&device)); - int warp_size = 0; - AT_CUDA_DRIVER_CHECK( - hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device)); - - AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel( - func, - gridX, - gridY, - gridZ, - warp_size * numWarps, // blockDim.x - 1, // blockDim.y - 1, // blockDim.z - sharedMemBytes, - stream, - args, - nullptr)); - -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( func, gridX, @@ -233,7 +164,6 @@ inline void launchKernel( stream, args, nullptr)); -#endif } template @@ -339,20 +269,11 @@ PyObject* load_kernel(PyObject* self, PyObject* args) { CUdevice device = static_cast(device_ptr); // NOLINT CUfunction func = nullptr; func = loadKernel(filePath, funcName, sharedMemBytes, device); - -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK( - hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func)); - AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( - &n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); - -#else + // Taken from triton/nvidia/backend/driver.c AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func)); AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute( &n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); - -#endif n_spills /= 4; // Return a tuple of CUFunction, n_regs, n_spills return Py_BuildValue( @@ -378,6 +299,7 @@ PyObject* launch_kernel_inner( std::array argStorage = {}; std::array kernelArgs = {}; parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data()); + launchKernel( func, gridX, @@ -464,25 +386,13 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) { Py_RETURN_NONE; } CUcontext pctx = nullptr; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); -#endif - if (!pctx) { // Ensure device context exists CUdevice device = 0; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0)); - AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device)); - AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0)); AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device)); AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx)); - -#endif } CUfunction func = reinterpret_cast(func_ptr); // NOLINT cudaStream_t cudaStream = reinterpret_cast(stream); // NOLINT diff --git a/torch/csrc/inductor/static_cuda_launcher.h b/torch/csrc/inductor/static_cuda_launcher.h index 6f3980172275b..517036b9975e6 100644 --- a/torch/csrc/inductor/static_cuda_launcher.h +++ b/torch/csrc/inductor/static_cuda_launcher.h @@ -1,5 +1,5 @@ #pragma once -#if defined(USE_CUDA) +#if defined(USE_CUDA) && !defined(USE_ROCM) #include #include From e5fd7b7ac828e9b7abbae96243e3dd0b26ce488e Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 4 Dec 2025 20:25:02 +0000 Subject: [PATCH 278/338] Add a single GPU variant of modded-nanogpt to torchbench (#169502) (#169505) Summary: ## Tests Standalone: `python -m torchbenchmark.models.modded_nanogpt.main` Through dynamo benchmarks: `python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --only modded_nanogpt --disable-cudagraphs` This PR adds a tweaked version of the Aug 23rd record for the nanogpt speedrun (GPT-2 small variant): https://github.com/KellerJordan/modded-nanogpt/blob/9d9dc969c451c87b7ad3c84f807db2c2d9109f41/train_gpt.py. The later records can not be ran without building FA3 from source, so we will ommit them until the dynamo FA3 PR is merged. The tweaks are to library-ify the script by commenting out everything other than the model class definitions, to change the pg initialization to use fake pg, and constant-ify some hyperparameters. The tests run locally, but this model specifically requires H100. I wasn't sure how to filter for that, so I skipped all the tests. This will be tested on the dynamo benchmark side: https://github.com/pytorch/pytorch/pull/169449. X-link: https://github.com/pytorch/benchmark/pull/2660 Differential Revision: D88233265 Pulled By: xmfan Pull Request resolved: https://github.com/pytorch/pytorch/pull/169505 Approved by: https://github.com/BoyuanFeng --- benchmarks/dynamo/torchbench.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index ac4ddb4088416..f836dff3e52ec 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -172,6 +172,10 @@ def force_amp_for_fp16_bf16_models(self): def force_fp16_for_bf16_models(self): return self._config["dtype"]["force_fp16_for_bf16_models"] + @property + def amp_dtype_bfloat16(self): + return self._config["dtype"]["amp_dtype_bfloat16"] + @property def skip_accuracy_checks_large_models_dashboard(self): if self.args.dashboard or self.args.accuracy: From 36f60a9a0b532a580cd76638760b73849f47a7dd Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Thu, 4 Dec 2025 20:27:50 +0000 Subject: [PATCH 279/338] [dde] use is_contiguous_or_false for function with out= arg (#169305) Summary: Avoid DDE scenario by replacing `is_contiguous` -> `is_contiguous_or_false`. ### Codex prompt ``` Write a small unit test that tests that hits a data-dependent exception in the function is_contiguous specifically at `if maybe_guard_or_false(a.numel() < 2):` https://github.com/pytorch/pytorch/blob/481e5ab336275bd3acd5fa8a611b05b4469012af/torch/_prims_common/__init__.py#L314 Use torch.fmod to help you write the unit test. Write the unit test in https://github.com/pytorch/pytorch/blob/481e5ab336275bd3acd5fa8a611b05b4469012af/test/inductor/test_unbacked_symints.py#L4 I expect the unit test to fail with "Could not guard on data-dependent expression" Requirements: - Try to make the unit test the smallest reproducible scenario. - Follow the conventions in test_unbacked_symints.py. - The test must reproduce the error that says: "Could not guard on data-dependent expression" - It must call torch.fmod Here is a stack trace to help you reproduce the error: File "/re_cwd/re-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 1285, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/re_cwd/re-inplace#link-tree/torch/_dynamo/variables/lazy.py", line 218, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) File "/re_cwd/re-inplace#link-tree/torch/_dynamo/variables/torch.py", line 1663, in call_function if not torch._prims_common.is_contiguous(fake_out): File "/re_cwd/re-inplace#link-tree/torch/_prims_common/__init__.py", line 314, in is_contiguous if maybe_guard_or_false(a.numel() < 2): File "/re_cwd/re-inplace#link-tree/torch/_prims_common/__init__.py", line 310, in eval_eager return bool(x) File "/re_cwd/re-inplace#link-tree/torch/__init__.py", line 763, in __bool__ return self.node.bool_() File "/re_cwd/re-inplace#link-tree/torch/fx/experimental/sym_node.py", line 602, in bool_ return self.guard_bool("", 0) File "/re_cwd/re-inplace#link-tree/torch/fx/experimental/sym_node.py", line 538, in guard_bool r = self.evaluate() File "/re_cwd/re-inplace#link-tree/torch/fx/experimental/sym_node.py", line 512, in evaluate return self.shape_env.evaluate_sym_node(self, size_oblivious) File "/re_cwd/re-inplace#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7308, in evaluate_sym_node return self.evaluate_expr( File "/re_cwd/re-inplace#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7408, in evaluate_expr return self._inner_evaluate_expr( File "/re_cwd/re-inplace#link-tree/torch/fx/experimental/recording.py", line 273, in wrapper return retlog(fn(*args, **kwargs)) File "/re_cwd/re-inplace#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7431, in _inner_evaluate_expr return self._evaluate_expr( File "/re_cwd/re-inplace#link-tree/torch/fx/experimental/symbolic_shapes.py", line 7650, in _evaluate_expr raise self._make_data_dependent_error( torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression u11 < 2 (unhinted: u11 < 2). (Size-like symbols: u11) Caused by: fmod = torch.fmod(hash, 100, out = hash) in forward (_prims_common/__init__.py:310 in eval_eager) Verify it works by running this test cmd: python test/inductor/test_unbacked_symints.py -k ``` Differential Revision: D88087007 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169305 Approved by: https://github.com/Lucaskabela --- test/inductor/test_unbacked_symints.py | 12 ++++++++++++ torch/_dynamo/variables/torch.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 04c8c0573e99d..702d7c932748c 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -674,6 +674,18 @@ def fn(x, y, a): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") + @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) + def test_fmod_with_out_arg(self, device): + def fn(x): + nz = torch.nonzero(x).float() + return torch.fmod(nz, 2.0, out=nz) + + example_inputs = (torch.randn(32, device=device),) + actual = torch.compile(fn, fullgraph=True)(*example_inputs) + expected = fn(*example_inputs) + torch.testing.assert_close(actual, expected) + instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 3d0541dacfd6f..edc0921156394 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1679,7 +1679,7 @@ def call_function( *graph_break_hints.SUPPORTABLE, ], ) - if not torch._prims_common.is_contiguous(fake_out): + if not torch._prims_common.is_contiguous_or_false(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( From f575ecb83c0fe1ebad53f39c78c85e2701a1332b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 20:34:06 +0000 Subject: [PATCH 280/338] Revert "Remove unnecessary uses of thrust::tuple (#168936)" This reverts commit d19f1e8cab6810bb2e99141f9976665954c67a50. Reverted https://github.com/pytorch/pytorch/pull/168936 on behalf of https://github.com/malfet due to It'll break internal ROCM builds again ([comment](https://github.com/pytorch/pytorch/pull/168936#issuecomment-3614214438)) --- aten/src/ATen/native/cuda/ActivationEluKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationGeluKernel.cu | 1 + aten/src/ATen/native/cuda/ActivationGluKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationHardswishKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationMishKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationSiluKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu | 2 ++ aten/src/ATen/native/cuda/ActivationThresholdKernel.cu | 2 ++ aten/src/ATen/native/cuda/Loops.cuh | 2 +- aten/src/ATen/native/cuda/group_norm_kernel.cu | 1 + aten/src/ATen/native/cuda/layer_norm_kernel.cu | 3 ++- 17 files changed, 31 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cuda/ActivationEluKernel.cu b/aten/src/ATen/native/cuda/ActivationEluKernel.cu index 9fc29aa5539b5..5ad1f806f9ba5 100644 --- a/aten/src/ATen/native/cuda/ActivationEluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationEluKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu index 87781c44e3348..cd5a0ae85e61c 100644 --- a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu @@ -5,6 +5,7 @@ #include +#include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGluKernel.cu b/aten/src/ATen/native/cuda/ActivationGluKernel.cu index 8a782a129c9fb..e28a6d61ea152 100644 --- a/aten/src/ATen/native/cuda/ActivationGluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGluKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu index f0968b957aa6d..2a0be3f5d27bf 100644 --- a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu index 813a8c07ccfac..fcacef37ceaf0 100644 --- a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu index 651cdef82543b..1642d0909f7f0 100644 --- a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu index 85aa7ccd22a9e..a18072f7a27bc 100644 --- a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu index 340a6f97d00de..72130739898fe 100644 --- a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu index 2175920917852..9a1d672428b48 100644 --- a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationMishKernel.cu b/aten/src/ATen/native/cuda/ActivationMishKernel.cu index 25ba9810e37cf..0db0e96bb180a 100644 --- a/aten/src/ATen/native/cuda/ActivationMishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationMishKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu index ebdfe245b6166..f7ddfd8502a18 100644 --- a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu index 65f4f3679f862..64ffc21123707 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu index 712c86e0e5216..0c2dc63dbcf45 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu index 430f9cbfa78bb..2d1cb4a47d7d8 100644 --- a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu @@ -5,6 +5,8 @@ #include +#include + #include #include #include diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index e739d7d2ecee2..a80c51fa6a9cb 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -282,7 +282,7 @@ void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) using traits = function_traits; using output_t = typename traits::result_type; static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); - constexpr int num_outputs = std::tuple_size::value; + constexpr int num_outputs = thrust::tuple_size::value; constexpr int num_inputs = traits::arity; constexpr int ntensors = num_outputs + num_inputs; diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index 0ef6434f909de..77d26e915b65a 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -3,6 +3,7 @@ #include +#include #include #include diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 6f5112c605fab..84812eb22125f 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -1,9 +1,10 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include #include +#include + #include #include #include From 00279ab2660d3d7a350e11be93efd72cdd1a0579 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 4 Dec 2025 20:48:59 +0000 Subject: [PATCH 281/338] [annotate] Change AutogradFunctionApply HOP to use Interpreter to preserve annotation (#169528) Similar to https://github.com/pytorch/pytorch/pull/165336 We have an issue when using fx_traceback.annotate and tracing HOPs. HOPs have bodies that have already been traced by Dynamo, and have the annotations. But when we lower that Dynamo HOP body to aten, we need to propagate the annotations to the aten nodes. To do this, we need to run the graph in Interpreter mode so the node meta can be propagated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169528 Approved by: https://github.com/mlazos, https://github.com/xmfan --- test/dynamo/test_streams.py | 81 +++++++++++++++++++++++++++ torch/_functorch/autograd_function.py | 13 ++++- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index c594c87b7f1b7..ba151f63c5d3c 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -872,6 +872,87 @@ def forward(self, primals_1: "f32[2, 2]", mul: "f32[2, 2]", tangents_1: "f32[2, """, ) + @requires_cuda + def test_epilogue_copy_stream_tracking(self): + """ + Test that epilogue copies for mutated inputs use the correct stream. + This verifies that ViewAndMutationMeta.mutated_inp_stream_indices is + properly populated and used at runtime. + Uses a custom autograd.Function where the backward mutates a saved + tensor on a specific stream. + """ + + class BwMutationWithStream(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x) + ctx.s1 = torch.Stream(device="cuda:0") + ctx.s2 = torch.Stream(device="cuda:0") + # Do computation on stream s2 + with ctx.s2: + result = x * 2 + y + return result + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + # Mutate saved tensor x on stream s1 in backward + with ctx.s1: + x.mul_(2) + # Compute gradients on stream s2 + with ctx.s2: + grad_x = grad_output * 2 + grad_y = grad_output.clone() + return grad_x, grad_y, None, None + + def fn(x, y): + result = BwMutationWithStream.apply(x, y) + return result + + x = torch.ones(2, 2, requires_grad=True, device="cuda:0") + y = torch.ones(2, 2, requires_grad=True, device="cuda:0") + ( + actual, + _, + fw_graphs, + bw_graphs, + ) = extract_graph(fn, x.clone(), y.clone()) + self.assertEqual(len(fw_graphs), 1) + # Forward graph should show computation on stream 1 (s2) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): + # Annotation: {'stream': 1} + mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2) + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None + return (add, primals_1, mul) +""", + ) + # Run backward and check that the epilogue copy uses stream 0 (s1) + actual.sum().backward() + # The backward graph should show: + # 1. Mutation happening on stream 0 (s1) + # 2. Gradient computation on stream 1 (s2) + # 3. Epilogue copy for the mutated tensor on stream 0 (s1) + self.assertExpectedInline( + print_graph(bw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[2, 2]", mul: "f32[2, 2]", tangents_1: "f32[2, 2]"): + # Annotation: {'stream': 1} + mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2) + + # Annotation: {'stream': 1} + clone: "f32[2, 2]" = torch.ops.aten.clone.default(tangents_1); tangents_1 = None + + # No stacktrace found for following nodes + copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = copy_ = None + return (mul_2, clone) +""", + ) + @requires_cuda def test_inductor_lowering(self): with patch("torch._inductor.config.implicit_fallbacks", False): diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 3f4c1a4979446..ca7376cf9620c 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -756,7 +756,11 @@ class ApplyTemplate(torch.autograd.Function): # pyrefly: ignore [bad-override] def forward(ctx, *args): nonlocal saved_values - output, saved_values = fwd(None, *fwd_args) + + # The Interpreter here is required to propagate metadata + # from the dynamo graph body to the local_map graph body. + # This is required for fx_traceback.annotate for work. + output, saved_values = torch.fx.Interpreter(fwd).run(None, *fwd_args) # If users call ctx.mark_non_differentiable() in the original fwd function. if len(non_differentiable_idx) > 0: @@ -770,7 +774,12 @@ def forward(ctx, *args): @staticmethod def backward(ctx, *grad): - return bwd(None, *grad, *saved_values) + # The Interpreter here is required to propagate metadata + # from the dynamo graph body to the local_map graph body. + # This is required for fx_traceback.annotate for work. + + # pyrefly: ignore [not-iterable] + return torch.fx.Interpreter(bwd).run(None, *grad, *saved_values) return ApplyTemplate.apply(*new_fwd_args) From e64f1eece02ceb068ed0f5bbe5acda5259840a09 Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Thu, 4 Dec 2025 20:49:37 +0000 Subject: [PATCH 282/338] expandable_segments + memory pool (#169491) Fixes #147851 Please also see #165419 and #148378 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169491 Approved by: https://github.com/ngimel --- c10/cuda/CUDACachingAllocator.cpp | 46 +++++++++++++++++++++---------- test/test_cuda.py | 26 +++++++++++++---- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 3d1837061e7b2..9e637f4f6997e 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -863,8 +863,12 @@ struct AllocParams { size_t size, cudaStream_t stream, BlockPool* pool, - size_t alloc_size) - : search_key(device, stream, size), pool(pool), alloc_size(alloc_size) {} + size_t alloc_size, + bool is_expandable_segments_active) + : search_key(device, stream, size), + pool(pool), + alloc_size(alloc_size), + is_expandable_segments_active(is_expandable_segments_active) {} c10::DeviceIndex device() const { return search_key.device; @@ -879,6 +883,7 @@ struct AllocParams { Block search_key; BlockPool* pool; size_t alloc_size; + bool is_expandable_segments_active; Block* block{nullptr}; StatTypes stat_types = {false}; cudaError_t err{cudaSuccess}; @@ -1381,7 +1386,18 @@ class DeviceCachingAllocator { size_t size = round_size(orig_size); auto& pool = get_pool(size, stream); const size_t alloc_size = get_allocation_size(size); - AllocParams params(device_id, size, stream, &pool, alloc_size); + bool active_user_pool = + pool.owner_PrivatePool && pool.owner_PrivatePool->allocator(); + // The expandable segments are only active on the default pool. + bool is_expandable_segments_active = + CUDAAllocatorConfig::expandable_segments() && !active_user_pool; + AllocParams params( + device_id, + size, + stream, + &pool, + alloc_size, + is_expandable_segments_active); params.stat_types = get_stat_types_for_pool(pool); // First, try to get a block from the existing pool. @@ -1429,7 +1445,7 @@ class DeviceCachingAllocator { beginAllocateToPool(mempool_id, filter); auto& mempool = get_pool(size, stream); AllocParams mempool_params( - device_id, size, stream, &mempool, alloc_size); + device_id, size, stream, &mempool, alloc_size, false); mempool_params.stat_types = get_stat_types_for_pool(mempool); block_found = get_free_block(mempool_params); endAllocateToPool(mempool_id); @@ -1565,7 +1581,8 @@ class DeviceCachingAllocator { " (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"); } - bool split_remainder = should_split(params.block, params.size()); + bool split_remainder = should_split( + params.block, params.size(), params.is_expandable_segments_active); return alloc_found_block( params, orig_size, std::move(context), split_remainder); } @@ -2222,7 +2239,8 @@ class DeviceCachingAllocator { block_state.size, block_state.stream, &pool, - block_state.size); + block_state.size, + curr_block->expandable_segment_ != nullptr); pool.blocks.erase(curr_block); params.block = curr_block; params.stat_types = get_stat_types_for_pool(pool); @@ -2993,9 +3011,12 @@ class DeviceCachingAllocator { return stat_types; } - bool should_split(const Block* block, size_t size) { + bool should_split( + const Block* block, + size_t size, + bool is_expandable_segments_active) { size_t remaining = block->size - size; - if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) { + if (block->pool->is_small || is_expandable_segments_active) { return remaining >= kMinBlockSize; } else { return (size < AcceleratorAllocatorConfig::max_split_size()) && @@ -3027,7 +3048,7 @@ class DeviceCachingAllocator { return false; if ((*it)->expandable_segment_) { - if (CUDAAllocatorConfig::expandable_segments()) { + if (p.is_expandable_segments_active) { // if we are allocated to the part of the block that is expandable // for the purposes of "best fit" we consider its size to be the size it // can expand to, not the size it currently is. This means that we @@ -3166,19 +3187,14 @@ class DeviceCachingAllocator { bool in_fbcode = false; #endif - bool active_pool = - p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator(); if (allowed_memory_maximum.has_value() && total_allocated_memory + size > allowed_memory_maximum.value()) { p.err = cudaErrorMemoryAllocation; return false; // Temporarily disable checkpointing & cudagraphs internally } else if ( - CUDAAllocatorConfig::expandable_segments() && + p.is_expandable_segments_active && !(in_fbcode && p.pool->owner_PrivatePool)) { - TORCH_CHECK( - !active_pool, - "torch.cuda.MemPool doesn't currently support expandable_segments."); p.block = try_allocate_expandable_block( p.device(), p.stream(), p.pool, p.size(), ctx); if (p.block) { diff --git a/test/test_cuda.py b/test/test_cuda.py index 1ad9769072c23..21098ae096cc9 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -5887,15 +5887,31 @@ def test_graph_capture_reclaim_4_streams(self): @skipIfRocm(msg="expandable_segments mode is not supported on ROCm") @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode") def test_mempool_expandable(self): + torch.cuda.empty_cache() torch.cuda.memory._set_allocator_settings("expandable_segments:True") allocator, _ = self.get_dummy_allocator(check_vars=False) pool = torch.cuda.MemPool(allocator.allocator()) - # torch.cuda.MemPool doesn't work with expandable segments - with self.assertRaises(RuntimeError): - nelem_1mb = 1024 * 1024 // 4 - with torch.cuda.use_mem_pool(pool): - out_0 = torch.randn(nelem_1mb, device="cuda") + data = [] + nelem = 1024 * 1024 // 4 + with torch.cuda.use_mem_pool(pool): + data.append(torch.empty(nelem, device="cuda")) + + # the second allocation should be in expandable segment + data.append(torch.empty(nelem, device="cuda")) + + segments = torch.cuda.memory.memory_snapshot() + + num_expandable_segments = 0 + for segment in segments: + if segment["is_expandable"]: + num_expandable_segments += 1 + + self.assertEqual(len(segments), 2, "Expected to have 2 segment") + self.assertEqual( + num_expandable_segments, 1, "Expected to have 1 expandable segment only" + ) + torch.cuda.memory._set_allocator_settings("expandable_segments:False") @serialTest() From 5aad0e1d19529d6a95c3cefb087b94d5b6f688c3 Mon Sep 17 00:00:00 2001 From: angelayi Date: Thu, 4 Dec 2025 09:25:54 -0800 Subject: [PATCH 283/338] [opaque obj] Improve error msg for intermediate opaques (#167742) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167742 Approved by: https://github.com/zou3519 --- test/test_opaque_obj_v2.py | 20 +++++++++++++++++++- torch/_dynamo/graph_break_registry.json | 11 +++++++++++ torch/_dynamo/variables/torch.py | 23 +++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/test/test_opaque_obj_v2.py b/test/test_opaque_obj_v2.py index 99ff9058eda52..3015defd88349 100644 --- a/test/test_opaque_obj_v2.py +++ b/test/test_opaque_obj_v2.py @@ -6,6 +6,7 @@ import torch from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import AotEagerAndRecordGraphs +from torch._dynamo.utils import counters as dynamo_counters from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, @@ -376,7 +377,7 @@ def forward(self, arg0_1, arg1_1): return (add,)""", # noqa: B950 ) - def test_compile_intermediate(self): + def test_compile_global(self): counter = Counter(0) def foo(x, y): @@ -417,6 +418,23 @@ def forward(self, arg0_1, arg1_1, arg2_1): return (add,)""", # noqa: B950 ) + def test_compile_create_intermediate(self): + dynamo_counters.clear() + + def foo(x, y): + counter = Counter(0) + z = torch.ops._TestOpaqueObject.increment_counter(counter, y) + x = x * z + return x + + inp = (torch.tensor(1), torch.tensor(0)) + torch.compile(foo)(*inp) + self.assertEqual(len(dynamo_counters["graph_break"]), 1) + self.assertTrue( + "Opaque object were created in the middle of the program and passed to a custom op." + in next(iter(dynamo_counters["graph_break"].keys())), + ) + def test_compile_attribute(self): counter = Counter(0) diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index a5c1d22eea1fd..dd012a239bb23 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3667,5 +3667,16 @@ "Use custom operators instead of direct attribute/method access." ] } + ], + "GB0363": [ + { + "Gb_type": "Opaque object were created in the middle of the program and passed to a custom op.", + "Context": "Opaque object types: {intermediate_opaques}. Function: {self.value}", + "Explanation": "Opaque objects cannot be created inside the torch.compile region. They must be created before entering the compiled function.", + "Hints": [ + "Please create the opaque object before calling torch.compile ", + "and pass it in as an argument or as a global variable." + ] + } ] } diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index edc0921156394..a4f940cb2adaf 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -41,6 +41,7 @@ import torch.fx import torch.nn from torch._guards import TracingContext +from torch._library.opaque_object import is_opaque_type from torch._logging import warning_once from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type @@ -86,6 +87,7 @@ TensorWithTFOverrideVariable, TorchFunctionModeStackVariable, ) +from .user_defined import UserDefinedObjectVariable try: @@ -1507,6 +1509,27 @@ def call_function( ) return self.call_tensor_method(tx, args, kwargs) + intermediate_opaques = [ + type(x.value) + for x in args + if x.source is None + and isinstance(x, UserDefinedObjectVariable) + and is_opaque_type(type(x.value)) + ] + if len(intermediate_opaques) > 0: + unimplemented( + gb_type="Opaque object were created in the middle of the program and passed to a custom op.", + context=f"Opaque object types: {intermediate_opaques}. Function: {self.value}", + explanation=( + "Opaque objects cannot be created inside the torch.compile region. " + "They must be created before entering the compiled function." + ), + hints=[ + "Please create the opaque object before calling torch.compile " + "and pass it in as an argument or as a global variable." + ], + ) + special_handler = self._get_handlers().get(self.value) if special_handler: result = special_handler(self, tx, *args, **kwargs) From 5c3874bf204360256d05df8be51337a451dc1e15 Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Thu, 4 Dec 2025 21:18:18 +0000 Subject: [PATCH 284/338] [ROCm][CI] Enable TD for all ROCm default and distributed config workflows (#168225) Ensures that Target Determination (TD) is enabled for all ROCm workflows that run default or distributed config tests on PRs. NOTE: Excluding inductor-rocm workflows to keep parity with CUDA inductor workflows, which also do not seem to have TD enabled. Example of TD not being enabled on some ROCm workflows (TD enabled will have "[Running 25% of tests based on TD](https://github.com/pytorch/pytorch/blob/6fa7791bab2785bdcae096bd2f80b2528112b859/test/run_test.py#L2104)" in the log; while TD disabled will have "[Running all tests](https://github.com/pytorch/pytorch/blob/6fa7791bab2785bdcae096bd2f80b2528112b859/test/run_test.py#L2106)" in the log): periodic-rocm-mi300: https://hud.pytorch.org/pr/pytorch/pytorch/167548#periodic-rocm-mi300 periodic-rocm-mi200: https://hud.pytorch.org/pr/pytorch/pytorch/167548#periodic-rocm-mi200 rocm-mi200: https://hud.pytorch.org/pr/167183#rocm-mi200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168225 Approved by: https://github.com/jeffdaily --- .github/actions/filter-test-configs/action.yml | 3 +++ test/run_test.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/actions/filter-test-configs/action.yml b/.github/actions/filter-test-configs/action.yml index 338fc0c2a844c..a9e2be53c6935 100644 --- a/.github/actions/filter-test-configs/action.yml +++ b/.github/actions/filter-test-configs/action.yml @@ -156,5 +156,8 @@ runs: echo echo "Is keep-going label set? ${{ steps.filter.outputs.keep-going }}" + echo + echo "Is ci-no-td label set? ${{ steps.filter.outputs.ci-no-td }}" + echo echo "Reenabled issues? ${{ steps.filter.outputs.reenabled-issues }}" diff --git a/test/run_test.py b/test/run_test.py index c6a8473b5667b..ac36d5db27e35 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1488,6 +1488,7 @@ def parse_args(): help="Set a timeout based on the test times json file. Only works if there are test times available", default=IS_CI and not strtobool(os.environ.get("NO_TEST_TIMEOUT", "False")), ) + GITHUB_WORKFLOW = os.environ.get("GITHUB_WORKFLOW", "slow") parser.add_argument( "--enable-td", action="store_true", @@ -1498,8 +1499,10 @@ def parse_args(): and not IS_MACOS and "xpu" not in BUILD_ENVIRONMENT and "onnx" not in BUILD_ENVIRONMENT - and os.environ.get("GITHUB_WORKFLOW", "slow") - in ("trunk", "pull", "rocm", "rocm-mi300"), + and ( + GITHUB_WORKFLOW in ("trunk", "pull") + or GITHUB_WORKFLOW.startswith(("rocm-", "periodic-rocm-")) + ), ) parser.add_argument( "--shard", From 2912ea3bb32dd95a78c52dd22ba472c9a1d44d24 Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Thu, 4 Dec 2025 21:47:37 +0000 Subject: [PATCH 285/338] Add huggingface storage reader for MXFP4 quantized GPT-OSS checkpoint (#167672) As titled, the updated `QuantizedHuggingFaceStorageReader` can be used to load MXFP4 quantized GPT-OSS HF checkpoint. For example, this feature enables TorchTitan to load GPT-OSS models with de-quantization happening under the hood. 1. Test 1. We use `dcp.load(hf_state_dict, storage_reader=QuantizedHuggingFaceStorageReader(path=input_dir))` to load from GPT-OSS HF checkpoint, and map the `hf_state_dict` back to TorchTitan state dict. We build one test input, and compare two outputs: 1. Using `transformer` library to load GPT-OSS HF checkpoint and run inference on the test input; 2. We use the converted TorchTitan model to run inference on the test input. We compare the outputs by comparing the KL divergence of two output probability distributions. The result shows two models are very similar. Pasted Graphic 2. Test 2. With TorchTitan, we load the model directly from quantized GPT-OSS HF checkpoint, and do a test training. Pasted Graphic 1 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167672 Approved by: https://github.com/LucasLLC, https://github.com/ankitageorge --- .../checkpoint/quantized_hf_storage.py | 230 ++++++++++++++++-- 1 file changed, 207 insertions(+), 23 deletions(-) diff --git a/torch/distributed/checkpoint/quantized_hf_storage.py b/torch/distributed/checkpoint/quantized_hf_storage.py index 36f4ddf937fee..464052d99062a 100644 --- a/torch/distributed/checkpoint/quantized_hf_storage.py +++ b/torch/distributed/checkpoint/quantized_hf_storage.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import json import logging +import math from pathlib import Path from typing import Any @@ -56,14 +57,28 @@ def __init__( def read_metadata(self) -> Any: metadata = super().read_metadata() - # Build a cache of FQN -> full tensor shape for faster lookups. - for fqn, tensor_metadata in metadata.state_dict_metadata.items(): - # Only process TensorStorageMetadata which has size attribute - if isinstance(tensor_metadata, TensorStorageMetadata): - self._tensor_full_shapes[fqn] = tensor_metadata.size + # Load quantization metadata first. self._load_quantization_metadata() + # Build a cache of FQN -> full tensor shape, correcting for quantized tensors. + for fqn, tensor_metadata in metadata.state_dict_metadata.items(): + # Only process TensorStorageMetadata which has size attribute. + if isinstance(tensor_metadata, TensorStorageMetadata): + # Check if this is a MXFP4 quantized tensor that needs shape correction. + if fqn.endswith("_blocks"): + # Save the quantized tensor shapes for lookup when dequantization. + self._tensor_full_shapes[fqn + "_quantized"] = tensor_metadata.size + *prefix_shape, G, B = tensor_metadata.size + dequantized_size = torch.Size([*prefix_shape, G * B * 2]) + + # Update the metadata with the size after dequantization. + # Metadata used by planner to slice state dict. + tensor_metadata.size = dequantized_size + self._tensor_full_shapes[fqn] = dequantized_size + else: + self._tensor_full_shapes[fqn] = tensor_metadata.size + return metadata def _load_quantization_metadata(self): @@ -79,7 +94,7 @@ def _load_quantization_metadata(self): def _build_weight_scale_mapping(self, weight_map: dict[str, str]): """Analyze and build weight-scale tensor pairs from weight mapping.""" - # Store the complete weight map for file location lookups + # Store the complete weight map for file location lookups. self._weight_map = weight_map for tensor_name in weight_map: @@ -87,6 +102,11 @@ def _build_weight_scale_mapping(self, weight_map: dict[str, str]): weight_name = tensor_name.replace(".weight_scale_inv", ".weight") if weight_name in weight_map: self._weight_scale_mapping[weight_name] = tensor_name + # Handle MXFP4 format: _blocks and _scales. + elif tensor_name.endswith("_scales"): + blocks_name = tensor_name.replace("_scales", "_blocks") + if blocks_name in weight_map: + self._weight_scale_mapping[blocks_name] = tensor_name def _process_read_request( self, f: Any, req: ReadItem, planner: LoadPlanner @@ -149,6 +169,112 @@ def _get_slice_to_block_mapping( col_slice, ) + def _dequantize_tensor_mxfp4( + self, + blocks: torch.Tensor, + scales: torch.Tensor, + req: ReadItem, + group_start: int, + offset_in_first_group: int, + ) -> torch.Tensor: + """ + Dequantize a 4D tensor using MXFP4 format. + Adapted from openai's implementation: + https://github.com/openai/gpt-oss/blob/8890e95919f975a490fc0ba09ffb10890ec7319d/gpt_oss/torch/weights.py#L68 + + Args: + blocks: Sliced quantized weight tensor of shape [a_slice, b_slice, groups_slice, B] in uint8 + scales: FULL scale tensor of shape [a, b, c] in uint8 (will be converted to exponents) + req: Read request containing slice information + group_start: The starting group index in the checkpoint + offset_in_first_group: Offset in values within the first group + + Returns: + Dequantized tensor matching the requested shape + """ + # FP4 lookup table + FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + + # blocks: [a_slice, b_slice, groups_slice, B] uint8. + # Read slightly more groups than needed, and slice at the end. + + # Slice the scales to match the blocks dimensions. + # [a_full, b_full, c_full] -> [a_slice, b_slice, groups_slice] + dim0_start = req.storage_offsets[0] + dim0_end = dim0_start + req.lengths[0] + dim1_start = req.storage_offsets[1] + dim1_end = dim1_start + req.lengths[1] + num_groups = blocks.shape[2] + scales = scales[ + dim0_start:dim0_end, + dim1_start:dim1_end, + group_start : group_start + num_groups, + ] + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, ( + f"{blocks.shape=} does not match {scales.shape=}" + ) + + lut = torch.tensor(FP4_VALUES, dtype=self.target_dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty( + rows_total, B * 2, dtype=self.target_dtype, device=blocks.device + ) + + rows_per_chunk = 16384 * 512 + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + + del idx_lo, idx_hi, blk, exp + + result = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + + # Slice the last dimension to match the requested range. + if offset_in_first_group > 0 or result.shape[-1] > req.lengths[2]: + end_offset = offset_in_first_group + req.lengths[2] + result = result[..., offset_in_first_group:end_offset] + + return result + def _dequantize_tensor( self, weight: torch.Tensor, @@ -245,7 +371,7 @@ def _is_tensor_quantized(self, tensor_fqn: str) -> bool: False otherwise """ # Skip scale tensors themselves - if tensor_fqn.endswith(".weight_scale_inv"): + if tensor_fqn.endswith((".weight_scale_inv", "_scales")): return False # Check if this weight tensor has a corresponding scale tensor @@ -271,12 +397,59 @@ def _read_quantized_tensor_with_block_alignment( scale_fqn = self._weight_scale_mapping[tensor_fqn] try: - # Load the sliced quantized weight - weight_slices = tuple( - slice(offset, offset + length) - for offset, length in zip(req.storage_offsets, req.lengths) - ) - quantized_tensor = safetensor_file.get_slice(tensor_fqn)[weight_slices] + group_start = 0 + offset_in_first_group = 0 + if tensor_fqn.endswith("_blocks"): + # Full tensor is a 4D MXFP4 quantized tensor: [..., G, B]. + # Each group G produces B * 2 dequantized values. + # Checkpoint [..., G, B] -> dequantized [..., G*B*2]. + + # The planner gives 3D requests based on the dequantized shape. + # Need to figure out which groups (dimension 2 in checkpoint) to read. + + # Use the quantized checkpoint shape to get the correct B. + *prefix_shape, B = self._tensor_full_shapes[tensor_fqn + "_quantized"] + values_per_group = B * 2 # Each byte has 2 nibbles (4-bit values). + + # Calculate which groups we need based on the requested range in dim 2. + # Ensure the reequest is in 3D. + assert len(req.storage_offsets) == 3 + + # Positions in dequantized space. + dim2_start_deq = req.storage_offsets[2] + dim2_length_deq = req.lengths[2] + dim2_end_deq = dim2_start_deq + dim2_length_deq + + # Convert to group indices. + group_start = dim2_start_deq // values_per_group + group_end = (dim2_end_deq + values_per_group - 1) // values_per_group + + # Read only the necessary groups from checkpoint. + weight_slices_4d = ( + slice( + req.storage_offsets[0], req.storage_offsets[0] + req.lengths[0] + ), + slice( + req.storage_offsets[1], req.storage_offsets[1] + req.lengths[1] + ), + slice(group_start, group_end), + slice(None), # Read all B values for each group. + ) + quantized_tensor = safetensor_file.get_slice(tensor_fqn)[ + weight_slices_4d + ] + + # Also track the offset within the first group + offset_in_first_group = dim2_start_deq - ( + group_start * values_per_group + ) + else: + # 2D quantized tensor, use 2d block partition. + weight_slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) + ) + quantized_tensor = safetensor_file.get_slice(tensor_fqn)[weight_slices] # Load the corresponding scale inverse tensor (full tensor) scale_file_name = self._weight_map.get(scale_fqn) @@ -304,16 +477,27 @@ def _read_quantized_tensor_with_block_alignment( if full_tensor_shape is None: raise ValueError(f"Could not find full tensor shape for {tensor_fqn}") - # Get slice to block mapping - slice_info = self._get_slice_to_block_mapping(req) - - # Perform dequantization with proper block alignment - dequantized_tensor = self._dequantize_tensor( - weight=quantized_tensor, - scale_inv=scale_inv, - full_tensor_shape=full_tensor_shape, - slice_info=slice_info, - ) + # Determine which dequantization function to use. + if len(full_tensor_shape) == 2: + # 2D block-wise quantization, e.g., used in deepseek v3.1 + slice_info = self._get_slice_to_block_mapping(req) + dequantized_tensor = self._dequantize_tensor( + weight=quantized_tensor, + scale_inv=scale_inv, + full_tensor_shape=full_tensor_shape, + slice_info=slice_info, + ) + elif tensor_fqn.endswith("_blocks"): + # 4D with blocks along dimension 2, used in MXFP4, e.g. gpt-oss + dequantized_tensor = self._dequantize_tensor_mxfp4( + blocks=quantized_tensor, + scales=scale_inv, + req=req, + group_start=group_start, + offset_in_first_group=offset_in_first_group, + ) + else: + raise ValueError("Unsupported quantization types") return dequantized_tensor From 31bb133faf6a88c3adab11fa06d9a5fe4f6c9c85 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 3 Dec 2025 23:13:08 -0800 Subject: [PATCH 286/338] [dynamo] Refactor isinstance(x, ConstantVariable) to x.is_python_constant() (#169006) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169006 Approved by: https://github.com/zou3519, https://github.com/anijain2305 --- test/dynamo/test_list.py | 6 ++++ test/dynamo/test_modules.py | 3 +- torch/_dynamo/comptime.py | 3 +- torch/_dynamo/output_graph.py | 2 +- torch/_dynamo/side_effects.py | 5 +-- torch/_dynamo/symbolic_convert.py | 28 ++++++++------- torch/_dynamo/utils.py | 8 ++--- torch/_dynamo/variables/base.py | 26 +++++++++++--- torch/_dynamo/variables/builder.py | 4 +-- torch/_dynamo/variables/builtin.py | 33 +++++++---------- torch/_dynamo/variables/constant.py | 16 ++++++--- torch/_dynamo/variables/ctx_manager.py | 4 +-- torch/_dynamo/variables/dicts.py | 25 ++++++------- torch/_dynamo/variables/functions.py | 20 ++++------- torch/_dynamo/variables/higher_order_ops.py | 39 +++++++++++---------- torch/_dynamo/variables/iter.py | 4 +-- torch/_dynamo/variables/lists.py | 13 ++++--- torch/_dynamo/variables/misc.py | 6 ++-- torch/_dynamo/variables/nn_module.py | 2 +- torch/_dynamo/variables/optimizer.py | 2 +- torch/_dynamo/variables/tensor.py | 17 +++++---- torch/_dynamo/variables/torch.py | 39 +++++++++++---------- torch/_dynamo/variables/torch_function.py | 4 +-- torch/_dynamo/variables/user_defined.py | 10 +++--- 24 files changed, 166 insertions(+), 153 deletions(-) diff --git a/test/dynamo/test_list.py b/test/dynamo/test_list.py index 41e5da15b5378..85415244db69c 100644 --- a/test/dynamo/test_list.py +++ b/test/dynamo/test_list.py @@ -176,6 +176,12 @@ def test___iter__(self): it = p.__iter__().__iter__() self.assertEqual(next(it), 1) + @make_dynamo_test + def test_list_mul_constant_tuple(self): + tree = (1, 2) + result = [tree] * 2 + self.assertEqual(result, [tree, tree]) + class ListTests(TupleTests): # List methods diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 6fd1e6b477f36..959a32ff17a10 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3383,7 +3383,8 @@ def __init__(self): def __bool__(self): self.bool_invoked += 1 - return len(self.key_cache) + # __bool__ must return a real bool; use truthiness of cache size + return len(self.key_cache) > 0 @torch.compile(fullgraph=True, backend="eager") def f(x): diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 34eec572ce550..f53c753365b63 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -49,7 +49,6 @@ def my_model(x): from .exc import unimplemented from .variables import CellVariable -from .variables.constant import ConstantVariable from .variables.tensor import SymNodeVariable @@ -143,7 +142,7 @@ def force_static(self) -> None: """ if isinstance(self.__variable, SymNodeVariable): self.__variable.evaluate_expr() - elif isinstance(self.__variable, ConstantVariable): + elif self.__variable.is_python_constant(): # TODO: Maybe complain if this isn't a int/bool/float variable pass else: diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 0d409869ccec5..4fc288c9bf546 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1685,7 +1685,7 @@ def compile_subgraph( "input", vt.source, ) - elif isinstance(vt, torch._dynamo.variables.ConstantVariable): + elif vt.is_python_constant(): self.export_metadata.output_return_type[idx] = ( "constant", vt.as_python_constant(), diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 999bd145c3e57..df9716339b661 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -881,10 +881,7 @@ def codegen_update_mutated(self, cg: PyCodegen) -> None: elif isinstance(var, variables.lists.DequeVariable): # For limited maxlen, the order of operations matter for side # effect, but we currently don't track the order, so no support. - if not ( - isinstance(var.maxlen, variables.ConstantVariable) - and var.maxlen.value is None - ): + if not var.maxlen.is_constant_none(): unimplemented( gb_type="Side effect on existing deque with limited maxlen", context="", diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index f401b9d6178b9..487346940dfdf 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -761,10 +761,15 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: # __bool__ or __len__ is function if isinstance(x, UserMethodVariable): result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment] - if isinstance(result, ConstantVariable) and isinstance( - result.value, (bool, int) - ): - if truth_fn(result.value): + method_name = getattr(getattr(x, "fn", None), "__name__", None) + if result.is_python_constant(): + result_value = result.as_python_constant() + if method_name == "__bool__" and not isinstance(result_value, bool): + msg = variables.ConstantVariable.create( + f"__bool__ should return bool, returned {type(result_value).__name__}" + ) + exc.raise_observed_exception(TypeError, self, args=[msg]) + if isinstance(result_value, (bool, int)) and truth_fn(result_value): if push: self.push(value) self.jump(inst) @@ -2633,7 +2638,7 @@ def STORE_ATTR(self, inst: Instruction) -> None: return self.store_attr_graph_break(inst) val, obj = self.popn(2) - if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable): + if isinstance(obj, NNModuleVariable) and not val.is_python_constant(): # We don't allow side effects during export on non-constant values # https://github.com/pytorch/torchdynamo/issues/1475 assert not self.export, ( @@ -3548,7 +3553,7 @@ def BUILD_STRING(self, inst: Instruction) -> None: kwargs: dict[str, VariableTracker] = {} assert inst.arg is not None for part in self.popn(inst.arg): - if isinstance(part, ConstantVariable): + if part.is_python_constant(): format_string_parts.append("{}") args.append(part) elif isinstance(part, variables.StringFormatVariable): @@ -4980,10 +4985,7 @@ def inline_call_(self) -> VariableTracker: assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration args = [] - if not ( - isinstance(self.symbolic_result, ConstantVariable) - and self.symbolic_result.value is None - ): + if not self.symbolic_result.is_constant_none(): args = [self.symbolic_result] exc.raise_observed_exception(StopIteration, self, args=args) else: @@ -4991,7 +4993,7 @@ def inline_call_(self) -> VariableTracker: else: if is_generator(code): assert isinstance(self, InliningGeneratorInstructionTranslator) - assert self.symbolic_result.as_python_constant() is None + assert self.symbolic_result.is_constant_none() return ListIteratorVariable( self.generated_items, mutation_type=ValueMutationNew(), @@ -5223,7 +5225,7 @@ def YIELD_FROM(self, inst: Instruction) -> None: assert len(self.stack) >= 2 val = self.pop() tos = self.stack[-1] - if not (isinstance(val, ConstantVariable) and val.value is None): + if not val.is_constant_none(): # invoke send # Unreachable code - if you hit this, you are implementing generator support and have # lifted the `unimplemented("generator")` in frame conversion. This codepath handles @@ -5265,7 +5267,7 @@ def SEND(self, inst: Instruction) -> None: isinstance(tos, UserDefinedObjectVariable) and isinstance(tos.value, collections.abc.Iterator) ): - if isinstance(val, ConstantVariable) and val.value is None: + if val.is_constant_none(): try: val = tos.next_variable(self) # type: ignore[arg-type] except (StopIteration, exc.ObservedUserStopIteration) as ex: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index d08b92de3441e..b0ad5d2bf5118 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2600,11 +2600,11 @@ def specialize_symnode(arg: Any) -> Any: def guard_if_dyn(arg: Any) -> Any: - from .variables import ConstantVariable + from .variables import VariableTracker arg = specialize_symnode(arg) - if isinstance(arg, ConstantVariable): + if isinstance(arg, VariableTracker) and arg.is_python_constant(): return arg.as_python_constant() return arg @@ -2615,14 +2615,14 @@ def check_constant_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool: def check_unspec_python_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool: - from .variables.constant import ConstantVariable + from .variables import VariableTracker from .variables.tensor import UnspecializedPythonVariable unspec_count = 0 for x in itertools.chain(args, kwargs.values()): if isinstance(x, UnspecializedPythonVariable): unspec_count += 1 - elif not isinstance(x, ConstantVariable): + elif not (isinstance(x, VariableTracker) and x.is_python_constant()): return False return unspec_count > 0 diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 617f787e43d8a..982e0fccc5ca6 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -366,6 +366,21 @@ def is_python_constant(self) -> bool: except NotImplementedError: return False + def is_constant_match(self, *values: Any) -> bool: + """ + Check if this variable is a python constant matching one of the given values. + + Examples: + var.is_constant_match(None) # True if var is constant None + var.is_constant_match(True, False) # True if var is constant True or False + var.is_constant_match(NotImplemented) # True if var is constant NotImplemented + """ + return False + + def is_constant_none(self) -> bool: + """Check if this variable is a constant None value.""" + return False + def make_guard(self, fn: Callable[..., Any]) -> Guard: if self.source: return self.source.make_guard(fn) @@ -377,13 +392,17 @@ def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: """getattr(self, name) returning a python constant""" raise NotImplementedError + def is_symnode_like(self) -> bool: + """Return True for values that can participate in SymNode operations""" + return False + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": """getattr(self, name) returning a new variable""" value = self.const_getattr(tx, name) if not variables.ConstantVariable.is_literal(value): raise NotImplementedError source = self.source and AttrSource(self.source, name) - if source and not isinstance(self, variables.ConstantVariable): + if source and not self.is_python_constant(): # The second condition is to avoid guards on const getattr objects # like __code__.co_argcount install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) @@ -572,10 +591,7 @@ def call_tree_map( ) -> "VariableTracker": """Performance optimization to implement optree.tree_map faster than tracing it""" is_leaf_var = tree_map_kwargs.get("is_leaf") - if is_leaf_var is not None and not ( - is_leaf_var.is_python_constant() - and is_leaf_var.as_python_constant() is None - ): + if is_leaf_var is not None and not is_leaf_var.is_constant_none(): pred_result = is_leaf_var.call_function(tx, [self], {}) try: leaf_decision = pred_result.as_python_constant() diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 248ab9d5f4bab..968321be56a51 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2703,9 +2703,9 @@ def wrap_unspecialized_primitive(self, value): f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" ) fake_tensor_value = None - if isinstance(unspec_var, ConstantVariable): + if unspec_var.is_python_constant(): # TODO: when can this happen? - example_value = unspec_var.value + example_value = unspec_var.as_python_constant() else: example_value = unspec_var.proxy.node.meta["example_value"] assert is_fake(example_value) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 40b2be0437373..7143d4d8f3b3f 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -690,7 +690,7 @@ def list_iadd_handler( def expand_list_like( tx: "InstructionTranslator", lst: VariableTracker, const: VariableTracker ) -> VariableTracker: - if isinstance(lst, ConstantVariable): + if not isinstance(lst, BaseListVariable) and lst.is_python_constant(): lst, const = const, lst try: assert isinstance(lst, BaseListVariable) @@ -1031,8 +1031,7 @@ def create_exception_class_object( kwargs: dict[str, VariableTracker], ) -> VariableTracker: if fn is AssertionError and not all( - isinstance(x, variables.ConstantVariable) - and isinstance(x.value, str) + x.is_python_constant() and isinstance(x.as_python_constant(), str) for x in args ): unimplemented( @@ -1504,7 +1503,7 @@ def call_method( ) if self.fn is float and len(args) == 1 and name in ("fromhex", "hex"): - if isinstance(args[0], ConstantVariable): + if args[0].is_python_constant(): try: fn = getattr(float, name) res = fn(args[0].as_python_constant()) @@ -1550,10 +1549,12 @@ def call_method( if self.fn is str and len(args) >= 1: resolved_fn = getattr(self.fn, name) if resolved_fn in str_methods: + # Only delegate to ConstantVariable, not other types that happen to be constants if isinstance(args[0], ConstantVariable): return args[0].call_method(tx, name, args[1:], kwargs) if self.fn is float and len(args) >= 1: + # Only delegate to ConstantVariable, not other types that happen to be constants if isinstance(args[0], ConstantVariable): return ConstantVariable.create( getattr(float, name)(args[0].as_python_constant()) @@ -1802,7 +1803,7 @@ def _call_min_max_binary( "call_function", py_fn, *proxy_args_kwargs([a, b], {}) ) return SymNodeVariable.create(tx, proxy, None) - elif isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + elif a.is_python_constant() and b.is_python_constant(): value = self.fn( a.as_python_constant(), b.as_python_constant(), @@ -2587,7 +2588,7 @@ def call_getattr( if default is not None: hasattr_var = self.call_hasattr(tx, obj, name_var) if hasattr_var is not None: - assert hasattr_var.as_python_constant() in (True, False) + assert hasattr_var.is_constant_match(True, False) if not hasattr_var.as_python_constant(): return default else: @@ -3094,9 +3095,7 @@ def call_xor( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( - b, (SymNodeVariable, ConstantVariable) - ): + if a.is_symnode_like() and b.is_symnode_like(): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -3139,9 +3138,7 @@ def call_and_( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( - b, (SymNodeVariable, ConstantVariable) - ): + if a.is_symnode_like() and b.is_symnode_like(): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -3160,9 +3157,7 @@ def call_iand( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( - b, (SymNodeVariable, ConstantVariable) - ): + if a.is_symnode_like() and b.is_symnode_like(): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -3180,9 +3175,7 @@ def call_or_( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( - b, (SymNodeVariable, ConstantVariable) - ): + if a.is_symnode_like() and b.is_symnode_like(): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -3216,9 +3209,7 @@ def call_ior( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( - b, (SymNodeVariable, ConstantVariable) - ): + if a.is_symnode_like() and b.is_symnode_like(): return SymNodeVariable.create( tx, tx.output.create_proxy( diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 672fa1d804383..2b7a7661a1182 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -115,6 +115,15 @@ def as_python_constant(self) -> Any: def is_python_constant(self) -> Literal[True]: return True + def is_symnode_like(self) -> bool: + return isinstance(self.value, (int, bool)) + + def is_constant_match(self, *values: Any) -> bool: + return self.value in values + + def is_constant_none(self) -> bool: + return self.value is None + @property def items(self) -> list[VariableTracker]: """ @@ -311,10 +320,7 @@ def call_tree_map( return map_fn.call_function(tx, [self, *rest], {}) else: for other in rest: - if not ( - other.is_python_constant() - and other.as_python_constant() is None - ): + if not other.is_constant_none(): return self._tree_map_fallback( tx, tree_map_fn, @@ -356,7 +362,7 @@ def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None def create( cls, cls_type: Any, value_vt: VariableTracker, options: Any ) -> "EnumVariable": - if isinstance(value_vt, variables.ConstantVariable): + if value_vt.is_python_constant(): for member in list(cls_type): if member.value == value_vt.as_python_constant(): return cls(member, **options) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index c79f19216f68b..64ec27cf9e430 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -1143,9 +1143,7 @@ def __init__( # The context manager accepts Union[Tensor, Tuple[Tensor]] if isinstance(self.tensors, variables.TensorVariable): self.tensors = variables.TupleVariable([self.tensors]) - if isinstance( - self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable) - ): + if self.prev_versions.is_symnode_like(): self.prev_versions = variables.TupleVariable([self.prev_versions]) def enter(self, tx: "InstructionTranslator") -> VariableTracker: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 422cae7c4d3f1..b794cd2735a38 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -493,7 +493,6 @@ def install_dict_contains_guard( # 3b) contains=False. There is no easy way to selectively apply this # DICT_NOT_CONTAINS guard because our guard are represented via trees. # Be conservative and add DICT_KEYS_MATCH guard. - from . import ConstantVariable if not self.source: return @@ -502,12 +501,12 @@ def install_dict_contains_guard( return contains = args[0] in self - if args[0].source is None and isinstance(args[0], ConstantVariable): + if args[0].source is None and args[0].is_python_constant(): install_guard( self.make_guard( functools.partial( type(self).CONTAINS_GUARD, - key=args[0].value, + key=args[0].as_python_constant(), invert=not contains, ) ) @@ -674,10 +673,10 @@ def call_method( if self.user_cls is collections.OrderedDict and ( len(args) == 1 or "last" in kwargs ): - if len(args) == 1 and isinstance(args[0], ConstantVariable): - last = args[0].value - elif (v := kwargs.get("last")) and isinstance(v, ConstantVariable): - last = v.value + if len(args) == 1 and args[0].is_python_constant(): + last = args[0].as_python_constant() + elif (v := kwargs.get("last")) and v.is_python_constant(): + last = v.as_python_constant() else: raise_args_mismatch(tx, name) k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] @@ -780,15 +779,11 @@ def call_method( raise_observed_exception(KeyError, tx) last = True - if len(args) == 2 and isinstance(args[1], ConstantVariable): - last = args[1].value + if len(args) == 2 and args[1].is_python_constant(): + last = args[1].as_python_constant() - if ( - kwargs - and "last" in kwargs - and isinstance(kwargs["last"], ConstantVariable) - ): - last = kwargs.get("last").value # type: ignore[union-attr] + if kwargs and "last" in kwargs and kwargs["last"].is_python_constant(): + last = kwargs.get("last").as_python_constant() # type: ignore[union-attr] key = Hashable(args[0]) self.items.move_to_end(key, last=last) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index fdc2f53f82383..c43866c62809c 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1047,10 +1047,7 @@ def call_method( if self._is_generator_just_started() and len(args): # can't send non-None value to a just-started generator # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen - if not all( - isinstance(arg, ConstantVariable) and arg.value is None - for arg in args - ): + if not all(arg.is_constant_none() for arg in args): raise_observed_exception(TypeError, tx) tracer = self.inline_tracer tracer.push_many(args) @@ -2427,7 +2424,7 @@ def call_function( and not kwargs and isinstance(args[0], (variables.ListVariable, variables.TupleVariable)) and all( - (isinstance(x, variables.ConstantVariable) and isinstance(x.value, int)) + (x.is_python_constant() and isinstance(x.as_python_constant(), int)) or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int) for x in args[0].items ) @@ -2443,8 +2440,8 @@ def call_function( sym_num=torch.sym_sum( [ ( - x.value - if isinstance(x, variables.ConstantVariable) + x.as_python_constant() + if x.is_python_constant() else x.sym_num # type: ignore[attr-defined] ) for x in args[0].items @@ -2649,7 +2646,6 @@ def call_HOP( combined_args_raw: dict[str, Any], tx: "InstructionTranslator", ) -> "variables.ConstantVariable": - from .constant import ConstantVariable from .dicts import ConstDictVariable # as we can only pass tensors as non-const args in fx graph, @@ -2683,12 +2679,12 @@ def call_HOP( constant_args = { k: v.as_python_constant() for k, v in combined_args_raw.items() - if isinstance(v, ConstantVariable) + if isinstance(v, VariableTracker) and v.is_python_constant() } non_constant_args = { k: v for k, v in combined_args.items() - if not isinstance(v, ConstantVariable) + if not (isinstance(v, VariableTracker) and v.is_python_constant()) } for v in non_constant_args.values(): @@ -2989,9 +2985,7 @@ def call_function( if len(args) == 2: is_leaf = args[1] - if not ( - isinstance(is_leaf, variables.ConstantVariable) and is_leaf.value is None - ): + if not is_leaf.is_constant_none(): return super().call_function(tx, args, kwargs) # Optimize the case where is_leaf is None diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 0f7491911d35b..14dec2f9c45ea 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -160,7 +160,7 @@ def _unwrap_var(var): return var.proxy.node.meta["example_value"] elif isinstance(var, SymNodeVariable): return var.sym_num - elif isinstance(var, ConstantVariable): + elif var.is_python_constant(): return var.as_python_constant() else: unimplemented( @@ -225,11 +225,7 @@ def find_mismatched_vars(var, types, allow_none=False): for value in var.items.values(): mismatched_vars.update(find_mismatched_vars(value, types, allow_none)) else: - - def _is_none(var): - return var.is_python_constant() and var.as_python_constant() is None - - if not isinstance(var, types) and not (allow_none and _is_none(var)): + if not isinstance(var, types) and not (allow_none and var.is_constant_none()): mismatched_vars.add(var) return mismatched_vars @@ -503,7 +499,8 @@ def _call_while_loop( def unspecialize_carried_inputs(tx, carry) -> VariableTracker: # See NOTE [unspecialize int carry with unbacked symints] if ( - isinstance(carry, ConstantVariable) and carry.python_type() is int + carry.is_python_constant() + and isinstance(carry.as_python_constant(), int) ) or isinstance(carry, SymNodeVariable): example_value = _create_unbacked_symint( tx.output.fake_mode, ignore_fresh_unbacked_symbols=True @@ -601,7 +598,7 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: *graph_break_hints.USER_ERROR, ], ) - elif isinstance(cond_r, ConstantVariable): + elif cond_r.is_python_constant(): # short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False pred = cond_r.as_python_constant() if pred: @@ -1976,7 +1973,9 @@ def speculate_branch(branch): ], ) for ret in ret_val.unpack_var_sequence(tx): - if isinstance(ret, ConstantVariable) and ret.python_type() is not int: + if ret.is_python_constant() and not isinstance( + ret.as_python_constant(), int + ): unimplemented( gb_type="torch.cond: unsupported branch return type (constant non-int)", context=str(ret_val), @@ -2103,7 +2102,8 @@ def validate_subgraph_output_types(output: VariableTracker): if ( isinstance(out, SymNodeVariable) and out.python_type() in (int, bool) ) or ( - isinstance(out, ConstantVariable) and out.python_type() in (int, bool) + out.is_python_constant() + and isinstance(out.as_python_constant(), (int, bool)) ): continue unimplemented( @@ -2719,10 +2719,7 @@ def _call_function( # Check all outputs of map are tensors. # For map, outputting None is OK, thus ignore None values in the check body_r_vars = body_r.unpack_var_sequence(tx) - none_mask = [ - type(x.realize()) is ConstantVariable and x.as_python_constant() is None - for x in body_r_vars - ] + none_mask = [x.is_constant_none() for x in body_r_vars] _check_all_tensorvariable( [br for bm, br in zip(none_mask, body_r_vars) if not bm] ) @@ -3011,7 +3008,7 @@ def call_function( grad_enabled, fn_var, *rest_args = args - if not isinstance(grad_enabled, ConstantVariable): + if not grad_enabled.is_python_constant(): unimplemented( gb_type="wrap_with_set_grad_enabled: non-constant grad_enabled", context=str(grad_enabled), @@ -3099,7 +3096,7 @@ def call_function( device_type, dtype, enabled, cache_enabled, fn_var, *rest_args = args for arg in [device_type, dtype, enabled, cache_enabled]: - if not isinstance(arg, ConstantVariable): + if not arg.is_python_constant(): unimplemented( gb_type="wrap_with_autocast: expected constant arg", context=str(args), @@ -3717,8 +3714,14 @@ def _call_function( tx, query, score_mod, "score_mod" ) mask_fn = block_mask.items[-1] - if isinstance(mask_fn, ConstantVariable): - mask_fn = UserFunctionVariable(torch.nn.attention._flex_attention._no_mask) + if mask_fn.is_python_constant(): + mask_callable = mask_fn.as_python_constant() + if mask_callable is None: + mask_callable = torch.nn.attention.flex_attention.noop_mask + mask_fn = UserFunctionVariable( + mask_callable, + source=mask_fn.source, + ) mask_fn_node, mask_fn_lifted_args = self.create_wrapped_node( tx, query, mask_fn, "mask_fn" ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 2689d5e094977..4a3c0247add1b 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -116,7 +116,7 @@ def call_function( def retrieve_const_key(key: VariableTracker) -> Any: if isinstance(key, variables.SymNodeVariable): return key.evaluate_expr() - elif isinstance(key, variables.ConstantVariable): + elif key.is_python_constant(): return key.as_python_constant() else: unimplemented( @@ -595,7 +595,7 @@ def _next() -> VariableTracker: while True: item = _next() self.index += 1 - if isinstance(self.fn, ConstantVariable) and self.fn.value is None: + if self.fn.is_constant_none(): res = item else: res = self.fn.call_function(tx, [item], {}) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 4f21e35479fb8..c6f9448c9b6cf 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -365,7 +365,9 @@ def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None: def maybe_as_int(x: VariableTracker) -> VariableTracker: return ( - ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x + ConstantVariable.create(int(x.as_python_constant())) + if x.is_python_constant() + else x ) # cast each argument to an integer @@ -903,10 +905,7 @@ def call_method( if len(kwargs) != 0: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") - if ( - key_fn_var.is_python_constant() - and key_fn_var.as_python_constant() is None - ): + if key_fn_var.is_constant_none(): keys = self.items.copy() else: keys = [key_fn_var.call_function(tx, [x], {}) for x in self.items] @@ -1260,8 +1259,8 @@ def numel(self, tx: "InstructionTranslator") -> VariableTracker: sym_sizes = [] for v in self.items: - if isinstance(v, ConstantVariable): - const_result *= v.value + if v.is_python_constant(): + const_result *= v.as_python_constant() else: assert isinstance(v, SymNodeVariable), type(v) # Delay proxy calls until we know it will be necessary diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index c7d6e58ba4531..466b7a757d829 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -474,7 +474,7 @@ def raise_error(msg): if name == "__context__": self.set_context(val) elif name == "__cause__": - if (isinstance(val, ConstantVariable) and val.value is None) or isinstance( + if val.is_constant_none() or isinstance( val, ( variables.BuiltinVariable, @@ -488,12 +488,12 @@ def raise_error(msg): else: raise_error("exception cause must be None or derive from BaseException") elif name == "__suppress_context__": - if isinstance(val, ConstantVariable) and val.value in (True, False): + if val.is_constant_match(True, False): self.__suppress_context__ = val else: raise_error("exception cause must be None or derive from BaseException") elif name == "__traceback__": - if isinstance(val, ConstantVariable) and val.value is None: + if val.is_constant_none(): self.__traceback__ = val else: unimplemented( diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index bb6952abf0b56..c9fe1e2802264 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -859,7 +859,7 @@ def gen_source(source: Source, name: str) -> Source: # pyrefly: ignore[missing-attribute] if type(module).__getitem__ not in builtin_supported: if not ( - isinstance(args[0], variables.ConstantVariable) + args[0].is_python_constant() and isinstance(args[0].as_python_constant(), (str, int)) ): unimplemented( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 69ca37db4ef37..3e29b9a08347e 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -218,7 +218,7 @@ def get_python_args( """Get python values equivalent to the variable tracker args""" def map_arg(arg: Any) -> Any: - if isinstance(arg, ConstantVariable): + if isinstance(arg, VariableTracker) and arg.is_python_constant(): return arg.as_python_constant() elif isinstance(arg, ListVariable) and not arg.items: return [] diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 47439387e0fca..8002e41a42631 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -597,11 +597,14 @@ def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {}) # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through # symbolic_shapes, but that end up as int/sympy.Integer - assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable)) + assert ( + isinstance(dyn_length, SymNodeVariable) + or dyn_length.is_python_constant() + ) if isinstance(dyn_length, SymNodeVariable): length = dyn_length.evaluate_expr(tx.output) else: - length = dyn_length.value + length = dyn_length.as_python_constant() if idxes is None: idxes = range(length) @@ -1409,7 +1412,8 @@ def method_new(self, *args, **kwargs): if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( len(args) >= 1 and all( - isinstance(a, ConstantVariable) and a.python_type() is int for a in args + a.is_python_constant() and isinstance(a.as_python_constant(), int) + for a in args ) ): from ..symbolic_convert import InstructionTranslator @@ -1475,6 +1479,9 @@ def python_type(self): else: return type(self.sym_num) + def is_symnode_like(self) -> bool: + return True + def as_proxy(self): return self.proxy @@ -1632,9 +1639,7 @@ def call_method( dtype_arg = kwargs["dtype"] elif len(args) > 0: dtype_arg = args[0] - is_object_str = ( - isinstance(dtype_arg, ConstantVariable) and dtype_arg.value == "O" - ) + is_object_str = dtype_arg is not None and dtype_arg.is_constant_match("O") is_object_type = ( isinstance(dtype_arg, BuiltinVariable) and dtype_arg.fn is object ) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index a4f940cb2adaf..8a621aedd5a2c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -934,11 +934,15 @@ def handle_constant_processgroup_functions( # bake the result into the trace if len(args) == 1: # group or group name - assert isinstance(args[0], (ProcessGroupVariable, ConstantVariable)) + assert ( + isinstance(args[0], ProcessGroupVariable) + or args[0].is_python_constant() + ) elif len(args) == 2: # ranks + tag - assert isinstance(args[0], ListVariable) and isinstance( - args[1], ConstantVariable + assert ( + isinstance(args[0], ListVariable) + and args[1].is_python_constant() ) else: raise AssertionError( @@ -1017,7 +1021,7 @@ def handle_nested_tensor( ): from .lists import BaseListVariable - if layout and layout.as_python_constant() == torch.strided: + if layout and layout.is_constant_match(torch.strided): unimplemented( gb_type="Attempted to use strided NestedTensor", context=f"layout={layout}", @@ -1041,9 +1045,7 @@ def handle_nested_tensor( @register(torch.nn.functional.one_hot) def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): if len(args) + len(kwargs) == 1 or ( - len(args) == 2 - and args[1].is_python_constant() - and args[1].as_python_constant() == -1 + len(args) == 2 and args[1].is_constant_match(-1) ): unimplemented( gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape", @@ -1065,7 +1067,7 @@ def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr): expr.sym_num ) ) - elif isinstance(expr, ConstantVariable): + elif expr.is_python_constant(): return expr @register(torch.fx.experimental.symbolic_shapes.guard_or_true) @@ -1076,7 +1078,7 @@ def handle_guard_or_true(self, tx: "InstructionTranslator", expr): return variables.ConstantVariable.create( torch.fx.experimental.symbolic_shapes.guard_or_true(expr.sym_num) ) - elif isinstance(expr, ConstantVariable): + elif expr.is_python_constant(): return expr @register(torch.fx.experimental.symbolic_shapes.guard_or_false) @@ -1087,7 +1089,7 @@ def handle_guard_or_false(self, tx: "InstructionTranslator", expr): return variables.ConstantVariable.create( torch.fx.experimental.symbolic_shapes.guard_or_false(expr.sym_num) ) - elif isinstance(expr, ConstantVariable): + elif expr.is_python_constant(): return expr @register(torch.fx.experimental.symbolic_shapes.statically_known_false) @@ -1098,15 +1100,15 @@ def handle_statically_known_false(self, tx: "InstructionTranslator", expr): expr.sym_num ) ) - elif isinstance(expr, ConstantVariable): + elif expr.is_python_constant(): return expr @register(torch.fx.experimental.symbolic_shapes.guard_scalar) def guard_scalar(self, tx: "InstructionTranslator", expr): if isinstance(expr, SymNodeVariable): val = expr.sym_num - elif isinstance(expr, ConstantVariable): - val = expr.value + elif expr.is_python_constant(): + val = expr.as_python_constant() else: unimplemented( gb_type="torch.fx.experimental.symbolic_shapes.guard_scalar branch not supported", @@ -1127,7 +1129,7 @@ def handle_statically_known_true(self, tx: "InstructionTranslator", expr): expr.sym_num ) ) - elif isinstance(expr, ConstantVariable): + elif expr.is_python_constant(): return expr @register(torch.fx.experimental.symbolic_shapes.sym_and) @@ -1156,8 +1158,8 @@ def handle_sym_or(self, tx: "InstructionTranslator", *terms): def handle_has_static_value(self, tx: "InstructionTranslator", expr): if isinstance(expr, SymNodeVariable): val = expr.sym_num - elif isinstance(expr, ConstantVariable): - val = expr.value + elif expr.is_python_constant(): + val = expr.as_python_constant() else: return @@ -1357,7 +1359,7 @@ def handle_set_default_device( # Running the graph will ensure that the DeviceContext mode is # at the correct position in the stack TorchFunctionModeStackVariable.register_mutation(tx) - if args[0].is_python_constant() and args[0].as_python_constant() is None: + if args[0].is_constant_none(): TorchFunctionModeStackVariable.clear_default_device(tx) else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) @@ -1539,8 +1541,7 @@ def call_function( any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args + isinstance(x, SymNodeVariable) or x.is_python_constant() for x in args ) if ( getattr(self.value, "__module__", "") == "torch" diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index c7254afdfebfc..b2a86eb4f017f 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -543,7 +543,7 @@ def dispatch_torch_function( res = tx.symbolic_torch_function_state.call_torch_function_mode( tx, fn, types, args, kwargs ) - if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + if not res.is_constant_match(NotImplemented): return res for arg in overloaded_args: @@ -555,7 +555,7 @@ def dispatch_torch_function( kwargs, ) - if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + if not res.is_constant_match(NotImplemented): return res unimplemented( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index cc377a09ab746..d227e9fa453ad 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -881,8 +881,8 @@ def deque_signature(iterable=None, maxlen=None): return tensor_variable elif self.value is random.Random: - if len(args) == 1 and isinstance(args[0], variables.ConstantVariable): - seed = args[0].value + if len(args) == 1 and args[0].is_python_constant(): + seed = args[0].as_python_constant() else: seed = None random_object = random.Random(seed) @@ -1911,9 +1911,9 @@ def call_method(self, tx, name, args, kwargs): elif ( name == "__setattr__" and len(args) == 2 - and isinstance(args[0], variables.ConstantVariable) - and args[0].value - in ("__cause__", "__context__", "__suppress_context__", "__traceback__") + and args[0].is_constant_match( + "__cause__", "__context__", "__suppress_context__", "__traceback__" + ) ): self.exc_vt.call_setattr(tx, args[0], args[1]) elif name == "with_traceback": From 45cd3b9f155c09b921de010b1867d606fac28a62 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 4 Dec 2025 00:27:10 +0000 Subject: [PATCH 287/338] Re-enable torch.compile tests for Python 3.12 and Windows (#169387) These skips are no longer needed now that torch.compile is supported on Python 3.12 and Windows. Test plan: - Test on a windows machine if these tests are working as expected Pull Request resolved: https://github.com/pytorch/pytorch/pull/169387 Approved by: https://github.com/williamwen42 --- test/export/test_sparse.py | 4 --- test/quantization/pt2e/test_graph_utils.py | 10 +------ test/test_nestedtensor.py | 32 ++-------------------- test/test_transformers.py | 2 -- torch/testing/_internal/common_utils.py | 4 +++ 5 files changed, 8 insertions(+), 44 deletions(-) diff --git a/test/export/test_sparse.py b/test/export/test_sparse.py index c8d799a0254b0..975e9979982f5 100644 --- a/test/export/test_sparse.py +++ b/test/export/test_sparse.py @@ -3,7 +3,6 @@ # Test to ensure sparsity information propagates properly into traced graph. # -import sys import unittest import torch @@ -91,9 +90,6 @@ def forward(self, x): @unittest.skipIf(is_fbcode(), "See torch._dynamo.config") -@unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" -) class TestSparseProp(TestCase): def setUp(self): super().setUp() diff --git a/test/quantization/pt2e/test_graph_utils.py b/test/quantization/pt2e/test_graph_utils.py index 2a26ff682b93f..ee2603dae84cd 100644 --- a/test/quantization/pt2e/test_graph_utils.py +++ b/test/quantization/pt2e/test_graph_utils.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: quantization"] import copy -import unittest import torch import torch._dynamo as torchdynamo @@ -9,15 +8,10 @@ get_equivalent_types, update_equivalent_types_dict, ) -from torch.testing._internal.common_utils import ( - IS_WINDOWS, - raise_on_run_directly, - TestCase, -) +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class TestGraphUtils(TestCase): - @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") def test_conv_bn_conv_relu(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -63,7 +57,6 @@ def x(): self.assertRaises(ValueError, x) - @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") def test_conv_bn_relu(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -98,7 +91,6 @@ def forward(self, x): ) self.assertEqual(len(fused_partitions), 0) - @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") def test_customized_equivalet_types_dict(self): class M(torch.nn.Module): def __init__(self) -> None: diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 8e9d1ed0217ae..909cbad423588 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -52,7 +52,6 @@ gradcheck, instantiate_parametrized_tests, IS_FBCODE, - IS_WINDOWS, markDynamoStrictTest, NestedTensorTestCase, parametrize, @@ -63,6 +62,7 @@ subtest, TEST_WITH_ROCM, xfailIfTorchDynamo, + xfailIfWindows, ) from torch.testing._internal.opinfo.core import ( BinaryUfuncInfo, @@ -6672,7 +6672,6 @@ def check_size(nt1, nt2, nt3, nt4): check_size(nt1_t, nt2_t, nt3_t, nt4_t) @skipIfTorchDynamo("compiles internally") - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") def test_specialize_dynamic_shape(self, device): values = torch.randn((18, 16), device=device) @@ -6694,7 +6693,6 @@ def fn(values, same_size): ) @skipIfTorchDynamo("compiles internally") - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") def test_specialize_dynamic_shape_recompile(self, device): def generate_inp(total_len): @@ -7005,9 +7003,9 @@ def check_forward_backward(skip_backward=False): check_forward_backward() @skipIfTorchDynamo("SDPA test compiles internally") - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") # Guarding with sqrt() doesn't work on ROCm? + @xfailIfWindows @skipCUDAIfRocm @onlyCUDA @dtypes( @@ -7192,8 +7190,8 @@ def in_proj(input_packed, qkv_linear=qkv_linear): out, out_component, atol=output_ref_atol, rtol=output_ref_rtol ) + @decorateIf(xfailIfWindows, lambda params: params["dtype"] == torch.float32) @skipIfTorchDynamo("SDPA test compiles internally") - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") # mha_varlen_fwd not supported on ROCm @skipCUDAIfRocm @@ -7230,7 +7228,6 @@ def f(values, offsets): @skipCUDAIfRocm @onlyCUDA @skipIfTorchDynamo() - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") def test_sdpa_autocast(self, device): def fn_nt(values32, values16, offsets): nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16) @@ -7377,7 +7374,6 @@ def fn(values, lengths): # TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated. @skipCUDAIfRocm # not needed @skipIfTorchDynamo("compiles internally") - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @parametrize("use_legacy_api", [True, False]) @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644") @@ -7734,10 +7730,6 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_preserves_metadata_cache(self, device, dtype): @@ -7765,10 +7757,6 @@ def f(nt): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_with_dynamic_max_seq_len(self, device, dtype): @@ -7802,10 +7790,6 @@ def f(nt): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_with_dynamic_min_seq_len(self, device, dtype): @@ -7839,10 +7823,6 @@ def f(nt): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): @@ -7970,7 +7950,6 @@ def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad): # blows up due to test parametrization otherwise @torch._dynamo.utils.disable_cache_limit() @skipIfTorchDynamo("SDPA test compiles internally") - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm @dtypes(torch.float32, torch.double, torch.half) @@ -8073,10 +8052,6 @@ def _g(nt): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_padded_dense_conversion_preserves_metadata_cache( @@ -8166,7 +8141,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None): self.assertEqual(res.shape, (4, nt.shape[1], 6)) @skipIfTorchDynamo("compiles internally") - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @dtypes(torch.float32) @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) diff --git a/test/test_transformers.py b/test/test_transformers.py index 1897548f560cf..d9818f3330184 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -37,7 +37,6 @@ gradcheck, make_tensor, NOTEST_CPU, - IS_WINDOWS, TEST_WITH_TORCHDYNAMO, TEST_XPU, ) @@ -4811,7 +4810,6 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: lis "shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], ) - @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") @skipIfTorchDynamo("This function already calls torch.compile.") def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]): cnts = CompileCounterWithBackend("aot_eager") diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index df3ca03b76242..ad4918ef9103a 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1713,6 +1713,10 @@ def xfailIfLinux(func): return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM and not IS_FBCODE else func +def xfailIfWindows(func): + return unittest.expectedFailure(func) if IS_WINDOWS else func + + def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"): """ Usage: From 8da5d29de7feb165047246464d09c4c2b2318987 Mon Sep 17 00:00:00 2001 From: Kathryn-cat Date: Thu, 4 Dec 2025 22:52:41 +0000 Subject: [PATCH 288/338] [DLPack] C Functions for DLPack Speed Exchange and Stream Handling (#165483) ## Addressed Issue Issue #162845 ## Summary of Changes This PR introduces a unified `DLPackExchangeAPI` struct as described in proposal [175](https://github.com/dmlc/dlpack/issues/175). This new convention replaces the previous mechanism of separate function pointers, and aligns with the latest DLPack standard as shown in PR [174](https://github.com/dmlc/dlpack/pull/174). Specifically, the new `DLPackExchangeAPI` struct is exposed as `torch.Tensor.__c_dlpack_exchange_api__`, which stores and exposes the following function pointers: * `managed_tensor_allocator` * `managed_tensor_from_py_object_no_sync` * `managed_tensor_to_py_object_no_sync` * `dltensor_from_py_object_no_sync` * `current_work_stream` Within the new `DLPackExchangeAPI` struct, the new `current_work_stream` function pointer allows more robust and integrated querying of the current device stream (e.g., CUDA stream) during DLPack tensor exchanges. All the conversion from/to DLPack has been updated to `_no_sync`, meaning you should use `current_work_stream` to explicitly handle stream synchronization. It also includes a non-owning DLTensor conversion `dltensor_from_py_object_no_sync` to avoid unnecessary reference counting. Following this change, the `dlpack.h` has been updated to the latest DLPack. Unit tests are added using `torch.utils.cpp_extension.load_inline` to avoid GIL release issues when calling `THPVariable_Wrap`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165483 Approved by: https://github.com/tqchen, https://github.com/albanD --- aten/src/ATen/DLConvertor.cpp | 25 ++- aten/src/ATen/DLConvertor.h | 7 + aten/src/ATen/dlpack.h | 292 ++++++++++++++++++++++++++++++++-- test/test_dlpack.py | 233 +++++++++++++++++++++++++++ torch/_C/__init__.pyi.in | 1 + torch/_tensor.py | 1 + torch/csrc/Module.cpp | 119 ++++++++++++++ 7 files changed, 666 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index b39f3eafa32df..9d7ebb3a86cfb 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -152,7 +152,10 @@ DLDevice torchDeviceToDLDevice(at::Device device) { return ctx; } -static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) { +Device dlDeviceToTorchDevice( + DLDeviceType type, + c10::DeviceIndex index, + void* data) { switch (type) { case DLDeviceType::kDLCPU: return at::Device(DeviceType::CPU); @@ -437,7 +440,8 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { } DLTensor& dl_tensor = src->dl_tensor; - Device device = getATenDevice(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data); + Device device = dlDeviceToTorchDevice( + dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data); ScalarType stype = toScalarType(dl_tensor.dtype); if (!dl_tensor.strides) { @@ -465,6 +469,21 @@ template at::Tensor fromDLPackImpl(DLManagedTensorVers } // namespace +void toDLPackNonOwning(const Tensor& src, DLTensor* out) { + // Fill in the pre-allocated DLTensor struct with direct pointers + // This is a non-owning conversion - the caller owns the tensor + // and must keep it alive for the duration of DLTensor usage + out->data = src.data_ptr(); + out->device = torchDeviceToDLDevice(src.device()); + out->ndim = static_cast(src.dim()); + out->dtype = getDLDataType(src); + // sizes() and strides() return pointers to TensorImpl's stable storage + // which remains valid as long as the tensor is alive + out->shape = const_cast(src.sizes().data()); + out->strides = const_cast(src.strides().data()); + out->byte_offset = 0; +} + DLManagedTensor* toDLPack(const Tensor& src) { return toDLPackImpl(src); } @@ -489,7 +508,7 @@ Tensor maybeCopyTensor( bool force_move = copy.has_value() && !*copy; if (optional_dl_device.has_value()) { - auto device = at::getATenDevice( + auto device = at::dlDeviceToTorchDevice( optional_dl_device->device_type, static_cast(optional_dl_device->device_id)); diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index 928731fafb2f6..46a7cb202e5b4 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -13,6 +13,7 @@ namespace at { TORCH_API ScalarType toScalarType(const DLDataType& dtype); TORCH_API DLManagedTensor* toDLPack(const Tensor& src); TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src); +TORCH_API void toDLPackNonOwning(const Tensor& src, DLTensor* out); TORCH_API Tensor fromDLPack(DLManagedTensor* src, std::function deleter = {}); TORCH_API Tensor fromDLPackVersioned( @@ -31,6 +32,12 @@ TORCH_API Tensor maybeCopyTensor( // Converts the given at::Device into a DLDevice. TORCH_API DLDevice torchDeviceToDLDevice(at::Device device); +// Converts the DLDevice to an ATen device. +TORCH_API Device dlDeviceToTorchDevice( + DLDeviceType type, + c10::DeviceIndex index, + void* data = nullptr); + // This trait class is used for retrieving different attributes, such as the // PyCapsule names and conversion functions for both DLPack tensor classes: // `DLManagedTensor` and `DLManagedTensorVersioned`. diff --git a/aten/src/ATen/dlpack.h b/aten/src/ATen/dlpack.h index f1b3ae2b7760b..63fd0d0f4df33 100644 --- a/aten/src/ATen/dlpack.h +++ b/aten/src/ATen/dlpack.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2017 by Contributors + * Copyright (c) 2017 - by Contributors * \file dlpack.h * \brief The common header of DLPack. */ @@ -19,7 +19,7 @@ #define DLPACK_MAJOR_VERSION 1 /*! \brief The current minor version of dlpack */ -#define DLPACK_MINOR_VERSION 1 +#define DLPACK_MINOR_VERSION 2 /*! \brief DLPACK_DLL prefix for windows */ #ifdef _WIN32 @@ -118,6 +118,8 @@ typedef enum { kDLHexagon = 16, /*! \brief Microsoft MAIA devices */ kDLMAIA = 17, + /*! \brief AWS Trainium */ + kDLTrn = 18, } DLDeviceType; /*! @@ -222,7 +224,7 @@ typedef struct { * types. This pointer is always aligned to 256 bytes as in CUDA. The * `byte_offset` field should be used to point to the beginning of the data. * - * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, + * Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow, * TVM, perhaps others) do not adhere to this 256 byte alignment requirement * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed * (after which this note will be updated); at the moment it is recommended @@ -252,11 +254,23 @@ typedef struct { int32_t ndim; /*! \brief The data type of the pointer*/ DLDataType dtype; - /*! \brief The shape of the tensor */ + /*! + * \brief The shape of the tensor + * + * When ndim == 0, shape can be set to NULL. + */ int64_t* shape; /*! - * \brief strides of the tensor (in number of elements, not bytes) - * can be NULL, indicating tensor is compact and row-majored. + * \brief strides of the tensor (in number of elements, not bytes), + * can not be NULL if ndim != 0, must points to + * an array of ndim elements that specifies the strides, + * so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim. + * + * When ndim == 0, strides can be set to NULL. + * + * \note Before DLPack v1.2, strides can be NULL to indicate contiguous data. + * This is not allowed in DLPack v1.2 and later. The rationale + * is to simplify the consumer handling. */ int64_t* strides; /*! \brief The offset in bytes to the beginning pointer to data */ @@ -306,7 +320,7 @@ typedef struct DLManagedTensor { */ #define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) -/* +/*! * \brief bit mask to indicate that whether a sub-byte type is packed or padded. * * The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can @@ -324,7 +338,7 @@ typedef struct DLManagedTensor { * * \note This is the current standard DLPack exchange data structure. */ -struct DLManagedTensorVersioned { +typedef struct DLManagedTensorVersioned { /*! * \brief The API and ABI version of the current managed Tensor */ @@ -358,7 +372,267 @@ struct DLManagedTensorVersioned { uint64_t flags; /*! \brief DLTensor which is being memory managed */ DLTensor dl_tensor; -}; +} DLManagedTensorVersioned; + +//---------------------------------------------------------------------- +// DLPack `__c_dlpack_exchange_api__` fast exchange protocol definitions +//---------------------------------------------------------------------- +/*! + * \brief Request a producer library to create a new tensor. + * + * Create a new `DLManagedTensorVersioned` within the context of the producer + * library. The allocation is defined via the prototype DLTensor. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param prototype The prototype DLTensor. Only the dtype, ndim, shape, + * and device fields are used. + * \param out The output DLManagedTensorVersioned. + * \param error_ctx Context for `SetError`. + * \param SetError The function to set the error. + * \return The owning DLManagedTensorVersioned* or NULL on failure. + * SetError is called exactly when NULL is returned (the implementer + * must ensure this). + * \note - As a C function, must not thrown C++ exceptions. + * - Error propagation via SetError to avoid any direct need + * of Python API. Due to this `SetError` may have to ensure the GIL is + * held since it will presumably set a Python error. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackManagedTensorAllocator)( // + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // + void (*SetError)(void* error_ctx, const char* kind, const char* message) // +); + +/*! + * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned. + * + * This function does not perform any stream synchronization. The consumer should query + * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param py_object The Python object to convert. Must have the same type + * as the one the `DLPackExchangeAPI` was discovered from. + * \return The owning DLManagedTensorVersioned* or NULL on failure with a + * Python exception set. If the data cannot be described using DLPack + * this should be a BufferError if possible. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI, DLPackCurrentWorkStream + */ +typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // + void* py_object, // + DLManagedTensorVersioned** out // +); + +/*! + * \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor. + * + * This function provides a faster interface for temporary, non-owning, + * exchange. The producer (implementer) still owns the memory of data, strides, + * shape. The liveness of the DLTensor and the data it views is only guaranteed + * until control is returned. + * + * This function currently assumes that the producer (implementer) can fill + * in the DLTensor shape and strides without the need for temporary allocations. + * + * This function does not perform any stream synchronization. The consumer + * should query DLPackCurrentWorkStream to get the current work stream and + * launch kernels on it. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param py_object The Python object to convert. Must have the same type + * as the one the `DLPackExchangeAPI` was discovered from. + * \param out The output DLTensor, whose space is pre-allocated on stack. + * \return 0 on success, -1 on failure with a Python exception set. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI, DLPackCurrentWorkStream + */ +typedef int (*DLPackDLTensorFromPyObjectNoSync)( // + void* py_object, // + DLTensor* out // +); + +/*! + * \brief Obtain the current work stream of a device. + * + * Obtain the current work stream of a device from the producer framework. + * For example, it should map to torch.cuda.current_stream in PyTorch. + * + * When device_type is kDLCPU, the consumer do not have to query the stream + * and the producer can simply return NULL when queried. + * The consumer do not have to do anything on stream sync or setting. + * So CPU only framework can just provide a dummy implementation that + * always set out_current_stream[0] to NULL. + * + * \param device_type The device type. + * \param device_id The device id. + * \param out_current_stream The output current work stream. + * + * \return 0 on success, -1 on failure with a Python exception set. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackCurrentWorkStream)( // + DLDeviceType device_type, // + int32_t device_id, // + void** out_current_stream // +); + +/*! + * \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray. + * + * Convert an owning DLManagedTensorVersioned* to the Python tensor of the + * producer (implementer) library with the correct type. + * + * This function does not perform any stream synchronization. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param tensor The DLManagedTensorVersioned to convert the ownership of the + * tensor is stolen. + * \param out_py_object The output Python object. + * \return 0 on success, -1 on failure with a Python exception set. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackManagedTensorToPyObjectNoSync)( // + DLManagedTensorVersioned* tensor, // + void** out_py_object // +); + +/*! + * \brief DLPackExchangeAPI stable header. + * \sa DLPackExchangeAPI + */ +typedef struct DLPackExchangeAPIHeader { + /*! + * \brief The provided DLPack version the consumer must check major version + * compatibility before using this struct. + */ + DLPackVersion version; + /*! + * \brief Optional pointer to an older DLPackExchangeAPI in the chain. + * + * It must be NULL if the framework does not support older versions. + * If the current major version is larger than the one supported by the + * consumer, the consumer may walk this to find an earlier supported version. + * + * \sa DLPackExchangeAPI + */ + struct DLPackExchangeAPIHeader* prev_api; +} DLPackExchangeAPIHeader; + +/*! + * \brief Framework-specific function pointers table for DLPack exchange. + * + * Additionally to `__dlpack__()` we define a C function table sharable by + * Python implementations via `__c_dlpack_exchange_api__`. + * This attribute must be set on the type as a Python integer compatible + * with `PyLong_FromVoidPtr`/`PyLong_AsVoidPtr`. + * + * A consumer library may use a pattern such as: + * + * \code + * + * PyObject *api_obj = type(tensor_obj).__c_dlpack_exchange_api__; // as C-code + * MyDLPackExchangeAPI *api = PyLong_AsVoidPtr(api_obj); + * if (api == NULL && PyErr_Occurred()) { goto handle_error; } + * + * \endcode + * + * Note that this must be defined on the type. The consumer should look up the + * attribute on the type and may cache the result for each unique type. + * + * The precise API table is given by: + * \code + * struct MyDLPackExchangeAPI : public DLPackExchangeAPI { + * MyDLPackExchangeAPI() { + * header.version.major = DLPACK_MAJOR_VERSION; + * header.version.minor = DLPACK_MINOR_VERSION; + * header.prev_version_api = nullptr; + * + * managed_tensor_allocator = MyDLPackManagedTensorAllocator; + * managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync; + * managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync; + * dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync; + * current_work_stream = MyDLPackCurrentWorkStream; + * } + * + * static const DLPackExchangeAPI* Global() { + * static MyDLPackExchangeAPI inst; + * return &inst; + * } + * }; + * \endcode + * + * Guidelines for leveraging DLPackExchangeAPI: + * + * There are generally two kinds of consumer needs for DLPack exchange: + * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel + * with the data from x, y, z. The consumer is also expected to run the kernel with the same + * stream context as the producer. For example, when x, y, z is torch.Tensor, + * consumer should query exchange_api->current_work_stream to get the + * current stream and launch the kernel with the same stream. + * This setup is necessary for no synchronization in kernel launch and maximum compatibility + * with CUDA graph capture in the producer. + * This is the desirable behavior for library extension support for frameworks like PyTorch. + * - N1: data ingestion and retention + * + * Note that obj.__dlpack__() API should provide useful ways for N1. + * The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0 + * with the support of the function pointer current_work_stream. + * + * Array/Tensor libraries should statically create and initialize this structure + * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array. + * The DLPackExchangeAPI* must stay alive throughout the lifetime of the process. + * + * One simple way to do so is to create a static instance of DLPackExchangeAPI + * within the framework and return a pointer to it. The following code + * shows an example to do so in C++. It should also be reasonably easy + * to do so in other languages. + */ +typedef struct DLPackExchangeAPI { + /*! + * \brief The header that remains stable across versions. + */ + DLPackExchangeAPIHeader header; + /*! + * \brief Producer function pointer for DLPackManagedTensorAllocator + * This function must not be NULL. + * \sa DLPackManagedTensorAllocator + */ + DLPackManagedTensorAllocator managed_tensor_allocator; + /*! + * \brief Producer function pointer for DLPackManagedTensorFromPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorFromPyObject + */ + DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackManagedTensorToPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorToPyObject + */ + DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackDLTensorFromPyObject + * This function can be NULL when the producer does not support this function. + * \sa DLPackDLTensorFromPyObjectNoSync + */ + DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackCurrentWorkStream + * This function must be not NULL. + * \sa DLPackCurrentWorkStream + */ + DLPackCurrentWorkStream current_work_stream; +} DLPackExchangeAPI; #ifdef __cplusplus } // DLPACK_EXTERN_C diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 3d27678b5864a..7abd5ea475b70 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -537,6 +537,239 @@ def test_dlpack_unsupported_dtype_error(self, device): ): from_dlpack(inp) + @skipMeta + @onlyNativeDeviceTypes + def test_dlpack_exchange_api(self, device): + """Comprehensive test of all DLPack Exchange API functions using inline C++""" + # Check that the C API capsule exists and get it + self.assertTrue(hasattr(torch.Tensor, "__c_dlpack_exchange_api__")) + api_capsule = torch.Tensor.__c_dlpack_exchange_api__ + self.assertEqual( + type(api_capsule).__name__, "PyCapsule", "API should be a PyCapsule" + ) + self.assertRegex(str(api_capsule), r'capsule object "dlpack_exchange_api"') + tensor = torch.arange(24, dtype=torch.float32, device=device).reshape(2, 3, 4) + + source = """ + #include + #include + #include + #include + + namespace py = pybind11; + + void test_dlpack_exchange_api(at::Tensor tensor, py::object api_obj, bool test_stream_exchange) { + PyObject* api_capsule = api_obj.ptr(); + TORCH_CHECK(PyCapsule_IsValid(api_capsule, "dlpack_exchange_api"), + "Invalid or mismatched DLPack exchange API capsule"); + const DLPackExchangeAPI* api = + static_cast( + PyCapsule_GetPointer(api_capsule, "dlpack_exchange_api")); + + // Test 1: API structure and version + { + TORCH_CHECK(api != nullptr, "API pointer is NULL"); + TORCH_CHECK(api->header.version.major == DLPACK_MAJOR_VERSION, + "Expected major version ", DLPACK_MAJOR_VERSION, + ", got ", api->header.version.major); + TORCH_CHECK(api->header.version.minor == DLPACK_MINOR_VERSION, + "Expected minor version ", DLPACK_MINOR_VERSION, + ", got ", api->header.version.minor); + TORCH_CHECK(api->managed_tensor_allocator != nullptr, + "managed_tensor_allocator is NULL"); + TORCH_CHECK(api->managed_tensor_from_py_object_no_sync != nullptr, + "managed_tensor_from_py_object_no_sync is NULL"); + TORCH_CHECK(api->managed_tensor_to_py_object_no_sync != nullptr, + "managed_tensor_to_py_object_no_sync is NULL"); + TORCH_CHECK(api->dltensor_from_py_object_no_sync != nullptr, + "dltensor_from_py_object_no_sync is NULL"); + TORCH_CHECK(api->current_work_stream != nullptr, + "current_work_stream is NULL"); + } + + // Test 2: managed_tensor_allocator + { + DLTensor prototype; + prototype.device.device_type = kDLCPU; + prototype.device.device_id = 0; + prototype.ndim = 3; + int64_t shape[3] = {3, 4, 5}; + prototype.shape = shape; + prototype.strides = nullptr; + DLDataType dtype; + dtype.code = kDLFloat; + dtype.bits = 32; + dtype.lanes = 1; + prototype.dtype = dtype; + prototype.data = nullptr; + prototype.byte_offset = 0; + + DLManagedTensorVersioned* out_tensor = nullptr; + int result = api->managed_tensor_allocator( + &prototype, &out_tensor, nullptr, nullptr); + TORCH_CHECK(result == 0, "Allocator failed with code ", result); + TORCH_CHECK(out_tensor != nullptr, "Allocator returned NULL"); + TORCH_CHECK(out_tensor->dl_tensor.ndim == 3, + "Expected ndim 3, got ", out_tensor->dl_tensor.ndim); + TORCH_CHECK(out_tensor->dl_tensor.shape[0] == 3, + "Expected shape[0] = 3, got ", out_tensor->dl_tensor.shape[0]); + TORCH_CHECK(out_tensor->dl_tensor.shape[1] == 4, + "Expected shape[1] = 4, got ", out_tensor->dl_tensor.shape[1]); + TORCH_CHECK(out_tensor->dl_tensor.shape[2] == 5, + "Expected shape[2] = 5, got ", out_tensor->dl_tensor.shape[2]); + TORCH_CHECK(out_tensor->dl_tensor.dtype.code == kDLFloat, + "Expected dtype code kDLFloat, got ", + out_tensor->dl_tensor.dtype.code); + TORCH_CHECK(out_tensor->dl_tensor.dtype.bits == 32, + "Expected dtype bits 32, got ", out_tensor->dl_tensor.dtype.bits); + TORCH_CHECK(out_tensor->dl_tensor.device.device_type == kDLCPU, + "Expected device type kDLCPU, got ", + out_tensor->dl_tensor.device.device_type); + if (out_tensor->deleter) { + out_tensor->deleter(out_tensor); + } + } + + // Test 3: managed_tensor_from_py_object_no_sync + { + std::unique_ptr py_obj( + THPVariable_Wrap(tensor), &Py_DecRef); + TORCH_CHECK(py_obj.get() != nullptr, "Failed to wrap tensor to PyObject"); + + DLManagedTensorVersioned* out_tensor = nullptr; + int result = api->managed_tensor_from_py_object_no_sync( + py_obj.get(), &out_tensor); + + TORCH_CHECK(result == 0, + "from_py_object_no_sync failed with code ", result); + TORCH_CHECK(out_tensor != nullptr, + "from_py_object_no_sync returned NULL"); + TORCH_CHECK(out_tensor->version.major == DLPACK_MAJOR_VERSION, + "Expected major version ", DLPACK_MAJOR_VERSION, + ", got ", out_tensor->version.major); + TORCH_CHECK(out_tensor->version.minor == DLPACK_MINOR_VERSION, + "Expected minor version ", DLPACK_MINOR_VERSION, + ", got ", out_tensor->version.minor); + TORCH_CHECK(out_tensor->dl_tensor.ndim == 3, + "Expected ndim 3, got ", out_tensor->dl_tensor.ndim); + TORCH_CHECK(out_tensor->dl_tensor.shape[0] == 2, + "Expected shape[0] = 2, got ", out_tensor->dl_tensor.shape[0]); + TORCH_CHECK(out_tensor->dl_tensor.shape[1] == 3, + "Expected shape[1] = 3, got ", out_tensor->dl_tensor.shape[1]); + TORCH_CHECK(out_tensor->dl_tensor.shape[2] == 4, + "Expected shape[2] = 4, got ", out_tensor->dl_tensor.shape[2]); + TORCH_CHECK(out_tensor->dl_tensor.dtype.code == kDLFloat, + "Expected dtype code kDLFloat, got ", + out_tensor->dl_tensor.dtype.code); + TORCH_CHECK(out_tensor->dl_tensor.dtype.bits == 32, + "Expected dtype bits 32, got ", + out_tensor->dl_tensor.dtype.bits); + TORCH_CHECK(out_tensor->dl_tensor.data != nullptr, + "Data pointer is NULL"); + + if (out_tensor->deleter) { + out_tensor->deleter(out_tensor); + } + } + + // Test 4: managed_tensor_to_py_object_no_sync + { + std::unique_ptr py_obj( + THPVariable_Wrap(tensor), &Py_DecRef); + TORCH_CHECK(py_obj.get() != nullptr, "Failed to wrap tensor to PyObject"); + + DLManagedTensorVersioned* managed_tensor = nullptr; + int result = api->managed_tensor_from_py_object_no_sync( + py_obj.get(), &managed_tensor); + TORCH_CHECK(result == 0, "from_py_object_no_sync failed"); + TORCH_CHECK(managed_tensor != nullptr, + "from_py_object_no_sync returned NULL"); + + std::unique_ptr py_obj_out( + nullptr, &Py_DecRef); + PyObject* py_obj_out_raw = nullptr; + result = api->managed_tensor_to_py_object_no_sync( + managed_tensor, reinterpret_cast(&py_obj_out_raw)); + py_obj_out.reset(py_obj_out_raw); + + TORCH_CHECK(result == 0, + "to_py_object_no_sync failed with code ", result); + TORCH_CHECK(py_obj_out.get() != nullptr, + "to_py_object_no_sync returned NULL"); + TORCH_CHECK(THPVariable_Check(py_obj_out.get()), + "Returned PyObject is not a Tensor"); + + at::Tensor result_tensor = THPVariable_Unpack(py_obj_out.get()); + TORCH_CHECK(result_tensor.dim() == 3, + "Expected 3 dimensions, got ", result_tensor.dim()); + TORCH_CHECK(result_tensor.size(0) == 2, + "Expected size(0) = 2, got ", result_tensor.size(0)); + TORCH_CHECK(result_tensor.size(1) == 3, + "Expected size(1) = 3, got ", result_tensor.size(1)); + TORCH_CHECK(result_tensor.size(2) == 4, + "Expected size(2) = 4, got ", result_tensor.size(2)); + TORCH_CHECK(result_tensor.scalar_type() == at::kFloat, + "Expected dtype kFloat, got ", result_tensor.scalar_type()); + } + + // Test 5: dltensor_from_py_object_no_sync (non-owning conversion) + DLDeviceType device_type; + int32_t device_id; + { + std::unique_ptr py_obj( + THPVariable_Wrap(tensor), &Py_DecRef); + TORCH_CHECK(py_obj.get() != nullptr, "Failed to wrap tensor to PyObject"); + + DLTensor dltensor; + int result = api->dltensor_from_py_object_no_sync(py_obj.get(), &dltensor); + TORCH_CHECK(result == 0, + "dltensor_from_py_object_no_sync failed with code ", result); + TORCH_CHECK(dltensor.ndim == 3, "Expected ndim 3, got ", dltensor.ndim); + TORCH_CHECK(dltensor.shape[0] == 2, + "Expected shape[0] = 2, got ", dltensor.shape[0]); + TORCH_CHECK(dltensor.shape[1] == 3, + "Expected shape[1] = 3, got ", dltensor.shape[1]); + TORCH_CHECK(dltensor.shape[2] == 4, + "Expected shape[2] = 4, got ", dltensor.shape[2]); + TORCH_CHECK(dltensor.dtype.code == kDLFloat, + "Expected dtype code kDLFloat, got ", dltensor.dtype.code); + TORCH_CHECK(dltensor.dtype.bits == 32, + "Expected dtype bits 32, got ", dltensor.dtype.bits); + TORCH_CHECK(dltensor.data != nullptr, "Data pointer is NULL"); + + // Capture device info for stream test + device_type = dltensor.device.device_type; + device_id = dltensor.device.device_id; + } + + // Test 6: current_work_stream + { + if (test_stream_exchange) { + void* stream_out = nullptr; + int result = api->current_work_stream(device_type, device_id, &stream_out); + TORCH_CHECK(result == 0, + "current_work_stream failed with code ", result); + TORCH_CHECK(stream_out != nullptr, + "Expected stream to be non-NULL"); + } + } + } + """ + + # Load and compile the inline C++ test + from torch.utils import cpp_extension + + module = cpp_extension.load_inline( + name="test_dlpack_exchange_api", + cpp_sources=[source], + functions=["test_dlpack_exchange_api"], + verbose=False, + with_cuda=device.startswith("cuda"), + ) + + # Run the comprehensive C++ test + module.test_dlpack_exchange_api(tensor, api_capsule, device.startswith("cuda")) + instantiate_device_type_tests(TestTorchDlPack, globals(), allow_mps=True) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9ad00753fe25c..f69272a5f0142 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1329,6 +1329,7 @@ def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack def _torchDeviceToDLDevice( device: torch.device, ) -> tuple[_int, _int]: ... # THPModule_torchDeviceToDLDevice +def _dlpack_exchange_api() -> object: ... # THPModule_DLPackExchangeAPI def _get_cpp_backtrace( frames_to_skip: _int, maximum_number_of_frames: _int, diff --git a/torch/_tensor.py b/torch/_tensor.py index c1093f35aa984..6acc8af9dab7c 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -109,6 +109,7 @@ def _dtype_to_typestr(dtype): # otherwise, it will not show up in autocomplete. class Tensor(torch._C.TensorBase): _is_param: bool + __c_dlpack_exchange_api__: object = torch._C._dlpack_exchange_api() def _clear_non_serializable_cached_data(self): r"""Clears any data cached in the tensor's ``__dict__`` that would prevent the tensor diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 4c304c27bfa19..00206bf827ee9 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -689,6 +689,124 @@ static PyObject* THPModule_torchDeviceToDLDevice( END_HANDLE_TH_ERRORS } +struct TorchDLPackExchangeAPI : public DLPackExchangeAPI { + TorchDLPackExchangeAPI() { + header.version.major = DLPACK_MAJOR_VERSION; + header.version.minor = DLPACK_MINOR_VERSION; + header.prev_api = nullptr; + managed_tensor_allocator = ManagedTensorAllocator; + managed_tensor_from_py_object_no_sync = ManagedTensorFromPyObjectNoSync; + managed_tensor_to_py_object_no_sync = ManagedTensorToPyObjectNoSync; + dltensor_from_py_object_no_sync = DLTensorFromPyObjectNoSync; + current_work_stream = CurrentWorkStream; + } + + static const DLPackExchangeAPI* Global() { + static TorchDLPackExchangeAPI inst; + return &inst; + } + + private: + // Fast non-owning PyObject→DLTensor conversion + static int DLTensorFromPyObjectNoSync(void* py_obj, DLTensor* out) { + try { + // Use handle (non-owning) to avoid unnecessary refcount operations + py::handle handle(static_cast(py_obj)); + at::Tensor tensor = handle.cast(); + at::toDLPackNonOwning(tensor, out); + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } + } + + // PyObject→DLManagedTensorVersioned conversion + static int ManagedTensorFromPyObjectNoSync( + void* py_obj, + DLManagedTensorVersioned** out) { + try { + py::handle handle(static_cast(py_obj)); + at::Tensor tensor = handle.cast(); + *out = at::toDLPackVersioned(tensor); + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } + } + + // DLManagedTensorVersioned→PyObject conversion + static int ManagedTensorToPyObjectNoSync( + DLManagedTensorVersioned* src, + void** py_obj_out) { + try { + at::Tensor tensor = at::fromDLPackVersioned(src, nullptr); + *py_obj_out = THPVariable_Wrap(tensor); + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } + } + + // Allocate new tensor from prototype + static int ManagedTensorAllocator( + DLTensor* prototype, + DLManagedTensorVersioned** out, + void* error_ctx, + void ( + *SetError)(void* error_ctx, const char* kind, const char* message)) { + try { + at::IntArrayRef shape( + prototype->shape, prototype->shape + prototype->ndim); + at::TensorOptions options = + at::TensorOptions() + .dtype(at::toScalarType(prototype->dtype)) + .device(at::dlDeviceToTorchDevice( + prototype->device.device_type, prototype->device.device_id)); + at::Tensor tensor = at::empty(shape, options); + *out = at::toDLPackVersioned(tensor); + return 0; + } catch (const std::exception& e) { + SetError(error_ctx, "MemoryError", e.what()); + return -1; + } + } + + // Get current CUDA/ROCm work stream + static int CurrentWorkStream( + DLDeviceType device_type, + int32_t device_id, + void** out_stream) { + try { +#ifdef USE_CUDA + if (device_type == kDLCUDA || device_type == kDLROCM) { + *out_stream = at::cuda::getCurrentCUDAStream(device_id).stream(); + return 0; + } +#endif + // For CPU and other devices, return NULL (no stream concept) + *out_stream = nullptr; + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } + } +}; + +static PyObject* THPModule_DLPackExchangeAPI( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return PyCapsule_New( + const_cast(TorchDLPackExchangeAPI::Global()), + "dlpack_exchange_api", + nullptr); + END_HANDLE_TH_ERRORS +} + static PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS size_t frames_to_skip = 0; @@ -1864,6 +1982,7 @@ static std::initializer_list TorchMethods = { THPModule_torchDeviceToDLDevice, METH_O, nullptr}, + {"_dlpack_exchange_api", THPModule_DLPackExchangeAPI, METH_NOARGS, nullptr}, {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr}, {"_rename_privateuse1_backend", THModule_rename_privateuse1_backend, From 77b90b703e0e56d2e70c5dc58fe8d967281c5705 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Dec 2025 22:53:43 +0000 Subject: [PATCH 289/338] Revert "Re-enable torch.compile tests for Python 3.12 and Windows (#169387)" This reverts commit 45cd3b9f155c09b921de010b1867d606fac28a62. Reverted https://github.com/pytorch/pytorch/pull/169387 on behalf of https://github.com/guilhermeleobas due to needs additional work ([comment](https://github.com/pytorch/pytorch/pull/169387#issuecomment-3614647730)) --- test/export/test_sparse.py | 4 +++ test/quantization/pt2e/test_graph_utils.py | 10 ++++++- test/test_nestedtensor.py | 32 ++++++++++++++++++++-- test/test_transformers.py | 2 ++ torch/testing/_internal/common_utils.py | 4 --- 5 files changed, 44 insertions(+), 8 deletions(-) diff --git a/test/export/test_sparse.py b/test/export/test_sparse.py index 975e9979982f5..c8d799a0254b0 100644 --- a/test/export/test_sparse.py +++ b/test/export/test_sparse.py @@ -3,6 +3,7 @@ # Test to ensure sparsity information propagates properly into traced graph. # +import sys import unittest import torch @@ -90,6 +91,9 @@ def forward(self, x): @unittest.skipIf(is_fbcode(), "See torch._dynamo.config") +@unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" +) class TestSparseProp(TestCase): def setUp(self): super().setUp() diff --git a/test/quantization/pt2e/test_graph_utils.py b/test/quantization/pt2e/test_graph_utils.py index ee2603dae84cd..2a26ff682b93f 100644 --- a/test/quantization/pt2e/test_graph_utils.py +++ b/test/quantization/pt2e/test_graph_utils.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: quantization"] import copy +import unittest import torch import torch._dynamo as torchdynamo @@ -8,10 +9,15 @@ get_equivalent_types, update_equivalent_types_dict, ) -from torch.testing._internal.common_utils import raise_on_run_directly, TestCase +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + raise_on_run_directly, + TestCase, +) class TestGraphUtils(TestCase): + @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") def test_conv_bn_conv_relu(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -57,6 +63,7 @@ def x(): self.assertRaises(ValueError, x) + @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") def test_conv_bn_relu(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -91,6 +98,7 @@ def forward(self, x): ) self.assertEqual(len(fused_partitions), 0) + @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") def test_customized_equivalet_types_dict(self): class M(torch.nn.Module): def __init__(self) -> None: diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 909cbad423588..8e9d1ed0217ae 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -52,6 +52,7 @@ gradcheck, instantiate_parametrized_tests, IS_FBCODE, + IS_WINDOWS, markDynamoStrictTest, NestedTensorTestCase, parametrize, @@ -62,7 +63,6 @@ subtest, TEST_WITH_ROCM, xfailIfTorchDynamo, - xfailIfWindows, ) from torch.testing._internal.opinfo.core import ( BinaryUfuncInfo, @@ -6672,6 +6672,7 @@ def check_size(nt1, nt2, nt3, nt4): check_size(nt1_t, nt2_t, nt3_t, nt4_t) @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") def test_specialize_dynamic_shape(self, device): values = torch.randn((18, 16), device=device) @@ -6693,6 +6694,7 @@ def fn(values, same_size): ) @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") def test_specialize_dynamic_shape_recompile(self, device): def generate_inp(total_len): @@ -7003,9 +7005,9 @@ def check_forward_backward(skip_backward=False): check_forward_backward() @skipIfTorchDynamo("SDPA test compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") # Guarding with sqrt() doesn't work on ROCm? - @xfailIfWindows @skipCUDAIfRocm @onlyCUDA @dtypes( @@ -7190,8 +7192,8 @@ def in_proj(input_packed, qkv_linear=qkv_linear): out, out_component, atol=output_ref_atol, rtol=output_ref_rtol ) - @decorateIf(xfailIfWindows, lambda params: params["dtype"] == torch.float32) @skipIfTorchDynamo("SDPA test compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") # mha_varlen_fwd not supported on ROCm @skipCUDAIfRocm @@ -7228,6 +7230,7 @@ def f(values, offsets): @skipCUDAIfRocm @onlyCUDA @skipIfTorchDynamo() + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") def test_sdpa_autocast(self, device): def fn_nt(values32, values16, offsets): nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16) @@ -7374,6 +7377,7 @@ def fn(values, lengths): # TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated. @skipCUDAIfRocm # not needed @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @parametrize("use_legacy_api", [True, False]) @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644") @@ -7730,6 +7734,10 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_preserves_metadata_cache(self, device, dtype): @@ -7757,6 +7765,10 @@ def f(nt): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_with_dynamic_max_seq_len(self, device, dtype): @@ -7790,6 +7802,10 @@ def f(nt): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_with_dynamic_min_seq_len(self, device, dtype): @@ -7823,6 +7839,10 @@ def f(nt): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): @@ -7950,6 +7970,7 @@ def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad): # blows up due to test parametrization otherwise @torch._dynamo.utils.disable_cache_limit() @skipIfTorchDynamo("SDPA test compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm @dtypes(torch.float32, torch.double, torch.half) @@ -8052,6 +8073,10 @@ def _g(nt): @dtypes(torch.float32) @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @skipCUDAIfRocm def test_compile_padded_dense_conversion_preserves_metadata_cache( @@ -8141,6 +8166,7 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None): self.assertEqual(res.shape, (4, nt.shape[1], 6)) @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") @dtypes(torch.float32) @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) diff --git a/test/test_transformers.py b/test/test_transformers.py index d9818f3330184..1897548f560cf 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -37,6 +37,7 @@ gradcheck, make_tensor, NOTEST_CPU, + IS_WINDOWS, TEST_WITH_TORCHDYNAMO, TEST_XPU, ) @@ -4810,6 +4811,7 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: lis "shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], ) + @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") @skipIfTorchDynamo("This function already calls torch.compile.") def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]): cnts = CompileCounterWithBackend("aot_eager") diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index ad4918ef9103a..df3ca03b76242 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1713,10 +1713,6 @@ def xfailIfLinux(func): return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM and not IS_FBCODE else func -def xfailIfWindows(func): - return unittest.expectedFailure(func) if IS_WINDOWS else func - - def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"): """ Usage: From 9ff16b89d744afa7184f3b2df1f277b9f1327908 Mon Sep 17 00:00:00 2001 From: Chris Leonard Date: Thu, 4 Dec 2025 22:55:53 +0000 Subject: [PATCH 290/338] Updated hypot to accept cpu scalar Tensor even on GPU. (#169302) Updated hypot to accept cpu scalar Tensor even on GPU. Updated test_hypot unit test to make sure it is working. The hypot CUDA kernel was actually already capable of taking in both since it used 'opmath_symmetric_gpu_kernel_with_scalars'. However, there was a check before the kernel to make sure both input and output were on the same device. I just added 'device_check: NoCheck' to each of the hypot function variants in native_functions.yaml. @albanD Fixes #167567 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169302 Approved by: https://github.com/albanD --- aten/src/ATen/native/native_functions.yaml | 3 +++ test/test_binary_ufuncs.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1759951b68bdc..50192342ff331 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10081,6 +10081,7 @@ tags: pointwise - func: hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator structured: True structured_inherits: TensorIteratorBase dispatch: @@ -10088,11 +10089,13 @@ tags: pointwise - func: hypot(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator structured_delegate: hypot.out variants: method, function tags: pointwise - func: hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator structured_delegate: hypot.out variants: method tags: pointwise diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index d448f95319416..c8cc95222fa97 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -2899,6 +2899,18 @@ def test_hypot(self, device, dtype): expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy()) self.assertEqual(actual, expected, exact_dtype=False) + if torch.device(device).type == "cuda": + # test using cpu scalar with cuda. + x = torch.randn(10, device=device).to(dtype) + y = torch.tensor(2.0).to(dtype) + actual1 = torch.hypot(x, y) + actual2 = torch.hypot(y, x) + expected = np.hypot(x.cpu().numpy(), 2.0) + self.assertTrue(actual1.is_cuda) + self.assertTrue(actual2.is_cuda) + self.assertEqual(actual1, expected, exact_dtype=False) + self.assertEqual(actual2, expected, exact_dtype=False) + @onlyNativeDeviceTypes @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) def test_gcd(self, device, dtype): From 5a9322c076f30031ae13f4e3f39ecadd180672cd Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 4 Dec 2025 23:15:53 +0000 Subject: [PATCH 291/338] Remove unnecessary code in exception registration (#169365) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/169365 Approved by: https://github.com/zou3519 --- torch/csrc/Exceptions.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index cf74ddff576c3..32b9a4664f613 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -65,9 +65,6 @@ could not be completed because the input matrix is singular.", "Exception raised when device is out of memory", PyExc_RuntimeError, nullptr)); - PyTypeObject* type = - reinterpret_cast(THPException_OutOfMemoryError); - type->tp_name = "torch.OutOfMemoryError"; ASSERT_TRUE( PyModule_AddObject( module, "OutOfMemoryError", THPException_OutOfMemoryError) == 0); @@ -134,7 +131,6 @@ could not be completed because the input matrix is singular.", "Exception raised while executing on device", PyExc_RuntimeError, nullptr)); - type = reinterpret_cast(THPException_AcceleratorError); ASSERT_TRUE( PyModule_AddObject( module, "AcceleratorError", THPException_AcceleratorError) == 0); From b74e8231d2703d4a89bf8a6af55ea7e38e5dc83d Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 4 Dec 2025 23:20:20 +0000 Subject: [PATCH 292/338] Remove outdated Makefile targets (#164679) The build scripts of these targets are inexistent. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164679 Approved by: https://github.com/XuehaiPan, https://github.com/albanD --- Makefile | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/Makefile b/Makefile index 3db2b7aa44e76..9791630a881b1 100644 --- a/Makefile +++ b/Makefile @@ -11,18 +11,6 @@ all: @cmake -S . -B build $(shell $(PYTHON) ./scripts/get_python_cmake_flags.py) && \ cmake --build build --parallel -- -.PHONY: local -local: - @./scripts/build_local.sh - -.PHONY: android -android: - @./scripts/build_android.sh - -.PHONY: ios -ios: - @./scripts/build_ios.sh - .PHONY: triton triton: $(PIP) uninstall -y triton From b5e5ee907ee6b23a726e7a211302b503d5a22304 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 4 Dec 2025 23:33:37 +0000 Subject: [PATCH 293/338] improve accuracy of division by scalar (#169507) Pytorch emulates division by scalar as multiplication by 1/scalar, which already can throw results off by 1 ulp. However, since it computes 1/scalar in float as opposed to double, the results are different compared to what you'd get in python by writing `x * (1/scalar)`. This PR remedies that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169507 Approved by: https://github.com/ezyang --- aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu | 6 +++++- test/test_binary_ufuncs.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu b/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu index f3dfc2ba11a60..c7345633edd15 100644 --- a/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu @@ -39,7 +39,11 @@ void div_true_kernel_cuda(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { using opmath_t = at::opmath_type; - auto inv_b = opmath_t(1.0) / iter.scalar_value(2); + using high_prec_t = std::conditional_t< + c10::is_complex::value, + c10::complex, + double>; + auto inv_b = static_cast(high_prec_t(1.0) / iter.scalar_value(2)); iter.remove_operand(2); gpu_kernel( iter, diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index c8cc95222fa97..ff4684c5f945c 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1149,6 +1149,16 @@ def test_complex_div_underflow_overflow(self, device, dtype): res = nom / denom self.assertEqual(res, expected) + @onlyCUDA + @dtypes(torch.float, torch.bfloat16) + def test_division_by_scalar(self, device, dtype): + num = torch.rand(1024, device=device, dtype=dtype) + denom = torch.logspace(-4, 4, steps=20) + denom = [d.item() for d in denom] + res = [num / d for d in denom] + ref = [num * (1 / d) for d in denom] + self.assertEqual(res, ref, atol=0, rtol=0) + # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor # throws the correct error message @onlyCUDA From 6211bfd3ca87d64d7155da32b4548322ee5e6be8 Mon Sep 17 00:00:00 2001 From: karthickai Date: Thu, 4 Dec 2025 10:30:52 -0800 Subject: [PATCH 294/338] [Inductor] Fix combo kernels for cpu backend (#167781) This PR fixes two issues Fixes: #167780 combo_kernel fails with CppScheduling backend Fixes: #168067 combo_kernel fails with mixed cpu/cuda nodes Pull Request resolved: https://github.com/pytorch/pytorch/pull/167781 Approved by: https://github.com/mlazos --- test/inductor/test_cpu_cpp_wrapper.py | 3 +- test/inductor/test_gpu_cpp_wrapper.py | 3 +- test/inductor/test_torchinductor.py | 101 +++++++++++++++++--------- torch/_inductor/scheduler.py | 34 ++++++--- 4 files changed, 96 insertions(+), 45 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 47a8f3aa063e3..e96651dba3e35 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -171,7 +171,8 @@ class BaseTest(NamedTuple): BaseTest("test_add_complex4"), BaseTest("test_add_complex4", test_build_separate=True), BaseTest("test_as_strided"), # buffer reuse - BaseTest("test_bernoulli1"), + BaseTest("test_bernoulli1_combo_kernels_False"), + BaseTest("test_bernoulli1_combo_kernels_True"), BaseTest("test_bitwise"), # int32 BaseTest("test_bmm1"), BaseTest("test_bmm1", test_build_separate=True), diff --git a/test/inductor/test_gpu_cpp_wrapper.py b/test/inductor/test_gpu_cpp_wrapper.py index 832b119c8455d..db5e9bd6429ed 100644 --- a/test/inductor/test_gpu_cpp_wrapper.py +++ b/test/inductor/test_gpu_cpp_wrapper.py @@ -196,7 +196,8 @@ class BaseTest(NamedTuple): BaseTest("test_add_complex4"), BaseTest("test_as_strided"), # buffer reuse BaseTest("test_batch_norm_2d_2"), - BaseTest("test_bernoulli1"), + BaseTest("test_bernoulli1_combo_kernels_False"), + BaseTest("test_bernoulli1_combo_kernels_True"), BaseTest("test_bitwise"), # int32 BaseTest("test_bmm1"), BaseTest("test_bmm2"), diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index c5bdab4135b0f..09831d4ae28b6 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4988,29 +4988,32 @@ def test_conv3d_channels_last(self, use_block_ptr: bool): @skip_if_gpu_halide # slow @xfail_if_mps # Non-divisible input sizes are not implemented on MPS device - def test_adaptive_avg_pool2d1(self): - def fn(x): - return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d( - x + 1, (2, 5) - ) + @parametrize("combo_kernels", (False, True)) + def test_adaptive_avg_pool2d1(self, combo_kernels): + with config.patch(combo_kernels=combo_kernels): - self.common( - fn, - (torch.randn(2, 4, 16, 16),), - check_lowp=False, - ) + def fn(x): + return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d( + x + 1, (2, 5) + ) - # lowering to avg_pool2d case - self.common( - fn, - (torch.randn(2, 4, 3, 3),), - ) + self.common( + fn, + (torch.randn(2, 4, 16, 16),), + check_lowp=False, + ) - # no-op case - self.common( - fn, - (torch.randn(2, 4, 6, 6),), - ) + # lowering to avg_pool2d case + self.common( + fn, + (torch.randn(2, 4, 3, 3),), + ) + + # no-op case + self.common( + fn, + (torch.randn(2, 4, 6, 6),), + ) @xfail_if_mps # Non-divisible input sizes are not implemented on MPS device def test_adaptive_avg_pool2d2(self): @@ -8594,22 +8597,25 @@ def fn(x, y): self.common(fn, [torch.randn(1, 1024), torch.randn(1, 1024, 2)]) + @parametrize("combo_kernels", (False, True)) @config.patch(fallback_random=True) - def test_bernoulli1(self): - def fn(a): - b = a.clone() - # aten.bernoulli_() uses aten.bernoulli.p() behind the scene, so it will be decomposed. - return aten.bernoulli_(b).sum() / torch.prod(torch.tensor(a.size())) + def test_bernoulli1(self, combo_kernels): + with config.patch(combo_kernels=combo_kernels): - p = 0.3 - self.common( - fn, - [ - torch.ones(200, 200) * p, - ], - atol=p * 0.06, - rtol=0.06, - ) + def fn(a): + b = a.clone() + # aten.bernoulli_() uses aten.bernoulli.p() behind the scene, so it will be decomposed. + return aten.bernoulli_(b).sum() / torch.prod(torch.tensor(a.size())) + + p = 0.3 + self.common( + fn, + [ + torch.ones(200, 200) * p, + ], + atol=p * 0.06, + rtol=0.06, + ) @skip_if_triton_cpu def test_bernoulli2(self): @@ -14818,6 +14824,33 @@ def fn(x, max_val): ), ) + @config.patch(combo_kernels=True) + def test_combo_kernel_filter_cpu(self): + def fn(a, b, c, d): + a = a * 4 + b = b + 8 + return a, b, c.min(-1), d.max(-1) + + inps = [ + torch.rand(20, 20, device=self.device), + torch.rand(30, 30, device=self.device), + torch.rand(256, 256, device=self.device), + torch.rand(256, 256, device=self.device), + ] + torch._inductor.metrics.reset() + compiled_fn = torch.compile(fn) + result = compiled_fn(*inps) + expected = fn(*inps) + + self.assertEqual(result, expected) + # on cuda combo kernel fuses (a, b) into one kernel and (c, d) into another (total 2) + # on cpu combo kernel is skipped (a), (b), (c), (d) each run as separate kernels (total 4) + if self.device.lower() == "cpu": + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) + + if self.device.lower() == "cuda": + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + # end of class CommonTemplate - add new tests here diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index aeaed244cab2f..4e77af7075eaf 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2330,12 +2330,24 @@ def _default_group_nodes_for_combo_kernels( grouped_nodes = [] max_num_nodes = 8 for nodes in sorted_nodes: - grouped_nodes.extend( - [ - nodes[i : i + max_num_nodes] - for i in range(0, len(nodes), max_num_nodes) - ] + # Group nodes by device first to avoid mixed-device fusion + device_groups: dict[Optional[torch.device], list[BaseSchedulerNode]] = ( + defaultdict(list) ) + for node in nodes: + device = node.get_device() + if device and (device.type == "mps" or device.type == "cpu"): + continue + device_groups[device].append(node) + + # Chunk each device group separately + for device_nodes in device_groups.values(): + grouped_nodes.extend( + [ + device_nodes[i : i + max_num_nodes] + for i in range(0, len(device_nodes), max_num_nodes) + ] + ) return grouped_nodes @@ -2480,6 +2492,9 @@ def estimate_flops(self) -> int | None: def get_nodes(self) -> Sequence[BaseSchedulerNode]: return self.snodes + def get_device(self) -> Optional[torch.device]: + return self.snodes[0].get_device() if self.snodes else None + @classmethod def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: # GroupedSchedulerNode cannot be fused with another node @@ -6147,14 +6162,15 @@ def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool: If config.benchmark_fusion is False, always return True. Otherwise, return True if fusion can brings speedup. """ - if not config.benchmark_combo_kernel: - return True subkernel_nodes = nodes device = subkernel_nodes[0].get_device() - # don't support benchmark fusion for CPU C++ backend right now. - if device is None or (device.type == "cpu" and config.cpu_backend != "triton"): + assert all(node.get_device() == device for node in subkernel_nodes), ( + "All nodes in a combo kernel group must be on the same device" + ) + + if not config.benchmark_combo_kernel: return True from triton.compiler.errors import CompilationError From 607a53592ba70ed7302f3f88bd9d4fb0eac5ae88 Mon Sep 17 00:00:00 2001 From: angelayi Date: Thu, 4 Dec 2025 23:46:16 +0000 Subject: [PATCH 295/338] [effects] Various effect/dce fixes (#169141) In torch.compile there are 2 places where DCE happens -- once in post-grad passes with a call to FX's `eliminate_dead_code`, and a second time where inductor's scheduler checks if a node `has_side_effects`. To prevent a node from being DCEd in FX's DCE, we can register `torch.fx.node.has_side_effect(op)`. However this does not propagate to inductor. 1. Fixes FX DCE such that if an operator is wrapped with auto_functionalize, then we check if the operator passed to auto_functionalized has side effects to determine if that call should be DCEd 2. Fixes inductor's FallbackKernel's has_side_effect function to also check if the operator is in FX's _side_effectful_fns. This way operators registered as having side effects through `torch.fx.node.has_side_effect` will show up 5. Update FX DCE to log what nodes have been DCEd. Inductor's nodes that are DCEd will show up in `TORCH_LOGS="+inductor"` logs with the prefix, `removed dead buffer: ...` 6. If an operator has side effects, marked through torch.library._register_effectful_op, this will also make sure FX's DCE does not DCE the operator. a. Some additional changes needed to be made -- We mark all operators with ScriptObjects as inputs as being side effectful by default. However the above change causes an issue where if a graph contains a quantization op using a PackedParamsBase, then the op will not get DCEd, causing some unittest failures. So I marked some of the quantization ScriptObject classes as being not effectful. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169141 Approved by: https://github.com/yushangdi, https://github.com/zou3519 --- test/higher_order_ops/test_with_effects.py | 37 +++++++++ test/inductor/test_torchinductor.py | 55 +++++++++++++ torch/_inductor/ir.py | 9 ++- torch/_inductor/lowering.py | 2 + torch/_library/effects.py | 16 ++++ torch/_library/utils.py | 89 ++++++++++++++++++++++ torch/fx/graph.py | 11 ++- torch/fx/node.py | 68 ++++++++--------- 8 files changed, 245 insertions(+), 42 deletions(-) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index c612c3a65ce0b..d6304810143e7 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -561,6 +561,43 @@ def forward(self, tangents_1, tangents_2, tangents_token): return (clone, clone_1, tangents_1, tangents_2, getitem_6)""", ) + def test_dce(self): + # If an operator is marked as side effectful, it should not get DCEd by + # FX's eliminate_dead_code + + with torch.library._scoped_library("mylib", "FRAGMENT") as m: + log3 = [] + + @torch.library.custom_op( + "mylib::my_logger3", + mutates_args=(), + ) + def my_logger3(s: str, t: torch.Tensor) -> torch.Tensor: + log3.append(s) + return torch.zeros(1) + + @my_logger3.register_fake + def my_logger3(s, t) -> torch.Tensor: + return torch.zeros(1) + + # Registering an op as being effectful should also prevent FX DCE + from torch._library.effects import EffectType + + torch.library._register_effectful_op( + "mylib::my_logger3", EffectType.ORDERED + ) + + def foo(x): + b = torch.scalar_tensor(x.shape[0]) + torch.ops.mylib.my_logger3("moo", b) + return x + x + + gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 3)) + gm.graph.eliminate_dead_code() + gm.recompile() + gm(torch.ones(3, 3)) + self.assertTrue(len(log3), 1) + def test_effects_and_input_mutation_return(self): def fn(a, b): torch.ops.aten._print("effect") diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 09831d4ae28b6..16d2fc706fb6e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14737,6 +14737,61 @@ def fn(repeat, output_size, data): "Generated Triton code should use triton_helpers.minimum for clamping", ) + @config.patch(implicit_fallbacks=True) + def test_custom_op_dce(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as m: + # CASE 1: The op should get wrapped with auto_functionalized, and + # FX's DCE should not remove it because this op is registered as + # effectful + + log1 = [] + + @torch.library.custom_op( + "mylib::my_logger1", + mutates_args="unknown", + ) + def my_logger1(s: str, t: torch.Tensor) -> torch.Tensor: + log1.append(s) + return torch.zeros(1) + + @my_logger1.register_fake + def my_logger1(s, t) -> torch.Tensor: + return torch.zeros(1) + + def foo(x): + b = torch.scalar_tensor(x.shape[0]) + torch.ops.mylib.my_logger1("moo", b) + return x + x + + torch.fx.node.has_side_effect(torch.ops.mylib.my_logger1.default) + torch.compile(foo, fullgraph=True)(torch.ones(3, 3)) + self.assertTrue(len(log1), 1) + + # CASE 2: The op should not get DCEd by TorchInductor + + log2 = [] + + @torch.library.custom_op( + "mylib::my_logger2", + mutates_args=(), + ) + def my_logger2(s: str, t: torch.Tensor) -> torch.Tensor: + log2.append(s) + return torch.zeros(1) + + @my_logger2.register_fake + def my_logger2(s, t) -> torch.Tensor: + return torch.zeros(1) + + def foo(x): + b = torch.scalar_tensor(x.shape[0]) + torch.ops.mylib.my_logger2("moo", b) + return x + x + + torch.fx.node.has_side_effect(torch.ops.mylib.my_logger2.default) + torch.compile(foo, fullgraph=True)(torch.ones(3, 3)) + self.assertTrue(len(log2), 1) + @skipIfMPS # Accuracy issue on MPS def test_weight_norm_conv2d(self): """ diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8ba7ab9311b6c..eef842251e531 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -50,7 +50,6 @@ make_channels_last_strides_for, StrideType, ) -from torch._subclasses.fake_tensor import get_schema_info from torch.fx.experimental.symbolic_shapes import ( _remove_effect_token_unbacked_bindings, compute_unbacked_bindings, @@ -7882,9 +7881,11 @@ def find_device( return None def has_side_effects(self) -> bool: - if isinstance(self.op_overload, torch._ops.HigherOrderOperator): - return False - return get_schema_info(self.op_overload).is_mutable() + from torch._library.utils import is_impure + + # Note: We don't pass args/kwargs here because they're IRNodes, not actual values + # The check is done on the op_overload itself + return is_impure(self.op_overload) # pyrefly: ignore[bad-argument-type] def get_inputs_that_alias_output(self) -> Sequence[str]: assert isinstance( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 45f660f04674b..bfd8ca2c05efa 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2705,6 +2705,8 @@ def require_channels_last(_, *args, **kwargs): def constrain_to_fake_tensor(arg, fake_arg): + if fake_arg is None: + return arg if isinstance(fake_arg, FakeScriptObject): return arg if isinstance(arg, ir.IRNode): diff --git a/torch/_library/effects.py b/torch/_library/effects.py index 41fbaa4c1c7b4..e69c361789b5d 100644 --- a/torch/_library/effects.py +++ b/torch/_library/effects.py @@ -11,6 +11,19 @@ class EffectType(Enum): from torch._library.utils import RegistrationHandle +# These classes do not have side effects as they just store quantization +# params, so we dont need to mark them as ordered +skip_classes = ( + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase", + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase", + "__torch__.torch.classes.quantized.EmbeddingPackedParamsBase", + "__torch__.torch.classes.quantized.LinearPackedParamsBase", + "__torch__.torch.classes.xnnpack.Conv2dOpContext", + "__torch__.torch.classes.xnnpack.LinearOpContext", + "__torch__.torch.classes.xnnpack.TransposeConv2dOpContext", +) + + class EffectHolder: """A holder where one can register an effect impl to.""" @@ -42,6 +55,9 @@ def _set_default_effect(self) -> None: schema = torch._C._get_schema(opname, overload) for arg in schema.arguments: if isinstance(arg.type, torch.ClassType): + type_str = arg.type.str() # pyrefly: ignore[missing-attribute] + if type_str in skip_classes: + continue self._effect = EffectType.ORDERED return diff --git a/torch/_library/utils.py b/torch/_library/utils.py index edbe86992b6ad..d5d2eee465886 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -554,3 +554,92 @@ def get_layout_constraint_tag(fn, *, with_default=True): return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) return None + + +# List of random functions that should be treated as impure +_RANDOM_FUNCTIONS = { + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.rand_like, + torch.randn_like, + torch.randint_like, + torch.normal, + torch.poisson, + torch.bernoulli, + torch.multinomial, +} + + +def is_impure( + op: Callable, + *, + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, + impure_random: bool = True, +) -> bool: + """ + An operator is impure if it: + - Mutates its inputs (has a mutable schema) + - Has nondeterministic/random behavior that mutates RNG state + - Is explicitly marked as effectful via torch.library._register_effectful_op + + Args: + op: The operator to check (function, OpOverload, HigherOrderOperator, etc.) + args: Optional arguments that would be passed to the callable + kwargs: Optional keyword arguments that would be passed to the callable + impure_random: Whether to treat random operations as impure (default: True) + + Returns: + bool: True if the callable has side effects, False otherwise + """ + # Import here to avoid circular dependencies + from torch.fx.node import _side_effectful_functions + + if isinstance(op, torch._ops.OpOverload): + schema = getattr(op, "_schema", None) + if schema is not None and schema.is_mutable: + return True + + if op in _side_effectful_functions: + return True + + from torch._higher_order_ops.effects import _get_effect + + if _get_effect(op) is not None: + return True + + if isinstance(op, torch._ops.HigherOrderOperator): + if op in ( + torch.ops.higher_order.auto_functionalized, + torch.ops.higher_order.auto_functionalized_v2, + ): + # Check if the auto-functionalized operator (the first argument) is + # side-effectful + if args and len(args) > 0: + return args[0] in _side_effectful_functions + + return False + + # Impure since it mutates RNG state + if impure_random and getattr(op, "_nondeterministic_seeded", False): + return True + + # Handle Python random functions that don't have _nondeterministic_seeded + # but still affect global RNG state (issue #151524) + # These should be impure regardless of impure_random setting to maintain + # consistency between eager and compiled execution + # All random operations are impure to ensure consistent behavior + # between eager and compiled execution, regardless of generator usage + if op in _RANDOM_FUNCTIONS: + return True + + schema = getattr(op, "_schema", None) + if schema is not None and schema.is_mutable: + return True + + if op in _side_effectful_functions: + return True + + return False diff --git a/torch/fx/graph.py b/torch/fx/graph.py index d4b0a1b1500d3..db863f1987289 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -6,6 +6,7 @@ import functools import inspect import keyword +import logging import math import os import pprint @@ -30,6 +31,8 @@ from .node import _get_qualified_name, _type_repr, Argument, Node, Target +log = logging.getLogger(__name__) + __all__ = ["PythonCode", "CodeGen", "Graph"] if TYPE_CHECKING: @@ -2085,11 +2088,15 @@ def has_side_effect(node): # Reverse iterate so that when we remove a node, any nodes used as an # input to that node have an updated user count that no longer reflects # the removed node. - changed = False + removed_nodes = set() for node in reversed(self.nodes): if not has_side_effect(node) and len(node.users) == 0: self.erase_node(node) - changed = True + removed_nodes.add(node.name) + + changed = len(removed_nodes) > 0 + if changed: + log.info("The following nodes were dead code eliminated: %s", removed_nodes) # Call DCE on the subgraphs if self.owning_module is not None: diff --git a/torch/fx/node.py b/torch/fx/node.py index 85e6f3a82e969..9e07ba824aff3 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -87,6 +87,14 @@ # TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, # or add logic to correctly mark all inplace ops as side effectful. +# +# NOTE: For new operators, please do not add to this set! +# Instead, consider using the effects system via +# torch.library._register_effectful_op() for operators. +# +# This _side_effectful_functions set is only for: +# - Legacy functions that aren't operators (e.g., profiler ops, asserts) +# - Things that cannot be marked via the normal effects system _side_effectful_functions: set[Callable[..., Any]] = { torch._assert, torch._assert_async, @@ -109,6 +117,18 @@ @compatibility(is_backward_compatible=False) def has_side_effect(fn: Callable[_P, _R]) -> Callable[_P, _R]: + """ + Registers a function to not be dead code eliminated by + fx.graph.eliminate_dead_code + + NOTE: For new operators, please do not add to this set! + Instead, consider using the effects system via + torch.library._register_effectful_op() for operators. + + This _side_effectful_functions set is only for: + - Legacy functions that aren't operators (e.g., profiler ops, asserts) + - Things that cannot be marked via the normal effects system + """ _side_effectful_functions.add(fn) return fn @@ -717,45 +737,10 @@ def is_impure(self, impure_random: bool = True) -> bool: bool: If the op is impure or not. """ + # Placeholders and outputs are always impure for DCE purposes if self.op in {"placeholder", "output"}: return True - if self.op == "call_function": - schema = getattr(self.target, "_schema", None) - if schema is not None and schema.is_mutable: - # impure since it mutates inputs - return True - - if impure_random: - if getattr(self.target, "_nondeterministic_seeded", False): - # impure since it mutates RNG state - return True - - # Handle Python random functions that don't have _nondeterministic_seeded - # but still affect global RNG state (issue #151524) - # These should be impure regardless of impure_random setting to maintain - # consistency between eager and compiled execution - _random_functions = { - torch.rand, - torch.randn, - torch.randint, - torch.randperm, - torch.rand_like, - torch.randn_like, - torch.randint_like, - torch.normal, - torch.poisson, - torch.bernoulli, - torch.multinomial, - } - - if self.target in _random_functions: - # All random operations are impure to ensure consistent behavior - # between eager and compiled execution, regardless of generator usage - return True - - return self.target in _side_effectful_functions - # Check if an impure module. if self.op == "call_module": assert self.graph.owning_module is not None, ( @@ -771,6 +756,17 @@ def is_impure(self, impure_random: bool = True) -> bool: # and some users depend on current elimination behavior. return getattr(target_mod, "_is_impure", False) + # For call_function, delegate to the unified has_side_effects function + if self.op == "call_function": + from torch._library.utils import is_impure + + return is_impure( + self.target, # pyrefly: ignore[bad-argument-type] + args=self.args, + kwargs=self.kwargs, + impure_random=impure_random, + ) + return False @compatibility(is_backward_compatible=False) From e7a86c4be00505195e5932abfe946211d081ac99 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:49:09 +0000 Subject: [PATCH 296/338] [CI] Pin CPython to 3.14.0 (#169592) As updating to 3.14.1 causes number of regressions Fixes https://github.com/pytorch/pytorch/issues/169587 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169592 Approved by: https://github.com/atalman, https://github.com/huydhn --- .ci/docker/common/install_conda.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 481de54a50f2c..bfcf19947e905 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -86,6 +86,12 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then conda_install_through_forge libstdcxx-ng=14 fi + # NS: Workaround for https://github.com/pytorch/pytorch/issues/169586 + # Downgrade cpython to 3.14.0 + if [ "$ANACONDA_PYTHON_VERSION" = "3.14" ]; then + conda_install python==3.14.0 + fi + # Install some other packages, including those needed for Python test reporting pip_install -r /opt/conda/requirements-ci.txt From 3199a3eaab5171eb70bab31510311f64707a5b6e Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:51:59 +0000 Subject: [PATCH 297/338] localtensor: test_utils.py (#169532) adds localtensor support for these tests. don't touch the ones that are localtensor-native Pull Request resolved: https://github.com/pytorch/pytorch/pull/169532 Approved by: https://github.com/dzmitry-huba --- test/distributed/tensor/test_utils.py | 236 +++++++++++++++----- torch/distributed/_local_tensor/__init__.py | 53 +++++ torch/distributed/tensor/_random.py | 8 +- torch/distributed/tensor/_utils.py | 35 ++- torch/distributed/tensor/placement_types.py | 6 +- 5 files changed, 264 insertions(+), 74 deletions(-) diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 871b8e19f4c41..d76da428bc32a 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] + import itertools from contextlib import nullcontext from typing import Any @@ -10,6 +11,7 @@ local_tensor_mode, LocalTensor, LocalTensorMode, + maybe_run_for_local_tensor, ) from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor @@ -32,6 +34,7 @@ ) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorTestBase, generate_shard_orders, LocalDTensorTestBase, @@ -309,11 +312,17 @@ def test_compute_global_tensor_shape_1D(self): for placements in one_d_placements: if isinstance(placements[0], Shard): uneven_dim = list(range(self.world_size)) - local_shape = ( - torch.Size([5, uneven_dim[self.rank]]) - if placements[0].dim == 1 - else torch.Size([uneven_dim[self.rank], 5]) - ) + + @maybe_run_for_local_tensor + def get_local_shape(rank): + local_shape = ( + torch.Size([5, uneven_dim[rank]]) + if placements[0].dim == 1 + else torch.Size([uneven_dim[rank], 5]) + ) + return local_shape + + local_shape = get_local_shape(self.rank) expected_global_shape = ( torch.Size([5, sum(uneven_dim)]) if placements[0].dim == 1 @@ -322,6 +331,7 @@ def test_compute_global_tensor_shape_1D(self): else: expected_global_shape = torch.Size([5, 5]) local_shape = torch.Size([5, 5]) + global_shape = compute_global_tensor_shape( local_shape, device_mesh, placements ) @@ -332,11 +342,18 @@ def test_compute_global_tensor_shape_1D_invalid_shape(self): one_d_placement = [Shard(1)] device_mesh = init_device_mesh(self.device_type, (self.world_size,)) uneven_dim = list(range(self.world_size)) - local_shape = ( - torch.Size([5, uneven_dim[self.rank]]) - if self.rank % 2 == 0 - else torch.Size([6, uneven_dim[self.rank]]) - ) + + @maybe_run_for_local_tensor + def get_local_shape(rank): + local_shape = ( + torch.Size([5, uneven_dim[rank]]) + if rank % 2 == 0 + else torch.Size([6, uneven_dim[rank]]) + ) + return local_shape + + local_shape = get_local_shape(self.rank) + with self.assertRaisesRegex( RuntimeError, "Non-sharded dimensions should have identical size across ranks.", @@ -424,11 +441,29 @@ def test_compute_local_shape_and_global_offset_2D(self): dim0_start, dim0_end = dim[0][0], dim[0][1] dim1_start, dim1_end = dim[1][0], dim[1][1] - # Check the local tensor of dtensor is exactly the same - # if we slice the global_tensor with local_size and global_offset - self.assertEqual( + @maybe_run_for_local_tensor + def maybe_compute_rankwise( + dim0_start, + dim0_end, + dim1_start, + dim1_end, + local_tensor, + global_tensor, + ): + # Check the local tensor of dtensor is exactly the same + # if we slice the global_tensor with local_size and global_offset + self.assertEqual( + local_tensor, + global_tensor[dim0_start:dim0_end, dim1_start:dim1_end], + ) + + maybe_compute_rankwise( + dim0_start, + dim0_end, + dim1_start, + dim1_end, dtensor.to_local(), - global_tensor[dim0_start:dim0_end, dim1_start:dim1_end], + global_tensor, ) @with_comms @@ -543,8 +578,13 @@ def test_uneven_fsdp_tp_meta_compute(self): rank = global_mesh.get_rank() expected_shapes = [2, 2, 2, 2, 2, 2, 2, 1] expected_offsets = [0, 8, 2, 10, 4, 12, 6, 14] - self.assertEqual(local_shape[0], expected_shapes[rank]) - self.assertEqual(global_offset[0], expected_offsets[rank]) + + @maybe_run_for_local_tensor + def maybe_compute_rankwise(rank, local_shape, global_offset): + self.assertEqual(local_shape[0], expected_shapes[rank]) + self.assertEqual(global_offset[0], expected_offsets[rank]) + + maybe_compute_rankwise(rank, local_shape, global_offset) @with_comms def test_hsdp_tp_meta_compute(self): @@ -688,8 +728,15 @@ def test_1d_mesh_strided_sharding(self): """ shard_placement = _StridedShard(0, split_factor=1) # same as Shard(0) tensor_list, _ = shard_placement._split_tensor(x, self.world_size) - shard_x = tensor_list[self.rank] - self.assertEqual(shard_x, x.view(self.world_size, -1)[self.rank]) + + @maybe_run_for_local_tensor + def maybe_compute_rankwise(rank, tensor_list, x): + shard_x = tensor_list[rank] + self.assertEqual(shard_x, x.view(self.world_size, -1)[rank]) + return shard_x + + shard_x = maybe_compute_rankwise(self.rank, tensor_list, x) + # shard_to_replicate full_tensor = shard_placement._to_replicate_tensor( shard_x, @@ -704,10 +751,15 @@ def test_1d_mesh_strided_sharding(self): """ shard_placement = _StridedShard(0, split_factor=2) tensor_list, _ = shard_placement._split_tensor(x, self.world_size) - shard_x = tensor_list[self.rank] - self.assertEqual( - shard_x, x.view(-1, self.world_size).swapdims(-1, 0)[self.rank] - ) + + @maybe_run_for_local_tensor + def maybe_compute_rankwise(rank, tensor_list, x): + shard_x = tensor_list[rank] + self.assertEqual(shard_x, x.view(-1, self.world_size).swapdims(-1, 0)[rank]) + return shard_x + + shard_x = maybe_compute_rankwise(self.rank, tensor_list, x) + # shard_to_replicate full_tensor = shard_placement._to_replicate_tensor( shard_x, @@ -737,16 +789,31 @@ def test_2d_mesh_strided_sharding(self): # shard on mesh dim-0 shard_placement_dim0 = _StridedShard(0, split_factor=1) # same as Shard(0) tensor_list, _ = shard_placement_dim0._split_tensor(x, mesh_dim0_size) - expected_shard_dim0 = x.view(mesh_dim0_size, -1)[mesh_dim0_local_rank] - shard_x = tensor_list[mesh_dim0_local_rank] - self.assertEqual(shard_x, expected_shard_dim0) - - # shard on mesh dim-1 shard_placement_dim1 = _StridedShard(0, split_factor=1) # same as Shard(0) + + @maybe_run_for_local_tensor + def maybe_compute_rankwise_strided(mesh_dim0_local_rank): + expected_shard_dim0 = x.view(mesh_dim0_size, -1)[mesh_dim0_local_rank] + shard_x = tensor_list[mesh_dim0_local_rank] + self.assertEqual(shard_x, expected_shard_dim0) + return shard_x, expected_shard_dim0 + + shard_x, expected_shard_dim0 = maybe_compute_rankwise_strided( + mesh_dim0_local_rank + ) tensor_list, _ = shard_placement_dim1._split_tensor(shard_x, mesh_dim1_size) - expected_shard_dim1 = shard_x.view(mesh_dim1_size, -1)[mesh_dim1_local_rank] - shard_x = tensor_list[mesh_dim1_local_rank] - self.assertEqual(shard_x, expected_shard_dim1) + + @maybe_run_for_local_tensor + def maybe_compute_rankwise_strided(mesh_dim1_local_rank): + expected_shard_dim1 = shard_x.view(mesh_dim1_size, -1)[mesh_dim1_local_rank] + shard_x2 = tensor_list[mesh_dim1_local_rank] + self.assertEqual(shard_x2, expected_shard_dim1) + + return shard_x2, expected_shard_dim0 + + shard_x, expected_shard_dim0 = maybe_compute_rankwise_strided( + mesh_dim1_local_rank + ) # shard_to_replicate on mesh dim-1 full_tensor = shard_placement_dim1._to_replicate_tensor( @@ -759,11 +826,12 @@ def test_2d_mesh_strided_sharding(self): # shard_to_replicate on mesh dim-0 full_tensor = shard_placement_dim0._to_replicate_tensor( - full_tensor, + full_tensor.reconcile() if self.is_local_tensor_enabled else full_tensor, mesh_2d, mesh_dim=0, current_logical_shape=list(x.shape), ) + self.assertEqual(full_tensor, x) """ @@ -776,22 +844,36 @@ def test_2d_mesh_strided_sharding(self): # shard on mesh dim-0 shard_placement_dim0 = _StridedShard(0, split_factor=split_factor) tensor_list, _ = shard_placement_dim0._split_tensor(x, mesh_dim0_size) - shard_x = tensor_list[mesh_dim0_local_rank] - expected_shard_dim0 = ( - torch.tensor([0, 1, 4, 5], device=self.device_type) - if mesh_dim0_local_rank == 0 - else torch.tensor([2, 3, 6, 7], device=self.device_type) - ) - self.assertEqual(shard_x, expected_shard_dim0) - - # shard on mesh dim-1 shard_placement_dim1 = _StridedShard(0, split_factor=1) # same as Shard(0) + + @maybe_run_for_local_tensor + def maybe_compute_rankwise_strided(mesh_dim0_local_rank): + shard_x = tensor_list[mesh_dim0_local_rank] + expected_shard_dim0 = ( + torch.tensor([0, 1, 4, 5], device=self.device_type) + if mesh_dim0_local_rank == 0 + else torch.tensor([2, 3, 6, 7], device=self.device_type) + ) + self.assertEqual(shard_x, expected_shard_dim0) + return shard_x, expected_shard_dim0 + + shard_x, expected_shard_dim0 = maybe_compute_rankwise_strided( + mesh_dim0_local_rank + ) tensor_list, _ = shard_placement_dim1._split_tensor(shard_x, mesh_dim1_size) - shard_x = tensor_list[mesh_dim1_local_rank] - expected_shard_dim1 = expected_shard_dim0.view(mesh_dim1_size, -1)[ + + @maybe_run_for_local_tensor + def maybe_compute_rankwise_strided(mesh_dim1_local_rank): + shard_x2 = tensor_list[mesh_dim1_local_rank] + expected_shard_dim1 = expected_shard_dim0.view(mesh_dim1_size, -1)[ + mesh_dim1_local_rank + ] + self.assertEqual(shard_x2, expected_shard_dim1) + return shard_x2, expected_shard_dim0 + + shard_x, expected_shard_dim0 = maybe_compute_rankwise_strided( mesh_dim1_local_rank - ] - self.assertEqual(shard_x, expected_shard_dim1) + ) # shard_to_replicate on mesh dim-1 full_tensor = shard_placement_dim1._to_replicate_tensor( @@ -804,7 +886,7 @@ def test_2d_mesh_strided_sharding(self): # shard_to_replicate on mesh dim-0 full_tensor = shard_placement_dim0._to_replicate_tensor( - full_tensor, + full_tensor.reconcile() if self.is_local_tensor_enabled else full_tensor, mesh_2d, mesh_dim=0, current_logical_shape=list(x.shape), @@ -833,23 +915,40 @@ def test_2d_mesh_2d_tensor_strided_sharding(self): # shard on mesh dim-0 shard_placement_dim0 = _StridedShard(1, split_factor=split_factor) tensor_list, _ = shard_placement_dim0._split_tensor(x, mesh_dim0_size) - shard_x = tensor_list[mesh_dim0_local_rank] - expected_shard_dim0 = ( - torch.tensor([[0, 2], [4, 6]], device=self.device_type) - if mesh_dim0_local_rank == 0 - else torch.tensor([[1, 3], [5, 7]], device=self.device_type) + + @maybe_run_for_local_tensor + def maybe_compute_rankwise_strided(mesh_dim0_local_rank, tensor_list): + shard_x2 = tensor_list[mesh_dim0_local_rank] + expected_shard_dim0 = ( + torch.tensor([[0, 2], [4, 6]], device=self.device_type) + if mesh_dim0_local_rank == 0 + else torch.tensor([[1, 3], [5, 7]], device=self.device_type) + ) + self.assertEqual(shard_x2, expected_shard_dim0) + return shard_x2, expected_shard_dim0 + + shard_x, expected_shard_dim0 = maybe_compute_rankwise_strided( + mesh_dim0_local_rank, tensor_list ) - self.assertEqual(shard_x, expected_shard_dim0) # shard on mesh dim-1 shard_placement_dim1 = _StridedShard(1, split_factor=1) # same as Shard(1) tensor_list, _ = shard_placement_dim1._split_tensor(shard_x, mesh_dim1_size) - shard_x = tensor_list[mesh_dim1_local_rank] - expected_shard_dim1 = [ - torch.tensor(value, device=self.device_type) - for value in [[[0], [4]], [[2], [6]], [[1], [5]], [[3], [7]]] - ][self.rank] - self.assertEqual(shard_x, expected_shard_dim1) + + @maybe_run_for_local_tensor + def maybe_compute_rankwise_strided(mesh_dim1_local_rank, rank, tensor_list): + shard_x = tensor_list[mesh_dim1_local_rank] + expected_shard_dim1 = [ + torch.tensor(value, device=self.device_type) + for value in [[[0], [4]], [[2], [6]], [[1], [5]], [[3], [7]]] + ][rank] + self.assertEqual(shard_x, expected_shard_dim1) + + return shard_x, expected_shard_dim0 + + shard_x, expected_shard_dim0 = maybe_compute_rankwise_strided( + mesh_dim1_local_rank, self.rank, tensor_list + ) # shard_to_replicate on mesh dim-1 full_tensor = shard_placement_dim1._to_replicate_tensor( @@ -858,7 +957,13 @@ def test_2d_mesh_2d_tensor_strided_sharding(self): mesh_dim=1, current_logical_shape=list(expected_shard_dim0.shape), ) - self.assertEqual(full_tensor, expected_shard_dim0) + + self.assertEqual( + full_tensor, + expected_shard_dim0.reconcile() + if self.is_local_tensor_enabled + else expected_shard_dim0, + ) # shard_to_replicate on mesh dim-0 full_tensor = shard_placement_dim0._to_replicate_tensor( @@ -1006,8 +1111,13 @@ def test_fsdp2_tp_2d_dtensor_local_shards_and_offsets(self): global_tensor, tp_mesh, placements=[Shard(0)] ) chunks = list(torch.chunk(dtensor_tp.to_local(), 2, dim=0)) - shard_rank = 0 if self.rank // 2 == 0 else 1 - sharded_param = chunks[shard_rank] + + @maybe_run_for_local_tensor + def get_sharded_param(rank, chunks): + shard_rank = 0 if rank // 2 == 0 else 1 + return chunks[shard_rank] + + sharded_param = get_sharded_param(self.rank, chunks) spec_2d = DTensorSpec( mesh=mesh_2d, placements=(_StridedShard(0, split_factor=2), Shard(0)), @@ -1147,5 +1257,11 @@ def test_explicit_matmul(self): loss.backward(retain_graph=True) +UtilTestWithLocalTensor = create_local_tensor_test_class(UtilTest) +TestStridedShardingWithLocalTensor = create_local_tensor_test_class(TestStridedSharding) +Test2DStridedLocalShardWithLocalTensor = create_local_tensor_test_class( + Test2DStridedLocalShard +) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index c780e1ef7cb8a..7382fa4f934af 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -300,6 +300,14 @@ def _combine_any_rank_results(rank_results: dict[int, Any]) -> Any: if isinstance(any_v, int): return _combine_int_rank_results(rank_results) + if isinstance(any_v, torch.device): + assert all(v.type == any_v.type for v in rank_results.values()), ( + "device type should be the same" + ) + # Just use the first device - the device type is what matters, + # and LocalTensorMode runs on a single physical device anyway + return any_v + assert all(v == any_v for v in rank_results.values()), ( "Non Tensor or int rank results must be equal for all ranks" ) @@ -1167,6 +1175,8 @@ def __init__(self, ranks: Union[int, frozenset[int]]): self.ranks = ranks self._disable = True self._old_get_coordinate = None + self._old_get_rank = None + self._old_get_local_rank = None self._old_torch_manual_seed: Any = None self._old_torch_initial_seed: Any = None self._per_rank_rng_states: dict[ @@ -1383,14 +1393,28 @@ def _any_local_rng_state(self) -> tuple[torch.Tensor, dict[int, torch.Tensor]]: def _patch_device_mesh(self) -> None: assert self._old_get_coordinate is None + assert self._old_get_rank is None + assert self._old_get_local_rank is None self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment] + self._old_get_rank = DeviceMesh.get_rank # type: ignore[assignment] + self._old_get_local_rank = DeviceMesh.get_local_rank # type: ignore[assignment] DeviceMesh.get_coordinate = _LocalDeviceMesh.get_coordinate # type: ignore[method-assign] + DeviceMesh.get_rank = _LocalDeviceMesh.get_rank # type: ignore[method-assign] + DeviceMesh.get_local_rank = _LocalDeviceMesh.get_local_rank # type: ignore[method-assign] def _unpatch_device_mesh(self) -> None: assert self._old_get_coordinate is not None + assert self._old_get_rank is not None + assert self._old_get_local_rank is not None DeviceMesh.get_coordinate = self._old_get_coordinate + DeviceMesh.get_rank = self._old_get_rank + DeviceMesh.get_local_rank = self._old_get_local_rank # pyrefly: ignore [bad-assignment] self._old_get_coordinate = None + # pyrefly: ignore [bad-assignment] + self._old_get_rank = None + # pyrefly: ignore [bad-assignment] + self._old_get_local_rank = None def _patch_random_functions(self) -> None: import torch.random @@ -1507,6 +1531,35 @@ def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]: # as the current mesh. return out # type: ignore[return-value] + @staticmethod + def get_rank(self) -> int | SymInt: + lm = enabled_local_tensor_mode() + assert lm is not None, "Unexpectedly not in LocalTensorMode" + return torch.SymInt(LocalIntNode(local_ints={r: r for r in lm.ranks})) + + @staticmethod + def get_local_rank(self, mesh_dim: int | str | None = None) -> int | SymInt: + lm = enabled_local_tensor_mode() + assert lm is not None, "Unexpectedly not in LocalTensorMode" + + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + if isinstance(mesh_dim, str): + mesh_dim = self._mesh_dim_names.index(mesh_dim) + + # Compute local rank for each global rank + # get_coordinate returns a list of SymInt, one per mesh dimension + # We need to extract the coordinate for the specified mesh_dim + coords = _LocalDeviceMesh.get_coordinate(self) + assert coords is not None + return coords[mesh_dim] + def reconcile_args(args: Any, kwargs: dict[str, Any] | None = None) -> Any: """ diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 4c3d51381f541..c07e0f6522189 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -6,6 +6,7 @@ from typing import Optional import torch +from torch.distributed._local_tensor import maybe_run_for_local_tensor from torch.distributed.device_mesh import _get_device_handle, DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor.placement_types import Shard @@ -469,4 +470,9 @@ def _resolve_device(device_mesh: DeviceMesh) -> torch.device: device_handle = _get_device_handle(device_type) assert device_handle is not None device_idx = device_mesh.get_rank() % device_handle.device_count() - return torch.device(f"{device_type}:{device_idx:d}") + + @maybe_run_for_local_tensor + def get_device(device_idx): + return torch.device(f"{device_type}:{device_idx:d}") + + return get_device(device_idx) diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 9dc9d188faf61..f085b681f9491 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -361,22 +361,35 @@ def compute_global_tensor_shape( if isinstance(placements[0], Replicate): return shape elif isinstance(placements[0], Shard): - local_shape = torch.tensor(list(shape), device=mesh.device_type) + + @maybe_run_for_local_tensor + def _create_local_shape_tensor(shape): + return torch.tensor(list(shape), device=mesh.device_type) + + local_shape = _create_local_shape_tensor(shape) gathered_shaped_tensors = [ torch.empty_like(local_shape, device=local_shape.device) for _ in range(mesh.size()) ] funcol.all_gather_inplace(gathered_shaped_tensors, local_shape, mesh) - sharded_dim_sum = 0 - shard_dim = placements[0].dim - other_dims = [d for d in range(mesh.ndim) if d != shard_dim] - for shape_tensor in gathered_shaped_tensors: - if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]): - raise RuntimeError( - "Non-sharded dimensions should have identical size across ranks." - ) - shape_tensor_list = shape_tensor.tolist() - sharded_dim_sum += shape_tensor_list[shard_dim] + + @maybe_run_for_local_tensor + def _validate_and_compute_global_shape(local_shape, gathered_shaped_tensors): + sharded_dim_sum = 0 + shard_dim = placements[0].dim # type: ignore[union-attr] + other_dims = [d for d in range(len(shape)) if d != shard_dim] + for shape_tensor in gathered_shaped_tensors: + if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]): + raise RuntimeError( + "Non-sharded dimensions should have identical size across ranks." + ) + shape_tensor_list = shape_tensor.tolist() + sharded_dim_sum += shape_tensor_list[shard_dim] + return sharded_dim_sum + + sharded_dim_sum = _validate_and_compute_global_shape( + local_shape, gathered_shaped_tensors + ) global_shape = list(shape) global_shape[placements[0].dim] = sharded_dim_sum return torch.Size(global_shape) diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 590ec80b8f009..1f6910ddfe632 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -749,8 +749,10 @@ def local_shard_size_and_offset( # pyre-ignore[bad-override] else: offsets = [] - if return_first_offset and len(offsets) > 0: - offsets = offsets[0] + if return_first_offset: + # Always return an int for consistency across ranks. + # For empty shards, return -1 as an invalid offset indicator. + offsets = offsets[0] if len(offsets) > 0 else -1 return local_shard_size, offsets From d2b0d8148c255e2147c1d16480d189fea7f27f2c Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Fri, 5 Dec 2025 00:02:39 +0000 Subject: [PATCH 298/338] Add check_lowerbound config for AOTI lowering (#169430) Summary: ## Why: PT has a 0/1 specialization assumption. Thus, unless backed_size_oblivious is on, the greatest lowerbound is [2+, ...]. So, this trips up a lot of models who can have a dynamic size with 0 or 1. I want a way for AOTI lowering users to easily check their dynamic shape spec is correct via `AOTI_RUNTIME_CHECK_INPUTS=1` which is a runtime env var. In-order to do that, I need to avoid any errors triggered by a wrong lowerbound. ## What does it look like? ``` // this is a lowerbound check if (arg0_1_size[0] < 2) { std::stringstream ss; ss << "input_handles[0]: dim value is too small at 0, " << "expected it to be >= 2, " << "but got: " << arg0_1_size[0] << "\n"; throw std::runtime_error(ss.str()); } ``` Differential Revision: D88203028 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169430 Approved by: https://github.com/yushangdi, https://github.com/desertfire --- test/inductor/test_aot_inductor.py | 46 ++++++++++++++++++++++ torch/_inductor/codegen/cpp_wrapper_cpu.py | 4 +- torch/_inductor/config.py | 6 +++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 8e4102a57d682..2f1feedf6dd47 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -7999,6 +7999,52 @@ class AOTInductorTestABICompatibleMps(TestCase): ) +class TestCheckLowerboundConfig(TestCase): + def test_aoti_check_lowerbound_codegen(self): + """ + Test that check_lowerbound config controls lowerbound check codegen. + When check_lowerbound=False, no lowerbound checks should be generated. + """ + + class Model(torch.nn.Module): + def forward(self, x): + return x + 1 + + model = Model() + batch = Dim("batch", min=2, max=10) + example_inputs = (torch.randn(4, 3),) + + # Test with check_lowerbound=True (default) + with config.patch({"aot_inductor.check_lowerbound": True}): + result, code = run_and_get_cpp_code( + AOTIRunnerUtil.legacy_compile, + model, + example_inputs, + dynamic_shapes={"x": {0: batch}}, + ) + # Should have lowerbound checks + FileCheck().check_count( + "dim value is too small", + 1, + exactly=True, + ).run(code) + + # Test with check_lowerbound=False + with config.patch({"aot_inductor.check_lowerbound": False}): + result, code = run_and_get_cpp_code( + AOTIRunnerUtil.legacy_compile, + model, + example_inputs, + dynamic_shapes={"x": {0: batch}}, + ) + # Should NOT have lowerbound checks + FileCheck().check_count( + "dim value is too small", + 0, + exactly=True, + ).run(code) + + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 9ec44c6c2790f..16522d9832ec0 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -423,7 +423,9 @@ def gen_check(handle_kind, idx, name, tensor): from torch.utils._sympy.value_ranges import bound_sympy sym_range = bound_sympy(d, V.graph.sizevars.shape_env.var_to_range) - if not math.isinf(sym_range.lower): + if config.aot_inductor.check_lowerbound and not math.isinf( + sym_range.lower + ): self.prefix.splice( f""" if ({name}_size[{dim_idx}] < {sym_range.lower}) {{ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 47f3fd77908c0..4297880dbdbcf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1739,6 +1739,12 @@ class aot_inductor: os.environ.get("AOTINDUCTOR_RAISE_ERROR_ON_IGNORED_OPTIMIZATION", "1") == "1" ) + # Whether to check lowerbound constraints on dynamic shapes during runtime. + # When disabled, allows models with dynamic sizes of 0 or 1 to work with + # AOTI_RUNTIME_CHECK_INPUTS=1, avoiding errors from the [2+, ...] lowerbound + # restriction when backed_size_oblivious is off. + check_lowerbound: bool = True + # dump an aoti minifier if program errors dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1" From 7143efbb74130e189476a76ad12089f0799f815d Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 5 Dec 2025 00:16:03 +0000 Subject: [PATCH 299/338] Remove outdated skip conditons of CUDA and ROCm (#166391) This PR removes outdated skip conditions of CUDA 11 and ROCm 5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166391 Approved by: https://github.com/albanD --- test/test_jiterator.py | 6 +----- torch/testing/_internal/common_methods_invocations.py | 8 ++------ torch/testing/_internal/opinfo/definitions/linalg.py | 10 +--------- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/test/test_jiterator.py b/test/test_jiterator.py index 55ad64adb6b34..7adc8a1df0c87 100644 --- a/test/test_jiterator.py +++ b/test/test_jiterator.py @@ -8,7 +8,7 @@ from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA, NoTest from torch.testing._internal.common_dtype import all_types_and_complex_and from torch.testing._internal.common_device_type import ( - skipCUDAIfVersionLessThan, instantiate_device_type_tests, dtypes, toleranceOverride, tol) + instantiate_device_type_tests, dtypes, toleranceOverride, tol) if not TEST_CUDA: print('CUDA not available, skipping tests', file=sys.stderr) @@ -39,10 +39,6 @@ def test_all_dtype_contiguous(self, device, dtypes, shape_strides): self.assertEqual(expected, result) - # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details - # On cuda 11.3, nvrtcCompileProgram is taking too long to - # compile jiterator generated kernels for non-contiguous input that requires dynamic-casting. - @skipCUDAIfVersionLessThan((11, 6)) @parametrize("shape_strides", [ (([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous ]) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4578789eddf22..86320ed763204 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -31,8 +31,7 @@ toleranceOverride, tol, skipXPU) from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, - SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, - _get_torch_rocm_version, + SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, ) from torch.testing._internal.common_quantized import ( _bfloat16_to_float4_e2m1fn_x2, @@ -13822,9 +13821,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_autograd=True, sample_inputs_func=sample_inputs_sparse_sampled_addmm, decorators=[ - skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3)) - or (_get_torch_rocm_version() >= (5, 2))), - "cusparseSDDMM was added in 11.2.1"), skipCPUIfNoMklSparse, skipXPU], skips=( @@ -16628,7 +16624,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cuda'), # This is only failing on Linux Bionic 3.10 Cuda 11.6 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', - device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)), + device_type='cuda'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', dtypes=(torch.float32,)), # AssertionError: JIT Test does not execute any logic diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index 87071c439f8e0..f41cadad67eb7 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -12,15 +12,10 @@ import torch from torch.testing import make_tensor -from torch.testing._internal.common_cuda import ( - _get_magma_version, - _get_torch_cuda_version, - with_tf32_off, -) +from torch.testing._internal.common_cuda import _get_magma_version, with_tf32_off from torch.testing._internal.common_device_type import ( has_cusolver, skipCPUIfNoLapack, - skipCUDAIf, skipCUDAIfNoCusolver, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -1484,9 +1479,6 @@ def make_input(): supports_autograd=False, sample_inputs_func=sample_inputs_linalg_ldl_solve, decorators=[ - skipCUDAIf( - _get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1" - ), skipCUDAIfNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack, From 4380129736e058d7baf8d3a97f54232b3f35c175 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 3 Dec 2025 15:39:07 -0800 Subject: [PATCH 300/338] [DTensor] Fix slow sharding prop for stack (#169519) As identified in the original issue, there is quadratic complexity in the number of input tensors, due to an improperly written sharding prop rule. The previous code generated N output strategies for the stack op, one based on each of the original N input strategies. However, Each of the N output strategies was the same. The heuristic in the stack rule is to find one of the N inputs and follow that one. We now just generate one output strategy. Fixes #169445 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169519 Approved by: https://github.com/zpcore, https://github.com/malfet, https://github.com/albanD --- torch/distributed/tensor/_ops/_tensor_ops.py | 34 +++++++++----------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index a6ff33a12a189..5253a37952ea4 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -759,9 +759,11 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType: input_tuple_strategy = args_schema[0] if not isinstance(input_tuple_strategy, TupleStrategy): raise AssertionError(f"Expected TupleStrategy, got {input_tuple_strategy}") - first_input_strategy = input_tuple_strategy.children[0] - if not isinstance(first_input_strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {first_input_strategy}") + input_strategies: list[OpStrategy] = [] + for child in input_tuple_strategy.children: + assert isinstance(child, OpStrategy), f"Expected OpStrategy, got {child}" + input_strategies.append(child) + first_input_strategy = input_strategies[0] common_input_ndim = first_input_strategy.ndim dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 # normalize the dim to be within the common input ndim @@ -784,22 +786,18 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType: # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to # be normalized with the new Shard placement follow_placements = shift_shard_dims_after_insert(follow_placements, dim) - - for strategy in input_tuple_strategy.children: - if not isinstance(strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {type(strategy)}") - output_spec = DTensorSpec(mesh, tuple(follow_placements)) - redistribute_cost = [] - for input_spec in input_specs: - cost = generate_redistribute_costs(strategy, input_spec) - redistribute_cost.append(cost) - op_strategy.strategies.append( - OpSpec( - output_specs=output_spec, - input_specs=input_specs, - redistribute_cost=redistribute_cost, - ) + output_spec = DTensorSpec(mesh, tuple(follow_placements)) + redistribute_cost = [ + generate_redistribute_costs(input_strategies[i], input_specs[i]) + for i in range(len(input_specs)) + ] + op_strategy.strategies.append( + OpSpec( + output_specs=output_spec, + input_specs=input_specs, + redistribute_cost=redistribute_cost, ) + ) return op_strategy From 6be8f42812e0c472ba5070da925c38764be1577d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 5 Dec 2025 00:53:57 +0000 Subject: [PATCH 301/338] Revert "[dynamo] Refactor isinstance(x, ConstantVariable) to x.is_python_constant() (#169006)" This reverts commit 31bb133faf6a88c3adab11fa06d9a5fe4f6c9c85. Reverted https://github.com/pytorch/pytorch/pull/169006 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break 3.10 dynamo tests ([comment](https://github.com/pytorch/pytorch/pull/169006#issuecomment-3614888197)) --- test/dynamo/test_list.py | 6 ---- test/dynamo/test_modules.py | 3 +- torch/_dynamo/comptime.py | 3 +- torch/_dynamo/output_graph.py | 2 +- torch/_dynamo/side_effects.py | 5 ++- torch/_dynamo/symbolic_convert.py | 28 +++++++-------- torch/_dynamo/utils.py | 8 ++--- torch/_dynamo/variables/base.py | 26 +++----------- torch/_dynamo/variables/builder.py | 4 +-- torch/_dynamo/variables/builtin.py | 33 ++++++++++------- torch/_dynamo/variables/constant.py | 16 +++------ torch/_dynamo/variables/ctx_manager.py | 4 ++- torch/_dynamo/variables/dicts.py | 25 +++++++------ torch/_dynamo/variables/functions.py | 20 +++++++---- torch/_dynamo/variables/higher_order_ops.py | 39 ++++++++++----------- torch/_dynamo/variables/iter.py | 4 +-- torch/_dynamo/variables/lists.py | 13 +++---- torch/_dynamo/variables/misc.py | 6 ++-- torch/_dynamo/variables/nn_module.py | 2 +- torch/_dynamo/variables/optimizer.py | 2 +- torch/_dynamo/variables/tensor.py | 17 ++++----- torch/_dynamo/variables/torch.py | 39 ++++++++++----------- torch/_dynamo/variables/torch_function.py | 4 +-- torch/_dynamo/variables/user_defined.py | 10 +++--- 24 files changed, 153 insertions(+), 166 deletions(-) diff --git a/test/dynamo/test_list.py b/test/dynamo/test_list.py index 85415244db69c..41e5da15b5378 100644 --- a/test/dynamo/test_list.py +++ b/test/dynamo/test_list.py @@ -176,12 +176,6 @@ def test___iter__(self): it = p.__iter__().__iter__() self.assertEqual(next(it), 1) - @make_dynamo_test - def test_list_mul_constant_tuple(self): - tree = (1, 2) - result = [tree] * 2 - self.assertEqual(result, [tree, tree]) - class ListTests(TupleTests): # List methods diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 959a32ff17a10..6fd1e6b477f36 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3383,8 +3383,7 @@ def __init__(self): def __bool__(self): self.bool_invoked += 1 - # __bool__ must return a real bool; use truthiness of cache size - return len(self.key_cache) > 0 + return len(self.key_cache) @torch.compile(fullgraph=True, backend="eager") def f(x): diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index f53c753365b63..34eec572ce550 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -49,6 +49,7 @@ def my_model(x): from .exc import unimplemented from .variables import CellVariable +from .variables.constant import ConstantVariable from .variables.tensor import SymNodeVariable @@ -142,7 +143,7 @@ def force_static(self) -> None: """ if isinstance(self.__variable, SymNodeVariable): self.__variable.evaluate_expr() - elif self.__variable.is_python_constant(): + elif isinstance(self.__variable, ConstantVariable): # TODO: Maybe complain if this isn't a int/bool/float variable pass else: diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 4fc288c9bf546..0d409869ccec5 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1685,7 +1685,7 @@ def compile_subgraph( "input", vt.source, ) - elif vt.is_python_constant(): + elif isinstance(vt, torch._dynamo.variables.ConstantVariable): self.export_metadata.output_return_type[idx] = ( "constant", vt.as_python_constant(), diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index df9716339b661..999bd145c3e57 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -881,7 +881,10 @@ def codegen_update_mutated(self, cg: PyCodegen) -> None: elif isinstance(var, variables.lists.DequeVariable): # For limited maxlen, the order of operations matter for side # effect, but we currently don't track the order, so no support. - if not var.maxlen.is_constant_none(): + if not ( + isinstance(var.maxlen, variables.ConstantVariable) + and var.maxlen.value is None + ): unimplemented( gb_type="Side effect on existing deque with limited maxlen", context="", diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 487346940dfdf..f401b9d6178b9 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -761,15 +761,10 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: # __bool__ or __len__ is function if isinstance(x, UserMethodVariable): result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment] - method_name = getattr(getattr(x, "fn", None), "__name__", None) - if result.is_python_constant(): - result_value = result.as_python_constant() - if method_name == "__bool__" and not isinstance(result_value, bool): - msg = variables.ConstantVariable.create( - f"__bool__ should return bool, returned {type(result_value).__name__}" - ) - exc.raise_observed_exception(TypeError, self, args=[msg]) - if isinstance(result_value, (bool, int)) and truth_fn(result_value): + if isinstance(result, ConstantVariable) and isinstance( + result.value, (bool, int) + ): + if truth_fn(result.value): if push: self.push(value) self.jump(inst) @@ -2638,7 +2633,7 @@ def STORE_ATTR(self, inst: Instruction) -> None: return self.store_attr_graph_break(inst) val, obj = self.popn(2) - if isinstance(obj, NNModuleVariable) and not val.is_python_constant(): + if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable): # We don't allow side effects during export on non-constant values # https://github.com/pytorch/torchdynamo/issues/1475 assert not self.export, ( @@ -3553,7 +3548,7 @@ def BUILD_STRING(self, inst: Instruction) -> None: kwargs: dict[str, VariableTracker] = {} assert inst.arg is not None for part in self.popn(inst.arg): - if part.is_python_constant(): + if isinstance(part, ConstantVariable): format_string_parts.append("{}") args.append(part) elif isinstance(part, variables.StringFormatVariable): @@ -4985,7 +4980,10 @@ def inline_call_(self) -> VariableTracker: assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration args = [] - if not self.symbolic_result.is_constant_none(): + if not ( + isinstance(self.symbolic_result, ConstantVariable) + and self.symbolic_result.value is None + ): args = [self.symbolic_result] exc.raise_observed_exception(StopIteration, self, args=args) else: @@ -4993,7 +4991,7 @@ def inline_call_(self) -> VariableTracker: else: if is_generator(code): assert isinstance(self, InliningGeneratorInstructionTranslator) - assert self.symbolic_result.is_constant_none() + assert self.symbolic_result.as_python_constant() is None return ListIteratorVariable( self.generated_items, mutation_type=ValueMutationNew(), @@ -5225,7 +5223,7 @@ def YIELD_FROM(self, inst: Instruction) -> None: assert len(self.stack) >= 2 val = self.pop() tos = self.stack[-1] - if not val.is_constant_none(): + if not (isinstance(val, ConstantVariable) and val.value is None): # invoke send # Unreachable code - if you hit this, you are implementing generator support and have # lifted the `unimplemented("generator")` in frame conversion. This codepath handles @@ -5267,7 +5265,7 @@ def SEND(self, inst: Instruction) -> None: isinstance(tos, UserDefinedObjectVariable) and isinstance(tos.value, collections.abc.Iterator) ): - if val.is_constant_none(): + if isinstance(val, ConstantVariable) and val.value is None: try: val = tos.next_variable(self) # type: ignore[arg-type] except (StopIteration, exc.ObservedUserStopIteration) as ex: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index b0ad5d2bf5118..d08b92de3441e 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2600,11 +2600,11 @@ def specialize_symnode(arg: Any) -> Any: def guard_if_dyn(arg: Any) -> Any: - from .variables import VariableTracker + from .variables import ConstantVariable arg = specialize_symnode(arg) - if isinstance(arg, VariableTracker) and arg.is_python_constant(): + if isinstance(arg, ConstantVariable): return arg.as_python_constant() return arg @@ -2615,14 +2615,14 @@ def check_constant_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool: def check_unspec_python_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool: - from .variables import VariableTracker + from .variables.constant import ConstantVariable from .variables.tensor import UnspecializedPythonVariable unspec_count = 0 for x in itertools.chain(args, kwargs.values()): if isinstance(x, UnspecializedPythonVariable): unspec_count += 1 - elif not (isinstance(x, VariableTracker) and x.is_python_constant()): + elif not isinstance(x, ConstantVariable): return False return unspec_count > 0 diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 982e0fccc5ca6..617f787e43d8a 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -366,21 +366,6 @@ def is_python_constant(self) -> bool: except NotImplementedError: return False - def is_constant_match(self, *values: Any) -> bool: - """ - Check if this variable is a python constant matching one of the given values. - - Examples: - var.is_constant_match(None) # True if var is constant None - var.is_constant_match(True, False) # True if var is constant True or False - var.is_constant_match(NotImplemented) # True if var is constant NotImplemented - """ - return False - - def is_constant_none(self) -> bool: - """Check if this variable is a constant None value.""" - return False - def make_guard(self, fn: Callable[..., Any]) -> Guard: if self.source: return self.source.make_guard(fn) @@ -392,17 +377,13 @@ def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: """getattr(self, name) returning a python constant""" raise NotImplementedError - def is_symnode_like(self) -> bool: - """Return True for values that can participate in SymNode operations""" - return False - def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": """getattr(self, name) returning a new variable""" value = self.const_getattr(tx, name) if not variables.ConstantVariable.is_literal(value): raise NotImplementedError source = self.source and AttrSource(self.source, name) - if source and not self.is_python_constant(): + if source and not isinstance(self, variables.ConstantVariable): # The second condition is to avoid guards on const getattr objects # like __code__.co_argcount install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) @@ -591,7 +572,10 @@ def call_tree_map( ) -> "VariableTracker": """Performance optimization to implement optree.tree_map faster than tracing it""" is_leaf_var = tree_map_kwargs.get("is_leaf") - if is_leaf_var is not None and not is_leaf_var.is_constant_none(): + if is_leaf_var is not None and not ( + is_leaf_var.is_python_constant() + and is_leaf_var.as_python_constant() is None + ): pred_result = is_leaf_var.call_function(tx, [self], {}) try: leaf_decision = pred_result.as_python_constant() diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 968321be56a51..248ab9d5f4bab 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2703,9 +2703,9 @@ def wrap_unspecialized_primitive(self, value): f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" ) fake_tensor_value = None - if unspec_var.is_python_constant(): + if isinstance(unspec_var, ConstantVariable): # TODO: when can this happen? - example_value = unspec_var.as_python_constant() + example_value = unspec_var.value else: example_value = unspec_var.proxy.node.meta["example_value"] assert is_fake(example_value) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 7143d4d8f3b3f..40b2be0437373 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -690,7 +690,7 @@ def list_iadd_handler( def expand_list_like( tx: "InstructionTranslator", lst: VariableTracker, const: VariableTracker ) -> VariableTracker: - if not isinstance(lst, BaseListVariable) and lst.is_python_constant(): + if isinstance(lst, ConstantVariable): lst, const = const, lst try: assert isinstance(lst, BaseListVariable) @@ -1031,7 +1031,8 @@ def create_exception_class_object( kwargs: dict[str, VariableTracker], ) -> VariableTracker: if fn is AssertionError and not all( - x.is_python_constant() and isinstance(x.as_python_constant(), str) + isinstance(x, variables.ConstantVariable) + and isinstance(x.value, str) for x in args ): unimplemented( @@ -1503,7 +1504,7 @@ def call_method( ) if self.fn is float and len(args) == 1 and name in ("fromhex", "hex"): - if args[0].is_python_constant(): + if isinstance(args[0], ConstantVariable): try: fn = getattr(float, name) res = fn(args[0].as_python_constant()) @@ -1549,12 +1550,10 @@ def call_method( if self.fn is str and len(args) >= 1: resolved_fn = getattr(self.fn, name) if resolved_fn in str_methods: - # Only delegate to ConstantVariable, not other types that happen to be constants if isinstance(args[0], ConstantVariable): return args[0].call_method(tx, name, args[1:], kwargs) if self.fn is float and len(args) >= 1: - # Only delegate to ConstantVariable, not other types that happen to be constants if isinstance(args[0], ConstantVariable): return ConstantVariable.create( getattr(float, name)(args[0].as_python_constant()) @@ -1803,7 +1802,7 @@ def _call_min_max_binary( "call_function", py_fn, *proxy_args_kwargs([a, b], {}) ) return SymNodeVariable.create(tx, proxy, None) - elif a.is_python_constant() and b.is_python_constant(): + elif isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): value = self.fn( a.as_python_constant(), b.as_python_constant(), @@ -2588,7 +2587,7 @@ def call_getattr( if default is not None: hasattr_var = self.call_hasattr(tx, obj, name_var) if hasattr_var is not None: - assert hasattr_var.is_constant_match(True, False) + assert hasattr_var.as_python_constant() in (True, False) if not hasattr_var.as_python_constant(): return default else: @@ -3095,7 +3094,9 @@ def call_xor( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if a.is_symnode_like() and b.is_symnode_like(): + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -3138,7 +3139,9 @@ def call_and_( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if a.is_symnode_like() and b.is_symnode_like(): + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -3157,7 +3160,9 @@ def call_iand( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if a.is_symnode_like() and b.is_symnode_like(): + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -3175,7 +3180,9 @@ def call_or_( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if a.is_symnode_like() and b.is_symnode_like(): + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): return SymNodeVariable.create( tx, tx.output.create_proxy( @@ -3209,7 +3216,9 @@ def call_ior( # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): return None - if a.is_symnode_like() and b.is_symnode_like(): + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): return SymNodeVariable.create( tx, tx.output.create_proxy( diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 2b7a7661a1182..672fa1d804383 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -115,15 +115,6 @@ def as_python_constant(self) -> Any: def is_python_constant(self) -> Literal[True]: return True - def is_symnode_like(self) -> bool: - return isinstance(self.value, (int, bool)) - - def is_constant_match(self, *values: Any) -> bool: - return self.value in values - - def is_constant_none(self) -> bool: - return self.value is None - @property def items(self) -> list[VariableTracker]: """ @@ -320,7 +311,10 @@ def call_tree_map( return map_fn.call_function(tx, [self, *rest], {}) else: for other in rest: - if not other.is_constant_none(): + if not ( + other.is_python_constant() + and other.as_python_constant() is None + ): return self._tree_map_fallback( tx, tree_map_fn, @@ -362,7 +356,7 @@ def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None def create( cls, cls_type: Any, value_vt: VariableTracker, options: Any ) -> "EnumVariable": - if value_vt.is_python_constant(): + if isinstance(value_vt, variables.ConstantVariable): for member in list(cls_type): if member.value == value_vt.as_python_constant(): return cls(member, **options) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 64ec27cf9e430..c79f19216f68b 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -1143,7 +1143,9 @@ def __init__( # The context manager accepts Union[Tensor, Tuple[Tensor]] if isinstance(self.tensors, variables.TensorVariable): self.tensors = variables.TupleVariable([self.tensors]) - if self.prev_versions.is_symnode_like(): + if isinstance( + self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable) + ): self.prev_versions = variables.TupleVariable([self.prev_versions]) def enter(self, tx: "InstructionTranslator") -> VariableTracker: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index b794cd2735a38..422cae7c4d3f1 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -493,6 +493,7 @@ def install_dict_contains_guard( # 3b) contains=False. There is no easy way to selectively apply this # DICT_NOT_CONTAINS guard because our guard are represented via trees. # Be conservative and add DICT_KEYS_MATCH guard. + from . import ConstantVariable if not self.source: return @@ -501,12 +502,12 @@ def install_dict_contains_guard( return contains = args[0] in self - if args[0].source is None and args[0].is_python_constant(): + if args[0].source is None and isinstance(args[0], ConstantVariable): install_guard( self.make_guard( functools.partial( type(self).CONTAINS_GUARD, - key=args[0].as_python_constant(), + key=args[0].value, invert=not contains, ) ) @@ -673,10 +674,10 @@ def call_method( if self.user_cls is collections.OrderedDict and ( len(args) == 1 or "last" in kwargs ): - if len(args) == 1 and args[0].is_python_constant(): - last = args[0].as_python_constant() - elif (v := kwargs.get("last")) and v.is_python_constant(): - last = v.as_python_constant() + if len(args) == 1 and isinstance(args[0], ConstantVariable): + last = args[0].value + elif (v := kwargs.get("last")) and isinstance(v, ConstantVariable): + last = v.value else: raise_args_mismatch(tx, name) k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] @@ -779,11 +780,15 @@ def call_method( raise_observed_exception(KeyError, tx) last = True - if len(args) == 2 and args[1].is_python_constant(): - last = args[1].as_python_constant() + if len(args) == 2 and isinstance(args[1], ConstantVariable): + last = args[1].value - if kwargs and "last" in kwargs and kwargs["last"].is_python_constant(): - last = kwargs.get("last").as_python_constant() # type: ignore[union-attr] + if ( + kwargs + and "last" in kwargs + and isinstance(kwargs["last"], ConstantVariable) + ): + last = kwargs.get("last").value # type: ignore[union-attr] key = Hashable(args[0]) self.items.move_to_end(key, last=last) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index c43866c62809c..fdc2f53f82383 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1047,7 +1047,10 @@ def call_method( if self._is_generator_just_started() and len(args): # can't send non-None value to a just-started generator # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen - if not all(arg.is_constant_none() for arg in args): + if not all( + isinstance(arg, ConstantVariable) and arg.value is None + for arg in args + ): raise_observed_exception(TypeError, tx) tracer = self.inline_tracer tracer.push_many(args) @@ -2424,7 +2427,7 @@ def call_function( and not kwargs and isinstance(args[0], (variables.ListVariable, variables.TupleVariable)) and all( - (x.is_python_constant() and isinstance(x.as_python_constant(), int)) + (isinstance(x, variables.ConstantVariable) and isinstance(x.value, int)) or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int) for x in args[0].items ) @@ -2440,8 +2443,8 @@ def call_function( sym_num=torch.sym_sum( [ ( - x.as_python_constant() - if x.is_python_constant() + x.value + if isinstance(x, variables.ConstantVariable) else x.sym_num # type: ignore[attr-defined] ) for x in args[0].items @@ -2646,6 +2649,7 @@ def call_HOP( combined_args_raw: dict[str, Any], tx: "InstructionTranslator", ) -> "variables.ConstantVariable": + from .constant import ConstantVariable from .dicts import ConstDictVariable # as we can only pass tensors as non-const args in fx graph, @@ -2679,12 +2683,12 @@ def call_HOP( constant_args = { k: v.as_python_constant() for k, v in combined_args_raw.items() - if isinstance(v, VariableTracker) and v.is_python_constant() + if isinstance(v, ConstantVariable) } non_constant_args = { k: v for k, v in combined_args.items() - if not (isinstance(v, VariableTracker) and v.is_python_constant()) + if not isinstance(v, ConstantVariable) } for v in non_constant_args.values(): @@ -2985,7 +2989,9 @@ def call_function( if len(args) == 2: is_leaf = args[1] - if not is_leaf.is_constant_none(): + if not ( + isinstance(is_leaf, variables.ConstantVariable) and is_leaf.value is None + ): return super().call_function(tx, args, kwargs) # Optimize the case where is_leaf is None diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 14dec2f9c45ea..0f7491911d35b 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -160,7 +160,7 @@ def _unwrap_var(var): return var.proxy.node.meta["example_value"] elif isinstance(var, SymNodeVariable): return var.sym_num - elif var.is_python_constant(): + elif isinstance(var, ConstantVariable): return var.as_python_constant() else: unimplemented( @@ -225,7 +225,11 @@ def find_mismatched_vars(var, types, allow_none=False): for value in var.items.values(): mismatched_vars.update(find_mismatched_vars(value, types, allow_none)) else: - if not isinstance(var, types) and not (allow_none and var.is_constant_none()): + + def _is_none(var): + return var.is_python_constant() and var.as_python_constant() is None + + if not isinstance(var, types) and not (allow_none and _is_none(var)): mismatched_vars.add(var) return mismatched_vars @@ -499,8 +503,7 @@ def _call_while_loop( def unspecialize_carried_inputs(tx, carry) -> VariableTracker: # See NOTE [unspecialize int carry with unbacked symints] if ( - carry.is_python_constant() - and isinstance(carry.as_python_constant(), int) + isinstance(carry, ConstantVariable) and carry.python_type() is int ) or isinstance(carry, SymNodeVariable): example_value = _create_unbacked_symint( tx.output.fake_mode, ignore_fresh_unbacked_symbols=True @@ -598,7 +601,7 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: *graph_break_hints.USER_ERROR, ], ) - elif cond_r.is_python_constant(): + elif isinstance(cond_r, ConstantVariable): # short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False pred = cond_r.as_python_constant() if pred: @@ -1973,9 +1976,7 @@ def speculate_branch(branch): ], ) for ret in ret_val.unpack_var_sequence(tx): - if ret.is_python_constant() and not isinstance( - ret.as_python_constant(), int - ): + if isinstance(ret, ConstantVariable) and ret.python_type() is not int: unimplemented( gb_type="torch.cond: unsupported branch return type (constant non-int)", context=str(ret_val), @@ -2102,8 +2103,7 @@ def validate_subgraph_output_types(output: VariableTracker): if ( isinstance(out, SymNodeVariable) and out.python_type() in (int, bool) ) or ( - out.is_python_constant() - and isinstance(out.as_python_constant(), (int, bool)) + isinstance(out, ConstantVariable) and out.python_type() in (int, bool) ): continue unimplemented( @@ -2719,7 +2719,10 @@ def _call_function( # Check all outputs of map are tensors. # For map, outputting None is OK, thus ignore None values in the check body_r_vars = body_r.unpack_var_sequence(tx) - none_mask = [x.is_constant_none() for x in body_r_vars] + none_mask = [ + type(x.realize()) is ConstantVariable and x.as_python_constant() is None + for x in body_r_vars + ] _check_all_tensorvariable( [br for bm, br in zip(none_mask, body_r_vars) if not bm] ) @@ -3008,7 +3011,7 @@ def call_function( grad_enabled, fn_var, *rest_args = args - if not grad_enabled.is_python_constant(): + if not isinstance(grad_enabled, ConstantVariable): unimplemented( gb_type="wrap_with_set_grad_enabled: non-constant grad_enabled", context=str(grad_enabled), @@ -3096,7 +3099,7 @@ def call_function( device_type, dtype, enabled, cache_enabled, fn_var, *rest_args = args for arg in [device_type, dtype, enabled, cache_enabled]: - if not arg.is_python_constant(): + if not isinstance(arg, ConstantVariable): unimplemented( gb_type="wrap_with_autocast: expected constant arg", context=str(args), @@ -3714,14 +3717,8 @@ def _call_function( tx, query, score_mod, "score_mod" ) mask_fn = block_mask.items[-1] - if mask_fn.is_python_constant(): - mask_callable = mask_fn.as_python_constant() - if mask_callable is None: - mask_callable = torch.nn.attention.flex_attention.noop_mask - mask_fn = UserFunctionVariable( - mask_callable, - source=mask_fn.source, - ) + if isinstance(mask_fn, ConstantVariable): + mask_fn = UserFunctionVariable(torch.nn.attention._flex_attention._no_mask) mask_fn_node, mask_fn_lifted_args = self.create_wrapped_node( tx, query, mask_fn, "mask_fn" ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 4a3c0247add1b..2689d5e094977 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -116,7 +116,7 @@ def call_function( def retrieve_const_key(key: VariableTracker) -> Any: if isinstance(key, variables.SymNodeVariable): return key.evaluate_expr() - elif key.is_python_constant(): + elif isinstance(key, variables.ConstantVariable): return key.as_python_constant() else: unimplemented( @@ -595,7 +595,7 @@ def _next() -> VariableTracker: while True: item = _next() self.index += 1 - if self.fn.is_constant_none(): + if isinstance(self.fn, ConstantVariable) and self.fn.value is None: res = item else: res = self.fn.call_function(tx, [item], {}) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index c6f9448c9b6cf..4f21e35479fb8 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -365,9 +365,7 @@ def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None: def maybe_as_int(x: VariableTracker) -> VariableTracker: return ( - ConstantVariable.create(int(x.as_python_constant())) - if x.is_python_constant() - else x + ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x ) # cast each argument to an integer @@ -905,7 +903,10 @@ def call_method( if len(kwargs) != 0: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") - if key_fn_var.is_constant_none(): + if ( + key_fn_var.is_python_constant() + and key_fn_var.as_python_constant() is None + ): keys = self.items.copy() else: keys = [key_fn_var.call_function(tx, [x], {}) for x in self.items] @@ -1259,8 +1260,8 @@ def numel(self, tx: "InstructionTranslator") -> VariableTracker: sym_sizes = [] for v in self.items: - if v.is_python_constant(): - const_result *= v.as_python_constant() + if isinstance(v, ConstantVariable): + const_result *= v.value else: assert isinstance(v, SymNodeVariable), type(v) # Delay proxy calls until we know it will be necessary diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 466b7a757d829..c7d6e58ba4531 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -474,7 +474,7 @@ def raise_error(msg): if name == "__context__": self.set_context(val) elif name == "__cause__": - if val.is_constant_none() or isinstance( + if (isinstance(val, ConstantVariable) and val.value is None) or isinstance( val, ( variables.BuiltinVariable, @@ -488,12 +488,12 @@ def raise_error(msg): else: raise_error("exception cause must be None or derive from BaseException") elif name == "__suppress_context__": - if val.is_constant_match(True, False): + if isinstance(val, ConstantVariable) and val.value in (True, False): self.__suppress_context__ = val else: raise_error("exception cause must be None or derive from BaseException") elif name == "__traceback__": - if val.is_constant_none(): + if isinstance(val, ConstantVariable) and val.value is None: self.__traceback__ = val else: unimplemented( diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index c9fe1e2802264..bb6952abf0b56 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -859,7 +859,7 @@ def gen_source(source: Source, name: str) -> Source: # pyrefly: ignore[missing-attribute] if type(module).__getitem__ not in builtin_supported: if not ( - args[0].is_python_constant() + isinstance(args[0], variables.ConstantVariable) and isinstance(args[0].as_python_constant(), (str, int)) ): unimplemented( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 3e29b9a08347e..69ca37db4ef37 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -218,7 +218,7 @@ def get_python_args( """Get python values equivalent to the variable tracker args""" def map_arg(arg: Any) -> Any: - if isinstance(arg, VariableTracker) and arg.is_python_constant(): + if isinstance(arg, ConstantVariable): return arg.as_python_constant() elif isinstance(arg, ListVariable) and not arg.items: return [] diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 8002e41a42631..47439387e0fca 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -597,14 +597,11 @@ def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {}) # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through # symbolic_shapes, but that end up as int/sympy.Integer - assert ( - isinstance(dyn_length, SymNodeVariable) - or dyn_length.is_python_constant() - ) + assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable)) if isinstance(dyn_length, SymNodeVariable): length = dyn_length.evaluate_expr(tx.output) else: - length = dyn_length.as_python_constant() + length = dyn_length.value if idxes is None: idxes = range(length) @@ -1412,8 +1409,7 @@ def method_new(self, *args, **kwargs): if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( len(args) >= 1 and all( - a.is_python_constant() and isinstance(a.as_python_constant(), int) - for a in args + isinstance(a, ConstantVariable) and a.python_type() is int for a in args ) ): from ..symbolic_convert import InstructionTranslator @@ -1479,9 +1475,6 @@ def python_type(self): else: return type(self.sym_num) - def is_symnode_like(self) -> bool: - return True - def as_proxy(self): return self.proxy @@ -1639,7 +1632,9 @@ def call_method( dtype_arg = kwargs["dtype"] elif len(args) > 0: dtype_arg = args[0] - is_object_str = dtype_arg is not None and dtype_arg.is_constant_match("O") + is_object_str = ( + isinstance(dtype_arg, ConstantVariable) and dtype_arg.value == "O" + ) is_object_type = ( isinstance(dtype_arg, BuiltinVariable) and dtype_arg.fn is object ) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 8a621aedd5a2c..a4f940cb2adaf 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -934,15 +934,11 @@ def handle_constant_processgroup_functions( # bake the result into the trace if len(args) == 1: # group or group name - assert ( - isinstance(args[0], ProcessGroupVariable) - or args[0].is_python_constant() - ) + assert isinstance(args[0], (ProcessGroupVariable, ConstantVariable)) elif len(args) == 2: # ranks + tag - assert ( - isinstance(args[0], ListVariable) - and args[1].is_python_constant() + assert isinstance(args[0], ListVariable) and isinstance( + args[1], ConstantVariable ) else: raise AssertionError( @@ -1021,7 +1017,7 @@ def handle_nested_tensor( ): from .lists import BaseListVariable - if layout and layout.is_constant_match(torch.strided): + if layout and layout.as_python_constant() == torch.strided: unimplemented( gb_type="Attempted to use strided NestedTensor", context=f"layout={layout}", @@ -1045,7 +1041,9 @@ def handle_nested_tensor( @register(torch.nn.functional.one_hot) def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): if len(args) + len(kwargs) == 1 or ( - len(args) == 2 and args[1].is_constant_match(-1) + len(args) == 2 + and args[1].is_python_constant() + and args[1].as_python_constant() == -1 ): unimplemented( gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape", @@ -1067,7 +1065,7 @@ def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr): expr.sym_num ) ) - elif expr.is_python_constant(): + elif isinstance(expr, ConstantVariable): return expr @register(torch.fx.experimental.symbolic_shapes.guard_or_true) @@ -1078,7 +1076,7 @@ def handle_guard_or_true(self, tx: "InstructionTranslator", expr): return variables.ConstantVariable.create( torch.fx.experimental.symbolic_shapes.guard_or_true(expr.sym_num) ) - elif expr.is_python_constant(): + elif isinstance(expr, ConstantVariable): return expr @register(torch.fx.experimental.symbolic_shapes.guard_or_false) @@ -1089,7 +1087,7 @@ def handle_guard_or_false(self, tx: "InstructionTranslator", expr): return variables.ConstantVariable.create( torch.fx.experimental.symbolic_shapes.guard_or_false(expr.sym_num) ) - elif expr.is_python_constant(): + elif isinstance(expr, ConstantVariable): return expr @register(torch.fx.experimental.symbolic_shapes.statically_known_false) @@ -1100,15 +1098,15 @@ def handle_statically_known_false(self, tx: "InstructionTranslator", expr): expr.sym_num ) ) - elif expr.is_python_constant(): + elif isinstance(expr, ConstantVariable): return expr @register(torch.fx.experimental.symbolic_shapes.guard_scalar) def guard_scalar(self, tx: "InstructionTranslator", expr): if isinstance(expr, SymNodeVariable): val = expr.sym_num - elif expr.is_python_constant(): - val = expr.as_python_constant() + elif isinstance(expr, ConstantVariable): + val = expr.value else: unimplemented( gb_type="torch.fx.experimental.symbolic_shapes.guard_scalar branch not supported", @@ -1129,7 +1127,7 @@ def handle_statically_known_true(self, tx: "InstructionTranslator", expr): expr.sym_num ) ) - elif expr.is_python_constant(): + elif isinstance(expr, ConstantVariable): return expr @register(torch.fx.experimental.symbolic_shapes.sym_and) @@ -1158,8 +1156,8 @@ def handle_sym_or(self, tx: "InstructionTranslator", *terms): def handle_has_static_value(self, tx: "InstructionTranslator", expr): if isinstance(expr, SymNodeVariable): val = expr.sym_num - elif expr.is_python_constant(): - val = expr.as_python_constant() + elif isinstance(expr, ConstantVariable): + val = expr.value else: return @@ -1359,7 +1357,7 @@ def handle_set_default_device( # Running the graph will ensure that the DeviceContext mode is # at the correct position in the stack TorchFunctionModeStackVariable.register_mutation(tx) - if args[0].is_constant_none(): + if args[0].is_python_constant() and args[0].as_python_constant() is None: TorchFunctionModeStackVariable.clear_default_device(tx) else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) @@ -1541,7 +1539,8 @@ def call_function( any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) all_ints_or_floats = all( - isinstance(x, SymNodeVariable) or x.is_python_constant() for x in args + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args ) if ( getattr(self.value, "__module__", "") == "torch" diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index b2a86eb4f017f..c7254afdfebfc 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -543,7 +543,7 @@ def dispatch_torch_function( res = tx.symbolic_torch_function_state.call_torch_function_mode( tx, fn, types, args, kwargs ) - if not res.is_constant_match(NotImplemented): + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): return res for arg in overloaded_args: @@ -555,7 +555,7 @@ def dispatch_torch_function( kwargs, ) - if not res.is_constant_match(NotImplemented): + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): return res unimplemented( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index d227e9fa453ad..cc377a09ab746 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -881,8 +881,8 @@ def deque_signature(iterable=None, maxlen=None): return tensor_variable elif self.value is random.Random: - if len(args) == 1 and args[0].is_python_constant(): - seed = args[0].as_python_constant() + if len(args) == 1 and isinstance(args[0], variables.ConstantVariable): + seed = args[0].value else: seed = None random_object = random.Random(seed) @@ -1911,9 +1911,9 @@ def call_method(self, tx, name, args, kwargs): elif ( name == "__setattr__" and len(args) == 2 - and args[0].is_constant_match( - "__cause__", "__context__", "__suppress_context__", "__traceback__" - ) + and isinstance(args[0], variables.ConstantVariable) + and args[0].value + in ("__cause__", "__context__", "__suppress_context__", "__traceback__") ): self.exc_vt.call_setattr(tx, args[0], args[1]) elif name == "with_traceback": From ff43f43d5981ece3b1f04aafa7db36fc3dd33cc5 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Fri, 5 Dec 2025 01:12:37 +0000 Subject: [PATCH 302/338] =?UTF-8?q?[Profiler][PrivateUse1]=20Fix=20Profile?= =?UTF-8?q?rState=20typo=20('Disable'=E2=86=92'Disabled')=20and=20expose?= =?UTF-8?q?=20PRIVATEUSE1=20in=20ActiveProfilerType=20(#169166)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Compare the implementation in `torch/csrc/profiler/python/init.cpp` with `torch/_C/_profiler.pyi` and find some inconsistencies. https://github.com/pytorch/pytorch/blob/a5436a5e8e4ee42d1debf52c2786c7ae0043a434/torch/csrc/profiler/python/init.cpp#L339-L358 https://github.com/pytorch/pytorch/blob/a5436a5e8e4ee42d1debf52c2786c7ae0043a434/torch/_C/_profiler.pyi#L20-L36 In a nutshell, this pr fixes the inconsistency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169166 Approved by: https://github.com/fffrog, https://github.com/albanD --- torch/_C/_profiler.pyi | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index de12af50c1855..ae8121e4b71d2 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -18,15 +18,15 @@ class RecordScope(Enum): STATIC_RUNTIME_MODEL = ... class ProfilerState(Enum): - Disable = ... + Disabled = ... CPU = ... CUDA = ... NVTX = ... ITT = ... + PRIVATEUSE1 = ... KINETO = ... KINETO_GPU_FALLBACK = ... KINETO_PRIVATEUSE1_FALLBACK = ... - KINETO_PRIVATEUSE1 = ... class ActiveProfilerType(Enum): NONE = ... @@ -34,6 +34,7 @@ class ActiveProfilerType(Enum): KINETO = ... NVTX = ... ITT = ... + PRIVATEUSE1 = ... class ProfilerActivity(Enum): CPU = ... From 64f29c08daff1f20d0efb1b0d5d6f3965357703f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 4 Dec 2025 13:23:47 -0800 Subject: [PATCH 303/338] [dynamo][dicts] Prepare for dict refactor for fewer conflicts (#169589) Just sending changes that impact graph break registry. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169589 Approved by: https://github.com/williamwen42 --- torch/_dynamo/graph_break_registry.json | 44 +++++++++++++++++++ torch/_dynamo/utils.py | 18 ++++++++ torch/_dynamo/variables/base.py | 56 +++++++++++++++++++++++++ 3 files changed, 118 insertions(+) diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index dd012a239bb23..38125b59fcc5e 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3678,5 +3678,49 @@ "and pass it in as an argument or as a global variable." ] } + ], + "GB0364": [ + { + "Gb_type": "User-defined object with overridden __hash__", + "Context": "hashing object of type={type(obj)} and variable tracker {vt}", + "Explanation": "Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", + "Hints": [ + "Dynamo does not support hashing user-defined objects with overridden __hash__", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0365": [ + { + "Gb_type": "Dynamo cannot determine whether the underlying object is hashable", + "Context": "is_python_hashable {self}", + "Explanation": "Dynamo does not know whether the underlying python object for {self} is hashable", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {type_self}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0366": [ + { + "Gb_type": "Dynamo cannot determine the hash of an object", + "Context": "get_python_hash {self}", + "Explanation": "Dynamo does not know the hash of the underlying python object for {self}", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0367": [ + { + "Gb_type": "Dynamo cannot determine the equality comparison of an object", + "Context": "is_python_equal {self}", + "Explanation": "Dynamo does not know the equality comparison of the underlying python object for {self}", + "Hints": [ + "Consider using a different type of object as the dictionary key instead of {self.python_type()}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } ] } diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index d08b92de3441e..afdd0c7aefa4d 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4962,3 +4962,21 @@ def get_traced_code() -> Optional[list[CodeType]]: from torch._guards import TracingContext return TracingContext.get_traced_code() + + +def raise_on_overridden_hash(obj: Any, vt: VariableTracker) -> None: + from . import graph_break_hints + from .exc import unimplemented + + is_overridden = type(obj).__dict__.get("__hash__", False) + + if is_overridden: + unimplemented( + gb_type="User-defined object with overridden __hash__", + context=f"hashing object of type={type(obj)} and variable tracker {vt}", + explanation=f"Found a user-defined object {vt} with overridden __hash__ when attempting to hash it", + hints=[ + "Dynamo does not support hashing user-defined objects with overridden __hash__", + *graph_break_hints.SUPPORTABLE, + ], + ) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 617f787e43d8a..a794010f4083f 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -683,6 +683,62 @@ def build( else: return variables.LazyVariableTracker.create(value, source) + def is_python_hashable(self): + """ + Unlike the variable tracker's own __hash__, this method checks whether + the underlying Python object referenced by this variable tracker is hashable. + """ + try: + type_self = self.python_type() + except NotImplementedError: + type_self = type(self) + + unimplemented( + gb_type="Dynamo cannot determine whether the underlying object is hashable", + context=f"is_python_hashable {self}", + explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {type_self}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def get_python_hash(self): + """ + Unlike the variable tracker’s own __hash__, this method is used by + ConstDictVariableTracker to compute the hash of the underlying key object. + """ + unimplemented( + gb_type="Dynamo cannot determine the hash of an object", + context=f"get_python_hash {self}", + explanation=f"Dynamo does not know the hash of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def is_python_equal(self, other): + """ + NB - Deliberately not overriding the __eq__ method because that can + disable the __hash__ for the vt itself. + """ + unimplemented( + gb_type="Dynamo cannot determine the equality comparison of an object", + context=f"is_python_equal {self}", + explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + def __init__( self, *, From 1a8de0293273d8802c383a8e926b3bd106a6e40b Mon Sep 17 00:00:00 2001 From: Slawomir Siwek Date: Fri, 5 Dec 2025 01:37:43 +0000 Subject: [PATCH 304/338] Fix log_sigmoid_backward_batch_rule on XPU (#169215) Fixes https://github.com/intel/torch-xpu-ops/issues/2240, submitted as alternative approach to https://github.com/intel/torch-xpu-ops/pull/2373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169215 Approved by: https://github.com/guangyey, https://github.com/cyyever, https://github.com/EikanWang, https://github.com/albanD --- aten/src/ATen/functorch/BatchRulesBinaryOps.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index 5426e50e7100a..937f39273ab57 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -322,11 +322,13 @@ static std::tuple> log_sigmoid_backward_batch_rul Tensor& self, std::optional self_bdim, Tensor& buffer, std::optional buffer_bdim) { // NB: This emulates handle_pointwise_ops except we ignore the last argument, buffer - // when any of the inputs are on cuda. - // We do this because on cuda, buffer is a dummy tensor always of logical rank 1 and + // when any of the inputs are on cuda/xpu. + // We do this because on cuda/xpu, buffer is a dummy tensor always of logical rank 1 and // it becomes an issue when the rest of the inputs are scalar int64_t out_logical_rank = std::max(rankWithoutBatchDim(grad, grad_bdim), rankWithoutBatchDim(self, self_bdim)); - if (!grad.is_cuda() && !self.is_cuda() && !buffer.is_cuda()) { + bool inputs_on_cuda = grad.is_cuda() || self.is_cuda() || buffer.is_cuda(); + bool inputs_on_xpu = grad.is_xpu() || self.is_xpu() || buffer.is_xpu(); + if (!inputs_on_cuda && !inputs_on_xpu) { out_logical_rank = std::max(out_logical_rank, rankWithoutBatchDim(buffer, buffer_bdim)); } Tensor out_grad = maybePadToLogicalRank(moveBatchDimToFront(grad, grad_bdim), grad_bdim, out_logical_rank); From 8c13e8bc03cbc7807eef8fe72b3bc94bfa04e0c8 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Mon, 1 Dec 2025 21:43:50 +0000 Subject: [PATCH 305/338] Revert getAllOperatorsFor changes (#167860, #162218) (#169281) This reverts commit 567dcdba757aebd92b8d2b4b1604d66f55eb5e02. This reverts commit 0e7ccc09db936d3154b5d70ce4255f2e6065cf98. These changes break downstream builds and also are the suspected culprit of a memory corruption. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169281 Approved by: https://github.com/zou3519, https://github.com/albanD --- test/cpp/jit/test_custom_operators.cpp | 13 ++-- test/custom_operator/test_custom_ops.cpp | 2 +- torch/csrc/jit/frontend/schema_matching.cpp | 2 +- torch/csrc/jit/ir/alias_analysis.cpp | 2 +- torch/csrc/jit/ir/ir.cpp | 2 +- torch/csrc/jit/python/init.cpp | 2 +- torch/csrc/jit/runtime/operator.cpp | 70 ++++++++----------- torch/csrc/jit/runtime/operator.h | 5 +- .../jit/runtime/symbolic_shape_registry.cpp | 2 +- 9 files changed, 44 insertions(+), 56 deletions(-) diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index 66295d0380629..58f87717844de 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -15,7 +15,7 @@ namespace jit { TEST(CustomOperatorTest, InferredSchema) { torch::RegisterOperators reg( "foo::bar", [](double a, at::Tensor b) { return a + b; }); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -43,7 +43,8 @@ TEST(CustomOperatorTest, ExplicitSchema) { "foo::bar_with_schema(float a, Tensor b) -> Tensor", [](double a, at::Tensor b) { return a + b; }); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); + auto& ops = + getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -76,7 +77,7 @@ TEST(CustomOperatorTest, ListParameters) { torch::List> complexdoubles, torch::List tensors) { return floats; }); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -122,7 +123,7 @@ TEST(CustomOperatorTest, ListParameters2) { "foo::lists2(Tensor[] tensors) -> Tensor[]", [](torch::List tensors) { return tensors; }); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -212,7 +213,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) { }, aliasAnalysisFromSchema())}); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); ASSERT_EQ(ops.size(), 0); } @@ -231,7 +232,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) { }, aliasAnalysisFromSchema())}); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); diff --git a/test/custom_operator/test_custom_ops.cpp b/test/custom_operator/test_custom_ops.cpp index 9791006d1498f..a526bebd26144 100644 --- a/test/custom_operator/test_custom_ops.cpp +++ b/test/custom_operator/test_custom_ops.cpp @@ -22,7 +22,7 @@ void check_all_parameters( template Result get_operator_from_registry_and_execute(const char* op_name, Args&&... args) { - auto ops = torch::jit::getAllOperatorsFor( + auto& ops = torch::jit::getAllOperatorsFor( torch::jit::Symbol::fromQualString(op_name)); TORCH_INTERNAL_ASSERT(ops.size() == 1); diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index 83742b40ae9cc..68e169824bce8 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -678,7 +678,7 @@ Value* emitBuiltinCall( at::ArrayRef args, at::ArrayRef kwargs, const std::optional& self) { - auto variants = getAllOperatorsFor(name); + const auto& variants = getAllOperatorsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name); // first let's set the graph's version diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 51dbb09db9ea0..c55d87e5c1772 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -616,7 +616,7 @@ void AliasDb::analyzeImpl(Node* node) { oss << input->type()->str() << ", "; } oss << "\n\nCandidates:"; - auto candidates = getAllOperatorsFor(node->kind()); + const auto& candidates = getAllOperatorsFor(node->kind()); for (const auto& candidate : candidates) { oss << "\n\t" << candidate->schema(); } diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index c5dfa56b48a2e..5a9abcab8e82a 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1085,7 +1085,7 @@ const FunctionSchema* Node::maybeSchema() const { const Operator* Node::maybeOperator() const { if (!op_) { - auto candidates = getAllOperatorsFor(kind()); + const auto& candidates = getAllOperatorsFor(kind()); for (const auto& candidate : candidates) { if (matches(candidate->schema())) { op_ = candidate.get(); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 82a11af3714b4..671aa5454ae5e 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -2109,7 +2109,7 @@ void initJITBindings(PyObject* module) { m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); - auto operations = getAllOperatorsFor(symbol); + const auto& operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 30105754c5ee2..478e595e78ce7 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -52,16 +52,6 @@ struct OperatorRegistry { to_register.clear(); } - const std::vector>& getOperatorsWithLockHeld( - Symbol name) { - registerPendingOperators(); - static std::vector> empty; - auto it = operators.find(name); - if (it != operators.end()) - return it->second; - return empty; - } - public: void registerOperator(Operator&& op) { std::lock_guard guard(lock); @@ -152,35 +142,14 @@ struct OperatorRegistry { return it->second; } - // This function returns internal lock-protected state. We need to - // copy it to avoid race conditions. - std::vector> getOperators(Symbol name) { + const std::vector>& getOperators(Symbol name) { std::lock_guard guard(lock); - return getOperatorsWithLockHeld(name); - } - - std::vector> getSortedOperators(Symbol name) { - std::lock_guard guard(lock); - const auto& unsortedOps = getOperatorsWithLockHeld(name); - // Depending on the order of registration, aten or jit ops may be - // registered first. This sorting is helpful in cases where - // deterministic (i.e. not dependent on build config) behavior is - // desired; e.g. torch.ops.aten.* uses this function, and tries to - // find the "first" op that matches input args. Without the sorting, - // the "first" op may change depending on registration order. - std::vector> sortedOps; - sortedOps.reserve(unsortedOps.size()); - std::copy_if( - unsortedOps.begin(), - unsortedOps.end(), - std::back_inserter(sortedOps), - [](const std::shared_ptr& op) { return op->isC10Op(); }); - std::copy_if( - unsortedOps.begin(), - unsortedOps.end(), - std::back_inserter(sortedOps), - [](const std::shared_ptr& op) { return !op->isC10Op(); }); - return sortedOps; + registerPendingOperators(); + static std::vector> empty; + auto it = operators.find(name); + if (it != operators.end()) + return it->second; + return empty; } std::vector findSimilarOperators(Symbol input_op) { @@ -417,16 +386,35 @@ void deregisterOperator(const FunctionSchema& schema) { getRegistry().deregisterOperator(schema); } -std::vector> getAllOperators() { +const std::vector> getAllOperators() { return getRegistry().getAllOperators(); } -std::vector> getAllOperatorsFor(Symbol name) { +const std::vector>& getAllOperatorsFor(Symbol name) { return getRegistry().getOperators(name); } std::vector> getAllSortedOperatorsFor(Symbol name) { - return getRegistry().getSortedOperators(name); + const auto& unsortedOps = getAllOperatorsFor(name); + // Depending on the order of registration, aten or jit ops may be + // registered first. This sorting is helpful in cases where + // deterministic (i.e. not dependent on build config) behavior is + // desired; e.g. torch.ops.aten.* uses this function, and tries to + // find the "first" op that matches input args. Without the sorting, + // the "first" op may change depending on registration order. + std::vector> sortedOps; + sortedOps.reserve(unsortedOps.size()); + std::copy_if( + unsortedOps.begin(), + unsortedOps.end(), + std::back_inserter(sortedOps), + [](const std::shared_ptr& op) { return op->isC10Op(); }); + std::copy_if( + unsortedOps.begin(), + unsortedOps.end(), + std::back_inserter(sortedOps), + [](const std::shared_ptr& op) { return !op->isC10Op(); }); + return sortedOps; } std::shared_ptr findOperatorFor(const c10::OperatorName& full_name) { diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 6b6972deeebf0..bde3825f5ea38 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -260,9 +260,8 @@ struct TORCH_API Operator { TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); -TORCH_API std::vector> getAllOperators(); -// This function returns a copy for thread safety. -TORCH_API std::vector> getAllOperatorsFor( +TORCH_API const std::vector> getAllOperators(); +TORCH_API const std::vector>& getAllOperatorsFor( Symbol name); // Returns operators in the order which OpOverloadPacket resolves them. TORCH_API std::vector> getAllSortedOperatorsFor( diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 6fb34bc2027b4..d1a42a9c16faa 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -78,7 +78,7 @@ auto compilation_unit = std::make_shared(); const std::optional getInplaceVariant( const FunctionSchema& base_schema) { - auto inplace_variants = + auto& inplace_variants = getAllOperatorsFor(c10::Symbol::fromQualString(base_schema.name() + "_")); for (const auto& variant : inplace_variants) { From 6470af76a93a007b7886a5ade7f711ea221a4a3b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 4 Dec 2025 13:23:47 -0800 Subject: [PATCH 306/338] [dynamo] Prep for dict key refactor - add hash related methods to VT (#169610) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/169610 Approved by: https://github.com/williamwen42, https://github.com/zou3519 ghstack dependencies: #169589 --- torch/_dynamo/variables/builtin.py | 9 ++++ torch/_dynamo/variables/constant.py | 25 +++++++++++ torch/_dynamo/variables/functions.py | 46 +++++++++++++++++++++ torch/_dynamo/variables/higher_order_ops.py | 9 ++++ torch/_dynamo/variables/lists.py | 34 +++++++++++++++ torch/_dynamo/variables/misc.py | 37 +++++++++++++++++ torch/_dynamo/variables/tensor.py | 28 +++++++++++++ torch/_dynamo/variables/torch.py | 9 ++++ 8 files changed, 197 insertions(+) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 40b2be0437373..9bd1bae080508 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -3268,6 +3268,15 @@ def call_contains( ) -> VariableTracker: return a.call_method(tx, "__contains__", [b], {}) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn + @contextlib.contextmanager def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]: diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 672fa1d804383..0b2eaaea80826 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -23,6 +23,7 @@ istype, np, raise_args_mismatch, + raise_on_overridden_hash, ) from .base import ValueMutationNew, VariableTracker @@ -340,6 +341,20 @@ def call_obj_hasattr( result = hasattr(self.value, name) return variables.ConstantVariable.create(result) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + # Could be an EnumVariable as well + from .tensor import SymNodeVariable + + if isinstance(other, SymNodeVariable): + return self.as_python_constant() == other.evaluate_expr() + return self.as_python_constant() == other.as_python_constant() + class EnumVariable(VariableTracker): """VariableTracker for enum.Enum and enum.IntEnum instances @@ -388,3 +403,13 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker member = getattr(self.value, name) source = self.source and AttrSource(self.source, name) return VariableTracker.build(tx, member, source=source) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index fdc2f53f82383..f493e0e1fd961 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -810,6 +810,15 @@ def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: return collected return None + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.UserFunctionVariable) and self.fn is other.fn + class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): _nonvar_fields = { @@ -1957,6 +1966,15 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return fn_var_getattr(tx, self.value, self.source, name) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class WrappedSkipFunctionVariable(SkipFunctionVariable): def __init__( @@ -2343,6 +2361,34 @@ def guard_as_python_constant(self) -> Any: **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, ) + def is_python_hashable(self) -> bool: + return ( + self.func.is_python_hashable() + and all(arg.is_python_hashable() for arg in self.args) + and all(value.is_python_hashable() for value in self.keywords.values()) + ) + + def get_python_hash(self): + func_hash = self.func.get_python_hash() + args_hash = (arg.get_python_hash() for arg in self.args) + values_hash = (value.get_python_hash() for value in self.keywords.values()) + return hash((func_hash, *args_hash, *values_hash)) + + def is_python_equal(self, other): + return ( + self.func.is_python_equal(other.func) + and all( + arg_a.is_python_equal(arg_b) + for (arg_a, arg_b) in zip(self.args, other.args) + ) + and all( + value_a.is_python_equal(value_b) + for (value_a, value_b) in zip( + self.keywords.values(), other.keywords.values() + ) + ) + ) + class PolyfilledFunctionVariable(VariableTracker): _nonvar_fields = { diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 0f7491911d35b..a4543821b19b1 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1813,6 +1813,15 @@ def _call_function( def as_python_constant(self): return self.value + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 4f21e35479fb8..a97c284f9516c 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -620,6 +620,25 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return self.items[fields.index(name)] return super().var_getattr(tx, name) + def is_python_hashable(self): + return True + + def get_python_hash(self): + l = self.range_length() + start = self.start() + step = self.step() + return hash((l, start, step)) + + def is_python_equal(self, other): + if not isinstance(other, variables.RangeVariable): + return False + + return ( + self.start() == other.start() + and self.step() == other.step() + and self.stop() == other.stop() + ) + class CommonListMethodsVariable(BaseListVariable): """ @@ -981,6 +1000,9 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) + def is_python_hashable(self): + return False + class DequeVariable(CommonListMethodsVariable): def __init__( @@ -1170,6 +1192,18 @@ def call_obj_hasattr( return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) + def is_python_hashable(self): + return all(item.is_python_hashable() for item in self.items) + + def get_python_hash(self): + items = tuple(x.get_python_hash() for x in self.items) + return hash(items) + + def is_python_equal(self, other): + return isinstance(other, variables.TupleVariable) and all( + a.is_python_equal(b) for (a, b) in zip(self.items, other.items) + ) + class SizeVariable(TupleVariable): """torch.Size(...)""" diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index c7d6e58ba4531..748d4a0985b49 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1306,6 +1306,15 @@ def is_python_constant(self): def as_python_constant(self): return self.method_wrapper + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class GetSetDescriptorVariable(VariableTracker): def __init__(self, desc, **kwargs) -> None: @@ -1440,6 +1449,15 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # codegen.append_output(codegen.create_load_const(self.value)) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): @@ -1618,6 +1636,15 @@ def as_proxy(self): return super().as_proxy() + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls class NullVariable(VariableTracker): @@ -2097,3 +2124,13 @@ def reconstruct(self, codegen: "PyCodegen"): codegen(self.referent_vt) codegen(self.callback_vt) codegen.extend_output(create_call_function(2, False)) + + def is_python_hashable(self): + return self.referent_vt.is_python_hashable() + + def get_python_hash(self): + # weakref relies on the referent's hash + return self.referent_vt.get_python_hash() + + def is_python_equal(self, other): + return self.referent_vt.is_python_equal(other.referent_vt) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 47439387e0fca..d47c520046d38 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1428,6 +1428,20 @@ def set_name_hint(self, name: str): self.proxy.node._rename(name) self._is_name_set = True + def is_python_hashable(self): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return self.as_proxy().node.meta["example_value"] is not None + + def get_python_hash(self): + return hash(self.as_proxy().node.meta["example_value"]) + + def is_python_equal(self, other): + a = self.as_proxy().node.meta["example_value"] + b = other.as_proxy().node.meta["example_value"] + return a is b + class SymNodeVariable(VariableTracker): """ @@ -1516,6 +1530,20 @@ def call_method( ), ) + def is_python_hashable(self): + return True + + def get_python_hash(self): + # Essentially convert the SymNode to a constant variable whenever its + # searched for a dict key. + return hash(self.evaluate_expr()) + + def is_python_equal(self, other): + if isinstance(other, SymNodeVariable): + return self.evaluate_expr() == other.evaluate_expr() + # could be constant variable as well + return self.evaluate_expr() == other.as_python_constant() + class NumpyNdarrayVariable(TensorVariable): """ diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index a4f940cb2adaf..19f98ea6a13b0 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -2117,6 +2117,15 @@ def torch_function_override_enabled(self, tx, args, kwargs): ) ) and can_dispatch_torch_function(tx, args, kwargs) + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + class DispatchKeySetVariable(BaseTorchVariable): """represents torch.DispatchKeySet""" From 3a3156f4c3b109968164fd10812780532a178cba Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 5 Dec 2025 01:51:04 +0000 Subject: [PATCH 307/338] Skip NVIDIA cleanup on CPU runners after #169431 (#169625) Fix this issue https://github.com/pytorch/pytorch/actions/runs/19940322037/job/57177019555 that is showing up on CPU jobs after #169431 lands Pull Request resolved: https://github.com/pytorch/pytorch/pull/169625 Approved by: https://github.com/seemethere, https://github.com/atalman --- .github/workflows/_linux-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index ee2837a7456f5..b6d49617df652 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -534,7 +534,7 @@ jobs: # As both the root cause and recovery path are unclear, let's take the runner out of # service so that it doesn't get any more jobs - name: Check NVIDIA driver installation step - if: failure() && steps.install-nvidia-driver.outcome && steps.install-nvidia-driver.outcome != 'skipped' + if: failure() && steps.install-nvidia-driver.outputs.has-nvidia == 'true' && !contains(matrix.runner, 'b200') shell: bash run: | set +e From 5213a72bd376dfae74a39184b96999566acff1d2 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 5 Dec 2025 01:58:13 +0000 Subject: [PATCH 308/338] Enable ruff SIM115 check (#169437) This PR enables the ruff check for using context managers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169437 Approved by: https://github.com/Lucaskabela, https://github.com/albanD --- .../dynamo/microbenchmarks/operatorbench.py | 32 ++-- benchmarks/sparse/spmm.py | 2 +- benchmarks/sparse/spmv.py | 2 +- benchmarks/sparse/triton_ops.py | 2 +- pyproject.toml | 1 - .../elastic/multiprocessing/tail_log_test.py | 42 ++--- test/distributed/fsdp/test_wrap.py | 15 +- test/distributed/test_c10d_common.py | 16 +- test/distributed/test_c10d_gloo.py | 3 +- test/distributed/test_c10d_nccl.py | 50 +++--- test/distributed/test_c10d_spawn.py | 2 +- test/distributed/test_c10d_spawn_gloo.py | 3 +- test/distributed/test_store.py | 59 ++++--- test/dynamo/test_guard_serialization.py | 9 +- test/dynamo/test_structured_trace.py | 2 +- test/inductor/test_cutlass_backend.py | 4 +- test/inductor/test_profiler.py | 2 +- test/inductor/test_static_cuda_launcher.py | 7 +- test/package/common.py | 2 +- test/profiler/test_execution_trace.py | 161 ++++++++++-------- test/scripts/run_cuda_memcheck.py | 2 +- test/test_serialization.py | 7 +- tools/jit/test/test_gen_unboxing.py | 44 +++-- torch/profiler/profiler.py | 8 +- torch/utils/data/dataloader.py | 7 +- 25 files changed, 257 insertions(+), 227 deletions(-) diff --git a/benchmarks/dynamo/microbenchmarks/operatorbench.py b/benchmarks/dynamo/microbenchmarks/operatorbench.py index 779bb80a454c4..31772faf619d9 100644 --- a/benchmarks/dynamo/microbenchmarks/operatorbench.py +++ b/benchmarks/dynamo/microbenchmarks/operatorbench.py @@ -261,22 +261,22 @@ def benchmark( output_csv = None if op == "all": filename = f"operatorbench_{suite}_{dtype}.csv" - output_fd = open(filename, "w") - output_csv = csv.writer(output_fd) - output_csv.writerow( - [ - "operator", - *[ - f"{a} {b}" - for a, b in itertools.product( - backend_names, - [f"{x * 100:.0f}th" for x in quantiles_thresholds], - ) - ], - "elapsed", - *map("{} abs".format, ["eager", *backend_names]), - ] - ) + with open(filename, "w") as output_fd: + output_csv = csv.writer(output_fd) + output_csv.writerow( + [ + "operator", + *[ + f"{a} {b}" + for a, b in itertools.product( + backend_names, + [f"{x * 100:.0f}th" for x in quantiles_thresholds], + ) + ], + "elapsed", + *map("{} abs".format, ["eager", *backend_names]), + ] + ) dtype = torch.float16 if dtype == "float16" else torch.float32 diff --git a/benchmarks/sparse/spmm.py b/benchmarks/sparse/spmm.py index b2c658d6faeb6..e3a505eda73c3 100644 --- a/benchmarks/sparse/spmm.py +++ b/benchmarks/sparse/spmm.py @@ -88,7 +88,7 @@ def test_sparse_coo_and_csr(m, n, k, nnz, test_count): outfile = sys.stderr need_close = False else: - outfile = open(args.outfile, "a") + outfile = open(args.outfile, "a") # noqa: SIM115 need_close = True test_count = args.test_count diff --git a/benchmarks/sparse/spmv.py b/benchmarks/sparse/spmv.py index 3e9502686a884..0166fcb15abb8 100644 --- a/benchmarks/sparse/spmv.py +++ b/benchmarks/sparse/spmv.py @@ -87,7 +87,7 @@ def test_sparse_coo_and_csr(m, nnz, test_count): outfile = sys.stderr need_close = False else: - outfile = open(args.outfile, "a") + outfile = open(args.outfile, "a") # noqa: SIM115 need_close = True test_count = args.test_count diff --git a/benchmarks/sparse/triton_ops.py b/benchmarks/sparse/triton_ops.py index a49a53bcd207c..e087eaa714b33 100644 --- a/benchmarks/sparse/triton_ops.py +++ b/benchmarks/sparse/triton_ops.py @@ -184,7 +184,7 @@ def integer_or_float_list(a): outfile = sys.stderr need_close = False else: - outfile = open(args.outfile, "a") + outfile = open(args.outfile, "a") # noqa: SIM115 need_close = True ops = args.ops.split(",") diff --git a/pyproject.toml b/pyproject.toml index 6474ddd8f5027..0d065d21aef2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,6 @@ ignore = [ "SIM108", # SIM108 ignored because we prefer if-else-block instead of ternary expression "SIM110", # Checks for for loops that can be replaced with a builtin function, like any or all. "SIM114", # Combine `if` branches using logical `or` operator - "SIM115", # Checks for cases where files are opened without using a context manager. "SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM117", "SIM300", # Yoda condition detected diff --git a/test/distributed/elastic/multiprocessing/tail_log_test.py b/test/distributed/elastic/multiprocessing/tail_log_test.py index 1ed0d5e292106..a0db8cdf12fe3 100644 --- a/test/distributed/elastic/multiprocessing/tail_log_test.py +++ b/test/distributed/elastic/multiprocessing/tail_log_test.py @@ -100,28 +100,30 @@ def test_tail_write_to_dst_file(self): } dst = os.path.join(self.test_dir, "tailed_stdout.log") - dst_file = open(dst, "w", buffering=1) - tail = TailLog( - name="writer", log_files=log_files, dst=dst_file, interval_sec=interval_sec - ).start() - # sleep here is intentional to ensure that the log tail - # can gracefully handle and wait for non-existent log files - time.sleep(interval_sec * 10) - - futs = [] - for local_rank, file in log_files.items(): - f = self.threadpool.submit( - write, max=max, sleep=interval_sec * local_rank, file=file - ) - futs.append(f) - - wait(futs, return_when=ALL_COMPLETED) - self.assertFalse(tail.stopped()) - tail.stop() - dst_file.close() + with open(dst, "w", encoding="utf8", buffering=1) as dst_file: + tail = TailLog( + name="writer", + log_files=log_files, + dst=dst_file, + interval_sec=interval_sec, + ).start() + # sleep here is intentional to ensure that the log tail + # can gracefully handle and wait for non-existent log files + time.sleep(interval_sec * 10) + + futs = [] + for local_rank, file in log_files.items(): + f = self.threadpool.submit( + write, max=max, sleep=interval_sec * local_rank, file=file + ) + futs.append(f) + + wait(futs, return_when=ALL_COMPLETED) + self.assertFalse(tail.stopped()) + tail.stop() actual: dict[int, set[int]] = {} - with open(dst) as read_dst_file: + with open(dst, encoding="utf8") as read_dst_file: for line in read_dst_file: header, num = line.split(":") nums = actual.setdefault(header, set()) diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py index aa224edaefa1d..a98b567bebf97 100644 --- a/test/distributed/fsdp/test_wrap.py +++ b/test/distributed/fsdp/test_wrap.py @@ -761,13 +761,14 @@ def test_auto_wrap_smoke_test(self, device_init_mode, cpu_offload, use_device_id os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) - file_name = tempfile.NamedTemporaryFile(delete=False).name - torch.distributed.init_process_group( - backend=backend, - init_method=f"{FILE_SCHEMA}_{file_name}", - rank=0, - world_size=1, - ) + with tempfile.NamedTemporaryFile(delete=False) as f: + file_name = f.name + torch.distributed.init_process_group( + backend=backend, + init_method=f"{FILE_SCHEMA}_{file_name}", + rank=0, + world_size=1, + ) # NOTE: We move model to GPU after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 0d11725829d26..2eceeb1098003 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -107,14 +107,13 @@ def _test_store_timeout(self, backend, init_method, c2p): c2p.append(e) def _init_methods(self): - f = tempfile.NamedTemporaryFile(delete=False) - if sys.platform == "win32": - yield "file:///{}".format(f.name.replace("\\", "/")) + with tempfile.NamedTemporaryFile(delete=False) as f: f.close() - else: - yield f"file://{f.name}" - f.close() - yield f"tcp://127.0.0.1:{common.find_free_port():d}" + if sys.platform == "win32": + yield "file:///{}".format(f.name.replace("\\", "/")) + else: + yield f"file://{f.name}" + yield f"tcp://127.0.0.1:{common.find_free_port():d}" def _test_default_store_timeout(self, backend): for init_method in self._init_methods(): @@ -140,7 +139,8 @@ def _test_default_store_timeout(self, backend): class TimeoutTest(TestCase): @retry_on_connect_failures def test_store_based_barrier(self): - f = tempfile.NamedTemporaryFile(delete=False) + f = tempfile.NamedTemporaryFile(delete=False) # noqa: SIM115 + f.close() port = common.find_free_port() def thread_work(timeout, init_type, world_size, rank, error_list): diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 07c68d5c0a465..604a6156e30eb 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2356,7 +2356,8 @@ def forward(self, x, use_fc3=True): class ReducerTest(TestCase): def setUp(self): super().setUp() - self.file = tempfile.NamedTemporaryFile(delete=False) + with tempfile.NamedTemporaryFile(delete=False) as f: + self.file = f world_size = 1 self.store = c10d.FileStore(self.file.name, world_size) c10d.init_process_group( diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 60deb3654df27..c681e0e5226fa 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -258,7 +258,8 @@ def setUp(self): super().setUp() self.rank = self.MAIN_PROCESS_RANK self.world_size = 1 - self.file = tempfile.NamedTemporaryFile(delete=False) + with tempfile.NamedTemporaryFile(delete=False) as f: + self.file = f def tearDown(self): pass @@ -4001,8 +4002,9 @@ def test_restart_pg_after_error(self): self.assertEqual(nccl_backend.get_error(), ErrorType.TIMEOUT) # we need a brand new fileStore for the new PG # the new file name is shared through the old fileStore - new_file_name = tempfile.NamedTemporaryFile(delete=False).name - store.set("file", new_file_name) + with tempfile.NamedTemporaryFile(delete=False) as f: + new_file_name = f.name + store.set("file", new_file_name) else: # other ranks not exiting before rank 0 timeout, this is to avoid # nccl error happening before rank 0 timeouts @@ -4059,21 +4061,21 @@ def test_invalid_nccl_blocking_wait_env(self): class NcclUserBufferRegistrationTest(MultiProcessTestCase): def setUp(self): super().setUp() - nccl_debug_file = tempfile.NamedTemporaryFile() - nccl_env = { - # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests - # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected. - "TORCH_NCCL_ASYNC_ERROR_HANDLING": "1", - "NCCL_ALGO": "NVLS", - "NCCL_DEBUG": "INFO", - "NCCL_DEBUG_SUBSYS": "NVLS", - "NCCL_DEBUG_FILE": nccl_debug_file.name, - } - if torch.cuda.nccl.version() >= (2, 24, 3): - nccl_env["NCCL_DEBUG_SUBSYS"] = "REG,TUNING" - self.env_patcher = mock.patch.dict(os.environ, nccl_env) - self.env_patcher.start() - self._spawn_processes() + with tempfile.NamedTemporaryFile(delete=False) as nccl_debug_file: + nccl_env = { + # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests + # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected. + "TORCH_NCCL_ASYNC_ERROR_HANDLING": "1", + "NCCL_ALGO": "NVLS", + "NCCL_DEBUG": "INFO", + "NCCL_DEBUG_SUBSYS": "NVLS", + "NCCL_DEBUG_FILE": nccl_debug_file.name, + } + if torch.cuda.nccl.version() >= (2, 24, 3): + nccl_env["NCCL_DEBUG_SUBSYS"] = "REG,TUNING" + self.env_patcher = mock.patch.dict(os.environ, nccl_env) + self.env_patcher.start() + self._spawn_processes() def tearDown(self): self.env_patcher.stop() @@ -4445,15 +4447,15 @@ def test_pass_nccl_options_config(self): pg_opts.config.cga_cluster_size = 2 pg_opts.config.net_name = "Socket" pg_opts.config.split_share = 1 - nccl_debug_file = tempfile.NamedTemporaryFile() os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name + with tempfile.NamedTemporaryFile() as nccl_debug_file: + os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name - # Tests functionality when passing nccl config - self._test_pass_nccl_options(pg_opts) + # Tests functionality when passing nccl config + self._test_pass_nccl_options(pg_opts) - # Tests if comms were configured - nccl_debug_file_content = nccl_debug_file.read() + # Tests if comms were configured + nccl_debug_file_content = nccl_debug_file.read() max_ctas = re.search(rb"Max CTAs.*(\d+)|$", nccl_debug_file_content).group(1) min_ctas = re.search(rb"Min CTAs.*(\d+)|$", nccl_debug_file_content).group(1) split_share = re.search( diff --git a/test/distributed/test_c10d_spawn.py b/test/distributed/test_c10d_spawn.py index 26e20a4f45dbe..5efa3dc2deb2d 100644 --- a/test/distributed/test_c10d_spawn.py +++ b/test/distributed/test_c10d_spawn.py @@ -34,7 +34,7 @@ class AbstractProcessGroupShareTensorTest: def _test_multiprocess(self, f, shared_tensors, init_pg, n_output): ws = self.world_size # file store will delete the test file on destruction - file = tempfile.NamedTemporaryFile(delete=False) + file = tempfile.NamedTemporaryFile(delete=False) # noqa: SIM115 ctx = mp.get_context("spawn") c2p = ctx.Queue(2) p2c = ctx.Queue(2) diff --git a/test/distributed/test_c10d_spawn_gloo.py b/test/distributed/test_c10d_spawn_gloo.py index c4667bb5dd486..97b60528f13a5 100644 --- a/test/distributed/test_c10d_spawn_gloo.py +++ b/test/distributed/test_c10d_spawn_gloo.py @@ -26,7 +26,8 @@ class DistributedDataParallelSingleProcessTest(TestCase): def setUp(self): self.rank = 0 self.world_size = 1 - self.file = tempfile.NamedTemporaryFile(delete=False) # noqa: P201 + with tempfile.NamedTemporaryFile(delete=False) as f: + self.file = f def tearDown(self): try: diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index e1412701807b6..310e41f5829a3 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -273,7 +273,8 @@ def num_keys_total(self): class FileStoreTest(TestCase, StoreTestBase): def setUp(self): super().setUp() - self.file = tempfile.NamedTemporaryFile(delete=False) + with tempfile.NamedTemporaryFile(delete=False) as f: + self.file = f def _create_store(self): store = dist.FileStore(self.file.name, 1) @@ -281,34 +282,34 @@ def _create_store(self): return store def test_init_pg_and_rpc_with_same_file(self): - file = tempfile.NamedTemporaryFile(delete=False) - # Init RPC using file - rpc_backend_options = rpc.TensorPipeRpcBackendOptions() - rpc_backend_options.init_method = f"file://{file.name}" - rpc_backend_options._transports = tp_transports() - rpc.init_rpc( - "worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options - ) + with tempfile.NamedTemporaryFile(delete=False) as file: + # Init RPC using file + rpc_backend_options = rpc.TensorPipeRpcBackendOptions() + rpc_backend_options.init_method = f"file://{file.name}" + rpc_backend_options._transports = tp_transports() + rpc.init_rpc( + "worker", rank=0, world_size=1, rpc_backend_options=rpc_backend_options + ) - # Init PG using file - dist.init_process_group( - "gloo", rank=0, world_size=1, init_method=f"file://{file.name}" - ) - dist.destroy_process_group() - assert os.path.exists(file.name) + # Init PG using file + dist.init_process_group( + "gloo", rank=0, world_size=1, init_method=f"file://{file.name}" + ) + dist.destroy_process_group() + assert os.path.exists(file.name) - rpc.shutdown() - os.remove(file.name) + rpc.shutdown() + os.remove(file.name) def test_refcount(self): - file = tempfile.NamedTemporaryFile(delete=False) - store = dist.FileStore(file.name, 1) - store2 = dist.FileStore(file.name, 1) + with tempfile.NamedTemporaryFile(delete=False) as file: + store = dist.FileStore(file.name, 1) + store2 = dist.FileStore(file.name, 1) - del store - assert os.path.exists(file.name) - del store2 - assert not os.path.exists(file.name) + del store + assert os.path.exists(file.name) + del store2 + assert not os.path.exists(file.name) @property def num_keys_total(self): @@ -327,7 +328,8 @@ class PrefixStoreTest(TestCase): def setUp(self): super().setUp() # delete is false as FileStore will automatically clean up the file - self.file = tempfile.NamedTemporaryFile(delete=False) + with tempfile.NamedTemporaryFile(delete=False) as f: + self.file = f def test_get_underlying_store(self): tcp_store = dist.TCPStore( @@ -348,7 +350,8 @@ def test_get_underlying_store(self): class PrefixFileStoreTest(TestCase, StoreTestBase): def setUp(self): super().setUp() - self.file = tempfile.NamedTemporaryFile(delete=False) + with tempfile.NamedTemporaryFile(delete=False) as f: + self.file = f self.filestore = dist.FileStore(self.file.name, 1) self.prefix = "test_prefix" self.filestore.set_timeout(timedelta(seconds=300)) @@ -977,7 +980,7 @@ def test_extended_methods_fallbacks(self): class TestMultiThreadedWait(MultiThreadedTestCase): - file_store = dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1) + file_store = dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1) # noqa: SIM115 hash_store = dist.HashStore() tcp_store = create_tcp_store(use_libuv=False) @@ -1058,7 +1061,7 @@ def run(rank, my_store): else: my_store.wait(["foo"], datetime.timedelta(seconds=10)) rank_res[rank] = True - except Error as e: # noqa: F821 + except BaseException as e: # noqa: B036,E261 rank_res[rank] = e time.sleep(1) diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 927040c1836ce..f98f758d929bd 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -1476,7 +1476,7 @@ def test_ddp_module(self): self.skipTest("Torch distributed is not available") from torch.nn.parallel import DistributedDataParallel as DDP - tmpfile = tempfile.NamedTemporaryFile() + tmpfile = tempfile.NamedTemporaryFile() # noqa: SIM115 dist.init_process_group( backend="gloo", rank=0, world_size=1, init_method=f"file://{tmpfile.name}" ) @@ -1501,6 +1501,7 @@ def foo(ddp, x): ) finally: dist.destroy_process_group() + tmpfile.close() def test_dict_keys_serialization(self): d = {1: 2, 3: 4} @@ -1526,7 +1527,7 @@ def test_unserializable_sharded_tensor(self): if not dist.is_available(): self.skipTest("Torch distributed is not available") - tmpfile = tempfile.NamedTemporaryFile() + tmpfile = tempfile.NamedTemporaryFile() # noqa:SIM115 dist.init_process_group( backend="gloo", rank=0, world_size=1, init_method=f"file://{tmpfile.name}" ) @@ -1558,6 +1559,7 @@ def foo(inputs): ) finally: dist.destroy_process_group() + tmpfile.close() def test_function_with_wrong_fqn(self): def foo(inputs): @@ -1645,7 +1647,7 @@ def test_unused_process_group(self): def foo(inputs): return inputs.x + 1 - tmpfile = tempfile.NamedTemporaryFile() + tmpfile = tempfile.NamedTemporaryFile() # noqa: SIM115 dist.init_process_group( backend="gloo", init_method=f"file://{tmpfile.name}", @@ -1660,6 +1662,7 @@ def foo(inputs): self._test_check_fn(ref, loaded, {"inputs": Inputs(x, pg)}, True) finally: dist.destroy_process_group() + tmpfile.close() def test_unserializable_submodule(self): def foo(mod, x): diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 4bd1b251f86d4..a3954f01e2045 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -222,7 +222,7 @@ def setUp(self): self.handler.addFilter(chrome_event_filter) trace_log.addHandler(self.handler) - self.raw_file = tempfile.NamedTemporaryFile( + self.raw_file = tempfile.NamedTemporaryFile( # noqa: SIM115 mode="w", delete=True ) # set this to False to keep temporary files self.raw_handler = logging.StreamHandler(self.raw_file) diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 5d9b421e562e0..828c5099ff044 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -1487,9 +1487,9 @@ def test_standalone_runner(self): assert len(sources) >= 1 # Get names for temporary source and executable files. - cu_file = NamedTemporaryFile("w", suffix=".cu", delete=False) + cu_file = NamedTemporaryFile("w", suffix=".cu", delete=False) # noqa: SIM115 cu_file.close() - exe_file = NamedTemporaryFile("w", suffix="", delete=False) + exe_file = NamedTemporaryFile("w", suffix="", delete=False) # noqa: SIM115 exe_file.close() # Save the generated code into the .cu file. diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index b4e671c9ba68e..7b13d03a209bd 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -246,7 +246,7 @@ def fn(a, b, c): with config.patch(compile_threads=1): fn(*inputs) - fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=not debug) + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=not debug) # noqa: SIM115 fp.close() with torch.profiler.profile( diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index 654bfd269f761..b63ebcb2d1e79 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -38,11 +38,10 @@ def write_cubin_to_tmp(self, kernel: CompiledKernel) -> str: return # Just used by tests for now. # TODO: derive cubin_path from wherever triton stores the cubin file on disk. - tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False) - with tmp_file: + with tempfile.NamedTemporaryFile(mode="wb", delete=False) as tmp_file: tmp_file.write(kernel.asm["cubin"]) - self.tmp_files.append(tmp_file) - return tmp_file.name + self.tmp_files.append(tmp_file) + return tmp_file.name def _make_launcher( self, diff --git a/test/package/common.py b/test/package/common.py index f522c37e17894..9328ab06faf28 100644 --- a/test/package/common.py +++ b/test/package/common.py @@ -12,7 +12,7 @@ def __init__(self, *args, **kwargs): self._temporary_files = [] def temp(self): - t = NamedTemporaryFile() + t = NamedTemporaryFile() # noqa: SIM115 name = t.name if IS_WINDOWS: t.close() # can't read an open file in windows diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index 1bc07c187fdc3..b66fa9999d2e5 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -149,23 +149,25 @@ def trace_handler(p): or torch.profiler.ProfilerActivity.HPU in supported_activities() ) # Create a temp file to save execution trace and kineto data. - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) - fp.close() - kt = tempfile.NamedTemporaryFile( - mode="w+t", suffix=".kineto.json", delete=False - ) - kt.close() - with profile( - activities=supported_activities(), - schedule=torch.profiler.schedule( - skip_first=3, wait=1, warmup=1, active=2, repeat=1 - ), - on_trace_ready=trace_handler, - execution_trace_observer=( - ExecutionTraceObserver().register_callback(fp.name) - ), - ) as p: + with ( + tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) as fp, + tempfile.NamedTemporaryFile( + mode="w+t", suffix=".kineto.json", delete=False + ) as kt, + profile( + activities=supported_activities(), + schedule=torch.profiler.schedule( + skip_first=3, wait=1, warmup=1, active=2, repeat=1 + ), + on_trace_ready=trace_handler, + execution_trace_observer=( + ExecutionTraceObserver().register_callback(fp.name) + ), + ) as p, + ): + trace_name = fp.name + kt_name = kt.name for idx in range(10): with record_function(f"## LOOP {idx} ##"): self.payload(device, use_device=use_device) @@ -176,10 +178,11 @@ def trace_handler(p): # print("Output kineto = ", kt.name) # print("Output ET = ", fp.name) - p.export_chrome_trace(kt.name) + p.export_chrome_trace(kt_name) self.assertEqual(trace_called_num, 1) - nodes = self.get_execution_trace_root(fp.name) + nodes = self.get_execution_trace_root(trace_name) + os.remove(trace_name) loop_count = 0 found_root_node = False for n in nodes: @@ -196,9 +199,10 @@ def trace_handler(p): # in terms of record func ID (rf_id) and External IDs # both of these should match for the same trace window. - with open(kt.name) as f: + with open(kt_name) as f: kineto = json.load(f) events = kineto["traceEvents"] + os.remove(kt_name) # Look up rf_ids in both Execution and Kineto trace as two lists. rf_ids_et = self.get_execution_trace_rf_ids(nodes) @@ -233,18 +237,20 @@ def trace_handler(p): or torch.profiler.ProfilerActivity.HPU in supported_activities() ) # Create a temp file to save kineto data. - kt = tempfile.NamedTemporaryFile( - mode="w+t", suffix=".kineto.json", delete=False - ) - kt.close() - with profile( - activities=supported_activities(), - schedule=torch.profiler.schedule( - skip_first=3, wait=1, warmup=1, active=2, repeat=1 - ), - on_trace_ready=trace_handler, - ) as p: + with ( + tempfile.NamedTemporaryFile( + mode="w+t", suffix=".kineto.json", delete=False + ) as kt, + profile( + activities=supported_activities(), + schedule=torch.profiler.schedule( + skip_first=3, wait=1, warmup=1, active=2, repeat=1 + ), + on_trace_ready=trace_handler, + ) as p, + ): + kt_name = kt.name for idx in range(10): with record_function(f"## LOOP {idx} ##"): self.payload(device, use_device=use_device) @@ -254,7 +260,8 @@ def trace_handler(p): # print("Output kineto = ", kt.name) # print("Output ET = ", fp.name) - p.export_chrome_trace(kt.name) + p.export_chrome_trace(kt_name) + self.assertEqual(trace_called_num, 1) et_path = p.execution_trace_observer.get_output_file_path() et_res_path = p.execution_trace_observer.get_resources_dir(et_path) @@ -282,9 +289,10 @@ def trace_handler(p): # in terms of record func ID (rf_id) and External IDs # both of these should match for the same trace window. - with open(kt.name) as f: + with open(kt_name) as f: kineto = json.load(f) events = kineto["traceEvents"] + os.remove(kt_name) # Look up rf_ids in both Execution and Kineto trace as two lists. rf_ids_et = self.get_execution_trace_rf_ids(nodes) @@ -307,11 +315,11 @@ def test_execution_trace_alone(self, device): ) # Create a temp file to save execution trace data. # Use a gzip file to test compression codepath - fp = tempfile.NamedTemporaryFile("w", suffix=".et.json.gz", delete=False) - fp.close() + with tempfile.NamedTemporaryFile("w", suffix=".et.json.gz", delete=False) as fp: + filename = fp.name expected_loop_events = 0 - et = ExecutionTraceObserver().register_callback(fp.name) + et = ExecutionTraceObserver().register_callback(filename) et.start() for idx in range(5): @@ -320,9 +328,10 @@ def test_execution_trace_alone(self, device): self.payload(device, use_device=use_device) et.stop() - assert fp.name == et.get_output_file_path() + assert filename == et.get_output_file_path() et.unregister_callback() - nodes = self.get_execution_trace_root(fp.name) + nodes = self.get_execution_trace_root(filename) + os.remove(filename) loop_count = 0 # Expected tensor object tuple size, in th form of: # [tensor_id, storage_id, offset, numel, itemsize, device_str] @@ -388,10 +397,10 @@ def fn(a, b, c): fn(*inputs) # Create a temp file to save execution trace data. - fp = tempfile.NamedTemporaryFile("w+t", suffix="_et.json", delete=False) - fp.close() + with tempfile.NamedTemporaryFile("w+t", suffix="_et.json", delete=False) as fp: + filename = fp.name et = ExecutionTraceObserver() - et.register_callback(fp.name) + et.register_callback(filename) et.set_extra_resource_collection(True) with profile( @@ -407,7 +416,8 @@ def fn(a, b, c): fn(*inputs) p.step() - nodes = self.get_execution_trace_root(fp.name) + nodes = self.get_execution_trace_root(filename) + os.remove(filename) found_captured_triton_kernel_node = False found_call_compiled_fx_graph = False for n in nodes: @@ -520,10 +530,11 @@ def fn(a, b, c): ): fn(*inputs) - fp = tempfile.NamedTemporaryFile("w+t", suffix="fx_graph_et.json", delete=False) - fp.close() et = ExecutionTraceObserver() - et.register_callback(fp.name) + with tempfile.NamedTemporaryFile( + "w+t", suffix="fx_graph_et.json", delete=False + ) as fp: + et.register_callback(fp.name) et.set_extra_resource_collection(True) with profile( activities=torch.profiler.supported_activities(), @@ -592,6 +603,7 @@ def fn(a, b, c): == '# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})' # noqa: B950 ) assert fx_graph[7] == "# return %cos" + os.remove(file_path) def test_execution_trace_start_stop(self, device): use_device = ( @@ -600,10 +612,10 @@ def test_execution_trace_start_stop(self, device): or torch.profiler.ProfilerActivity.HPU in supported_activities() ) # Create a temp file to save execution trace data. - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) - fp.close() + with tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) as fp: + filename = fp.name expected_loop_events = 0 - et = ExecutionTraceObserver().register_callback(fp.name) + et = ExecutionTraceObserver().register_callback(filename) for idx in range(10): if idx == 3: et.start() @@ -618,9 +630,10 @@ def test_execution_trace_start_stop(self, device): with record_function(f"## LOOP {idx} ##"): self.payload(device, use_device=use_device) - assert fp.name == et.get_output_file_path() + assert filename == et.get_output_file_path() et.unregister_callback() - nodes = self.get_execution_trace_root(fp.name) + nodes = self.get_execution_trace_root(filename) + os.remove(filename) loop_count = 0 found_root_node = False for n in nodes: @@ -644,10 +657,11 @@ def test_execution_trace_repeat_in_loop(self, device): for idx in range(10): if idx in iter_list: # Create a temp file to save execution trace data. - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) - fp.close() - output_files.append(fp.name) - et = ExecutionTraceObserver().register_callback(fp.name) + with tempfile.NamedTemporaryFile( + "w+t", suffix=".et.json", delete=False + ) as fp: + output_files.append(fp.name) + et = ExecutionTraceObserver().register_callback(fp.name) et.start() with record_function(f"## LOOP {idx} ##"): self.payload(device, use_device=use_device) @@ -685,14 +699,12 @@ def test_execution_trace_no_capture(self): @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500") def test_execution_trace_nested_tensor(self): - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) - fp.close() - - observer = ExecutionTraceObserver().register_callback(fp.name) - def fn(nt): return nt.sin().cos() + with tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) as fp: + observer = ExecutionTraceObserver().register_callback(fp.name) + filename = fp.name with torch.profiler.profile(execution_trace_observer=observer): for i in range(3): values = torch.rand((8 + i, 4 + i)) @@ -700,7 +712,8 @@ def fn(nt): nt = torch.nested.nested_tensor_from_jagged(values, offsets) fn(nt) - nodes = self.get_execution_trace_root(fp.name) + nodes = self.get_execution_trace_root(filename) + os.remove(filename) found_cos = False for n in nodes: assert "name" in n @@ -713,26 +726,28 @@ def fn(nt): "need CUDA device availability to run", ) def test_execution_trace_record_integral_tensor_range(self): - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) - fp.close() - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_RANGE"] = "1" t1 = torch.tensor([[1, 2], [3, 4]]).cuda() t2 = torch.tensor([[0, 0], [1, 0]]).cuda() - with profile( - activities=supported_activities(), - schedule=torch.profiler.schedule( - skip_first=0, wait=0, warmup=0, active=1, repeat=1 - ), - record_shapes=True, - execution_trace_observer=( - ExecutionTraceObserver().register_callback(fp.name) - ), - ) as p: + with ( + tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) as fp, + profile( + activities=supported_activities(), + schedule=torch.profiler.schedule( + skip_first=0, wait=0, warmup=0, active=1, repeat=1 + ), + record_shapes=True, + execution_trace_observer=( + ExecutionTraceObserver().register_callback(fp.name) + ), + ) as p, + ): + filename = fp.name torch.gather(t1, 1, t2) p.step() - nodes = self.get_execution_trace_root(fp.name) + nodes = self.get_execution_trace_root(filename) + os.remove(filename) for n in nodes: assert "name" in n if "aten::gather" in n["name"]: diff --git a/test/scripts/run_cuda_memcheck.py b/test/scripts/run_cuda_memcheck.py index ca3196f4f4910..df17a89747d26 100755 --- a/test/scripts/run_cuda_memcheck.py +++ b/test/scripts/run_cuda_memcheck.py @@ -137,7 +137,7 @@ def is_cpu_only(name): # or as specified by the user progress = 0 if not args.ci: - logfile = open("result.log", "w") + logfile = open("result.log", "w") # noqa:SIM115 progressbar = tqdm.tqdm(total=len(ALL_TESTS)) else: logfile = sys.stdout diff --git a/test/test_serialization.py b/test/test_serialization.py index 39f8b7735663f..da6512d456609 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -384,15 +384,14 @@ def test_serialization_dill(self): def test_serialization_offset_gzip(self): a = torch.randn(5, 5) i = 41 - f2 = tempfile.NamedTemporaryFile(delete=False) - with tempfile.NamedTemporaryFile() as f1: + with TemporaryFileName() as tmp_file, tempfile.NamedTemporaryFile() as f1: pickle.dump(i, f1) torch.save(a, f1) f1.seek(0) - with gzip.open(f2.name, 'wb') as f_out: + with gzip.open(tmp_file, 'wb') as f_out: shutil.copyfileobj(f1, f_out) - with gzip.open(f2.name, 'rb') as f: + with gzip.open(tmp_file, 'rb') as f: j = pickle.load(f) b = torch.load(f) self.assertTrue(torch.equal(a, b)) diff --git a/tools/jit/test/test_gen_unboxing.py b/tools/jit/test/test_gen_unboxing.py index 975342aad0f7a..6e2aa23495d08 100644 --- a/tools/jit/test/test_gen_unboxing.py +++ b/tools/jit/test/test_gen_unboxing.py @@ -28,18 +28,17 @@ def test_get_custom_build_selector_with_allowlist_yaml( mock_parse_native_yaml: NonCallableMock, mock_get_custom_build_selector: NonCallableMock, ) -> None: - temp_file = tempfile.NamedTemporaryFile() - temp_file.write(b"- aten::add.Tensor") - temp_file.seek(0) - args = [ - f"--TEST-ONLY-op-registration-allowlist-yaml-path={temp_file.name}", - "--op-selection-yaml-path=path2", - ] - gen_unboxing.main(args) - mock_get_custom_build_selector.assert_called_once_with( - ["aten::add.Tensor"], "path2" - ) - temp_file.close() + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(b"- aten::add.Tensor") + temp_file.seek(0) + args = [ + f"--TEST-ONLY-op-registration-allowlist-yaml-path={temp_file.name}", + "--op-selection-yaml-path=path2", + ] + gen_unboxing.main(args) + mock_get_custom_build_selector.assert_called_once_with( + ["aten::add.Tensor"], "path2" + ) def test_get_custom_build_selector_with_both_allowlist_and_yaml( self, @@ -48,17 +47,16 @@ def test_get_custom_build_selector_with_both_allowlist_and_yaml( mock_parse_native_yaml: NonCallableMock, mock_get_custom_build_selector: NonCallableMock, ) -> None: - temp_file = tempfile.NamedTemporaryFile() - temp_file.write(b"- aten::add.Tensor") - temp_file.seek(0) - args = [ - "--op-registration-allowlist=op1", - f"--TEST-ONLY-op-registration-allowlist-yaml-path={temp_file.name}", - "--op-selection-yaml-path=path2", - ] - gen_unboxing.main(args) - mock_get_custom_build_selector.assert_called_once_with(["op1"], "path2") - temp_file.close() + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(b"- aten::add.Tensor") + temp_file.seek(0) + args = [ + "--op-registration-allowlist=op1", + f"--TEST-ONLY-op-registration-allowlist-yaml-path={temp_file.name}", + "--op-selection-yaml-path=path2", + ] + gen_unboxing.main(args) + mock_get_custom_build_selector.assert_called_once_with(["op1"], "path2") if __name__ == "__main__": diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 151a41af919e4..be2cddd7f3cf7 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -965,16 +965,18 @@ def build_execution_trace_obs_from_env() -> Optional["ExecutionTraceObserver"]: """ if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE", "0") == "1": try: - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) # noqa:SIM115 + with tempfile.NamedTemporaryFile( + "w+t", suffix=".et.json", delete=False + ) as fp: + filename = fp.name except Exception as e: warn( f"Execution trace will not be recorded. Exception on creating default temporary file: {e}", stacklevel=2, ) return None - fp.close() et = ExecutionTraceObserver() - et.register_callback(fp.name) + et.register_callback(filename) # additionally, check if the env requires us to collect extra resources if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS", "0") == "1": et.set_extra_resource_collection(True) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index e01422708f791..9f2cd710faf6e 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -8,6 +8,7 @@ from __future__ import annotations +import contextlib import functools import itertools import logging @@ -1334,7 +1335,11 @@ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): # test. # See NOTE [ DataLoader on Linux and open files limit ] fds_limit_margin = 10 - [tempfile.NamedTemporaryFile() for _ in range(fds_limit_margin)] # noqa: SIM115 + with contextlib.ExitStack() as stack: + for _ in range(fds_limit_margin): + stack.enter_context( + tempfile.NamedTemporaryFile() # pyrefly: ignore [bad-argument-type] + ) except OSError as e: if e.errno == errno.EMFILE: raise RuntimeError( From d914547caf98b021539fb295c67fcbd2975ed80d Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 4 Dec 2025 14:36:53 -0800 Subject: [PATCH 309/338] Make TORCH_BOX support const ref arguments (#169563) This PR removes const and reference for arguments in `unbox_type_t` before the type mapper (`UnboxType`) is applied For a given argument `a` of type `T` in the function schema, `unbox_type_t` is used to map `T->T1` in order to determine which `to(a)` to call to convert that StableIValue into a value in the tuple of arguments passed to the function. - stripping & from T& to call `to` rather than `to` is ok because the boxing kernel needs to take ownership of the arguments in the argument tuple - stripping `const` from `const T` is ok because if the function has the type `const T` for an arg, the StableIValue passed in should be `T` (`to` is not meaningful) Pull Request resolved: https://github.com/pytorch/pytorch/pull/169563 Approved by: https://github.com/albanD ghstack dependencies: #168380, #169385 --- .../make_tensor_clones_and_call_foreach.cpp | 2 +- .../csrc/my__foreach_mul.cpp | 3 +- .../csrc/my__foreach_mul_vec.cpp | 27 ++++++ .../libtorch_agnostic_2_10/csrc/my_empty.cpp | 4 +- .../csrc/my_string_op_variants.cpp | 64 +++++++++++++ .../csrc/test_device_is_cpu.cpp | 3 +- .../csrc/test_device_is_cuda.cpp | 3 +- .../libtorch_agnostic_2_10/ops.py | 62 +++++++++++++ .../libtorch_agnostic_2_9/csrc/kernel.cpp | 38 +++++++- .../libtorch_agnostic_2_9/ops.py | 16 ++++ test/cpp_extensions/test_libtorch_agnostic.py | 93 +++++++++++++++++++ torch/csrc/stable/library.h | 5 +- 12 files changed, 311 insertions(+), 9 deletions(-) create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_vec.cpp create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op_variants.cpp diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp index d3dbab5891394..57607c3ffa0f7 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/make_tensor_clones_and_call_foreach.cpp @@ -8,7 +8,7 @@ using torch::stable::Tensor; // Declare my__foreach_mul (defined in my__foreach_mul.cpp) extern std::vector my__foreach_mul( - torch::headeronly::HeaderOnlyArrayRef self, + const torch::headeronly::HeaderOnlyArrayRef& self, torch::headeronly::HeaderOnlyArrayRef other); // Helper function for cloning diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp index 834a63afea646..69d8dda388b0f 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul.cpp @@ -5,7 +5,8 @@ using torch::stable::Tensor; -std::vector my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { +// This is used to test const torch::headeronly::HeaderOnlyArrayRef& with TORCH_BOX +std::vector my__foreach_mul(const torch::headeronly::HeaderOnlyArrayRef& self, torch::headeronly::HeaderOnlyArrayRef other) { std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data()); return torch::stable::detail::to>(stack[0]); diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_vec.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_vec.cpp new file mode 100644 index 0000000000000..f857de94fa32f --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my__foreach_mul_vec.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include + +using torch::stable::Tensor; + +// This is used to test const std::vector& with TORCH_BOX +std::vector my__foreach_mul_vec( + const std::vector& self, + const std::vector& other) { + std::array stack = { + torch::stable::detail::from(self), torch::stable::detail::from(other)}; + aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data()); + return torch::stable::detail::to>(stack[0]); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my__foreach_mul_vec(Tensor[] self, Tensor[] other) -> Tensor[]"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl("my__foreach_mul_vec", TORCH_BOX(&my__foreach_mul_vec)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp index 4b17b113135e6..0e78d484bf9df 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp @@ -10,8 +10,8 @@ using torch::stable::Tensor; Tensor my_empty( torch::headeronly::HeaderOnlyArrayRef size, std::optional dtype, - std::optional layout, - std::optional device, + std::optional& layout, + const std::optional& device, std::optional pin_memory, std::optional memory_format) { return empty(size, dtype, layout, device, pin_memory, memory_format); diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op_variants.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op_variants.cpp new file mode 100644 index 0000000000000..c60d8bcaaf711 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op_variants.cpp @@ -0,0 +1,64 @@ +// This file is intended to test (const) std::string& and const std::string_view& arguments with TORCH_BOX +#include +#include +#include + +#include +#include + +using torch::stable::Tensor; + +// Helper function to process accessor +static int64_t process_accessor(Tensor t, std::string_view accessor) { + if (accessor == "dim") { + return t.dim(); + } else if (accessor == "size") { + return t.size(0); + } else if (accessor == "stride") { + return t.stride(0); + } else { + STD_TORCH_CHECK(false, "Unsupported accessor value: ", std::string(accessor).c_str()) + } +} + +// Test const std::string& +std::tuple, int64_t> my_string_op_const_string_ref( + Tensor t, + const std::string& accessor, + const std::string& passthru) { + int64_t res = process_accessor(t, accessor); + auto vec = std::vector({accessor, std::to_string(res), passthru}); + return std::make_tuple(vec, res); +} + +// Test const std::string_view& +std::tuple, int64_t> my_string_op_const_string_view_ref( + Tensor t, + const std::string_view& accessor, + const std::string_view& passthru) { + int64_t res = process_accessor(t, accessor); + auto vec = std::vector({std::string(accessor), std::to_string(res), std::string(passthru)}); + return std::make_tuple(vec, res); +} + +// Test std::string& (non-const) +std::tuple, int64_t> my_string_op_string_ref( + Tensor t, + std::string& accessor, + std::string& passthru) { + int64_t res = process_accessor(t, accessor); + auto vec = std::vector({accessor, std::to_string(res), passthru}); + return std::make_tuple(vec, res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_string_op_const_string_ref(Tensor t, str accessor, str passthru) -> (str[], int)"); + m.def("my_string_op_const_string_view_ref(Tensor t, str accessor, str passthru) -> (str[], int)"); + m.def("my_string_op_string_ref(Tensor t, str accessor, str passthru) -> (str[], int)"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_string_op_const_string_ref", TORCH_BOX(&my_string_op_const_string_ref)); + m.impl("my_string_op_const_string_view_ref", TORCH_BOX(&my_string_op_const_string_view_ref)); + m.impl("my_string_op_string_ref", TORCH_BOX(&my_string_op_string_ref)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cpu.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cpu.cpp index 58e1af91dfd50..020eb427847d1 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cpu.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cpu.cpp @@ -1,7 +1,8 @@ #include #include -bool test_device_is_cpu(torch::stable::Device device) { +// This is used to test torch::stable::Device& with TORCH_BOX +bool test_device_is_cpu(torch::stable::Device& device) { return device.is_cpu(); } diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cuda.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cuda.cpp index e08709f30c2d7..61e6cd801046b 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cuda.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_device_is_cuda.cpp @@ -1,7 +1,8 @@ #include #include -bool test_device_is_cuda(torch::stable::Device device) { +// This is used to test const torch::stable::Device& with TORCH_BOX +bool test_device_is_cuda(const torch::stable::Device& device) { return device.is_cuda(); } diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index b063961575cb7..f429b48851620 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -368,3 +368,65 @@ def test_std_cuda_kernel_launch_check_error() -> None: This should raise a RuntimeError with the CUDA kernel launch error message. """ torch.ops.libtorch_agnostic_2_10.test_std_cuda_kernel_launch_check_error.default() + + +def my__foreach_mul_vec(tensors, others) -> list[Tensor]: + """ + Returns a list of tensors that are the results of pointwise multiplying + tensors and others. This variant tests const std::vector& parameters. + + Args: + tensors: list of tensors + others: list of tensors (with the same corresponding shapes as tensors) + + Returns: list of multiplied tensors + """ + return torch.ops.libtorch_agnostic_2_10.my__foreach_mul_vec.default(tensors, others) + + +def my_string_op_const_string_ref(t, accessor, passthru) -> tuple[list[str], int]: + """ + Tests TORCH_BOX with const std::string& parameters. + + Args: + t: Tensor - input tensor to query + accessor: str - which property to access ("dim", "size", or "stride") + passthru: str - a string that gets returned as the last element of the list + + Returns: tuple - (list of [accessor, value, passthru] as strings, value) + """ + return torch.ops.libtorch_agnostic_2_10.my_string_op_const_string_ref.default( + t, accessor, passthru + ) + + +def my_string_op_const_string_view_ref(t, accessor, passthru) -> tuple[list[str], int]: + """ + Tests TORCH_BOX with const std::string_view& parameters. + + Args: + t: Tensor - input tensor to query + accessor: str - which property to access ("dim", "size", or "stride") + passthru: str - a string that gets returned as the last element of the list + + Returns: tuple - (list of [accessor, value, passthru] as strings, value) + """ + return torch.ops.libtorch_agnostic_2_10.my_string_op_const_string_view_ref.default( + t, accessor, passthru + ) + + +def my_string_op_string_ref(t, accessor, passthru) -> tuple[list[str], int]: + """ + Tests TORCH_BOX with std::string& (non-const) parameters. + + Args: + t: Tensor - input tensor to query + accessor: str - which property to access ("dim", "size", or "stride") + passthru: str - a string that gets returned as the last element of the list + + Returns: tuple - (list of [accessor, value, passthru] as strings, value) + """ + return torch.ops.libtorch_agnostic_2_10.my_string_op_string_ref.default( + t, accessor, passthru + ) diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp index cf50a4d70e6d7..9f7ecacb1d3ed 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp @@ -228,11 +228,13 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { return transpose(t, dim0, dim1); } -Tensor my_empty_like(Tensor t) { +// This is used to test const torch::stable::Tensor& with TORCH_BOX +Tensor my_empty_like(const Tensor& t) { return empty_like(t); } -bool my_is_cpu(Tensor t) { +// This is used to test torch::stable::Tensor& with TORCH_BOX +bool my_is_cpu(Tensor& t) { return t.is_cpu(); } @@ -444,3 +446,35 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { m.impl("my_flatten", TORCH_BOX(&my_flatten)); } + +// Test function for const std::optional& with TORCH_BOX +// Returns the tensor if present, otherwise returns a zeros tensor of specified size +Tensor my_optional_tensor_ref( + const std::optional& maybe_tensor, + int64_t default_size) { + if (maybe_tensor.has_value()) { + return maybe_tensor.value(); + } + // Create a zeros tensor as default + AtenTensorHandle zeros_ath; + int64_t sizes[] = {default_size}; + int64_t strides[] = {1}; + aoti_torch_empty_strided( + 1, + sizes, + strides, + aoti_torch_dtype_float32(), + aoti_torch_device_type_cpu(), + 0, + &zeros_ath); + Tensor zeros_tensor(zeros_ath); + return zero_(zeros_tensor); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("my_optional_tensor_ref(Tensor? maybe_tensor, int default_size) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("my_optional_tensor_ref", TORCH_BOX(&my_optional_tensor_ref)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py index 04a1377836554..488b53be13bd9 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py @@ -361,3 +361,19 @@ def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor: Returns: Tensor - flattened tensor """ return torch.ops.libtorch_agnostic_2_9.my_flatten.default(t, start_dim, end_dim) + + +def my_optional_tensor_ref(maybe_tensor, default_size) -> Tensor: + """ + Tests TORCH_BOX with const std::optional& parameter. + Returns the tensor if present, otherwise returns a zeros tensor of specified size. + + Args: + maybe_tensor: Optional[Tensor] - optional input tensor + default_size: int - size of the default zeros tensor if maybe_tensor is None + + Returns: Tensor - the input tensor or a zeros tensor + """ + return torch.ops.libtorch_agnostic_2_9.my_optional_tensor_ref.default( + maybe_tensor, default_size + ) diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index d06099a6b7cf7..370faced72634 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -726,6 +726,23 @@ def test_my_flatten(self, device): expected_range = torch.flatten(t, 2, -1) self.assertEqual(result_range, expected_range) + @onlyCPU + @xfailIfTorchDynamo + def test_my_optional_tensor_ref(self, device): + """Test TORCH_BOX with const std::optional& parameter.""" + import libtorch_agnostic_2_9 as libtorch_agnostic + + # Test with a tensor provided + t = torch.randn(5, device=device) + result = libtorch_agnostic.ops.my_optional_tensor_ref(t, 10) + self.assertEqual(result, t) + + # Test with None (should return zeros tensor of specified size) + result_none = libtorch_agnostic.ops.my_optional_tensor_ref(None, 7) + expected_zeros = torch.zeros(7) + self.assertEqual(result_none, expected_zeros) + self.assertEqual(result_none.shape, (7,)) + @skipIfTorchVersionLessThan(2, 10) def test_my_reshape(self, device): import libtorch_agnostic_2_10 as libtorch_agnostic @@ -860,6 +877,82 @@ def test_my_string_op(self, device): with self.assertRaisesRegex(RuntimeError, "Unsupported accessor value: "): libtorch_agnostic.ops.my_string_op(t, "invalid", "") + @skipIfTorchVersionLessThan(2, 10) + def test_my__foreach_mul_vec(self, device): + """Test my__foreach_mul_vec which uses const std::vector& parameters.""" + import libtorch_agnostic_2_10 as libtorch_agnostic + + N = 5 + tensors = [torch.rand(32, 16, device=device) for _ in range(N)] + others = [torch.rand(32, 16, device=device) for _ in range(N)] + + result = libtorch_agnostic.ops.my__foreach_mul_vec(tensors, others) + expected = torch._foreach_mul(tensors, others) + + for result_t, expected_t in zip(result, expected): + self.assertEqual(result_t, expected_t) + + @skipIfTorchVersionLessThan(2, 10) + def test_my_string_op_const_string_ref(self, device): + """Test my_string_op_const_string_ref which uses const std::string& parameters.""" + import libtorch_agnostic_2_10 as libtorch_agnostic + + t = torch.empty(3, 4, 5, device=device) + + dim_vec, result_dim = libtorch_agnostic.ops.my_string_op_const_string_ref( + t, "dim", "test1" + ) + self.assertEqual(dim_vec, ["dim", str(t.dim()), "test1"]) + self.assertEqual(result_dim, t.dim()) + + size_vec, result_size = libtorch_agnostic.ops.my_string_op_const_string_ref( + t, "size", "test2" + ) + self.assertEqual(size_vec, ["size", str(t.size(0)), "test2"]) + self.assertEqual(result_size, t.size(0)) + + @skipIfTorchVersionLessThan(2, 10) + def test_my_string_op_const_string_view_ref(self, device): + """Test my_string_op_const_string_view_ref which uses const std::string_view& parameters.""" + import libtorch_agnostic_2_10 as libtorch_agnostic + + t = torch.empty(3, 4, 5, device=device) + + dim_vec, result_dim = ( + libtorch_agnostic.ops.my_string_op_const_string_view_ref( + t, "dim", "view1" + ) + ) + self.assertEqual(dim_vec, ["dim", str(t.dim()), "view1"]) + self.assertEqual(result_dim, t.dim()) + + stride_vec, result_stride = ( + libtorch_agnostic.ops.my_string_op_const_string_view_ref( + t, "stride", "view2" + ) + ) + self.assertEqual(stride_vec, ["stride", str(t.stride(0)), "view2"]) + self.assertEqual(result_stride, t.stride(0)) + + @skipIfTorchVersionLessThan(2, 10) + def test_my_string_op_string_ref(self, device): + """Test my_string_op_string_ref which uses std::string& (non-const) parameters.""" + import libtorch_agnostic_2_10 as libtorch_agnostic + + t = torch.empty(3, 4, 5, device=device) + + dim_vec, result_dim = libtorch_agnostic.ops.my_string_op_string_ref( + t, "dim", "ref1" + ) + self.assertEqual(dim_vec, ["dim", str(t.dim()), "ref1"]) + self.assertEqual(result_dim, t.dim()) + + size_vec, result_size = libtorch_agnostic.ops.my_string_op_string_ref( + t, "size", "ref2" + ) + self.assertEqual(size_vec, ["size", str(t.size(0)), "ref2"]) + self.assertEqual(result_size, t.size(0)) + @skipIfTorchVersionLessThan(2, 10) @onlyCUDA def test_my_get_current_cuda_stream(self, device): diff --git a/torch/csrc/stable/library.h b/torch/csrc/stable/library.h index ac6d252f757a1..39377cad87437 100644 --- a/torch/csrc/stable/library.h +++ b/torch/csrc/stable/library.h @@ -136,8 +136,11 @@ struct UnboxType { using type = std::string; }; +// const and reference are stripped before UnboxType is applied +// in order to avoid ambiguous template matches template -using unbox_type_t = typename UnboxType::type; +using unbox_type_t = + typename UnboxType>>::type; template std::tuple unbox_to_tuple_impl( From 9f7c46bbb85e112e92de53bab6dbc2ceb1af53da Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Thu, 4 Dec 2025 11:39:49 -0800 Subject: [PATCH 310/338] [Dynamo][Guard]Add the user-friendly TYPE_MATCH for type (#169025) Fix #168160 after the opensource PR https://github.com/pytorch/pytorch/pull/168272 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169025 Approved by: https://github.com/anijain2305 --- test/dynamo/test_check_type_id.py | 139 ++++++++++++++++++++++++++++++ torch/_dynamo/guards.py | 10 ++- 2 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 test/dynamo/test_check_type_id.py diff --git a/test/dynamo/test_check_type_id.py b/test/dynamo/test_check_type_id.py new file mode 100644 index 0000000000000..3d9c5efb38c1a --- /dev/null +++ b/test/dynamo/test_check_type_id.py @@ -0,0 +1,139 @@ +# Owner(s): ["module: dynamo"] +""" +Test for TYPE_MATCH guard and ___check_type_id function. + +This test demonstrates how the TYPE_MATCH guard works in PyTorch Dynamo. +When a function is compiled, Dynamo installs guards to ensure the compiled +code remains valid. TYPE_MATCH guards ensure that values maintain their +exact type (using type identity, not just type equality). +""" + +import re + +import torch +import torch._dynamo +import torch._dynamo.test_case +from torch._dynamo.eval_frame import _debug_get_cache_entry_list +from torch.testing._internal.common_utils import munge_exc + + +class TestCheckTypeId(torch._dynamo.test_case.TestCase): + @staticmethod + def _find_guard_lines(guard_manager_str: str, keyword: str) -> list[str]: + # Normalize and anonymize type IDs, then return lines containing the keyword + normalized = re.sub( + r"\d{7,}", "", munge_exc(guard_manager_str), flags=re.MULTILINE + ) + pattern = re.compile(rf"^.*{re.escape(keyword)}.*$", re.MULTILINE) + return pattern.findall(normalized) + + def test_type_match_with_different_values(self): + """ + Test that TYPE_MATCH guard correctly identifies type mismatches. + + This test compiles a function that uses a global variable and verifies: + 1. The compiled function works with values of the same type + 2. The function recompiles when the type changes + 3. The ___check_type_id/check_obj_id guard is present in the generated code + 4. The check_type_id should present the user-friendly code that specify the type + """ + + # Define a global variable that we'll guard on + class Config: + multiplier = 2 # int type + + def fn(x): + # This will trigger a TYPE_MATCH guard on Config.multiplier + return x * Config.multiplier + + # Compile the function + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + # First call - should compile and install guards + x = torch.randn(4) + result1 = opt_fn(x) + expected1 = x * 2 + self.assertTrue(torch.allclose(result1, expected1)) + + # Get the cache entry to inspect guards + cache_entries = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(cache_entries), 1) + + # Check that the guard string contains check_type_id + guard_str = str(cache_entries[0].guard_manager) + matches = self._find_guard_lines(guard_str, "ID_MATCH") + self.assertIn("___check_obj_id", matches[0]) + self.assertIn( + "type=.Config'>", + matches[0], + ) + # Match the first part (everything before "type=") + first_part = matches[0].split("type=")[0] + expected_first_part = ( + "| | +- ID_MATCH: ___check_obj_id(L['Config'], ), " + ) + self.assertEqual(first_part, expected_first_part) + + # Match the second part (the type string) + second_part = matches[0].split("type=")[1].rstrip() + expected_second_part = ( + "TestCheckTypeId.test_type_match_with_different_values..Config'>" + ) + self.assertIn(expected_second_part, second_part) + + def test_type_match_with_custom_classes(self): + """ + Test TYPE_MATCH guard with custom class instances. + + Demonstrates that the guard checks type identity, not structural equality. + """ + + class Point: + def __init__(self, x, y): + self.x = x + self.y = y + + class Point2D: + def __init__(self, x, y): + self.x = x + self.y = y + + point = Point(1, 2) + + def fn(tensor): + # Access point's attributes, triggering TYPE_MATCH guard on point + return tensor + point.x + point.y + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + # First call with Point instance + x = torch.ones(4) + result1 = opt_fn(x) + expected1 = x + 1 + 2 + self.assertTrue(torch.allclose(result1, expected1)) + + # Verify guard contains check_type_id + cache_entries = _debug_get_cache_entry_list(fn.__code__) + self.assertEqual(len(cache_entries), 1) + + guard_str = str(cache_entries[0].guard_manager) + matches = self._find_guard_lines(guard_str, "TYPE_MATCH") + # Match the first part (everything before "type=") + first_part = matches[0].split("type=")[0] + expected_first_part = ( + "| | +- TYPE_MATCH: ___check_type_id(L['point'], ), " + ) + self.assertEqual(first_part, expected_first_part) + + # Match the second part (the type string) + second_part = matches[0].split("type=")[1].rstrip() + expected_second_part = ( + "TestCheckTypeId.test_type_match_with_custom_classes..Point'>" + ) + self.assertIn(expected_second_part, second_part) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index e9097c592af9f..ea720d5c49f5f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1958,7 +1958,7 @@ def TYPE_MATCH(self, guard: Guard) -> None: obj_id = self.id_ref(t, f"type({guard.name})") type_repr = repr(t) - code = f"___check_type_id({self.arg_ref(guard)}, {obj_id}) # {type_repr}" + code = f"___check_type_id({self.arg_ref(guard)}, {obj_id}), type={type_repr}" self._set_guard_export_info(guard, [code]) self.get_guard_manager(guard).add_type_match_guard( @@ -2060,7 +2060,13 @@ def id_match_unchecked( ref = self.arg_ref(guard) val = self.get(guard) id_val = self.id_ref(val, guard.name) - code = f"___check_obj_id({ref}, {id_val})" + try: + type_repr = repr(val) + except Exception: + # During deepcopy reconstruction or other state transitions, + # objects may be in an incomplete state where repr() fails + type_repr = f"<{type(val).__name__}>" + code = f"___check_obj_id({ref}, {id_val}), type={type_repr}" self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH") self.get_guard_manager(guard).add_id_match_guard( id_val, get_verbose_code_parts(code, guard, recompile_hint) From 79e8d6a8674b30e69c2c4773c5efe75d6696980c Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 5 Dec 2025 02:28:47 +0000 Subject: [PATCH 311/338] [12/N] Use Python 3.10 typing (#169355) This PR applies Python 3.10 typing syntax to some files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169355 Approved by: https://github.com/Lucaskabela, https://github.com/albanD --- torch/distributed/_pycute/int_tuple.py | 16 +++++----- torch/distributed/_pycute/layout.py | 31 ++++++++++--------- .../_shard/sharding_spec/_internals.py | 7 ++--- .../sharding_spec/chunk_sharding_spec.py | 10 +++--- torch/multiprocessing/reductions.py | 5 +-- torch/multiprocessing/spawn.py | 7 ++--- torch/utils/_cxx_pytree.py | 6 ++-- torch/utils/_debug_mode.py | 12 +++---- 8 files changed, 42 insertions(+), 52 deletions(-) diff --git a/torch/distributed/_pycute/int_tuple.py b/torch/distributed/_pycute/int_tuple.py index b060edde22817..bb3406a7399b1 100644 --- a/torch/distributed/_pycute/int_tuple.py +++ b/torch/distributed/_pycute/int_tuple.py @@ -36,14 +36,14 @@ from functools import reduce from itertools import chain -from typing import Optional, TypeAlias, Union +from typing import TypeAlias from typing_extensions import TypeIs from .typing import Integer # Type aliases for better readability -IntTuple: TypeAlias = Union[int, tuple["IntTuple", ...]] +IntTuple: TypeAlias = int | tuple["IntTuple", ...] def is_int(x: object) -> TypeIs[int]: @@ -168,9 +168,7 @@ def suffix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple: return init -def idx2crd( - idx: IntTuple, shape: IntTuple, stride: Optional[IntTuple] = None -) -> IntTuple: +def idx2crd(idx: IntTuple, shape: IntTuple, stride: IntTuple | None = None) -> IntTuple: if stride is None: stride = suffix_product(shape) @@ -190,7 +188,7 @@ def idx2crd( def crd2idx( - crd: Optional[IntTuple], shape: IntTuple, stride: Optional[IntTuple] = None + crd: IntTuple | None, shape: IntTuple, stride: IntTuple | None = None ) -> int: if stride is None: stride = suffix_product(shape) @@ -222,7 +220,7 @@ def crd2idx( # Transform crd into the dst_shape's iteration space def crd2crd( - crd: IntTuple, dst_shape: IntTuple, src_shape: Optional[IntTuple] = None + crd: IntTuple, dst_shape: IntTuple, src_shape: IntTuple | None = None ) -> IntTuple: if is_tuple(crd): if is_tuple(dst_shape): # tuple tuple @@ -241,7 +239,7 @@ def crd2crd( # Filter trg according to crd: keep only elements of trg that are paired with None -def slice_(crd: Union[None, tuple, int], trg: Union[tuple, int]) -> Union[tuple, int]: +def slice_(crd: None | tuple | int, trg: tuple | int) -> tuple | int: if is_tuple(crd): if is_tuple(trg): # tuple tuple assert len(crd) == len(trg) @@ -264,7 +262,7 @@ def slice_(crd: Union[None, tuple, int], trg: Union[tuple, int]) -> Union[tuple, # Determine if None appears at any of an int_tuples' terminals -def has_none(a: Union[None, tuple, int]) -> bool: +def has_none(a: None | tuple | int) -> bool: if is_tuple(a): return any(has_none(v) for v in a) else: diff --git a/torch/distributed/_pycute/layout.py b/torch/distributed/_pycute/layout.py index 04ae5d1fa5fdb..0adf94b5b142b 100644 --- a/torch/distributed/_pycute/layout.py +++ b/torch/distributed/_pycute/layout.py @@ -36,8 +36,8 @@ """ from itertools import chain -from typing import Optional, TypeAlias, Union -from typing_extensions import TypeIs +from typing import TypeAlias +from typing_extensions import Self, TypeIs from .int_tuple import ( crd2idx, @@ -53,12 +53,9 @@ # Type aliases -LayoutOrIntTuple: TypeAlias = Union["Layout", IntTuple] -LayoutProfile: TypeAlias = Optional[Union[tuple[object, ...], "Layout"]] -LayoutInput: TypeAlias = Optional[Union["Layout", IntTuple, tuple[object, ...]]] -CoordinateType: TypeAlias = Optional[ - Union[int, IntTuple, tuple[object, ...]] -] # Input for slice_ and crd2idx functions +CoordinateType: TypeAlias = ( + int | IntTuple | tuple[object, ...] | None +) # Input for slice_ and crd2idx functions class LayoutBase: @@ -70,7 +67,7 @@ def is_layout(x: object) -> TypeIs["Layout"]: class Layout(LayoutBase): - def __init__(self, _shape: IntTuple, _stride: Optional[IntTuple] = None) -> None: + def __init__(self, _shape: IntTuple, _stride: IntTuple | None = None) -> None: self.shape = _shape if _stride is None: self.stride = suffix_product(self.shape) @@ -91,7 +88,7 @@ def __len__(self) -> int: return 1 # operator () (map coord to idx) - def __call__(self, *args: CoordinateType) -> Union["Layout", int]: + def __call__(self, *args: CoordinateType) -> Self | int: """ Map a logical coordinate to a linear index (Coord has no Underscore slice operators) OR @@ -111,7 +108,7 @@ def __call__(self, *args: CoordinateType) -> Union["Layout", int]: return crd2idx(args, self.shape, self.stride) # type: ignore[arg-type] # operator [] (get-i like tuples) - def __getitem__(self, i: int) -> "Layout": + def __getitem__(self, i: int) -> Self: if is_tuple(self.shape): return Layout(self.shape[i], self.stride[i]) # type: ignore[index] else: @@ -135,8 +132,14 @@ def __repr__(self) -> str: return f"Layout({self.shape},{self.stride})" +# Type aliases +LayoutOrIntTuple: TypeAlias = Layout | IntTuple +LayoutProfile: TypeAlias = tuple[object, ...] | Layout | None +LayoutInput: TypeAlias = Layout | IntTuple | tuple[object, ...] | None + + # Make Layout from a list of layouts (each layout it's own mode in the result) -def make_layout(*layouts: Union[Layout, tuple[Layout, ...]]) -> Layout: +def make_layout(*layouts: Layout | tuple[Layout, ...]) -> Layout: if len(layouts) == 1 and not is_layout(layouts[0]): layouts = layouts[0] @@ -321,7 +324,7 @@ def complement(layout: LayoutOrIntTuple, max_idx: int = 1) -> Layout: # Layout right inverse -def right_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]: +def right_inverse(layout: LayoutOrIntTuple | None) -> Layout | None: if layout is None: return None elif is_int(layout): @@ -350,7 +353,7 @@ def right_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]: # Layout left inverse -def left_inverse(layout: Optional[LayoutOrIntTuple]) -> Optional[Layout]: +def left_inverse(layout: LayoutOrIntTuple | None) -> Layout | None: if layout is None: return None elif is_int(layout): diff --git a/torch/distributed/_shard/sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py index 9825edd352c1f..486c62a18cd7b 100644 --- a/torch/distributed/_shard/sharding_spec/_internals.py +++ b/torch/distributed/_shard/sharding_spec/_internals.py @@ -2,7 +2,6 @@ import math import sys from bisect import bisect_right, insort -from typing import Optional from torch.distributed._shard.metadata import ShardMetadata @@ -28,7 +27,7 @@ def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetad def _find_nd_overlapping_shards( shards: list[ShardMetadata], sharded_dims: list[int] -) -> Optional[tuple[int, int]]: +) -> tuple[int, int] | None: """Find overlapping shards using sweep-line algorithm.""" if len(shards) <= 1: return None @@ -76,7 +75,7 @@ def _find_nd_overlapping_shards( def _find_1d_overlapping_shards( shards: list[ShardMetadata], dim: int -) -> Optional[tuple[int, int]]: +) -> tuple[int, int] | None: # (begin, end, index_in_shards). Begin and end are inclusive. intervals = [ (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1, i) @@ -112,7 +111,7 @@ def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]): sharded_dims.append(dim) break - pair: Optional[tuple[int, int]] = None + pair: tuple[int, int] | None = None if len(sharded_dims) == 0: # if shard is all zeros, we should consider as pass all_zeros: bool = all( diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py index d4cd5728b2a16..4d7b11b7c16c5 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from dataclasses import dataclass -from typing import cast, Optional, TYPE_CHECKING, Union +from typing import cast, TYPE_CHECKING import torch import torch.distributed as dist @@ -50,10 +50,10 @@ class ChunkShardingSpec(ShardingSpec): :class:`torch.distributed._remote_device` """ - ShardingDim = Union[int, str] + ShardingDim = int | str dim: ShardingDim - placements: list[Union[torch.distributed._remote_device, str]] + placements: list[torch.distributed._remote_device | str] def __post_init__(self): self._verify_dim(self.dim) @@ -134,7 +134,7 @@ def shard( local_metadata = None tensors_to_scatter = cast( - list[Optional[torch.Tensor]], + list[torch.Tensor | None], [None] * dist.get_world_size(process_group), ) @@ -196,7 +196,7 @@ def shard( process_group, src_for_scatter ) - tensors_to_scatter_: Optional[list[torch.Tensor]] = None + tensors_to_scatter_: list[torch.Tensor] | None = None if current_rank == src_rank: tensors_to_scatter_ = [] for t in tensors_to_scatter: diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index cbd6eee571f13..b0b6d9468bcb5 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -4,7 +4,6 @@ import threading from multiprocessing import reduction from multiprocessing.util import register_after_fork -from typing import Union import torch from torch._namedtensor_internals import check_serializing_named_tensor @@ -551,9 +550,7 @@ def rebuild_storage_fd(cls, df, size): def rebuild_storage_filename(cls, manager, handle, size, dtype=None): - storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache( - cls, handle - ) + storage: torch.TypedStorage | torch.UntypedStorage = storage_from_cache(cls, handle) if storage is not None: return storage._shared_decref() if dtype is None: diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index f53be2ebe0392..83cfea4b80d33 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -10,7 +10,6 @@ import time import warnings from concurrent.futures import as_completed, ThreadPoolExecutor -from typing import Optional from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] @@ -66,7 +65,7 @@ def __init__( error_index: int, error_pid: int, exit_code: int, - signal_name: Optional[str] = None, + signal_name: str | None = None, ): super().__init__(msg, error_index, error_pid) self.exit_code = exit_code @@ -118,9 +117,7 @@ def _join_procs_with_timeout(self, timeout: float): time_to_wait = max(0, end - time.monotonic()) process.join(time_to_wait) - def join( - self, timeout: Optional[float] = None, grace_period: Optional[float] = None - ): + def join(self, timeout: float | None = None, grace_period: float | None = None): r"""Join one or more processes within spawn context. Attempt to join one or more processes in this spawn context. diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index e88209398302b..3c6f79bfe2243 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -16,7 +16,7 @@ import sys import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, overload, TypeAlias, TypeVar, Union +from typing import Any, overload, TypeAlias, TypeVar from typing_extensions import deprecated, Self, TypeIs import torch.utils._pytree as python_pytree @@ -270,7 +270,7 @@ def _private_register_pytree_node( def _is_pytreespec_instance( obj: Any, /, -) -> TypeIs[Union[TreeSpec, python_pytree.PyTreeSpec]]: +) -> TypeIs[TreeSpec | python_pytree.PyTreeSpec]: if isinstance(obj, (TreeSpec, python_pytree.PyTreeSpec)): return True if "torch._dynamo.polyfills.pytree" in sys.modules: @@ -612,7 +612,7 @@ def tree_map_( Type2 = tuple[type[T], type[S]] Type3 = tuple[type[T], type[S], type[U]] -TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] +TypeAny = type[Any] | tuple[type[Any], ...] | types.UnionType Fn2 = Callable[[T | S], R] Fn3 = Callable[[T | S | U], R] diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index abe9f6aa59ae1..d579e957e0234 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -39,7 +39,7 @@ import traceback import weakref from collections.abc import Callable -from typing import Any, TYPE_CHECKING, Union # noqa: F401 +from typing import Any, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -157,9 +157,7 @@ def to_str(x): return str(arg) -def norm_hash_fn( - t: torch.Tensor, use_scalar: bool = False -) -> Union[torch.Tensor, float]: +def norm_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor | float: """ from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous, replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128. @@ -188,9 +186,7 @@ def _compute_rel_diff(hash1, hash2): return numerator / denominator -def hash_tensor_fn( - t: torch.Tensor, use_scalar: bool = False -) -> Union[torch.Tensor, int]: +def hash_tensor_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor | int: """ wrapper over torch.hash_tensor """ @@ -933,7 +929,7 @@ def dispatch_hook(func, types, args, kwargs, result): @staticmethod @contextlib.contextmanager def log_tensor_hashes( - hash_fn: Union[Callable, str, list[str]] = "norm", hash_inputs: bool = False + hash_fn: Callable | str | list[str] = "norm", hash_inputs: bool = False ): """ Installs hook for tensor hash logging. From 2169e666d2424a84c266b694c9f97a42f786e30e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 5 Dec 2025 02:38:01 +0000 Subject: [PATCH 312/338] Revert "[Inductor] ReLU/GELU(Addmm) fusions (#168157)" This reverts commit 2da3bafb30b91de44fba1b9aecce1147cb64e679. Reverted https://github.com/pytorch/pytorch/pull/168157 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it fails test_linear_relu ([comment](https://github.com/pytorch/pytorch/pull/168157#issuecomment-3615072958)) --- test/inductor/test_pattern_matcher.py | 68 +------------------ torch/_inductor/fx_passes/post_grad.py | 63 ----------------- .../serialized_patterns/addmm_gelu_pattern.py | 43 ------------ .../serialized_patterns/addmm_relu_pattern.py | 35 ---------- torchgen/fuse/gen_patterns.py | 3 +- 5 files changed, 3 insertions(+), 209 deletions(-) delete mode 100644 torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py delete mode 100644 torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index f7e795f53f90d..9928b89b81e64 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -1216,70 +1216,6 @@ def fn2(inp, a, b): _, (code) = run_and_get_code(fn2, args[0], args[1], args[2]) FileCheck().check_not("extern_kernels.addmm(").run(code[0]) - @skipIfRocm - def test_addmm_activation_fusion(self): - """ - Test whether Activation(Addmm) implies _addmm_activation - """ - - b = torch.rand(4, device=GPU_TYPE) - m1 = torch.rand(3, 2, device=GPU_TYPE) - m2 = torch.rand(2, 4, device=GPU_TYPE) - alphas = ({"alpha": 0.8}, {}) # **{} -> alpha=1 - betas = ({"beta": 1}, {}) # **{} -> beta=1 - - # Cases Activation(Addmm) -> _addmm_activation - fusable_activations = ( - torch.nn.functional.relu, - # NOTE: only approximate="tanh" is fusable - lambda *args, **kwargs: torch.nn.functional.gelu( - *args, approximate="tanh", **kwargs - ), - ) - for activation in fusable_activations: - - def f(b, m1, m2, beta, alpha): - return activation(torch.addmm(b, m1, m2, **beta, **alpha)) - - fc = torch.compile(f) - - for beta, alpha in itertools.product(betas, alphas): - expected = f(b, m1, m2, beta, alpha) - actual = fc(b, m1, m2, beta, alpha) - torch.testing.assert_close(expected, actual) - - _, (code) = run_and_get_code(fc, b, m1, m2, beta, alpha) - self.assertIn("_addmm_activation", code[0]) - - # Check no disruptions in the gemm autotune process - _, (code) = run_and_get_code( - torch.compile(f, options={"max_autotune_gemm": True}), - b, - m1, - m2, - beta, - alpha, - ) - self.assertNotIn("_addmm_activation", code[0]) - - # Cases Activation(Addmm) -> Activation(Addmm) - non_fusable_activations = ( - torch.nn.functional.gelu, # implies approximate="none" - lambda *args, **kwargs: torch.nn.functional.gelu( - *args, approximate="none", **kwargs - ), - ) - for activation in non_fusable_activations: - - def f(b, m1, m2, beta, alpha): - return activation(torch.addmm(b, m1, m2, **beta, **alpha)) - - fc = torch.compile(f) - - for beta, alpha in itertools.product(betas, alphas): - _, (code) = run_and_get_code(fc, b, m1, m2, beta, alpha) - self.assertNotIn("_addmm_activation", code[0]) - def test_addmm_alpha_beta_with_pointwise(self): # Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops # See https://github.com/pytorch/pytorch/issues/167313 @@ -1288,7 +1224,7 @@ def test_addmm_alpha_beta_with_pointwise(self): b = torch.rand(3, 2, device=GPU_TYPE) def f(x, a, b): - return torch.abs(torch.addmm(x, a, b, alpha=0.8, beta=0.2)) + return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2)) fc = torch.compile(f) @@ -1305,7 +1241,7 @@ def f(x, a, b): # Test with alpha=1, beta=1 (default) - should also unfuse def f_default(x, a, b): - return torch.abs(torch.addmm(x, a, b)) + return torch.nn.functional.relu(torch.addmm(x, a, b)) fc_default = torch.compile(f_default) expected_default = f_default(x, a, b) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 2f73bf6ae86c4..a21e78821e52b 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -34,7 +34,6 @@ CallFunctionVarArgs, filter_nodes, fwd_only, - gen_register_replacement, get_arg_value, get_mutation_region_id, Ignored, @@ -689,66 +688,6 @@ def body_fn(*flat_args): raise AssertionError("scan is not lowered to while_loop") -@functools.cache -def register_addmm_activation_fusions(): - def is_valid_addmm_activation_fusion(match: Match) -> bool: - # Exclude ROCm - if torch.version.hip: - return False - - if config.max_autotune_gemm: - return False - - inp = match.kwargs["inp"].meta["val"] - - if not inp.is_cuda: - return False - - output = match.output_node() - return not all( - is_pointwise_use(use, lambda target: torch.Tag.reduction in target.tags) - for use in output.users - ) - - args = [torch.empty(3), torch.empty(4, 2), torch.empty(2, 3)] - beta_alpha_workaround = {"beta": 1.3, "alpha": 1.2} - - def addmm_relu_pattern(inp, m1, m2, beta, alpha): - return aten.relu(aten.addmm(inp, m1, m2, beta=beta, alpha=alpha)) - - def addmm_gelu_pattern(inp, m1, m2, beta, alpha): - return aten.gelu( - aten.addmm(inp, m1, m2, beta=beta, alpha=alpha), approximate="tanh" - ) - - def addmm_relu_replacement(inp, m1, m2, beta, alpha): - return aten._addmm_activation(inp, m1, m2, beta=beta, alpha=alpha) - - def addmm_gelu_replacement(inp, m1, m2, beta, alpha): - return aten._addmm_activation( - inp, m1, m2, beta=beta, alpha=alpha, use_gelu=True - ) - - patterns = (addmm_relu_pattern, addmm_gelu_pattern) - replacements = (addmm_relu_replacement, addmm_gelu_replacement) - for pattern, replacement in zip(patterns, replacements): - key = f"{pattern.__name__}" - gen_register_replacement( - key, - # pyrefly: ignore [bad-argument-type] - pattern, - # pyrefly: ignore [bad-argument-type] - replacement, - args, - # pyrefly: ignore [bad-argument-type] - trace_fn=fwd_only, - # pyrefly: ignore [bad-argument-type] - pass_dicts=pass_patterns[1], - extra_check=is_valid_addmm_activation_fusion, - scalar_workaround=beta_alpha_workaround, - ) - - @init_once_fakemode def lazy_init(): if torch._C._has_mkldnn: @@ -774,8 +713,6 @@ def lazy_init(): extra_check=prepare_softmax_extra_check, ) - register_addmm_activation_fusions() - def reorder_for_locality(graph: torch.fx.Graph): if torch.distributed.is_available(): diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py deleted file mode 100644 index f991015b4de69..0000000000000 --- a/torch/_inductor/fx_passes/serialized_patterns/addmm_gelu_pattern.py +++ /dev/null @@ -1,43 +0,0 @@ -# mypy: ignore-errors - -# noqa: F401, E501 -# This is an auto-generated file. Please do not modify it by hand. -# To re-generate, run: -# cd ~/pytorch && python torchgen/fuse/gen_patterns.py - -import torch -import torch._inductor -import operator - -aten = torch.ops.aten -prims = torch.ops.prims - -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - CallFunctionVarArgs, - CallMethod, - CallMethodVarArgs, - CallModule, - CallModuleVarArgs, - ExclusiveKeywordArg, - Ignored, - KeywordArg, - ListOf, - MultiOutputPattern, - PatternExpr, - RepeatedExpr, - _TargetArgsExpr, - _TargetExpr, - _TargetExprVarArgs, -) -addmm_default = CallFunction(aten.addmm.default, KeywordArg('inp'), KeywordArg('m1'), KeywordArg('m2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=4) -mul_Tensor = CallFunction(aten.mul.Tensor, addmm_default, Ignored()) -mul_Tensor_1 = CallFunction(aten.mul.Tensor, addmm_default, addmm_default) -mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, addmm_default) -mul_Tensor_3 = CallFunction(aten.mul.Tensor, mul_Tensor_2, Ignored()) -add_Tensor = CallFunction(aten.add.Tensor, addmm_default, mul_Tensor_3) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, add_Tensor, Ignored()) -tanh_default = CallFunction(aten.tanh.default, mul_Tensor_4) -add_Tensor_1 = CallFunction(aten.add.Tensor, tanh_default, Ignored()) -addmm_gelu_pattern = CallFunction(aten.mul.Tensor, mul_Tensor, add_Tensor_1, _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py b/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py deleted file mode 100644 index e9729a7787131..0000000000000 --- a/torch/_inductor/fx_passes/serialized_patterns/addmm_relu_pattern.py +++ /dev/null @@ -1,35 +0,0 @@ -# mypy: ignore-errors - -# noqa: F401, E501 -# This is an auto-generated file. Please do not modify it by hand. -# To re-generate, run: -# cd ~/pytorch && python torchgen/fuse/gen_patterns.py - -import torch -import torch._inductor -import operator - -aten = torch.ops.aten -prims = torch.ops.prims - -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - CallFunctionVarArgs, - CallMethod, - CallMethodVarArgs, - CallModule, - CallModuleVarArgs, - ExclusiveKeywordArg, - Ignored, - KeywordArg, - ListOf, - MultiOutputPattern, - PatternExpr, - RepeatedExpr, - _TargetArgsExpr, - _TargetExpr, - _TargetExprVarArgs, -) -addmm_default = CallFunction(aten.addmm.default, KeywordArg('inp'), KeywordArg('m1'), KeywordArg('m2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha')) -addmm_relu_pattern = CallFunction(aten.relu.default, addmm_default, _users=0) diff --git a/torchgen/fuse/gen_patterns.py b/torchgen/fuse/gen_patterns.py index b4bdf022202ba..0861c882e3fff 100644 --- a/torchgen/fuse/gen_patterns.py +++ b/torchgen/fuse/gen_patterns.py @@ -2,7 +2,7 @@ import os from torch._inductor import pattern_matcher -from torch._inductor.fx_passes import joint_graph, post_grad +from torch._inductor.fx_passes import joint_graph if __name__ == "__main__": @@ -17,4 +17,3 @@ # to serialize the patterns as it goes. os.environ["PYTORCH_GEN_PATTERNS"] = "1" joint_graph.lazy_init() - post_grad.lazy_init() From 053cbc1d694ae10b3ace786b024d5f3ee7033116 Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 4 Dec 2025 13:59:32 -0800 Subject: [PATCH 313/338] [dynamo] assign random IDs to graph break registry (#169617) This should help reduce merge conflicts and revert pain. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169617 Approved by: https://github.com/anijain2305 --- tools/dynamo/gb_id_mapping.py | 20 +++++- tools/linter/adapters/gb_registry_linter.py | 34 ++++++++++ tools/test/test_gb_registry_linter.py | 71 ++++++++++++++++++++- 3 files changed, 120 insertions(+), 5 deletions(-) diff --git a/tools/dynamo/gb_id_mapping.py b/tools/dynamo/gb_id_mapping.py index f7ec2347ba92e..ede34ce642090 100644 --- a/tools/dynamo/gb_id_mapping.py +++ b/tools/dynamo/gb_id_mapping.py @@ -1,6 +1,7 @@ import argparse import ast import json +import random import re from pathlib import Path from typing import Any @@ -23,8 +24,23 @@ def save_registry(reg: dict[str, Any], path: Path) -> None: def next_gb_id(reg: dict[str, Any]) -> str: - ids = [int(x[2:]) for x in reg if x.startswith("GB") and x[2:].isdigit()] - return f"GB{(max(ids, default=-1) + 1):04d}" + """Generate a random unused GB ID from GB0000-GB9999 range.""" + used_ids = set(reg.keys()) + max_attempts = 100 + + # Try random selection first + for _ in range(max_attempts): + candidate = f"GB{random.randint(0, 9999):04d}" + if candidate not in used_ids: + return candidate + + # Fallback: find first available ID if random selection keeps colliding + for i in range(10000): + candidate = f"GB{i:04d}" + if candidate not in used_ids: + return candidate + + raise RuntimeError("No available GB IDs in range GB0000-GB9999") def clean_string(s: Any) -> Any: diff --git a/tools/linter/adapters/gb_registry_linter.py b/tools/linter/adapters/gb_registry_linter.py index ac6bfc3264d51..442ca2a3f5fae 100644 --- a/tools/linter/adapters/gb_registry_linter.py +++ b/tools/linter/adapters/gb_registry_linter.py @@ -181,6 +181,40 @@ def check_registry_sync(dynamo_dir: Path, registry_path: Path) -> list[LintMessa calls = {gb_type: calls[0] for gb_type, calls in all_calls.items()} registry = load_registry(registry_path) + + # Check for duplicate gb_types across different GB IDs in the registry + gb_type_to_ids: dict[str, list[str]] = {} + for gb_id, entries in registry.items(): + gb_type = entries[0]["Gb_type"] + if gb_type not in gb_type_to_ids: + gb_type_to_ids[gb_type] = [] + gb_type_to_ids[gb_type].append(gb_id) + + duplicate_gb_types_in_registry = [ + (gb_type, ids) for gb_type, ids in gb_type_to_ids.items() if len(ids) > 1 + ] + + if duplicate_gb_types_in_registry: + for gb_type, ids in duplicate_gb_types_in_registry: + description = ( + f"The gb_type '{gb_type}' appears in multiple GB IDs: {', '.join(sorted(ids))}. " + f"Each gb_type must map to exactly one GB ID. Please manually fix the registry." + ) + lint_messages.append( + LintMessage( + path=str(registry_path), + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="Duplicate gb_type in registry", + original=None, + replacement=None, + description=description, + ) + ) + return lint_messages + latest_entry: dict[str, Any] = { entries[0]["Gb_type"]: entries[0] for entries in registry.values() } diff --git a/tools/test/test_gb_registry_linter.py b/tools/test/test_gb_registry_linter.py index 837e5910a4abb..dd619deaa1b7c 100644 --- a/tools/test/test_gb_registry_linter.py +++ b/tools/test/test_gb_registry_linter.py @@ -51,8 +51,13 @@ def test_case1_new_gb_type(self): messages = check_registry_sync(self.test_data_dir, self.registry_path) + # Parse the replacement to get the actual GB ID that was generated + self.assertEqual(len(messages), 1) + replacement_registry = json.loads(messages[0].replacement) + gb_id = next(iter(replacement_registry.keys())) + expected_registry = { - "GB0000": [ + gb_id: [ { "Gb_type": "testing", "Context": "testing", @@ -271,6 +276,12 @@ def test(self): original_content = f.read() messages = check_registry_sync(self.test_data_dir, self.registry_path) + + # Parse the replacement to get the actual GB ID that was generated + self.assertEqual(len(messages), 1) + replacement_registry = json.loads(messages[0].replacement) + new_gb_id = next(k for k in replacement_registry if k != "GB0000") + expected_registry = { "GB0000": [ { @@ -280,7 +291,7 @@ def test(self): "Hints": ["original_hint"], } ], - "GB0001": [ + new_gb_id: [ { "Gb_type": "completely_new_testing", "Context": "completely_new_context", @@ -349,8 +360,13 @@ def test(self): messages = check_registry_sync(self.test_data_dir, self.registry_path) + # Parse the replacement to get the actual GB ID that was generated + self.assertEqual(len(messages), 1) + replacement_registry = json.loads(messages[0].replacement) + gb_id = next(iter(replacement_registry.keys())) + expected_registry = { - "GB0000": [ + gb_id: [ { "Gb_type": "testing_with_graph_break_hints", "Context": "testing_with_graph_break_hints", @@ -392,6 +408,55 @@ def test(self): mock_hints_file.unlink() init_py.unlink() + def test_case7_duplicate_gb_type_in_registry(self): + """Test Case 7: Detecting duplicate gb_types across different GB IDs in the registry.""" + registry_data = { + "GB0000": [ + { + "Gb_type": "duplicate_type", + "Context": "context1", + "Explanation": "explanation1", + "Hints": ["hint1"], + } + ], + "GB0042": [ + { + "Gb_type": "duplicate_type", + "Context": "context2", + "Explanation": "explanation2", + "Hints": ["hint2"], + } + ], + } + with open(self.registry_path, "w") as f: + json.dump(registry_data, f, indent=2) + + # Create a callsite with one of the duplicate types + callsite_content = """from torch._dynamo.exc import unimplemented +def test(self): + unimplemented(gb_type="duplicate_type", context="context1", explanation="explanation1", hints=["hint1"]) +""" + with open(self.callsite_file, "w") as f: + f.write(callsite_content) + + messages = check_registry_sync(self.test_data_dir, self.registry_path) + + expected_msg = LintMessage( + path=str(self.registry_path), + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="Duplicate gb_type in registry", + original=None, + replacement=None, + description=( + "The gb_type 'duplicate_type' appears in multiple GB IDs: GB0000, GB0042. " + "Each gb_type must map to exactly one GB ID. Please manually fix the registry." + ), + ) + self.assertEqual(messages, [expected_msg]) + if __name__ == "__main__": unittest.main() From 240780a8bf14a4abe0f88053ae52a1502d0a5697 Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 4 Dec 2025 15:25:20 -0800 Subject: [PATCH 314/338] [dynamo] add new graph break registry entries via linter to random spots (#169624) Randomize graph break registry linter placement to further reduce merge conflicts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169624 Approved by: https://github.com/anijain2305 ghstack dependencies: #169617 --- tools/dynamo/gb_id_mapping.py | 3 +- tools/linter/adapters/gb_registry_linter.py | 27 +++++++++++++ tools/test/test_gb_registry_linter.py | 42 +++++++++++---------- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/tools/dynamo/gb_id_mapping.py b/tools/dynamo/gb_id_mapping.py index ede34ce642090..d5da4703dcd19 100644 --- a/tools/dynamo/gb_id_mapping.py +++ b/tools/dynamo/gb_id_mapping.py @@ -203,7 +203,8 @@ def create_registry(dynamo_dir: str, registry_path: str) -> None: for info in calls: gb_types[info["gb_type"]] = info - GB_ID_INDEX = 0000 + # Use sequential IDs for initial registry creation + GB_ID_INDEX = 0 for i, (gb_type, info) in enumerate(sorted(gb_types.items()), GB_ID_INDEX): gb_id = f"GB{i:04d}" hints = info["hints"] diff --git a/tools/linter/adapters/gb_registry_linter.py b/tools/linter/adapters/gb_registry_linter.py index 442ca2a3f5fae..e71ec83646df6 100644 --- a/tools/linter/adapters/gb_registry_linter.py +++ b/tools/linter/adapters/gb_registry_linter.py @@ -4,6 +4,7 @@ import argparse import json +import random import sys from enum import Enum from pathlib import Path @@ -109,6 +110,9 @@ def _update_registry_with_changes( del latest_entry[old_gb_type] del gb_type_to_key[old_gb_type] + # Collect new entries separately to insert them all at once + new_entries: list[tuple[str, list[dict[str, Any]]]] = [] + for gb_type, (call, file_path) in calls.items(): if gb_type in latest_entry: existing_entry = latest_entry[gb_type] @@ -126,12 +130,35 @@ def _update_registry_with_changes( registry_key ] else: + # Collect new entries to add later new_key = next_gb_id(updated_registry) new_entry = _create_registry_entry( gb_type, call["context"], call["explanation"], call["hints"] ) + new_entries.append((new_key, [new_entry])) + # Temporarily add to updated_registry so next_gb_id works correctly updated_registry[new_key] = [new_entry] + # Insert all new entries at the same random position to reduce merge conflicts + if new_entries: + # Remove temporarily added entries + for new_key, _ in new_entries: + del updated_registry[new_key] + + registry_items = list(updated_registry.items()) + if registry_items: + # Pick one random position for all new entries + insert_pos = random.randint(0, len(registry_items)) + # Insert all new entries at the same position + for new_key, new_entry in new_entries: + registry_items.insert(insert_pos, (new_key, new_entry)) + insert_pos += 1 # Keep them together + updated_registry = dict(registry_items) + else: + # Empty registry, just add all entries + for new_key, new_entry in new_entries: + updated_registry[new_key] = new_entry + return updated_registry diff --git a/tools/test/test_gb_registry_linter.py b/tools/test/test_gb_registry_linter.py index dd619deaa1b7c..2a4cc7e65be6c 100644 --- a/tools/test/test_gb_registry_linter.py +++ b/tools/test/test_gb_registry_linter.py @@ -280,26 +280,30 @@ def test(self): # Parse the replacement to get the actual GB ID that was generated self.assertEqual(len(messages), 1) replacement_registry = json.loads(messages[0].replacement) - new_gb_id = next(k for k in replacement_registry if k != "GB0000") - expected_registry = { - "GB0000": [ - { - "Gb_type": "original_testing", - "Context": "original_context", - "Explanation": "original_explanation", - "Hints": ["original_hint"], - } - ], - new_gb_id: [ - { - "Gb_type": "completely_new_testing", - "Context": "completely_new_context", - "Explanation": "completely_new_explanation", - "Hints": ["completely_new_hint"], - } - ], - } + # Build expected_registry in the same order as replacement_registry + # since random insertion means order is not deterministic + expected_registry = {} + for gb_id in replacement_registry: + if gb_id == "GB0000": + expected_registry[gb_id] = [ + { + "Gb_type": "original_testing", + "Context": "original_context", + "Explanation": "original_explanation", + "Hints": ["original_hint"], + } + ] + else: + expected_registry[gb_id] = [ + { + "Gb_type": "completely_new_testing", + "Context": "completely_new_context", + "Explanation": "completely_new_explanation", + "Hints": ["completely_new_hint"], + } + ] + expected_replacement = ( json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n" ) From 365a6c84db516f244b7234b7aa3c8843af52936b Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Fri, 5 Dec 2025 02:48:23 +0000 Subject: [PATCH 315/338] [BE] Upgrade XPU support package to 2025.3 (#166829) Follows #166723. Including below changes, - Add XPU support package 2025.3 build and test in CI for both Linux and Windows - Keep XPU support package 2025.2 build in CI to ensure no break issue until PyTorch 2.10 release - Upgrade XPU support package from 2025.2 to 2025.3 in CD for both Linux and Windows - Update XPU runtime pypi packages dependencies of CD wheels Pull Request resolved: https://github.com/pytorch/pytorch/pull/166829 Approved by: https://github.com/atalman --- .ci/docker/build.sh | 4 +- .ci/docker/common/install_xpu.sh | 8 ++-- .ci/docker/manywheel/Dockerfile_2_28 | 2 +- .ci/pytorch/windows/internal/xpu_install.bat | 10 ++--- .circleci/scripts/binary_windows_build.sh | 2 +- .circleci/scripts/binary_windows_test.sh | 2 +- .../scripts/generate_binary_build_matrix.py | 41 ++++++++++--------- ...nerated-linux-binary-manywheel-nightly.yml | 14 +++---- ...generated-windows-binary-wheel-nightly.yml | 14 +++---- .github/workflows/xpu.yml | 4 +- 10 files changed, 51 insertions(+), 50 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index e175be2a6df4d..0e8caf69b3192 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -204,7 +204,7 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 VISION=yes - XPU_VERSION=2025.1 + XPU_VERSION=2025.2 NINJA_VERSION=1.9.0 TRITON=yes ;; @@ -212,7 +212,7 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=13 VISION=yes - XPU_VERSION=2025.2 + XPU_VERSION=2025.3 NINJA_VERSION=1.9.0 TRITON=yes if [[ $tag =~ "benchmarks" ]]; then diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh index a29de2cecb870..806272fcd0ee8 100644 --- a/.ci/docker/common/install_xpu.sh +++ b/.ci/docker/common/install_xpu.sh @@ -148,11 +148,11 @@ if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then XPU_DRIVER_VERSION="/lts/2523" fi -# Default use Intel® oneAPI Deep Learning Essentials 2025.1 -if [[ "$XPU_VERSION" == "2025.2" ]]; then - XPU_PACKAGES="intel-deep-learning-essentials-2025.2" +# Default use Intel® oneAPI Deep Learning Essentials 2025.2 +if [[ "$XPU_VERSION" == "2025.3" ]]; then + XPU_PACKAGES="intel-deep-learning-essentials-2025.3" else - XPU_PACKAGES="intel-deep-learning-essentials-2025.1" + XPU_PACKAGES="intel-deep-learning-essentials-2025.2" fi # The installation depends on the base OS diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index bcc249633faa5..452096630ffc8 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -176,6 +176,6 @@ ENV XPU_DRIVER_TYPE ROLLING RUN python3 -m pip install --upgrade pip && \ python3 -mpip install cmake==3.28.4 ADD ./common/install_xpu.sh install_xpu.sh -ENV XPU_VERSION 2025.2 +ENV XPU_VERSION 2025.3 RUN bash ./install_xpu.sh && rm install_xpu.sh RUN pushd /opt/_internal && tar -xJf static-libs-for-embedding-only.tar.xz && popd diff --git a/.ci/pytorch/windows/internal/xpu_install.bat b/.ci/pytorch/windows/internal/xpu_install.bat index f143571a56922..c6b377037f607 100644 --- a/.ci/pytorch/windows/internal/xpu_install.bat +++ b/.ci/pytorch/windows/internal/xpu_install.bat @@ -13,9 +13,9 @@ if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" :xpu_bundle_install_start set XPU_BUNDLE_PARENT_DIR=C:\Program Files (x86)\Intel\oneAPI -set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/75d4eb97-914a-4a95-852c-7b9733d80f74/intel-deep-learning-essentials-2025.1.3.8_offline.exe +set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.deep-learning-essentials.product -set XPU_BUNDLE_VERSION=2025.1.3+5 +set XPU_BUNDLE_VERSION=2025.2.1+20 set XPU_BUNDLE_INSTALLED=0 set XPU_BUNDLE_UNINSTALL=0 set XPU_EXTRA_URL=NULL @@ -24,9 +24,9 @@ set XPU_EXTRA_VERSION=2025.0.1+1226 set XPU_EXTRA_INSTALLED=0 set XPU_EXTRA_UNINSTALL=0 -if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.2] ( - set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe - set XPU_BUNDLE_VERSION=2025.2.1+20 +if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.3] ( + set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/0909c8b0-1475-414f-a9a9-489ee3822dbf/intel-deep-learning-essentials-2025.3.1.11_offline.exe + set XPU_BUNDLE_VERSION=2025.3.1+8 ) :: Check if XPU bundle is target version or already installed diff --git a/.circleci/scripts/binary_windows_build.sh b/.circleci/scripts/binary_windows_build.sh index 18dcde50e2b65..59dbbb3d9b6a8 100644 --- a/.circleci/scripts/binary_windows_build.sh +++ b/.circleci/scripts/binary_windows_build.sh @@ -15,7 +15,7 @@ fi if [[ "$DESIRED_CUDA" == 'xpu' ]]; then export VC_YEAR=2022 export USE_SCCACHE=0 - export XPU_VERSION=2025.2 + export XPU_VERSION=2025.3 fi echo "Free space on filesystem before build:" diff --git a/.circleci/scripts/binary_windows_test.sh b/.circleci/scripts/binary_windows_test.sh index 9326d9037e8b3..b8b82979caf48 100644 --- a/.circleci/scripts/binary_windows_test.sh +++ b/.circleci/scripts/binary_windows_test.sh @@ -8,7 +8,7 @@ export VC_YEAR=2022 if [[ "$DESIRED_CUDA" == 'xpu' ]]; then export VC_YEAR=2022 - export XPU_VERSION=2025.2 + export XPU_VERSION=2025.3 fi pushd "$PYTORCH_ROOT/.ci/pytorch/" diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 7fb1ba1f238f4..47c7bd3819c26 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -122,26 +122,27 @@ "nvidia-cufile==1.15.1.6; platform_system == 'Linux'" ), "xpu": ( - "intel-cmplr-lib-rt==2025.2.1 | " - "intel-cmplr-lib-ur==2025.2.1 | " - "intel-cmplr-lic-rt==2025.2.1 | " - "intel-sycl-rt==2025.2.1 | " - "oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "onemkl-sycl-blas==2025.2.0 | " - "onemkl-sycl-dft==2025.2.0 | " - "onemkl-sycl-lapack==2025.2.0 | " - "onemkl-sycl-rng==2025.2.0 | " - "onemkl-sycl-sparse==2025.2.0 | " - "dpcpp-cpp-rt==2025.2.1 | " - "intel-opencl-rt==2025.2.1 | " - "mkl==2025.2.0 | " - "intel-openmp==2025.2.1 | " - "tbb==2022.2.0 | " - "tcmlib==1.4.0 | " - "umf==0.11.0 | " - "intel-pti==0.13.1" + "intel-cmplr-lib-rt==2025.3.1 | " + "intel-cmplr-lib-ur==2025.3.1 | " + "intel-cmplr-lic-rt==2025.3.1 | " + "intel-sycl-rt==2025.3.1 | " + "oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "onemkl-license==2025.3.0 | " + "onemkl-sycl-blas==2025.3.0 | " + "onemkl-sycl-dft==2025.3.0 | " + "onemkl-sycl-lapack==2025.3.0 | " + "onemkl-sycl-rng==2025.3.0 | " + "onemkl-sycl-sparse==2025.3.0 | " + "dpcpp-cpp-rt==2025.3.1 | " + "intel-opencl-rt==2025.3.1 | " + "mkl==2025.3.0 | " + "intel-openmp==2025.3.1 | " + "tbb==2022.3.0 | " + "tcmlib==1.4.1 | " + "umf==1.0.2 | " + "intel-pti==0.15.0" ), } diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 754432bf461bf..553e9b6670c39 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -627,7 +627,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1301,7 +1301,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1975,7 +1975,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2649,7 +2649,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3323,7 +3323,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3997,7 +3997,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -4671,7 +4671,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index e14cb79c0000e..409c8619b434c 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -1004,7 +1004,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -2189,7 +2189,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -3374,7 +3374,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -4559,7 +4559,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -5744,7 +5744,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -6929,7 +6929,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -8114,7 +8114,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14t" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.3.1 | intel-cmplr-lib-ur==2025.3.1 | intel-cmplr-lic-rt==2025.3.1 | intel-sycl-rt==2025.3.1 | oneccl-devel==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.17.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.17.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-license==2025.3.0 | onemkl-sycl-blas==2025.3.0 | onemkl-sycl-dft==2025.3.0 | onemkl-sycl-lapack==2025.3.0 | onemkl-sycl-rng==2025.3.0 | onemkl-sycl-sparse==2025.3.0 | dpcpp-cpp-rt==2025.3.1 | intel-opencl-rt==2025.3.1 | mkl==2025.3.0 | intel-openmp==2025.3.1 | tbb==2022.3.0 | tcmlib==1.4.1 | umf==1.0.2 | intel-pti==0.15.0 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index 8799743809a77..84b8aa3cd91d4 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -95,7 +95,7 @@ jobs: build-environment: win-vs2022-xpu-n-1-py3 cuda-version: cpu use-xpu: true - xpu-version: '2025.1' + xpu-version: '2025.2' vc-year: '2022' secrets: inherit @@ -107,6 +107,6 @@ jobs: build-environment: win-vs2022-xpu-n-py3 cuda-version: cpu use-xpu: true - xpu-version: '2025.2' + xpu-version: '2025.3' vc-year: '2022' secrets: inherit From e36b4d6a08d3bdba77316e2ef1725b1f97ae0d46 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 5 Dec 2025 02:54:48 +0000 Subject: [PATCH 316/338] [3/N] Remove unused header inclusion (#169200) Remove unneeded header inclusion in C++ source files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169200 Approved by: https://github.com/albanD --- aten/src/ATen/CPUGeneratorImpl.cpp | 1 - aten/src/ATen/CachedTensorUtils.cpp | 1 - aten/src/ATen/Context.cpp | 2 -- aten/src/ATen/DLConvertor.cpp | 1 - aten/src/ATen/DynamicLibrary.cpp | 3 +-- aten/src/ATen/EmptyTensor.cpp | 3 --- aten/src/ATen/FunctionalTensorWrapper.cpp | 3 --- aten/src/ATen/FunctionalizeFallbackKernel.cpp | 3 --- aten/src/ATen/LegacyBatchingRegistrations.cpp | 3 --- aten/src/ATen/LegacyVmapTransforms.cpp | 2 -- aten/src/ATen/MapAllocator.cpp | 1 - aten/src/ATen/MemoryOverlap.cpp | 1 - aten/src/ATen/PythonTorchFunctionTLS.cpp | 1 - aten/src/ATen/ScalarOps.cpp | 1 - aten/src/ATen/SparseCsrTensorImpl.cpp | 4 ---- aten/src/ATen/SparseTensorImpl.cpp | 2 -- aten/src/ATen/TensorUtils.cpp | 3 --- aten/src/ATen/ThreadLocalPythonObjects.cpp | 1 - aten/src/ATen/ThreadLocalState.cpp | 1 - aten/src/ATen/VmapModeRegistrations.cpp | 1 - aten/src/ATen/core/interned_strings.cpp | 1 - aten/src/ATen/cuda/CUDABlas.cpp | 4 ---- aten/src/ATen/cuda/CUDAContext.cpp | 1 - aten/src/ATen/cuda/CUDASparseDescriptors.cpp | 2 -- aten/src/ATen/cuda/CachingHostAllocator.cpp | 3 --- aten/src/ATen/cuda/Exceptions.cpp | 1 - aten/src/ATen/cuda/MemPool.cpp | 1 - aten/src/ATen/cuda/detail/CUDAHooks.cpp | 1 - aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.cpp | 1 - aten/src/ATen/cuda/tunable/StreamTimer.cpp | 2 -- aten/src/ATen/cuda/tunable/Tunable.cpp | 2 -- aten/src/ATen/cudnn/AutocastRNN.cpp | 1 - aten/src/ATen/cudnn/Descriptors.cpp | 1 - aten/src/ATen/cudnn/Types.cpp | 1 - aten/src/ATen/functorch/BatchRulesActivation.cpp | 2 -- aten/src/ATen/functorch/BatchRulesBinaryOps.cpp | 1 - aten/src/ATen/functorch/BatchRulesConvolution.cpp | 1 - aten/src/ATen/functorch/BatchRulesLoss.cpp | 2 -- aten/src/ATen/functorch/BatchRulesModules.cpp | 2 -- aten/src/ATen/functorch/BatchRulesNorm.cpp | 2 -- aten/src/ATen/functorch/BatchRulesPooling.cpp | 3 --- aten/src/ATen/functorch/BatchRulesRandomness.cpp | 1 - aten/src/ATen/functorch/BatchRulesReduceOps.cpp | 1 - aten/src/ATen/functorch/BatchRulesScatterOps.cpp | 1 - aten/src/ATen/functorch/BatchRulesUnaryOps.cpp | 1 - aten/src/ATen/functorch/BatchRulesViews.cpp | 3 --- aten/src/ATen/functorch/BatchedFallback.cpp | 1 - aten/src/ATen/functorch/DynamicLayer.cpp | 2 -- aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp | 4 ---- aten/src/ATen/functorch/LegacyVmapTransforms.cpp | 1 - aten/src/ATen/functorch/PlumbingHelper.cpp | 1 - aten/src/ATen/functorch/VmapModeRegistrations.cpp | 5 ----- aten/src/ATen/metal/Context.cpp | 1 - aten/src/ATen/native/AdaptiveMaxPooling3d.cpp | 1 - aten/src/ATen/native/ReplicationPadding.cpp | 1 - aten/src/ATen/native/Scalar.cpp | 1 - aten/src/ATen/native/SpectralOps.cpp | 1 - aten/src/ATen/native/nested/NestedTensorBackward.cpp | 1 - aten/src/ATen/native/nested/NestedTensorMatmul.cpp | 1 - aten/src/ATen/native/quantized/AffineQuantizerBase.cpp | 2 +- aten/src/ATen/native/quantized/cpu/qhardswish.cpp | 1 - aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp | 1 - aten/src/ATen/native/sparse/ParamUtils.cpp | 1 - .../ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp | 1 - aten/src/ATen/nnapi/nnapi_bind.cpp | 1 - aten/src/ATen/vulkan/Context.cpp | 1 - torch/csrc/distributed/c10d/FileStore.cpp | 1 - torch/csrc/distributed/c10d/Functional.cpp | 2 -- torch/csrc/distributed/c10d/GlooDeviceFactory.cpp | 4 ---- torch/csrc/distributed/c10d/HashStore.cpp | 1 - torch/csrc/distributed/c10d/NCCLUtils.cpp | 5 ++--- torch/csrc/distributed/c10d/Ops.cpp | 1 - torch/csrc/distributed/c10d/ProcessGroup.cpp | 6 ------ torch/csrc/distributed/c10d/ProcessGroupGloo.cpp | 5 ----- torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp | 3 --- torch/csrc/distributed/c10d/debug.cpp | 2 -- torch/csrc/distributed/c10d/python_comm_hook.cpp | 3 --- torch/csrc/distributed/c10d/reducer.cpp | 4 ---- torch/csrc/distributed/c10d/socket.cpp | 1 - .../distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp | 2 -- .../csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp | 1 - torch/csrc/dynamo/extra_state.cpp | 1 - torch/csrc/dynamo/framelocals_mapping.cpp | 1 - torch/csrc/dynamo/init.cpp | 2 -- torch/csrc/dynamo/python_compiled_autograd.cpp | 3 --- torch/csrc/inductor/aoti_eager/kernel_holder.cpp | 6 ------ torch/csrc/inductor/aoti_package/model_package_loader.cpp | 3 --- torch/csrc/inductor/aoti_package/pybind.cpp | 3 --- torch/csrc/inductor/aoti_runner/model_container_runner.cpp | 1 - torch/csrc/inductor/inductor_ops.cpp | 2 -- torch/csrc/inductor/static_cuda_launcher.cpp | 4 ---- torch/csrc/jit/frontend/strtod.cpp | 4 ---- torch/csrc/jit/tensorexpr/block_codegen.cpp | 4 ---- torch/csrc/jit/tensorexpr/bounds_inference.cpp | 2 -- torch/csrc/jit/tensorexpr/bounds_overlap.cpp | 2 -- torch/csrc/jit/tensorexpr/cuda_codegen.cpp | 1 - torch/csrc/jit/tensorexpr/external_functions_codegen.cpp | 2 -- torch/csrc/jit/tensorexpr/graph_opt.cpp | 1 - torch/csrc/jit/tensorexpr/ir_cloner.cpp | 2 -- torch/csrc/jit/tensorexpr/ir_mutator.cpp | 1 - torch/csrc/jit/tensorexpr/ir_verifier.cpp | 3 --- torch/csrc/jit/tensorexpr/loopnest.cpp | 1 - torch/csrc/jit/tensorexpr/reduction.cpp | 1 - torch/csrc/jit/tensorexpr/tensor.cpp | 1 - torch/csrc/jit/tensorexpr/tensorexpr_init.cpp | 3 --- torch/csrc/jit/tensorexpr/types.cpp | 3 --- 106 files changed, 4 insertions(+), 203 deletions(-) diff --git a/aten/src/ATen/CPUGeneratorImpl.cpp b/aten/src/ATen/CPUGeneratorImpl.cpp index 4d3dafc65663e..61c6bd3e62b80 100644 --- a/aten/src/ATen/CPUGeneratorImpl.cpp +++ b/aten/src/ATen/CPUGeneratorImpl.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include namespace at { diff --git a/aten/src/ATen/CachedTensorUtils.cpp b/aten/src/ATen/CachedTensorUtils.cpp index d9e0f1453f4e5..87d0a6a10a4d3 100644 --- a/aten/src/ATen/CachedTensorUtils.cpp +++ b/aten/src/ATen/CachedTensorUtils.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 6bc321887502d..4f66a8a5ff38a 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -3,12 +3,10 @@ #include #include -#include #include #include #include -#include #include #include diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 9d7ebb3a86cfb..74f69b291b09e 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -1,5 +1,4 @@ #include -#include using namespace std; namespace at { diff --git a/aten/src/ATen/DynamicLibrary.cpp b/aten/src/ATen/DynamicLibrary.cpp index 7dc27f38fa7f0..df933c23ea800 100644 --- a/aten/src/ATen/DynamicLibrary.cpp +++ b/aten/src/ATen/DynamicLibrary.cpp @@ -1,13 +1,12 @@ #include -#include #include -#include #ifndef _WIN32 #include #include #else #include +#include #endif namespace at { diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index 0e535ab20cd21..4d12942eb0449 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -1,9 +1,6 @@ #define TORCH_ASSERT_NO_OPERATORS #include -#include -#include #include -#include #include #include diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 8b7b3bc42a9cb..9610360be4dd9 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -1,9 +1,6 @@ #include -#include -#include -#include #include #include #include diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index 10f988b4d2815..100b4efe90b67 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -16,14 +16,11 @@ #include #else #include -#include #include #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/LegacyBatchingRegistrations.cpp b/aten/src/ATen/LegacyBatchingRegistrations.cpp index 2c54718e938fb..cb1c71916c42c 100644 --- a/aten/src/ATen/LegacyBatchingRegistrations.cpp +++ b/aten/src/ATen/LegacyBatchingRegistrations.cpp @@ -1,12 +1,9 @@ #include -#include #include #include #include #include -#include #include -#include #include diff --git a/aten/src/ATen/LegacyVmapTransforms.cpp b/aten/src/ATen/LegacyVmapTransforms.cpp index 540bdd3bda3e4..53de9799577d6 100644 --- a/aten/src/ATen/LegacyVmapTransforms.cpp +++ b/aten/src/ATen/LegacyVmapTransforms.cpp @@ -1,6 +1,4 @@ #include -#include -#include #include namespace at { diff --git a/aten/src/ATen/MapAllocator.cpp b/aten/src/ATen/MapAllocator.cpp index d8ad62c8c62a4..f2f0545410794 100644 --- a/aten/src/ATen/MapAllocator.cpp +++ b/aten/src/ATen/MapAllocator.cpp @@ -7,7 +7,6 @@ #define AT_ATOMIC_IPC_REFCOUNT 1 #endif -#include #include #ifdef _WIN32 diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 1bc8c30158aec..5cdf192c1abf2 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -1,6 +1,5 @@ #include #include -#include #include namespace at { diff --git a/aten/src/ATen/PythonTorchFunctionTLS.cpp b/aten/src/ATen/PythonTorchFunctionTLS.cpp index e90065543e35b..37ea3a318b0f4 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.cpp +++ b/aten/src/ATen/PythonTorchFunctionTLS.cpp @@ -1,5 +1,4 @@ #include -#include namespace at::impl { diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp index da4f7a35a2f47..080bb5011cd3f 100644 --- a/aten/src/ATen/ScalarOps.cpp +++ b/aten/src/ATen/ScalarOps.cpp @@ -1,5 +1,4 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include #include #include #include diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp index dec6d2e95960b..6b3a79242a289 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.cpp +++ b/aten/src/ATen/SparseCsrTensorImpl.cpp @@ -1,10 +1,6 @@ -#include #include #include #include -#include -#include -#include namespace at { diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 2b2f286ea50d3..7a870fac117da 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -1,7 +1,5 @@ -#include #include #include -#include namespace at { diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 2752ff792e485..d5c8632134c85 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -1,8 +1,5 @@ -#include -#include #include #include -#include #include #include diff --git a/aten/src/ATen/ThreadLocalPythonObjects.cpp b/aten/src/ATen/ThreadLocalPythonObjects.cpp index 117f9e5d735de..0c70a5c14211f 100644 --- a/aten/src/ATen/ThreadLocalPythonObjects.cpp +++ b/aten/src/ATen/ThreadLocalPythonObjects.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 22509c7be4e19..b5d1e5ff6d105 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -2,7 +2,6 @@ #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) #include -#include #endif #include diff --git a/aten/src/ATen/VmapModeRegistrations.cpp b/aten/src/ATen/VmapModeRegistrations.cpp index ca5a87bf2d253..abcafa2075288 100644 --- a/aten/src/ATen/VmapModeRegistrations.cpp +++ b/aten/src/ATen/VmapModeRegistrations.cpp @@ -1,5 +1,4 @@ #include -#include using torch::CppFunction; diff --git a/aten/src/ATen/core/interned_strings.cpp b/aten/src/ATen/core/interned_strings.cpp index 018ee82fe3227..b8e9f8d65c6e0 100644 --- a/aten/src/ATen/core/interned_strings.cpp +++ b/aten/src/ATen/core/interned_strings.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index bc7607f232011..4478791487302 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -2,17 +2,13 @@ Provides the implementations of CUDA BLAS function templates. */ -#include #include #include #include #include #include #include -#include -#include #include -#include #include #include diff --git a/aten/src/ATen/cuda/CUDAContext.cpp b/aten/src/ATen/cuda/CUDAContext.cpp index 322a4aec1fe9a..829acefc7b333 100644 --- a/aten/src/ATen/cuda/CUDAContext.cpp +++ b/aten/src/ATen/cuda/CUDAContext.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include diff --git a/aten/src/ATen/cuda/CUDASparseDescriptors.cpp b/aten/src/ATen/cuda/CUDASparseDescriptors.cpp index c7ab4fbfc95df..b1688d861bd2a 100644 --- a/aten/src/ATen/cuda/CUDASparseDescriptors.cpp +++ b/aten/src/ATen/cuda/CUDASparseDescriptors.cpp @@ -1,6 +1,4 @@ -#include #include -#include #include #include #include diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index 5786e87dac519..8560cfe272688 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -1,9 +1,6 @@ #include -#include #include -#include -#include #include #include diff --git a/aten/src/ATen/cuda/Exceptions.cpp b/aten/src/ATen/cuda/Exceptions.cpp index dd240cd643e19..8945512481957 100644 --- a/aten/src/ATen/cuda/Exceptions.cpp +++ b/aten/src/ATen/cuda/Exceptions.cpp @@ -1,5 +1,4 @@ //NS: CUDACachingAllocator must be included before to get CUDART_VERSION definedi -#include #include diff --git a/aten/src/ATen/cuda/MemPool.cpp b/aten/src/ATen/cuda/MemPool.cpp index 99405965898e0..df58cbfa6111f 100644 --- a/aten/src/ATen/cuda/MemPool.cpp +++ b/aten/src/ATen/cuda/MemPool.cpp @@ -1,4 +1,3 @@ -#include #include namespace at::cuda { diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index a4fd454633dc0..39abfe7b91458 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -36,7 +36,6 @@ #include #endif -#include #include #include diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.cpp b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.cpp index 68e52314d9bea..1353014ee0993 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.cpp +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.cpp @@ -1,5 +1,4 @@ #include -#include namespace at::cuda { diff --git a/aten/src/ATen/cuda/tunable/StreamTimer.cpp b/aten/src/ATen/cuda/tunable/StreamTimer.cpp index 8b9e6f05cbf1d..2327574834eb5 100644 --- a/aten/src/ATen/cuda/tunable/StreamTimer.cpp +++ b/aten/src/ATen/cuda/tunable/StreamTimer.cpp @@ -7,12 +7,10 @@ // Adapting TunableOp into PyTorch // Copyright (c) Advanced Micro Devices, Inc. // -#include #include #include #include -#include namespace at::cuda::tunable { diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index eb7e381d27766..9c5e0c91d6b12 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -7,7 +7,6 @@ // Adapting TunableOp into PyTorch // Copyright (c) Advanced Micro Devices, Inc. // -#include #include #include @@ -21,7 +20,6 @@ #endif #include -#include #include #include #include diff --git a/aten/src/ATen/cudnn/AutocastRNN.cpp b/aten/src/ATen/cudnn/AutocastRNN.cpp index 84571c9b45dcf..acf448702616f 100644 --- a/aten/src/ATen/cudnn/AutocastRNN.cpp +++ b/aten/src/ATen/cudnn/AutocastRNN.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index a2cb0cb0a1025..343bf108e3749 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -1,6 +1,5 @@ #include -#include #include #include diff --git a/aten/src/ATen/cudnn/Types.cpp b/aten/src/ATen/cudnn/Types.cpp index f612436f56724..8a77c094d167c 100644 --- a/aten/src/ATen/cudnn/Types.cpp +++ b/aten/src/ATen/cudnn/Types.cpp @@ -1,6 +1,5 @@ #include -#include #include diff --git a/aten/src/ATen/functorch/BatchRulesActivation.cpp b/aten/src/ATen/functorch/BatchRulesActivation.cpp index dbcc673804009..92b5527db77c5 100644 --- a/aten/src/ATen/functorch/BatchRulesActivation.cpp +++ b/aten/src/ATen/functorch/BatchRulesActivation.cpp @@ -5,8 +5,6 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include // NB: most activation functions fit pointwise unary or binary rules. // These are only the ones that have special batch rules to help with organization diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index 937f39273ab57..c0e102e3cfbd1 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include diff --git a/aten/src/ATen/functorch/BatchRulesConvolution.cpp b/aten/src/ATen/functorch/BatchRulesConvolution.cpp index 0ebc5da1e1e3a..748d5b1687a3c 100644 --- a/aten/src/ATen/functorch/BatchRulesConvolution.cpp +++ b/aten/src/ATen/functorch/BatchRulesConvolution.cpp @@ -6,7 +6,6 @@ #include #include -#include namespace at::functorch { diff --git a/aten/src/ATen/functorch/BatchRulesLoss.cpp b/aten/src/ATen/functorch/BatchRulesLoss.cpp index c02e58db2e65c..0c9f0ebe1fd7b 100644 --- a/aten/src/ATen/functorch/BatchRulesLoss.cpp +++ b/aten/src/ATen/functorch/BatchRulesLoss.cpp @@ -6,8 +6,6 @@ #include #include -#include -#include namespace at::functorch { // Flattens out all dims except the batch dim, and also moves batch dim diff --git a/aten/src/ATen/functorch/BatchRulesModules.cpp b/aten/src/ATen/functorch/BatchRulesModules.cpp index 5fba8d257ceb8..4e0b50c4e3fe7 100644 --- a/aten/src/ATen/functorch/BatchRulesModules.cpp +++ b/aten/src/ATen/functorch/BatchRulesModules.cpp @@ -5,8 +5,6 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include #include #include diff --git a/aten/src/ATen/functorch/BatchRulesNorm.cpp b/aten/src/ATen/functorch/BatchRulesNorm.cpp index 4546c56e2f586..51dae00e6b7ed 100644 --- a/aten/src/ATen/functorch/BatchRulesNorm.cpp +++ b/aten/src/ATen/functorch/BatchRulesNorm.cpp @@ -6,8 +6,6 @@ #include #include -#include -#include namespace at::functorch { diff --git a/aten/src/ATen/functorch/BatchRulesPooling.cpp b/aten/src/ATen/functorch/BatchRulesPooling.cpp index e94a63086e939..09b1ff90bc935 100644 --- a/aten/src/ATen/functorch/BatchRulesPooling.cpp +++ b/aten/src/ATen/functorch/BatchRulesPooling.cpp @@ -5,9 +5,6 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include -#include namespace at::functorch { diff --git a/aten/src/ATen/functorch/BatchRulesRandomness.cpp b/aten/src/ATen/functorch/BatchRulesRandomness.cpp index 2c12854f3268d..0c2ab1f7044ba 100644 --- a/aten/src/ATen/functorch/BatchRulesRandomness.cpp +++ b/aten/src/ATen/functorch/BatchRulesRandomness.cpp @@ -4,7 +4,6 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include #include #include diff --git a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp index ecee801965e71..c1c017f718814 100644 --- a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp @@ -6,7 +6,6 @@ #include #include -#include #include #include diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index ae4b5b25988e4..80034ff95ca3c 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include diff --git a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp index 48a735c3e5332..3cabdd251480f 100644 --- a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp @@ -5,7 +5,6 @@ // LICENSE file in the root directory of this source tree. #include -#include namespace at::functorch { diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp index a78d8b0eec7e1..08724d4fc1243 100644 --- a/aten/src/ATen/functorch/BatchRulesViews.cpp +++ b/aten/src/ATen/functorch/BatchRulesViews.cpp @@ -9,11 +9,8 @@ #include #include -#include -#include #include #include -#include #include namespace at::functorch { diff --git a/aten/src/ATen/functorch/BatchedFallback.cpp b/aten/src/ATen/functorch/BatchedFallback.cpp index aab1da68053b7..b479639f1c1a5 100644 --- a/aten/src/ATen/functorch/BatchedFallback.cpp +++ b/aten/src/ATen/functorch/BatchedFallback.cpp @@ -6,7 +6,6 @@ #include #include -#include #include #include diff --git a/aten/src/ATen/functorch/DynamicLayer.cpp b/aten/src/ATen/functorch/DynamicLayer.cpp index 518098a8b4a80..1420aaf0ab943 100644 --- a/aten/src/ATen/functorch/DynamicLayer.cpp +++ b/aten/src/ATen/functorch/DynamicLayer.cpp @@ -7,12 +7,10 @@ #include #include #include -#include #include #include #include -#include #include #include #include diff --git a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp index e51f4901f36bc..1df4c8938183a 100644 --- a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp +++ b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp @@ -6,13 +6,9 @@ #include #include -#include #include -#include #include -#include -#include #include #include #include diff --git a/aten/src/ATen/functorch/LegacyVmapTransforms.cpp b/aten/src/ATen/functorch/LegacyVmapTransforms.cpp index 662aaeb8e5ca3..5f8b124924e61 100644 --- a/aten/src/ATen/functorch/LegacyVmapTransforms.cpp +++ b/aten/src/ATen/functorch/LegacyVmapTransforms.cpp @@ -7,7 +7,6 @@ #include #include -#include #include namespace at::functorch { diff --git a/aten/src/ATen/functorch/PlumbingHelper.cpp b/aten/src/ATen/functorch/PlumbingHelper.cpp index f8ebe66908237..2ecc8084b8b89 100644 --- a/aten/src/ATen/functorch/PlumbingHelper.cpp +++ b/aten/src/ATen/functorch/PlumbingHelper.cpp @@ -4,7 +4,6 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include #include #include #include diff --git a/aten/src/ATen/functorch/VmapModeRegistrations.cpp b/aten/src/ATen/functorch/VmapModeRegistrations.cpp index 195afd80bc713..e84468c3af4dd 100644 --- a/aten/src/ATen/functorch/VmapModeRegistrations.cpp +++ b/aten/src/ATen/functorch/VmapModeRegistrations.cpp @@ -5,11 +5,6 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include -#include -#include -#include #include // functorch's vmap has two Dispatch Keys that implement it: diff --git a/aten/src/ATen/metal/Context.cpp b/aten/src/ATen/metal/Context.cpp index c0d32086d4179..111e49201eb33 100644 --- a/aten/src/ATen/metal/Context.cpp +++ b/aten/src/ATen/metal/Context.cpp @@ -1,6 +1,5 @@ #include -#include #include namespace at::metal { diff --git a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp index ef4bab3ec1de0..57a1487bfbeb9 100644 --- a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include diff --git a/aten/src/ATen/native/ReplicationPadding.cpp b/aten/src/ATen/native/ReplicationPadding.cpp index 0c66c7a632997..795e2fea3f03f 100644 --- a/aten/src/ATen/native/ReplicationPadding.cpp +++ b/aten/src/ATen/native/ReplicationPadding.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index 39e203f632781..dea7ecc7118ac 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -1,5 +1,4 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include #include #ifndef AT_PER_OPERATOR_HEADERS diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 975e237c468d6..91a0c3ff8cf93 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -58,7 +58,6 @@ #include #include #include -#include #endif #include diff --git a/aten/src/ATen/native/nested/NestedTensorBackward.cpp b/aten/src/ATen/native/nested/NestedTensorBackward.cpp index 701c38ce52e33..328e957d1a94a 100644 --- a/aten/src/ATen/native/nested/NestedTensorBackward.cpp +++ b/aten/src/ATen/native/nested/NestedTensorBackward.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/native/nested/NestedTensorMatmul.cpp b/aten/src/ATen/native/nested/NestedTensorMatmul.cpp index 8e0a371ba784e..60de6dd2bdaba 100644 --- a/aten/src/ATen/native/nested/NestedTensorMatmul.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMatmul.cpp @@ -12,7 +12,6 @@ #include #include #include -#include namespace at::native { diff --git a/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp b/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp index 1086b4d0d8c58..a5b15b86d27fa 100644 --- a/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp +++ b/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #ifdef USE_FBGEMM #include diff --git a/aten/src/ATen/native/quantized/cpu/qhardswish.cpp b/aten/src/ATen/native/quantized/cpu/qhardswish.cpp index 5c71e07dfad2a..569b8f487a75f 100644 --- a/aten/src/ATen/native/quantized/cpu/qhardswish.cpp +++ b/aten/src/ATen/native/quantized/cpu/qhardswish.cpp @@ -13,7 +13,6 @@ #include #endif -#include namespace at::native { diff --git a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp index 53da11b4d0fe7..3b01841c4aa87 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp @@ -9,7 +9,6 @@ #include #include #include -#include int register_linear_params(); diff --git a/aten/src/ATen/native/sparse/ParamUtils.cpp b/aten/src/ATen/native/sparse/ParamUtils.cpp index 1f2ee5932e40b..62d5ea5cf3212 100644 --- a/aten/src/ATen/native/sparse/ParamUtils.cpp +++ b/aten/src/ATen/native/sparse/ParamUtils.cpp @@ -1,6 +1,5 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include #include #include #include diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index cf854a84e7dad..979dbdd033ac3 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include namespace at::native { diff --git a/aten/src/ATen/nnapi/nnapi_bind.cpp b/aten/src/ATen/nnapi/nnapi_bind.cpp index 8f40ee4045681..78e51fa1c7e5f 100644 --- a/aten/src/ATen/nnapi/nnapi_bind.cpp +++ b/aten/src/ATen/nnapi/nnapi_bind.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include diff --git a/aten/src/ATen/vulkan/Context.cpp b/aten/src/ATen/vulkan/Context.cpp index 06d959b89fcb5..5b83c3e4b9a21 100644 --- a/aten/src/ATen/vulkan/Context.cpp +++ b/aten/src/ATen/vulkan/Context.cpp @@ -1,6 +1,5 @@ #include -#include #include #ifdef USE_VULKAN_API diff --git a/torch/csrc/distributed/c10d/FileStore.cpp b/torch/csrc/distributed/c10d/FileStore.cpp index 969379e739438..8a459b8080dbc 100644 --- a/torch/csrc/distributed/c10d/FileStore.cpp +++ b/torch/csrc/distributed/c10d/FileStore.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index c21c5f9129acb..1284676dae015 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -1,5 +1,3 @@ -#include -#include #include #include #include diff --git a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp index d9a74e2efa379..25448dbc9f690 100644 --- a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp +++ b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp @@ -1,11 +1,7 @@ #include -#include - #ifdef USE_C10D_GLOO -#include - #include #include diff --git a/torch/csrc/distributed/c10d/HashStore.cpp b/torch/csrc/distributed/c10d/HashStore.cpp index 9073333fb9a48..d7079d0c48125 100644 --- a/torch/csrc/distributed/c10d/HashStore.cpp +++ b/torch/csrc/distributed/c10d/HashStore.cpp @@ -1,6 +1,5 @@ #include -#include #include #include diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index a41f654b9ae20..b9c9b313e0b4d 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,10 +1,9 @@ #include -#include - -#include #ifdef USE_C10D_NCCL +#include #include +#include #include namespace c10d { diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index a5d42771ce05b..0ded9d4cc733d 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index b888e315021ac..903144511f297 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -7,11 +6,6 @@ #include #include -#include -#include -#include -#include -#include namespace c10d { diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index c1d28b2787cda..9eb7770381cb0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -6,7 +6,6 @@ #include #include -#include #include #include #include @@ -22,18 +21,14 @@ #include #include #endif -#include -#include #include #include -#include #include #include #include -#include #include #include diff --git a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp index fa40ff15ec74f..4a316c7733280 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp @@ -2,15 +2,12 @@ #ifdef USE_C10D_GLOO -#include #include #include #include #include #include #include -#include -#include #include #include #include diff --git a/torch/csrc/distributed/c10d/debug.cpp b/torch/csrc/distributed/c10d/debug.cpp index d5d77094e1718..eb05a9b1e7151 100644 --- a/torch/csrc/distributed/c10d/debug.cpp +++ b/torch/csrc/distributed/c10d/debug.cpp @@ -9,10 +9,8 @@ #include #include -#include #include -#include #include namespace c10d { diff --git a/torch/csrc/distributed/c10d/python_comm_hook.cpp b/torch/csrc/distributed/c10d/python_comm_hook.cpp index af3bf6b4c65d3..dfb656b003c85 100644 --- a/torch/csrc/distributed/c10d/python_comm_hook.cpp +++ b/torch/csrc/distributed/c10d/python_comm_hook.cpp @@ -1,9 +1,6 @@ #include -#include -#include #include -#include namespace c10d { diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index d2bf2c6cf7f62..7635f1a8165ef 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -5,17 +5,13 @@ #include -#include #include -#include #include #include #include #include #include #include -#include -#include #include #include #include diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index c79f5a04010eb..1a36efcc4eb36 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -7,7 +7,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp index 04838b1581ad2..51a5d5e7244b1 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp @@ -1,5 +1,4 @@ #include -#include #include #include @@ -12,7 +11,6 @@ #include #endif -#include #include #include diff --git a/torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp b/torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp index b5efcfeb3006f..c19037e1a7862 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp @@ -5,7 +5,6 @@ #include #include -#include #include namespace { diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 8dc316b98e63c..b890c2848011b 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/dynamo/framelocals_mapping.cpp b/torch/csrc/dynamo/framelocals_mapping.cpp index 5f78dca9591f9..8165810caa58c 100644 --- a/torch/csrc/dynamo/framelocals_mapping.cpp +++ b/torch/csrc/dynamo/framelocals_mapping.cpp @@ -1,6 +1,5 @@ #include -#include #include #include diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 69d6e0555ceb4..0dfd6b828cf51 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -11,8 +11,6 @@ #include #include #include -#include -#include #include static struct PyModuleDef _module = diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index c24f2cffdd762..463eb7de0c222 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -1,12 +1,9 @@ #include #include -#include #include #include #include -#include -#include #include #include #include diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index fcdefeac9219c..5a8956c5c2354 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -1,12 +1,9 @@ #if !defined(C10_MOBILE) && !defined(ANDROID) #include -#include - #include #include #include -#include #include #include #include @@ -16,11 +13,8 @@ #ifdef USE_XPU #include #endif -#include #include -#include -#include namespace torch::inductor { diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 93c8f71e84d80..9ff0f844931cb 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -4,12 +4,10 @@ #include #include #include -#include #include #include #include -#include #include #include #include @@ -33,7 +31,6 @@ namespace fs = std::filesystem; #define access _access #define F_OK 0 #else -#include #include #endif diff --git a/torch/csrc/inductor/aoti_package/pybind.cpp b/torch/csrc/inductor/aoti_package/pybind.cpp index 591153bb1f6c2..452d46e05bff7 100644 --- a/torch/csrc/inductor/aoti_package/pybind.cpp +++ b/torch/csrc/inductor/aoti_package/pybind.cpp @@ -1,7 +1,5 @@ #include #include -#include -#include #ifdef USE_CUDA #include #endif @@ -9,7 +7,6 @@ #include #include #include -#include namespace torch::inductor { diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index 44517bcd702e8..445246f82848c 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -15,7 +15,6 @@ #include #include // std::function #else // !_WIN32 -#include #include #include #endif // _WIN32 diff --git a/torch/csrc/inductor/inductor_ops.cpp b/torch/csrc/inductor/inductor_ops.cpp index 7d0e9b612343b..9723e27e6ba8a 100644 --- a/torch/csrc/inductor/inductor_ops.cpp +++ b/torch/csrc/inductor/inductor_ops.cpp @@ -8,8 +8,6 @@ #include #include -#include - namespace torch::inductor { using namespace at; diff --git a/torch/csrc/inductor/static_cuda_launcher.cpp b/torch/csrc/inductor/static_cuda_launcher.cpp index 59916b6763bfa..4c2b3aaae2007 100644 --- a/torch/csrc/inductor/static_cuda_launcher.cpp +++ b/torch/csrc/inductor/static_cuda_launcher.cpp @@ -2,16 +2,12 @@ // We disable this file from being hipified because there are CUDA drivers hip // has not implemented yet. Also, we're passing in a cubin file directly, so it // would take more work to support ROCM anyway. -#include #include #include #include -#include -#include #include #include -#include #include #include diff --git a/torch/csrc/jit/frontend/strtod.cpp b/torch/csrc/jit/frontend/strtod.cpp index 76fc20cf6a20a..daf768ee62512 100644 --- a/torch/csrc/jit/frontend/strtod.cpp +++ b/torch/csrc/jit/frontend/strtod.cpp @@ -22,10 +22,6 @@ // respective // C stdlib functions -#include -#include -#include -#include #include namespace torch::jit { diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp index 99dd289fb0964..ceb49a8675918 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -1,10 +1,6 @@ #include #include -#include -#include -#include -#include namespace torch::jit::tensorexpr { diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.cpp b/torch/csrc/jit/tensorexpr/bounds_inference.cpp index 034f51f46b8f7..1c74a9f547a81 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_inference.cpp @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include diff --git a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp index 0c785504efe85..fd7e74fcc235c 100644 --- a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp @@ -1,7 +1,5 @@ #include #include -#include -#include #include diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index 264e01d65db94..c787ccd88ddcf 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp b/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp index 3c909f44f1faa..bd50737682f7b 100644 --- a/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions_codegen.cpp @@ -2,8 +2,6 @@ // external_functions_codegen_template.cpp #include -#include -#include #include namespace torch::jit::tensorexpr { diff --git a/torch/csrc/jit/tensorexpr/graph_opt.cpp b/torch/csrc/jit/tensorexpr/graph_opt.cpp index 27c24f927b692..de2d6f011eb9b 100644 --- a/torch/csrc/jit/tensorexpr/graph_opt.cpp +++ b/torch/csrc/jit/tensorexpr/graph_opt.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include diff --git a/torch/csrc/jit/tensorexpr/ir_cloner.cpp b/torch/csrc/jit/tensorexpr/ir_cloner.cpp index 78421bb0f0a41..c2c0e7080a48e 100644 --- a/torch/csrc/jit/tensorexpr/ir_cloner.cpp +++ b/torch/csrc/jit/tensorexpr/ir_cloner.cpp @@ -4,8 +4,6 @@ #include #include -#include - namespace torch::jit::tensorexpr { template < diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index 75fbcbe074845..cb8135be6307d 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include diff --git a/torch/csrc/jit/tensorexpr/ir_verifier.cpp b/torch/csrc/jit/tensorexpr/ir_verifier.cpp index d914e5c575246..8342bc7abbbfd 100644 --- a/torch/csrc/jit/tensorexpr/ir_verifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_verifier.cpp @@ -1,9 +1,6 @@ #include #include -#include -#include -#include namespace torch::jit::tensorexpr { diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index cca7efcd0adaf..1bdae4ca7ae90 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include diff --git a/torch/csrc/jit/tensorexpr/reduction.cpp b/torch/csrc/jit/tensorexpr/reduction.cpp index d7101011f492c..524d6928c84f5 100644 --- a/torch/csrc/jit/tensorexpr/reduction.cpp +++ b/torch/csrc/jit/tensorexpr/reduction.cpp @@ -1,6 +1,5 @@ #include -#include #include diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index 156868bc5774d..90e7fb8bf072a 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -1,6 +1,5 @@ #include -#include #include #include diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index 6c7c9c060c915..87620a9fb26af 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -1,6 +1,4 @@ -#include #include -#include #include #include #include @@ -11,7 +9,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp index f3a62fa374056..57f6c1c9ec342 100644 --- a/torch/csrc/jit/tensorexpr/types.cpp +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -1,10 +1,7 @@ #include -#include #include -#include - namespace torch::jit::tensorexpr { Dtype Dtype::scalar_dtype() const { From 0ff77ef2635a3b3fde31740570c5da111316fd84 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 4 Dec 2025 15:57:08 -0800 Subject: [PATCH 317/338] [DTensor] unbacked matmuls for no-redistribute case (#168051) This allows compiling a matmul on 2 DTensors with fully unbacked sizes, when a zero-cost strategy is available. Changes with the PR: - `mark_unbacked()` would previously error on tensor subclasses; now for DTensors it allocates unbacked symbols for both inner & outer sizes. The main motivation here is for testing, so happy to tweak semantics. The unbacked binding search process also now matches on DTensor outer sizes. - Selecting an op strategy in sharding propagation is based on minimal redistribution costs, and these costs are functions of tensor shapes, so can be unbacked expressions. This PR makes this process more unbacked-friendly, choosing negative or zero-cost strategies when they're available. When these "trivial" strategies aren't available, selection requires comparing unbacked costs, addressed in the next PR (with usage of fallback hints). - For matmul strategies, sharding prop rules filter out strategies where the matmul inputs fail the `is_tensor_shardable` check on the given DeviceMesh. In eager, this filters out cases where `size of sharded dim < num shards`. In the compiled & unbacked case, we'll often encounter dim size `u_` where `u_` can be both larger and smaller than num shards. This PR assumes such cases are shardable by default, and the implication is that strategies that shard on unbacked dimensions are included for consideration, and if selected, can lead to uneven sharding/zero-size shards at runtime. Alternatives would be 1) the current state of things: DDE and force the user to pick a path: `torch._check(size of sharded dim < or >= num shards)`, or 2) assume the non-shardable case and never include sharded strategies, unless the user picks the shardable path. More discussion in https://github.com/pytorch/pytorch/issues/165034#issuecomment-3417695068. - Lastly, testing traced redistribution decisions required using aot_eager backend, so that the collectives/ops were hardcoded (eager backend would go through DTensor.dispatch again). This seemed to require re-enabling proxy tracking during shard prop, basically reverting https://github.com/pytorch/pytorch/pull/163126. Otherwise, errors like `RuntimeError: Max(1, u2) (, 140294330350224)is not tracked with proxy for ` show up for DTensor outer strides... Pull Request resolved: https://github.com/pytorch/pytorch/pull/168051 Approved by: https://github.com/laithsakka --- .../tensor/test_dtensor_compile.py | 74 ++++++++++++-- .../distributed/tensor/test_dtensor_export.py | 97 +++++++++++++------ torch/_dynamo/decorators.py | 21 +++- torch/distributed/tensor/_ops/_matrix_ops.py | 28 ++++-- torch/distributed/tensor/_ops/utils.py | 32 +++++- torch/distributed/tensor/_sharding_prop.py | 78 ++++++++------- torch/fx/experimental/symbolic_shapes.py | 63 ++++++------ 7 files changed, 277 insertions(+), 116 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 05fd187bb7576..cd326ec26fb39 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -399,9 +399,6 @@ def fn(x): self.assertEqual(res, ref) @skipIfHpu - @unittest.skip( - "DTensor + dynamic fails - s77 + 8 is not tracked with proxy .. proxy_tensor.PythonKeyTracer" - ) def test_dtensor_dynamic_slice(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -444,9 +441,6 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) - @unittest.skip( - "DTensor + dynamic fails - s77 + 8 is not tracked with proxy .. proxy_tensor.PythonKeyTracer" - ) def test_dtensor_dynamic_cat(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -512,6 +506,74 @@ def g(x): run(g, 64, 8) self.assertEqual(cnt.frame_count, 2) + @unittest.skipIf(not HAS_GPU, "requires GPU for RNG support") + def test_dtensor_unbacked_matmuls(self): + from torch.distributed.tensor import randn as d_randn + + # use 2x2 mesh for testing + dist.destroy_process_group() + dist.init_process_group("fake", store=FakeStore(), rank=0, world_size=4) + device_mesh = init_device_mesh(self.device_type, (2, 2)) + + def test_placements(x_placements, y_placements, out_placements): + # create DTensors with unbacked outer/inner sizes + x_dt = d_randn(64, 64, device_mesh=device_mesh, placements=x_placements) + y_dt = d_randn(64, 64, device_mesh=device_mesh, placements=y_placements) + for i in range(2): + torch._dynamo.decorators.mark_unbacked(x_dt, i) + torch._dynamo.decorators.mark_unbacked(y_dt, i) + + # full-graph capture + torch._dynamo.reset() + fn = torch.compile(torch.mm, backend="aot_eager", fullgraph=True) + out = fn(x_dt, y_dt) + + # check output placements + self.assertEqual(out.placements, out_placements) + + test_placements( + (Replicate(), Replicate()), + (Replicate(), Replicate()), + (Replicate(), Replicate()), + ) + test_placements( + (Replicate(), Shard(1)), (Replicate(), Shard(0)), (Replicate(), Partial()) + ) + test_placements( + (Replicate(), Shard(0)), (Replicate(), Replicate()), (Replicate(), Shard(0)) + ) + + @unittest.skipIf(not HAS_GPU, "requires GPU for RNG support") + def test_dtensor_matmul_zero_size_shards(self): + from torch.distributed.tensor import randn as d_randn + + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + + dist.destroy_process_group() + dist.init_process_group("fake", store=FakeStore(), rank=0, world_size=4) + device_mesh = init_device_mesh(self.device_type, (2, 2)) + + # create DTensors with unbacked outer/inner sizes + px, py = (Replicate(), Shard(1)), (Replicate(), Shard(0)) + x_dt = d_randn(64, 64, device_mesh=device_mesh, placements=px) + y_dt = d_randn(64, 64, device_mesh=device_mesh, placements=py) + for i in range(2): + torch._dynamo.decorators.mark_unbacked(x_dt, i) + torch._dynamo.decorators.mark_unbacked(y_dt, i) + + # full-graph capture + fn = torch.compile(torch.mm, backend=cnt, fullgraph=True) + fn(x_dt, y_dt) + + # check zero-size shards + for m in [3, 0]: # n, k = 0 cause recompiles on strides + dx = d_randn(m, 1, device_mesh=device_mesh, placements=px) + dy = d_randn(1, 1, device_mesh=device_mesh, placements=py) + c_out, eager_out = fn(dx, dy), torch.mm(dx, dy) + self.assertEqual(tuple(c_out.shape), (m, 1)) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(c_out.shape, eager_out.shape) + def test_dtensor_requires_grad_recompile(self): cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index 4a88cf9a6e0b1..dbf0e33184ac2 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -266,7 +266,7 @@ def unmarked_nodes(gm): "all_reduce", "wait_tensor", "view_2", - "t_12", + "t_16", ] unmarked_nodes_fw = [ "view_3", @@ -281,48 +281,48 @@ def unmarked_nodes(gm): "all_reduce_1", "wait_tensor_1", "view_6", - "t_4", - "t_8", + "t_5", + "t_11", ] marked_nodes_bw = [ - "mm_4", - "t_13", + "mm_8", + "t_17", "view_1", - "mm_5", - "t_14", - "sum_3", - "view_9", - "t_15", + "mm_9", + "t_18", + "sum_5", + "view_11", + "t_19", "detach", "detach_3", "threshold_backward_1", - "t_16", - "mm_6", - "t_17", - "sum_4", - "view_10", - "t_18", + "t_20", + "mm_10", + "t_21", + "sum_6", + "view_12", + "t_22", ] unmarked_nodes_bw = [ - "mm", - "t_5", - "view_5", "mm_1", - "t_6", - "sum_1", - "view_7", "t_7", - "detach_1", - "detach_2", - "threshold_backward", - "mm_2", - "t_9", + "view_5", "mm_3", - "t_10", + "t_8", "sum_2", "view_8", - "t_11", + "t_9", + "detach_1", + "detach_2", + "threshold_backward", + "mm_5", + "t_13", + "mm_7", + "t_14", + "sum_4", + "view_10", + "t_15", "all_reduce_2", "wait_tensor_2", ] @@ -540,16 +540,53 @@ def forward(self, x): %item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {}) %ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {}) %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {}) - %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {}) + %getitem : [num_users=3] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {}) %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%getitem, _local_tensor), kwargs = {}) %sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {}) + %sym_size_int_1 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem, 0), kwargs = {}) %ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {}) %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {}) %le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {}) %_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {}) + %ge_3 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {}) + %_assert_scalar_default_3 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_3, Runtime assertion failed for expression u1 >= 0 on node 'ge_3'), kwargs = {}) + %le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 4), kwargs = {}) + %_assert_scalar_default_4 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u1 <= 4 on node 'le_1'), kwargs = {}) return (getitem,)""", # noqa: B950 ) + def test_dtensor_mark_unbacked(self): + device_mesh = init_device_mesh( + self.device_type, mesh_shape=(self.world_size // 2, 2) + ) + + class Foo(torch.nn.Module): + def forward(self, x, y): + return x @ y + + x_dt = distribute_tensor( + torch.randn(64, 64), device_mesh, placements=[Replicate(), Replicate()] + ) + y_dt = x_dt.clone() + for i in range(2): + torch._dynamo.decorators.mark_unbacked(x_dt, i) + torch._dynamo.decorators.mark_unbacked(y_dt, i) + + gm = dynamo_graph_capture_for_export(Foo())(x_dt, y_dt) + n = 0 + for node in gm.graph.nodes: + if bindings := node.meta.get("unbacked_bindings", {}): + # 2 outer sizes, 2 inner sizes + self.assertEqual(len(bindings), 4) + n += 1 + self.assertEqual(n, 2) # 2 nodes with bindings (x, y) + + # test size-0 tensor + z_dt = distribute_tensor( + torch.randn(0, 0), device_mesh, placements=[Replicate(), Replicate()] + ) + self.assertEqual(gm(z_dt, z_dt).shape, (0, 0)) + instantiate_parametrized_tests(DTensorExportTest) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 87becc8b8b1b2..3a9718b045cb6 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -575,34 +575,49 @@ def mark_unbacked( specialize_on (Optional[list[Any]], default=None): A list of specialization criteria (e.g., lambdas) for this dimension. If provided, Dynamo will generate specialized compiled regions for each criterion in addition to a generic trace. """ - # You could have copied the mark_dynamic behavior but I'm not convinced - # it's what you want - assert not is_traceable_wrapper_subclass(t), "not implemented yet" + if torch.distributed.is_available() and isinstance( + t, torch.distributed.tensor.DTensor + ): + # apply on inner tensor sizes/strides + mark_unbacked(t._local_tensor, index) + else: + # You could have copied the mark_dynamic behavior but I'm not convinced + # it's what you want + assert not is_traceable_wrapper_subclass(t), "not implemented yet" if isinstance(index, int): if strict: if not hasattr(t, "_dynamo_strict_unbacked_indices"): + # pyrefly: ignore [missing-attribute] t._dynamo_strict_unbacked_indices = set() + # pyrefly: ignore [missing-attribute] t._dynamo_strict_unbacked_indices.add(index) return if not hasattr(t, "_specialized_on"): + # pyrefly: ignore [missing-attribute] t._specialize_on = {} if not hasattr(t, "_dynamo_unbacked_indices"): + # pyrefly: ignore [missing-attribute] t._dynamo_unbacked_indices = set() if not hasattr(t, "_dynamo_hint_overrides"): + # pyrefly: ignore [missing-attribute] t._dynamo_hint_overrides = {} if hint_override: + # pyrefly: ignore [missing-attribute] t._dynamo_hint_overrides[index] = hint_override # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: # TypeError: 'Attribute' object does not support item assignment + # pyrefly: ignore [missing-attribute] if isinstance(t._specialize_on, dict): + # pyrefly: ignore [missing-attribute] t._specialize_on[index] = specialize_on if specialize_on is not None else [] + # pyrefly: ignore [missing-attribute] t._dynamo_unbacked_indices.add(index) return diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index f633088e946ed..c00a44ef8f4f4 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -83,8 +83,10 @@ def _mm_like_strategy( ) self_spec = strtg.input_specs[0] mat2_spec = strtg.input_specs[1] - if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable( - mat2_strategy.shape, mat2_spec + if is_tensor_shardable( + self_strategy.shape, self_spec, allow_unbacked_sharding=True + ) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True ): redistribute_cost = [ generate_redistribute_costs(self_strategy, self_spec), @@ -138,8 +140,10 @@ def _addmm_like_strategy( ) self_spec = DTensorSpec(mesh=mesh, placements=self_placements) - if is_tensor_shardable(mat1_strategy.shape, mat1_spec) and is_tensor_shardable( - mat2_strategy.shape, mat2_spec + if is_tensor_shardable( + mat1_strategy.shape, mat1_spec, allow_unbacked_sharding=True + ) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True ): # update input specs with new self spec strtg.input_specs = (self_spec, mat1_spec, mat2_spec) @@ -210,10 +214,18 @@ def _scaled_mm_like_strategy( ) strtg.input_specs = list(strtg.input_specs) + [scale_self_spec, scale_mat2_spec] if ( - is_tensor_shardable(self_strategy.shape, self_spec) - and is_tensor_shardable(mat2_strategy.shape, mat2_spec) - and is_tensor_shardable(scale_self_strategy.shape, scale_self_spec) - and is_tensor_shardable(scale_mat2_strategy.shape, scale_mat2_spec) + is_tensor_shardable( + self_strategy.shape, self_spec, allow_unbacked_sharding=True + ) + and is_tensor_shardable( + mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True + ) + and is_tensor_shardable( + scale_self_strategy.shape, scale_self_spec, allow_unbacked_sharding=True + ) + and is_tensor_shardable( + scale_mat2_strategy.shape, scale_mat2_spec, allow_unbacked_sharding=True + ) ): redistribute_cost = [ generate_redistribute_costs(self_strategy, self_spec), diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index f09a888734807..83857e1c3a8e9 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -89,11 +89,33 @@ def prod(xs: Iterable[int]) -> int: return functools.reduce(operator.mul, xs, 1) -def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: - """Check if the spec matches these criteria: - * any Shard placements in spec refer to valid tensor dims - * no empty local tensors (uneven sharding OK, as long as last rank has >0 size) +def is_tensor_shardable( + shape: Sequence[int], + spec: DTensorSpec, + allow_unbacked_sharding: Optional[bool] = None, +) -> bool: """ + Check if the shape is shardable according to the spec. + + allow_unbacked_sharding: determines the fallback value if unbacked shapes are involved, + and the queried shape properties are not statically known. + + e.g. when asking if u0 is shardable on num_shards, and u0 has generic bounds [0, inf], + the behavior of allow_unbacked_sharding is: + + None: will data-dependent error + True: assumes shardability; we return True, allowing zero-size shards at runtime when u0 < num_shards. + False: returns False, and lower-bounding u0, e.g. torch._check(u0 >= num_shards), is needed to enable sharding. + """ + from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true + + assert allow_unbacked_sharding in [None, True, False] + guard_fn = { + None: bool, + True: guard_or_false, + False: guard_or_true, + }[allow_unbacked_sharding] + # number of shards in each tensor dimension shards_map = [1] * len(shape) for i, placement in enumerate(spec.placements): @@ -106,7 +128,7 @@ def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: for i, dim_size in enumerate(shape): # TODO: maybe we should determine is_shardable based on # whether it's evenly sharded or not - if shards_map[i] > 1 and dim_size < shards_map[i]: + if shards_map[i] > 1 and guard_fn(dim_size < shards_map[i]): return False return True diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 11eb7a8ce667b..68bd38c11b94c 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -import contextlib +import logging import threading from collections.abc import Callable, Sequence from functools import lru_cache @@ -31,6 +31,8 @@ aten = torch.ops.aten +log = logging.getLogger(__name__) + def _length(obj) -> int: if obj is None: @@ -165,20 +167,9 @@ def _propagate_tensor_meta_non_cached( return None # NOTE: We must call the tracing in fake tensor mode so that it avoids - # materializing memory. Also disable the proxy mode tracing to prevent - # these operators to be inserted in the fx graph. - from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing - - # DTensor.dispatch runs fake tensor prop twice, once here, and once for the actual - # local tensor result. The result here is never surfaced to tracing, and so if - # the op is data-dependent, can result in PendingUnbackedSymbolNotFound errors. + # materializing memory. fake_mode = detect_fake_mode() or FakeTensorMode() - suppress_fresh_symbols_ctx = ( - fake_mode.shape_env.ignore_fresh_unbacked_symbols() - if fake_mode.shape_env - else contextlib.nullcontext() - ) - with fake_mode, disable_proxy_modes_tracing(), suppress_fresh_symbols_ctx: + with fake_mode: fake_args = op_schema.gen_fake_args() fake_kwargs = op_schema.gen_fake_kwargs() fake_out = op_schema.op(*fake_args, **fake_kwargs) @@ -593,12 +584,16 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin def _select_strategy( self, strategy: OpStrategy, op_schema: OpSchema | None = None ) -> OpSpec: + from torch.fx.experimental.symbolic_shapes import guard_or_false + if len(strategy.strategies) == 1: # short cut with only one possible OpSpec return strategy.strategies[0] - op_spec_costs: list[float] = [] + op_spec_costs: list[torch.types.FloatLikeType] = [] no_redistribute_strategy_index: int = -1 + negative_cost_index: int = -1 + zero_cost_index: int = -1 for strategy_idx, op_spec in enumerate(strategy.strategies): assert op_spec.redistribute_cost is not None, ( "must set redistribute cost each OpSpec!" @@ -606,37 +601,48 @@ def _select_strategy( redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost)) op_spec_costs.append(redistribute_cost) - # If there's no redistribute cost, we record the index of the strategy - # which doesn't need redistribute. + # If there are strategies with negative/zero/no redistribute cost, + # we record those indices. # TODO: Currently this only applies to OpStrategy selection. Requires extra # logic to make it work for TupleStrategy, if needed. - if op_schema is not None and redistribute_cost == 0: - needs_redistribute = False - for spec_idx, input_spec in enumerate(op_schema.args_spec): - desired_spec = ( - op_spec.output_spec - if op_spec.input_specs is None - else op_spec.input_specs[spec_idx] - ) - if input_spec.placements != desired_spec.placements: - needs_redistribute = True - break + if op_schema is not None: + if guard_or_false(redistribute_cost < 0): + if ( + negative_cost_index == -1 + or redistribute_cost < op_spec_costs[negative_cost_index] + ): + negative_cost_index = strategy_idx + elif guard_or_false(redistribute_cost == 0): + needs_redistribute = False + for spec_idx, input_spec in enumerate(op_schema.args_spec): + desired_spec = ( + op_spec.output_spec + if op_spec.input_specs is None + else op_spec.input_specs[spec_idx] + ) + if input_spec.placements != desired_spec.placements: + needs_redistribute = True + break - if not needs_redistribute: - no_redistribute_strategy_index = strategy_idx + if not needs_redistribute: + no_redistribute_strategy_index = strategy_idx + elif zero_cost_index == -1: + zero_cost_index = strategy_idx - # for eager execution, we just select the one with the minimal redistribute cost - min_cost = min(op_spec_costs) - if min_cost < 0: + # prioritize negative/zero/no redistribute cost strategies + if negative_cost_index != -1: # If there's negative cost, we select the one with the minimal cost, # even if this means we need to redistribute, e.g. via local chunking. # E.g. this can happen for ops in self.op_to_shape_and_stride_idx # when the inputs / outputs are sharded. - selected_strategy_index = op_spec_costs.index(min_cost) - elif min_cost == 0 and no_redistribute_strategy_index != -1: - # If there's no redistribute cost, we select the one with no redistribute. + selected_strategy_index = negative_cost_index + elif no_redistribute_strategy_index != -1: selected_strategy_index = no_redistribute_strategy_index + elif zero_cost_index != -1: + selected_strategy_index = zero_cost_index else: + # default to choosing minimal redistribute cost + min_cost = min(op_spec_costs) selected_strategy_index = op_spec_costs.index(min_cost) return strategy.strategies[selected_strategy_index] diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 77b2681055c44..56ffc77c23b08 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1192,35 +1192,13 @@ def expr(s: Union[SymInt, SymFloat, SymBool]) -> sympy.Expr: if pending is None: pending = set() r = {} - if isinstance(a, (tuple, list)): - # NB: real is apparently not always a tuple/list here - # python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu - for i in range(len(a)): - r.update( - go( - a[i], - path + (pytree.SequenceKey(i),), - real=real[i] if real is not None else None, # type: ignore[index] - ) - ) - elif is_traceable_wrapper_subclass(a): - # TODO: Determine if this is correct - attrs, _ = a.__tensor_flatten__() - for attr in attrs: - sub = getattr(a, attr) - r.update(go(sub, path + (InnerTensorKey(attr),))) - elif isinstance(a, torch.Tensor) and is_batchedtensor(a): - unwrapped_tensor = get_unwrapped(a) - r.update(go(unwrapped_tensor, path)) - elif isinstance(a, torch.Tensor) and not is_batchedtensor(a): - from torch._subclasses.fake_tensor import FakeTensor - assert isinstance(a, FakeTensor) + def match_tensor(a: torch.Tensor, real_tensor: Optional[torch.Tensor] = None): r.update( go( a.size(), path + (CallMethodKey("size"),), - real=a.real_tensor.size() if a.real_tensor is not None else None, + real=real_tensor.size() if real_tensor is not None else None, ) ) if a.layout not in [ @@ -1233,7 +1211,7 @@ def expr(s: Union[SymInt, SymFloat, SymBool]) -> sympy.Expr: go( a.stride(), path + (CallMethodKey("stride"),), - real=a.real_tensor.stride() if a.real_tensor is not None else None, + real=real_tensor.stride() if real_tensor is not None else None, ) ) r.update( @@ -1241,13 +1219,42 @@ def expr(s: Union[SymInt, SymFloat, SymBool]) -> sympy.Expr: a.storage_offset(), path + (CallMethodKey("storage_offset"),), real=( - a.real_tensor.storage_offset() - if a.real_tensor is not None - else None + real_tensor.storage_offset() if real_tensor is not None else None ), ) ) + if isinstance(a, (tuple, list)): + # NB: real is apparently not always a tuple/list here + # python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu + for i in range(len(a)): + r.update( + go( + a[i], + path + (pytree.SequenceKey(i),), + real=real[i] if real is not None else None, # type: ignore[index] + ) + ) + elif is_traceable_wrapper_subclass(a): + # TODO: Determine if this is correct + attrs, _ = a.__tensor_flatten__() + for attr in attrs: + sub = getattr(a, attr) + r.update(go(sub, path + (InnerTensorKey(attr),))) + + # match DTensor outer shapes + if torch.distributed.is_available() and isinstance( + a, torch.distributed.tensor.DTensor + ): + match_tensor(a) + elif isinstance(a, torch.Tensor) and is_batchedtensor(a): + unwrapped_tensor = get_unwrapped(a) + r.update(go(unwrapped_tensor, path)) + elif isinstance(a, torch.Tensor) and not is_batchedtensor(a): + from torch._subclasses.fake_tensor import FakeTensor + + assert isinstance(a, FakeTensor) + match_tensor(a, a.real_tensor) elif ( isinstance(a, (torch.SymInt, torch.SymFloat)) and isinstance(s := expr(a), sympy.Symbol) From 8b683e50e51dbf9d18b07c746330abafba644b05 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:40:51 +0000 Subject: [PATCH 318/338] [BE] Delete `install_vision` from Docker builds (#169609) Caffe2 used to have OpenCV integration path, but this is not the case for PyTorch Pull Request resolved: https://github.com/pytorch/pytorch/pull/169609 Approved by: https://github.com/jathu, https://github.com/atalman --- .ci/docker/build.sh | 20 ------------- .ci/docker/centos-rocm/Dockerfile | 7 ----- .ci/docker/common/install_vision.sh | 46 ----------------------------- .ci/docker/ubuntu-rocm/Dockerfile | 7 ----- .ci/docker/ubuntu-xpu/Dockerfile | 7 ----- .ci/docker/ubuntu/Dockerfile | 7 ----- 6 files changed, 94 deletions(-) delete mode 100755 .ci/docker/common/install_vision.sh diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 0e8caf69b3192..18979052c875c 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -98,7 +98,6 @@ case "$tag" in CUDA_VERSION=12.4 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -108,7 +107,6 @@ case "$tag" in CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -119,7 +117,6 @@ case "$tag" in CUDA_VERSION=13.0.0 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -129,7 +126,6 @@ case "$tag" in CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -140,7 +136,6 @@ case "$tag" in CUDA_VERSION=13.0.2 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -151,7 +146,6 @@ case "$tag" in CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=11 - VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -160,25 +154,21 @@ case "$tag" in pytorch-linux-jammy-py3-clang12-onnx) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=12 - VISION=yes ONNX=yes ;; pytorch-linux-jammy-py3.10-clang12) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=12 - VISION=yes TRITON=yes ;; pytorch-linux-jammy-py3.11-clang12) ANACONDA_PYTHON_VERSION=3.11 CLANG_VERSION=12 - VISION=no TRITON=no ;; pytorch-linux-jammy-py3.12-clang12) ANACONDA_PYTHON_VERSION=3.12 CLANG_VERSION=12 - VISION=no TRITON=no ;; pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3) @@ -188,7 +178,6 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.12 fi GCC_VERSION=11 - VISION=yes ROCM_VERSION=7.1 NINJA_VERSION=1.9.0 TRITON=yes @@ -222,7 +211,6 @@ case "$tag" in pytorch-linux-jammy-py3-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes TRITON=yes DOCS=yes @@ -232,18 +220,15 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 CUDA_VERSION=12.8.1 CLANG_VERSION=12 - VISION=yes TRITON=yes ;; pytorch-linux-jammy-py3-clang18-asan) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=18 - VISION=yes ;; pytorch-linux-jammy-py3.10-gcc11) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes TRITON=yes DOCS=yes @@ -285,7 +270,6 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=13 ACL=yes - VISION=yes OPENBLAS=yes # snadampal: skipping llvm src build install because the current version # from pytorch/llvm:9.0.1 is x86 specific @@ -295,7 +279,6 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=21 ACL=yes - VISION=yes OPENBLAS=yes # snadampal: skipping llvm src build install because the current version # from pytorch/llvm:9.0.1 is x86 specific @@ -305,7 +288,6 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=13 ACL=yes - VISION=yes OPENBLAS=yes # snadampal: skipping llvm src build install because the current version # from pytorch/llvm:9.0.1 is x86 specific @@ -317,7 +299,6 @@ case "$tag" in ;; *) # Catch-all for builds that are not hardcoded. - VISION=yes echo "image '$image' did not match an existing build configuration" if [[ "$image" == *py* ]]; then extract_version_from_image_name py ANACONDA_PYTHON_VERSION @@ -366,7 +347,6 @@ docker build \ ${progress_flag} \ --build-arg "BUILD_ENVIRONMENT=${image}" \ --build-arg "LLVMDEV=${LLVMDEV:-}" \ - --build-arg "VISION=${VISION:-}" \ --build-arg "UBUNTU_VERSION=${UBUNTU_VERSION}" \ --build-arg "DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}" \ --build-arg "GLIBC_VERSION=${GLIBC_VERSION}" \ diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index 319765590fc02..bf10142db3a56 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -47,13 +47,6 @@ COPY ./common/install_conda.sh install_conda.sh COPY ./common/common_utils.sh common_utils.sh RUN bash ./install_conda.sh && rm install_conda.sh common_utils.sh /opt/conda/requirements-ci.txt -# (optional) Install vision packages like OpenCV -ARG VISION -COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ -RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi -RUN rm install_vision.sh cache_vision_models.sh common_utils.sh -ENV INSTALLED_VISION ${VISION} - # Install rocm ARG ROCM_VERSION RUN mkdir ci_commit_pins diff --git a/.ci/docker/common/install_vision.sh b/.ci/docker/common/install_vision.sh deleted file mode 100755 index 78c445568ddcd..0000000000000 --- a/.ci/docker/common/install_vision.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -set -ex - -install_ubuntu() { - apt-get update - apt-get install -y --no-install-recommends \ - libopencv-dev - - # Cleanup - apt-get autoclean && apt-get clean - rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -} - -install_centos() { - # Need EPEL for many packages we depend on. - # See http://fedoraproject.org/wiki/EPEL - yum --enablerepo=extras install -y epel-release - - yum install -y \ - opencv-devel - - # Cleanup - yum clean all - rm -rf /var/cache/yum - rm -rf /var/lib/yum/yumdb - rm -rf /var/lib/yum/history -} - -# Install base packages depending on the base OS -ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') -case "$ID" in - ubuntu) - install_ubuntu - ;; - centos) - install_centos - ;; - *) - echo "Unable to determine OS..." - exit 1 - ;; -esac - -# Cache vision models used by the test -source "$(dirname "${BASH_SOURCE[0]}")/cache_vision_models.sh" diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index b517a990a057b..50f814fb2dff9 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -43,13 +43,6 @@ ARG CLANG_VERSION COPY ./common/install_clang.sh install_clang.sh RUN bash ./install_clang.sh && rm install_clang.sh -# (optional) Install vision packages like OpenCV -ARG VISION -COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ -RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi -RUN rm install_vision.sh cache_vision_models.sh common_utils.sh -ENV INSTALLED_VISION ${VISION} - # Install rocm ARG ROCM_VERSION RUN mkdir ci_commit_pins diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile index af11992a91646..f5db20a35945a 100644 --- a/.ci/docker/ubuntu-xpu/Dockerfile +++ b/.ci/docker/ubuntu-xpu/Dockerfile @@ -79,13 +79,6 @@ COPY triton_xpu_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-xpu.txt triton_version.txt -# (optional) Install vision packages like OpenCV -ARG VISION -COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ -RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi -RUN rm install_vision.sh cache_vision_models.sh common_utils.sh -ENV INSTALLED_VISION ${VISION} - # (optional) Install non-default Ninja version ARG NINJA_VERSION COPY ./common/install_ninja.sh install_ninja.sh diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 2081dcbdffd17..a50cdb0506ed2 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -75,13 +75,6 @@ ADD ./common/install_ucc.sh install_ucc.sh RUN if [ -n "${UCX_COMMIT}" ] && [ -n "${UCC_COMMIT}" ]; then bash ./install_ucc.sh; fi RUN rm install_ucc.sh -# (optional) Install vision packages like OpenCV -ARG VISION -COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ -RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi -RUN rm install_vision.sh cache_vision_models.sh common_utils.sh -ENV INSTALLED_VISION ${VISION} - # (optional) Install non-default Ninja version ARG NINJA_VERSION COPY ./common/install_ninja.sh install_ninja.sh From 4cd5afa836991795fd233e02021d8effc607d699 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Fri, 5 Dec 2025 04:36:04 +0000 Subject: [PATCH 319/338] [vllm hash update] update the pinned vllm hash (#165274) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vllm hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165274 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vllm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt index 45ad7752358c9..fe05273efd400 100644 --- a/.github/ci_commit_pins/vllm.txt +++ b/.github/ci_commit_pins/vllm.txt @@ -1 +1 @@ -e5192819208c4d68194844b7dfafbc00020d0dea +bcf43ab1f380208ea33769c49d116ea83f915080 From a573c495c20e60e472a1b836dbc09cd99658dbc2 Mon Sep 17 00:00:00 2001 From: karthickai Date: Thu, 4 Dec 2025 12:14:54 -0800 Subject: [PATCH 320/338] [Inductor] Add debug output for specific pattern matching (#169603) Fixes: debug logging part of the issue #169440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169603 Approved by: https://github.com/ProExpertProg, https://github.com/zou3519 --- test/inductor/test_pattern_matcher.py | 44 +++++++++++++++++++++++++++ torch/_inductor/pattern_matcher.py | 9 ++++++ 2 files changed, 53 insertions(+) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 9928b89b81e64..8f3fdb19a99ad 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -43,6 +43,7 @@ skipIfRocm, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU +from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test from torch.utils import _pytree as pytree @@ -1963,6 +1964,49 @@ def fn_replaced(x): self.assertEqual(fn_result, fn_replaced_result) +class TestPatternMatcherLogging(LoggingTestCase): + device_type = GPU_TYPE + + @make_logging_test() + def test_pattern_match_debug_output(self, records): + def pattern(x, y): + return x + y + + def replacement(x, y): + return x * y + + my_patterns = PatternMatcherPass() + inputs = [ + torch.randn(4, 4, device=GPU_TYPE), + torch.randn(4, 4, device=GPU_TYPE), + ] + register_replacement(pattern, replacement, inputs, fwd_only, my_patterns) + + def custom_pass(graph: torch.fx.Graph): + return my_patterns.apply(graph) + + def fn(x, y): + return x + y + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + + with unittest.mock.patch.dict( + os.environ, {"TORCHINDUCTOR_PATTERN_MATCH_DEBUG": "add"} + ): + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + result = compiled_fn(x, y) + self.assertEqual(result, x * y) + + specific_record = self.getRecord(records, "Specific pattern match") + self.assertIn( + "Match(..., [], {'x': arg0_1, 'y': arg1_1})", specific_record.getMessage() + ) + self.assertIn("add(arg0_1, arg1_1)", specific_record.getMessage()) + + if __name__ == "__main__": if IS_LINUX and HAS_GPU: run_tests() diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index c015c5232adf3..af071a55a23fc 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1553,6 +1553,15 @@ def search_fn_new(*args_new: Any) -> Any: assert node is not None specific_pattern_match = specific_pattern.match(node) + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning( + "Specific pattern match: %s%s %s %s", + node, + node.args, + specific_pattern_match, + specific_pattern, + ) + if is_match(specific_pattern_match) and extra_check(specific_pattern_match): # trace the pattern using the shapes from the user program match.replacement_graph = trace_fn(replace_fn, args) From 17ec1c3cc14aa87a487ca157d33974745947e40f Mon Sep 17 00:00:00 2001 From: karthickai Date: Thu, 4 Dec 2025 10:30:52 -0800 Subject: [PATCH 321/338] [Inductor] Fix combo kernels by populating constants for equal_to_1 args (#168127) Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168127 Approved by: https://github.com/mlazos ghstack dependencies: #167781 --- test/inductor/test_torchinductor.py | 57 ++++++++++--------- ...st_torchinductor_codegen_dynamic_shapes.py | 5 +- .../_inductor/codegen/triton_combo_kernel.py | 6 +- 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 16d2fc706fb6e..6cae8f568d87a 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -9220,38 +9220,41 @@ def fn(a): self.assertFalse(torch.allclose(a0, a1)) self.assertFalse(torch.allclose(a1, a2)) - def test_rand_like_deterministic(self): - @torch.compile(backend="inductor") - def fn(a): - return torch.rand_like(a), torch.rand_like(a) + @parametrize("combo_kernels", (False, True)) + def test_rand_like_deterministic(self, combo_kernels): + with config.patch(combo_kernels=combo_kernels): - x = torch.ones(1024, device=self.device, dtype=torch.float32) + @torch.compile(backend="inductor") + def fn(a): + return torch.rand_like(a), torch.rand_like(a) - torch.manual_seed(1234) - a0 = fn(x)[0].clone() - a1 = fn(x)[0].clone() - a2 = fn(x)[0].clone() + x = torch.ones(1024, device=self.device, dtype=torch.float32) - torch.manual_seed(1234) - b0 = fn(x)[0].clone() - b1 = fn(x)[0].clone() - b2 = fn(x)[0].clone() + torch.manual_seed(1234) + a0 = fn(x)[0].clone() + a1 = fn(x)[0].clone() + a2 = fn(x)[0].clone() - # same seed, same values - self.assertTrue(torch.allclose(a0, b0)) - self.assertTrue(torch.allclose(a1, b1)) - self.assertTrue(torch.allclose(a2, b2)) + torch.manual_seed(1234) + b0 = fn(x)[0].clone() + b1 = fn(x)[0].clone() + b2 = fn(x)[0].clone() - # different calls, different values - self.assertFalse(torch.allclose(a0, a1)) - self.assertFalse(torch.allclose(a1, a2)) - - c, d = fn(x) - self.assertFalse(torch.allclose(c, d)) - self.assertTrue((c >= 0).all()) - self.assertTrue((c < 1).all()) - self.assertTrue((d >= 0).all()) - self.assertTrue((d < 1).all()) + # same seed, same values + self.assertTrue(torch.allclose(a0, b0)) + self.assertTrue(torch.allclose(a1, b1)) + self.assertTrue(torch.allclose(a2, b2)) + + # different calls, different values + self.assertFalse(torch.allclose(a0, a1)) + self.assertFalse(torch.allclose(a1, a2)) + + c, d = fn(x) + self.assertFalse(torch.allclose(c, d)) + self.assertTrue((c >= 0).all()) + self.assertTrue((c < 1).all()) + self.assertTrue((d >= 0).all()) + self.assertTrue((d < 1).all()) @config.patch(implicit_fallbacks=True) def test_needs_contiguous_strides(self): diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index e73f82ab64911..edd18519e1d2e 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -367,7 +367,10 @@ def run(*ex, **kwargs): "test_profiler_mark_wrapper_call_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), - "test_rand_like_deterministic_dynamic_shapes": TestFailure( + "test_rand_like_deterministic_combo_kernels_False_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_rand_like_deterministic_combo_kernels_True_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu",)), diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 41b12d05cd32e..6edf0e2decb0c 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -36,7 +36,7 @@ from .simd import prefix_is_reduction, SIMDScheduling from .simd_kernel_features import SIMDKernelFeatures from .triton import gen_common_triton_imports, TritonKernel -from .triton_utils import config_of, signature_to_meta +from .triton_utils import config_of, equal_1_arg_indices, signature_to_meta log = logging.getLogger(__name__) @@ -610,6 +610,10 @@ def jit_line( "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), "constants": {}, } + + for arg_num in equal_1_arg_indices(signature): + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + # pyrefly: ignore [unsupported-operation] triton_meta["configs"] = [config_of(signature)] mutated_args = self.get_mutated_args_sub_kernels() From 5f76830b72871bed56f9edde79fe7e5b767512a1 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 5 Dec 2025 06:01:52 +0000 Subject: [PATCH 322/338] [ROCm] unskip some ROCm inductor UTs (#169564) Fixes https://github.com/pytorch/pytorch/issues/168478 Fixes https://github.com/pytorch/pytorch/issues/168557 Fixes https://github.com/pytorch/pytorch/issues/168573 Fixes https://github.com/pytorch/pytorch/issues/168581 Fixes https://github.com/pytorch/pytorch/issues/168586 Fixes https://github.com/pytorch/pytorch/issues/168625 Fixes https://github.com/pytorch/pytorch/issues/168647 Fixes https://github.com/pytorch/pytorch/issues/168649 Fixes https://github.com/pytorch/pytorch/issues/168672 Fixes https://github.com/pytorch/pytorch/issues/168676 Fixes https://github.com/pytorch/pytorch/issues/168677 Fixes https://github.com/pytorch/pytorch/issues/168678 Fixes https://github.com/pytorch/pytorch/issues/168679 Fixes https://github.com/pytorch/pytorch/issues/168684 Fixes https://github.com/pytorch/pytorch/issues/168683 Fixes https://github.com/pytorch/pytorch/issues/168681 Unskip some UTs Pull Request resolved: https://github.com/pytorch/pytorch/pull/169564 Approved by: https://github.com/jeffdaily --- test/distributed/test_inductor_collectives.py | 3 -- test/inductor/test_analysis.py | 35 ++++++------------- test/inductor/test_codecache.py | 35 ++++++++++++++++--- test/inductor/test_cuda_repro.py | 3 -- test/inductor/test_device_assert.py | 3 -- test/inductor/test_mkldnn_pattern_matcher.py | 9 ----- test/inductor/test_pad_mm.py | 2 -- test/inductor/test_pattern_matcher.py | 2 -- test/inductor/test_subgraph_choice.py | 5 +-- test/inductor/test_torchinductor.py | 6 ---- 10 files changed, 42 insertions(+), 61 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 4be02cbafbe1f..3a54e8c5fb1ac 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -52,7 +52,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - skipIfRocm, skipIfXpu, TEST_XPU, xfailIf, @@ -276,8 +275,6 @@ def compile(func, example_inputs): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728 - @skipIfRocm - @xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728 def test_eager_async_allreduce_inductor_wait(self): import torch.distributed as dist from torch._inductor.utils import run_and_get_code diff --git a/test/inductor/test_analysis.py b/test/inductor/test_analysis.py index 147760fe4df67..0731edd9c5ab2 100644 --- a/test/inductor/test_analysis.py +++ b/test/inductor/test_analysis.py @@ -274,11 +274,13 @@ def test_zip_dicts(self): self.assertEqual(set(res2), {("a", 1, 3), ("b", 2, None), ("c", None, 4)}) +def has_supported_gpu(): + """Check if any GPU platform with Triton support is available.""" + return torch.xpu.is_available() or SM80OrLater or torch.version.hip + + class TestAnalysis(TestCase): - @skipIf( - (not torch.xpu.is_available()) and (not SM80OrLater), - "Requires XPU or CUDA SM80", - ) + @skipIf(not has_supported_gpu(), "Requires XPU, CUDA SM80+, or ROCm") def test_noop(self): with ( patch("sys.stdout", new_callable=StringIO) as mock_stdout, @@ -287,10 +289,7 @@ def test_noop(self): main() self.assertEqual(mock_stdout.getvalue(), "") - @skipIf( - (not torch.xpu.is_available()) and (not SM80OrLater), - "Requires XPU or CUDA SM80", - ) + @skipIf(not has_supported_gpu(), "Requires XPU, CUDA SM80+, or ROCm") @dtypes(torch.float, torch.double, torch.float16) def test_diff(self, device, dtype): """ @@ -341,10 +340,7 @@ def test_augment_trace_helper_unit(self): expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0] verify_flops(self, expected_flops, out_profile) - @skipIf( - (not torch.xpu.is_available()) and (not SM80OrLater), - "Requires XPU or CUDA SM80", - ) + @skipIf(not has_supported_gpu(), "Requires XPU, CUDA SM80+, or ROCm") @skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU") @dtypes(torch.float, torch.double, torch.float16) @parametrize( @@ -399,10 +395,7 @@ def verify_triton(comp): verify_triton(comp_omni) - @skipIf( - (not torch.xpu.is_available()) and (not SM80OrLater), - "Requires XPU or CUDA SM80", - ) + @skipIf(not has_supported_gpu(), "Requires XPU, CUDA SM80+, or ROCm") @skipIfXpu( msg="Intel triton issue: https://github.com/intel/intel-xpu-backend-for-triton/issues/5491" ) @@ -518,10 +511,7 @@ def test_augment_trace_against_flop_counter(self, device, dtype, maxat): self.assertTrue(seen_baddbmm) self.assertTrue(seen_conv) - @skipIf( - (not torch.xpu.is_available()) and (not SM80OrLater), - "Requires XPU or CUDA SM80", - ) + @skipIf(not has_supported_gpu(), "Requires XPU, CUDA SM80+, or ROCm") @skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU") @dtypes(torch.float, torch.float16) @parametrize( @@ -572,10 +562,7 @@ def test_pointwise_bandwidth(self, device, dtype, maxat): if event["name"] == "triton_poi_fused_add_randn_sin_0": event["args"]["kernel_num_gb"] = 0.002097168 - @skipIf( - (not torch.xpu.is_available()) and (not SM80OrLater), - "Requires XPU or CUDA SM80", - ) + @skipIf(not has_supported_gpu(), "Requires XPU, CUDA SM80+, or ROCm") @dtypes(torch.float, torch.float16) def test_combine_profiles(self, device, dtype): """ diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 1ab261051f4c6..9bab2bb970c55 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -290,7 +290,12 @@ def test_cache_load_function( """ if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if ( + device == "cuda" + and torch.version.hip is None + and dtype == torch.bfloat16 + and not SM80OrLater + ): raise unittest.SkipTest("requires SM80 or later") if use_static_cuda_launcher and not (device == "cuda" and bundle_triton): raise unittest.SkipTest( @@ -542,7 +547,12 @@ def test_cache_hot_load(self, device, dtype, dynamic): """ if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if ( + device == "cuda" + and torch.version.hip is None + and dtype == torch.bfloat16 + and not SM80OrLater + ): raise unittest.SkipTest("requires SM80 or later") def fn(x, y): @@ -634,7 +644,12 @@ def test_cache_hot_load_caching_precompile(self, device, dtype, dynamic): if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if ( + device == "cuda" + and torch.version.hip is None + and dtype == torch.bfloat16 + and not SM80OrLater + ): raise unittest.SkipTest("requires SM80 or later") def fn(x, y): @@ -1003,7 +1018,12 @@ def test_cache_load_with_guards_int32_bounds(self, device, dtype): """ if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if ( + device == "cuda" + and torch.version.hip is None + and dtype == torch.bfloat16 + and not SM80OrLater + ): raise unittest.SkipTest("requires CUDA SM80 or later") def fn(x, y): @@ -1052,7 +1072,12 @@ def test_cache_load_with_guards_static_bounds(self, device, dtype): """ if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if ( + device == "cuda" + and torch.version.hip is None + and dtype == torch.bfloat16 + and not SM80OrLater + ): raise unittest.SkipTest("requires SM80 or later") # See lowering; for all of the pooling operators, we always guard and diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 3cd2900051943..eff8c8937deb2 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -40,9 +40,7 @@ freeze_rng_state, instantiate_parametrized_tests, IS_FBCODE, - MI350_ARCH, parametrize, - skipIfRocmArch, TEST_WITH_ASAN, TEST_WITH_ROCM, xfailIfPy312Plus, @@ -223,7 +221,6 @@ def fn( # dont check rng state self.assertEqual(out[:2], fn(query, key, value, input_tensor2)[:2]) - @skipIfRocmArch(MI350_ARCH) def test_effn_attn_bias_padding_misaligned(self): seqlen_start = 1008 diff --git a/test/inductor/test_device_assert.py b/test/inductor/test_device_assert.py index c5dfd8de26f0b..cbeb7960f2f55 100644 --- a/test/inductor/test_device_assert.py +++ b/test/inductor/test_device_assert.py @@ -8,7 +8,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - skipIfRocm, ) from torch.testing._internal.triton_utils import requires_gpu_and_triton @@ -59,7 +58,6 @@ def func_inline(): f_c() @requires_gpu_and_triton - @skipIfRocm @torch._inductor.config.patch(force_disable_caches=True) def test_assert_fusion(self): torch._logging.set_logs(inductor_metrics=True) @@ -78,7 +76,6 @@ def func(): torch._logging.set_logs() @requires_gpu_and_triton - @skipIfRocm @torch._inductor.config.patch(force_disable_caches=True) def test_run_assert_triton(self): @torch.compile(backend="inductor") diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 1001a8a9f997a..3626dd17301db 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -362,7 +362,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @reduced_f32_on_and_off() def test_conv2d_unary(self, device): self.device = device @@ -370,7 +369,6 @@ def test_conv2d_unary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @reduced_f32_on_and_off() def test_conv3d_unary(self, device): self.device = device @@ -451,7 +449,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @skipIfXpu( msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." ) @@ -462,7 +459,6 @@ def test_conv_transpose2d_unary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @skipIfXpu( msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." ) @@ -560,7 +556,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @reduced_f32_on_and_off(0.02) def test_conv2d_binary(self, device): self.device = device @@ -568,7 +563,6 @@ def test_conv2d_binary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @reduced_f32_on_and_off(0.02) def test_conv3d_binary(self, device): self.device = device @@ -668,7 +662,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @reduced_f32_on_and_off() def test_conv2d_binary_broadcast_shapes(self, device): self.device = device @@ -676,7 +669,6 @@ def test_conv2d_binary_broadcast_shapes(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @reduced_f32_on_and_off(bf32_precision=5e-2) def test_conv3d_binary_broadcast_shapes(self, device): self.device = device @@ -684,7 +676,6 @@ def test_conv3d_binary_broadcast_shapes(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skipIfRocm @unittest.skipIf(IS_FBCODE, "Failing in fbcode") @reduced_f32_on_and_off() def test_conv2d_linear_add_broadcast_shapes(self, device): diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index c61434427f535..004855606cce0 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -15,7 +15,6 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck -from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON @@ -475,7 +474,6 @@ def mm(inps, b): and (not torch.xpu.is_available()), "No perf regression on H100+ with BF16", ) - @skipIfRocm @fresh_cache() @inductor_config.patch( post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}} diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 8f3fdb19a99ad..30e7f28c45ada 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -40,7 +40,6 @@ instantiate_parametrized_tests, IS_LINUX, parametrize, - skipIfRocm, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test @@ -286,7 +285,6 @@ def fn2(a, b, c): self._test_fused_int_mm_mul_impl(fn1, args, True) self._test_fused_int_mm_mul_impl(fn2, args, True) - @skipIfRocm @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch( { diff --git a/test/inductor/test_subgraph_choice.py b/test/inductor/test_subgraph_choice.py index d2d5a3bf59a9e..408af8d379111 100644 --- a/test/inductor/test_subgraph_choice.py +++ b/test/inductor/test_subgraph_choice.py @@ -1,5 +1,4 @@ # Owner(s): ["module: inductor"] -import unittest from unittest import mock from unittest.mock import MagicMock @@ -8,7 +7,7 @@ from torch._inductor.lowering import register_lowering from torch._inductor.select_algorithm import autotune_select_algorithm from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM +from torch.testing._internal.common_utils import skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU @@ -37,7 +36,6 @@ def _create_buffer(self, name, shape, dtype): ) @skipIfXpu - @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") def test_subgraph_decompose_k(self): from torch._inductor.kernel.mm import aten_mm from torch._inductor.kernel.mm_common import mm_args @@ -98,7 +96,6 @@ def func(mat1, mat2): torch.testing.assert_close(res, a_in @ b_in, atol=1e-1, rtol=1e-1) @skipIfXpu - @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") def test_subgraph_freeze_layout(self): from torch._inductor.kernel.mm_common import mm_args diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 6cae8f568d87a..f51825c0b0cc0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2488,7 +2488,6 @@ def fn(a, b_int8pack, b_scales, c): @xfail_if_mps_unimplemented @xfail_if_triton_cpu @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") - @skipIfRocm @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") def test__dyn_quant_pack_4bit_weight_fp32(self): q_group = 32 @@ -2524,7 +2523,6 @@ def fn(b, in_features, out_features): @xfail_if_mps_unimplemented @xfail_if_triton_cpu @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") - @skipIfRocm @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") @skip_if_halide # bf16 def test__dyn_quant_pack_4bit_weight_bf16(self): @@ -2566,7 +2564,6 @@ def fn(b, in_features, out_features): @xfail_if_mps_unimplemented @xfail_if_triton_cpu @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") - @skipIfRocm @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") def test__dyn_quant_matmul_4bit_fp32_input(self): q_group = 32 @@ -2612,7 +2609,6 @@ def fn(a, q_group, in_features, out_features): @xfail_if_mps_unimplemented @xfail_if_triton_cpu @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") - @skipIfRocm @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") @skip_if_halide # bf16 def test__dyn_quant_matmul_4bit_bf16_input(self): @@ -4700,7 +4696,6 @@ def fn(x): check_lowp=False, # cpu doesn't understand fp16, and there are explicit .cpu() calls ) - @skipIfRocm @requires_multigpu() def test_multi_gpu_device(self): # TODO: https://github.com/pytorch/pytorch/issues/92627 @@ -11625,7 +11620,6 @@ def fn_or(x, y): (torch.randn(32), torch.randn(32)), ) - @skipIfRocm def test_conv_with_as_strided(self): class Model(nn.Module): def __init__(self) -> None: From 7375582b1aa31638020d1fd15f3812d7209707c9 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 5 Dec 2025 06:52:20 +0000 Subject: [PATCH 323/338] Fix slotscheck warnings (#169348) This PR fixes some of slotscheck warnings. The some of them are: ``` ERROR: 'torch._inductor.cudagraph_trees:AliasesNewOutput' has slots but superclass does not. ERROR: 'torch._inductor.cudagraph_trees:AliasesPriorGraphOutput' has slots but superclass does not. ERROR: 'torch._subclasses.fake_tensor:_BypassDispatchCache' has slots but superclass does not. ERROR: 'torch.distributed._functional_collectives:AsyncCollectiveTensor' has slots but superclass does not. ERROR: 'torch.distributed.elastic.timer.file_based_local_timer:FileTimerRequest' defines overlapping slots. ERROR: 'torch.distributed.tensor._shards_wrapper:LocalShardsWrapper' has slots but superclass does not. ERROR: 'torch.distributed.tensor:DTensor' has slots but superclass does not. ERROR: 'torch.multiprocessing.spawn:ProcessException' has slots but superclass does not. ERROR: 'torch.package.package_importer:_ModuleNode' has slots but superclass does not. ERROR: 'torch.sparse.semi_structured:SparseSemiStructuredTensor' has slots but superclass does not. ERROR: 'torch.testing._internal.logging_tensor:LoggingTensor' has slots but superclass does not. ``` The fixes work by adding slot to their parent. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169348 Approved by: https://github.com/Skylion007 --- torch/_inductor/cudagraph_trees.py | 2 +- .../elastic/timer/file_based_local_timer.py | 18 ++++++++++-------- torch/package/package_importer.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 98280b5af783c..72d0bcc69e3d0 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -763,7 +763,7 @@ def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: class OutputAliasInfo: - pass + __slots__ = [] class _UnaliasedStorage(OutputAliasInfo): diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 14ec6e6af8537..5855efefcc853 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -68,24 +68,26 @@ class FileTimerRequest(TimerRequest): process. """ - __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"] + __slots__ = ["version", "signal"] def __init__( self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 ) -> None: + super().__init__( + worker_id=worker_pid, scope_id=scope_id, expiration_time=expiration_time + ) self.version = 1 - self.worker_pid = worker_pid - self.scope_id = scope_id - self.expiration_time = expiration_time self.signal = signal + @property + def worker_pid(self) -> int: + return self.worker_id + def __eq__(self, other) -> bool: if isinstance(other, FileTimerRequest): return ( - self.version == other.version - and self.worker_pid == other.worker_pid - and self.scope_id == other.scope_id - and self.expiration_time == other.expiration_time + super().__eq__(other) + and self.version == other.version and self.signal == other.signal ) return False diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 10bf8981e28ae..b564dace63b4a 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -695,7 +695,7 @@ def _add_extern(self, extern_name: str): class _PathNode: - pass + __slots__ = [] class _PackageNode(_PathNode): From 09f18009ea088c383418ee63649a73e197997424 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 5 Dec 2025 06:55:26 +0000 Subject: [PATCH 324/338] Revert "[BE] Delete `install_vision` from Docker builds (#169609)" This reverts commit 8b683e50e51dbf9d18b07c746330abafba644b05. Reverted https://github.com/pytorch/pytorch/pull/169609 on behalf of https://github.com/malfet due to It causes inductor tests to fail ([comment](https://github.com/pytorch/pytorch/pull/169609#issuecomment-3615543740)) --- .ci/docker/build.sh | 20 +++++++++++++ .ci/docker/centos-rocm/Dockerfile | 7 +++++ .ci/docker/common/install_vision.sh | 46 +++++++++++++++++++++++++++++ .ci/docker/ubuntu-rocm/Dockerfile | 7 +++++ .ci/docker/ubuntu-xpu/Dockerfile | 7 +++++ .ci/docker/ubuntu/Dockerfile | 7 +++++ 6 files changed, 94 insertions(+) create mode 100755 .ci/docker/common/install_vision.sh diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 18979052c875c..0e8caf69b3192 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -98,6 +98,7 @@ case "$tag" in CUDA_VERSION=12.4 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 + VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -107,6 +108,7 @@ case "$tag" in CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 + VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -117,6 +119,7 @@ case "$tag" in CUDA_VERSION=13.0.0 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 + VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -126,6 +129,7 @@ case "$tag" in CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 + VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -136,6 +140,7 @@ case "$tag" in CUDA_VERSION=13.0.2 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 + VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -146,6 +151,7 @@ case "$tag" in CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=11 + VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} @@ -154,21 +160,25 @@ case "$tag" in pytorch-linux-jammy-py3-clang12-onnx) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=12 + VISION=yes ONNX=yes ;; pytorch-linux-jammy-py3.10-clang12) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=12 + VISION=yes TRITON=yes ;; pytorch-linux-jammy-py3.11-clang12) ANACONDA_PYTHON_VERSION=3.11 CLANG_VERSION=12 + VISION=no TRITON=no ;; pytorch-linux-jammy-py3.12-clang12) ANACONDA_PYTHON_VERSION=3.12 CLANG_VERSION=12 + VISION=no TRITON=no ;; pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3) @@ -178,6 +188,7 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.12 fi GCC_VERSION=11 + VISION=yes ROCM_VERSION=7.1 NINJA_VERSION=1.9.0 TRITON=yes @@ -211,6 +222,7 @@ case "$tag" in pytorch-linux-jammy-py3-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 + VISION=yes KATEX=yes TRITON=yes DOCS=yes @@ -220,15 +232,18 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 CUDA_VERSION=12.8.1 CLANG_VERSION=12 + VISION=yes TRITON=yes ;; pytorch-linux-jammy-py3-clang18-asan) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=18 + VISION=yes ;; pytorch-linux-jammy-py3.10-gcc11) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 + VISION=yes KATEX=yes TRITON=yes DOCS=yes @@ -270,6 +285,7 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=13 ACL=yes + VISION=yes OPENBLAS=yes # snadampal: skipping llvm src build install because the current version # from pytorch/llvm:9.0.1 is x86 specific @@ -279,6 +295,7 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=21 ACL=yes + VISION=yes OPENBLAS=yes # snadampal: skipping llvm src build install because the current version # from pytorch/llvm:9.0.1 is x86 specific @@ -288,6 +305,7 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=13 ACL=yes + VISION=yes OPENBLAS=yes # snadampal: skipping llvm src build install because the current version # from pytorch/llvm:9.0.1 is x86 specific @@ -299,6 +317,7 @@ case "$tag" in ;; *) # Catch-all for builds that are not hardcoded. + VISION=yes echo "image '$image' did not match an existing build configuration" if [[ "$image" == *py* ]]; then extract_version_from_image_name py ANACONDA_PYTHON_VERSION @@ -347,6 +366,7 @@ docker build \ ${progress_flag} \ --build-arg "BUILD_ENVIRONMENT=${image}" \ --build-arg "LLVMDEV=${LLVMDEV:-}" \ + --build-arg "VISION=${VISION:-}" \ --build-arg "UBUNTU_VERSION=${UBUNTU_VERSION}" \ --build-arg "DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}" \ --build-arg "GLIBC_VERSION=${GLIBC_VERSION}" \ diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index bf10142db3a56..319765590fc02 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -47,6 +47,13 @@ COPY ./common/install_conda.sh install_conda.sh COPY ./common/common_utils.sh common_utils.sh RUN bash ./install_conda.sh && rm install_conda.sh common_utils.sh /opt/conda/requirements-ci.txt +# (optional) Install vision packages like OpenCV +ARG VISION +COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ +RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi +RUN rm install_vision.sh cache_vision_models.sh common_utils.sh +ENV INSTALLED_VISION ${VISION} + # Install rocm ARG ROCM_VERSION RUN mkdir ci_commit_pins diff --git a/.ci/docker/common/install_vision.sh b/.ci/docker/common/install_vision.sh new file mode 100755 index 0000000000000..78c445568ddcd --- /dev/null +++ b/.ci/docker/common/install_vision.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +set -ex + +install_ubuntu() { + apt-get update + apt-get install -y --no-install-recommends \ + libopencv-dev + + # Cleanup + apt-get autoclean && apt-get clean + rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* +} + +install_centos() { + # Need EPEL for many packages we depend on. + # See http://fedoraproject.org/wiki/EPEL + yum --enablerepo=extras install -y epel-release + + yum install -y \ + opencv-devel + + # Cleanup + yum clean all + rm -rf /var/cache/yum + rm -rf /var/lib/yum/yumdb + rm -rf /var/lib/yum/history +} + +# Install base packages depending on the base OS +ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') +case "$ID" in + ubuntu) + install_ubuntu + ;; + centos) + install_centos + ;; + *) + echo "Unable to determine OS..." + exit 1 + ;; +esac + +# Cache vision models used by the test +source "$(dirname "${BASH_SOURCE[0]}")/cache_vision_models.sh" diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 50f814fb2dff9..b517a990a057b 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -43,6 +43,13 @@ ARG CLANG_VERSION COPY ./common/install_clang.sh install_clang.sh RUN bash ./install_clang.sh && rm install_clang.sh +# (optional) Install vision packages like OpenCV +ARG VISION +COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ +RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi +RUN rm install_vision.sh cache_vision_models.sh common_utils.sh +ENV INSTALLED_VISION ${VISION} + # Install rocm ARG ROCM_VERSION RUN mkdir ci_commit_pins diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile index f5db20a35945a..af11992a91646 100644 --- a/.ci/docker/ubuntu-xpu/Dockerfile +++ b/.ci/docker/ubuntu-xpu/Dockerfile @@ -79,6 +79,13 @@ COPY triton_xpu_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-xpu.txt triton_version.txt +# (optional) Install vision packages like OpenCV +ARG VISION +COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ +RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi +RUN rm install_vision.sh cache_vision_models.sh common_utils.sh +ENV INSTALLED_VISION ${VISION} + # (optional) Install non-default Ninja version ARG NINJA_VERSION COPY ./common/install_ninja.sh install_ninja.sh diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index a50cdb0506ed2..2081dcbdffd17 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -75,6 +75,13 @@ ADD ./common/install_ucc.sh install_ucc.sh RUN if [ -n "${UCX_COMMIT}" ] && [ -n "${UCC_COMMIT}" ]; then bash ./install_ucc.sh; fi RUN rm install_ucc.sh +# (optional) Install vision packages like OpenCV +ARG VISION +COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ +RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi +RUN rm install_vision.sh cache_vision_models.sh common_utils.sh +ENV INSTALLED_VISION ${VISION} + # (optional) Install non-default Ninja version ARG NINJA_VERSION COPY ./common/install_ninja.sh install_ninja.sh From a0a2ae660d112b793772f4a834b8f6f91c722035 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Fri, 5 Dec 2025 00:29:41 +0000 Subject: [PATCH 325/338] [Dynamo][Guards]Fix TLParse CPP guard message with sorting get_leaf_guards and verbose_code_parts (#169102) Fix #168379. 1. The results are validated in the improved testing that the ``___dict_contains`` will be sorted based on the verbose part. The first solution was also suggested in https://fb.workplace.com/groups/1075192433118967/permalink/1650742858897252/ by sorting the ``get_leaf_guards()`` in ``construct_manager_string``. 2. The second solution will be adopted the ``OrderedSet`` in setGuards during guards construction to make sure the ``contain_dict`` are displayed as the order of being added. We decided to pursuit the second options to reduce the sorting time overhead and simplicity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169102 Approved by: https://github.com/anijain2305 --- test/dynamo/test_misc.py | 31 +++++++++++++++---------------- torch/_dynamo/guards.py | 4 ++-- torch/_guards.py | 24 +++++++++++++----------- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b3e2e9d4fee4d..5da86066b977b 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1225,30 +1225,29 @@ def fn(x, y): # Filter out id-matches that won't reproduce run to run guard_code = filter( lambda line: "id" not in line and "lookup_backend" not in line, - sorted(guard_code), + guard_code, ) guard_code_str = "\n".join(guard_code) - for line in """\ -2 <= L['x'].size()[0] -L['x'] is L['y'] -L['x'].ndimension() == 2 -L['x'].requires_grad == False + # Make sure that the dict_contains are present in the order of added + self.assertExpectedInline( + guard_code_str, + """\ L['x'].size()[1] == L['x'].size()[0] L['x'].storage_offset() == 0 -___dict_contains('operator', G['sys'].modules) -___dict_contains('operator', G['sys'].modules) +2 <= L['x'].size()[0] +utils_device.CURRENT_DEVICE == None +str(L['x'].dtype) == 'torch.float32' +str(L['x'].device) == 'cpu' +L['x'].requires_grad == False +L['x'].ndimension() == 2 hasattr(L['x'], '_dynamo_dynamic_indices') == False +L['x'] is L['y'] not ___dict_contains('aaaaaaaa', G['sys'].modules) not ___dict_contains('bbbbbbbb', G['sys'].modules) -not ___dict_contains('cccccccc', G['sys'].modules) -str(L['x'].device) == 'cpu' -str(L['x'].dtype) == 'torch.float32' -utils_device.CURRENT_DEVICE == None""".split("\n"): - self.assertIn( - line, - guard_code_str, - ) +___dict_contains('operator', G['sys'].modules) +not ___dict_contains('cccccccc', G['sys'].modules)""", + ) def test_fold(self): def fn(a): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index ea720d5c49f5f..a30e509e72e47 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -3893,7 +3893,7 @@ def _ref(x: Any) -> Any: }, global_scope=global_scope_state, _guards=torch._guards.GuardsSet( - { + OrderedSet( dataclasses.replace( guard, obj_weakref=None, @@ -3901,7 +3901,7 @@ def _ref(x: Any) -> Any: create_fn=normalize_create_fn(guard.create_fn), ) for guard in sorted_guards - } + ) ), input_source_to_sizes_strides=pytree.tree_map( convert_int_to_concrete_values, diff --git a/torch/_guards.py b/torch/_guards.py index 2f5b41527478b..e5efcfed17a6b 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -15,7 +15,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, overload, TYPE_CHECKING, TypeVar +from typing import Any, Generic, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar if sys.version_info >= (3, 11): @@ -31,6 +31,7 @@ def decorator(fn): import torch from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._traceback import CapturedTraceback, format_frame from torch.utils.weak import WeakTensorKeyDictionary @@ -500,16 +501,16 @@ class GuardsCheckpointState: The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext """ - dynamo_guards: set[Guard] = set() + dynamo_guards: OrderedSet[Guard] - def __init__(self, dynamo_guards: set[Guard]) -> None: + def __init__(self, dynamo_guards: OrderedSet[Guard]) -> None: self.dynamo_guards = dynamo_guards - def diff(self, other: GuardsCheckpointState) -> set[Guard] | None: + def diff(self, other: GuardsCheckpointState) -> Optional[OrderedSet[Guard]]: """ Produces a delta against another GuardsCheckpointState. - Returns None if no delta is found, otherwise, return a set() of mismatched + Returns None if no delta is found, otherwise, return an OrderedSet() of mismatched Guard type objects. """ r = self.dynamo_guards.difference(other.dynamo_guards) @@ -618,10 +619,11 @@ def restore_graphstate(self, state: GlobalContextCheckpointState) -> None: # Like a Set[Guard] but will record the user stack on all guards at the # time they were installed at their destination class GuardsSet: - def __init__(self, inner: set[Guard] | None = None) -> None: + def __init__(self, inner: Optional[OrderedSet[Guard]] = None) -> None: if inner is None: - inner = set() - self.inner = inner + self.inner: OrderedSet[Guard] = OrderedSet() + else: + self.inner = inner def __iter__(self) -> Iterator[Guard]: return iter(self.inner) @@ -658,9 +660,9 @@ def remove_guards_with_source(self, source: Source) -> None: """Delete all guards that contains a given source""" from ._dynamo.source import is_from_source - self.inner = { + self.inner = OrderedSet( g for g in self.inner if not is_from_source(g.originating_source, source) - } + ) """ @@ -677,7 +679,7 @@ def __init__(self) -> None: self.aotautograd_guards: list[GuardEnvExpr] = [] def copy_graphstate(self) -> GuardsCheckpointState: - return GuardsCheckpointState(set(self.dynamo_guards.inner)) + return GuardsCheckpointState(OrderedSet(self.dynamo_guards.inner)) def restore_graphstate(self, state: GuardsCheckpointState) -> None: # NB: "steals" the passed in state From 6de6685797cabc6256df76803f3a5f772d5275a7 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 5 Dec 2025 07:54:31 +0000 Subject: [PATCH 326/338] Use context managers (#169447) This PR fixes unuser context managers detected by pylint. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169447 Approved by: https://github.com/malfet --- torch/_dynamo/convert_frame.py | 14 ++-- torch/_inductor/codecache.py | 3 +- torch/_inductor/runtime/caching/interfaces.py | 2 +- torch/cuda/_memory_viz.py | 22 +++--- torch/distributed/__init__.py | 4 +- .../elastic/rendezvous/dynamic_rendezvous.py | 39 +++++----- torch/export/pt2_archive/_package.py | 12 ++-- torch/hub.py | 72 +++++++++---------- torch/testing/_internal/common_utils.py | 17 ++--- torch/testing/_internal/jit_utils.py | 32 +++++---- torch/utils/_zip.py | 23 +++--- 11 files changed, 120 insertions(+), 120 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 0da68fa5fe042..34b8fddbab876 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -500,7 +500,7 @@ def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: log.warning("Raw profile at %s", profile_path) svg_path = profile_path.with_suffix(".svg") try: - gprof2dot_process = subprocess.Popen( + with subprocess.Popen( [ "gprof2dot", "-f", @@ -511,12 +511,12 @@ def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: str(profile_path), ], stdout=subprocess.PIPE, - ) - subprocess.check_call( - ["dot", "-Tsvg", "-o", str(svg_path)], - stdin=gprof2dot_process.stdout, - ) - log.warning("Generated SVG from profile at %s", svg_path) + ) as gprof2dot_process: + subprocess.check_call( + ["dot", "-Tsvg", "-o", str(svg_path)], + stdin=gprof2dot_process.stdout, + ) + log.warning("Generated SVG from profile at %s", svg_path) except FileNotFoundError: log.warning( "Failed to generate SVG from profile -- dumping stats instead." diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 2542d5ecefd3f..e9e2eaadf55ef 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -3628,7 +3628,8 @@ def __repr__(self) -> str: def touch(filename: str) -> None: - open(filename, "a").close() + with open(filename, "a"): + pass @clear_on_fresh_cache diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 4c0972268e6f0..eb4b8251bc399 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -572,7 +572,7 @@ def _dump_imc_to_disk(self) -> Path | None: ) fpath: Path = odc._cache_dir / "imc.save" with odc.lock(): - r_fp, w_fp = None, None + w_fp = None try: w_fp = open(fpath, "x") # noqa:SIM115 except FileExistsError: diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index 5f0d868653e0e..56da01b202d62 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -109,18 +109,18 @@ def format_flamegraph(flamegraph_lines, flamegraph_script=None): # Ok to skip, the file will be removed by tempfile pass args = [flamegraph_script, "--countname", "bytes"] - p = subprocess.Popen( + with subprocess.Popen( args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding="utf-8" - ) - assert p.stdin is not None - assert p.stdout is not None - p.stdin.write(flamegraph_lines) - p.stdin.close() - result = p.stdout.read() - p.stdout.close() - p.wait() - assert p.wait() == 0 - return result + ) as p: + assert p.stdin is not None + assert p.stdout is not None + p.stdin.write(flamegraph_lines) + p.stdin.close() + result = p.stdout.read() + p.stdout.close() + p.wait() + assert p.wait() == 0 + return result def _write_blocks(f, prefix, blocks): diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 4e20a2b27e99d..095e8e9bf2654 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -76,8 +76,8 @@ class _DistributedPdb(pdb.Pdb): def interaction(self, *args, **kwargs): _stdin = sys.stdin try: - sys.stdin = open("/dev/stdin") # noqa: SIM115 - pdb.Pdb.interaction(self, *args, **kwargs) + with open("/dev/stdin") as sys.stdin: + pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index 35496e62ba6ac..84adeea955731 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -1318,30 +1318,27 @@ def _keep_alive_weak(weak_self) -> None: self._keep_alive() def _keep_alive(self) -> None: - self._heartbeat_lock.acquire() + with self._heartbeat_lock: + op = _RendezvousKeepAliveOp() - op = _RendezvousKeepAliveOp() + deadline = self._get_deadline(self._settings.timeout.heartbeat) - deadline = self._get_deadline(self._settings.timeout.heartbeat) - - try: - self._op_executor.run(op, deadline) + try: + self._op_executor.run(op, deadline) - msg = ( - f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " - f"'{self._settings.run_id}'." - ) - self._record(message=msg) - logger.debug(msg) - except RendezvousError as ex: - msg = ( - f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " - f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." - ) - self._record(message=msg, node_state=NodeState.FAILED) - logger.warning(msg) - finally: - self._heartbeat_lock.release() + msg = ( + f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.debug(msg) + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " + f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) def _start_heartbeats(self) -> None: self._keep_alive_timer = _PeriodicTimer( diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 302854891f199..89061dde02197 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -74,15 +74,15 @@ def is_pt2_package(serialized_model: Union[bytes, str]) -> bool: Check if the serialized model is a PT2 Archive package. """ try: - zip_reader = zipfile.ZipFile( + with zipfile.ZipFile( io.BytesIO(serialized_model) if isinstance(serialized_model, bytes) else serialized_model - ) - root_folder = zip_reader.namelist()[0].split(os.path.sep)[0] - archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}" - if archive_format_path in zip_reader.namelist(): - return zip_reader.read(archive_format_path) == b"pt2" + ) as zip_reader: + root_folder = zip_reader.namelist()[0].split(os.path.sep)[0] + archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}" + if archive_format_path in zip_reader.namelist(): + return zip_reader.read(archive_format_path) == b"pt2" except Exception: logger.info("Model is not a PT2 package") return False diff --git a/torch/hub.py b/torch/hub.py index 3ec285fcb3a9e..4344855d0060f 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -716,17 +716,6 @@ def download_url_to_file( ... ) """ - file_size = None - req = Request(url, headers={"User-Agent": "torch.hub"}) - u = urlopen(req) - meta = u.info() - if hasattr(meta, "getheaders"): - content_length = meta.getheaders("Content-Length") - else: - content_length = meta.get_all("Content-Length") - if content_length is not None and len(content_length) > 0: - file_size = int(content_length[0]) - # We deliberately save it in a temp file and move it after # download is complete. This prevents a local working checkpoint # being overridden by a broken download. @@ -742,33 +731,42 @@ def download_url_to_file( break else: raise FileExistsError(errno.EEXIST, "No usable temporary file name found") - + req = Request(url, headers={"User-Agent": "torch.hub"}) try: - if hash_prefix is not None: - sha256 = hashlib.sha256() - with tqdm( - total=file_size, - disable=not progress, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as pbar: - while True: - buffer = u.read(READ_DATA_CHUNK) - if len(buffer) == 0: - break - f.write(buffer) # type: ignore[possibly-undefined] - if hash_prefix is not None: - sha256.update(buffer) # type: ignore[possibly-undefined] - pbar.update(len(buffer)) - - f.close() - if hash_prefix is not None: - digest = sha256.hexdigest() # type: ignore[possibly-undefined] - if digest[: len(hash_prefix)] != hash_prefix: - raise RuntimeError( - f'invalid hash value (expected "{hash_prefix}", got "{digest}")' - ) + with urlopen(req) as u: + meta = u.info() + if hasattr(meta, "getheaders"): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + file_size = None + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + sha256 = hashlib.sha256() if hash_prefix is not None else None + with tqdm( + total=file_size, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + while True: + buffer = u.read(READ_DATA_CHUNK) + if len(buffer) == 0: + break + f.write(buffer) + if sha256 is not None: + sha256.update(buffer) + pbar.update(len(buffer)) + + f.close() + if sha256 is not None and hash_prefix is not None: + digest = sha256.hexdigest() + if digest[: len(hash_prefix)] != hash_prefix: + raise RuntimeError( + f'invalid hash value (expected "{hash_prefix}", got "{digest}")' + ) shutil.move(f.name, dst) finally: f.close() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index df3ca03b76242..5618947ce8ed1 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -4594,13 +4594,14 @@ def check_nondeterministic_alert(self, fn, caller_name, should_alert=True): def run_process_no_exception(code, env=None): import subprocess - popen = subprocess.Popen( - [sys.executable, '-c', code], + with subprocess.Popen( + [sys.executable, "-c", code], stdout=subprocess.PIPE, stderr=subprocess.PIPE, - env=env) - (stdout, stderr) = popen.communicate() - return (stdout, stderr) + env=env, + ) as p: + (stdout, stderr) = p.communicate() + return (stdout, stderr) # returns captured stderr @staticmethod @@ -4667,9 +4668,9 @@ def download_file(url, binary=True): if os.path.exists(path): return path try: - data = request.urlopen(url, timeout=15).read() - with open(path, 'wb' if binary else 'w') as f: - f.write(data) + with request.urlopen(url, timeout=15) as f1, open(path, 'wb' if binary else 'w') as f2: + data = f1.read() + f2.write(data) return path except error.URLError as e: msg = f"could not download test file '{url}'" diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 7647a6595ec73..4aab838e8c87b 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -197,20 +197,24 @@ def _isHookExceptionOk(self, e): def _compared_saved_loaded(self, m): def extract_files(buffer): # crack open the zip format to get at the main module code - archive = zipfile.ZipFile(buffer) - # check that we have no duplicate names - self.assertEqual(len(set(archive.namelist())), len(archive.namelist())) - files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) - # unwrap all the code files into strings - code_files_str = filter(lambda x: x.endswith('.py'), files) - code_files_stream = (archive.open(f) for f in code_files_str) - code_files = ("".join([line.decode() for line in file]) for file in code_files_stream) - - # unpickled all the debug files - debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files) - debug_files_stream = (archive.open(f) for f in debug_files_str) - debug_files = (pickle.load(f) for f in debug_files_stream) - return code_files, debug_files + with zipfile.ZipFile(buffer) as archive: + # check that we have no duplicate names + self.assertEqual(len(set(archive.namelist())), len(archive.namelist())) + files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) + # unwrap all the code files into strings + code_files_str = filter(lambda x: x.endswith('.py'), files) + code_files = [] + for f in code_files_str: + with archive.open(f) as stream: + code_files.append("".join([line.decode() for line in stream])) + + # unpickled all the debug files + debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files) + debug_files = [] + for f in debug_files_str: + with archive.open(f) as stream: + debug_files.append(pickle.load(stream)) + return code_files, debug_files # disable the hook while we parse code, otherwise we will re-enter the hook with torch._jit_internal._disable_emit_hooks(): diff --git a/torch/utils/_zip.py b/torch/utils/_zip.py index 5dd98e43c4a77..c4bfbcb0b9b63 100644 --- a/torch/utils/_zip.py +++ b/torch/utils/_zip.py @@ -69,18 +69,17 @@ def main() -> None: zip_file_name = args.install_dir + "/" + args.zip_name strip_file_dir = args.strip_dir prepend_str = args.prepend_str - zf = ZipFile(zip_file_name, mode="w") - - for p in sorted(args.paths): - if os.path.isdir(p): - files = glob.glob(p + "/**/*.py", recursive=True) - for file_path in sorted(files): - # strip the absolute path - write_to_zip( - file_path, strip_file_dir + "/", zf, prepend_str=prepend_str - ) - else: - write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str) + with ZipFile(zip_file_name, mode="w") as zf: + for p in sorted(args.paths): + if os.path.isdir(p): + files = glob.glob(p + "/**/*.py", recursive=True) + for file_path in sorted(files): + # strip the absolute path + write_to_zip( + file_path, strip_file_dir + "/", zf, prepend_str=prepend_str + ) + else: + write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str) if __name__ == "__main__": From 3158c045f195647cae47293c81c656bac53ec185 Mon Sep 17 00:00:00 2001 From: karthickai Date: Thu, 4 Dec 2025 13:19:53 -0800 Subject: [PATCH 327/338] [Inductor] Fix pattern matcher FailedMatch format string (#169611) This PR fixes FailedMatch format string bug, part of the issue: #169440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169611 Approved by: https://github.com/ProExpertProg, https://github.com/zou3519 ghstack dependencies: #169603 --- test/inductor/test_pattern_matcher.py | 37 +++++++++++++++++++++++++++ torch/_inductor/pattern_matcher.py | 5 +++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 30e7f28c45ada..f5a2d0a0cf808 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -2004,6 +2004,43 @@ def fn(x, y): ) self.assertIn("add(arg0_1, arg1_1)", specific_record.getMessage()) + @make_logging_test() + def test_failed_match_constant_args_format_string(self, records): + def pattern(x): + return x + 1 + + def replacement(x): + return x * 2 + + my_patterns = PatternMatcherPass() + inputs = [ + torch.randn(4, 4, device=GPU_TYPE), + ] + register_replacement(pattern, replacement, inputs, fwd_only, my_patterns) + + def custom_pass(graph: torch.fx.Graph): + return my_patterns.apply(graph) + + def fn(x): + return x + 2 + + x = torch.randn(4, 4, device=GPU_TYPE) + + with unittest.mock.patch.dict( + os.environ, {"TORCHINDUCTOR_PATTERN_MATCH_DEBUG": "add"} + ): + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + result = compiled_fn(x) + self.assertEqual(result, x + 2) + + specific_record = self.getRecord(records, "Specific pattern match") + self.assertIn( + "add(arg0_1, 2) constant_args: add 2!=1 CallFunction(aten.add.Tensor, KeywordArg('x'), 1, _users=0)", + specific_record.getMessage(), + ) + if __name__ == "__main__": if IS_LINUX and HAS_GPU: diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index af071a55a23fc..6c2c98a5609d1 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -751,7 +751,10 @@ def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: m.extend(child_match) elif isinstance(child_node, torch.fx.Node) or child_node != pattern: return FailedMatch( - "constant_args: {} {!r}!={pattern!r}", node, child_node + "constant_args: {} {!r}!={pattern!r}", + node, + child_node, + pattern=pattern, ) m.nodes.append(node) m.targets[self] = node.target From 716edc39ba35aa088ff32cb95b8af9b91cb22215 Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 5 Dec 2025 08:17:37 +0000 Subject: [PATCH 328/338] Triton 3.6 update. Add CooperativeReductionTests to slow (#169630) Related to https://github.com/pytorch/pytorch/issues/169492 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169630 Approved by: https://github.com/huydhn --- test/inductor/test_cooperative_reductions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index 45a79bbbc73f0..66ca0d9d050f6 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -153,6 +153,7 @@ def run_and_check(self, fn, args, dtype=None, *, expect_kernel_count=1): ) return source_code + @slowTest @parametrize( "name", [ @@ -229,6 +230,7 @@ def fn(x): ) self.assertEqual(source_code.count(f"empty_strided_{GPU_TYPE}"), 5) + @slowTest def test_reduce_split(self): def fn(a, b): a1 = torch.linalg.vector_norm(a) From d064b2289d3980912e864e3fb481b3c54f944115 Mon Sep 17 00:00:00 2001 From: Xuan Zhang Date: Thu, 4 Dec 2025 20:48:06 -0800 Subject: [PATCH 329/338] activation offloading implementation (#167880) This PR is for the **yellow block** of the design flow below. image Assuming users mark some nodes with metadata `should_offload` (which could later be replaced as compiler automatic decisions), the implementation first checks if the tensors are safe to offload via the `can_offload()` function. For tensors that are marked as `should_offload` and pass the verification in `can_offload()`, we modify the forward and backward graph, where in the forward graph, we offload tensors to CPU, and in the backward graph, we reload tensors to GPU. We introduce two flags: * `enable_activation_offloading` -- with this turned on, we insert offload/reload nodes to the fwd/bwd graphs. * `activation_offload_separate_stream` -- with this turned on, we wrap the offload/reload nodes in a separate stream with correct waits for e.g., data dependencies. (Note that from a "frontend" perspective, everything is the same, the only difference is that the memcpy is now moved to a separate stream, which prepares for subsequent reordering and overlapping) For an [example model](https://gist.github.com/xuanzhang816/ee2e3648123670f14ced9963858ee3b4), we have 1. baseline (i.e., no activation offloading) - 9160 MB peak memory - 157.42 ms per-iteration runtime - trace for fwd and bwd as below: image 2. AO (i.e., with `enable_activation_offloading=True`) - 7460 MB peak memory - 224.15 ms per-iteration runtime - trace for fwd and bwd as below (the "M" blocks are for Memcpy): image 3. AO on separate stream (i.e., additionally with `activation_offload_separate_stream = True`) - 7460 MB peak memory - 224.11 ms per-iteration runtime - trace for fwd and bwd as below (the "M" blocks are for Memcpy): image Pull Request resolved: https://github.com/pytorch/pytorch/pull/167880 Approved by: https://github.com/eellison --- test/dynamo/test_activation_offloading.py | 218 ++++++++ .../_activation_offloading/__init__.py | 5 + .../activation_offloading.py | 518 ++++++++++++++++++ torch/_functorch/config.py | 6 + torch/_functorch/partitioners.py | 13 + 5 files changed, 760 insertions(+) create mode 100644 test/dynamo/test_activation_offloading.py create mode 100644 torch/_functorch/_activation_offloading/__init__.py create mode 100644 torch/_functorch/_activation_offloading/activation_offloading.py diff --git a/test/dynamo/test_activation_offloading.py b/test/dynamo/test_activation_offloading.py new file mode 100644 index 0000000000000..5c228110998bd --- /dev/null +++ b/test/dynamo/test_activation_offloading.py @@ -0,0 +1,218 @@ +# Owner(s): ["oncall: pt2"] +# flake8: noqa: B950 + +from functools import partial + +import pytest + +import torch +import torch._functorch.config +from functorch.compile import ( + aot_function, + default_decompositions, + min_cut_rematerialization_partition, +) +from torch._dynamo.graph_bytecode_inputs import reset_user_object_tracking +from torch._inductor.utils import run_fw_bw_and_get_code +from torch.testing import FileCheck +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU + + +networkx = pytest.importorskip("networkx") + + +def extract_graph(fx_g, _, graph_cell): + graph_cell[0] = fx_g + return fx_g + + +def get_fw_bw_graph( + f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False +): + fw_graph_cell = [None] + bw_graph_cell = [None] + aot_function( + f, + fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), + bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), + partition_fn=partitioner, + decompositions=default_decompositions, + dynamic=dynamic, + )(*inps).sum().backward() + return (fw_graph_cell[0], bw_graph_cell[0]) + + +class ActivationOffloadingTests(TestCase): + """Tests activation offloading functionality""" + + def setUp(self): + super().setUp() + + def fn(x): + return (x[0] + x[1]).sin() + (x[2] + x[3]).sin() + (x[4] + x[5]).sin() + + def mark_one_cos_for_offloading(gm, joint_inputs): + for node in gm.graph.nodes: + if node.name == "cos_1": + node.meta["should_offload"] = True + return gm + + dim = 10 + self.x = [ + torch.randn(dim, dim, requires_grad=True, device=GPU_TYPE) for _ in range(6) + ] + self.fn = fn + self.joint_custom_pass = mark_one_cos_for_offloading + + """ + The first set of tests are for the case of adding offload nodes to the fwd and bwd graphs. + """ + + @torch._functorch.config.patch(enable_activation_offloading=True) + def test_partitioner_offload(self): + torch._dynamo.reset() + torch._functorch.config.joint_custom_pass = self.joint_custom_pass + fw_graph, bw_graph = get_fw_bw_graph(self.fn, [self.x]) + + self.assertExpectedInline( + fw_graph.code.strip(), + """\ +def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): + add = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None + sin = torch.ops.aten.sin.default(add) + add_1 = torch.ops.aten.add.Tensor(primals_3, primals_4); primals_3 = primals_4 = None + sin_1 = torch.ops.aten.sin.default(add_1) + add_2 = torch.ops.aten.add.Tensor(sin, sin_1); sin = sin_1 = None + add_3 = torch.ops.aten.add.Tensor(primals_5, primals_6); primals_5 = primals_6 = None + sin_2 = torch.ops.aten.sin.default(add_3) + add_4 = torch.ops.aten.add.Tensor(add_2, sin_2); add_2 = sin_2 = None + cos = torch.ops.aten.cos.default(add_3); add_3 = None + cos_1 = torch.ops.aten.cos.default(add_1); add_1 = None + cpu_offload_cos_1 = torch.ops.prims.device_put.default(cos_1, device(type='cpu'), non_blocking = True); cos_1 = None + cos_2 = torch.ops.aten.cos.default(add); add = None + return (add_4, cos, cpu_offload_cos_1, cos_2)""", + ) + + self.assertExpectedInline( + bw_graph.code.strip(), + """\ +def forward(self, cos, cpu_offload_cos_1, cos_2, tangents_1): + mul = torch.ops.aten.mul.Tensor(tangents_1, cos); cos = None + gpu_reload_cos_1 = torch.ops.prims.device_put.default(cpu_offload_cos_1, device(type='cuda', index=0), non_blocking = True); cpu_offload_cos_1 = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, gpu_reload_cos_1); gpu_reload_cos_1 = None + mul_2 = torch.ops.aten.mul.Tensor(tangents_1, cos_2); tangents_1 = cos_2 = None + return (mul_2, mul_2, mul_1, mul_1, mul, mul)""", + ) + + def test_inductor_offload(self): + torch._dynamo.reset() + + def run_compiled(): + torch._functorch.config.enable_activation_offloading = True + torch._functorch.config.joint_custom_pass = self.joint_custom_pass + return torch.compile(self.fn)(self.x) + + _, (fw_code, bw_code) = run_fw_bw_and_get_code(run_compiled) + + ( + FileCheck() + .check("buf3 = empty_strided_cpu_pinned(") + .check("buf3.copy_(buf2, True)") + .run(fw_code) + ) + + ( + FileCheck() + .check("buf1 = empty_strided_cuda(") + .check("buf1.copy_(cpu_offload_cos_1, True)") + .check("del cpu_offload_cos_1") + .run(bw_code) + ) + + @torch._functorch.config.patch( + enable_activation_offloading=True, + activation_offload_separate_stream=True, + ) + def test_partitioner_offload_sep_stream(self): + reset_user_object_tracking() + torch._dynamo.reset() + torch._functorch.config.joint_custom_pass = self.joint_custom_pass + fw_graph, bw_graph = get_fw_bw_graph(self.fn, [self.x]) + + self.assertExpectedInline( + fw_graph.code.strip(), + """\ +def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): + add = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None + sin = torch.ops.aten.sin.default(add) + add_1 = torch.ops.aten.add.Tensor(primals_3, primals_4); primals_3 = primals_4 = None + sin_1 = torch.ops.aten.sin.default(add_1) + add_2 = torch.ops.aten.add.Tensor(sin, sin_1); sin = sin_1 = None + add_3 = torch.ops.aten.add.Tensor(primals_5, primals_6); primals_5 = primals_6 = None + sin_2 = torch.ops.aten.sin.default(add_3) + add_4 = torch.ops.aten.add.Tensor(add_2, sin_2); add_2 = sin_2 = None + cos = torch.ops.aten.cos.default(add_3); add_3 = None + cos_1 = torch.ops.aten.cos.default(add_1); add_1 = None + record_event_default = torch.ops.streams.record_event.default(2, 0); record_event_default = None + stream_in_cpu_offload_cos_1 = torch.ops.streams.fork.default(0, 1); stream_in_cpu_offload_cos_1 = None + wait_event_default = torch.ops.streams.wait_event.default(2, 1); wait_event_default = None + record_stream_cos_1 = torch.ops.streams.record_stream.default(cos_1, 1); record_stream_cos_1 = None + cpu_offload_cos_1 = torch.ops.prims.device_put.default(cos_1, device(type='cpu'), non_blocking = True); cos_1 = None + record_event_default_1 = torch.ops.streams.record_event.default(3, 1); record_event_default_1 = None + stream_out_cpu_offload_cos_1 = torch.ops.streams.join.default(1, 0); stream_out_cpu_offload_cos_1 = None + wait_event_default_1 = torch.ops.streams.wait_event.default(3, 0); wait_event_default_1 = None + cos_2 = torch.ops.aten.cos.default(add); add = None + return (add_4, cos, cpu_offload_cos_1, cos_2)""", + ) + + self.assertExpectedInline( + bw_graph.code.strip(), + """\ +def forward(self, cos, cpu_offload_cos_1, cos_2, tangents_1): + mul = torch.ops.aten.mul.Tensor(tangents_1, cos); cos = None + stream_in_gpu_reload_cos_1 = torch.ops.streams.fork.default(4, 5); stream_in_gpu_reload_cos_1 = None + wait_stream_default = torch.ops.streams.wait_stream.default(5, 4); wait_stream_default = None + gpu_reload_cos_1 = torch.ops.prims.device_put.default(cpu_offload_cos_1, device(type='cuda', index=0), non_blocking = True); cpu_offload_cos_1 = None + record_event_default = torch.ops.streams.record_event.default(6, 5); record_event_default = None + stream_out_gpu_reload_cos_1 = torch.ops.streams.join.default(5, 4); stream_out_gpu_reload_cos_1 = None + wait_event_default = torch.ops.streams.wait_event.default(6, 4); wait_event_default = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, gpu_reload_cos_1); gpu_reload_cos_1 = None + mul_2 = torch.ops.aten.mul.Tensor(tangents_1, cos_2); tangents_1 = cos_2 = None + return (mul_2, mul_2, mul_1, mul_1, mul, mul)""", + ) + + @torch._functorch.config.patch( + enable_activation_offloading=True, + activation_offload_separate_stream=True, + ) + def test_partitioner_offload_sep_stream_accuracy(self): + # Run without compilation to get reference gradients + x_ref = [x.detach().clone().requires_grad_(True) for x in self.x] + out_ref = self.fn(x_ref) + out_ref.sum().backward() + grads_ref = [inp.grad for inp in x_ref] + + # Run with aot_eager compilation and offloading enabled + reset_user_object_tracking() + torch._dynamo.reset() + torch._functorch.config.joint_custom_pass = self.joint_custom_pass + x_compile = [x.detach().clone().requires_grad_(True) for x in self.x] + compiled_fn = torch.compile(self.fn, backend="aot_eager") + out_compiled = compiled_fn(x_compile) + out_compiled.sum().backward() + grads_compiled = [inp.grad for inp in x_compile] + + # Verify gradients match between reference and compiled versions + for grad_ref, grad_compiled in zip(grads_ref, grads_compiled): + torch.testing.assert_close( + grad_compiled, + grad_ref, + rtol=1e-5, + atol=1e-5, + ) + + +if __name__ == "__main__": + if HAS_GPU: + run_tests() diff --git a/torch/_functorch/_activation_offloading/__init__.py b/torch/_functorch/_activation_offloading/__init__.py new file mode 100644 index 0000000000000..10a55772ab58b --- /dev/null +++ b/torch/_functorch/_activation_offloading/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torch/_functorch/_activation_offloading/activation_offloading.py b/torch/_functorch/_activation_offloading/activation_offloading.py new file mode 100644 index 0000000000000..9b3feca724a69 --- /dev/null +++ b/torch/_functorch/_activation_offloading/activation_offloading.py @@ -0,0 +1,518 @@ +""" +Activation offloading for memory optimization in (more like post) partitioners. + +This module provides functionality to offload activations to CPU during forward pass +and reload them during backward pass, reducing GPU memory usage. + +Additional TODO: +* given the fact that PT2 stream support is in active development, testings should + be done once that is more finalized. A issue currently known is that with streams, + each iteration will have its own offload streams, but the streams should be shared + across the iterations. +""" + +import logging +import operator + +import torch +import torch.fx as fx +from torch._dynamo.variables.streams import get_current_stream, new_event, new_stream +from torch._subclasses.fake_tensor import extract_tensor_metadata +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..partitioners import get_default_op_list, OpTypes + + +log: logging.Logger = logging.getLogger(__name__) + + +# Node name prefixes for offload/reload operations +# NOTE: right now we are using these prefixes as identifiers for offload/reload +CPU_OFFLOAD_PREFIX = "cpu_offload_" +GPU_RELOAD_PREFIX = "gpu_reload_" + + +def offload_activation_fw(graph: fx.Graph) -> None: + """ + Insert CPU offload operations in the forward pass graph. + + Offload operations are placed after the last effective use of each tensor marked + for offloading. This ensures the tensor is no longer needed on the GPU before + transferring it to CPU memory. + + NOTE: An alternative approach would offload tensors immediately after generation + to maximize compute-communication overlap. However, this requires additional + synchronization to ensure tensor deletion (which occurs on the default stream) + waits for the asynchronous offload operation to complete. This would necessitate + more complex tracking to separate operation scheduling from memory cleanup. + + Args: + graph: The forward graph to modify + """ + + op_types: OpTypes = get_default_op_list() + + def find_all_effective_users(node: fx.Node) -> OrderedSet[fx.Node]: + """ + Find all effective users of a node, where view ops extend the lifetime of the + original node. If a user is a view op, recursively find users of the view. + """ + effective_users: OrderedSet[fx.Node] = OrderedSet() + for user in node.users: + if user.op == "output": + continue + effective_users.add(user) + if op_types.is_view(user): + effective_users.update(find_all_effective_users(user)) + + return effective_users + + output_node: fx.Node = graph.find_nodes(op="output")[0] + fwd_outputs: tuple[fx.Node] = output_node.args[ + 0 + ] # pyrefly: ignore [bad-assignment] + node_to_offload: dict[fx.Node, fx.Node] = dict() + node_to_index: dict[fx.Node, int] = { + node: idx for idx, node in enumerate(graph.nodes) + } + + for node in fwd_outputs: + if node.meta.get("saved_for_offloading", False) is False: + continue + + # Find insertion point, which is the last use + all_effective_users: OrderedSet[fx.Node] = find_all_effective_users(node) + if all_effective_users := find_all_effective_users(node): + last_user = max(all_effective_users, key=lambda n: node_to_index[n]) + else: + last_user: fx.Node = node + + # Insert the CPU offload operation after the last user + with graph.inserting_after(last_user): + cpu_node: fx.Node = graph.call_function( + torch.ops.prims.device_put.default, + args=(node, torch.device("cpu")), + kwargs={"non_blocking": True}, + name=CPU_OFFLOAD_PREFIX + str(node.name), + ) + cpu_node.meta["val"] = node.meta["val"].to(torch.device("cpu")) + cpu_node.meta["tensor_meta"] = extract_tensor_metadata(cpu_node.meta["val"]) + + node_to_offload[node] = cpu_node + + # Update the return node args + output_node.update_arg( + 0, tuple(node_to_offload.get(node, node) for node in fwd_outputs) + ) + + +def reload_activation_bw(graph: fx.Graph) -> None: + """ + Insert GPU reload operations in the backward pass graph. + + Reload operations are placed before the first use of each offloaded tensor, + transferring it from CPU back to GPU memory before it's needed for computation. + + Args: + graph: The backward graph to modify + """ + + node_to_index: dict[fx.Node, int] = { + node: idx for idx, node in enumerate(graph.nodes) + } + output_node: fx.Node = graph.find_nodes(op="output")[0] + + for node in graph.find_nodes(op="placeholder"): + if node.meta.get("saved_for_offloading", False) is False: + continue + + # Find insertion point, which is the first use or output node if no users + # The later should not happen, but inserting before output node is safe + insert_point: fx.Node = ( + min(node.users.keys(), key=lambda n: node_to_index[n]) + if node.users + else output_node + ) + + # Insert the GPU reload operation before the first user + original_device: torch.Device = node.meta["original_device"] + with graph.inserting_before(insert_point): + gpu_node: fx.Node = graph.call_function( + torch.ops.prims.device_put.default, + args=(node, original_device), + kwargs={"non_blocking": True}, + name=str(node.name).replace(CPU_OFFLOAD_PREFIX, GPU_RELOAD_PREFIX), + ) + gpu_node.meta["val"] = node.meta["val"].to(original_device) + gpu_node.meta["tensor_meta"] = extract_tensor_metadata(gpu_node.meta["val"]) + + # Replace all uses of the CPU tensor with the GPU tensor + for user in list(node.users.keys()): + if user != gpu_node: + user.replace_input_with(node, gpu_node) + + +def can_offload( + node: fx.Node, + fwd_outputs: OrderedSet[fx.Node], + model_outputs: OrderedSet[fx.Node], + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> bool: + """ + Determine if a node can be offloaded to CPU. + + Args: + node: The node to check + fwd_outputs: Forward module outputs, including model outputs and activations + model_outputs: Model outputs + + NOTE: Additional context for the logic behind these offloading checks: + + * fwd_outputs: Only saved intermediate tensors should be offloaded. + + * model_outputs / static_lifetime_input_nodes: Tensors that may be accessed outside + the compiled region (e.g., model outputs, static inputs) cannot be offloaded as + they must remain accessible beyond the scope of the compiled graph. + + * views / getitems: Offloading such nodes can lead to segmentation faults. + + * contiguous: Offloading non-contiguous tensors causes CPU-side stride changes + during both forward and backward passes when using the Inductor backend. While + these stride changes cancel each other out, they introduce significant compute + overhead. This is due to the contiguity check in ir.py (see link below). + TODO: This restriction could potentially be bypassed in the future. + Reference: https://github.com/pytorch/pytorch/blob/44ac69388a4a5eb463dbd2a13f00d1e3b924566c/torch/_inductor/ir.py#L3214 + + Additional criteria to consider for offloading optimization: + + * Tensor size: Small tensors may not fully utilize available bandwidth, reducing the + efficiency gains from offloading. + + * Position in forward/backward graph: Activations generated near the end of the forward + pass are typically consumed near the beginning of the backward pass. Offloading such + tensors may be counterproductive since they are quickly reloaded, not having sufficient + time to overlap the transfer with computation. + """ + + log.debug(f"Checking node {node.name} for offloading...") # noqa: G004 + + op_types: OpTypes = get_default_op_list() + + if node not in fwd_outputs: + log.debug("\tSkipped! Can only offload nodes in fwd_module_outputs.") + return False + if node in model_outputs: + log.debug("\tSkipped! Cannot offload model outputs.") + return False + if node in static_lifetime_input_nodes: + log.debug("\tSkipped! Cannot offload static input nodes.") + return False + if op_types.is_view(node): + log.debug("\tSkipped! Cannot offload views.") + return False + if node.target == operator.getitem: + log.debug("\tSkipped! Cannot offload getitems.") + return False + if hasattr(node, "meta") and "val" in node.meta: + if ( + isinstance(val := node.meta["val"], torch.Tensor) + and not val.is_contiguous() + ): + log.debug("\tSkipped! Cannot offload non-contiguous tensors.") + return False + + log.debug("\tGood!") + return True + + +def choose_offload_sets( + fwd_module: fx.GraphModule, + num_fwd_outputs: int, + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> bool: + """ + Decide which nodes will be offloaded based on the marked nodes and feasibility. + Marks nodes with "saved_for_offloading" if they should and can be offloaded. + + Args: + fwd_module: Forward graph module + bwd_module: Backward graph module + num_fwd_outputs: Number of forward outputs + + Returns: + bool: Whether activation offloading should be performed + """ + + fwd_outputs: OrderedSet[fx.Node] = OrderedSet( + fwd_module.graph.find_nodes(op="output")[0].args[0] + ) + model_outputs: OrderedSet[fx.Node] = OrderedSet( + fwd_module.graph.find_nodes(op="output")[0].args[0][:num_fwd_outputs] + ) + + should_perform_offloading = False + for node in fwd_module.graph.nodes: + if node.meta.get("should_offload", False) and can_offload( + node, fwd_outputs, model_outputs, static_lifetime_input_nodes + ): + node.meta["saved_for_offloading"] = True + node.meta["original_device"] = node.meta["val"].device + should_perform_offloading = True + + return should_perform_offloading + + +def offload_chosen_sets( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, +) -> None: + """ + Add offload and reload nodes to the forward and backward graphs. + This function adds device_put operations without any stream handling. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + """ + + # Add offload nodes in forward graph + offload_activation_fw(fwd_module.graph) + + # Update backward graph inputs to be offloaded tensors + bwd_inputs: dict[str, fx.Node] = { + node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") + } + for fwd_node in fwd_module.graph.find_nodes(op="output")[0].args[0]: + if CPU_OFFLOAD_PREFIX not in fwd_node.name: + continue + + bwd_node: fx.Node = bwd_inputs[fwd_node.name.replace(CPU_OFFLOAD_PREFIX, "")] + with bwd_module.graph.inserting_after(bwd_node): + bwd_offload_node: fx.Node = bwd_module.graph.placeholder(name=fwd_node.name) + + bwd_offload_node.meta.update(fwd_node.meta) + bwd_offload_node.meta["saved_for_offloading"] = True + bwd_offload_node.meta["original_device"] = bwd_node.meta["val"].device + bwd_node.replace_all_uses_with(bwd_offload_node) + bwd_module.graph.erase_node(bwd_node) + + # Add reload nodes in backward graph + reload_activation_bw(bwd_module.graph) + + +def add_forward_offload_stream_ops(graph: fx.Graph) -> None: + """ + Add stream operations for forward pass CPU offloading. + + Pattern: record_event → fork → wait_event → record_stream → device_put → record_event_2 → join → wait_event_2 + + This ensures that: + 1. Offloading waits for the last use to complete (record_event on default stream) + 2. Offloading happens on a separate stream (fork → wait_event → device_put) + 3. The tensor is marked as used in the offload stream (record_stream) + 4. Execution returns to the default stream after offloading and + waits for offload to complete (record_event_2 → join → wait_event_2) + + NOTE: For stream optimization and overlapping compute with communication, + the "wait_event_2" ops can be sinked to the end of the graph. + + Args: + graph: The forward graph to modify + """ + + # Find all CPU offload nodes + offload_nodes: list[fx.Node] = [ + node + for node in graph.nodes + if CPU_OFFLOAD_PREFIX in node.name and node.op == "call_function" + ] + if not offload_nodes: + return + + # Get default stream id and offload stream id + current_stream_id: int = get_current_stream( + offload_nodes[0].args[0].meta["val"].device # type: ignore[assignment] + ) + offload_stream_id: int = new_stream() + + for offload_node in offload_nodes: + offload_ready_event_id: int = new_event() + offload_completion_event_id: int = new_event() + + # Get the tensor being offloaded + tensor_node: fx.Node = offload_node.args[0] # type: ignore[assignment] + + with graph.inserting_before(offload_node): + # Record event on default stream to ensure last use completes + graph.call_function( + torch.ops.streams.record_event.default, + args=(offload_ready_event_id, current_stream_id), + ) + # Fork to offload stream + graph.call_function( + torch.ops.streams.fork.default, + args=(current_stream_id, offload_stream_id), + name=f"stream_in_{offload_node.name}", + ) + # Wait for the event on offload stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(offload_ready_event_id, offload_stream_id), + ) + # Inform the CUDA Caching Allocator that this tensor will be accessed in the + # offload stream. Without this, the program may prematurely free its memory + # even though the async offload operation is still in progress, and this can + # lead to memory corruption, especially with reordering for compute and + # communication overlaps. + graph.call_function( + torch.ops.streams.record_stream.default, + args=(tensor_node, offload_stream_id), + name=f"record_stream_{tensor_node.name}", + ) + with graph.inserting_after(offload_node): + # Record event on offload stream after device_put completes + record_event_node = graph.call_function( + torch.ops.streams.record_event.default, + args=(offload_completion_event_id, offload_stream_id), + ) + with graph.inserting_after(record_event_node): + # Join back to default stream + join_node = graph.call_function( + torch.ops.streams.join.default, + args=(offload_stream_id, current_stream_id), + name=f"stream_out_{offload_node.name}", + ) + with graph.inserting_after(join_node): + # Wait for the offload to complete on default stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(offload_completion_event_id, current_stream_id), + ) + + +def add_backward_reload_stream_ops(graph: fx.Graph) -> None: + """ + Add stream operations for backward pass GPU reloading. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This ensures that: + 1. Reloading doesn't start prematurely (fork → wait_stream) + 2. Reloading happens on a separate stream (device_put) + 3. First use waits for reload completion (record_event → join → wait_event) + + NOTE: The pattern consists of two logical groups: + - First group (fork → wait_stream → device_put → record_event → join): + Performs asynchronous data transfer on a separate stream + - Second group (wait_event): + Data transfer completion check when the data is actually needed + + For prefetch optimization, the first group can be moved earlier in the graph + to overlap computation with data transfer, while the wait_event must remain + at its current position to prevent blocking computation unnecessarily. + + Args: + graph: The backward graph to modify + """ + + # Find all GPU reload nodes + reload_nodes: list[fx.Node] = [ + node + for node in graph.nodes + if GPU_RELOAD_PREFIX in node.name and node.op == "call_function" + ] + if not reload_nodes: + return + + # Get default stream id and offload stream id + current_stream_id: int = get_current_stream( + reload_nodes[0].args[0].meta["original_device"] # type: ignore[assignment] + ) + reload_stream_id: int = new_stream() + + for reload_node in reload_nodes: + event_id: int = new_event() + + with graph.inserting_before(reload_node): + # Fork to reload stream + graph.call_function( + torch.ops.streams.fork.default, + args=(current_stream_id, reload_stream_id), + name=f"stream_in_{reload_node.name}", + ) + # Wait for default stream to prevent premature reloading + graph.call_function( + torch.ops.streams.wait_stream.default, + args=(reload_stream_id, current_stream_id), + ) + with graph.inserting_after(reload_node): + # Record event on reload stream after device_put + record_event_node = graph.call_function( + torch.ops.streams.record_event.default, + args=(event_id, reload_stream_id), + ) + with graph.inserting_after(record_event_node): + # Join back to default stream + join_node = graph.call_function( + torch.ops.streams.join.default, + args=(reload_stream_id, current_stream_id), + name=f"stream_out_{reload_node.name}", + ) + with graph.inserting_after(join_node): + # Wait for the event on default stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(event_id, current_stream_id), + ) + + +def put_offload_nodes_on_separate_stream( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, +) -> None: + """ + Add stream and event related operations around offload nodes. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + """ + + add_forward_offload_stream_ops(fwd_module.graph) + add_backward_reload_stream_ops(bwd_module.graph) + + +def enable_activation_offloading( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, + num_fwd_outputs: int, + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> None: + """ + Main entry point for activation offloading. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + num_fwd_outputs: Number of forward outputs + """ + + # Step 1: Decide which nodes to offload and mark them + should_perform_offloading: bool = choose_offload_sets( + fwd_module, + num_fwd_outputs, + static_lifetime_input_nodes, + ) + if not should_perform_offloading: + return + + # Step 2: Add offload and reload nodes to the graphs + offload_chosen_sets(fwd_module, bwd_module) + + # Step 3: Put offload nodes on separate stream if configured + if config.activation_offload_separate_stream: + put_offload_nodes_on_separate_stream(fwd_module, bwd_module) + + fwd_module.graph.lint() + bwd_module.graph.lint() diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 49a069a096f58..06873aa78f983 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -192,6 +192,12 @@ def remote_autograd_cache_default() -> Optional[bool]: # cost of some performance aggressive_recomputation = False +# activation offloading enablement (testing purpose) +enable_activation_offloading = False + +# activation offloading with separate CUDA stream +activation_offload_separate_stream = False + # If FakeTensor.data_ptr() should error. # This option is independent of AOTAutograd and torch.compile, but our policy # is to turn it off during torch.compile. diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 3e2abf2b5650f..be67e82bf46ff 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -3026,6 +3026,19 @@ def min_cut_rematerialization_partition( ) bw_module = reordering_to_mimic_autograd_engine(bw_module) + # pyrefly: ignore [unbound-name] + if config.enable_activation_offloading: + from ._activation_offloading.activation_offloading import ( + enable_activation_offloading, + ) + + enable_activation_offloading( + fw_module, + bw_module, + num_fwd_outputs, + node_info.static_lifetime_input_nodes, + ) + # raise all getitem ops to as early as possible # this is helpful for memory, especially in the case of aot_eager backend fw_module = raise_getitems(fw_module) From eca37f0b16bb92729699d4d93354de65c9bd2242 Mon Sep 17 00:00:00 2001 From: Xuan Zhang Date: Thu, 4 Dec 2025 20:48:06 -0800 Subject: [PATCH 330/338] activation offloading reordering for comp<>comm overlaps (#168316) We introduce two flags for computation <> communication overlap: * `activation_offload_sink_wait` -- with this turned on, the wait events in forward is sinked to the end of the graph so that the offload ops is not blocking * `activation_reload_prefetch` -- with this turned on, the reload in the backward is prefetched just enough so that the reload operation is perfectly overlapped with computation. Continue with the [example model](https://gist.github.com/xuanzhang816/ee2e3648123670f14ced9963858ee3b4) in the first PR in the stack: 4. AO on separate stream with reorders (i.e., additionally with `activation_offload_sink_wait = True` and `activation_reload_prefetch = True`) - 7534 MB peak memory - 160.86 ms ms per-iteration runtime - trace for fwd and bwd as below (the "M" blocks are for Memcpy): image Pull Request resolved: https://github.com/pytorch/pytorch/pull/168316 Approved by: https://github.com/eellison ghstack dependencies: #167880 --- test/dynamo/test_activation_offloading.py | 95 +++++- .../activation_offloading.py | 308 +++++++++++++++++- torch/_functorch/config.py | 6 + torch/_inductor/config.py | 4 + 4 files changed, 411 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_activation_offloading.py b/test/dynamo/test_activation_offloading.py index 5c228110998bd..3970a5e0c111e 100644 --- a/test/dynamo/test_activation_offloading.py +++ b/test/dynamo/test_activation_offloading.py @@ -15,7 +15,7 @@ from torch._dynamo.graph_bytecode_inputs import reset_user_object_tracking from torch._inductor.utils import run_fw_bw_and_get_code from torch.testing import FileCheck -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, serialTest, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -212,6 +212,99 @@ def test_partitioner_offload_sep_stream_accuracy(self): atol=1e-5, ) + @torch._functorch.config.patch( + enable_activation_offloading=True, + activation_offload_separate_stream=True, + activation_offload_sink_wait=True, + activation_reload_prefetch=True, + ) + def test_partitioner_offload_sep_stream_reorder(self): + reset_user_object_tracking() + torch._dynamo.reset() + torch._functorch.config.joint_custom_pass = self.joint_custom_pass + fw_graph, bw_graph = get_fw_bw_graph(self.fn, [self.x]) + + self.assertExpectedInline( + fw_graph.code.strip(), + """\ +def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): + add = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None + sin = torch.ops.aten.sin.default(add) + add_1 = torch.ops.aten.add.Tensor(primals_3, primals_4); primals_3 = primals_4 = None + sin_1 = torch.ops.aten.sin.default(add_1) + add_2 = torch.ops.aten.add.Tensor(sin, sin_1); sin = sin_1 = None + add_3 = torch.ops.aten.add.Tensor(primals_5, primals_6); primals_5 = primals_6 = None + sin_2 = torch.ops.aten.sin.default(add_3) + add_4 = torch.ops.aten.add.Tensor(add_2, sin_2); add_2 = sin_2 = None + cos = torch.ops.aten.cos.default(add_3); add_3 = None + cos_1 = torch.ops.aten.cos.default(add_1); add_1 = None + record_event_default = torch.ops.streams.record_event.default(2, 0); record_event_default = None + stream_in_cpu_offload_cos_1 = torch.ops.streams.fork.default(0, 1); stream_in_cpu_offload_cos_1 = None + wait_event_default = torch.ops.streams.wait_event.default(2, 1); wait_event_default = None + record_stream_cos_1 = torch.ops.streams.record_stream.default(cos_1, 1); record_stream_cos_1 = None + cpu_offload_cos_1 = torch.ops.prims.device_put.default(cos_1, device(type='cpu'), non_blocking = True); cos_1 = None + record_event_default_1 = torch.ops.streams.record_event.default(3, 1); record_event_default_1 = None + stream_out_cpu_offload_cos_1 = torch.ops.streams.join.default(1, 0); stream_out_cpu_offload_cos_1 = None + cos_2 = torch.ops.aten.cos.default(add); add = None + wait_event_default_1 = torch.ops.streams.wait_event.default(3, 0); wait_event_default_1 = None + return (add_4, cos, cpu_offload_cos_1, cos_2)""", + ) + + self.assertExpectedInline( + bw_graph.code.strip(), + """\ +def forward(self, cos, cpu_offload_cos_1, cos_2, tangents_1): + stream_in_gpu_reload_cos_1 = torch.ops.streams.fork.default(4, 5); stream_in_gpu_reload_cos_1 = None + wait_stream_default = torch.ops.streams.wait_stream.default(5, 4); wait_stream_default = None + gpu_reload_cos_1 = torch.ops.prims.device_put.default(cpu_offload_cos_1, device(type='cuda', index=0), non_blocking = True); cpu_offload_cos_1 = None + record_event_default = torch.ops.streams.record_event.default(6, 5); record_event_default = None + stream_out_gpu_reload_cos_1 = torch.ops.streams.join.default(5, 4); stream_out_gpu_reload_cos_1 = None + mul = torch.ops.aten.mul.Tensor(tangents_1, cos); cos = None + wait_event_default = torch.ops.streams.wait_event.default(6, 4); wait_event_default = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, gpu_reload_cos_1); gpu_reload_cos_1 = None + mul_2 = torch.ops.aten.mul.Tensor(tangents_1, cos_2); tangents_1 = cos_2 = None + return (mul_2, mul_2, mul_1, mul_1, mul, mul)""", + ) + + @torch._functorch.config.patch( + enable_activation_offloading=True, + activation_offload_separate_stream=True, + activation_offload_sink_wait=True, + activation_reload_prefetch=True, + ) + @serialTest() + def test_partitioner_offload_sep_stream_reorder_accuracy(self): + # need larger dimension so that memcpy takes longer, and the code is at the risk of + # premature memory deallocation + dim = 1024 * 8 + x_larger = [ + torch.randn(dim, dim, requires_grad=True, device=GPU_TYPE) for _ in range(6) + ] + # Run without compilation to get reference gradients + x_ref = [x.detach().clone().requires_grad_(True) for x in x_larger] + out_ref = self.fn(x_ref) + out_ref.sum().backward() + grads_ref = [inp.grad for inp in x_ref] + + # Run with aot_eager compilation and offloading enabled + reset_user_object_tracking() + torch._dynamo.reset() + torch._functorch.config.joint_custom_pass = self.joint_custom_pass + x_compile = [x.detach().clone().requires_grad_(True) for x in x_larger] + compiled_fn = torch.compile(self.fn, backend="aot_eager") + out_compiled = compiled_fn(x_compile) + out_compiled.sum().backward() + grads_compiled = [inp.grad for inp in x_compile] + + # Verify gradients match between reference and compiled versions + for grad_ref, grad_compiled in zip(grads_ref, grads_compiled): + torch.testing.assert_close( + grad_compiled, + grad_ref, + rtol=1e-5, + atol=1e-5, + ) + if __name__ == "__main__": if HAS_GPU: diff --git a/torch/_functorch/_activation_offloading/activation_offloading.py b/torch/_functorch/_activation_offloading/activation_offloading.py index 9b3feca724a69..7b1b05af49ef9 100644 --- a/torch/_functorch/_activation_offloading/activation_offloading.py +++ b/torch/_functorch/_activation_offloading/activation_offloading.py @@ -13,15 +13,18 @@ import logging import operator +from dataclasses import dataclass import torch import torch.fx as fx from torch._dynamo.variables.streams import get_current_stream, new_event, new_stream +from torch._inductor import config as inductor_config +from torch._inductor.fx_passes.overlap_scheduling import benchmark_node, is_compute_node from torch._subclasses.fake_tensor import extract_tensor_metadata from torch.utils._ordered_set import OrderedSet from .. import config -from ..partitioners import get_default_op_list, OpTypes +from ..partitioners import _size_of, get_default_op_list, OpTypes log: logging.Logger = logging.getLogger(__name__) @@ -33,6 +36,42 @@ GPU_RELOAD_PREFIX = "gpu_reload_" +@dataclass +class ReloadNodeInfo: + """ + Information about backward reload related nodes for each reload operation. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This pattern is divided into two logical groups for optimization purposes: + - Reload group (fork → wait_stream → device_put → record_event → join): + Performs the actual asynchronous data transfer on a separate stream. + These nodes can be moved earlier in the graph to overlap with computation. + - Wait group (wait_event): + Synchronization point that blocks until the data transfer completes. + This must remain at the point where the reloaded data is first needed. + """ + + reload_group_nodes: list[fx.Node] + wait_event_node: fx.Node + transfer_size_bytes: int + transfer_time_ms: float + + +@dataclass +class ReloadQueueEntry: + """ + Entry in the reload queue for prefetch scheduling. + + Attributes: + pattern: The reload pattern information + remaining_time_ms: Remaining overlap time needed in milliseconds + """ + + pattern: ReloadNodeInfo + remaining_time_ms: float + + def offload_activation_fw(graph: fx.Graph) -> None: """ Insert CPU offload operations in the forward pass graph. @@ -483,6 +522,269 @@ def put_offload_nodes_on_separate_stream( add_backward_reload_stream_ops(bwd_module.graph) +def _validate_pattern_nodes( + fork_node: fx.Node, + wait_stream_node: fx.Node, + record_event_node: fx.Node, + join_node: fx.Node, + wait_event_node: fx.Node, +) -> None: + """ + Validate that the pattern nodes match the expected structure. + + Raises ValueError if any node doesn't match expectations. + """ + + if not ( + fork_node.op == "call_function" + and fork_node.target == torch.ops.streams.fork.default + ): + raise ValueError("Expected fork node two nodes before device_put node") + + if not ( + wait_stream_node.op == "call_function" + and wait_stream_node.target == torch.ops.streams.wait_stream.default + ): + raise ValueError("Expected wait_stream node one node before device_put node") + + if not ( + record_event_node.op == "call_function" + and record_event_node.target == torch.ops.streams.record_event.default + ): + raise ValueError("Expected record_event node one node after device_put node") + + if not ( + join_node.op == "call_function" + and join_node.target == torch.ops.streams.join.default + ): + raise ValueError("Expected join node two nodes after device_put node") + + if not ( + wait_event_node.op == "call_function" + and wait_event_node.target == torch.ops.streams.wait_event.default + ): + raise ValueError("Expected wait_event node three nodes after device_put node") + + +def _calculate_transfer_size(device_put_node: fx.Node) -> int: + """Calculate the size in bytes of data being transferred.""" + + return _size_of(device_put_node.args[0]) # pyrefly: ignore [bad-argument-type] + + +def _estimate_transfer_time_in_ms(transfer_size_bytes: int) -> float: + """ + Estimate transfer time in milliseconds based on size and bandwidth. + NOTE: potentially could be standardized in node estimator class + """ + + return transfer_size_bytes / (1024**3) * 1_000 / inductor_config.cpu_gpu_bw + + +def identify_reload_patterns( + graph: fx.Graph, nodes_list: list[fx.Node], node_to_idx: dict[fx.Node, int] +) -> dict[fx.Node, ReloadNodeInfo]: + """ + Identify backward reload patterns in the graph. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This uses position-based matching since these nodes are inserted together in + add_backward_reload_stream_ops() in a specific order. Since stream operations + do not have data dependencies between them, they are unsuitable for subgroup + pattern matching type of checks. + + Returns a dict mapping device_put node to ReloadNodeInfo containing: + - reload_group_nodes: fork → wait_stream → device_put → record_event → join + - wait_event_node: the wait_event node + - transfer_size_bytes: size of data being transferred + - transfer_time_ms: estimated transfer time in milliseconds + """ + patterns: dict[fx.Node, ReloadNodeInfo] = {} + + # Find all GPU reload device_put nodes whose inputs are placeholder nodes + reload_nodes: list[fx.Node] = [ + node + for node in graph.find_nodes( + op="call_function", target=torch.ops.prims.device_put.default + ) + if GPU_RELOAD_PREFIX in node.name + and ( + node.args + and isinstance(node.args[0], fx.Node) + and node.args[0].op == "placeholder" + ) + ] + + # Extract patterns for each reload device_put node + for reload_node in reload_nodes: + reload_node_idx: int = node_to_idx[reload_node] + + fork_node: fx.Node = nodes_list[reload_node_idx - 2] + wait_stream_node: fx.Node = nodes_list[reload_node_idx - 1] + record_event_node: fx.Node = nodes_list[reload_node_idx + 1] + join_node: fx.Node = nodes_list[reload_node_idx + 2] + wait_event_node: fx.Node = nodes_list[reload_node_idx + 3] + + # Validate the nodes are what we expect + _validate_pattern_nodes( + fork_node, + wait_stream_node, + record_event_node, + join_node, + wait_event_node, + ) + + # Calculate transfer size and time + transfer_size_bytes: int = _calculate_transfer_size(reload_node) + transfer_time_ms: float = _estimate_transfer_time_in_ms(transfer_size_bytes) + + patterns[reload_node] = ReloadNodeInfo( + reload_group_nodes=[ + fork_node, + wait_stream_node, + reload_node, + record_event_node, + join_node, + ], + wait_event_node=wait_event_node, + transfer_size_bytes=transfer_size_bytes, + transfer_time_ms=transfer_time_ms, + ) + + return patterns + + +def reorder_for_prefetch( + nodes_list: list[fx.Node], + reload_patterns: dict[fx.Node, ReloadNodeInfo], +) -> None: + """ + Reorder nodes to prefetch reload operations by directly manipulating the graph. + + This follows the algorithm as follows: + - Go through nodes in reverse order + - When encountering a reload pattern, add it to a queue with its transfer time + - When encountering a compute node, use its runtime to satisfy overlap requirements + - Place reload patterns when their overlap requirement is satisfied + - When encountering placeholder nodes, flush queue as reloads cannot move before inputs + """ + + # Build a set of all nodes in reload groups for quick lookup + reload_group_nodes_set: set[fx.Node] = set() + for pattern in reload_patterns.values(): + reload_group_nodes_set.update(pattern.reload_group_nodes) + + # Queue to hold reload group nodes waiting to be placed (FIFO) + reload_queue: list[ReloadQueueEntry] = [] + + # Loop through nodes in reverse + for node in reversed(nodes_list): + if node.op == "output": + continue + elif node.op == "placeholder": + # Flush queue - place all remaining reloads after the last placeholder + while reload_queue: + entry: ReloadQueueEntry = reload_queue.pop(0) + for reload_group_node in reversed(entry.pattern.reload_group_nodes): + node.append(reload_group_node) + break + elif node in reload_patterns: + pattern: ReloadNodeInfo = reload_patterns[node] + reload_queue.append( + ReloadQueueEntry( + pattern=pattern, remaining_time_ms=pattern.transfer_time_ms + ) + ) + elif node in reload_group_nodes_set: + continue + else: + if not reload_queue: + continue + compute_runtime_ms: float = ( + benchmark_node(node) if is_compute_node(node) else 0 + ) + reload_queue[0].remaining_time_ms -= compute_runtime_ms + + # Pop and place reload if its remaining time is satisfied (<= 0) + if reload_queue[0].remaining_time_ms <= 0: + entry: ReloadQueueEntry = reload_queue.pop(0) + for reload_group_node in entry.pattern.reload_group_nodes: + node.prepend(reload_group_node) + + +def activation_offload_sink_wait(fwd_module: fx.GraphModule) -> None: + """ + Sink wait_event operations for offload completion to the end of the graph. + + This function identifies wait_event nodes for offload completion and moves them + to the end of the graph, allowing computation to overlap with offload operations. + + Args: + fwd_module: Forward module graph + """ + graph: fx.Graph = fwd_module.graph + nodes_list: list[fx.Node] = list(graph.nodes) + node_to_idx: dict[fx.Node, int] = {node: idx for idx, node in enumerate(nodes_list)} + + # Find all CPU offload device_put nodes + offload_nodes: list[fx.Node] = [ + node + for node in graph.find_nodes( + op="call_function", target=torch.ops.prims.device_put.default + ) + if CPU_OFFLOAD_PREFIX in node.name + ] + + # Collect all wait_event nodes that need to be moved + wait_nodes_to_sink: list[fx.Node] = [] + for offload_node in offload_nodes: + offload_idx: int = node_to_idx[offload_node] + wait_event_node: fx.Node = nodes_list[offload_idx + 3] + + # Validate it's actually a wait_event node + if not ( + wait_event_node.op == "call_function" + and wait_event_node.target == torch.ops.streams.wait_event.default + ): + raise ValueError( + f"Expected wait_event node three positions after {offload_node.name}" + ) + + wait_nodes_to_sink.append(wait_event_node) + + # Find the output node, and move all wait_event nodes to just before the output node + output_node: fx.Node = graph.find_nodes(op="output")[0] + for wait_node in wait_nodes_to_sink: + output_node.prepend(wait_node) + + +def activation_reload_prefetch(bwd_module: fx.GraphModule) -> None: + """ + Prefetch backward reload operations by moving them earlier in the graph + to overlap communication with computation. + + This function identifies backward reload patterns (fork → wait_stream → device_put → + record_event → join) and moves them earlier in the execution order to overlap + the data transfer with computation, while keeping the wait_event at its original + position. + + Args: + bwd_module: Backward module graph + """ + graph: fx.Graph = bwd_module.graph + nodes_list: list[fx.Node] = list(graph.nodes) + node_to_idx: dict[fx.Node, int] = {node: idx for idx, node in enumerate(nodes_list)} + + # Step 1: Identify reload patterns + reload_patterns: dict[fx.Node, ReloadNodeInfo] = identify_reload_patterns( + graph, nodes_list, node_to_idx + ) + + # Step 2: Reorder nodes by directly manipulating the graph + reorder_for_prefetch(nodes_list, reload_patterns) + + def enable_activation_offloading( fwd_module: fx.GraphModule, bwd_module: fx.GraphModule, @@ -513,6 +815,10 @@ def enable_activation_offloading( # Step 3: Put offload nodes on separate stream if configured if config.activation_offload_separate_stream: put_offload_nodes_on_separate_stream(fwd_module, bwd_module) + if config.activation_offload_sink_wait: + activation_offload_sink_wait(fwd_module) + if config.activation_reload_prefetch: + activation_reload_prefetch(bwd_module) fwd_module.graph.lint() bwd_module.graph.lint() diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 06873aa78f983..759db7f91dd6f 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -198,6 +198,12 @@ def remote_autograd_cache_default() -> Optional[bool]: # activation offloading with separate CUDA stream activation_offload_separate_stream = False +# activation offloading wait sinking when using separate stream (fwd graph) +activation_offload_sink_wait = False + +# activation reloading with prefetching when using separate streams (bwd graph) +activation_reload_prefetch = False + # If FakeTensor.data_ptr() should error. # This option is independent of AOTAutograd and torch.compile, but our policy # is to turn it off during torch.compile. diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 4297880dbdbcf..4ff678d820091 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -444,6 +444,10 @@ def prologue_fusion_enabled() -> bool: # default value is InfiniBand inter_node_bw = 25 +# unit: GB/s, uni-directional CPU<>GPU bandwidth +# default value is PCIe; modify for your hardware or measured bandwidth +cpu_gpu_bw = 50.0 + # use Inductor's experimental benchmarker (runtime/benchmarking.py) # to benchmark kernels during autotuning, otherwise fall back to # Triton's `do_bench`. the experimental benchmarker may produce From a1134835b3b6335f2b456d1600f75edb1bcee603 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Thu, 4 Dec 2025 21:43:00 -0800 Subject: [PATCH 331/338] [precompile] disable dispatch when deepcloning in PrecompileContext.record_artifact (#169242) I ran into this when trying to precompile simplefsdp and realizing that deepcloning a DeviceMesh within a fake mode causes the following error: ``` [rank0]: File "/home/bobren/local/a/pytorch/spc.py", line 111, in [rank0]: .aot_compile(((input_tensor, d_input_tensor), {"mesh": mesh})) [rank0]: File "/home/bobren/local/a/pytorch/torch/_dynamo/eval_frame.py", line 800, in aot_compile [rank0]: return aot_compile_fullgraph( [rank0]: File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile.py", line 235, in aot_compile_fullgraph [rank0]: compiled_fn = backend( [rank0]: File "/home/bobren/local/a/pytorch/torch/__init__.py", line 2445, in __call__ [rank0]: return compile_fx(model_, inputs_, config_patches=self.config) [rank0]: File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2525, in compile_fx [rank0]: return _maybe_wrap_and_compile_fx_main( [rank0]: File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2602, in _maybe_wrap_and_compile_fx_main [rank0]: return _compile_fx_main( [rank0]: File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2797, in _compile_fx_main [rank0]: return aot_autograd( [rank0]: File "/home/bobren/local/a/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__ [rank0]: cg = aot_module_simplified(gm, example_inputs, **self.kwargs) [rank0]: File "/home/bobren/local/a/pytorch/torch/_functorch/aot_autograd.py", line 1119, in aot_module_simplified [rank0]: compiled_fn, _ = aot_stage2_compile( [rank0]: File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 348, in aot_stage2_compile [rank0]: return aot_stage2_autograd(aot_state, aot_graph_capture) [rank0]: File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 2017, in aot_stage2_autograd [rank0]: try_save_cache_entry, entry = _cache_autograd_info( [rank0]: File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 2196, in _cache_autograd_info [rank0]: entry = try_save_cache_entry( [rank0]: File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 2186, in try_save_cache_entry [rank0]: AOTAutogradCache.save( [rank0]: File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 905, in save [rank0]: raise e [rank0]: File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 889, in save [rank0]: PrecompileContext.record_artifact(artifact) [rank0]: File "/home/bobren/local/a/pytorch/torch/_dynamo/precompile_context.py", line 147, in record_artifact [rank0]: cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy( [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 172, in deepcopy [rank0]: y = _reconstruct(x, memo, *rv) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 271, in _reconstruct [rank0]: state = deepcopy(state, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 146, in deepcopy [rank0]: y = copier(x, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 231, in _deepcopy_dict [rank0]: y[deepcopy(key, memo)] = deepcopy(value, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 172, in deepcopy [rank0]: y = _reconstruct(x, memo, *rv) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 271, in _reconstruct [rank0]: state = deepcopy(state, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 146, in deepcopy [rank0]: y = copier(x, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 231, in _deepcopy_dict [rank0]: y[deepcopy(key, memo)] = deepcopy(value, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 172, in deepcopy [rank0]: y = _reconstruct(x, memo, *rv) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 271, in _reconstruct [rank0]: state = deepcopy(state, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 146, in deepcopy [rank0]: y = copier(x, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 231, in _deepcopy_dict [rank0]: y[deepcopy(key, memo)] = deepcopy(value, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 146, in deepcopy [rank0]: y = copier(x, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 206, in _deepcopy_list [rank0]: append(deepcopy(a, memo)) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 172, in deepcopy [rank0]: y = _reconstruct(x, memo, *rv) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 271, in _reconstruct [rank0]: state = deepcopy(state, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 146, in deepcopy [rank0]: y = copier(x, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 231, in _deepcopy_dict [rank0]: y[deepcopy(key, memo)] = deepcopy(value, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 146, in deepcopy [rank0]: y = copier(x, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 211, in _deepcopy_tuple [rank0]: y = [deepcopy(a, memo) for a in x] [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 211, in [rank0]: y = [deepcopy(a, memo) for a in x] [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 172, in deepcopy [rank0]: y = _reconstruct(x, memo, *rv) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 271, in _reconstruct [rank0]: state = deepcopy(state, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 146, in deepcopy [rank0]: y = copier(x, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 231, in _deepcopy_dict [rank0]: y[deepcopy(key, memo)] = deepcopy(value, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 172, in deepcopy [rank0]: y = _reconstruct(x, memo, *rv) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 271, in _reconstruct [rank0]: state = deepcopy(state, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 146, in deepcopy [rank0]: y = copier(x, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 231, in _deepcopy_dict [rank0]: y[deepcopy(key, memo)] = deepcopy(value, memo) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 153, in deepcopy [rank0]: y = copier(memo) [rank0]: File "/home/bobren/local/a/pytorch/torch/_tensor.py", line 142, in __deepcopy__ [rank0]: return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo) [rank0]: File "/home/bobren/local/a/pytorch/torch/overrides.py", line 1733, in handle_torch_function [rank0]: result = mode.__torch_function__(public_api, types, args, kwargs) [rank0]: File "/home/bobren/local/a/pytorch/torch/utils/_device.py", line 109, in __torch_function__ [rank0]: return func(*args, **kwargs) [rank0]: File "/home/bobren/local/a/pytorch/torch/_tensor.py", line 180, in __deepcopy__ [rank0]: new_storage = self._typed_storage()._deepcopy(memo) [rank0]: File "/home/bobren/local/a/pytorch/torch/storage.py", line 1139, in _deepcopy [rank0]: return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo)) [rank0]: File "/home/bobren/local/a/pytorch-env/lib/python3.10/copy.py", line 153, in deepcopy [rank0]: y = copier(memo) [rank0]: File "/home/bobren/local/a/pytorch/torch/storage.py", line 243, in __deepcopy__ [rank0]: new_storage = self.clone() [rank0]: File "/home/bobren/local/a/pytorch/torch/storage.py", line 257, in clone [rank0]: return type(self)(self.nbytes(), device=self.device).copy_(self) [rank0]: File "/home/bobren/local/a/pytorch/torch/utils/_stats.py", line 29, in wrapper [rank0]: return fn(*args, **kwargs) [rank0]: File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 1397, in __torch_dispatch__ [rank0]: return self.dispatch(func, types, args, kwargs) [rank0]: File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 2155, in dispatch [rank0]: return self._cached_dispatch_impl(func, types, args, kwargs) [rank0]: File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 1544, in _cached_dispatch_impl [rank0]: output = self._dispatch_impl(func, types, args, kwargs) [rank0]: File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 2823, in _dispatch_impl [rank0]: r = func(*args, **kwargs) [rank0]: File "/home/bobren/local/a/pytorch/torch/_ops.py", line 836, in __call__ [rank0]: return self._op(*args, **kwargs) [rank0]: RuntimeError: Attempted to set the storage of a tensor on device "meta" to a storage on different device "cpu". This is no longer allowed; the devices must match. ``` As you can see the underlying problem is 1) we do a clone on storage 2) which under the hood calls `_copy` 3) but when we call `copy_` the fake mode turns a self, which is on "cpu", into a meta device. This PR fixes the issue by temporarily disabling dispatch when doing the deepclone. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169242 Approved by: https://github.com/bdhirsh Co-authored-by: Bob Ren --- test/dynamo/test_aot_compile.py | 42 +++++++++++++++++++++++++++++ torch/_dynamo/precompile_context.py | 12 ++++++--- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index 3146a37cb661a..33b3f4e7faaab 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -22,6 +22,7 @@ from torch._dynamo.package import DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._inductor.runtime.runtime_utils import cache_dir +from torch.distributed.tensor import DTensor, Replicate from torch.fx._graph_pickler import GraphPickler from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -269,6 +270,17 @@ def eval_mode(mdl): assert torch.allclose(expected, actual) +class RedistributeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 32) + + def forward(self, x, d_x, mesh): + x = self.linear(x) + y = d_x.redistribute(mesh, placements=(Replicate(), Replicate())) + return x, y + + @torch._dynamo.config.patch("enable_aot_compile", True) @instantiate_parametrized_tests class TestAOTCompile(torch._inductor.test_case.TestCase): @@ -781,6 +793,36 @@ def make_inputs(): self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor") self.assertEqual(expected, actual) + def test_aot_compile_with_redistribute(self): + from torch.distributed.device_mesh import init_device_mesh + from torch.testing._internal.distributed.fake_pg import FakeStore + + fake_store = FakeStore() + torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=4 + ) + mesh = init_device_mesh("cpu", (2, 2), mesh_dim_names=("dp", "tp")) + input_tensor = torch.randn(32, 32, device="cpu") + placements = (Replicate(), Replicate()) + d_input_tensor = DTensor.from_local(input_tensor, mesh, placements) + mod = RedistributeModel() + + compiled_fn = torch.compile( + mod, + fullgraph=True, + ).forward.aot_compile(((input_tensor, d_input_tensor, mesh), {})) + inputs = (input_tensor, d_input_tensor, mesh) + expected = mod(*inputs) + actual = compiled_fn(mod, *inputs) + self.assertEqual(expected, actual) + compiled_fn.save_compiled_function(self.path()) + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with open(self.path(), "rb") as f: + compiled_fn = torch.compiler.load_compiled_function(f) + actual = compiled_fn(mod, *inputs) + self.assertEqual(expected, actual) + def test_aot_compile_with_checkpoint(self): from torch.utils.checkpoint import checkpoint diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index f3715ca39ae1f..bae360041b58c 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -108,9 +108,15 @@ def record_artifact( """ Records a backend artifact to be used with dynamo cache entries """ - cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy( - artifact - ) + # Temporarily disable all dispatch modes (including FakeTensorMode) during + # deepcopy to avoid issues with cloning fake tensors (e.g., device mesh + # with meta tensors that fail when cloning due to device mismatches) + from torch.utils._mode_utils import no_dispatch + + with no_dispatch(): + cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy( + artifact + ) @classmethod def record_dynamo_cache_entry( From a70d81a28541ca4412507fd56837c894222e6a70 Mon Sep 17 00:00:00 2001 From: Pavan Balaji Date: Fri, 5 Dec 2025 08:58:34 +0000 Subject: [PATCH 332/338] [pytorch] Add env variable to enable IPC for expandable segments (#169487) Summary: PyTorch's expandable segments IPC capability was disabled in fbcode due to job failures (see https://github.com/pytorch/pytorch/pull/132890). However, some use cases like CTran require IPC functionality for multi-process GPU communication. This change introduces PYTORCH_CUDA_EXPANDABLE_SEGMENTS_IPC environment variable to allow opt-in enablement of IPC handle types for expandable segments in fbcode builds while maintaining backward compatibility. IPC is enabled by default in non-fbcode builds and disabled by default in fbcode builds (existing behavior). In both cases, it can be explicitly enabled by setting PYTORCH_CUDA_EXPANDABLE_SEGMENTS_IPC=1. Test Plan: CI Differential Revision: D88274246 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169487 Approved by: https://github.com/ngimel --- c10/cuda/CUDACachingAllocator.cpp | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 9e637f4f6997e..01e5ce59d7096 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -419,14 +419,28 @@ struct ExpandableSegment { CUmemGenericAllocationHandle handle = 0; CUmemAllocationProp prop = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; -#ifndef FBCODE_CAFFE2 - if (CUDAAllocatorConfig::expandable_segments_handle_type() != - Expandable_Segments_Handle_Type::FABRIC_HANDLE) { - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; - } else { - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; - } + // In fbcode, IPC handle types for expandable segments are disabled by + // default because some jobs were failing (see + // https://github.com/pytorch/pytorch/pull/132890), but can be explicitly + // enabled via environment variable when IPC functionality is required + // (e.g., for multi-process communication with CTran). In non-fbcode + // builds, IPC handle types are enabled by default. +#ifdef FBCODE_CAFFE2 + static const bool default_enable_ipc = false; +#else + static const bool default_enable_ipc = true; #endif + static const bool enable_ipc_handles = + c10::utils::check_env("TORCH_CUDA_EXPANDABLE_SEGMENTS_IPC") + .value_or(default_enable_ipc); + if (enable_ipc_handles) { + if (CUDAAllocatorConfig::expandable_segments_handle_type() != + Expandable_Segments_Handle_Type::FABRIC_HANDLE) { + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + } else { + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + } + } int flag = 0; C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuDeviceGetAttribute_( &flag, From 82e30f319712369de6c2acb54bd70f4562d50aae Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 5 Dec 2025 10:35:40 +0000 Subject: [PATCH 333/338] Fix viable strict update when no green commit found (#169585) If it gets to the end of the list and doesn't find a green commit, the LATEST_SHA is None Pull Request resolved: https://github.com/pytorch/pytorch/pull/169585 Approved by: https://github.com/huydhn --- .github/workflows/update-viablestrict.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index b3fc9efdf667f..1b4af0f274913 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -44,6 +44,8 @@ jobs: echo "${PUSH_RESULT}" if [ "$PUSH_RESULT" = "Everything up-to-date" ]; then echo "No update pushed" + elif [ "${LATEST_SHA}" == "None" ]; then + echo "No viable/strict candidate found" else echo "{\"sha\": \"${LATEST_SHA}\", \"repository\":\"pytorch/pytorch\", \"timestamp\": ${TIME}}" > "/tmp/${LATEST_SHA}.json" pip install awscli==1.29.40 From ae64a53c7c2572492bfbdff7779cf9576df54f87 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 4 Dec 2025 10:28:17 -0800 Subject: [PATCH 334/338] make the float4 dtype support equality comparisons (#169575) Summary: Makes `torch.allclose(a, b, atol=0, rtol=0)` work for `a` and `b` with dtype `torch.float4_e2m1fn_x2`. This is useful for testing. Test Plan: ``` pytest test/quantization/core/experimental/test_floatx.py -s -k test_float4_e2m1fn_x2 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/169575 Approved by: https://github.com/eqy, https://github.com/drisspg --- aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 16 ++++++++-------- aten/src/ATen/native/cuda/CompareEQKernel.cu | 2 +- .../core/experimental/test_floatx.py | 5 +++++ torch/headeronly/util/Float4_e2m1fn_x2.h | 15 +++++++++++++++ 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 26ec55c11d823..a79643e752c9c 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -624,11 +624,11 @@ void ge_kernel(TensorIteratorBase& iter) { void eq_kernel(TensorIteratorBase& iter) { // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - _AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "eq_cpu", [&]() { + AT_DISPATCH_V2(iter.common_dtype(), "eq_cpu", AT_WRAP([&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a == b; }); - }); + }), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } else { - _AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "eq_cpu", [&]() { + AT_DISPATCH_V2(iter.common_dtype(), "eq_cpu", AT_WRAP([&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -636,18 +636,18 @@ void eq_kernel(TensorIteratorBase& iter) { }, [](Vectorized a, Vectorized b) -> Vectorized { return a.eq(b); }); - }); + }), kComplexHalf, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } } void ne_kernel(TensorIteratorBase& iter) { // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - _AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "ne_cpu", [&]() { + AT_DISPATCH_V2(iter.common_dtype(), "ne_cpu", AT_WRAP([&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a != b; }); - }); + }), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } else { - _AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "ne_cpu", [&]() { + AT_DISPATCH_V2(iter.common_dtype(), "ne_cpu", AT_WRAP([&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -655,7 +655,7 @@ void ne_kernel(TensorIteratorBase& iter) { }, [](Vectorized a, Vectorized b) -> Vectorized { return a.ne(b); }); - }); + }), kComplexHalf, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } } diff --git a/aten/src/ATen/native/cuda/CompareEQKernel.cu b/aten/src/ATen/native/cuda/CompareEQKernel.cu index 954d0b08a1d06..442e484b9fa5c 100644 --- a/aten/src/ATen/native/cuda/CompareEQKernel.cu +++ b/aten/src/ATen/native/cuda/CompareEQKernel.cu @@ -33,7 +33,7 @@ C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) { AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() { opmath_symmetric_gpu_kernel_with_scalars( iter, CompareEqFunctor(op)); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2); } void eq_kernel_cuda(TensorIteratorBase& iter) { diff --git a/test/quantization/core/experimental/test_floatx.py b/test/quantization/core/experimental/test_floatx.py index c4cea4073a5cd..6542734ee40fe 100644 --- a/test/quantization/core/experimental/test_floatx.py +++ b/test/quantization/core/experimental/test_floatx.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: quantization"] +import copy import struct import unittest @@ -407,6 +408,10 @@ def test_float4_e2m1fn_x2(self, device): # can view uint8 as float4_e2m1fn_x2 x2.view(torch.float4_e2m1fn_x2) + # can do equality comparisons + x3 = copy.deepcopy(x1) + self.assertEqual(x1, x3, atol=0, rtol=0) + def test_f4_save_load(self, device): x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view( torch.float4_e2m1fn_x2 diff --git a/torch/headeronly/util/Float4_e2m1fn_x2.h b/torch/headeronly/util/Float4_e2m1fn_x2.h index 619a0648cf49b..00075914cdc34 100644 --- a/torch/headeronly/util/Float4_e2m1fn_x2.h +++ b/torch/headeronly/util/Float4_e2m1fn_x2.h @@ -25,8 +25,23 @@ struct alignas(1) Float4_e2m1fn_x2 { C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {} }; +/// Comparison operators +inline C10_HOST_DEVICE bool operator==( + const Float4_e2m1fn_x2& a, + const Float4_e2m1fn_x2& b) { + return a.val_ == b.val_; +} + +inline C10_HOST_DEVICE bool operator!=( + const Float4_e2m1fn_x2& a, + const Float4_e2m1fn_x2& b) { + return a.val_ != b.val_; +} + } // namespace c10 HIDDEN_NAMESPACE_BEGIN(torch, headeronly) using c10::Float4_e2m1fn_x2; +using c10::operator==; +using c10::operator!=; HIDDEN_NAMESPACE_END(torch, headeronly) From 3cf2f19f0a61269c1c6b7162f26466a5769fb914 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 4 Dec 2025 12:45:27 -0800 Subject: [PATCH 335/338] add copy_ support for float4 dtype (#169595) Summary: Enables `copy_` support for the `torch.float4_e2m1fn_x2` dtype. This is useful when slicing a tensor across dim1 and then calling contiguous, which can happen in vllm and therefore should be supported. Test Plan: ``` pytest test/quantization/core/experimental/test_floatx.py -s -k test_float4_e2m1fn_x2 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/169595 Approved by: https://github.com/drisspg ghstack dependencies: #169575 --- aten/src/ATen/native/cpu/CopyKernel.cpp | 2 ++ aten/src/ATen/native/cuda/Copy.cu | 4 ++++ test/quantization/core/experimental/test_floatx.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index 68c5a867f24ee..80708e548b196 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -235,6 +235,8 @@ void direct_copy_kernel(TensorIteratorBase &iter) { }); } else if (dtype == ScalarType::ComplexHalf) { cpu_kernel(iter, [=](c10::complex a) -> c10::complex { return a; }); + } else if (dtype == ScalarType::Float4_e2m1fn_x2) { + cpu_kernel(iter, [=](Float4_e2m1fn_x2 a) -> Float4_e2m1fn_x2 { return a; }); } else if (isBitsType(dtype)) { AT_DISPATCH_BIT_TYPES(dtype, "copy_kernel", [&] { cpu_kernel( diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 754582d2d9777..4295e4a74de26 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -234,6 +234,10 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); }); + } else if (dtype == ScalarType::Float4_e2m1fn_x2) { + TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting " + "Float4_e2m1fn_x2 to different types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype); + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Float4_e2m1fn_x2 x) { return x; }); } else { AT_DISPATCH_V2( dtype, "copy_", AT_WRAP([&] { diff --git a/test/quantization/core/experimental/test_floatx.py b/test/quantization/core/experimental/test_floatx.py index 6542734ee40fe..d234d857e84a1 100644 --- a/test/quantization/core/experimental/test_floatx.py +++ b/test/quantization/core/experimental/test_floatx.py @@ -412,6 +412,9 @@ def test_float4_e2m1fn_x2(self, device): x3 = copy.deepcopy(x1) self.assertEqual(x1, x3, atol=0, rtol=0) + # can call contiguous on a dim1 slice (calls `copy_` under the hood) + x1[:, 0:2048].contiguous() + def test_f4_save_load(self, device): x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view( torch.float4_e2m1fn_x2 From fa21963e02351c32d26b59c81e2bb2b9a5e4235d Mon Sep 17 00:00:00 2001 From: linhaifeng <1371675203@qq.com> Date: Fri, 5 Dec 2025 11:41:51 +0000 Subject: [PATCH 336/338] [MPS] Add input/indices shape validation for MaxUnpool{1,2,3}d (#169261) Add missing shape validation between `input` and `indices` tensors for `nn.MaxUnpool{1,2,3}d` on MPS backend Fixes #169235 Pull Request resolved: https://github.com/pytorch/pytorch/pull/169261 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .../src/ATen/native/mps/operations/Pooling.mm | 7 +++ .../_internal/common_methods_invocations.py | 47 ++++++++++++++++--- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index a8e25389b25a3..84920275c9dba 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -570,6 +570,13 @@ static void max_unpool_out_mps_template(const Tensor& input, " elements but got ", output_size_.size()); + // Check that input and indices have the same shape + TORCH_CHECK(input.sizes() == indices.sizes(), + "Expected shape of indices to be same as that of the input tensor (", + input.sizes(), + ") but got indices tensor with shape: ", + indices.sizes()); + auto dims = input.dim(); auto leading_dims = input.dim() - pooling_dims; diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 86320ed763204..e88a4f5887739 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9384,6 +9384,42 @@ def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(make_input((2,)), offset=1) yield SampleInput(make_input((2,)), offset=-1) + +_UNPOOL_NAME_TO_DIM = { + 'nn.functional.max_unpool1d': 1, + 'nn.functional.max_unpool2d': 2, + 'nn.functional.max_unpool3d': 3 +} + + +def error_inputs_max_unpool(op_info, device, **kwargs): + """Error inputs for max_unpool: shape mismatch between input and indices.""" + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name] + + # Create mismatched shapes for input and indices + kwargs_dict = {'kernel_size': 3, 'stride': 2, 'padding': 0} + if pool_dim == 1: + input_shape = (8, 8) + indices_shape = (8, 7) + elif pool_dim == 2: + input_shape = (1, 1, 4, 4) + indices_shape = (1, 1, 4, 1) + else: # pool_dim == 3 + input_shape = (1, 1, 4, 4, 4) + indices_shape = (1, 1, 4, 4, 1) + + yield ErrorInput( + SampleInput( + make_arg(input_shape), + args=(torch.zeros(indices_shape, device=device, dtype=torch.long),), + kwargs=kwargs_dict + ), + error_type=RuntimeError, + error_regex='Expected shape of indices to be' + ) + + def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): unpool_name_to_pool_method_dict = { 'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d, @@ -9391,15 +9427,9 @@ def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): 'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d } - unpool_name_to_dim = { - 'nn.functional.max_unpool1d': 1, - 'nn.functional.max_unpool2d': 2, - 'nn.functional.max_unpool3d': 3 - } - unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()} - pool_dim = unpool_name_to_dim[op_info.name] + pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name] pool_method = unpool_name_to_pool_method_dict[op_info.name] pool_op_info = copy.copy(op_info) @@ -16252,6 +16282,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): assert_jit_shape_analysis=False, dtypes=floating_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, skips=( # Gradients are tested in `variant_test_name=grad` below. # We skip tests here because there is non-determinism in backward @@ -16286,6 +16317,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): assert_jit_shape_analysis=False, dtypes=floating_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, skips=( # Gradients are tested in `variant_test_name=grad` below. # We skip tests here because there is non-determinism in backward @@ -16323,6 +16355,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): assert_jit_shape_analysis=False, dtypes=floating_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, skips=( # Gradients are tested in `variant_test_name=grad` below. # We skip tests here because there is non-determinism in backward From 3620149a2b491ec50fc1cc644ff1ee479ef9d59d Mon Sep 17 00:00:00 2001 From: Wenlin Chong Date: Fri, 5 Dec 2025 12:01:21 +0000 Subject: [PATCH 337/338] Correct some grammatical and expression errors in the CONTRIBUTING.md (#167926) Correct some grammatical and expression errors in the CONTRIBUTING.md file. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/167926 Approved by: https://github.com/mikaylagawarecki --- CONTRIBUTING.md | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 850753f13b63a..85982336d563c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -789,7 +789,7 @@ with `pip install ninja`. If PyTorch was already built, you will need to run `python setup.py clean` once after installing ninja for builds to succeed. -Note: Make sure to use a machine with a larger number of CPU cores, this will significantly reduce your build times. +Note: Make sure to use a machine with a larger number of CPU cores;this will significantly reduce your build times. #### Use CCache @@ -797,7 +797,7 @@ Even when dependencies are tracked with file modification, there are many situations where files get rebuilt when a previous compilation was exactly the same. Using ccache in a situation like this is a real time-saver. -Before building pytorch, install ccache from your package manager of choice: +Before building PyTorch, install ccache from your package manager of choice: ```bash sudo apt install ccache @@ -816,7 +816,7 @@ ccache -M 25Gi # -M 0 for unlimited ccache -F 0 ``` -To check this is working, do two clean builds of pytorch in a row. The second +To check this is working, do two clean builds of PyTorch in a row. The second build should be substantially and noticeably faster than the first build. If this doesn't seem to be the case, check the `CMAKE__COMPILER_LAUNCHER` rules in `build/CMakeCache.txt`, where `` is `C`, `CXX` and `CUDA`. @@ -865,8 +865,8 @@ This adds a build step where the compiler takes `` and essentially dumps its internal AST to a file so the compiler can avoid repeating itself for every `.cpp` file. -One caveat is that when enabled, this header gets included in every file by default. -Which may change what code is legal, for example: +One caveat is that when enabled, this header gets included in every file by default, +which may change what code is legal, for example: - internal functions can never alias existing names in `` - names in `` will work even if you don't explicitly include it. @@ -886,11 +886,11 @@ python -m pip install --no-build-isolation -v -e . ### Rebuild few files with debug information -While debugging a problem one often had to maintain a debug build in a separate folder. -But often only a few files needs to be rebuild with debug info to get a symbolicated backtrace or enable source debugging +While debugging a problem, one often has to maintain a debug build in a separate folder. +But often only a few files need to be rebuilt with debug info to get a symbolicated backtrace or enable source debugging. One can easily solve this with the help of `tools/build_with_debinfo.py` -For example, suppose one wants to debug what is going on while tensor index is selected, which can be achieved by setting a breakpoint at `applySelect` function: +For example, suppose one wants to debug what is going on while a tensor index is selected, which can be achieved by setting a breakpoint at `applySelect` function: ``` % lldb -o "b applySelect" -o "process launch" -- python3 -c "import torch;print(torch.rand(5)[3])" (lldb) target create "python" @@ -912,7 +912,7 @@ libtorch_python.dylib`at::indexing::impl::applySelect: Target 0: (python) stopped. Process 87729 launched: '/usr/bin/python' (arm64) ``` -Which is not very informative, but can be easily remedied by rebuilding `python_variable_indexing.cpp` with debug information +This is not very informative, but can be easily remedied by rebuilding `python_variable_indexing.cpp` with debug information. ``` % ./tools/build_with_debinfo.py torch/csrc/autograd/python_variable_indexing.cpp [1 / 2] Building caffe2/torch/CMakeFiles/torch_python.dir/csrc/autograd/python_variable_indexing.cpp.o @@ -942,7 +942,7 @@ Process 87741 stopped Target 0: (python) stopped. Process 87741 launched: '/usr/bin/python3' (arm64) ``` -Which is much more useful, isn't it? +This is much more useful, isn't it? ### C++ frontend development tips @@ -956,10 +956,10 @@ Please follow the lead of the other tests to see how to write a new test case. ### GDB integration -If you are debugging pytorch inside GDB, you might be interested in +If you are debugging PyTorch inside GDB, you might be interested in [pytorch-gdb](tools/gdb/pytorch-gdb.py). This script introduces some -pytorch-specific commands which you can use from the GDB prompt. In -particular, `torch-tensor-repr` prints a human-readable repr of an at::Tensor +PyTorch-specific commands which you can use from the GDB prompt. In +particular, `torch-tensor-repr` prints a human-readable representation of an at::Tensor object. Example of usage: ``` @@ -993,7 +993,7 @@ tensor([1., 2., 3., 4.], dtype=torch.float64) ``` GDB tries to automatically load `pytorch-gdb` thanks to the -[.gdbinit](.gdbinit) at the root of the pytorch repo. However, auto-loadings is disabled by default, because of security reasons: +[.gdbinit](.gdbinit) at the root of the PyTorch repository. However, auto-loading is disabled by default, because of security reasons: ```bash $ gdb @@ -1034,7 +1034,7 @@ If you are working on the CUDA code, here are some useful CUDA debugging tips: `std::tuple` etc. in device code. Many of such features are possible because of the [--expt-relaxed-constexpr](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#constexpr-functions) nvcc flag. There is a known [issue](https://github.com/ROCm/hip/issues/374) - that ROCm errors out on device code, which uses such stl functions. + that ROCm errors out on device code, which uses such STL functions. 4. A good performance metric for a CUDA kernel is the [Effective Memory Bandwidth](https://devblogs.nvidia.com/how-implement-performance-metrics-cuda-cc/). It is useful for you to measure this metric whenever you are writing/optimizing a CUDA @@ -1289,7 +1289,7 @@ More information can be found We need `LD_PRELOAD` because there is a cmake check that ensures that a simple program builds and runs. If we are building with ASAN as a shared -library, we need to `LD_PRELOAD` the runtime library, otherwise there will +library, we need to use `LD_PRELOAD` to load the runtime library, otherwise there will be dynamic linker errors and the check will fail. We don’t actually need either of these if we fix the cmake checks. @@ -1361,7 +1361,7 @@ There are two possible choices for which commit to use: For all practical purposes, most people can think of the commit being used as commit `B` (choice **1**). -However, if workflow files (which govern CI behavior) were modified (either by your PR or since dev branch were created ) there's +However, if workflow files (which govern CI behavior) were modified (either by your PR or since dev branch was created) there's a nuance to know about: The workflow files themselves get taken from checkpoint `C`, the merger of your PR and the `main` branch. But only the workflow files get taken from that merged From 4710fd9ecf5844092591446dc311cc265017d656 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 5 Dec 2025 13:26:59 +0000 Subject: [PATCH 338/338] Remove unnecessary uses of thrust::pair (#168941) This PR replaces unnecessary uses of thrust::pair with std::pair. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168941 Approved by: https://github.com/albanD --- aten/src/ATen/native/cuda/ReflectionPad.cu | 9 ++++----- aten/src/ATen/native/cuda/group_norm_kernel.cu | 2 +- aten/src/ATen/native/cuda/layer_norm_kernel.cu | 2 +- aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/cuda/ReflectionPad.cu b/aten/src/ATen/native/cuda/ReflectionPad.cu index 935471dad5c13..78d960850db00 100644 --- a/aten/src/ATen/native/cuda/ReflectionPad.cu +++ b/aten/src/ATen/native/cuda/ReflectionPad.cu @@ -23,7 +23,6 @@ #include #endif -#include namespace at::native { namespace { @@ -31,7 +30,7 @@ namespace { using at::cuda::detail::canUse32BitIndexMath; __device__ -inline thrust::pair get_index_mapping1d( +inline std::pair get_index_mapping1d( int64_t input_w, int64_t output_w, int64_t output_x, int64_t pad_l) { @@ -50,13 +49,13 @@ inline thrust::pair get_index_mapping1d( + 2 * pad_l + input_w - 1 - o_start_x + i_start_x; - return thrust::make_pair( + return std::make_pair( input_offset + input_x, output_offset + output_x); } __device__ -inline thrust::pair get_index_mapping2d( +inline std::pair get_index_mapping2d( int64_t input_dim_x, int64_t input_dim_y, int64_t output_dim_x, int64_t output_dim_y, int64_t pad_l, int64_t pad_t, @@ -87,7 +86,7 @@ inline thrust::pair get_index_mapping2d( + 2 * pad_t + input_dim_y - 1 - o_start_y + i_start_y; - return thrust::make_pair( + return std::make_pair( input_offset + input_y * input_dim_x + input_x, output_offset + output_y * output_dim_x + output_x); } diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index 77d26e915b65a..254cd69466e8d 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -38,7 +38,7 @@ __global__ void RowwiseMomentsCUDAKernel( using T_ACC = acc_type; using WelfordType = WelfordData; using WelfordOp = - WelfordOps>; + WelfordOps>; const int64_t i = blockIdx.x; WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 84812eb22125f..0a4b58cbdbd85 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -64,7 +64,7 @@ __global__ void RowwiseMomentsCUDAKernel( T_ACC* rstd) { using WelfordType = WelfordData; using WelfordOp = - WelfordOps>; + WelfordOps>; __shared__ typename std::aligned_storage:: diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index 410c511bebef6..9726f391129f9 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -88,7 +88,7 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { ); // this forces device-host synchronization! - thrust::pair newEnd = thrust::unique_by_key(policy, + auto newEnd = thrust::unique_by_key(policy, indicesIter, indicesIter + nnz, uniqueOffsetsIter );