From aeb5321b6360c899808d3461789b3bbd6265756e Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Tue, 5 Aug 2025 09:36:24 +0000 Subject: [PATCH 0001/1424] Allow controlling PG backend and options via init_device_mesh (#159371) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159371 Approved by: https://github.com/wconstab, https://github.com/fduwjj, https://github.com/wanchaol --- test/distributed/test_device_mesh.py | 112 ++++++++++++++- torch/_C/_distributed_c10d.pyi | 1 + .../distributed/c10d/FakeProcessGroup.hpp | 23 ++- torch/csrc/distributed/c10d/init.cpp | 29 ++-- torch/distributed/device_mesh.py | 133 ++++++++++++++++-- .../distributed/_tensor/common_dtensor.py | 24 ++-- .../testing/_internal/distributed/fake_pg.py | 10 +- 7 files changed, 297 insertions(+), 35 deletions(-) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 04aaad9990f9c..5672171d0be4d 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist import torch.distributed._functional_collectives as funcol +from torch._C._distributed_c10d import Backend as C10dBackend from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.distributed_c10d import ( @@ -30,7 +31,7 @@ DTensorTestBase, with_comms, ) -from torch.testing._internal.distributed.fake_pg import FakeStore +from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeStore from torch.utils._typing_utils import not_none @@ -578,6 +579,115 @@ def test_raises_mesh_shape_mesh_dim_names_mismatch(self): mesh_dim_names=["dp", "tp"], ) + def _test_backend_override_argument_dict_with_idx_and_backend(self): + opts = FakeProcessGroup.Options() + opts.fake_option = 42 + + mesh = init_device_mesh( + self.device_type, + (2, 2, 2), + mesh_dim_names=("dp", "tp", "cp"), + backend_override={0: "fake", 2: ("fake", opts)}, + ) + + def get_opts(mesh: DeviceMesh, dim_idx: int) -> C10dBackend.Options: + return ( + mesh.get_group(dim_idx) + ._get_backend(torch.device(f"{self.device_type}:{self.rank}")) + .options + ) + + # Fake pg only have BackendType as BackendType::CUSTOM. + self.assertEqual(mesh.get_group(0)._get_backend_name(), "custom") + self.assertNotEqual(mesh.get_group(1)._get_backend_name(), "custom") + self.assertEqual(mesh.get_group(2)._get_backend_name(), "custom") + + self.assertIsNone(get_opts(mesh, 0)) + self.assertEqual(get_opts(mesh, 2).fake_option, 42) + + dp_tp_mesh = mesh["dp", "tp"]._flatten() + dp_cp_mesh = mesh["dp", "cp"]._flatten(backend_override="fake") + tp_cp_mesh = mesh["tp", "cp"]._flatten(backend_override=("fake", opts)) + + self.assertNotEqual(dp_tp_mesh.get_group(0)._get_backend_name(), "custom") + self.assertEqual(dp_cp_mesh.get_group(0)._get_backend_name(), "custom") + self.assertEqual(tp_cp_mesh.get_group(0)._get_backend_name(), "custom") + + self.assertIsNone(get_opts(dp_cp_mesh, 0)) + self.assertEqual(get_opts(tp_cp_mesh, 0).fake_option, 42) + + @with_comms + def test_backend_override_argument_dict_with_idx_and_backend_lazy(self): + self._test_backend_override_argument_dict_with_idx_and_backend() + + @with_comms(eager_init=True) + def test_backend_override_argument_dict_with_idx_and_backend_eager(self): + self._test_backend_override_argument_dict_with_idx_and_backend() + + @with_comms(backend="fake") + def test_backend_override_argument_dict_with_name_and_options(self): + opts = FakeProcessGroup.Options() + opts.fake_option = 42 + + mesh = init_device_mesh( + self.device_type, + (2, 2, 2), + mesh_dim_names=("dp", "tp", "cp"), + backend_override={"tp": opts}, + ) + + def get_opts(mesh: DeviceMesh, dim_idx: int) -> C10dBackend.Options: + return ( + mesh.get_group(dim_idx) + ._get_backend(torch.device(f"{self.device_type}:{self.rank}")) + .options + ) + + self.assertIsNone(get_opts(mesh, 0)) + self.assertEqual(get_opts(mesh, 1).fake_option, 42) + self.assertIsNone(get_opts(mesh, 2)) + + dp_tp_mesh = mesh["dp", "tp"]._flatten() + dp_cp_mesh = mesh["dp", "cp"]._flatten(backend_override=opts) + + self.assertIsNone(get_opts(dp_tp_mesh, 0)) + self.assertEqual(get_opts(dp_cp_mesh, 0).fake_option, 42) + + @with_comms + def test_backend_override_argument_errors(self): + with self.assertRaisesRegex( + RuntimeError, + "Found redundant dim index 0 and name dp in backend_override", + ): + init_device_mesh( + self.device_type, + (2, 4), + mesh_dim_names=("dp", "tp"), + backend_override={"dp": "foo", 0: "bar"}, + ) + + with self.assertRaisesRegex( + RuntimeError, + r"Found invalid keys in backend_override: got \['cp'\]", + ): + init_device_mesh( + self.device_type, + (2, 4), + mesh_dim_names=("dp", "tp"), + backend_override={"cp": "foo"}, + ) + + with self.assertRaisesRegex( + RuntimeError, + r"Found invalid keys in backend_override: got \[42\]", + ): + init_device_mesh( + self.device_type, + (2, 4), + mesh_dim_names=("dp", "tp"), + backend_override={42: "bar"}, + ) + class TestDeviceMeshGetItem(DTensorTestBase): @property diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f0413764cda6c..9007d3fbf5a09 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -315,6 +315,7 @@ class Backend: def options(self) -> Options: ... def rank(self) -> int: ... def size(self) -> int: ... + def name(self) -> str: ... def abort(self) -> None: ... def shutdown(self) -> None: ... def eager_connect_single_device(self, device: torch.device | None) -> None: ... diff --git a/torch/csrc/distributed/c10d/FakeProcessGroup.hpp b/torch/csrc/distributed/c10d/FakeProcessGroup.hpp index e8cdbfbbe8c89..dc3c4889057c8 100644 --- a/torch/csrc/distributed/c10d/FakeProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/FakeProcessGroup.hpp @@ -20,7 +20,25 @@ class FakeWork : public Work { class FakeProcessGroup : public Backend { public: - FakeProcessGroup(int rank, int size) : Backend(rank, size) {} + struct Options : Backend::Options { + explicit Options() : Backend::Options("fake") {} + + int fake_option = 0; + }; + + FakeProcessGroup( + int rank, + int size, + c10::intrusive_ptr options = c10::make_intrusive()) + : Backend(rank, size), options_(std::move(options)) {} + + const std::string getBackendName() const override { + return "fake"; + } + + c10::intrusive_ptr getBackendOptions() override { + return c10::static_intrusive_pointer_cast(options_); + } c10::intrusive_ptr broadcast( std::vector& /* tensors */, @@ -194,6 +212,9 @@ class FakeProcessGroup : public Backend { const BarrierOptions& /* opts */ = BarrierOptions()) override { return c10::make_intrusive(); } + + private: + c10::intrusive_ptr options_; }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 824f26414c9fb..c39957c2e8386 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3776,14 +3776,27 @@ such as `dist.all_reduce(tensor, async_op=True)`. auto fakeProcessGroup = intrusive_ptr_no_gil_destructor_class_<::c10d::FakeProcessGroup>( - module, "FakeProcessGroup", backend) - .def( - py::init([](int rank, int size) { - return c10::make_intrusive<::c10d::FakeProcessGroup>( - rank, size); - }), - py::arg("rank"), - py::arg("world_size")); + module, "FakeProcessGroup", backend); + intrusive_ptr_class_<::c10d::FakeProcessGroup::Options>( + fakeProcessGroup, "Options", backendOptions) + .def(py::init()) + .def_readwrite( + "fake_option", &::c10d::FakeProcessGroup::Options::fake_option); + fakeProcessGroup + .def( + py::init([](int rank, + int size, + c10::intrusive_ptr<::c10d::FakeProcessGroup::Options> + options) { + return c10::make_intrusive<::c10d::FakeProcessGroup>( + rank, size, std::move(options)); + }), + py::arg("rank"), + py::arg("world_size"), + py::arg("options") = + c10::make_intrusive<::c10d::FakeProcessGroup::Options>()) + .def_property_readonly( + "options", &::c10d::FakeProcessGroup::getBackendOptions); auto fakeWork = intrusive_ptr_no_gil_destructor_class_<::c10d::FakeWork>( module, "FakeWork", work) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 85f2fff4f831b..e7d1e053fbfd8 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -5,8 +5,9 @@ import os import threading import warnings +from collections.abc import Iterator from functools import reduce -from itertools import chain +from itertools import chain, zip_longest from typing import Optional, TYPE_CHECKING, Union import torch @@ -69,7 +70,7 @@ def __init__(self) -> None: self.mesh_stack: list[DeviceMesh] = [] self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {} self.mesh_dim_group_options: dict[ - int, tuple[str, Optional[C10dBackend.Options]] + int, tuple[Optional[str], Optional[C10dBackend.Options]] ] = {} self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {} # Record flatten mesh name to its mesh dim index in root mesh. @@ -166,7 +167,13 @@ def create_sub_mesh( return res_submesh def create_flatten_mesh( - self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None + self, + device_mesh: "DeviceMesh", + mesh_dim_name: Optional[str] = None, + backend_override: tuple[Optional[str], Optional[C10dBackend.Options]] = ( + None, + None, + ), ) -> "DeviceMesh": root_mesh = _mesh_resources.get_root_mesh(device_mesh) @@ -217,6 +224,7 @@ def create_flatten_mesh( root_mesh.device_type, mesh_nd, mesh_dim_names=(mesh_dim_name,), + backend_override=(backend_override,), ) if cur_rank in mesh_nd: res_flattened_mesh = flattened_mesh @@ -283,7 +291,7 @@ def get_mesh_dim_by_name( def _set_mesh_dim_group_options( self, dim: int, - backend: str, + backend: Optional[str], pg_options: Optional[C10dBackend.Options] = None, ) -> None: self.mesh_dim_group_options[dim] = (backend, pg_options) @@ -439,6 +447,9 @@ def __init__( mesh: Union[torch.Tensor, "ArrayLike"], *, mesh_dim_names: Optional[tuple[str, ...]] = None, + backend_override: Optional[ + tuple[tuple[Optional[str], Optional[C10dBackend.Options]], ...] + ] = None, _init_backend: bool = True, ) -> None: self.device_type = device_type @@ -450,6 +461,8 @@ def __init__( else torch.tensor(mesh, device="cpu", dtype=torch.int) ) self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None + if backend_override is None: + backend_override = ((None, None),) * self.mesh.ndim # private field to pre-generate DeviceMesh's hash self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) @@ -463,7 +476,7 @@ def __init__( # process (we need to know if the current global rank is in the mesh or not). if _init_backend: self._setup_world_group_and_device() - self._init_process_groups() + self._init_process_groups(backend_override) if is_initialized() and get_backend() == "threaded": self._thread_id = threading.get_ident() @@ -525,7 +538,12 @@ def _setup_world_group_and_device(self): return _get_default_group() - def _init_process_groups(self): + def _init_process_groups( + self, + backend_override: tuple[ + tuple[Optional[str], Optional[C10dBackend.Options]], ... + ], + ): # group_name associated with each mesh dimension, each # mesh dimension should have one sub-group per rank # @@ -535,7 +553,9 @@ def _init_process_groups(self): if ( self.mesh.ndim == 1 and self.mesh.numel() == get_world_size() - and 0 not in _mesh_resources.mesh_dim_group_options + and _mesh_resources.mesh_dim_group_options.get(0, (None, None)) + == (None, None) + and backend_override[0] == (None, None) ): # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`. # Otherwise, create new pg. @@ -563,12 +583,17 @@ def _init_process_groups(self): # Respect dim group options specified via _MeshEnv.set_dim_group_options(). # Inherit from the parent group if no options are specified for the group. if dim in _mesh_resources.mesh_dim_group_options: + if backend_override[dim] != (None, None): + raise RuntimeError( + f"Dimension {dim} present both in the backend_override argument " + "and via _mesh_resources._set_mesh_dim_group_options" + ) ( backend, pg_options, ) = _mesh_resources.mesh_dim_group_options[dim] else: - backend, pg_options = None, None + backend, pg_options = backend_override[dim] # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. @@ -591,10 +616,19 @@ def _init_process_groups(self): dim_group = None has_split_group = False if ( - bound_device_id := getattr( - default_group, "bound_device_id", None + ( + bound_device_id := getattr( + default_group, "bound_device_id", None + ) + ) + is not None + and torch.cuda.is_available() + and ( + backend is None + or default_group._get_backend(torch.device("cuda")).name() + == backend ) - ) is not None and torch.cuda.is_available(): + ): dim_group = split_group( parent_pg=default_group, pg_options=pg_options, @@ -968,7 +1002,13 @@ def get_coordinate(self) -> Optional[list[int]]: """ return self._coordinate_on_dim if self._coordinate_on_dim else None - def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": + def _flatten( + self, + mesh_dim_name: Optional[str] = None, + backend_override: Union[ + None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] + ] = None, + ) -> "DeviceMesh": """ Returns a 1D DeviceMesh by flattening the current DeviceMesh. @@ -986,13 +1026,65 @@ def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": "Cannot flatten a DeviceMesh without mesh_dim_names!" ) - return _mesh_resources.create_flatten_mesh(self, mesh_dim_name) + if backend_override is not None: + (backend_override_tuple,) = _normalize_backend_override( + {0: backend_override}, 1 + ) + else: + backend_override_tuple = (None, None) + + return _mesh_resources.create_flatten_mesh( + self, mesh_dim_name, backend_override_tuple + ) + + def _normalize_backend_override( + backend_override: dict[ + Union[int, str], + Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], + ], + ndim: int, + mesh_dim_names: Optional[tuple[str, ...]] = None, + ) -> Iterator[tuple[Optional[str], Optional[C10dBackend.Options]]]: + if mesh_dim_names is None: + mesh_dim_names = () + for dim_idx, dim_name in zip_longest(range(ndim), mesh_dim_names): + if dim_name is not None and dim_name in backend_override: + if dim_idx in backend_override: + raise RuntimeError( + f"Found redundant dim index {dim_idx} and " + f"name {dim_name} in backend_override" + ) + val = backend_override.pop(dim_name) + elif dim_idx in backend_override: + val = backend_override.pop(dim_idx) + else: + yield (None, None) + continue + + if isinstance(val, str): + yield (val, None) + elif isinstance(val, C10dBackend.Options): + yield (None, val) + else: + yield val + + if backend_override: + raise RuntimeError( + f"Found invalid keys in backend_override: got {list(backend_override.keys())}, " + f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}" + ) def init_device_mesh( device_type: str, mesh_shape: tuple[int, ...], *, mesh_dim_names: Optional[tuple[str, ...]] = None, + backend_override: Optional[ + dict[ + Union[int, str], + Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], + ] + ] = None, ) -> DeviceMesh: """ Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. @@ -1017,6 +1109,11 @@ def init_device_mesh( mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names` must be unique. + backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional): Overrides for some or all of + the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a + dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name + of the backend and its options, or just one of these two components (in which case the other will be + set to its default value). Returns: DeviceMesh: A :class:`DeviceMesh` object representing the device layout. @@ -1043,6 +1140,15 @@ def init_device_mesh( f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", ) + if backend_override is not None: + backend_override_tuple = tuple( + _normalize_backend_override( + backend_override, len(mesh_shape), mesh_dim_names + ) + ) + else: + backend_override_tuple = None + # assume valid device types are all letters if device_type and not device_type.isalpha(): raise RuntimeError( @@ -1058,6 +1164,7 @@ def init_device_mesh( device_type=device_type, mesh=mesh, mesh_dim_names=mesh_dim_names, + backend_override=backend_override_tuple, ) return device_mesh diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 94bfead8a0c03..32fdcce997eca 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -355,22 +355,26 @@ def backend(self) -> str: def build_device_mesh(self) -> DeviceMesh: return init_device_mesh(self.device_type, (self.world_size,)) - def init_pg(self, eager_init) -> None: + def init_pg(self, eager_init, backend: Optional[str] = None) -> None: if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) - if self.backend not in [ + if backend is None: + backend = self.backend + + if backend not in [ "nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "hccl", "xccl", + "fake", ]: - raise RuntimeError(f"Backend {self.backend} not supported!") + raise RuntimeError(f"Backend {backend} not supported!") device_id = None - if "nccl" in self.backend or "xccl" in self.backend: + if "nccl" in backend or "xccl" in backend: # set device for nccl pg for collectives torch.accelerator.set_device_index(self.rank) # we only need to set device_id for nccl backend with eager init @@ -381,7 +385,7 @@ def init_pg(self, eager_init) -> None: # so the nccl communicator is immediately formed and we can use `ncclCommSplit` # for form subgroup to avoid unnecesssary overhead. dist.init_process_group( - backend=self.backend, + backend=backend, world_size=self.world_size, rank=self.rank, # pyre-ignore[16] init_method=f"file://{self.file_name}", # pyre-ignore[16] @@ -449,13 +453,15 @@ def run_subtests(self, *args, **kwargs): # wrapper to initialize comms (processgroup) -def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc: - def decorator(func, eager_init: bool = False): +def with_comms( + eager_init: Union[TestFunc, bool] = False, backend: Optional[str] = None +) -> TestFunc: + def decorator(func, eager_init: bool = False, backend: Optional[str] = None): @wraps(func) # pyre-ignore[6] def wrapper( self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] ) -> None: - self.init_pg(eager_init) + self.init_pg(eager_init, backend) try: func(self, *args, **kwargs) # type: ignore[misc] @@ -470,7 +476,7 @@ def wrapper( return ( decorator(func=eager_init) if callable(eager_init) - else partial(decorator, eager_init=eager_init) + else partial(decorator, eager_init=eager_init, backend=backend) ) diff --git a/torch/testing/_internal/distributed/fake_pg.py b/torch/testing/_internal/distributed/fake_pg.py index a34ee75cf600e..0a2814c246459 100644 --- a/torch/testing/_internal/distributed/fake_pg.py +++ b/torch/testing/_internal/distributed/fake_pg.py @@ -11,7 +11,7 @@ class FakeStore(dist.Store): """ -def _create_fake_pg(prefix_store, rank, world_size, timeout): +def _create_fake_pg(common_opts, backend_opts): """ A fake process group (not related to FakeTensor) is a process group which doesn't actually do any communication, it just hallucinates some @@ -22,7 +22,11 @@ def _create_fake_pg(prefix_store, rank, world_size, timeout): for every collective. It should be used as a convenient tool when playing with distributed but don't care about the actual data. """ - return FakeProcessGroup(rank, world_size) + return FakeProcessGroup( + common_opts.group_rank, common_opts.group_size, backend_opts + ) -dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda", "hpu"]) +dist.Backend.register_backend( + "fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu"] +) From 0ba09a6d345816483cbca2e8b872c0bd946d822e Mon Sep 17 00:00:00 2001 From: "Zheng, Zhaoqiong" Date: Tue, 5 Aug 2025 18:37:47 +0000 Subject: [PATCH 0002/1424] fix link for tutorial of inductor on windows (#159853) fix link issue from https://docs.pytorch.org/tutorials/prototype/inductor_windows.html to https://docs.pytorch.org/tutorials/unstable/inductor_windows.html due to structure change with pr https://github.com/pytorch/tutorials/pull/3489 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159853 Approved by: https://github.com/sekyondaMeta Co-authored-by: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Co-authored-by: Zesheng Zong --- docs/source/notes/get_start_xpu.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/notes/get_start_xpu.rst b/docs/source/notes/get_start_xpu.rst index 5ca51833f0256..6414730c28d47 100644 --- a/docs/source/notes/get_start_xpu.rst +++ b/docs/source/notes/get_start_xpu.rst @@ -107,7 +107,7 @@ If you are migrating code from ``cuda``, you would change references from ``cuda The following points outline the support and limitations for PyTorch with Intel GPU: #. Both training and inference workflows are supported. -#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to Use Inductor on Windows with CPU/XPU `_. +#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to use torch.compile on Windows CPU/XPU `_. #. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported. Examples From e06b110f731dc1e576c50dd102229bbd0fcbe89a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 5 Aug 2025 18:57:35 +0000 Subject: [PATCH 0003/1424] [Testing] Add MPS to NATIVE_DEVICES (#153835) This would allow me to enable more opinfo tests against MPS device eventually and supposed to be a very simple test, but actually required minor adjustments to lots of test files, namely: - Introduce `all_mps_types_and` that is very similar to `all_types_and`, but skips `float64` - Decorate lots of tests with `@dtypesIfMPS(*all_mps_types())` - Skip `test_from_dlpack_noncontinguous` as it currently crashes (need to be fixed) - Add lots of `expectedFailureIfMPS` - Delete all `@onlyNativeDeviceTypesAnd("mps")` <sarcasm> I love how well documented this variable are </sarcasm> Pull Request resolved: https://github.com/pytorch/pytorch/pull/153835 Approved by: https://github.com/Skylion007 --- test/nn/test_convolution.py | 1 + test/nn/test_pooling.py | 19 +++++++++++++++++ test/test_dlpack.py | 15 ++++++++++++-- test/test_indexing.py | 3 +++ test/test_nn.py | 17 ++++++++++++++++ test/test_view_ops.py | 27 +++++++++++++++++++------ torch/testing/_internal/common_dtype.py | 13 ++++++++++++ torch/testing/_internal/common_utils.py | 2 +- 8 files changed, 88 insertions(+), 9 deletions(-) diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index ad715598e580d..df3a3f5766c14 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -2842,6 +2842,7 @@ def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): @parametrize_test("strided", [False, True]) # Test with both contiguous and non-contiguous inputs. @parametrize_test("contiguous", [False, True]) + @expectedFailureMPS # No double support def test_conv_backend( self, device, diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index e33385bcfa11c..a8f77df22d311 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -504,6 +504,7 @@ def test_quantized_max_pool3d(self): class TestPoolingNNDeviceType(NNTestCase): + @expectedFailureMPS # No double, float shape prop does not work @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) def test_adaptive_pooling_zero_batch(self, dtype, device): @@ -523,6 +524,7 @@ def test_adaptive_pooling_zero_batch(self, dtype, device): # when output_size = 0, in adaptive_{avg, max}_pool and its variants. # These tests are explicitly written because ErrorInputs does not support backward calls # Issue: https://github.com/pytorch/pytorch/issues/78868 + @expectedFailureMPS # No double, float shape prop does not work @onlyNativeDeviceTypes @dtypes(torch.float32, torch.float64) @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16, torch.float16) @@ -556,6 +558,7 @@ def test_adaptive_pooling_empty_output_size(self, dtype, device): with self.assertRaisesRegex(RuntimeError, error_msg): fn(input2, output_size).sum().backward() + @expectedFailureMPS # Error message does not match @onlyNativeDeviceTypes def test_adaptive_avg_pooling_backward_fails(self, device): grad_output = torch.randn(1, 2, 7, device=device) @@ -582,6 +585,7 @@ def test_adaptive_max_pooling_backward_fails(self, device): with self.assertRaisesRegex(RuntimeError, "expected dimensions"): torch.ops.aten.adaptive_max_pool3d_backward(grad_output, input, indices) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_batch(self, device): mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5)) @@ -592,6 +596,7 @@ def test_FractionalMaxPool2d_zero_batch(self, device): inp = torch.randn(1, 0, 50, 32, device=device) mod(inp) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_batch(self, device): mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device) @@ -602,6 +607,7 @@ def test_FractionalMaxPool3d_zero_batch(self, device): inp = torch.randn(1, 0, 50, 32, 32, device=device) mod(inp) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_out_size(self, device): mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1]) @@ -609,6 +615,7 @@ def test_FractionalMaxPool2d_zero_out_size(self, device): out = mod(inp) self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device)) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_out_size(self, device): mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1]) @@ -616,6 +623,7 @@ def test_FractionalMaxPool3d_zero_out_size(self, device): out = mod(inp) self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device)) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_samples(self, device): samples = torch.rand([0, 16, 2], device=device) @@ -630,6 +638,7 @@ def test_FractionalMaxPool2d_zero_samples(self, device): with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"): mod(inp1) + @expectedFailureMPS # Op not implemented @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_samples(self, device): samples = torch.rand([0, 16, 3], device=device) @@ -823,6 +832,7 @@ def test_MaxUnpool_index_errors( else: unpool(output, indices) + @expectedFailureMPS @onlyNativeDeviceTypes def test_AdaptiveMaxPool_zero_batch_dim(self, device): inp = torch.randn(0, 16, 50, device=device) @@ -962,6 +972,7 @@ def test_adaptive_avg_pool3d_output_size_one(self, device): c = out.size(1) self.assertEqual(out.stride(), [c, 1, 1, 1, 1]) + @expectedFailureMPS # Runtime Error not raised for mps @expectedFailureMeta # Runtime Error not raised for meta @onlyNativeDeviceTypes @dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long) @@ -976,6 +987,7 @@ def test_adaptive_pooling_no_suppot_input(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "not implemented"): module(input) + @expectedFailureMPS # TODO: fixme @onlyNativeDeviceTypes @gcIfJetson @dtypes(torch.float, torch.double) @@ -1123,6 +1135,7 @@ def helper(n, c, h, w, ks): helper(1, 100000, 32, 32, ks=4) helper(1, 100000, 1, 4, ks=(1, 4)) # test for max_pool1d + @expectedFailureMPS # TODO: Fixme @onlyNativeDeviceTypes @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfCUDA(torch.half, torch.float, torch.double) @@ -1198,6 +1211,7 @@ def check(x, args, expected, memory_format): torch.channels_last, ) + @expectedFailureMPS # TODO: Fixme @onlyNativeDeviceTypes @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfCUDA(torch.half, torch.float, torch.double) @@ -1722,6 +1736,7 @@ def test_maxpool_indices_no_batch_dim(self, device, dtype): @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) + @expectedFailureMPS # Exception not raise @onlyNativeDeviceTypes # TODO: Fails on XLA @gcIfJetson def test_max_pool_nan_inf(self, device, dtype): @@ -1758,6 +1773,7 @@ def test_max_pool_nan_inf(self, device, dtype): res2 = fn(x2, 1 if adaptive else 3) self.assertTrue(math.isinf(res2.item())) + @expectedFailureMPS # float64 @expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta @onlyNativeDeviceTypes def test_fractional_max_pool2d(self, device): @@ -1820,6 +1836,7 @@ def test_fractional_max_pool2d_backward_fails(self, device): grad_output, input, kernel_size, output_size, indices ) + @expectedFailureMPS # float64 @expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta @onlyNativeDeviceTypes def test_fractional_max_pool3d(self, device): @@ -1867,6 +1884,7 @@ def func(x): x, (2, 2, 2), output_size=output_size, _random_samples=samples ) + @expectedFailureMPS # Not implemented @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) @onlyNativeDeviceTypes # TODO: Fails on XLA @@ -1896,6 +1914,7 @@ def test_fractional_max_pool_nan_inf(self, device, dtype): res2.backward(torch.randn_like(res2)) self.assertTrue(math.isinf(res2.item())) + @expectedFailureMPS # TODO: Fix me @onlyNativeDeviceTypes # TODO: RuntimeError message different on XLA def test_pooling_zero_stride(self, device): for op in ("max", "avg"): diff --git a/test/test_dlpack.py b/test/test_dlpack.py index f734126b5e7c9..b960575cc6348 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -5,6 +5,7 @@ from torch.testing._internal.common_device_type import ( deviceCountAtLeast, dtypes, + dtypesIfMPS, instantiate_device_type_tests, onlyCPU, onlyCUDA, @@ -13,10 +14,14 @@ skipCUDAIfRocm, skipMeta, ) -from torch.testing._internal.common_dtype import all_types_and_complex_and +from torch.testing._internal.common_dtype import ( + all_mps_types_and, + all_types_and_complex_and, +) from torch.testing._internal.common_utils import ( IS_JETSON, run_tests, + skipIfMPS, skipIfTorchDynamo, TestCase, ) @@ -55,6 +60,7 @@ class TestTorchDlPack(TestCase): torch.uint64, ) ) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) def test_dlpack_capsule_conversion(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) z = from_dlpack(to_dlpack(x)) @@ -72,6 +78,7 @@ def test_dlpack_capsule_conversion(self, device, dtype): torch.uint64, ) ) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) def test_dlpack_protocol_conversion(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) z = from_dlpack(x) @@ -80,7 +87,8 @@ def test_dlpack_protocol_conversion(self, device, dtype): @skipMeta @onlyNativeDeviceTypes def test_dlpack_shared_storage(self, device): - x = make_tensor((5,), dtype=torch.float64, device=device) + dtype = torch.bfloat16 if device.startswith("mps") else torch.float64 + x = make_tensor((5,), dtype=dtype, device=device) z = from_dlpack(to_dlpack(x)) z[0] = z[0] + 20.0 self.assertEqual(z, x) @@ -120,12 +128,14 @@ def test_dlpack_conversion_with_streams(self, device, dtype): torch.uint64, ) ) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) def test_from_dlpack(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) y = torch.from_dlpack(x) self.assertEqual(x, y) @skipMeta + @skipIfMPS # MPS crashes with noncontiguous now @onlyNativeDeviceTypes @dtypes( *all_types_and_complex_and( @@ -189,6 +199,7 @@ def test_dlpack_conversion_with_diff_streams(self, device, dtype): torch.uint64, ) ) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) def test_from_dlpack_dtype(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) y = torch.from_dlpack(x) diff --git a/test/test_indexing.py b/test/test_indexing.py index 3870734f60d34..c1b4612db9e30 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -16,6 +16,7 @@ dtypesIfCPU, dtypesIfCUDA, dtypesIfMPS, + expectedFailureMPS, instantiate_device_type_tests, onlyCUDA, onlyNativeDeviceTypes, @@ -183,6 +184,7 @@ def delitem(): @onlyNativeDeviceTypes @dtypes(torch.half, torch.double) + @dtypesIfMPS(torch.half) # TODO: add bf16 there? def test_advancedindex(self, device, dtype): # Tests for Integer Array Indexing, Part I - Purely integer array # indexing @@ -1193,6 +1195,7 @@ def func1(x, i, v): out_cpu = func1(t, ind, val) self.assertEqual(out_cuda.cpu(), out_cpu) + @expectedFailureMPS # Doubles not supported @onlyNativeDeviceTypes def test_index_put_accumulate_duplicate_indices(self, device): for i in range(1, 512): diff --git a/test/test_nn.py b/test/test_nn.py index a09404c40a1e4..904b819a6fc4d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8766,6 +8766,7 @@ def rms_norm_reference_fn(i, normalized_shape, weight, eps=None): @onlyNativeDeviceTypes @dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypesIfMPS(torch.float16, torch.bfloat16, torch.float32) def test_rmsnorm_epsilon(self, device, dtype): def rms_norm_reference_fn(i, normalized_shape): eps = torch.finfo(i.dtype).eps @@ -8940,6 +8941,7 @@ def group_norm_ref(X, gamma, beta, groups, channels, eps): Y_cpu = group_norm(X.cpu()) self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5) + @expectedFailureMPS # Double is not supported on MPS @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) def test_pad(self, device, dtype): @@ -8971,6 +8973,7 @@ def test_pad(self, device, dtype): out.fill_(4) self.assertTrue(torch.all(torch.abs(inputs) < 2)) + @expectedFailureMPS # Unsupported float64/complex128 @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) def test_ReplicationPad_empty(self, device, dtype): @@ -9109,6 +9112,7 @@ def test_Bilinear_empty(self, device): self.assertEqual(inp1.grad, torch.zeros_like(inp1)) self.assertEqual(inp2.grad, torch.zeros_like(inp2)) + @expectedFailureMPS # Double not supported @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] @onlyNativeDeviceTypes def test_TransformerEncoderLayer_empty(self, device): @@ -9138,6 +9142,7 @@ def test_TransformerEncoderLayer_empty(self, device): _test_module_empty_input(self, encoder_layer, input, check_size=False) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] + @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_TransformerEncoder_empty(self, device): for batch_first, input_shape in [(True, (0, 10, 512)), @@ -9148,6 +9153,7 @@ def test_TransformerEncoder_empty(self, device): _test_module_empty_input(self, transformer_encoder, input, check_size=False) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] + @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_TransformerDecoderLayer_empty(self, device): for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), @@ -9158,6 +9164,7 @@ def test_TransformerDecoderLayer_empty(self, device): self._test_module_empty_inputs(decoder_layer, [tgt, memory]) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] + @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_TransformerDecoder_empty(self, device): for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), @@ -9169,6 +9176,7 @@ def test_TransformerDecoder_empty(self, device): self._test_module_empty_inputs(transformer_decoder, [tgt, memory]) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] + @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_Transformer_empty(self, device): for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]: @@ -9304,6 +9312,7 @@ def test_ReflectionPad3d_large(self, device): self.assertEqual(x.grad, ref_x.grad) + @expectedFailureMPS # Unimplemented margin_loss @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) def test_MarginLoss_empty(self, device, dtype): @@ -9370,6 +9379,7 @@ def test_mse_loss_error(self, device): with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): F.mse_loss(i, t) + @expectedFailureMPS # TODO: Fixme, and raise assert on empty tensor @onlyNativeDeviceTypes def test_Unfold_empty(self, device): inp = torch.randn(0, 3, 3, 4, device=device) @@ -9593,6 +9603,7 @@ def verify_reduction_scalars(input, reduction, output): verify_reduction_scalars(input, reduction, output) # verify that bogus reduction strings are errors + @expectedFailureMPS # CTCLoss unimplemented @onlyNativeDeviceTypes def test_invalid_reduction_strings(self, device): input = torch.randn(3, 5, requires_grad=True, device=device) @@ -10079,6 +10090,7 @@ def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize @parametrize_test("align_corners", [True, False]) @parametrize_test("mode", ["bilinear", "bicubic"]) @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) + @expectedFailureMPS # double device type @onlyNativeDeviceTypes def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format): # Forward AD does not support XLA because XLA tensors don't have storage @@ -10148,6 +10160,7 @@ def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory @parametrize_test("num_channels", [3, 5]) @parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"]) @parametrize_test("dtype", integral_types() + floating_types()) + @skipIfMPS # Error message is wrong for some dtypes @onlyNativeDeviceTypes def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype): x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device) @@ -11470,6 +11483,7 @@ def test_hardsigmoid_grad(self, device): self.assertTrue(gradcheck(F.hardsigmoid, (inputs,))) # currently fails on XLA + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @onlyNativeDeviceTypes def test_hardswish_grad(self, device): inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10 @@ -11677,6 +11691,7 @@ def test_batchnorm_simple_average_mixed(self, device, dtype): self._test_batchnorm_simple_average(device, dtype, torch.float) @onlyNativeDeviceTypes + @expectedFailureMPS # Unsupported Border padding mode @dtypes(torch.float, torch.double) def test_grid_sample_nan_inf(self, device, dtype): input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype) @@ -12789,6 +12804,7 @@ def test_threshold_inplace_overlap(self, device): F.threshold(x, 0.5, 0.5, inplace=True) F.threshold_(x, 0.5, 0.5) + @expectedFailureMPS # Double is unsupported @onlyNativeDeviceTypes def test_triplet_margin_with_distance_loss_default_parity(self, device): # Test for `nn.TripletMarginWithDistanceLoss` and @@ -12823,6 +12839,7 @@ def test_triplet_margin_with_distance_loss_default_parity(self, device): self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n), (anchor, positive, negative))) + @expectedFailureMPS # Double is unsupported @onlyNativeDeviceTypes def test_triplet_margin_with_distance_loss(self, device): # Test for parity between `nn.TripletMarginWithDistanceLoss` and diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 5aa30483deba9..fd0fa0290c940 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -11,15 +11,16 @@ from torch.testing._internal.common_device_type import ( dtypes, dtypesIfMPS, + expectedFailureMPS, instantiate_device_type_tests, onlyCPU, onlyNativeDeviceTypes, - onlyNativeDeviceTypesAnd, skipLazy, skipMeta, skipXLA, ) from torch.testing._internal.common_dtype import ( + all_mps_types_and, all_types_and, all_types_and_complex_and, complex_types, @@ -157,8 +158,11 @@ def test_conj_self(self, device, dtype): @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.cfloat, torch.float, torch.half, torch.bool)) def test_view_dtype_new(self, device, dtype): dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} + if device.startswith("mps"): + del dtypes[torch.float64] del dtypes[torch.bool] def generate_inputs(): @@ -271,6 +275,7 @@ def calc_expected_size_and_stride(a, view_dtype): # has a greater element size than the original dtype @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*all_mps_types_and(torch.bool)) def test_view_dtype_upsize_errors(self, device, dtype): dtype_size = torch._utils._element_size(dtype) @@ -372,6 +377,7 @@ def fn(contiguous_input=True, dim0=0, dim1=1): @onlyNativeDeviceTypes @dtypes(*complex_types(), torch.complex32) + @dtypesIfMPS(torch.cfloat, torch.chalf) def test_view_as_real(self, device, dtype): def fn(contiguous_input=True): t = torch.randn(3, 4, dtype=dtype, device=device) @@ -398,9 +404,7 @@ def fn(contiguous_input=True): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) - @dtypesIfMPS( - *integral_types_and(torch.half, torch.bfloat16, torch.bool, torch.float32) - ) + @dtypesIfMPS(*all_mps_types_and(torch.bool)) def test_view_tensor_split(self, device, dtype): a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9) a_split_dim0 = a.tensor_split(7, 0) @@ -412,6 +416,7 @@ def test_view_tensor_split(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) def test_view_tensor_hsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_hsplit = torch.hsplit(t, 2) @@ -422,6 +427,7 @@ def test_view_tensor_hsplit(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) def test_view_tensor_vsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_vsplit = torch.vsplit(t, 2) @@ -432,6 +438,7 @@ def test_view_tensor_vsplit(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) def test_view_tensor_dsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_dsplit = torch.dsplit(t, 2) @@ -440,9 +447,9 @@ def test_view_tensor_dsplit(self, device, dtype): t[2, 2, 2] = 7 self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2]) - @onlyNativeDeviceTypesAnd("mps") + @onlyNativeDeviceTypes @dtypes(*all_types_and(torch.half, torch.bfloat16)) - @dtypesIfMPS(*integral_types_and(torch.half, torch.bool, torch.float32)) + @dtypesIfMPS(*all_mps_types_and(torch.bool)) def test_imag_noncomplex(self, device, dtype): t = torch.ones((5, 5), dtype=dtype, device=device) @@ -451,6 +458,7 @@ def test_imag_noncomplex(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*complex_types()) + @dtypesIfMPS(torch.cfloat) def test_real_imag_view(self, device, dtype): def compare_with_numpy(contiguous_input=True): t = torch.randn(3, 3, dtype=dtype, device=device) @@ -481,6 +489,7 @@ def compare_with_numpy(contiguous_input=True): self.assertEqual(a[5:].imag, a.imag[5:]) @onlyNativeDeviceTypes + @expectedFailureMPS @dtypes(*complex_types()) def test_conj_imag_view(self, device, dtype) -> None: t = _make_tensor((4, 5), dtype, device) @@ -512,6 +521,12 @@ def test_conj_view_with_shared_memory(self, device) -> None: all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), ) ) + @dtypesIfMPS( + *product( + [torch.cfloat, torch.chalf], + all_mps_types_and(torch.cfloat, torch.chalf, torch.bool), + ) + ) @suppress_warnings def test_set_real_imag(self, device, dtypes): x = torch.randn(10, dtype=dtypes[0], device=device) diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index 774ce179f33e0..474bb689f0ad9 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -121,6 +121,19 @@ def all_types_and_half(): return _all_types_and_half +_all_mps_types = ( + _dispatch_dtypes({torch.float, torch.half, torch.bfloat16}) + _integral_types +) + + +def all_mps_types(): + return _all_mps_types + + +def all_mps_types_and(*dtypes): + return _all_mps_types + _validate_dtypes(*dtypes) + + _float8_types = _dispatch_dtypes( ( torch.float8_e4m3fn, diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 384db57e92ecb..e3adef752e406 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -297,7 +297,7 @@ def maybe_load_json(filename): if os.getenv("DISABLED_TESTS_FILE", ""): disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) -NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name()) +NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', torch._C._get_privateuse1_backend_name()) # used for managing devices testing for torch profiler UTs # for now cpu, cuda and xpu are added for testing torch profiler UTs From d7c83972d53efaae029933b5b5559b4edcb85f35 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Tue, 5 Aug 2025 12:16:26 -0700 Subject: [PATCH 0004/1424] tools: Add mode to find python automatically (#159820) Add support for automatically finding Python interpreters in manylinux environments to our wheel building script. Scaffolding for sequential builds Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/159820 Approved by: https://github.com/malfet --- tools/packaging/build_wheel.py | 108 ++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 2 deletions(-) diff --git a/tools/packaging/build_wheel.py b/tools/packaging/build_wheel.py index 16e9a87bd9638..10c4516a32805 100644 --- a/tools/packaging/build_wheel.py +++ b/tools/packaging/build_wheel.py @@ -4,6 +4,7 @@ import contextlib import logging import os +import re import subprocess import sys import tempfile @@ -16,11 +17,12 @@ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) ROOT_PATH = Path(__file__).absolute().parent.parent.parent SETUP_PY_PATH = ROOT_PATH / "setup.py" REQUIREMENTS_PATH = ROOT_PATH / "requirements.txt" +PYPROJECT_TOML_PATH = ROOT_PATH / "pyproject.toml" def run_cmd( @@ -45,6 +47,79 @@ def interpreter_version(interpreter: str) -> str: return str(version_string.split(" ")[1]) +def get_supported_python_versions() -> list[str]: + """Extract supported Python versions from pyproject.toml classifiers.""" + with open(PYPROJECT_TOML_PATH) as f: + content = f.read() + + # Find Python version classifiers + pattern = r'"Programming Language :: Python :: (\d+\.\d+)"' + matches = re.findall(pattern, content) + + # Sort versions and return them + return sorted(matches, key=lambda x: tuple(map(int, x.split(".")))) + + +def find_python_interpreters(mode: str) -> list[str]: + """Find Python interpreters based on the specified mode.""" + if mode == "manylinux": + return _find_manylinux_interpreters() + else: + raise ValueError(f"Unsupported mode: {mode}") + + +def _find_manylinux_interpreters() -> list[str]: + """Find Python interpreters in manylinux format (/opt/python/).""" + supported_versions = get_supported_python_versions() + interpreters = [] + + python_root = Path("/opt/python") + if not python_root.exists(): + logger.warning("Path /opt/python does not exist, no interpreters found") + return [] + + # Find all python3 binaries in /opt/python/ + python_binaries = list(python_root.glob("*/bin/python3")) + + for python_path in python_binaries: + try: + # Check if it's PyPy (skip it) + version_output = run_cmd( + [str(python_path), "--version"], capture_output=True + ) + version_string = version_output.stdout.decode("utf-8").strip() + + if "PyPy" in version_string: + logger.debug("Skipping PyPy interpreter: %s", python_path) + continue + + # Extract Python version (e.g., "Python 3.9.1" -> "3.9") + match = re.search(r"Python (\d+\.\d+)", version_string) + if not match: + logger.debug("Could not parse version from: %s", version_string) + continue + + python_version = match.group(1) + + # Check if this version is supported + if python_version in supported_versions: + interpreters.append(str(python_path)) + logger.debug( + "Found supported Python %s at %s", python_version, python_path + ) + else: + logger.debug( + "Python %s not in supported versions: %s", + python_version, + supported_versions, + ) + + except subprocess.CalledProcessError as e: + logger.debug("Failed to get version for %s: %s", python_path, e) + continue + return interpreters + + @contextlib.contextmanager def venv(interpreter: str) -> Iterator[str]: # Should this use EnvBuilder? Probably, maybe a good todo in the future @@ -100,6 +175,16 @@ def parse_args() -> argparse.Namespace: " should ideally be full paths, (default: %(default)s)" ), ) + parser.add_argument( + "--find-python", + type=str, + choices=["manylinux"], + help=( + "Automatically find Python interpreters based on the specified mode. " + "Available modes: 'manylinux' (searches /opt/python/ for interpreters " + "matching supported versions in pyproject.toml)" + ), + ) parser.add_argument( "-d", "--destination", @@ -112,7 +197,26 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() - pythons = args.python or [sys.executable] + + if args.find_python: + if args.python: + logger.warning( + "Both --python and --find-python specified. Using --find-python and ignoring --python." + ) + pythons = find_python_interpreters(args.find_python) + if not pythons: + logger.error( + "No Python interpreters found with --find-python %s", args.find_python + ) + sys.exit(1) + logger.info( + "Found %d supported Python interpreters: %s", + len(pythons), + ", ".join(pythons), + ) + else: + pythons = args.python or [sys.executable] + build_times: dict[str, float] = dict() if len(pythons) > 1 and args.destination == "dist/": From 9884d0351e70cfac1444957f2f3fef6b35b70d68 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 5 Aug 2025 19:26:22 +0000 Subject: [PATCH 0005/1424] [CUDA] Decrease launch bounds of CTCLoss backward for blackwell (#159522) Otherwise we see `CUDA error: too many resources requested for launch` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159522 Approved by: https://github.com/janeyx99 --- aten/src/ATen/native/cuda/LossCTC.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index b5908cc0abcfc..c6d3c25200d50 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -644,7 +644,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta)) // As above, there may be better configurations to use. - constexpr int max_threads = std::is_same_v ? 1024 : 896; // we need 72 or so 32 bit registers for double + constexpr int max_threads_ = std::is_same_v ? 1024 : 896; // we need 72 or so 32 bit registers for double + int max_threads = max_threads_; + // Blackwell launch bounds + if (at::cuda::getCurrentDeviceProperties()->major >= 10) { + max_threads = 512; + } int threads_target = max_threads; while (threads_target / 2 >= 2*max_target_length+1) { threads_target /= 2; From eb25a95a6e4274eac083b218642850bd6f4a7406 Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 4 Aug 2025 20:30:00 -0700 Subject: [PATCH 0006/1424] Fix inductor memory estimation when a single buf has multiple mutations. Add runtime verification of mem tracking (#159569) With fsdp, we sometimes have multiple, non-overlapping views of a single buffer which are all mutated. Previously we considered the original buffer as an allocation, and make the mutated buffer the deallocation. With multiple mutations of the same buffer, we need to consider the original buffer as deallocated only when all of its aliases die (and avoid double counting the input buffer size). See comment inline: ``` When an operation mutates a buffer in-place, the scheduler creates a new buffer name to track the "before" and "after" states, even though they share the same memory. The mutated buffer represents a rename with zero allocation and deallocation cost. During dependency tracking, we transfer dependencies from the mutated name back to the original buffer, ensuring the original memory is only freed when all aliases are done. This handles cases where a buffer has multiple non-overlapping aliases - rather than trying to assign free costs to individual aliases, we forward all alias dependencies to the original buffer. Consider: buf0 = op0() buf1 = mutation_op_(buf0) del buf0 ... op(buf1) del buf1 The only memory events are the creation prior to op0, and the deletion following buf1. ``` As @IvanKobzarev 's logs in https://github.com/pytorch/pytorch/pull/158361/files#diff-e173a1d52aff49959c9f6d17ecc09946d8a616fc5909df884e62a15e1ebd1d41R1776-R1807 show, it can a bit of a pain to pinpoint which part of our memory calculation is incorrect. This pr also adds a runtime verifier `config.test_configs.track_memory_lifecycle` which tracks buffer allocation and deallocation, and errors if their lifetime does not match our expectations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159569 Approved by: https://github.com/IvanKobzarev --- .../test_compute_comm_reordering.py | 5 +- test/distributed/test_inductor_collectives.py | 7 +- test/inductor/test_memory.py | 57 ++++++- torch/_inductor/codegen/wrapper.py | 28 +++- torch/_inductor/config.py | 2 + torch/_inductor/ir.py | 42 +++++ torch/_inductor/memory.py | 148 ++++++++++++------ torch/_inductor/runtime/debug_utils.py | 138 ++++++++++++++++ torch/_inductor/scheduler.py | 81 ++++++++++ 9 files changed, 453 insertions(+), 55 deletions(-) create mode 100644 torch/_inductor/runtime/debug_utils.py diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index 63ff2fa2bbfe2..c05d5edae2330 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -179,8 +179,11 @@ def func(a): .check("extern_kernels.mm") .check("triton_poi_fused_relu") .check("torch.ops._c10d_functional.all_reduce_.default") - .check("torch.ops._c10d_functional.wait_tensor.default") + .check_same("buf0") + # mm not use buf prior to wait_tensor .check("extern_kernels.mm") + .check_not("buf0") + .check("torch.ops._c10d_functional.wait_tensor.default") .check("extern_kernels.mm") .run(code) ) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 856e1c5f7b3c4..d0b8c32497f04 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1745,10 +1745,15 @@ def _reorder_communication_preserving_peak_memory( _reorder_communication_preserving_peak_memory, ], "allow_buffer_reuse": False, + "test_configs.track_memory_lifecycle": "error", } ): - compiled = torch.compile(func) + compiled = torch.compile(func, fullgraph=True) code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + + # make sure memory tracking is codegen. the ops will then do runtime checking with assertion. + FileCheck().check("check_memory_step").check("tracked_empty_strided").run(code) + # NOTE: The first return value should be the output of the first wait_tensor. # We want to make sure no unnecessary copy is made. ( diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 3e23442b38ec7..2231b94316b36 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -215,6 +215,7 @@ def reorder_with_only_dfs( @mock.patch.object(config, "allow_buffer_reuse", False) @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") + @config.patch("test_configs.track_memory_lifecycle", "assert") def test_mutation_size_propogation(self): """ This tests correct size propogation in the case of mutations. @@ -262,6 +263,7 @@ def assign_memory_planning_info_for_scheduler_buffers_with_records( buffer_info[buf_name] = ( buf.mpi_buffer.size_alloc, buf.mpi_buffer.size_free, + buf.mpi_buffer.succ_nodes, ) # test example and checks @@ -281,11 +283,15 @@ def f(a, p): ): f_compiled = torch.compile(f) f_compiled(a, p) - for buf_name in ["buf0", "buf2", "buf4", "buf6"]: - self.assertEqual(buffer_info[buf_name], (2048, 0)) - for buf_name in ["buf1", "buf3", "buf5", "buf7"]: - self.assertEqual(buffer_info[buf_name], (0, 2048)) + pre_mutation = ["buf0", "buf2", "buf4", "buf6"] + post_mutation = ["buf1", "buf3", "buf5", "buf7"] + + for pre, post in zip(pre_mutation, post_mutation): + self.assertEqual(buffer_info[pre][0:2], (2048, 2048)) + self.assertEqual(buffer_info[post][0:2], (0, 0)) + # succ nodes should be forwarded to pre mutation buffer + self.assertTrue(buffer_info[post][2] <= buffer_info[pre][2]) @unittest.skipIf( not torch.cuda.is_available() @@ -359,6 +365,49 @@ def f(x, y, z): .run(code) ) + @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") + def test_multiple_mutations_of_buf(self): + @torch.compile() + def foo(inp, inp2): + inp = inp @ inp + inp = inp.view(2, -1, 256) + x = inp[0] + y = inp[1] + x, y = torch._foreach_add([x, y], 1.0) + out = x.sum() + out2 = y.sum(dim=-1) + + return out, out2, inp2 @ inp2 + + inp = torch.rand([256, 256], device="cuda") + inp2 = torch.rand([256, 256], device="cuda") + + def replace_foreach(gm): + nodes = gm.find_nodes( + op="call_function", target=torch.ops.aten._foreach_add.Scalar + ) + assert len(nodes) == 1 + node = nodes[0] + nodes[0].target = torch.ops.aten._foreach_add_.Scalar + for inp, out in zip(node.args[0], list(node.users.keys())): + out.replace_all_uses_with(inp) + gm.erase_node(out) + + with torch._inductor.config.patch( + { + "post_grad_custom_post_pass": replace_foreach, + "test_configs.track_memory_lifecycle": "assert", + "allow_buffer_reuse": False, + # make sure the mm is at the end so + # the earlier deallocation is not at the last step, + # which doesnt distinguish between returned tensors + # and which tensors are deallocated immediately prior + "reorder_for_peak_memory": False, + } + ): + code = run_and_get_triton_code(foo, inp, inp2) + FileCheck().check("allocated=['buf0']").run(code) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index f4370e619c1ba..dd03163440999 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -963,9 +963,12 @@ def write_header(self) -> None: aot_config_comment = "" if context is not None and context.aot_graph_name is not None: aot_config_comment = f"# AOT ID: {context.aot_graph_name}" - aot_inductor_debug_utils = "" + inductor_debug_utils = "" if int(config.aot_inductor.debug_intermediate_value_printer) > 0: - aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + elif torch._inductor.config.test_configs.track_memory_lifecycle: + inductor_debug_utils = "from torch._inductor.runtime.debug_utils import tracked_empty_strided\n" + self.imports.splice( f""" {aot_config_comment} @@ -983,7 +986,7 @@ def write_header(self) -> None: from torch import device, empty_strided from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels - {aot_inductor_debug_utils} + {inductor_debug_utils} """, strip=True, ) @@ -2773,6 +2776,14 @@ def make_buffer_allocation(self, buffer: BufferLike): buffer.get_name(), device, dtype, shape, stride, allocation_shape ) + @cache_on_self + def write_memory_track_allocation_once(self): + import_str = """ + from torch._inductor.runtime.debug_utils import check_memory_step, track_tensor + """ + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + def make_allocation( self, name, device, dtype, shape, stride, allocation_shape=None ): @@ -2784,7 +2795,16 @@ def make_allocation( allocation_shape ) codegen_stride_tuple = self.codegen_python_shape_tuple(stride) - if device.type in ("cpu", "cuda", "xpu", "mtia"): + if torch._inductor.config.test_configs.track_memory_lifecycle: + out = ( + f"{name} = tracked_empty_strided(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"dtype={dtype}, " + f"device='{device.type}', " + f"name='{name}')" + ) + elif device.type in ("cpu", "cuda", "xpu", "mtia"): # optimized path for faster allocations, saving ~2us versus the stuff below out = ( f"{name} = empty_strided_{device.type}(" diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e5b5fe224cc81..a42eb3cdeda90 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1861,6 +1861,8 @@ class test_configs: graphsafe_rng_func_ignores_fallback_random = False + track_memory_lifecycle: Optional[Literal["assert", "log"]] = None + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a3bc472a129ca..3f03c33d70daa 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5324,6 +5324,11 @@ def should_allocate(self) -> bool: @ir_dataclass(frozen=False) class ExternKernel(InputsKernel): + """ + A class that represents Kernels which are not directly lowered to Inductor + Loop Level IR, such as custom operators, or aten operators which we fallback to. + """ + constant_args: Sequence[Any] = () kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) output_view: Optional[ReinterpretView] = None @@ -6120,6 +6125,17 @@ def codegen_alignment_asserts(self, wrapper: PythonWrapperCodegen) -> None: f"# buffer {name} (op: {op_name}) is assumed to be not aligned" ) + def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None: + """ + Track outputs of fallback operators if config.test_configs.track_memory_lifecycle + """ + if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper: + return + + wrapper.write_memory_track_allocation_once() + name = self.get_name() + wrapper.writeline(f"track_tensor({name}, '{name}')") + def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]: """ get output sizes and strides, for template_codegen @@ -7579,6 +7595,7 @@ def is_number(t: torch.JitType) -> bool: if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) self.codegen_alignment_asserts(wrapper) + self.codegen_memory_tracking(wrapper) self.codegen_unbacked_symbol_defs(wrapper) @@ -7720,6 +7737,31 @@ def __init__( ) +class MemoryCheckKernel(FallbackKernel): + """ + Custom kernel for memory checking that generates direct function calls + + TODO - the custom op was erroring with str inputs. should be able to custom op directly. + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + """Override codegen to write direct function call""" + # Extract our arguments from nontensor_args + wrapper.write_memory_track_allocation_once() + alive_list, dead_list, is_final_step = self.constant_args + + alive_repr = repr(alive_list) + dead_repr = repr(dead_list) + if is_final_step: + wrapper.writeline( + "# note: dont currently distinguish between buffers returned and dealloc'd in last step" + ) + call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})" + else: + call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})" + wrapper.writeline(call) + + @ir_dataclass class MultiOutputLayout(OutputSpec): device: torch.device diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index d287208419a9f..0967bb553e04b 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -124,6 +124,28 @@ def compute_size_for_scheduler_buffer( buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed + When an operation mutates a buffer in-place, the scheduler creates a new buffer name + to track the "before" and "after" states, even though they share the same memory. + + The mutated buffer represents a rename with zero allocation and deallocation cost. + During dependency tracking, we transfer dependencies from the mutated name back to + the original buffer, ensuring the original memory is only freed when all aliases + are done. + + This handles cases where a buffer has multiple non-overlapping aliases - rather than + trying to assign free costs to individual aliases, we forward all alias dependencies + to the original buffer. + + Consider: + buf0 = op0() + buf1 = mutation_op_(buf0) + del buf0 + ... + op(buf1) + del buf1 + + The only memory events are the creation prior to op0, and the deletion following buf1. + Returns: A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free). """ @@ -135,18 +157,11 @@ def compute_size_for_scheduler_buffer( def _compute_and_update_buf_size( sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False ) -> int: - if isinstance(sched_buf.node.layout, NoneLayout): - # mutations should inherit the size of the mutated buffer - if sched_buf.get_mutations(): - mutated_buf_name = sched_buf.get_mutations()[0] - if mutated_buf_name in sched_buf_to_size: - (_size_alloc, _size_free) = sched_buf_to_size[mutated_buf_name] - else: - (_size_alloc, _size_free) = (0, 0) - sched_buf_to_size[sched_buf.get_name()] = (0, _size_free) - sched_buf_to_size[mutated_buf_name] = (_size_alloc, 0) - else: - sched_buf_to_size[sched_buf.get_name()] = (0, 0) + if sched_buf.get_name() in V.graph.scheduler.mutation_real_name: + sched_buf_to_size[sched_buf.get_name()] = (0, 0) + return 0 + elif isinstance(sched_buf.node.layout, NoneLayout): + sched_buf_to_size[sched_buf.get_name()] = (0, 0) return 0 elif isinstance(sched_buf.node.layout, MultiOutputLayout): size_alloc = 0 @@ -200,6 +215,14 @@ def assign_memory_planning_info_for_scheduler_buffers( for dep in node.unmet_dependencies: dep_name_to_succ_nodes[dep.name].add(node) + # iterate in reverse, so dependencies are picked up transitively. + for mutating_buf_name, real_buf_name in reversed( + V.graph.scheduler.mutation_real_name.items() + ): + dep_name_to_succ_nodes[real_buf_name] |= dep_name_to_succ_nodes[ + mutating_buf_name + ] + # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) for buf_name in name_to_buf.keys(): @@ -219,58 +242,72 @@ def assign_memory_planning_info_for_scheduler_nodes( """ Assign to each scheduler node its predecessor and successor nodes. """ - from .scheduler import SchedulerBuffer - for index, node in enumerate(nodes): - size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) - pred_buffers = OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]]() - for dep in node.read_writes.reads: - if dep.name in name_to_buf and dep in node.unmet_dependencies: - pred_buffers.add(name_to_buf[dep.name]) - elif dep.name in name_to_freeable_input_buf: - pred_buffers.add(name_to_freeable_input_buf[dep.name]) - pred_nodes = OrderedSet( - name_to_fused_node[pred_buffer.defining_op_name()] - for pred_buffer in pred_buffers - if (isinstance(pred_buffer, SchedulerBuffer)) - ) + node_to_pred_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + node_to_succ_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = {} + node_to_pred_buffers: dict[ + BaseSchedulerNode, OrderedSet[SchedulerBuffer | FreeableInputBuffer] + ] = collections.defaultdict(OrderedSet) + + # collect all predecessors using existing successor mappings + for node in nodes: succ_nodes = OrderedSet( succ_node for buffer in node.get_outputs() for succ_node in buffer.mpi_buffer.succ_nodes ) + node_to_succ_nodes[node] = succ_nodes + + # For each successor, add current node as its predecessor + for succ_node in succ_nodes: + node_to_pred_nodes[succ_node].add(node) + + # For each output buffer, add it as predecessor to its successor nodes + # TODO - is pred buffers needed ? + for buffer in node.get_outputs(): + for succ_node in buffer.mpi_buffer.succ_nodes: + node_to_pred_buffers[succ_node].add(buffer) + + for freeable_buffer in name_to_freeable_input_buf.values(): + for succ_node in freeable_buffer.mpi_buffer.succ_nodes: + node_to_pred_buffers[succ_node].add(freeable_buffer) + + # Second pass: assign memory planning info using completed predecessor mappings + for index, node in enumerate(nodes): + size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) + succ_nodes = node_to_succ_nodes[node] + node.mpi_node = MemoryPlanningInfoForNode( index=index, size=size_alloc, - pred_buffers=pred_buffers, - pred_nodes=pred_nodes, + pred_buffers=node_to_pred_buffers[node], + pred_nodes=node_to_pred_nodes[node], succ_nodes=succ_nodes, ) -def estimate_peak_memory( +# map each scheduler buffer to its size, start step, and end step +@dataclasses.dataclass +class BufferInfo: + buffer: Union[SchedulerBuffer, FreeableInputBuffer] + size_alloc: int + size_free: int + start_step: int + end_step: int + + +def compute_memory_timeline( nodes: list[BaseSchedulerNode], name_to_freeable_input_buf: dict[str, FreeableInputBuffer], graph_outputs: OrderedSet[str], -) -> tuple[int, list[int]]: +) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]: """ - Given a list of nodes in their execution order, estimate the peak memory, by - keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. - - Returns: - int: peak memory - List[int]: memory usage at each node (or each step). + Compute buffer allocation and deallocation sizes and map their + lifetime to the node schedule """ - # map each scheduler buffer to its size, start step, and end step - @dataclasses.dataclass - class BufferInfo: - buffer: Union[SchedulerBuffer, FreeableInputBuffer] - size_alloc: int - size_free: int - start_step: int - end_step: int - # get the execution step of each node, this will be used to determine # the end_step of buffers node_to_step: dict[BaseSchedulerNode, int] = { @@ -325,6 +362,27 @@ class BufferInfo: ) ) + return buf_info_list, node_to_step + + +def estimate_peak_memory( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + graph_outputs: OrderedSet[str], +) -> tuple[int, list[int]]: + """ + Given a list of nodes in their execution order, estimate the peak memory, by + keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. + + Returns: + int: peak memory + List[int]: memory usage at each node (or each step). + """ + + buf_info_list, _ = compute_memory_timeline( + nodes, name_to_freeable_input_buf, graph_outputs + ) + # incremental memory changes at each step memory = [0 for _ in range(len(nodes) + 1)] diff --git a/torch/_inductor/runtime/debug_utils.py b/torch/_inductor/runtime/debug_utils.py new file mode 100644 index 0000000000000..9c15ff890dda6 --- /dev/null +++ b/torch/_inductor/runtime/debug_utils.py @@ -0,0 +1,138 @@ +import functools +import logging +import threading +import weakref + +import torch +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + +local = threading.local() +local.memory_tracker = None + + +class BufferMemoryTracker: + """ + Tracks inductor runtime allocations and deallocations to compare against + expected behavior. + """ + + def __init__(self) -> None: + self.tensor_tracker: dict[str, torch.storage.UntypedStorage] = ( + weakref.WeakValueDictionary() # type: ignore[assignment] + ) + self.died_since_last_step: OrderedSet[str] = OrderedSet() + self.added_since_last_step: OrderedSet[str] = OrderedSet() + self.error = ( + torch._inductor.config.test_configs.track_memory_lifecycle == "assert" + ) + + def set_tensor(self, name: str, tensor: torch.Tensor) -> None: + storage = tensor.untyped_storage() + + self.added_since_last_step.add(name) + self.tensor_tracker[name] = storage + + def on_tensor_death() -> None: + self.died_since_last_step.add(name) + + weakref.finalize(storage, on_tensor_death) + + def advance_step(self) -> None: + self.died_since_last_step.clear() + self.added_since_last_step.clear() + + def log_or_raise(self, msg: str) -> None: + if self.error: + raise RuntimeError(msg) + else: + log.info(msg) + + def check_step_delta( + self, + expected_allocated: list[str], + expected_freed: list[str], + is_final_step: bool, + ) -> None: + """Check only the delta changes since last step""" + + # Check expected deaths - we dont currently distinguish between nodes which die in last step + # and are returned as outputs, so skip if final_step. + if not is_final_step: + missing_deaths = OrderedSet(expected_freed) - self.died_since_last_step + if missing_deaths: + self.log_or_raise( + f"Expected tensors to die but still alive: {missing_deaths}" + ) + + # Check for unexpected deaths + unexpected_deaths = self.died_since_last_step - OrderedSet(expected_freed) + if unexpected_deaths: + self.log_or_raise(f"Unexpected tensor deaths: {unexpected_deaths}") + + # Check newly alive tensors - separate messages like deaths + actual_allocated = self.added_since_last_step + expected_allocated_set = OrderedSet(expected_allocated) + + extra_alive = actual_allocated - expected_allocated_set + if extra_alive: + self.log_or_raise(f"Unexpected allocated tensors: {extra_alive}") + + missing_alive = expected_allocated_set - actual_allocated + if missing_alive: + self.log_or_raise( + f"Expected allocated tensors but missing: {missing_alive}" + ) + + # Reset for next step + self.advance_step() + + if is_final_step: + local.memory_tracker = None + + +def get_mem_tracker() -> BufferMemoryTracker: + if local.memory_tracker is None: + local.memory_tracker = BufferMemoryTracker() + return local.memory_tracker + + +def track_tensor(tensor: torch.Tensor, name: str) -> None: + get_mem_tracker().set_tensor(name, tensor) + + +def tracked_empty_strided( + size: list[int], + stride: list[int], + *, + dtype: torch.dtype, + device: torch.device, + name: str, +) -> torch.Tensor: + o = torch.empty_strided(size, stride, dtype=dtype, device=device) + track_tensor(o, name) + return o + + +def check_memory_step( + allocated: list[str], freed: list[str], is_final_step: bool = False +) -> None: + tracker = get_mem_tracker() + tracker.check_step_delta(allocated, freed, is_final_step) + + +@functools.lru_cache(None) +def register_check_mem_op() -> None: + lib = torch.library.Library("_inductor_debug", "FRAGMENT") # noqa: TOR901 + lib.define( + "check_memory_step(str[] allocated, str[] freed, bool is_final_step) -> ()" + ) + lib.impl("check_memory_step", check_memory_step, "BackendSelect") + from torch._higher_order_ops.effects import _EffectType, _register_effectful_op + + _register_effectful_op( + torch.ops._inductor_debug.check_memory_step.default, + _EffectType.ORDERED, + ) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 951f07ab7a5ba..abd2fe413d1af 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2184,6 +2184,10 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) self.compute_last_usage() + + if torch._inductor.config.test_configs.track_memory_lifecycle: + self.insert_memory_check_nodes() + log_ir_post_fusion(self.nodes) V.debug.graph_diagram(self.nodes) self.debug_draw_graph() @@ -2518,6 +2522,83 @@ def add_user( compute_dependencies_log.debug("BUFFER USER LIST\n") compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str) + def insert_memory_check_nodes(self) -> None: + from .memory import ( + assign_memory_planning_info_for_scheduler_buffers, + compute_memory_timeline, + FreeableInputBuffer, + get_freeable_input_buf, + ) + + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = ( + get_freeable_input_buf(self.nodes, graph_inputs) + ) + + if not torch._inductor.config.reorder_for_peak_memory: + assign_memory_planning_info_for_scheduler_buffers( + self.nodes, self.name_to_buf + ) + + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + buf_info_list, _ = compute_memory_timeline( + self.nodes, + name_to_freeable_input_buf, + graph_outputs, + ) + + step_allocs_deallocs: list[tuple[list[str], list[str]]] = [ + ([], []) for _ in range(len(self.nodes)) + ] + for buf_info in buf_info_list: + # Skip zero-size buffers + if buf_info.size_alloc == 0 and buf_info.size_free == 0: + continue + + buf_name = buf_info.buffer.get_name() + + step_allocs_deallocs[buf_info.start_step][0].append(buf_name) + step_allocs_deallocs[buf_info.end_step][1].append(buf_name) + + from torch._inductor.runtime.debug_utils import register_check_mem_op + + register_check_mem_op() + + def construct_mem_check_node( + step_idx: int, is_final_step: bool + ) -> ExternKernelSchedulerNode: + expected_newly_alive = step_allocs_deallocs[step_idx][0] + expected_newly_dead = step_allocs_deallocs[step_idx][1] + + nontensor_args = [expected_newly_alive, expected_newly_dead, is_final_step] + + node = ir.MemoryCheckKernel( + layout=NoneLayout(device=torch.device("cpu")), + kernel=torch.ops._inductor_debug.check_memory_step.default, + tensor_args=[], + nontensor_args=nontensor_args, + unflatten_args=lambda tensor_args, constant_args: ( + tensor_args, + { + "alive": constant_args[0], + "dead": constant_args[1], + "is_final_step": constant_args[2], + }, + ), + ) + node.operation_name = f"mem_check_{self.nodes[step_idx].get_name()}" + return ExternKernelSchedulerNode(self, node) + + new_nodes = [] + + for i, node in enumerate(self.nodes): + new_nodes.append(node) + new_nodes.append( + construct_mem_check_node(i, is_final_step=(i == len(self.nodes) - 1)) + ) + + self.nodes = new_nodes + def dead_node_elimination(self) -> None: """ Remove any nodes without users From 9b953bb3fbc838d4da45ae0cd7d72492c5585c1c Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 5 Aug 2025 11:59:20 -0700 Subject: [PATCH 0007/1424] [BE] Update TensorPipe pin (#159834) No functional changes, just: - Update C++ standard to C++17 - Update `cmake` min version to 3.18 - Update `libuv` dependency to 1.51 (to move its cmake min version to 3.10) - Replace boost optional implementation with `std::optional` wrapper - Make it compilable with gcc-14.x plus by including `cstddef` in few headers - Avoid using deprecated enums for MacOS builds Pull Request resolved: https://github.com/pytorch/pytorch/pull/159834 Approved by: https://github.com/Skylion007 --- cmake/Dependencies.cmake | 7 ------- third_party/tensorpipe | 2 +- third_party/tensorpipe.BUILD | 10 +++++----- torch/csrc/distributed/rpc/tensorpipe_agent.cpp | 2 -- torch/csrc/distributed/rpc/tensorpipe_cuda.cpp | 2 -- torch/csrc/distributed/rpc/tensorpipe_utils.cpp | 2 -- 6 files changed, 6 insertions(+), 19 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index d11915fe43147..3b4b6adac94b1 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1166,17 +1166,10 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE) # Tensorpipe uses cuda_add_library torch_update_find_cuda_flags() - if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") - message(WARNING "Archived TensorPipe forces CMake compatibility mode") - set(CMAKE_POLICY_VERSION_MINIMUM 3.5) - endif() add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/tensorpipe) # Suppress warning to unblock libnop compilation by clang-17 # See https://github.com/pytorch/pytorch/issues/151316 target_compile_options_if_supported(tensorpipe -Wno-missing-template-arg-list-after-template-kw) - if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") - unset(CMAKE_POLICY_VERSION_MINIMUM) - endif() list(APPEND Caffe2_DEPENDENCY_LIBS tensorpipe) list(APPEND Caffe2_DEPENDENCY_LIBS nlohmann) diff --git a/third_party/tensorpipe b/third_party/tensorpipe index 52791a2fd214b..dacda0567d9f2 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit 52791a2fd214b2a9dc5759d36725909c1daa7f2e +Subproject commit dacda0567d9f23d4bc503e1c4f84aa65f33ac38a diff --git a/third_party/tensorpipe.BUILD b/third_party/tensorpipe.BUILD index ece345fda4a26..5e5b69b4cb4ec 100644 --- a/third_party/tensorpipe.BUILD +++ b/third_party/tensorpipe.BUILD @@ -7,6 +7,7 @@ LIBUV_COMMON_SRCS = [ "third_party/libuv/src/inet.c", "third_party/libuv/src/random.c", "third_party/libuv/src/strscpy.c", + "third_party/libuv/src/strtok.c", "third_party/libuv/src/threadpool.c", "third_party/libuv/src/timer.c", "third_party/libuv/src/uv-common.c", @@ -37,9 +38,7 @@ LIBUV_POSIX_SRCS = [ LIBUV_LINUX_SRCS = LIBUV_POSIX_SRCS + [ "third_party/libuv/src/unix/proctitle.c", - "third_party/libuv/src/unix/linux-core.c", - "third_party/libuv/src/unix/linux-inotify.c", - "third_party/libuv/src/unix/linux-syscalls.c", + "third_party/libuv/src/unix/linux.c", "third_party/libuv/src/unix/procfs-exepath.c", "third_party/libuv/src/unix/random-getrandom.c", "third_party/libuv/src/unix/random-sysctl-linux.c", @@ -60,6 +59,7 @@ cc_library( "third_party/libuv/src/unix/*.h", ], ), + copts = ["-D_GNU_SOURCE"], visibility = ["//visibility:public"], ) @@ -151,7 +151,7 @@ cc_library( ".", ], copts = [ - "-std=c++14", + "-std=c++17", ], visibility = ["//visibility:public"], deps = [ @@ -168,7 +168,7 @@ cc_library( ".", ], copts = [ - "-std=c++14", + "-std=c++17", ], visibility = ["//visibility:public"], deps = [ diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 1907520702503..c25e83c07c6db 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -8,10 +8,8 @@ #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") -C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() -C10_DIAGNOSTIC_POP() #include #include diff --git a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp index 03b43184d143b..4c326b6a0e276 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp @@ -7,12 +7,10 @@ #include #include -C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") #include #include C10_DIAGNOSTIC_POP() -C10_DIAGNOSTIC_POP() namespace torch::distributed::rpc { namespace { diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index f28aefc06dee0..86308ae6cdf35 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -6,10 +6,8 @@ #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") -C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() -C10_DIAGNOSTIC_POP() namespace torch::distributed::rpc { namespace { From a45a8409267f3dcb7ae3c63d08e43d7c904c9003 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 5 Aug 2025 13:46:52 -0700 Subject: [PATCH 0008/1424] [CI] Disable check-labels and check_mergeability (#159900) See https://github.com/pytorch/pytorch/issues/159825 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159900 Approved by: https://github.com/clee2000 --- .github/workflows/check-labels.yml | 3 ++- .github/workflows/check_mergeability_ghstack.yml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/check-labels.yml b/.github/workflows/check-labels.yml index 44430522b79d8..a3a87708e966e 100644 --- a/.github/workflows/check-labels.yml +++ b/.github/workflows/check-labels.yml @@ -34,7 +34,8 @@ jobs: contents: read pull-requests: write name: Check labels - if: github.repository_owner == 'pytorch' + # Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved + if: github.repository_owner == 'pytorch' && false runs-on: linux.24_04.4x steps: - name: Checkout PyTorch diff --git a/.github/workflows/check_mergeability_ghstack.yml b/.github/workflows/check_mergeability_ghstack.yml index 569a174665ba8..689ee250c809a 100644 --- a/.github/workflows/check_mergeability_ghstack.yml +++ b/.github/workflows/check_mergeability_ghstack.yml @@ -7,7 +7,8 @@ on: jobs: ghstack-mergeability-check: - if: github.repository_owner == 'pytorch' + # Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved + if: github.repository_owner == 'pytorch' && false runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 From b52a4d0821d9494ef6c11888a1855195dc4092f0 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 5 Aug 2025 21:31:53 +0000 Subject: [PATCH 0009/1424] [ez][CI] Remove some unused docker images (#159171) Removes unused docker images from the docker build workflow Then removes unused definitions in build.sh The only one I left is the vllm one because I'm pretty sure it's going to be used in the future I assume everything not mentioned is old and we forgot to remove them Pull Request resolved: https://github.com/pytorch/pytorch/pull/159171 Approved by: https://github.com/yangw-dev --- .ci/docker/build.sh | 55 ----------------------------- .github/workflows/docker-builds.yml | 5 --- 2 files changed, 60 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index a286d8da39ac6..0bf0847c3400d 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -144,16 +144,6 @@ case "$tag" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9) - CUDA_VERSION=12.6.3 - ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.12 @@ -164,39 +154,6 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks) - CUDA_VERSION=12.6 - ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - INDUCTOR_BENCHMARKS=yes - ;; - pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks) - CUDA_VERSION=12.6 - ANACONDA_PYTHON_VERSION=3.12 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - INDUCTOR_BENCHMARKS=yes - ;; - pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks) - CUDA_VERSION=12.6 - ANACONDA_PYTHON_VERSION=3.13 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - INDUCTOR_BENCHMARKS=yes - ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 @@ -219,18 +176,6 @@ case "$tag" in VISION=yes TRITON=yes ;; - pytorch-linux-jammy-py3.11-clang12) - ANACONDA_PYTHON_VERSION=3.11 - CLANG_VERSION=12 - VISION=yes - TRITON=yes - ;; - pytorch-linux-jammy-py3.9-gcc9) - ANACONDA_PYTHON_VERSION=3.9 - GCC_VERSION=9 - VISION=yes - TRITON=yes - ;; pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3) if [[ $tag =~ "jammy" ]]; then ANACONDA_PYTHON_VERSION=3.10 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index c27f651b6b3aa..548847944cd73 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -51,17 +51,12 @@ jobs: docker-image-name: [ pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, - pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks, - pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.9-clang12, - pytorch-linux-jammy-py3.11-clang12, - pytorch-linux-jammy-py3.12-clang12, pytorch-linux-jammy-py3.13-clang12, pytorch-linux-jammy-rocm-n-py3, pytorch-linux-noble-rocm-n-py3, From 882d50c5bf0a29ee481f2235235ef0c73000ed40 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 5 Aug 2025 10:03:48 -0700 Subject: [PATCH 0010/1424] [C10] Add `Scalar::isUnsigned()` method (#159877) That returns true if Scalar hold unsigned integral value With the implications of `Tag::HAS_u` semantic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159877 Approved by: https://github.com/Skylion007, https://github.com/ezyang --- c10/core/Scalar.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 3b483c86bc88f..646a1dde39940 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -191,11 +191,17 @@ class C10_API Scalar { isIntegral() const { return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag; } + bool isIntegral(bool includeBool) const { return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag || (includeBool && isBoolean()); } + // See Note [Meaning of HAS_u] + bool isUnsigned() const { + return Tag::HAS_u == tag || (Tag::HAS_i == tag && v.i >= 0); + } + bool isComplex() const { return Tag::HAS_z == tag; } From 8085edc8f9c98f670f585586b4286a942927537a Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 5 Aug 2025 11:11:15 -0700 Subject: [PATCH 0011/1424] [autograd] torch._C._set_view_replay_enabled state leaking into other tests (#159840) This was causing view_fns to pop up in tests that ran after `TestAutograd.test_view_replay_enabled` where it isn't used as a context manager. It is unclear to me why we would want `_force_original_view_tracking` to mutate global state on __init__ rather than on __enter__, that could be an alternative fix. FIXES https://github.com/pytorch/pytorch/issues/156306 https://github.com/pytorch/pytorch/issues/156289 https://github.com/pytorch/pytorch/issues/156265 https://github.com/pytorch/pytorch/issues/156209 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159840 Approved by: https://github.com/albanD --- test/test_autograd.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_autograd.py b/test/test_autograd.py index 01929a276f569..e26e193cc799a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -109,6 +109,10 @@ def graph_desc(fn): class TestAutograd(TestCase): + def tearDown(self): + torch.autograd._force_original_view_tracking(False) + super(TestCase, self).tearDown() + def test_copy_slices_graph_task_updates(self): def f1(x, y): out = x.clone().view(-1) From bdb07a2bc54df66441d69b49b5a215f09a0b1927 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 5 Aug 2025 11:57:58 -0700 Subject: [PATCH 0012/1424] [Cutlass] Allow offsets to be passed as arguments to kernel (#159761) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159761 Approved by: https://github.com/henrylhtsang ghstack dependencies: #159760 --- test/inductor/test_cutlass_backend.py | 20 ++++++++++++ test/inductor/test_cutlass_evt.py | 10 +++--- torch/_inductor/codegen/cuda/cuda_kernel.py | 31 ++++++++++--------- torch/_inductor/codegen/cuda/cuda_template.py | 17 +++++++--- .../cutlass_lib_extensions/evt_extensions.py | 3 +- torch/_inductor/codegen/cuda/gemm_template.py | 2 +- 6 files changed, 56 insertions(+), 27 deletions(-) diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index dc9abf2e20c6f..ea0fa87382145 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -1793,6 +1793,26 @@ def test_cutlass_backend_matmul_same_tensor(self): torch.testing.assert_close(A @ A.t(), compiled(A, A.t())) + @unittest.skipIf(not SM90OrLater, "need sm_90") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_cutlass_backend_matmul_nonzero_offset(self): + max_autotune_gemm_backends = "CUTLASS" + + M = 129 + A = torch.randn(M, M - 1).cuda().half() + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_max_profiling_configs": 2, + } + ): + compiled = torch.compile(torch.mm) + torch.testing.assert_close( + A[1:, :] @ A[1:, :].t(), compiled(A[1:, :], A[1:, :].t()) + ) + @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_flexible_layout(self): diff --git a/test/inductor/test_cutlass_evt.py b/test/inductor/test_cutlass_evt.py index d6891af6e6afa..eb468c3910209 100644 --- a/test/inductor/test_cutlass_evt.py +++ b/test/inductor/test_cutlass_evt.py @@ -392,12 +392,12 @@ def test_evt_argument_codegen(self): {}, /* C */ {}, /* compute_0 */ }, - {/* ptr_aux */ (float*) ptr_0, /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */ + {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */ {}, /* compute_1 */ }, - {/* ptr_aux */ (float*) ptr_1, /* dAux */ {2048, _1{}, _0{}}}, /* F */ + {/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* dAux */ {2048, _1{}, _0{}}}, /* F */ }, - {/* ptr_col */ (float*) ptr_2, /* null_default */ float(0), /* dCol */ {}}, /* bias */ + {/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ {}, /* compute_2 */ {}, /* compute_3 */ {}, /* compute_4 */ @@ -444,9 +444,9 @@ def fn(accum, bias): { /* thread */ { /* E */ {}, /* accum */ - {/* ptr_aux */ (float*) ptr_0, /* dAux */ {2048, _1{}, _0{}}}, /* E */ + {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* dAux */ {2048, _1{}, _0{}}}, /* E */ }, - {/* ptr_col */ (float*) ptr_1, /* null_default */ float(0), /* dCol */ {}}, /* bias */ + {/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ {}, /* compute_0 */ } """, diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 224f0d2a423dc..0a9c6b0ca4e5f 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -177,6 +177,9 @@ def get_ld(node) -> Union[Expr, int]: def get_dynamic_shape_args(self) -> list[Union[Expr, int]]: return [*self.get_layout_args(), *self.size_args] + def get_offset_args(self) -> list[Expr]: + return [node.get_layout().offset for node in self.named_nodes.values()] + @staticmethod def find_ld_idx(node: IRNode) -> int: strides = node.get_stride() @@ -264,6 +267,7 @@ def def_kernel( In this case, the `input_reorder` would be [2, 0, 1]. additional_size_args: Additional size arguments for epilogue inputs """ + # NB: name order matters here, it's used to match up offsets names = [x.strip() for x in names_str.strip().split(",")] if len(inputs) + len(outputs) != len(names): raise RuntimeError( @@ -285,6 +289,7 @@ def def_kernel( free_symbols: OrderedSet[Expr] = OrderedSet() for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): if node is not None: + # NB: named nodes must be populated in the order of names self.named_nodes[name] = node self.args.output_buffers[node.get_name()] = name @@ -306,14 +311,17 @@ def def_kernel( size_vars.extend(str(s) for s in free_symbols) self.size_args.extend(free_symbols) size_args = [f"const int {s}" for s in size_vars] - + offset_args = [f"const int {name}_offset" for name in self.named_nodes.keys()] runtime_arg_decls = ",".join( [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] ) if runtime_arg_decls: runtime_arg_decls += ", " - signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + signature = ( + f"int {self.kernel_name}({', '.join(arg_defs + size_args + offset_args)},\ + {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + ) self.signature = signature return signature @@ -346,10 +354,13 @@ def call_kernel( _, call_args, _, arg_types = self.args.python_argdefs() dynamic_shape_args = self.get_dynamic_shape_args() + offset_args = self.get_offset_args() call_args.extend(dynamic_shape_args) # type: ignore[arg-type] + call_args.extend(offset_args) # type: ignore[arg-type] for arg in self.runtime_arg_values: - call_args.append(arg) - arg_types.extend("int" for _ in dynamic_shape_args) + call_args.append(str(arg)) + arg_types.extend("const int" for _ in dynamic_shape_args) + arg_types.extend("const int" for _ in offset_args) for arg in self.runtime_arg_info: arg_types.append(arg.ty) # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar @@ -425,15 +436,6 @@ def max_valid_index(self, node: IRNode, default=-1): max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] return max_valid_offset - def offset(self, node: IRNode) -> str: - """ - Generates code which represents offset of a given node. - """ - - if node is None: - return "0" - return str(node.get_layout().offset) # type: ignore[union-attr] - def ptr(self, node: IRNode) -> str: """ Generates code which represents pointer of a given node. @@ -444,8 +446,7 @@ def ptr(self, node: IRNode) -> str: arg_name = self.arg_name(node) if arg_name is None: return "nullptr" - offset = self.offset(node) - return arg_name if offset == "0" else f"{arg_name} + {offset}" + return f"{arg_name} + {arg_name}_offset" def size( self, diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index cc03ccbdda863..4aa0aeb46e077 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -43,7 +43,7 @@ class ArgInfo: class CUDATemplate(KernelTemplate): index_counter = itertools.count() # dict of cache key to (code, size_args) - code_cache: dict[str, tuple[str, tuple[int, ...]]] = {} + code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {} cache_clear = staticmethod(code_cache.clear) def __init__( @@ -113,8 +113,12 @@ def generate_code_and_args( key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) if key is not None and key in self.code_cache: - code, size_args = self.code_cache[key] - extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + code, size_args, offset_args = self.code_cache[key] + extra_args = tuple( + list(size_args) + + list(offset_args) + + list(self.get_runtime_arg_values(**kwargs)) + ) return code, extra_args kernel_name = str(Placeholder.KERNEL_NAME) @@ -148,12 +152,15 @@ def generate_code_and_args( ) V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) + offset_args = V.graph.sizevars.size_hints(kernel.get_offset_args()) if key is not None: - self.code_cache[key] = code, size_args + self.code_cache[key] = code, size_args, offset_args # extra args has runtime params, which shouldn't be cached - extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + extra_args = tuple( + list(size_args) + list(offset_args) + self.get_runtime_arg_values(**kwargs) + ) return code, extra_args diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py index e42a13534e6f4..605b93dff5926 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -255,7 +255,8 @@ def render_stride(x: int) -> str: return f"{{{', '.join([render_stride(x) for x in stride])}}}" elif issubclass(arg_ty, ctypes.c_void_p): - return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {arg_renames.new_name(node.get_name())}" + name = arg_renames.new_name(node.get_name()) + return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) ({name} + {name}_offset)" elif ( arg_ty in _CUTLASS_C_DTYPES ): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 6436989bb0bca..e74161deeb141 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -1317,7 +1317,7 @@ def test_call_statement( f"(({arg_type}){arg_name}_data.get())" for arg_type, arg_name in zip(arg_types, arg_names) ] - return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 def _render_evt( self, From 410812763bddd8d6f08eb605e24976aece74195d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 5 Aug 2025 22:00:23 +0000 Subject: [PATCH 0013/1424] Revert "[Inductor][Triton] Support TMA before strict 3.4 cutoff (#159777)" This reverts commit bbc0df1094b5a4dcd2cce83f8402127b07913231. Reverted https://github.com/pytorch/pytorch/pull/159777 on behalf of https://github.com/izaitsevfb due to breaking inductor test on ROCm ([comment](https://github.com/pytorch/pytorch/pull/159777#issuecomment-3156770098)) --- torch/_inductor/codegen/triton.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index f8ad32fafc734..49e10d7c05127 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -26,7 +26,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing -from torch.utils._triton import has_triton_package, has_triton_stable_tma_api +from torch.utils._triton import has_triton_package from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -1692,12 +1692,14 @@ def __post_init__(self): def can_use_tma( self, ) -> bool: + import triton + if not ( V.graph.get_current_device_or_throw().type == "cuda" and torch.cuda.get_device_capability()[0] >= 9 and config.triton.use_tensor_descriptor and config.assume_aligned_inputs - and has_triton_stable_tma_api() + and triton.__version__ >= "3.4.0" # For CUDA The base ptr needs to be aligned ): log.debug( From 64cc6f06b17944e0c38a29e1117f76052cf0bc2d Mon Sep 17 00:00:00 2001 From: anwang Date: Mon, 4 Aug 2025 16:21:42 -0700 Subject: [PATCH 0014/1424] [Inductor] Revert minimal changes to avoid internal test failures (#159809) The diff/PR https://github.com/pytorch/pytorch/pull/159211 caused a bunch of test failures for graph compiler(T232684410). But I couldn't figure out a forward fix so far. So with this diff/PR, I'm proposing to revert the minimal changes to resolve the test failures. I'll continue the debugging, and re-land the reverted changes once we find out a forward fix. Differential Revision: [D79221721](https://our.internmc.facebook.com/intern/diff/D79221721/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159809 Approved by: https://github.com/blaine-rister, https://github.com/eellison --- torch/_dynamo/device_interface.py | 4 ---- torch/utils/_triton.py | 1 - 2 files changed, 5 deletions(-) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 9ea53c900b054..ada43dd08393b 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -590,10 +590,6 @@ def init_device_reg() -> None: for i in range(torch.xpu.device_count()): register_interface_for_device(f"xpu:{i}", XpuInterface) - register_interface_for_device("mtia", MtiaInterface) - for i in range(torch.mtia.device_count()): - register_interface_for_device(f"mtia:{i}", MtiaInterface) - register_interface_for_device("cpu", CpuInterface) register_interface_for_device("mps", MpsInterface) diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index af1e5e0e6f42a..55beae4baf18a 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -135,7 +135,6 @@ def _return_true(device_interface: Any) -> bool: "cuda": cuda_extra_check, "xpu": _return_true, "cpu": cpu_extra_check, - "mtia": _return_true, } def is_device_compatible_with_triton() -> bool: From 8034b2a7323aaa983df0e03c60521bb0e792622e Mon Sep 17 00:00:00 2001 From: Sandeep Narendranath Karjala Date: Tue, 5 Aug 2025 11:30:55 -0700 Subject: [PATCH 0015/1424] [inductor] Add TLParse artifact for logging runtime of collective and compute ops (#159730) Summary: - debug.py: Added log_runtime_estimates() function to dump runtime estimation data as structured tlparse artifacts in JSON format - test_structured_trace.py: Added comprehensive test coverage with testing compute and collective ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/159730 Approved by: https://github.com/yushangdi ghstack dependencies: #159190 --- test/dynamo/test_structured_trace.py | 145 +++++++++++++++++++++++++++ torch/_inductor/compile_fx.py | 6 ++ torch/_inductor/config.py | 6 ++ torch/_inductor/debug.py | 23 +++++ 4 files changed, 180 insertions(+) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index b692c5ee8d4a1..77ef75d125367 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -1208,6 +1208,151 @@ def forward(self, x): finally: dist.destroy_process_group() + @contextmanager + def _setup_runtime_estimates_capture(self): + """Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace.""" + payload_buffer = io.StringIO() + payload_handler = logging.StreamHandler(payload_buffer) + payload_handler.setLevel(logging.DEBUG) + payload_handler.setFormatter(StructuredTracePayloadFormatter()) + payload_handler.addFilter( + StructuredTraceTestingFilter("inductor_tlparse_runtime") + ) + trace_log.addHandler(payload_handler) + try: + yield payload_buffer + finally: + trace_log.removeHandler(payload_handler) + + @requires_tlparse + @requires_distributed() + @requires_cuda + @torch._inductor.config.patch("fx_graph_cache", False) + @torch._inductor.config.patch("log_tlparse", True) + def test_runtime_estimates_simple(self): + """Test runtime estimates logging with simple compute and collective ops.""" + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class SimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + h = self.linear(x) + h = torch.relu(h) + + h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0") + h = torch.ops._c10d_functional.wait_tensor.default(h) + return h + + try: + with self._setup_runtime_estimates_capture() as payload_buffer: + torch._dynamo.reset() + + mod = SimpleModule().cuda() + compiled = torch.compile(mod, backend="inductor") + compiled(torch.randn(4, 4, device="cuda")) + + # Verify runtime estimates artifact was logged + self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue()) + + payload_content = payload_buffer.getvalue().strip() + if payload_content: + data = json.loads(payload_content) + self.assertIn("ops", data) + ops = data["ops"] + + # Verify runtime estimates + compute_ops = [op for op in ops if op["type"] == "compute"] + collective_ops = [op for op in ops if op["type"] == "collective"] + + self.assertTrue(len(compute_ops) > 0 or len(collective_ops) > 0) + + # All ops should have runtime > 0 except wait_tensor can be 0 + for op in ops: + if "wait_tensor" not in op["name"]: + self.assertGreater( + op["estimated_runtime_ns"], + 0, + f"Op {op['name']} should have runtime > 0", + ) + + self.assertParses() + finally: + dist.destroy_process_group() + + @requires_tlparse + @requires_distributed() + @requires_cuda + @torch._inductor.config.patch("fx_graph_cache", False) + @torch._inductor.config.patch("log_tlparse", True) + def test_runtime_estimates_mixed(self): + """Test runtime estimates logging with mixed compute and collective sequence.""" + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class MixedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.LayerNorm(4) + + def forward(self, x): + h = self.norm(x) + h = torch.nn.functional.gelu(h) + + h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0") + h = torch.ops._c10d_functional.wait_tensor.default(h) + + h = h * 0.5 + + gathered = torch.ops._c10d_functional.all_gather_into_tensor.default( + h, 2, "0" + ) + gathered = torch.ops._c10d_functional.wait_tensor.default(gathered) + + return gathered.sum(dim=0) + + try: + with self._setup_runtime_estimates_capture() as payload_buffer: + torch._dynamo.reset() + + mod = MixedModule().cuda() + compiled = torch.compile(mod, backend="inductor") + compiled(torch.randn(4, 4, device="cuda")) + + # Verify runtime estimates artifact was logged + self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue()) + + payload_content = payload_buffer.getvalue().strip() + if payload_content: + data = json.loads(payload_content) + self.assertIn("ops", data) + ops = data["ops"] + + # Should have both compute and collective ops + op_types = {op["type"] for op in ops} + self.assertIn("compute", op_types) + self.assertIn("collective", op_types) + + # All ops should have runtime > 0 except wait_tensor can be 0 + for op in ops: + if "wait_tensor" not in op["name"]: + self.assertGreater( + op["estimated_runtime_ns"], + 0, + f"Op {op['name']} should have runtime > 0", + ) + + self.assertParses() + finally: + dist.destroy_process_group() + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index bb00f46886f84..d17ffe19b3c70 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1509,6 +1509,7 @@ def codegen_and_compile( compiled_module, "runner", None ) + node_runtimes = None if inductor_metrics_log.isEnabledFor(logging.INFO): num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() metrics.num_bytes_accessed += num_bytes @@ -1523,6 +1524,11 @@ def codegen_and_compile( }, ) + # Collect and dump op runtimes for TLParse + if config.log_tlparse: + _, _, node_runtimes = graph.count_bytes() + torch._inductor.debug.log_runtime_estimates(node_runtimes) + # Collect and dump collective-op schedule for external diagnostics torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index a42eb3cdeda90..c6971301efe6c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -741,6 +741,12 @@ def decide_worker_start_method() -> str: default=True, ) +# Log per-operation runtime estimates for TLParse analysis. +log_tlparse: bool = Config( + env_name_force="LOG_TLPARSE", + default=False, +) + # Flags to turn on all_reduce fusion. These 2 flags should be automatically turned # on by DDP and should not be set by the users. _fuse_ddp_communication = False diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 2400b8235ca9c..f3be4a6b5506f 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -22,6 +22,7 @@ from torch import fx as fx from torch._dynamo.repro.after_aot import save_graph_repro from torch._dynamo.utils import get_debug_dir +from torch._inductor import utils from torch._logging import getArtifactLogger from torch._logging._internal import trace_structured from torch.fx.graph_module import GraphModule @@ -721,6 +722,28 @@ def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None: _dump_collective_schedule(schedule) +def log_runtime_estimates(node_runtimes: Sequence[tuple[Any, float]]) -> None: + """Log per-operation runtime estimates for TLParse.""" + + ops = [ + { + "name": getattr(s.node, "python_kernel_name", s.get_name()), + "type": "collective" if utils.is_collective(s.node) else "compute", + "estimated_runtime_ns": runtime_ns, + } + for s, runtime_ns in node_runtimes + ] + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_tlparse_runtime", + "encoding": "json", + }, + payload_fn=lambda: {"ops": ops}, + ) + + @dataclasses.dataclass class TensorMetadataHolder: tensor_metadata: TensorMetadata From fb35a9ea4ac074a882d1069ccbd626f0e49c3353 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 5 Aug 2025 22:26:48 +0000 Subject: [PATCH 0016/1424] [export] Improve error messages (#159881) Originally, if the PT2 errored when loading, we would try to load using the old loader to fit BC issues. However this hides the error messages for if an up-to-date PT2 is erroring when loading due to some other reason. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159881 Approved by: https://github.com/yushangdi --- torch/export/__init__.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 3ed8a6c37883f..51f0865f43049 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -1,3 +1,4 @@ +import logging import os import warnings import zipfile @@ -52,6 +53,8 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] +log: logging.Logger = logging.getLogger(__name__) + @deprecated( "`torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. " @@ -440,7 +443,8 @@ def load( f, expected_opset_version=expected_opset_version, ) - except RuntimeError: + except RuntimeError as e: + log.warning("Ran into the following error when deserializing: %s", e) pt2_contents = PT2ArchiveContents({}, {}, {}) if len(pt2_contents.exported_programs) > 0 or len(pt2_contents.extra_files) > 0: @@ -450,10 +454,18 @@ def load( return pt2_contents.exported_programs["model"] # TODO: For backward compatibility, we support loading a zip file from 2.7. Delete this path in 2.9(?) - warnings.warn( - "This version of file is deprecated. Please generate a new pt2 saved file." - ) with zipfile.ZipFile(f, "r") as zipf: + if "version" not in zipf.namelist(): + raise RuntimeError( + "We ran into an error when deserializing the saved file. " + "Please check the warnings above for possible errors. " + ) + + log.warning( + "Trying to deserialize for the older format. This version of file is " + "deprecated. Please generate a new pt2 saved file." + ) + # Check the version version = zipf.read("version").decode().split(".") from torch._export.serde.schema import ( From b1ec088113bac8c7602c3cc4ede5ea2c194154c4 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 5 Aug 2025 11:46:12 -0700 Subject: [PATCH 0017/1424] [mps] Turn on inductor dynamic shapes tests (#159456) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159456 Approved by: https://github.com/Skylion007, https://github.com/malfet --- test/inductor/test_torchinductor.py | 19 +++++++++ ...st_torchinductor_codegen_dynamic_shapes.py | 4 +- .../test_torchinductor_dynamic_shapes.py | 39 +++++++++++++++++-- test/run_test.py | 1 + torch/_inductor/codegen/mps.py | 10 ++--- 5 files changed, 64 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index e7b6695fee7b7..ed4b1ba3e466d 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -13693,6 +13693,25 @@ def new_test(self, value=value): other_cls.is_dtype_supported = my_cls.is_dtype_supported +def add_test_failures( + test_failures: dict[str, TestFailure], added_test_failures: dict[str, TestFailure] +): + """ + In-place modifies the given dictionary of `test_failures` to add the + contents of `added_test_failures` by unioning the test_failure.suffixes, and + or-ing the the is_skip value. + """ + for name, new_failure in added_test_failures.items(): + if name in test_failures: + orig_failure = test_failures[name] + orig_failure.suffixes = tuple( + set(orig_failure.suffixes).union(set(new_failure.suffixes)) + ) + orig_failure.is_skip = orig_failure.is_skip or new_failure.is_skip + else: + test_failures[name] = new_failure + + if RUN_CPU: class SweepInputsCpuTest(SweepInputs2, TestCase): diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 6a7d40b6b7cad..cdf76772b9366 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -25,6 +25,7 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + add_test_failures, CommonTemplate, copy_tests, run_and_get_cpp_code, @@ -382,9 +383,10 @@ def run(*ex, **kwargs): # Refinement means we don't actually generate dynamic shapes (but only on # cpu apparently?!) "test_nonzero_unbacked_refinement_dynamic_shapes": TestFailure(("cpu",)), - **dynamic_shapes_test_failures, } +add_test_failures(test_failures, dynamic_shapes_test_failures) + if not TEST_WITH_ROCM: test_failures.update( { diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index b75907894f63f..ba2a8c8f5248c 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -26,9 +26,11 @@ from torch.testing._internal.common_utils import ( IS_ARM64, IS_FBCODE, + MACOS_VERSION, parametrize, serialTest, TEST_CUDA_MEM_LEAK_CHECK, + TEST_MPS, TEST_WITH_ASAN, TEST_WITH_ROCM, ) @@ -36,6 +38,7 @@ GPU_TYPE, HAS_CPU, HAS_GPU, + HAS_MPS, patch_inductor_backend, ) @@ -59,9 +62,39 @@ "test_kwargs_dynamic_shapes": TestFailure(("cpu",)), # calling div on only symint args "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure( - ("cpu", "cuda", "xpu") + ("cpu", "cuda", "xpu", "mps") + ), + "test_argmax_argmin_with_duplicates_dynamic_shapes": TestFailure(("mps",)), + "test_batch_norm_2d_2_dynamic_shapes": TestFailure(("mps",)), + "test_buffer_batch_norm_dynamic_shapes": TestFailure(("mps",)), + "test_convolution4_dynamic_shapes": TestFailure(("mps",)), + "test_index_propagation_abs_dynamic_shapes": TestFailure(("mps",)), + "test_index_propagation_floordiv_dynamic_shapes": TestFailure(("mps",)), + "test_index_propagation_remainder_dynamic_shapes": TestFailure(("mps",)), + "test_multilayer_var_dynamic_shapes": TestFailure(("mps",)), + "test_multilayer_var_lowp_dynamic_shapes": TestFailure(("mps",)), + "test_reduction2_dynamic_shapes": TestFailure(("mps",)), + "test_reduction3_dynamic_shapes": TestFailure(("mps",)), + "test_reduction5_dynamic_shapes": TestFailure(("mps",)), + "test_reflection_pad2d_dynamic_shapes": TestFailure(("mps",)), + "test_require_stride_expanded_dynamic_shapes": TestFailure(("mps",)), + "test_roll_dynamic_shapes": TestFailure(("mps",)), + "test_std_dynamic_shapes": TestFailure(("mps",)), + "test_var_correction_dynamic_shapes": TestFailure(("mps",)), + "test_var_mean_div_by_dynamic_shapes": TestFailure(("mps",)), + "test_var_mean_tile_reduction_False_dynamic_shapes": TestFailure(("mps",)), + "test_var_mean_tile_reduction_True_dynamic_shapes": TestFailure(("mps",)), + "test_vectorized_ops_masked_var_novec_dynamic_shapes": TestFailure(("mps",)), + "test_reflection_pad2d_backward_dynamic_shapes": TestFailure( + ("mps",), is_skip=True ), } + +if TEST_MPS and MACOS_VERSION >= 15.0: + test_failures["test_scaled_dot_product_attention_dynamic_shapes"] = TestFailure( + "mps" + ) + if not torch._inductor.config.cpp_wrapper: test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure( ("cuda",) @@ -106,7 +139,7 @@ class DynamicShapesCpuTests(TestCase): copy_tests(DynamicShapesCommonTemplate, DynamicShapesCpuTests, "cpu", test_failures) -if HAS_GPU and not TEST_WITH_ASAN: +if (HAS_GPU or HAS_MPS) and not TEST_WITH_ASAN: class DynamicShapesGPUTests(TestCase): common = check_model_gpu @@ -1133,5 +1166,5 @@ def fn(a, descending): from torch._inductor.test_case import run_tests # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 - if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: + if (HAS_CPU or HAS_GPU or HAS_MPS) and not TEST_WITH_ASAN: run_tests(needs="filelock") diff --git a/test/run_test.py b/test/run_test.py index 7d1afb3f34c07..4c49acfdee9c0 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1582,6 +1582,7 @@ def get_selected_tests(options) -> list[str]: "inductor/test_mps_basic", "inductor/test_torchinductor", "inductor/test_aot_inductor", + "inductor/test_torchinductor_dynamic_shapes", ] else: # Exclude all mps tests otherwise diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 5850270a67e2c..d952a45d0b5a1 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -535,7 +535,7 @@ def _new_idxvar( var_def = "threadgroup " if is_threadgroup else "" var_def += f"{dtype} {var_name}" if elem_count: - var_def += f"[{elem_count}]" + var_def += f"[{self.sexpr(elem_count)}]" if default_value is not None: assert not is_threadgroup, "Thread group var can not have default value" var_def += f" = {default_value}" @@ -657,7 +657,7 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: ) return self.cse.generate( self.stores, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size_str})", dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], ) if reduction_type in ["argmin", "argmax"]: @@ -693,7 +693,7 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: return self.cse.generate( self.stores, f"c10::metal::threadgroup_{reduction_type}({data_acc_buf}, {idx_acc_buf}, " - f"{val}, {idx_val}, {reduction_idx}, {acc_buf_size})", + f"{val}, {idx_val}, {reduction_idx}, {acc_buf_size_str})", dtype=dtype, ) if reduction_type == "welford_reduce": @@ -702,7 +702,7 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") wf_res = self.cse.generate( self.compute, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})", dtype=torch.float32, ) return _unwrap_helper(wf_res) @@ -733,7 +733,7 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: self.compute.writeline(f"{acc_thread_var} = {inp_value};") wf_res = self.cse.generate( self.stores if self.multistage_reduction_entry else self.compute, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})", dtype=torch.float32, ) return _unwrap_helper(wf_res) From 74a754aae98aabc2aca67e5edb41cc684fae9a82 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 5 Aug 2025 11:46:13 -0700 Subject: [PATCH 0018/1424] Add meta kernel for sdpa_math_for_mps (#159695) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159695 Approved by: https://github.com/malfet ghstack dependencies: #159456 --- test/inductor/test_aot_inductor.py | 2 - .../test_torchinductor_dynamic_shapes.py | 7 -- test/test_mps.py | 73 +++++++++++++++++++ torch/_meta_registrations.py | 55 ++++++++++++++ 4 files changed, 128 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index e57a9c00fd700..9b501315cd9c2 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6916,8 +6916,6 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): # MPS doesn't support float8 "test_fp8": fail_mps(), "test_fp8_view_of_param": fail_mps(), - # unsupported operator: aten._scaled_dot_product_attention_math_for_mps.default - "test_issue_140766": fail_mps(), # cannot initialize a parameter of type 'double' with an rvalue of type 'std::nullptr_t' "test_fallback_kernel_with_symexpr_output": fail_mps(), # while-loop subgraph calls same kernel as outside. need to figure out how to diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index ba2a8c8f5248c..a2d5ff9be6c23 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -26,11 +26,9 @@ from torch.testing._internal.common_utils import ( IS_ARM64, IS_FBCODE, - MACOS_VERSION, parametrize, serialTest, TEST_CUDA_MEM_LEAK_CHECK, - TEST_MPS, TEST_WITH_ASAN, TEST_WITH_ROCM, ) @@ -90,11 +88,6 @@ ), } -if TEST_MPS and MACOS_VERSION >= 15.0: - test_failures["test_scaled_dot_product_attention_dynamic_shapes"] = TestFailure( - "mps" - ) - if not torch._inductor.config.cpp_wrapper: test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure( ("cuda",) diff --git a/test/test_mps.py b/test/test_mps.py index 6dfce783316f2..975ba00cc7d8a 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -29,6 +29,7 @@ from torch.testing._internal.common_dtype import get_all_dtypes, integral_types import torch.backends.mps from torch.distributions import Uniform, Exponential +from torch.utils._python_dispatch import TorchDispatchMode from functools import partial from torch.testing._internal.common_methods_invocations import ( @@ -9446,6 +9447,78 @@ def test_fast_full_attention(self, dtype, contiguous, head_dim, with_mask): self.run_fast_attention_test(q, k, v, with_mask) + + +class TestSDPAMetaDispatchMode(TorchDispatchMode): + """ + TorchDispatchMode which intercepts the + _scaled_dot_product_attention_math_for_mps aten operator to check that the + meta kernel is correct. + """ + + def __init__(self, test): + self.test = test + super().__init__() + + def __torch_dispatch__(self, func, types, args, kwargs=None): + kwargs = kwargs or {} + res = func(*args, **kwargs) + if func != torch.ops.aten._scaled_dot_product_attention_math_for_mps.default: + return res + + meta_args, meta_kwargs = pytree.tree_map_only(torch.Tensor, lambda t: t.to(device="meta"), (args, kwargs)) + meta_res = func(*meta_args, **meta_kwargs) + + def format_res(res): + return [ + (t.shape, t.stride(), t.dtype) if isinstance(t, torch.Tensor) else t + for t in pytree.tree_flatten(res)[0] + ] + + # Format the output so that we only look at the tensor metadata + self.test.assertEqual(format_res(res), format_res(meta_res)) + return res + + +def create_sdpa_meta_test(): + """ + Creates a new class which takes every test in TestSDPA and adds the + TestSDPAMetaDispatchMode context in order to test the + scaled_dot_product_attention_for_mps meta kernel. This allows us to test all + the branches for the sdpa op. If there are changes to the sdpa kernel + without changing the meta kernel, a torch.compile guard will catch the issue + but not necessarily export. + """ + orig_test_cls = TestSDPA + + new_test_cls = type(f"{orig_test_cls.__name__}Meta", orig_test_cls.__bases__, {}) + new_test_cls.__qualname__ = new_test_cls.__name__ + + for name in dir(orig_test_cls): + if name.startswith("test_"): + fn = getattr(orig_test_cls, name) + if not callable(fn): + setattr(new_test_cls, name, getattr(orig_test_cls, name)) + continue + + new_name = f"{name}_meta" + + def new_fn(self, *args, **kwargs): + with TestSDPAMetaDispatchMode(self): + fn(self, *args, **kwargs) + + new_fn.__name__ = new_name + + setattr(new_test_cls, new_name, new_fn) + + elif not hasattr(new_test_cls, name): + setattr(new_test_cls, name, getattr(orig_test_cls, name)) + + return new_test_cls + +TestSDPAMeta = create_sdpa_meta_test() +instantiate_parametrized_tests(TestSDPAMeta) + class TestGatherScatter(TestCaseMPS): def test_slicing_with_step(self): # Slicing with step diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index fc9e8a8489d8a..fc16cf58c6406 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5861,6 +5861,61 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( return grad_q, grad_k, grad_v +@register_meta([aten._scaled_dot_product_attention_math_for_mps]) +def meta__scaled_dot_product_attention_math_for_mps( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + dropout_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +) -> tuple[Tensor, Tensor]: + def ensure_4d(x): + if x.dim() == 3: + return x.unsqueeze(0), True + elif x.dim() > 4: + batch_size = 1 + for i in range(x.dim() - 3): + batch_size *= x.shape[i] + return x.view(batch_size, x.size(-3), x.size(-2), x.size(-1)), True + else: + return x, False + + q_, unsqueezed = ensure_4d(query) + k_, _ = ensure_4d(key) + v_, _ = ensure_4d(value) + + batch_size, num_head, q_size, head_size = q_.shape + _, k_size, max_seq_length, _ = k_.shape + + def sdpa_vector_fast_mps(): + out = q_.new_empty(q_.shape) + if unsqueezed: + out = out.view_as(query) + + attn = q_.new_empty((batch_size, num_head, q_size, max_seq_length)) + if unsqueezed: + if query.dim() == 3: + attn = attn.squeeze(0) + else: + shape = list(query.shape[:-3]) + attn.shape[1:4] + attn = attn.view(shape) + return out, attn + + def sdpa_vector_2pass_mps(): + blocks = 32 + out = q_.new_empty(q_.shape) + intermediate = q_.new_empty((batch_size, num_head, q_size, blocks, head_size)) + return out, intermediate + + if (max_seq_length >= 1024) or (k_size < q_size and max_seq_length >= 4096): + return sdpa_vector_2pass_mps() + else: + return sdpa_vector_fast_mps() + + @register_meta([aten._scaled_dot_product_efficient_attention]) def meta__scaled_dot_product_efficient_attention( query: Tensor, From fe8984a9f43bde10d1956abe7cb40710ed7ceed2 Mon Sep 17 00:00:00 2001 From: Alex Malyshev Date: Tue, 5 Aug 2025 23:32:48 +0000 Subject: [PATCH 0019/1424] Set PYTHONHOME for inductor subprocesses using torch (#159382) Summary: This is needed for subprocesses that are trying to call back into torch functionality, i.e. anything that's also setting `PYTHONPATH`. There are more `sys.executable` subprocesses in torch/ but it seems like they're fine. Test Plan: Local inference runs. Reviewed By: aorenste Differential Revision: D79124705 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159382 Approved by: https://github.com/aorenste --- torch/_inductor/autotune_process.py | 3 +++ torch/_inductor/compile_worker/subproc_pool.py | 3 +++ torch/_inductor/cpu_vec_isa.py | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index c936fbe92c671..c3d4b6af651dc 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -12,6 +12,7 @@ import selectors import subprocess import sys +import sysconfig import time import warnings from collections.abc import Iterable, Sequence @@ -128,6 +129,8 @@ def start(self): "PYTHONPATH": os.environ.get( "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) ), + # Need to set this for internal builds that bundle the runtime. + "PYTHONHOME": sysconfig.get_path("data"), # We shouldn't be using the Triton async compile subprocess pool, # but as a precaution set the env var that disables its creation. "TORCH_WARM_POOL": "0", diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 0b670b268b37e..80e7e75898cbf 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -8,6 +8,7 @@ import struct import subprocess import sys +import sysconfig import threading import traceback import typing @@ -158,6 +159,8 @@ def __init__( "PYTHONPATH": os.environ.get( "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) ), + # Need to set this for internal builds that bundle the runtime. + "PYTHONHOME": sysconfig.get_path("data"), # Safeguard against creating a SubprocPool in the subprocess. "TORCH_WARM_POOL": "0", # Some internal usages need a modified LD_LIBRARY_PATH. diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index b077c4da9c28d..71a27e99628db 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -6,6 +6,7 @@ import re import subprocess import sys +import sysconfig import warnings from typing import Any, Callable, Union @@ -133,9 +134,12 @@ def check_build(self, code: str) -> bool: stderr=subprocess.DEVNULL, env={ **os.environ, + # We need to set the PYTHONPATH so the subprocess can find torch. "PYTHONPATH": os.environ.get( "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) ), + # Need to set this for internal builds that bundle the runtime. + "PYTHONHOME": sysconfig.get_path("data"), }, ) except Exception: From 1052604acd652ba2fce483a5fb6251fb93c9b18e Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 5 Aug 2025 23:44:38 +0000 Subject: [PATCH 0020/1424] fix logging setup issue for Windows.. (#159887) When we setup logging config as guide: https://docs.pytorch.org/docs/stable/logging.html Such as: TORCH_LOGS="+schedule,+inductor,+output_code" On Linux, it shows as: ```cmd declare -x SSH_TTY="/dev/pts/0" declare -x TERM="xterm" declare -x TORCH_LOGS="+schedule,+inductor,+output_code" declare -x USER="xu" ``` On Windows, it shows as: ```cmd TORCHINDUCTOR_WINDOWS_TESTS=1 TORCH_LOGS="+schedule,+inductor,+output_code" UCRTVersion=10.0.22000.0 ``` For Linux, it shows quotes by default, And Windows is not shows quotes. Besides that, Windows would auto assemble quotes when env var processing. On Linux, we will get variable: "+schedule,+inductor,+output_code" On Windows, we will get variable: '"+schedule,+inductor,+output_code"' So, we need remove the outer quotes for Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159887 Approved by: https://github.com/angelayi --- torch/_logging/_internal.py | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index ffd3160b47ee8..c4bdeceeb4947 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -726,8 +726,49 @@ def _invalid_settings_err_msg(settings, verbose=False): return msg +def process_env_var_string_for_windows(env_var_str: str) -> str: + """ + When we setup logging config as guide: https://docs.pytorch.org/docs/stable/logging.html + Such as: + TORCH_LOGS="+schedule,+inductor,+output_code" + + On Linux, it shows as: + declare -x SSH_TTY="/dev/pts/0" + declare -x TERM="xterm" + declare -x TORCH_LOGS="+schedule,+inductor,+output_code" + declare -x USER="xu" + + On Windows, it shows as: + TORCHINDUCTOR_WINDOWS_TESTS=1 + TORCH_LOGS="+schedule,+inductor,+output_code" + UCRTVersion=10.0.22000.0 + + For Linux, it shows quotes by default, And Windows is not shows quotes. + Besides that, Windows would auto assemble quotes when env var processing. + On Linux, we will get variable: "+schedule,+inductor,+output_code" + On Windows, we will get variable: '"+schedule,+inductor,+output_code"' + + So, we need remove the outer quotes for Windows. + """ + _IS_WINDOWS = sys.platform == "win32" + + def remove_outer_quotes(s: str) -> str: + if len(s) >= 2 and ( + (s[0] == '"' and s[-1] == '"') or (s[0] == "'" and s[-1] == "'") + ): + return s[1:-1] + return s + + if _IS_WINDOWS: + env_var_str = remove_outer_quotes(env_var_str) + + return env_var_str + + @functools.lru_cache def _parse_log_settings(settings): + settings = process_env_var_string_for_windows(settings) + if settings == "": return {} From 49abc0e3f897d7e077d6e8a7627833ea51c3655e Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 5 Aug 2025 23:47:42 +0000 Subject: [PATCH 0021/1424] [Take 2] Setup TorchBench in Docker (#159300) Fix and reland https://github.com/pytorch/pytorch/pull/158613, I keep `checkout_install_torchbench` in `.ci/pytorch/macos-test.sh` script because it's still used there, and there is no Docker. ### Testing MacOS perf nightly run https://github.com/pytorch/pytorch/actions/runs/16580798470 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159300 Approved by: https://github.com/ZainRizvi --- .../docker}/ci_commit_pins/torchbench.txt | 0 .../common/install_inductor_benchmark_deps.sh | 30 +++++++++++++++++-- .ci/docker/requirements-ci.txt | 1 - .ci/docker/ubuntu-rocm/Dockerfile | 3 +- .ci/docker/ubuntu/Dockerfile | 3 +- .ci/pytorch/common_utils.sh | 26 ---------------- .ci/pytorch/macos-test.sh | 25 ++++++++++++++-- .ci/pytorch/test.sh | 22 +++++--------- .github/workflows/torchbench.yml | 4 +++ .github/workflows/trunk.yml | 2 +- 10 files changed, 67 insertions(+), 49 deletions(-) rename {.github => .ci/docker}/ci_commit_pins/torchbench.txt (100%) diff --git a/.github/ci_commit_pins/torchbench.txt b/.ci/docker/ci_commit_pins/torchbench.txt similarity index 100% rename from .github/ci_commit_pins/torchbench.txt rename to .ci/docker/ci_commit_pins/torchbench.txt diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index 7312dce170db2..bda3aa6009564 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -15,11 +15,37 @@ function install_timm() { commit=$(get_pinned_commit timm) pip_install "git+https://github.com/huggingface/pytorch-image-models@${commit}" - # Clean up - conda_run pip uninstall -y torch torchvision triton +} + +function install_torchbench() { + local commit + commit=$(get_pinned_commit torchbench) + git clone https://github.com/pytorch/benchmark torchbench + pushd torchbench + git checkout "$commit" + + python install.py --continue_on_fail + + # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 + # is regressing speedup metric. This needs to be investigated further + pip install transformers==4.38.1 + + echo "Print all dependencies after TorchBench is installed" + python -mpip freeze + popd + + chown -R jenkins torchbench } # Pango is needed for weasyprint which is needed for doctr conda_install pango + +# Stable packages are ok here, just to satisfy TorchBench check +pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 + +install_torchbench install_huggingface install_timm + +# Clean up +conda_run pip uninstall -y torch torchvision torchaudio triton diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index d25f79766baf5..4de9431bf300f 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -361,7 +361,6 @@ pwlf==2.2.1 #Pinned versions: 2.2.1 #test that import: test_sac_estimator.py - # To build PyTorch itself pyyaml pyzstd diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 2528da07c69e3..8f2cc6eef9581 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface.txt huggingface.txt COPY ci_commit_pins/timm.txt timm.txt +COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt # (optional) Install non-default Ninja version ARG NINJA_VERSION diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 27c466dd8d41d..077910cef9f35 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface.txt huggingface.txt COPY ci_commit_pins/timm.txt timm.txt +COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt ARG TRITON ARG TRITON_CPU diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index b9a063a2c7ef6..06decc2ea64b5 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -229,7 +229,6 @@ function install_torchrec_and_fbgemm() { pip_install tabulate # needed for newer fbgemm pip_install patchelf # needed for rocm fbgemm - pushd /tmp local wheel_dir=dist/fbgemm_gpu local found_whl=0 @@ -264,7 +263,6 @@ function install_torchrec_and_fbgemm() { done rm -rf fbgemm - popd else pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu @@ -283,30 +281,6 @@ function clone_pytorch_xla() { fi } -function checkout_install_torchbench() { - local commit - commit=$(get_pinned_commit torchbench) - git clone https://github.com/pytorch/benchmark torchbench - pushd torchbench - git checkout "$commit" - - if [ "$1" ]; then - python install.py --continue_on_fail models "$@" - else - # Occasionally the installation may fail on one model but it is ok to continue - # to install and test other models - python install.py --continue_on_fail - fi - - # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 - # is regressing speedup metric. This needs to be investigated further - pip install transformers==4.38.1 - - echo "Print all dependencies after TorchBench is installed" - python -mpip freeze - popd -} - function install_torchao() { local commit commit=$(get_pinned_commit torchao) diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 83f8e4e04331d..c38448898cb4b 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -157,6 +157,29 @@ test_jit_hooks() { assert_git_not_dirty } +# Shellcheck doesn't like it when you pass no arguments to a function +# that can take args. See https://www.shellcheck.net/wiki/SC2120 +# shellcheck disable=SC2120 +checkout_install_torchbench() { + local commit + commit=$(cat .ci/docker/ci_commit_pins/torchbench.txt) + git clone https://github.com/pytorch/benchmark torchbench + pushd torchbench + git checkout "$commit" + + if [ "$1" ]; then + python install.py --continue_on_fail models "$@" + else + # Occasionally the installation may fail on one model but it is ok to continue + # to install and test other models + python install.py --continue_on_fail + fi + + echo "Print all dependencies after TorchBench is installed" + python -mpip freeze + popd +} + torchbench_setup_macos() { git clone --recursive https://github.com/pytorch/vision torchvision git clone --recursive https://github.com/pytorch/audio torchaudio @@ -179,8 +202,6 @@ torchbench_setup_macos() { USE_OPENMP=0 python setup.py develop popd - # Shellcheck doesn't like it when you pass no arguments to a function that can take args. See https://www.shellcheck.net/wiki/SC2120 - # shellcheck disable=SC2119,SC2120 checkout_install_torchbench } diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9f2a67b4ff45b..84d40a2e458a1 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1684,13 +1684,11 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then elif [[ "${TEST_CONFIG}" == cachebench ]]; then install_torchaudio install_torchvision - checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco - PYTHONPATH=$(pwd)/torchbench test_cachebench + PYTHONPATH=/torchbench test_cachebench elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then install_torchaudio install_torchvision - checkout_install_torchbench nanogpt - PYTHONPATH=$(pwd)/torchbench test_verify_cachebench + PYTHONPATH=/torchbench test_verify_cachebench elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then install_torchaudio install_torchvision @@ -1699,28 +1697,22 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then - checkout_install_torchbench hf_Bert hf_Albert timm_vision_transformer - PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf + PYTHONPATH=/torchbench test_inductor_torchbench_smoketest_perf elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then - checkout_install_torchbench timm_vision_transformer phlippe_densenet basic_gnn_edgecnn \ - llama_v2_7b_16h resnet50 timm_efficientnet mobilenet_v3_large timm_resnest \ - functorch_maml_omniglot yolov3 mobilenet_v2 resnext50_32x4d densenet121 mnasnet1_0 - PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_cpu_smoketest_perf + PYTHONPATH=/torchbench test_inductor_torchbench_cpu_smoketest_perf elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then - checkout_install_torchbench - TORCHBENCHPATH=$(pwd)/torchbench test_torchbench_gcp_smoketest + TORCHBENCHPATH=/torchbench test_torchbench_gcp_smoketest else - checkout_install_torchbench # Do this after checkout_install_torchbench to ensure we clobber any # nightlies that torchbench may pull in if [[ "${TEST_CONFIG}" != *cpu* ]]; then install_torchrec_and_fbgemm fi - PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" + PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" fi elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then install_torchvision - PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" + PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" if [[ "$SHARD_NUMBER" -eq "1" ]]; then test_inductor_aoti fi diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index c656c16e97c2e..08fcd33402625 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -10,6 +10,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +permissions: + id-token: write + contents: read + jobs: get-default-label-prefix: if: github.repository_owner == 'pytorch' diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 3879b62cc020e..c7cf4c84e1888 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -205,7 +205,7 @@ jobs: with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.9-gcc11 - docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks test-matrix: | { include: [ { config: "verify_cachebench", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, From 22bedc429f27679bb9764287c443579023a63fab Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 4 Aug 2025 13:46:03 -0700 Subject: [PATCH 0022/1424] Extract some HOP utils to be importable (#159705) Useful helper function for stage 1 export -> manual partitioner -> stage 2 compile users Pull Request resolved: https://github.com/pytorch/pytorch/pull/159705 Approved by: https://github.com/zou3519 ghstack dependencies: #159134 --- .../_functorch/_aot_autograd/graph_compile.py | 87 +++++++++---------- torch/_inductor/compile_fx.py | 55 ++++++------ 2 files changed, 70 insertions(+), 72 deletions(-) diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 27cf699091ee4..a1c6e795bfec8 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -516,6 +516,48 @@ class InvokeSubgraphHopGraphs: new_num_saved_nodes: Optional[int] = None +def prepare_for_partitioner(mod, num_primals, num_fw_outputs): + # min-cut partitioner requires the placeholders to have primals and + # tangents string in the node.name. The signature of the joint graph is + # (*primals, *tangents) + + # We also have to update the output signature which is right now + # (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the + # partitioner to work. + new_graph = torch.fx.Graph() + env = {} + + primals_counter = itertools.count(0) + tangents_counter = itertools.count(0) + + for idx, node in enumerate(mod.graph.nodes): + if node.op == "placeholder": + if idx < num_primals: + env[node] = new_graph.placeholder(f"primals_{next(primals_counter)}") + else: + env[node] = new_graph.placeholder(f"tangents_{next(tangents_counter)}") + env[node].meta = copy.copy(node.meta) + elif node.op == "output": + # Reverse the (*grads, *fw_outs) to (*fw_outs, *grads) + # The reason for having the reversed signature in the first + # place is to simplify step 3. + old_outputs = node.args[0] + new_outputs = ( + *old_outputs[-num_fw_outputs:], + *old_outputs[:-num_fw_outputs], + ) + new_outputs = [env[n] if n else None for n in new_outputs] + new_graph.output(tuple(new_outputs)) + else: + env[node] = new_graph.node_copy(node, lambda n: env[n]) + env[node].meta = copy.copy(node.meta) + + new_graph.lint() + + out = torch.fx.GraphModule(mod, new_graph) + return out + + def run_joint_graph_passes_on_hops( joint_gm: torch.fx.GraphModule, joint_inputs: Any, @@ -553,51 +595,6 @@ def num_outputs(mod): def num_inputs(mod): return len(mod.graph.find_nodes(op="placeholder")) - def prepare_for_partitioner(mod, num_primals, num_fw_outputs): - # min-cut partitioner requires the placeholders to have primals and - # tangents string in the node.name. The signature of the joint graph is - # (*primals, *tangents) - - # We also have to update the output signature which is right now - # (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the - # partitioner to work. - new_graph = torch.fx.Graph() - env = {} - - primals_counter = itertools.count(0) - tangents_counter = itertools.count(0) - - for idx, node in enumerate(mod.graph.nodes): - if node.op == "placeholder": - if idx < num_primals: - env[node] = new_graph.placeholder( - f"primals_{next(primals_counter)}" - ) - else: - env[node] = new_graph.placeholder( - f"tangents_{next(tangents_counter)}" - ) - env[node].meta = copy.copy(node.meta) - elif node.op == "output": - # Reverse the (*grads, *fw_outs) to (*fw_outs, *grads) - # The reason for having the reversed signature in the first - # place is to simplify step 3. - old_outputs = node.args[0] - new_outputs = ( - *old_outputs[-num_fw_outputs:], - *old_outputs[:-num_fw_outputs], - ) - new_outputs = [env[n] if n else None for n in new_outputs] - new_graph.output(tuple(new_outputs)) - else: - env[node] = new_graph.node_copy(node, lambda n: env[n]) - env[node].meta = copy.copy(node.meta) - - new_graph.lint() - - out = torch.fx.GraphModule(mod, new_graph) - return out - new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict( lambda: InvokeSubgraphHopGraphs() ) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index d17ffe19b3c70..eaab9020f1e84 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -2052,6 +2052,34 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[ ) +def partition_fn( + gm: GraphModule, + joint_inputs: Sequence[object], + **kwargs: object, +) -> tuple[GraphModule, GraphModule]: + cuda_context = get_cuda_device_context(gm) + with cuda_context: + # We can skip the invoke_subgraph because the + # entire_partition_fn is called recursively for invoke_subgraph + # in partitioning. + _recursive_joint_graph_passes(gm, skip_invoke_subgraph=True) + + static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment] + "static_lifetime_input_indices", None + ) + + with dynamo_utils.dynamo_timed( + "min_cut_rematerialization_partition", log_pt2_compile_event=True + ): + return min_cut_rematerialization_partition( + gm, + joint_inputs, + compiler="inductor", + static_lifetime_input_indices=static_lifetime_input_indices, + **kwargs, + ) + + def compile_fx( model_: GraphModule, example_inputs_: Sequence[InputType], @@ -2370,33 +2398,6 @@ def fw_compiler_base( OutputCode, inference_compiler ) - def partition_fn( - gm: GraphModule, - joint_inputs: Sequence[object], - **kwargs: object, - ) -> tuple[GraphModule, GraphModule]: - cuda_context = get_cuda_device_context(gm) - with cuda_context: - # We can skip the invoke_subgraph because the - # entire_partition_fn is called recursively for invoke_subgraph - # in partitioning. - _recursive_joint_graph_passes(gm, skip_invoke_subgraph=True) - - static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment] - "static_lifetime_input_indices", None - ) - - with dynamo_utils.dynamo_timed( - "min_cut_rematerialization_partition", log_pt2_compile_event=True - ): - return min_cut_rematerialization_partition( - gm, - joint_inputs, - compiler="inductor", - static_lifetime_input_indices=static_lifetime_input_indices, - **kwargs, - ) - @compile_time_strobelight_meta(phase_name="backward") def bw_compiler( gm: GraphModule, example_inputs: Sequence[InputType] From 6a82da392edb485491b9ed601f3edc88cb1d5dcb Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Wed, 6 Aug 2025 00:23:05 +0000 Subject: [PATCH 0023/1424] [export] Fix generated schema for C++20/23 (#159871) Summary: Fixing the issue from https://github.com/pytorch/pytorch/issues/159838 Test Plan: buck run caffe2/:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/ Rollback Plan: Differential Revision: D79647167 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159871 Approved by: https://github.com/malfet --- torch/_export/serde/schema_check.py | 2 ++ torch/csrc/utils/generated_serialization_types.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index ccc963397530b..29b9766ae18a4 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -448,6 +448,7 @@ class ForwardRef {{ ptr_ = std::make_unique(*other.ptr_); return *this; }} + ~ForwardRef(); const T& operator*() const {{ return *ptr_; }} @@ -519,6 +520,7 @@ class F64 {{ template ForwardRef::ForwardRef(ForwardRef&&) = default; template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +template ForwardRef::~ForwardRef() = default; }} // namespace _export }} // namespace torch """ diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index 98803390e5104..14741e4d2c6e1 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -61,6 +61,7 @@ class ForwardRef { ptr_ = std::make_unique(*other.ptr_); return *this; } + ~ForwardRef(); const T& operator*() const { return *ptr_; } @@ -3717,6 +3718,7 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlo template ForwardRef::ForwardRef(ForwardRef&&) = default; template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +template ForwardRef::~ForwardRef() = default; } // namespace _export } // namespace torch From 3ddfd46bd203a09e5f56b69489c2b8f656d3e86a Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 5 Aug 2025 14:00:17 -0700 Subject: [PATCH 0024/1424] Cut a version of TORCH_ERROR_CODE_CHECK in headeronly from AOTI (#159604) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159604 Approved by: https://github.com/albanD, https://github.com/desertfire --- test/cpp/aoti_abi_check/test_exception.cpp | 6 +++++ torch/csrc/inductor/aoti_runtime/utils.h | 22 ++++++---------- torch/csrc/stable/ops.h | 8 +++--- torch/csrc/stable/tensor.h | 25 ++++++++----------- torch/header_only_apis.txt | 3 +++ torch/headeronly/util/shim_utils.h | 29 ++++++++++++++++++++++ 6 files changed, 60 insertions(+), 33 deletions(-) create mode 100644 torch/headeronly/util/shim_utils.h diff --git a/test/cpp/aoti_abi_check/test_exception.cpp b/test/cpp/aoti_abi_check/test_exception.cpp index 74a9fee5d9863..26f8092932444 100644 --- a/test/cpp/aoti_abi_check/test_exception.cpp +++ b/test/cpp/aoti_abi_check/test_exception.cpp @@ -1,6 +1,7 @@ #include #include +#include namespace torch { namespace aot_inductor { @@ -15,5 +16,10 @@ TEST(TestExceptions, TestStdTorchCheck) { std::runtime_error); } +TEST(TestExceptions, TestTorchErrorCodeCheck) { + EXPECT_NO_THROW(TORCH_ERROR_CODE_CHECK(0)); + EXPECT_THROW(TORCH_ERROR_CODE_CHECK(1), std::runtime_error); +} + } // namespace aot_inductor } // namespace torch diff --git a/torch/csrc/inductor/aoti_runtime/utils.h b/torch/csrc/inductor/aoti_runtime/utils.h index b6c009805c71d..8d1dd116afe56 100644 --- a/torch/csrc/inductor/aoti_runtime/utils.h +++ b/torch/csrc/inductor/aoti_runtime/utils.h @@ -12,6 +12,7 @@ // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule // applies to other files under torch/csrc/inductor/aoti_runtime/. #include +#include #if defined(__GNUC__) || defined(__clang__) #define AOTI_NOINLINE __attribute__((noinline)) @@ -21,27 +22,18 @@ #define AOTI_NOINLINE #endif -AOTI_NOINLINE static void throw_exception( - const char* call, - const char* file, - int64_t line) { - std::stringstream ss; - ss << call << " API call failed at " << file << ", line " << line; - throw std::runtime_error(ss.str()); -} - -#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ - if ((call) != AOTI_TORCH_SUCCESS) { \ - throw_exception(#call, __FILE__, __LINE__); \ +#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_TORCH_SUCCESS) { \ + torch::headeronly::detail::throw_exception(#call, __FILE__, __LINE__); \ } using AOTIRuntimeError = int32_t; #define AOTI_RUNTIME_SUCCESS 0 #define AOTI_RUNTIME_FAILURE 1 -#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ - if ((call) != AOTI_RUNTIME_SUCCESS) { \ - throw_exception(#call, __FILE__, __LINE__); \ +#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_RUNTIME_SUCCESS) { \ + torch::headeronly::detail::throw_exception(#call, __FILE__, __LINE__); \ } namespace torch::aot_inductor { diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index a8f68f4a5e3ad..c4a8a99848055 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -21,7 +21,7 @@ inline Tensor empty_like(const Tensor& self) { from(std::nullopt), from(std::nullopt), from(std::nullopt)}; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::empty_like", "", stack.data())); return to(stack[0]); } @@ -32,7 +32,7 @@ inline Tensor empty_like(const Tensor& self) { // actually a Scalar. This is because Scalar.h is currently not // header-only. inline Tensor fill_(const Tensor& self, double value) { - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_aten_fill__Scalar(self.get(), value)); + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_fill__Scalar(self.get(), value)); return self; } @@ -41,7 +41,7 @@ inline Tensor fill_(const Tensor& self, double value) { inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { const auto num_args = 3; std::array stack{from(self), from(dim0), from(dim1)}; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::transpose", "int", stack.data())); return to(stack[0]); } @@ -52,7 +52,7 @@ inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { inline Tensor zero_(Tensor& self) { const auto num_args = 1; std::array stack{from(self)}; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_call_dispatcher("aten::zero_", "", stack.data())); return to(stack[0]); } diff --git a/torch/csrc/stable/tensor.h b/torch/csrc/stable/tensor.h index 1b9b3fecb4173..741da7e62e409 100644 --- a/torch/csrc/stable/tensor.h +++ b/torch/csrc/stable/tensor.h @@ -1,10 +1,8 @@ #pragma once -// TODO ASAP: THIS FILE SHOULD BE HEADER ONLY BUT ISN'T ENFORCED: -// I only need it for AOTI_TORCH_ERROR_CODE_CHECK, see #154908 -#include - #include +#include +#include namespace torch::stable { @@ -37,7 +35,7 @@ class Tensor { // Steals ownership from the ATH explicit Tensor(AtenTensorHandle ath) : ath_(ath, [](AtenTensorHandle ath) { - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); + TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); }) {} // Copy and move constructors can be default cuz the underlying handle is a @@ -65,19 +63,19 @@ class Tensor { void* data_ptr() const { void* data_ptr; - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); return data_ptr; } int64_t dim() const { int64_t dim; - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); return dim; } int64_t numel() const { int64_t numel; - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); return numel; } @@ -86,35 +84,34 @@ class Tensor { // Here, we assume the default contiguous memory format. bool is_contiguous() const { bool is_contiguous; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_is_contiguous(ath_.get(), &is_contiguous)); return is_contiguous; } int64_t stride(int64_t dim) const { int64_t stride; - AOTI_TORCH_ERROR_CODE_CHECK( - aoti_torch_get_stride(ath_.get(), dim, &stride)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(ath_.get(), dim, &stride)); return stride; } DeviceIndex get_device() const { int32_t device_index; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_index(ath_.get(), &device_index)); return static_cast(device_index); } bool is_cuda() const { int32_t device_type; - AOTI_TORCH_ERROR_CODE_CHECK( + TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_type(ath_.get(), &device_type)); return device_type == aoti_torch_device_type_cuda(); } int64_t size(int64_t dim) const { int64_t size; - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size)); return size; } diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index e0eaa91f4ca76..72a1b46fb37e8 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -3,6 +3,9 @@ # to guarantee that compiling these symbols do not require linking libtorch # to ensure header-only-ness. +# torch/headeronly/util/shim_utils.h +TORCH_ERROR_CODE_CHECK + # c10/util/TypeCast.h convert diff --git a/torch/headeronly/util/shim_utils.h b/torch/headeronly/util/shim_utils.h new file mode 100644 index 0000000000000..5acb3e2e347c1 --- /dev/null +++ b/torch/headeronly/util/shim_utils.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include +#include + +#define TORCH_SUCCESS 0 +#define TORCH_FAILURE 1 + +namespace torch::headeronly::detail { +[[maybe_unused]] C10_NOINLINE static void throw_exception( + const char* call, + const char* file, + int64_t line) { + std::stringstream ss; + ss << call << " API call failed at " << file << ", line " << line; + throw std::runtime_error(ss.str()); +} +} // namespace torch::headeronly::detail + +// This API is 100% inspired by AOTI_TORCH_ERROR_CODE_CHECK defined in +// pytorch/torch/csrc/inductor/aoti_runtime/utils.h to handle the returns +// of the APIs in the shim. We are genericizing this for more global use +// of the shim beyond AOTI, for examples, see torch/csrc/stable/ops.h. +#define TORCH_ERROR_CODE_CHECK(call) \ + if ((call) != TORCH_SUCCESS) { \ + torch::headeronly::detail::throw_exception(#call, __FILE__, __LINE__); \ + } From 3eb3da9b4ba44985bea78154ff9d74402890fe96 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 5 Aug 2025 14:49:07 -0700 Subject: [PATCH 0025/1424] [dynamo][guards] Skip ID_MATCH guard on self.__class__.__closure__ (#159888) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159888 Approved by: https://github.com/williamwen42 --- torch/_dynamo/source.py | 8 ++++++++ torch/_dynamo/variables/builder.py | 14 +++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index a6bedb178e00b..3cb36a63d27ad 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -1066,6 +1066,14 @@ def is_from_nonlocal_source(source: Source) -> bool: ) +def is_from_closure_source(source: Source) -> bool: + if isinstance(source, ClosureSource): + return True + if isinstance(source, ChainedSource): + return is_from_closure_source(source.base) + return False + + def is_from_source(source: Source, target: Source) -> bool: if isinstance(source, ChainedSource): return is_from_source(source.base, target) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f9d8e273068f3..481773860f8d5 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -104,6 +104,7 @@ GetItemSource, GradSource, is_constant_source, + is_from_closure_source, is_from_global_source, is_from_nonlocal_source, is_from_optimizer_source, @@ -1332,9 +1333,16 @@ def build_key_value(i, k, v): and not is_traceable_wrapper_subclass_type(value) ): return TensorSubclassVariable(value, source=self.source) - # This is a userdefined class, so install an ID_MATCH even if its a - # global variable. - self.install_guards(GuardBuilder.ID_MATCH) + + if not is_from_closure_source(self.source): + # For closure source, the variable comes from LOAD_SUPER_ATTR, + # which calls self.__class__. This is internal Cpython + # implementation, and it is rare for the user to modify + # self.__class__ manually. + # For other cases, this is a userdefined class, so install an + # ID_MATCH even if its a global variable. + self.install_guards(GuardBuilder.ID_MATCH) + return UserDefinedClassVariable( value, source=self.source, From f7a66da5f9f6b8b75119b1ee8ce9ddc23e15570e Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 6 Aug 2025 00:36:22 +0000 Subject: [PATCH 0026/1424] Add DeviceAllocator as the base device allocator (#138222) # Motivation In line with [RFC] [A device-agnostic Python device memory related API design for stream-based accelerators](https://github.com/pytorch/pytorch/issues/134978), some memory-related APIs are widely used in popular repositories, such as HuggingFace [so many if-else conditional code](https://github.com/search?q=repo%3Ahuggingface%2Faccelerate%20torch.cuda.empty_cache&type=code). We would like to introduce a generic API set under torch.accelerator namespace to generalize these user cases.
Device-specific memory APIs torch.xxx.foo Device-agnostic memory APIs torch.accelerator.foo
```python torch.xxx.empty_cache ``` ```python torch.accelerator.empty_cache ```
```python torch.xxx.reset_peak_memory_stats ``` ```python torch.accelerator.reset_peak_memory_stats ```
```python torch.xxx.reset_accumulated_memory_stats ``` ```python torch.accelerator.reset_accumulated_memory_stats ```
```python torch.xxx.memory_stats ``` ```python torch.accelerator.memory_stats ```
```python torch.xxx.memory_allocated ``` ```python torch.accelerator.memory_allocated ```
```python torch.xxx.max_memory_allocated ``` ```python torch.accelerator.max_memory_allocated ```
```python torch.xxx.memory_reserved ``` ```python torch.accelerator.memory_reserved ```
```python torch.xxx.max_memory_reserved ``` ```python torch.accelerator.max_memory_reserved ```
# Solution This design follows a similar pattern to `HostAllocator`. We're introducing a base class `DeviceAllocator`, from which `CUDAAllocator` and `XPUAllocator` will inherit. This allows us to provide a unified call path like: `torch.accelerator.empty_cache()` -> `GetDeviceAllocator(allocator)->empty_cache()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138222 Approved by: https://github.com/albanD, https://github.com/Camyll --- aten/src/ATen/cuda/CUDAGraph.cpp | 1 - aten/src/ATen/cuda/CUDAGraph.h | 1 + c10/core/CachingDeviceAllocator.cpp | 10 ++++++ c10/core/CachingDeviceAllocator.h | 53 +++++++++++++++++++++++++++++ c10/cuda/CUDACachingAllocator.cpp | 11 ++++++ c10/cuda/CUDACachingAllocator.h | 19 ++++++----- c10/cuda/CUDAGraphsC10Utils.h | 6 ---- c10/xpu/XPUCachingAllocator.cpp | 19 +++++++---- 8 files changed, 98 insertions(+), 22 deletions(-) create mode 100644 c10/core/CachingDeviceAllocator.cpp diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 7fba7c4c7424c..2800e505a9b76 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index c8cae16b624fe..4f2aa31dd1c35 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/c10/core/CachingDeviceAllocator.cpp b/c10/core/CachingDeviceAllocator.cpp new file mode 100644 index 0000000000000..582efd59cf1b1 --- /dev/null +++ b/c10/core/CachingDeviceAllocator.cpp @@ -0,0 +1,10 @@ +#include + +namespace c10 { + +// Ensures proper DLL export of this pure virtual base class on Windows, +// since it's mainly used in other DLLs outside c10.dll. +DeviceAllocator::DeviceAllocator() = default; +DeviceAllocator::~DeviceAllocator() = default; + +} // namespace c10 diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index b23490de693a8..0bec03ae417fa 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace c10::CachingDeviceAllocator { @@ -59,3 +60,55 @@ struct DeviceStats { }; } // namespace c10::CachingDeviceAllocator + +namespace c10 { + +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by Graph mode capture_begin. +// second is set if the instance is created by Graph mode graph_pool_handle. +using MempoolId_t = std::pair; + +struct C10_API DeviceAllocator : public c10::Allocator { + DeviceAllocator(); + ~DeviceAllocator() override; + + // Returns true if the allocator has been properly initialized and is ready + // for use + virtual bool initialized() = 0; + + // Releases all cached device memory from the specified memory pool back to + // the system + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; + + // Associates a memory allocation with a stream to establish dependency + // tracking. Prevents memory reuse until all operations on the specified + // stream complete + virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0; + + // Retrieves comprehensive memory statistics for the specified device, + // including allocation patterns, usage metrics + virtual CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) = 0; + + // Resets cumulative allocation statistics for the specified device to zero + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + + // Resets peak memory usage statistics for the specified device + virtual void resetPeakStats(c10::DeviceIndex device) = 0; +}; + +// This function is used to get the DeviceAllocator for a specific device type +// and keep backward compatibility with c10::GetAllocator. +C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) { + TORCH_CHECK( + t != DeviceType::CPU, + "getDeviceAllocator is not supported for CPU device type."); + auto* allocator = c10::GetAllocator(t); + auto* device_allocator = dynamic_cast(allocator); + TORCH_INTERNAL_ASSERT( + device_allocator, "Allocator for ", t, " is not a DeviceAllocator."); + return device_allocator; +} + +} // namespace c10 diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index c2a46ac9f3f74..59b62dcac07f0 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -4118,7 +4118,18 @@ struct BackendStaticInitializer { BackendStaticInitializer() { auto r = parseEnvForBackend(); +// Register this HIP allocator as the CUDA allocator to allow it to work +// with both c10::GetAllocator(kCUDA) and c10::getDeviceAllocator(kCUDA) +// APIs. We don't perform this masquerading inside +// HIPAllocatorMasqueradingAsCUDA because it needs to happen during static +// initialization, and doing so there may introduce static initialization +// order (SIOF) issues. +#define HIP_MASQUERADING_AS_CUDA \ + "cud" \ + "a" + at::SetAllocator(c10::Device(HIP_MASQUERADING_AS_CUDA).type(), r, 0); allocator.store(r); +#undef HIP_MASQUERADING_AS_CUDA } }; diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 956411fe22827..75a2d4c8e481b 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -202,25 +202,24 @@ struct ShareableHandle { std::string handle; }; -class CUDAAllocator : public Allocator { +class CUDAAllocator : public DeviceAllocator { public: virtual void* raw_alloc(size_t nbytes) = 0; virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; virtual void raw_delete(void* ptr) = 0; virtual void init(int device_count) = 0; - virtual bool initialized() = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; - virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; - virtual void recordStream(const DataPtr&, CUDAStream stream) = 0; - virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device) = 0; - virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; - virtual void resetPeakStats(c10::DeviceIndex device) = 0; + // Keep for BC only + virtual void recordStream(const DataPtr& ptr, CUDAStream stream) = 0; + void recordStream(const DataPtr& ptr, c10::Stream stream) override { + CUDAStream cuda_stream = CUDAStream(stream); + recordStream(ptr, cuda_stream); + } virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, @@ -525,6 +524,10 @@ inline void enablePeerAccess( namespace c10::cuda { +// Keep BC only +using c10::CaptureId_t; +using c10::MempoolId_t; + // MemPool represents a pool of memory in a caching allocator. Currently, // it's just the ID of the pool object maintained in the CUDACachingAllocator. // diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index eb29ca8bc9f02..936875fd71d5c 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -9,12 +9,6 @@ namespace c10::cuda { -using CaptureId_t = unsigned long long; - -// first is set if the instance is created by CUDAGraph::capture_begin. -// second is set if the instance is created by at::cuda::graph_pool_handle. -using MempoolId_t = std::pair; - // RAII guard for "cudaStreamCaptureMode", a thread-local value // that controls the error-checking strictness of a capture. struct C10_CUDA_API CUDAStreamCaptureModeGuard { diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index afae32d92a4b4..04ab3cabcbc2b 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -539,7 +539,7 @@ class DeviceCachingAllocator { static void local_raw_delete(void* ptr); -class XPUAllocator : public Allocator { +class XPUAllocator : public DeviceAllocator { private: std::mutex mutex; ska::flat_hash_map allocated_blocks; @@ -575,6 +575,10 @@ class XPUAllocator : public Allocator { } } + bool initialized() override { + return !device_allocators.empty(); + } + void malloc( void** devPtr, DeviceIndex device, @@ -609,13 +613,13 @@ class XPUAllocator : public Allocator { } } - void emptyCache() { + void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override { for (auto& da : device_allocators) { da->emptyCache(); } } - void recordStream(const DataPtr& ptr, XPUStream stream) { + void recordStream(const DataPtr& ptr, c10::Stream stream) override { if (!ptr.get()) { return; } @@ -625,7 +629,8 @@ class XPUAllocator : public Allocator { Block* block = get_allocated_block(ptr.get()); TORCH_CHECK(block, "No allocated block can be found."); - device_allocators[block->device]->recordStream(block, stream); + c10::xpu::XPUStream xpu_stream{stream}; + device_allocators[block->device]->recordStream(block, xpu_stream); } DataPtr allocate(size_t size) override { @@ -678,17 +683,17 @@ class XPUAllocator : public Allocator { ": did you call init?"); } - DeviceStats getDeviceStats(DeviceIndex device) { + DeviceStats getDeviceStats(DeviceIndex device) override { assertValidDevice(device); return device_allocators[device]->getStats(); } - void resetPeakStats(DeviceIndex device) { + void resetPeakStats(DeviceIndex device) override { assertValidDevice(device); device_allocators[device]->resetPeakStats(); } - void resetAccumulatedStats(DeviceIndex device) { + void resetAccumulatedStats(DeviceIndex device) override { assertValidDevice(device); device_allocators[device]->resetAccumulatedStats(); } From e16c48ae97e1785d77f5019eb8315e4385bb23ee Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Tue, 5 Aug 2025 22:30:34 +0000 Subject: [PATCH 0027/1424] [BE] Fix type hint in AOTIRunnerUtil (#159577) Not sure why it was labelled as list in the first place. In test_aot_inductor.py, I scanned a few use cases and they are tuple as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159577 Approved by: https://github.com/Skylion007 --- test/inductor/test_aot_inductor_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index a2706933d6156..a86690270461e 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -148,7 +148,7 @@ def legacy_run( @staticmethod def compile( model: Union[torch.nn.Module, types.FunctionType], - example_inputs: list[torch.Tensor], + example_inputs: tuple[torch.Tensor, ...], inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): @@ -169,7 +169,7 @@ def compile( @staticmethod def run( model: Union[torch.nn.Module, types.FunctionType], - example_inputs: list[torch.Tensor], + example_inputs: tuple[torch.Tensor, ...], inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): @@ -185,7 +185,7 @@ def run( @staticmethod def run_multiple( model: Union[torch.nn.Module, types.FunctionType], - list_example_inputs: list[list[torch.Tensor]], + list_example_inputs: list[tuple[torch.Tensor, ...]], inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): From 15f1173e5d72d6d45faba4cecd135e0160f06c6f Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 6 Aug 2025 00:36:24 +0000 Subject: [PATCH 0028/1424] Add unified memory APIs for torch.accelerator (#152932) # Motivation The following API will be put under torch.accelerator - empty_cache - max_memory_allocated - max_memory_reserved - memory_allocated - memory_reserved - memory_stats - reset_accumulated_memory_stats - reset_peak_memory_stats Pull Request resolved: https://github.com/pytorch/pytorch/pull/152932 Approved by: https://github.com/albanD ghstack dependencies: #138222 --- aten/src/ATen/DeviceAccelerator.h | 22 ++++ docs/source/accelerator.md | 23 ++++ torch/_C/__init__.pyi.in | 5 + torch/accelerator/__init__.py | 18 +++ torch/accelerator/memory.py | 201 ++++++++++++++++++++++++++++++ torch/csrc/DeviceAccelerator.cpp | 64 ++++++++++ torch/cuda/memory.py | 4 +- 7 files changed, 335 insertions(+), 2 deletions(-) create mode 100644 torch/accelerator/memory.py diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f37e492c861fe..f23b35047fcc8 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -72,6 +73,27 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +TORCH_API inline void emptyCache() { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->emptyCache(); +} + +TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); +} + +TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); +} + +TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetPeakStats(device_index); +} + } // namespace at::accelerator namespace at { diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index c6f2fb1080400..ce593a9acf518 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -25,3 +25,26 @@ synchronize device_index ``` + +```{eval-rst} +.. automodule:: torch.accelerator.memory +``` +```{eval-rst} +.. currentmodule:: torch.accelerator.memory +``` + +## Memory management +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + max_memory_allocated + max_memory_reserved + memory_allocated + memory_reserved + memory_stats + reset_accumulated_memory_stats + reset_peak_memory_stats +``` diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9e03c7dba8305..fb7e9c5ce56e0 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2435,6 +2435,11 @@ def _accelerator_synchronizeDevice(device_index: _int) -> None: ... def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... def _accelerator_setAllocatorSettings(env: str) -> None: ... +def _accelerator_isAllocatorInitialized() -> _bool: ... +def _accelerator_emptyCache() -> None: ... +def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... +def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... +def _accelerator_resetPeakStats(device_index: _int) -> None: ... # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index e9e48f1cf3061..4d1a78df1f74c 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -8,6 +8,16 @@ import torch from ._utils import _device_t, _get_device_index +from .memory import ( + empty_cache, + max_memory_allocated, + max_memory_reserved, + memory_allocated, + memory_reserved, + memory_stats, + reset_accumulated_memory_stats, + reset_peak_memory_stats, +) __all__ = [ @@ -15,9 +25,17 @@ "current_device_idx", # deprecated "current_device_index", "current_stream", + "empty_cache", "device_count", "device_index", "is_available", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", "set_device_idx", # deprecated "set_device_index", "set_stream", diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py new file mode 100644 index 0000000000000..d34a11a3a02e5 --- /dev/null +++ b/torch/accelerator/memory.py @@ -0,0 +1,201 @@ +from collections import OrderedDict +from typing import Any + +import torch + +from ._utils import _device_t, _get_device_index + + +__all__ = [ + "empty_cache", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", +] + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other application. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return + torch._C._accelerator_emptyCache() + + +def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: + r"""Return a dictionary of accelerator device memory allocator statistics for a given device index. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from device memory allocation. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of June 2025, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of June 2025, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed device memory allocation calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of device memory allocation calls. + - ``"num_device_free"``: number of device memory free calls. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return OrderedDict() + device_index = _get_device_index(device_index, optional=True) + stats = torch._C._accelerator_getDeviceStats(device_index) + flat_stats = [] + + def flatten(prefix: str, value: Any) -> None: + if isinstance(value, dict): + for k, v in value.items(): + nested_prefix = f"{prefix}.{k}" if prefix else k + flatten(nested_prefix, v) + else: + flat_stats.append((prefix, value)) + + flatten("", stats) + flat_stats.sort() + return OrderedDict(flat_stats) + + +def memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory occupied by tensors + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory occupied by tensors + in bytes for a given device index. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory managed by the caching allocator + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory managed by the caching allocator + in bytes for a given device index. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.peak", 0) + + +def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetAccumulatedStats(device_index) + + +def reset_peak_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "peak" stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetPeakStats(device_index) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 3a97c0794684f..59cb8047467c9 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -77,6 +77,70 @@ void initModule(PyObject* module) { m.def("_accelerator_setAllocatorSettings", [](std::string env) { c10::CachingAllocator::setAllocatorSettings(env); }); + + m.def("_accelerator_isAllocatorInitialized", []() { + const auto device_type = at::accelerator::getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->initialized(); + }); + + m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); }); + + m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) { + using c10::CachingAllocator::Stat; + using c10::CachingAllocator::StatArray; + using c10::CachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + + const auto stats = at::accelerator::getDeviceStats(device_index); + const auto stat_to_dict = [](const Stat& stat) -> py::dict { + py::dict dict; + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto stat_array_to_dict = [=](const StatArray& stats) -> py::dict { + const std::array(StatType::NUM_TYPES)> + kStatTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(kStatTypeNames.size())) { + dict[kStatTypeNames[i]] = stat_to_dict(stats[i]); + } + return dict; + }; + + py::dict result; + result["num_alloc_retries"] = stats.num_alloc_retries; + result["num_ooms"] = stats.num_ooms; + result["max_split_size"] = stats.max_split_size; + result["num_sync_all_streams"] = stats.num_sync_all_streams; + result["num_device_alloc"] = stats.num_device_alloc; + result["num_device_free"] = stats.num_device_free; + result["allocated_bytes"] = stat_array_to_dict(stats.allocated_bytes); + result["reserved_bytes"] = stat_array_to_dict(stats.reserved_bytes); + result["active_bytes"] = stat_array_to_dict(stats.active_bytes); + result["requested_bytes"] = stat_array_to_dict(stats.requested_bytes); + result["allocation"] = stat_array_to_dict(stats.allocation); + result["segment"] = stat_array_to_dict(stats.segment); + result["active"] = stat_array_to_dict(stats.active); + result["inactive_split"] = stat_array_to_dict(stats.inactive_split); + result["inactive_split_bytes"] = + stat_array_to_dict(stats.inactive_split_bytes); + result["oversize_allocations"] = stat_to_dict(stats.oversize_allocations); + result["oversize_segments"] = stat_to_dict(stats.oversize_segments); + return result; + }); + + m.def( + "_accelerator_resetAccumulatedStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetAccumulatedStats(device_index); + }); + + m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetPeakStats(device_index); + }); } } // namespace torch::accelerator diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 63e59096162fb..1bd6f9edc0319 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -255,9 +255,9 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: - ``all``: combined statistics across all memory pools. - ``large_pool``: statistics for the large allocation pool - (as of October 2019, for size >= 1MB allocations). + (as of June 2025, for size >= 1MB allocations). - ``small_pool``: statistics for the small allocation pool - (as of October 2019, for size < 1MB allocations). + (as of June 2025, for size < 1MB allocations). Metric type: From 4604f0482c2b4a3001b62e5bc5085149a9bb053c Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 6 Aug 2025 00:36:26 +0000 Subject: [PATCH 0029/1424] Add UT for torch.accelerator memory-related API (#155200) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155200 Approved by: https://github.com/albanD ghstack dependencies: #138222, #152932 --- test/test_accelerator.py | 78 ++++++++++++++++++++++++++++++++++++++++ test/test_cuda.py | 36 +++++++++++++++++++ test/test_xpu.py | 37 +++++++++++++++++++ 3 files changed, 151 insertions(+) diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 0ea224d704cb8..21731bd275b60 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -1,5 +1,6 @@ # Owner(s): ["module: tests"] +import gc import sys import unittest @@ -156,6 +157,83 @@ def test_generic_event_behavior(self): ): event1.elapsed_time(event2) + @unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!") + def test_memory_stats(self): + # Ensure that device allocator is initialized + acc = torch.accelerator.current_accelerator() + tmp = torch.randn(100, device=acc) + del tmp + gc.collect() + self.assertTrue(torch._C._accelerator_isAllocatorInitialized()) + torch.accelerator.empty_cache() + + pool_type = ["all", "small_pool", "large_pool"] + metric_type = ["peak", "current", "allocated", "freed"] + stats_type = [ + "allocated_bytes", + "reserved_bytes", + "active_bytes", + "requested_bytes", + ] + mem_stats = torch.accelerator.memory_stats() + expected_stats = [ + f"{st}.{pt}.{mt}" + for st in stats_type + for pt in pool_type + for mt in metric_type + ] + missing_stats = [stat for stat in expected_stats if stat not in mem_stats] + self.assertEqual( + len(missing_stats), + 0, + f"Missing expected memory statistics: {missing_stats}", + ) + + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertGreaterEqual(prev_allocated, 0) + self.assertGreaterEqual(prev_reserved, 0) + self.assertGreater(prev_max_allocated, 0) + self.assertGreater(prev_max_reserved, 0) + tmp = torch.ones(256, device=acc) + self.assertGreater(torch.accelerator.memory_allocated(), prev_allocated) + self.assertGreaterEqual(torch.accelerator.memory_reserved(), prev_reserved) + del tmp + gc.collect() + torch.accelerator.empty_cache() + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.memory_allocated(), prev_allocated) + self.assertEqual(torch.accelerator.memory_reserved(), prev_reserved) + torch.accelerator.reset_accumulated_memory_stats() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device=acc) + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + if __name__ == "__main__": run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index f2f3304069f1b..9755835853eed 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -373,6 +373,42 @@ def test_memory_allocation(self): torch.cuda.caching_allocator_delete(mem) self.assertEqual(torch.cuda.memory_allocated(), prev) + def test_memory_stats(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertEqual(prev_allocated, prev_max_allocated) + self.assertEqual(prev_reserved, prev_max_reserved) + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device="cuda") + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + def test_check_error(self): # Assert this call doesn't raise. torch.cuda.check_error(0) diff --git a/test/test_xpu.py b/test/test_xpu.py index cd5275418c440..beb5a53a4a6b3 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,5 +1,6 @@ # Owner(s): ["module: intel"] +import gc import re import subprocess import sys @@ -520,6 +521,42 @@ def test_device_memory_allocated(self): ) del a + def test_memory_stats(self): + gc.collect() + torch.xpu.empty_cache() + torch.xpu.reset_peak_memory_stats() + torch.xpu.reset_accumulated_memory_stats() + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertEqual(prev_allocated, prev_max_allocated) + self.assertEqual(prev_reserved, prev_max_reserved) + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device="xpu") + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + @skipXPUIf( int(torch.version.xpu) < 20250000, "Test requires SYCL compiler version 2025.0.0 or newer.", From 8ce81bcee1da294a34af0a90dc16483055e8c5a4 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Wed, 6 Aug 2025 02:26:07 +0000 Subject: [PATCH 0030/1424] [Torch Package] Make get names of OrderedImporters support fallback to importers (#155743) Summary: OrderedImporters is supposed to be an importer which tries out every single importer in self._importers. However the get_name API does not follow this behavior and only uses the get_name from the basic Importer class. This change is to update the OrderedImporters get_name API so that it tries the get_name API of every single importers. Differential Revision: D76463252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155743 Approved by: https://github.com/jcwchen, https://github.com/jingsh --- test/package/test_save_load.py | 7 +++---- torch/package/importer.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/test/package/test_save_load.py b/test/package/test_save_load.py index a0cc967787e67..edbba9f6f8ee8 100644 --- a/test/package/test_save_load.py +++ b/test/package/test_save_load.py @@ -208,11 +208,10 @@ def make_exporter(): # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. return pe - # This should fail. The 'PackageAObject' type defined from 'importer1' - # is not necessarily the same 'obj2's version of 'PackageAObject'. + # This succeeds because OrderedImporter.get_name() properly + # falls back to sys_importer which can find the original PackageAObject pe = make_exporter() - with self.assertRaises(pickle.PicklingError): - pe.save_pickle("obj", "obj.pkl", obj2) + pe.save_pickle("obj", "obj.pkl", obj2) # This should also fail. The 'PackageAObject' type defined from 'importer1' # is not necessarily the same as the one defined from 'importer2' diff --git a/torch/package/importer.py b/torch/package/importer.py index 49b4512f79a60..8cfc1e336a454 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import importlib +import logging from abc import ABC, abstractmethod from pickle import ( # type: ignore[attr-defined] _getattribute, @@ -13,6 +14,7 @@ __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] +log = logging.getLogger(__name__) class ObjNotFoundError(Exception): @@ -204,6 +206,20 @@ def _is_torchpackage_dummy(self, module): return True return module.__file__ is None + def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]: + for importer in self._importers: + try: + return importer.get_name(obj, name) + except (ObjNotFoundError, ObjMismatchError) as e: + warning_message = ( + f"Tried to call get_name with obj {obj}, " + f"and name {name} on {importer} and got {e}" + ) + log.warning(warning_message) + raise ObjNotFoundError( + f"Could not find obj {obj} and name {name} in any of the importers {self._importers}" + ) + def import_module(self, module_name: str) -> ModuleType: last_err = None for importer in self._importers: From 14c7358c645880196f54f84586975c6407ed3f40 Mon Sep 17 00:00:00 2001 From: Tianhao Huang Date: Wed, 6 Aug 2025 03:15:30 +0000 Subject: [PATCH 0031/1424] Enable fr_trace to read local traces from multiple hosts. (#159490) Summary: For training jobs particularly from GenAI, NCCL trace dumps are generated in the format of `.pci3_rank_`. For multi-node training jobs, the hostname varies across traces. The current prefix matching logic can't handle this case. Test Plan: Create a local folder `dumps` and several empty files: `host0.pci3_rank_0`, `host0.pci3_rank_1`, `host1.pci3_rank_0`, `host1.pci3_rank_1` inside it. Then run ``` buck2 run fbcode//caffe2/fb/flight_recorder:fr_trace -- trace_dir dumps ``` Before this diff, fr_trace cannot locate any trace files, giving the following assertion error: ``` AssertionError: no files loaded from /home/tianhaoh/dumps with prefix pci3_rank_ ``` After this diff, fr_trace is able to locate the trace files, resulting in the exceptions like ``` dump = pickle.load(infile) ^^^^^^^^^^^^^^^^^^^ EOFError: Ran out of input ``` (since the trace files are fake and empty). Rollback Plan: Differential Revision: D79224727 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159490 Approved by: https://github.com/fduwjj --- tools/flight_recorder/components/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index dd2eb109aa563..7634226bae528 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -78,9 +78,9 @@ def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]: if prefix is None: prefix = _determine_prefix(files) for f in files: - if f.find(prefix) != 0: + if (offset := f.find(prefix)) == -1: continue - details[f] = read_dump(prefix, os.path.join(root, f)) + details[f] = read_dump(f[:offset] + prefix, os.path.join(root, f)) filecount += 1 if not version: version = str(details[f]["version"]) From 311f74089ab6c423e73f1541846ee4d9290a16e6 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Tue, 5 Aug 2025 16:54:35 -0700 Subject: [PATCH 0032/1424] remove print (#159917) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159917 Approved by: https://github.com/laithsakka --- torch/fx/experimental/symbolic_shapes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index c6e757ca52011..420537ccfd3f8 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4810,7 +4810,6 @@ def create_unbacked_symfloat(self) -> SymFloat: ) self.counter["create_unbacked_symbol"] += 1 if not self._ignore_fresh_unbacked_symbols_tls(): - print(f"adding {symbol}") self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() From bfc27cf468660b50758defdc86c5d19df8750c2e Mon Sep 17 00:00:00 2001 From: eqy Date: Wed, 6 Aug 2025 03:51:42 +0000 Subject: [PATCH 0033/1424] [Distributed] Fix `@parametrize` on unordered iterable in distributed test (#159793) seems to fix https://github.com/pytorch/pytorch/issues/145807 sets aren't ordered so `@parametrize` can cause two processes to spawn with different settings originally debugged thanks to @k-artem, see https://github.com/pytorch/pytorch/issues/145807#issuecomment-2971009451 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159793 Approved by: https://github.com/Skylion007, https://github.com/wconstab --- test/distributed/fsdp/test_distributed_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index ac34246ee6432..c80602c5d50f3 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -31,10 +31,10 @@ sys.exit(0) -_DISTRIBUTED_STATE_DICT_IMPLS = { +_DISTRIBUTED_STATE_DICT_IMPLS = ( StateDictType.LOCAL_STATE_DICT, StateDictType.SHARDED_STATE_DICT, -} +) class TestDistributedCheckpoint(FSDPTest): From 704594eb239dd26354304d3e5b399e8fd77070e8 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 5 Aug 2025 16:13:03 -0700 Subject: [PATCH 0034/1424] [Dynamo] make HOPs hashable (#159910) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159910 Approved by: https://github.com/yf225 --- test/dynamo/test_misc.py | 13 +++++++++++++ torch/_dynamo/variables/dicts.py | 1 + 2 files changed, 14 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 82c0368c5b153..d34670c357bf4 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -11945,6 +11945,19 @@ def fn(x, d): with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn(torch.randn(4), d) + def test_hash_hop(self): + associative_scan = importlib.import_module( + "torch._higher_order_ops.associative_scan" + ) + + @torch.compile(fullgraph=True) + def fn(y, s): + d = dict() + d[s] = y + return d[s] + 1.0 + + fn(torch.ones(2, 2, device="cpu"), associative_scan.AssociativeScanOp()) + def test_iter_type(self): @torch.compile(fullgraph=True) def fn(y): diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index edb1169cb193b..dc3929c9cce4c 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -120,6 +120,7 @@ def is_hashable(x): variables.TypingVariable, variables.FunctoolsPartialVariable, variables.WeakRefVariable, + variables.TorchHigherOrderOperatorVariable, ), ) From 97649811164c3c4186a9539a8713844e079f2125 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 5 Aug 2025 08:26:29 -0700 Subject: [PATCH 0035/1424] Pass fw/bw compilers to aot_export_joint_with_descriptors (#159814) Allow overriding nop compilers with real ones when using this flow. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159814 Approved by: https://github.com/fmassa --- torch/_functorch/aot_autograd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 6be696fddbaff..cecfda2bcf1c6 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1154,6 +1154,8 @@ def aot_export_joint_with_descriptors( decompositions: Optional[dict] = None, keep_inference_input_mutations=False, ignore_shape_env=False, + fw_compiler: Optional[AOTDispatchCompiler] = boxed_nop_preserve_node_meta, + bw_compiler: Optional[AOTDispatchCompiler] = boxed_nop_preserve_node_meta, ) -> JointWithDescriptors: """ This API captures the joint graph for an nn.Module. However, unlike @@ -1231,8 +1233,8 @@ def aot_export_joint_with_descriptors( mod, args, kwargs, - boxed_nop_preserve_node_meta, - boxed_nop_preserve_node_meta, + fw_compiler, + bw_compiler, default_partition, decompositions, keep_inference_input_mutations, From 3461988a4b09aaba582297128ba05b9a42264a06 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Wed, 6 Aug 2025 05:02:31 +0000 Subject: [PATCH 0036/1424] [audio hash update] update the pinned audio hash (#159823) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159823 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 70e9da5216ae2..5e75486031249 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -9b57c7bd5ad4db093c5bb31c802df9f04d933ac9 +6fbc710b617f79b992ef2ebc7f95e818aa390293 From d0fccbc99c6dc7e4d8733005e1a35610e2c5aa43 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 5 Aug 2025 14:25:28 -0700 Subject: [PATCH 0037/1424] [CI] Delete sm86 tests from pull (#159903) And delete sm89+cuda12.4 builds from periodic (as sm86+legacy driver should be enough) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159903 Approved by: https://github.com/huydhn --- .github/workflows/periodic.yml | 31 ------------------------- .github/workflows/pull.yml | 42 +++++----------------------------- 2 files changed, 6 insertions(+), 67 deletions(-) diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 976fb241c99f9..7d43c68c61b04 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -51,37 +51,6 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-jammy-cuda12_4-py3_10-gcc11-sm89-build: - name: linux-jammy-cuda12.4-py3.10-gcc11-sm89 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11 - cuda-arch-list: 8.9 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_4-py3_10-gcc11-sm89-test: - name: linux-jammy-cuda12.4-py3.10-gcc11-sm89 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_4-py3_10-gcc11-sm89-build - - target-determination - with: - build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89 - docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-cuda12_4-py3_10-gcc11-build: name: linux-jammy-cuda12.4-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 519a1a870b16f..061586437a1a9 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -292,13 +292,14 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda12.8-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: 8.9 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, ]} secrets: inherit @@ -402,37 +403,6 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc11-sm89-build: - name: linux-jammy-cuda12.8-py3.10-gcc11-sm89 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - cuda-arch-list: 8.9 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc11-sm89-test: - name: linux-jammy-cuda12.8-py3.10-gcc11-sm89 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_8-py3_10-gcc11-sm89-build - - target-determination - with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-py3-clang12-executorch-build: if: false # Docker build needs pin update name: linux-jammy-py3-clang12-executorch From 2457e62c90a53e28293d9ebd5983bb58b463d1ee Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 6 Aug 2025 05:30:20 +0000 Subject: [PATCH 0038/1424] Revert "Set PYTHONHOME for inductor subprocesses using torch (#159382)" This reverts commit fe8984a9f43bde10d1956abe7cb40710ed7ceed2. Reverted https://github.com/pytorch/pytorch/pull/159382 on behalf of https://github.com/malfet due to Broke MacOS testing see https://hud.pytorch.org/hud/pytorch/pytorch/d0fccbc99c6dc7e4d8733005e1a35610e2c5aa43/1?per_page=50&name_filter=macos ([comment](https://github.com/pytorch/pytorch/pull/159382#issuecomment-3157455367)) --- torch/_inductor/autotune_process.py | 3 --- torch/_inductor/compile_worker/subproc_pool.py | 3 --- torch/_inductor/cpu_vec_isa.py | 4 ---- 3 files changed, 10 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index c3d4b6af651dc..c936fbe92c671 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -12,7 +12,6 @@ import selectors import subprocess import sys -import sysconfig import time import warnings from collections.abc import Iterable, Sequence @@ -129,8 +128,6 @@ def start(self): "PYTHONPATH": os.environ.get( "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) ), - # Need to set this for internal builds that bundle the runtime. - "PYTHONHOME": sysconfig.get_path("data"), # We shouldn't be using the Triton async compile subprocess pool, # but as a precaution set the env var that disables its creation. "TORCH_WARM_POOL": "0", diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 80e7e75898cbf..0b670b268b37e 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -8,7 +8,6 @@ import struct import subprocess import sys -import sysconfig import threading import traceback import typing @@ -159,8 +158,6 @@ def __init__( "PYTHONPATH": os.environ.get( "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) ), - # Need to set this for internal builds that bundle the runtime. - "PYTHONHOME": sysconfig.get_path("data"), # Safeguard against creating a SubprocPool in the subprocess. "TORCH_WARM_POOL": "0", # Some internal usages need a modified LD_LIBRARY_PATH. diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 71a27e99628db..b077c4da9c28d 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -6,7 +6,6 @@ import re import subprocess import sys -import sysconfig import warnings from typing import Any, Callable, Union @@ -134,12 +133,9 @@ def check_build(self, code: str) -> bool: stderr=subprocess.DEVNULL, env={ **os.environ, - # We need to set the PYTHONPATH so the subprocess can find torch. "PYTHONPATH": os.environ.get( "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) ), - # Need to set this for internal builds that bundle the runtime. - "PYTHONHOME": sysconfig.get_path("data"), }, ) except Exception: From e9d27aa8fd5aa4f9dc08b13ede6f91cc8831207b Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Wed, 6 Aug 2025 06:03:58 +0000 Subject: [PATCH 0039/1424] [CUDA 13] CMake/Dependencies: no need to call find_package(CUB) (#159854) CUB library is the part of CCCL of the CUDA Toolkit 13. If CUDA Found, CUB is found as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159854 Approved by: https://github.com/eqy --- cmake/Dependencies.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 3b4b6adac94b1..0501e00c08664 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1143,7 +1143,7 @@ if(USE_UCC) endif() # ---[ CUB -if(USE_CUDA) +if(USE_CUDA AND CUDA_VERSION VERSION_LESS 13.0) find_package(CUB) if(NOT CUB_FOUND) message(FATAL_ERROR "Cannot find CUB.") From 1690c0c3a047253d4e401ab2b0233bbf3039571c Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 6 Aug 2025 07:36:37 +0000 Subject: [PATCH 0040/1424] [Reland] Migrate ScalarType to headeronly (#159911) The non ghstack version of #159416, to make sure we don't get reverted again Pull Request resolved: https://github.com/pytorch/pytorch/pull/159911 Approved by: https://github.com/mikaylagawarecki --- c10/core/ScalarType.h | 76 +----------------- test/cpp/aoti_abi_check/test_dtype.cpp | 58 ++++++++++++++ torch/header_only_apis.txt | 5 ++ torch/headeronly/CMakeLists.txt | 1 + torch/headeronly/core/ScalarType.h | 103 +++++++++++++++++++++++++ torch/headeronly/ovrsource_defs.bzl | 1 + 6 files changed, 171 insertions(+), 73 deletions(-) create mode 100644 torch/headeronly/core/ScalarType.h diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 3d8a2b0074e9e..4a15eb23ac63c 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -19,25 +19,16 @@ #include #include -#include #include #include #include #include -namespace c10 { - -// dummy struct for uint1 to uint7, actual functionality -// of these dtypes will be implemented in python with Tensor subclass -template -struct dummy_uint1_7_t {}; +#include -// dummy struct for int1 to int7, actual functionality -// of these dtypes will be implemented in python with Tensor subclass -template -struct dummy_int1_7_t {}; +namespace c10 { -// For the macros below: +// [dtype Macros note] For the macros below: // // For users: If you want to macro some code for all non-QInt scalar types // (i.e. types with complete information, you probably want one of the @@ -57,56 +48,6 @@ struct dummy_int1_7_t {}; // some old PRs where we added new dtypes (check history of this file) can // help give you an idea where to start. -// NB: Order matters for this macro; it is relied upon in -// _promoteTypesLookup and the serialization format. -#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ - _(uint8_t, Byte) /* 0 */ \ - _(int8_t, Char) /* 1 */ \ - _(int16_t, Short) /* 2 */ \ - _(int, Int) /* 3 */ \ - _(int64_t, Long) /* 4 */ \ - _(at::Half, Half) /* 5 */ \ - _(float, Float) /* 6 */ \ - _(double, Double) /* 7 */ \ - _(c10::complex, ComplexHalf) /* 8 */ \ - _(c10::complex, ComplexFloat) /* 9 */ \ - _(c10::complex, ComplexDouble) /* 10 */ \ - _(bool, Bool) /* 11 */ \ - _(c10::qint8, QInt8) /* 12 */ \ - _(c10::quint8, QUInt8) /* 13 */ \ - _(c10::qint32, QInt32) /* 14 */ \ - _(at::BFloat16, BFloat16) /* 15 */ \ - _(c10::quint4x2, QUInt4x2) /* 16 */ \ - _(c10::quint2x4, QUInt2x4) /* 17 */ \ - _(c10::bits1x8, Bits1x8) /* 18 */ \ - _(c10::bits2x4, Bits2x4) /* 19 */ \ - _(c10::bits4x2, Bits4x2) /* 20 */ \ - _(c10::bits8, Bits8) /* 21 */ \ - _(c10::bits16, Bits16) /* 22 */ \ - _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ - _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ - _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ - _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ - _(uint16_t, UInt16) /* 27 */ \ - _(uint32_t, UInt32) /* 28 */ \ - _(uint64_t, UInt64) /* 29 */ \ - _(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \ - _(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \ - _(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \ - _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ - _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ - _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ - _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \ - _(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \ - _(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \ - _(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \ - _(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ - _(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ - _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ - _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ - _(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ - _(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ - // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() // doesn't work for all the conversions you need... @@ -152,17 +93,6 @@ struct dummy_int1_7_t {}; _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \ _(at::Float8_e8m0fnu, Float8_e8m0fnu) -enum class ScalarType : int8_t { -#define DEFINE_ST_ENUM_VAL_(_1, n) n, - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) -#undef DEFINE_ENUM_ST_ENUM_VAL_ - Undefined, - NumOptions -}; - -constexpr uint16_t NumScalarTypes = - static_cast(ScalarType::NumOptions); - namespace impl { // These are used to map ScalarTypes to C++ types. diff --git a/test/cpp/aoti_abi_check/test_dtype.cpp b/test/cpp/aoti_abi_check/test_dtype.cpp index d019b4144a9d0..e6e7e75867c8d 100644 --- a/test/cpp/aoti_abi_check/test_dtype.cpp +++ b/test/cpp/aoti_abi_check/test_dtype.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -149,3 +150,60 @@ TEST(TestDtype, TestQuintsQintsAndBits) { auto i = torch::headeronly::bits8(2); auto j = torch::headeronly::bits16(6); } + +TEST(TestDtype, TestScalarType) { + using torch::headeronly::ScalarType; + constexpr ScalarType expected_scalar_types[] = { + ScalarType::Byte, + ScalarType::Char, + ScalarType::Short, + ScalarType::Int, + ScalarType::Long, + ScalarType::Half, + ScalarType::Float, + ScalarType::Double, + ScalarType::ComplexHalf, + ScalarType::ComplexFloat, + ScalarType::ComplexDouble, + ScalarType::Bool, + ScalarType::QInt8, + ScalarType::QUInt8, + ScalarType::QInt32, + ScalarType::BFloat16, + ScalarType::QUInt4x2, + ScalarType::QUInt2x4, + ScalarType::Bits1x8, + ScalarType::Bits2x4, + ScalarType::Bits4x2, + ScalarType::Bits8, + ScalarType::Bits16, + ScalarType::Float8_e5m2, + ScalarType::Float8_e4m3fn, + ScalarType::Float8_e5m2fnuz, + ScalarType::Float8_e4m3fnuz, + ScalarType::UInt16, + ScalarType::UInt32, + ScalarType::UInt64, + ScalarType::UInt1, + ScalarType::UInt2, + ScalarType::UInt3, + ScalarType::UInt4, + ScalarType::UInt5, + ScalarType::UInt6, + ScalarType::UInt7, + ScalarType::Int1, + ScalarType::Int2, + ScalarType::Int3, + ScalarType::Int4, + ScalarType::Int5, + ScalarType::Int6, + ScalarType::Int7, + ScalarType::Float8_e8m0fnu, + ScalarType::Float4_e2m1fn_x2, + ScalarType::Undefined, + }; + for (int8_t i = 0; i < static_cast(torch::headeronly::NumScalarTypes); + i++) { + EXPECT_EQ(static_cast(i), expected_scalar_types[i]); + } +} diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 72a1b46fb37e8..4cfeeb6238ad5 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -94,3 +94,8 @@ bits2x4 bits4x2 bits8 bits16 + +# torch/headeronly/core/ScalarType.h +NumScalarTypes +ScalarType +# dummy_int1_7_t, dummy_uint1_7_t tested through ScalarType diff --git a/torch/headeronly/CMakeLists.txt b/torch/headeronly/CMakeLists.txt index 3b8f0d5466de0..93d2d7802b528 100644 --- a/torch/headeronly/CMakeLists.txt +++ b/torch/headeronly/CMakeLists.txt @@ -20,6 +20,7 @@ configure_file( file(GLOB HEADERONLY_HEADERS *.h + core/**/*.h cpu/**/*.h macros/*.h util/*.h diff --git a/torch/headeronly/core/ScalarType.h b/torch/headeronly/core/ScalarType.h new file mode 100644 index 0000000000000..0e426427997b3 --- /dev/null +++ b/torch/headeronly/core/ScalarType.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace c10 { + +// dummy struct for uint1 to uint7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_uint1_7_t {}; + +// dummy struct for int1 to int7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_int1_7_t {}; + +// See [dtype Macros note] in c10/core/ScalarType.h regarding macros + +// NB: Order matters for this macro; it is relied upon in +// _promoteTypesLookup and the serialization format. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(c10::complex, ComplexHalf) /* 8 */ \ + _(c10::complex, ComplexFloat) /* 9 */ \ + _(c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UInt16) /* 27 */ \ + _(uint32_t, UInt32) /* 28 */ \ + _(uint64_t, UInt64) /* 29 */ \ + _(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \ + _(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \ + _(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \ + _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ + _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ + _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ + _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \ + _(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \ + _(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \ + _(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \ + _(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ + _(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ + _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ + _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ + _(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ + _(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ + +enum class ScalarType : int8_t { +#define DEFINE_ST_ENUM_VAL_(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) +#undef DEFINE_ENUM_ST_ENUM_VAL_ + Undefined, + NumOptions +}; + +constexpr uint16_t NumScalarTypes = + static_cast(ScalarType::NumOptions); + +} // namespace c10 + +namespace torch::headeronly { +using c10::dummy_int1_7_t; +using c10::dummy_uint1_7_t; +using c10::NumScalarTypes; +using c10::ScalarType; +} // namespace torch::headeronly diff --git a/torch/headeronly/ovrsource_defs.bzl b/torch/headeronly/ovrsource_defs.bzl index c590f388ffb0e..3c3030c048b11 100644 --- a/torch/headeronly/ovrsource_defs.bzl +++ b/torch/headeronly/ovrsource_defs.bzl @@ -29,6 +29,7 @@ def define_torch_headeronly_ovrsource(name, is_mobile): public_include_directories = ["../.."], public_preprocessor_flags = pp_flags, public_raw_headers = native.glob([ + "core/**/*.h", "cpu/**/*.h", "macros/*.h", "util/*.h", From abfe4039811a28bae8c4e87abfdbaf576505b662 Mon Sep 17 00:00:00 2001 From: Mengtian Xu Date: Wed, 6 Aug 2025 07:39:39 +0000 Subject: [PATCH 0041/1424] [AIDIR] Internal util function to insert MLHub debugging insight for dynamic shape (#159391) Summary: This feature is Meta internal only Add a util function to put dynamic shape-related suggestion to MLHubDebugInsightService, which will then be surfaced to users in the MLHub . The rollout will be controlled by JK. Test Plan: MAST job aps-omnifmv3_dev_baseline_test-a34fdccf21 {F1980593060} * If you're not able to see the insight, please add yourself to this gk 'mlhub_debugging_insights_dev_visibility' * The URL link should route to a new Job Inspector page that will provide details and straight forward instructions of how to config the ds. The page is currently still in development so here we use the general PT2 compile JI page. * Test fails because of the export checks. I'll export after addressing all the comments from reviewers. Rollback Plan: Reviewed By: pianpwk Differential Revision: D78526522 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159391 Approved by: https://github.com/jingsh --- torch/_dynamo/pgo.py | 11 +++++++++++ torch/_utils_internal.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 403187bc6bde8..5e12e0dc36a80 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -173,6 +173,7 @@ class CodeState: _INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None _CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None +_LOGGED_DYNAMIC_ALLOWLIST: bool = False @dataclasses.dataclass(frozen=True) @@ -616,6 +617,7 @@ def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]: def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: + global _LOGGED_DYNAMIC_ALLOWLIST code_id = CodeId.make(f_code) frame_state = get_code_state()[code_id] frame_whitelist = ",".join(_collect_dynamic_sources(frame_state)) @@ -624,6 +626,15 @@ def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: CompileEventLogger.pt2_compile( name, recompile_dynamic_whitelist=frame_whitelist ) + if not _LOGGED_DYNAMIC_ALLOWLIST: + torch._utils_internal.add_mlhub_insight( + category="dynamic_shapes_analysis", + insight="Dynamic shapes detected", + insight_description="PGO detected a recompilation due to dynamic shapes. \ + Please follow the instruction from the action link to reduce shape recompilations.", + ) + # add mlhub insight only once per job + _LOGGED_DYNAMIC_ALLOWLIST = True def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 8c448adb0c6a0..4def85ec63a72 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -117,6 +117,10 @@ def signpost_event(category: str, name: str, parameters: dict[str, Any]): log.info("%s %s: %r", category, name, parameters) +def add_mlhub_insight(category: str, insight: str, insight_description: str): + pass + + def log_compilation_event(metrics): log.info("%s", metrics) From 0495cab545e0004672fa0e1fbe4cc3ffcf543a16 Mon Sep 17 00:00:00 2001 From: Colin L Reliability Rice Date: Wed, 6 Aug 2025 07:39:47 +0000 Subject: [PATCH 0042/1424] Wire in pt2_triton_builds (#159897) Summary: This allows us to start seeing the failure rate on these models (and potentially alert on it). Test Plan: ``` FORCE_LOG_TRITON_BUILDS_TO_PROD=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run @//mode/opt :compile 2>&1 | tee out ``` P1889607054 Waiting for scuba table to generate, but manual logging show it should show up at https://fburl.com/scuba/pt2_triton_builds_inc_archive/7852kt8h soon. Rollback Plan: Reviewed By: masnesral Differential Revision: D79308333 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159897 Approved by: https://github.com/masnesral --- torch/_inductor/async_compile.py | 40 ++++++++++++++---------- torch/_inductor/runtime/compile_tasks.py | 25 ++++++++++----- torch/_utils_internal.py | 4 +++ 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 0a12356de6701..b238383069233 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -49,6 +49,7 @@ ) from torch._inductor.utils import clear_on_fresh_cache from torch._inductor.virtualized import V +from torch._utils_internal import log_triton_builds from torch.hub import _Faketqdm, tqdm from torch.utils._ordered_set import OrderedSet from torch.utils._triton import has_triton_package @@ -479,22 +480,29 @@ def get_result() -> CachingAutotuner: log_waitcounter=True, waitcounter_name_override="compile_triton", ): - start_ns = time_ns() - _set_triton_ptxas_path() - kernel = load_kernel() - kernel.set_compile_info(compile_id, is_backward) - kernel.precompile( - warm_cache_only=False, - static_triton_bundle_key=CompiledTritonKernels.key(source_code), - ) - elapsed_us = (time_ns() - start_ns) // 1000 - get_metrics_context().add_top_n( - "triton_kernel_compile_times_us", kernel_name, elapsed_us - ) - info = kernel.autotune_cache_info or {} - info["compile_time_us"] = elapsed_us - _add_triton_kernel_info(kernel_name, info) - return kernel + fail = None + try: + start_ns = time_ns() + _set_triton_ptxas_path() + kernel = load_kernel() + kernel.set_compile_info(compile_id, is_backward) + kernel.precompile( + warm_cache_only=False, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), + ) + elapsed_us = (time_ns() - start_ns) // 1000 + get_metrics_context().add_top_n( + "triton_kernel_compile_times_us", kernel_name, elapsed_us + ) + info = kernel.autotune_cache_info or {} + info["compile_time_us"] = elapsed_us + _add_triton_kernel_info(kernel_name, info) + return kernel + except Exception as e: + fail = str(e) + raise + finally: + log_triton_builds(fail=fail) def multi_kernel(self, *args, **kwargs) -> Any: from torch._inductor.codegen.multi_kernel import MultiKernelCall diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 67140369faac4..850c7660d5d99 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -10,6 +10,8 @@ from types import ModuleType from typing import Any, Callable, TYPE_CHECKING +from torch._utils_internal import log_triton_builds + if TYPE_CHECKING: from torch._inductor.runtime.triton_heuristics import CachingAutotuner @@ -57,11 +59,18 @@ def _worker_compile_triton( from torch._inductor import config with config.patch(extra_config): - start_ns = time.time_ns() - kernel = load_kernel() - kernel.precompile(warm_cache_only=True) - elapsed_ns = time.time_ns() - start_ns - kernel.prepare_for_pickle() - # We can release this memory in the compile subprocesses: - linecache.clearcache() - return kernel, elapsed_ns // 1000 + fail = None + try: + start_ns = time.time_ns() + kernel = load_kernel() + kernel.precompile(warm_cache_only=True) + elapsed_ns = time.time_ns() - start_ns + kernel.prepare_for_pickle() + # We can release this memory in the compile subprocesses: + linecache.clearcache() + return kernel, elapsed_ns // 1000 + except Exception as e: + fail = str(e) + raise + finally: + log_triton_builds(fail=fail) diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 4def85ec63a72..f2613e734bbf8 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -354,3 +354,7 @@ def get_default_numa_options(): Must return None or NumaOptions, but not specifying to avoid circular import. """ return None + + +def log_triton_builds(fail: Optional[str]): + pass From dad2a05bec03ed1fef45b8e72de5cca1a5dd7eaa Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 5 Aug 2025 10:59:26 -0700 Subject: [PATCH 0043/1424] [DTensor] Set up DTensorContinuousTestBase (#159885) Also migrate `test_common_rules.py` since it was a short file `python test/distributed/tensor/test_common_rules.py` Before: Ran 10 tests in 91.516s After: Ran 10 tests in 5.604s Pull Request resolved: https://github.com/pytorch/pytorch/pull/159885 Approved by: https://github.com/ezyang --- test/distributed/tensor/test_common_rules.py | 43 +++++++------------ .../distributed/_tensor/common_dtensor.py | 18 ++++++++ 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/test/distributed/tensor/test_common_rules.py b/test/distributed/tensor/test_common_rules.py index b320f80fe03c6..3450f8faa2b5c 100644 --- a/test/distributed/tensor/test_common_rules.py +++ b/test/distributed/tensor/test_common_rules.py @@ -8,20 +8,17 @@ from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorContinuousTestBase, ) aten = torch.ops.aten -class CommonRulesTest(DTensorTestBase): - @property - def world_size(self) -> int: - # hard code world size to 4 as we need to test - # at least with 2d mesh - return 4 +class CommonRulesTest(DTensorContinuousTestBase): + # hard code world size to 4 as we need to test + # at least with 2d mesh + world_size = 4 def _gen_tensor_meta(self, shape): empty_tensor = torch.empty(shape) @@ -31,10 +28,9 @@ def _gen_tensor_meta(self, shape): empty_tensor.dtype, ) - @with_comms def test_einop_basic_propagation(self): # plain einsum, mm - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) mm_call = aten.mm.default # propagate col-wise sharding @@ -85,9 +81,8 @@ def test_einop_basic_propagation(self): self.assertIsNotNone(output_spec) self.assertTrue(output_spec.placements[0].is_partial()) - @with_comms def test_einop_pointwise_propagation(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) add_call = aten.add.Tensor # addition @@ -137,13 +132,12 @@ def test_einop_pointwise_propagation(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, -1, -1]) - @with_comms def test_einop_merge_sharding(self): # 2d mesh einop merge sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) mm_call = aten.mm.default @@ -163,12 +157,11 @@ def test_einop_merge_sharding(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, 1]) - @with_comms def test_einop_linearity(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) mm_call = aten.mm.default @@ -231,11 +224,10 @@ def test_einop_linearity(self): # mat2 mesh dim 1 should become partial now! self.assertTrue(mat2_spec.placements[1].is_partial()) - @with_comms def test_einop_multi_sharding_on_mesh_dim(self): # einop prop with multi sharding on same mesh dim mesh_shape = torch.arange(self.world_size) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) mm_call = aten.mm.default mat1, mat2 = [0, -1], [0, -1] @@ -260,12 +252,11 @@ def test_einop_multi_sharding_on_mesh_dim(self): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1]) - @with_comms def test_einop_errors(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) add_call = aten.add.Tensor mat1, mat2 = [0, -1], [1, -1] @@ -281,9 +272,8 @@ def test_einop_errors(self): with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"): einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {})) - @with_comms def test_pointwise_rules_broadcasting(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) where_call = aten.where.self inp1, inp2, inp3 = [0], [], [-1, -1] @@ -307,9 +297,8 @@ def test_pointwise_rules_broadcasting(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0]) - @with_comms def test_pointwise_rules_suggestion(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) lerp_call = aten.lerp.Scalar # propagate point-wise sharding @@ -335,13 +324,12 @@ def test_pointwise_rules_suggestion(self): self.assertEqual(len(schema_suggestion.args_schema), 3) self.assertEqual(schema_suggestion.args_schema[2], -1) - @with_comms def test_pointwise_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) add_call = aten.add.Tensor @@ -381,13 +369,12 @@ def test_pointwise_multi_sharding_on_mesh_dim(self): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2) - @with_comms def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = DeviceMesh(self.device_type(), mesh_shape) add_call = aten.add_.Tensor diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 32fdcce997eca..f3a72441f3704 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -31,6 +31,7 @@ SequenceParallel, ) from torch.testing._internal.common_distributed import ( + MultiProcContinousTest, MultiProcessTestCase, MultiThreadedTestCase, run_subtests, @@ -41,6 +42,8 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec +DEVICE_COUNT: int + if TEST_CUDA: DEVICE_TYPE = "cuda" PG_BACKEND = "nccl" @@ -334,6 +337,21 @@ def skip_unless_torch_gpu(method: T) -> T: return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method)) +class DTensorContinuousTestBase(MultiProcContinousTest): + @classmethod + def device_type(cls) -> str: + # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU + if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < cls.world_size: + return "cpu" + else: + return DEVICE_TYPE + + @classmethod + def backend_str(cls) -> str: + backend = dist.get_default_backend_for_device(DEVICE_TYPE) + return backend + + class DTensorTestBase(MultiProcessTestCase): @property def world_size(self) -> int: From e7feedf6a9bb346ad205796aa4084c8dcfb18072 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 5 Aug 2025 08:26:59 -0700 Subject: [PATCH 0044/1424] Replace C array with std::array in formatSockAddr (#159812) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159812 Approved by: https://github.com/Skylion007 --- torch/csrc/distributed/c10d/socket.cpp | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index b23722ec384ab..f64d6ec20aa02 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -199,12 +200,18 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { // job, logging IP addresses instead. See // https://github.com/pytorch/pytorch/issues/159007 static bool disable_getnameinfo = false; - - char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT + std::array host{}; + std::array port{}; if (!disable_getnameinfo) { int err = ::getnameinfo( - addr, len, host, NI_MAXHOST, port, NI_MAXSERV, NI_NUMERICSERV); + addr, + len, + host.data(), + NI_MAXHOST, + port.data(), + NI_MAXSERV, + NI_NUMERICSERV); if (err != 0) { C10D_WARNING( "The hostname of the client socket cannot be retrieved. err={}", err); @@ -221,17 +228,17 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { // if we can't resolve the hostname, display the IP address if (addr->sa_family == AF_INET) { struct sockaddr_in* psai = (struct sockaddr_in*)&addr; - // NOLINTNEXTLINE(*array*) - char ip[INET_ADDRSTRLEN]; - if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != + std::array ip{}; + if (inet_ntop( + addr->sa_family, &(psai->sin_addr), ip.data(), INET_ADDRSTRLEN) != nullptr) { return fmt::format("{}:{}", ip, psai->sin_port); } } else if (addr->sa_family == AF_INET6) { struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr; - // NOLINTNEXTLINE(*array*) - char ip[INET6_ADDRSTRLEN]; - if (inet_ntop(addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != + std::array ip{}; + if (inet_ntop( + addr->sa_family, &(psai->sin6_addr), ip.data(), INET6_ADDRSTRLEN) != nullptr) { return fmt::format("[{}]:{}", ip, psai->sin6_port); } From 23cf24103963adce84b2b4c027053fec0b29ad94 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 5 Aug 2025 15:51:27 -0700 Subject: [PATCH 0045/1424] [aoti][mps] Initialize mps kernels first (#159753) In some cases we have mps kernels which are reused across higher-order-op subgraphs and the toplevel code. However, currently we initialize the variable for the mps kernel the first time we use it, which runs into an issue if we run into the mps kernel within a subgraph since the kernel will only be initialized within the subgraph scope. For instance: ``` if ... auto mps_lib_0_func = ... mps_lib_0_func->run() // since we already used mps_lib_0 once, we don't re-initialize it mps_lib_0_func->run() // error, mps_lib_0_func not initialized ``` So the solution we took here is to initialize all the kernels at the beginning: ``` const std::shared_ptr get_mps_lib_0() { static const auto func = mps_lib_0.getKernelFunction("generated_kernel"); return func; } AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() { static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get()); return handle; } ... if ... get_mps_lib_0()->run() get_mps_lib_0()->run() // success ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159753 Approved by: https://github.com/malfet ghstack dependencies: #159456, #159695 --- test/inductor/test_aot_inductor.py | 6 -- torch/_inductor/codegen/cpp_wrapper_cpu.py | 5 ++ torch/_inductor/codegen/cpp_wrapper_mps.py | 92 ++++++++++++++++------ torch/_inductor/codegen/mps.py | 6 +- 4 files changed, 72 insertions(+), 37 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 9b501315cd9c2..ac3529679e351 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6918,12 +6918,6 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): "test_fp8_view_of_param": fail_mps(), # cannot initialize a parameter of type 'double' with an rvalue of type 'std::nullptr_t' "test_fallback_kernel_with_symexpr_output": fail_mps(), - # while-loop subgraph calls same kernel as outside. need to figure out how to - # either (1) tell outside to initialize a new kernel or (2) generate - # subgraph as a separate function, which would(?) cause (1) to happen automatically. - "test_while_loop_nested": fail_mps(), - "test_cond_with_parameters": fail_mps(), - "test_cond_share_predicte": fail_mps(), # correctness issue "test_index_put_with_none_index": fail_mps(), # Error device may not be nil diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 6047ea916fb17..ebef59717f133 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -518,6 +518,8 @@ def gen_check(handle_kind, idx, name, tensor): def write_wrapper_decl(self): inputs_len = len(V.graph.graph_inputs.keys()) if V.graph.aot_mode: + self.codegen_additional_funcs() + if V.graph.const_module: self.header.splice(V.graph.const_module.wrapper_code.header) @@ -674,6 +676,9 @@ def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name): f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));" ) + def codegen_additional_funcs(self): + pass + def codegen_model_kernels(self): self.prefix.writeline("namespace {") diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index b953927f52be1..aea4470f1c964 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -9,7 +9,7 @@ from ..virtualized import V from .cpp_wrapper_cpu import CppWrapperCpu from .cpp_wrapper_gpu import CppWrapperGpu -from .wrapper import PythonWrapperCodegen +from .wrapper import KernelCallLine, PythonWrapperCodegen class CppWrapperMps(CppWrapperGpu): @@ -47,14 +47,12 @@ def _generate_kernel_call_helper( """ Generates MPS kernel call code. It should look something like: ``` - auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel"); - auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get()); - mps_lib_0_func->runCommandBlock([&] { - mps_lib_0_func->startEncoding(); - aoti_torch_mps_set_arg(mps_lib_0_func_handle, 0, buf0); - aoti_torch_mps_set_arg(mps_lib_0_func_handle, 1, arg0_1); + get_mps_lib_0()->runCommandBlock([&] { + get_mps_lib_0()->startEncoding(); + aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 0, buf0); + aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 1, arg0_1); ... - mps_lib_0_func->dispatch(9); + get_mps_lib_0()->dispatch(9); }); ``` """ @@ -81,11 +79,11 @@ def _generate_kernel_call_helper( for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): if isinstance(arg_type, torch.dtype): new_args.append( - f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});" + f"aoti_torch_mps_set_arg_tensor(get_{kernel_name}_handle(), {idx}, {arg});" ) elif arg_type in (int, sympy.core.symbol.Symbol): new_args.append( - f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});" + f"aoti_torch_mps_set_arg_int(get_{kernel_name}_handle(), {idx}, {arg});" ) else: raise NotImplementedError( @@ -96,9 +94,11 @@ def _generate_kernel_call_helper( if threads is None: raise NotImplementedError("No threads or group_size provided") elif group_size is None: - new_args.append(f"{kernel_name}->dispatch({threads});\n") + new_args.append(f"get_{kernel_name}()->dispatch({threads});\n") else: - new_args.append(f"{kernel_name}->dispatch({threads}, {group_size});\n") + new_args.append( + f"get_{kernel_name}()->dispatch({threads}, {group_size});\n" + ) # debug printer related logic for cpp kernel type. debug_printer_manager = V.graph.wrapper_code.debug_printer @@ -113,20 +113,11 @@ def _generate_kernel_call_helper( self.write_mps_kernel_call(kernel_name, new_args) def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: - # Only add handle definition if the kernel is not already used - lib_name = name[: -len("_func")] - if name not in self._used_kernel_names: - self._used_kernel_names.add(name) - - self.writeline( - f'auto {name} = {lib_name}.getKernelFunction("generated_kernel");' - ) - self.writeline( - f"auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());" - ) - - self.writeline(f"{name}->runCommandBlock([&] {{") - self.writeline(f" {name}->startEncoding();") + # Initialization of the kernel function and kernel function handle + # variables have already been done at the beginning, which was + # codegen-ed in `codegen_mps_func_init` + self.writeline(f"get_{name}()->runCommandBlock([&] {{") + self.writeline(f" get_{name}()->startEncoding();") for call_arg in call_args: self.writeline(f" {call_arg}") self.writeline("});") @@ -138,3 +129,52 @@ def get_device_include_path(device: str) -> str: "#include \n" "#include " ) + + def codegen_additional_funcs(self) -> None: + """ + We want to codegen the mps kernel function variable initializations + ahead of time. This is so that if we reuse kernels within subgraphs, we + don't need to worry about the scope in which we're initializing the + variables. Instead we will just initialize the variables all at the top + level. + + The kernel function variable initializations should look something like: + ``` + const std::shared_ptr get_mps_lib_0() { + static const auto func = mps_lib_0.getKernelFunction("generated_kernel"); + return func; + } + AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() { + static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get()); + return handle; + } + ``` + """ + + for line in self.lines: + if not isinstance(line, KernelCallLine): + continue + if line.device.type != "mps": + continue + + # Only add handle definition once + if line.kernel_name not in self._used_kernel_names: + self._used_kernel_names.add(line.kernel_name) + + self.prefix.writeline( + f"const std::shared_ptr get_{line.kernel_name}() {{" + ) + self.prefix.writeline( + f' static const auto func = {line.kernel_name}.getKernelFunction("generated_kernel");' + ) + self.prefix.writeline(" return func;") + self.prefix.writeline("}") + + self.prefix.writeline( + f"AOTIMetalKernelFunctionHandle get_{line.kernel_name}_handle() {{" + ) + self.prefix.writeline( + f" static const auto handle = AOTIMetalKernelFunctionHandle(get_{line.kernel_name}().get());" + ) + self.prefix.writeline(" return handle;") + self.prefix.writeline("}") diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index d952a45d0b5a1..8b59db126f05d 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -1052,11 +1052,7 @@ def define_kernel( # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" - if V.graph.cpp_wrapper: - kernel_name = f"{mps_lib_name}_func" - else: - kernel_name = f"{mps_lib_name}" - + kernel_name = f"{mps_lib_name}" wrapper.src_to_kernel[src_code] = kernel_name if V.graph.cpp_wrapper: From 98316e589672c96d4f63d1355abdbe050b843ee8 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Wed, 6 Aug 2025 10:28:05 +0000 Subject: [PATCH 0046/1424] [WOQ] Add CUDA kernel for _weight_int8pack_mm (#159325) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary** This issue proposes implementing a CUDA kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU. On CUDA, the fallback path uses an unfused .mul().sum() pattern in quantization.py, which is less efficient for inference. https://github.com/pytorch/pytorch/issues/158849 **Motivation** A fused GPU kernel for aten._weight_int8pack_mm would: - Eliminate reliance on the .mul().sum() fallback in quantization.py - Improve performance for quantized inference on CUDA - Extend Inductor’s GPU quantization support across more workloads **Implementation** - Implement a Triton kernel for: ``` out[b, n] = sum_k(x[b, k] * w[n, k]) * scale[n] where: x: [B, K] float32 w: [N, K] int8 scale: [N] float32 out: [B, N] float32 ``` - Integrate the kernel with register_woq_mm_ops() in torch/_inductor/quantized_lowerings.py - Route it conditionally in quantization.py where GPU currently falls back to .mul().sum() - Add unit tests comparing results to the reference fallback path Test Plan: ``` buck2 run 'fbcode//mode/opt' :linalg test_linalg.TestLinalgCUDA.test__int8_mm_m_64_k_64_n_64_compile_True_slice_True_cuda ``` Log: P1882799769 ``` buck2 test 'fbcode//mode/opt' caffe2/test:linalg ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/6755399722424741/ Benchmark Results: ``` **[Shape B=256, K=1024, N=512]** CPU and CUDA outputs match Max abs diff: 2.59e-04, max rel diff: 0.75 CPU: 144.14 ms, CUDA: 303.67 µs Speedup: ×474.6 **[Shape B=512, K=2048, N=1024]** CPU and CUDA outputs match Max abs diff: 5.49e-04, max rel diff: 0.15 CPU: 1173.27 ms, CUDA: 2.40 ms Speedup: ×488.5 ``` Rollback Plan: Differential Revision: D79042656 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159325 Approved by: https://github.com/danielvegamyhre, https://github.com/jerryzh168 --- aten/src/ATen/native/cuda/int8mm.cu | 74 +++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 1 + test/test_linalg.py | 2 +- .../aoti_torch/generated/c_shim_cuda.h | 1 + 4 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 aten/src/ATen/native/cuda/int8mm.cu diff --git a/aten/src/ATen/native/cuda/int8mm.cu b/aten/src/ATen/native/cuda/int8mm.cu new file mode 100644 index 0000000000000..60f64cd9fc203 --- /dev/null +++ b/aten/src/ATen/native/cuda/int8mm.cu @@ -0,0 +1,74 @@ +#include +#include +#include +#include + +namespace at::native { + +__global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const float* scale, float* out, int B, int K, int N) { + // one thread per output element: [B, N] + int b = blockIdx.y * blockDim.y + threadIdx.y; + int n = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B || n >= N) return; + + float acc = 0.0f; + for (int k = 0; k < K; ++k) { + acc += x[b * K + k] * static_cast(w[n * K + k]); + } + + out[b * N + n] = acc * scale[n]; +} + +void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8, const Tensor& scale, Tensor& out) { + const int B = x.size(0); + const int K = x.size(1); + const int N = w_int8.size(0); + + const dim3 block(16, 16); + const dim3 grid((N + block.x - 1) / block.x, (B + block.y - 1) / block.y); + + auto stream = at::cuda::getCurrentCUDAStream(); + + weight_int8pack_mm_kernel<<>>( + x.data_ptr(), + w_int8.data_ptr(), + scale.data_ptr(), + out.data_ptr(), + B, K, N); +} + + +// Main GPU entry point +at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int8, const at::Tensor& scale) { + // --- Check inputs --- + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(w_int8.is_cuda(), "w must be a CUDA tensor"); + TORCH_CHECK(scale.is_cuda(), "scale must be a CUDA tensor"); + + TORCH_CHECK(x.dim() == 2, "x must be 2D"); + TORCH_CHECK(w_int8.dim() == 2, "w must be 2D"); + TORCH_CHECK(scale.dim() == 1, "scale must be 1D"); + + TORCH_CHECK(x.size(1) == w_int8.size(1), "K dimension mismatch: x.size(1) != w.size(1)"); + TORCH_CHECK(w_int8.size(0) == scale.size(0), "Output dim mismatch: w.size(0) != scale.size(0)"); + + // --- Determine shapes --- + auto B = x.size(0); // batch size + auto N = w_int8.size(0); // output dim + + // Ensure inputs are in the correct types for the kernel + auto x_f32 = x.to(at::kFloat); + auto w_int8_contiguous = w_int8.contiguous(); + auto scale_f32 = scale.to(at::kFloat); + + // --- Allocate output --- + auto out = at::empty({B, N}, x.options().dtype(at::kFloat)); + + // --- Launch kernel --- + launch_weight_int8pack_mm_cuda_kernel(x_f32, w_int8_contiguous, scale_f32, out); + + return out; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index db8eef9349642..8920864b3a719 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4230,6 +4230,7 @@ - func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor dispatch: CPU: _weight_int8pack_mm_cpu + CUDA: _weight_int8pack_mm_cuda MPS: _weight_int8pack_mm_mps - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor diff --git a/test/test_linalg.py b/test/test_linalg.py index f1c8bf5918517..ac668fee049d2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -7765,7 +7765,7 @@ def dyn_quant_matmul_4bit( all_elements_within_threshold, "Some elements have error >= 0.06" ) - @onlyCPU + @onlyNativeDeviceTypes @parametrize("m", [32, 64]) @parametrize("k", [32, 64]) @parametrize("n", [48, 64]) diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 92d30ded855f8..470919cf389c3 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -51,6 +51,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTenso AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_abs(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); From c03a734ba182f46414df4320349417d2c82b1fa9 Mon Sep 17 00:00:00 2001 From: can-gaa-hou Date: Wed, 6 Aug 2025 10:35:10 +0000 Subject: [PATCH 0047/1424] [OpenReg] Disable automatic inclusion of data files (#159845) # Background After I built torch_openreg, I noticed that the wheel package contained the stub.c file under the csrc directory, which was not used in the runtime. # Motivation This PR aims to remove the stub.c file and any unused file when running torch_openreg. **Changes:** - Setting **include_package_data** keyword to false in the setup function Pull Request resolved: https://github.com/pytorch/pytorch/pull/159845 Approved by: https://github.com/albanD --- .../open_registration_extension/torch_openreg/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py index 07d31e73d76ba..386e34cdb56f6 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py @@ -85,6 +85,7 @@ def main(): cmdclass={ "clean": BuildClean, # type: ignore[misc] }, + include_package_data=False, ) From 2231c3ca3a25529115610d8215ee5601c4c8ee89 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 6 Aug 2025 14:44:37 +0000 Subject: [PATCH 0048/1424] [CI][CD] Fix `install_nvshem` function (#159907) When one builds CD docker, all CUDA dependencies must be installed into `/usr/local/cuda/` folder Test plan: Looks at the binary build logs, for example [here](https://github.com/pytorch/pytorch/actions/runs/16768141521/job/47477380147?pr=159907): ``` 2025-08-06T05:58:00.7347471Z -- NVSHMEM_HOME set to: '' 2025-08-06T05:58:00.7348378Z -- NVSHMEM wheel installed at: '' 2025-08-06T05:58:00.7392528Z -- NVSHMEM_HOST_LIB: '/usr/local/cuda/lib64/libnvshmem_host.so' 2025-08-06T05:58:00.7393251Z -- NVSHMEM_DEVICE_LIB: '/usr/local/cuda/lib64/libnvshmem_device.a' 2025-08-06T05:58:00.7393792Z -- NVSHMEM_INCLUDE_DIR: '/usr/local/cuda/include' 2025-08-06T05:58:00.7394252Z -- NVSHMEM found, building with NVSHMEM support ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159907 Approved by: https://github.com/Skylion007, https://github.com/ngimel --- .ci/docker/common/install_cuda.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index c8a780f65c8e5..ebebd195d6b70 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -68,8 +68,8 @@ function install_nvshmem { # download, unpack, install wget -q "${url}" tar xf "${filename}.tar.gz" - cp -a "libnvshmem/include/"* /usr/local/include/ - cp -a "libnvshmem/lib/"* /usr/local/lib/ + cp -a "libnvshmem/include/"* /usr/local/cuda/include/ + cp -a "libnvshmem/lib/"* /usr/local/cuda/lib64/ # cleanup cd .. From 2855688a1dbe29fd2ce40747530ea4042d5be6d8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 6 Aug 2025 14:55:48 +0000 Subject: [PATCH 0049/1424] Revert "Replace C array with std::array in formatSockAddr (#159812)" This reverts commit e7feedf6a9bb346ad205796aa4084c8dcfb18072. Reverted https://github.com/pytorch/pytorch/pull/159812 on behalf of https://github.com/malfet due to Looks like it broke distribtued tests, see https://hud.pytorch.org/hud/pytorch/pytorch/2231c3ca3a25529115610d8215ee5601c4c8ee89/1?per_page=50&name_filter=distributed ([comment](https://github.com/pytorch/pytorch/pull/159812#issuecomment-3160513656)) --- torch/csrc/distributed/c10d/socket.cpp | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index f64d6ec20aa02..b23722ec384ab 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -7,7 +7,6 @@ #include #include -#include #include #include #include @@ -200,18 +199,12 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { // job, logging IP addresses instead. See // https://github.com/pytorch/pytorch/issues/159007 static bool disable_getnameinfo = false; - std::array host{}; - std::array port{}; + + char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT if (!disable_getnameinfo) { int err = ::getnameinfo( - addr, - len, - host.data(), - NI_MAXHOST, - port.data(), - NI_MAXSERV, - NI_NUMERICSERV); + addr, len, host, NI_MAXHOST, port, NI_MAXSERV, NI_NUMERICSERV); if (err != 0) { C10D_WARNING( "The hostname of the client socket cannot be retrieved. err={}", err); @@ -228,17 +221,17 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { // if we can't resolve the hostname, display the IP address if (addr->sa_family == AF_INET) { struct sockaddr_in* psai = (struct sockaddr_in*)&addr; - std::array ip{}; - if (inet_ntop( - addr->sa_family, &(psai->sin_addr), ip.data(), INET_ADDRSTRLEN) != + // NOLINTNEXTLINE(*array*) + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != nullptr) { return fmt::format("{}:{}", ip, psai->sin_port); } } else if (addr->sa_family == AF_INET6) { struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr; - std::array ip{}; - if (inet_ntop( - addr->sa_family, &(psai->sin6_addr), ip.data(), INET6_ADDRSTRLEN) != + // NOLINTNEXTLINE(*array*) + char ip[INET6_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != nullptr) { return fmt::format("[{}]:{}", ip, psai->sin6_port); } From 79eca4677b8ca536cea370c48a4752d5e6e37066 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Wed, 6 Aug 2025 15:00:28 +0000 Subject: [PATCH 0050/1424] [precompile] Skip serializing unnecesssary objects for guards. (#158926) Summary: The following type of objects don't need to be serialized for precompile: 1. PyCapsule because we don't guard on C binding objects in meaningful ways. 2. Code object because we only id matching on these but id matches will always be dropped for precompile. 3. Nested function objects since we also ban CLOSURE_MATCH. Test Plan: buck run mode/opt test/dynamo:test_dynamo -- -k test_skipped_objects Rollback Plan: Differential Revision: D78816888 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158926 Approved by: https://github.com/jamesjwu --- test/dynamo/test_guard_serialization.py | 21 ++++++++++++++++++ torch/_dynamo/guards.py | 29 +++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 10808c922b3fb..969460364630e 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -1325,6 +1325,27 @@ def getattr_new(*args, **kwargs): finally: builtins_dict["getattr"] = getattr_original + def test_skipped_objects(self): + def foo(): + pass + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.code = foo.__code__ + self.foo = foo + self.p = torch.nn.Parameter(torch.randn(3, 2)) + + def forward(self, x): + z = x + 1 + for p in self.parameters(): + z += p + return z + + m = Module() + ref, loaded = self._test_serialization("TENSOR_MATCH", m, torch.randn(3, 2)) + self._test_check_fn(ref, loaded, {"self": m, "x": torch.randn(3, 2)}, True) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 50220f3e23299..2d5d0af995b59 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2885,6 +2885,10 @@ class GuardsState: shape_code_parts: Optional[ShapeCodeParts] +class _Missing: + pass + + class GuardsStatePickler(pickle.Pickler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2944,6 +2948,10 @@ def _unpickle_functorch_interpreter(cls, json: bytes): def _unpickle_mapping_proxy(cls, d): return types.MappingProxyType(d) + @classmethod + def _unpickle_c_op(cls, name): + return getattr(torch.ops._C, name) + def reducer_override(self, obj): import sympy @@ -3008,6 +3016,27 @@ def reducer_override(self, obj): elif isinstance(obj, types.MappingProxyType): return type(self)._unpickle_mapping_proxy, (obj.copy(),) + elif isinstance( + obj, torch._ops.OpOverloadPacket + ) and obj._qualified_op_name.startswith("_C::"): + return type(self)._unpickle_c_op, (obj.__name__,) + + elif ( + obj.__class__.__module__ == "builtins" + and obj.__class__.__name__ == "PyCapsule" + ): + # Skipping PyCapsule since there isn't much to be guarded about them. + return _Missing, () + + elif isinstance(obj, types.CodeType): + # We only do ID_MATCH on code objects which is already banned from guards serialization. + return _Missing, () + + elif inspect.isfunction(obj) and (obj.__code__.co_flags & inspect.CO_NESTED): + # Skipping nested function since CLOSURE_MATCH is banned from guards serialization. + assert obj.__qualname__ != obj.__name__ + return _Missing, () + if type(obj).__qualname__ != type(obj).__name__: raise torch._dynamo.exc.PackageError( f"Type {type(obj)} for object {obj} cannot be saved " From d87161c3c8f117ae3393990dabba087a5e8687bf Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 5 Aug 2025 14:30:21 -0700 Subject: [PATCH 0051/1424] [Easy] Fix wrong propagation of fallback_ops_dict in gen_aoti_c_shim (#159904) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159904 Approved by: https://github.com/janeyx99 --- torchgen/gen_aoti_c_shim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 655f2bd65b02d..36db26bb5ea67 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -744,7 +744,7 @@ def headers_for_aoti() -> str: f"c_shim_{device_name}.cpp", lambda: gen_aoti_c_shim( fallback_native_functions, - inductor_fallback_ops, + fallback_ops_dict, structured_func_group_dict, dispatch_key, backend_indices, From a4b07fe8f6f053cf13df928f14613c22b5f128f0 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 5 Aug 2025 21:12:49 -0700 Subject: [PATCH 0052/1424] [AOTI] Add more default options to compile_standalone (#158560) Summary: When compiling for standalone, make embed_kernel_binary and emit_multi_arch_kernel default to True, and add a default name for model_name_for_generated_files to make the generated cpp project easier to understand. Also improved the weights object file naming to be more readable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158560 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 12 +++- test/inductor/test_aot_inductor_package.py | 36 ++++++++++ torch/_inductor/codecache.py | 21 +++--- torch/_inductor/codegen/cpp_wrapper_cpu.py | 18 +++-- torch/_inductor/codegen/triton.py | 5 ++ torch/_inductor/config.py | 8 ++- torch/_inductor/cpp_builder.py | 83 +++++++++++++++++----- torch/_inductor/utils.py | 42 ++++++----- torch/export/experimental/_utils.py | 7 +- 9 files changed, 173 insertions(+), 59 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index ac3529679e351..de8a34809bd14 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6833,11 +6833,21 @@ def test_compile_standalone_sets_package_cpp(self): result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True}) self.assertEqual(result["aot_inductor.package_cpp_only"], True) self.assertEqual(result["aot_inductor.compile_standalone"], True) + self.assertEqual(result["aot_inductor.embed_kernel_binary"], True) + self.assertEqual( + result["aot_inductor.emit_multi_arch_kernel"], not torch.version.hip + ) + self.assertEqual( + result["aot_inductor.model_name_for_generated_files"], "aoti_model" + ) - def test_compile_standalone_package_cpp_already_true(self): + def test_compile_standalone_explicit_set(self): patches = { "aot_inductor.compile_standalone": True, "aot_inductor.package_cpp_only": True, + "aot_inductor.embed_kernel_binary": True, + "aot_inductor.emit_multi_arch_kernel": not torch.version.hip, + "aot_inductor.model_name_for_generated_files": "aoti_model", } result = maybe_aoti_standalone_config(patches) self.assertEqual(result, patches) diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 51343b6b1883e..2809f5533bd9c 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -15,6 +15,7 @@ from parameterized import parameterized_class import torch +import torch._inductor.config from torch._inductor.codecache import get_kernel_bin_format from torch._inductor.package import load_package, package_aoti from torch._inductor.test_case import TestCase @@ -363,6 +364,7 @@ def forward(self, x, y): ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfXpu # build system may be different + @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_after_package_static(self): # compile_standalone will set package_cpp_only=True self.check_package_cpp_only() @@ -419,12 +421,46 @@ def forward(self, x, y): with self.assertRaisesRegex(Exception, "Invalid AOTI model name"): self.cmake_compile(model, example_inputs, options, "") + @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") + @skipIfXpu # build system may be different + @torch._inductor.config.patch("test_configs.use_libtorch", True) + def test_compile_standalone_cos(self): + # compile_standalone will set package_cpp_only=True + self.check_package_cpp_only() + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return torch.cos(x) + + with torch.no_grad(): + example_inputs = (torch.randn(8, 32, device=self.device),) + model = Model().to(device=self.device) + + # Test compilation when model name is passed in + options = { + "aot_inductor.compile_standalone": True, + "aot_inductor.model_name_for_generated_files": "cos", + } + with ( + tempfile.TemporaryDirectory() as tmp_dir, + ): + build_path, _ = self.cmake_compile( + model, example_inputs, options, tmp_dir + ) + # Check if the .a file was build successfully + a_path = build_path / "libcos.a" + self.assertTrue(a_path.exists()) + @unittest.skipIf( _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary + @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_with_exporter(self): self.check_package_cpp_only() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 451f72f621691..e404cd78936f0 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1711,12 +1711,6 @@ def compile( wrapper_code = "\n".join((wrapper_code, kernel_code)) kernel_code = "" - from .utils import aoti_model_name_from_config - - model_class_name = "" - if config.aot_inductor.compile_standalone: - model_class_name = aoti_model_name_from_config() - wrapper_key, wrapper_path = write( wrapper_code, "wrapper.cpp", @@ -1749,6 +1743,8 @@ def compile( "model.h", ) ) as f: + # model_name_for_generated_files is guaranteed to be non-empty when compile_standalone + model_class_name = config.aot_inductor.model_name_for_generated_files class_name = f"AOTInductorModel{model_class_name}" header_code = f.read() @@ -1763,7 +1759,7 @@ def compile( header_code, "h", specified_dir=specified_output_path, - key=f"{model_class_name}", + key=model_class_name, ) # Log the AOTInductor wrapper and kernel code, if needed. @@ -1888,7 +1884,7 @@ def format_consts_to_gnu_asm( consts_asm += f"\t.space {len(consts) - 8}\n" consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n" - return consts_asm, "S" + return consts_asm, "weights.S" # Use c++ to convert consts to object file can support more compilers, such as msvc and icx. def format_consts_to_cpp( @@ -1913,7 +1909,7 @@ def format_consts_to_cpp( const_cpp += "\t\n" const_cpp += "};\t\n" const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" - return const_cpp, "cpp" + return const_cpp, "weights.cpp" def get_zero_consts_asm_code( align_bytes: int, @@ -1979,6 +1975,7 @@ def get_zero_consts_asm_code( consts_code, code_ext, specified_dir=str(specified_sub_dir), + key=config.aot_inductor.model_name_for_generated_files, ) consts_s = Path(consts_s) object_build_options = CppTorchDeviceOptions( @@ -2279,7 +2276,13 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: asm_files = [] if not _IS_WINDOWS: ld, objcopy = get_ld_and_objcopy(use_relative_path) + kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) for kernel_name, value in CudaKernelParamCache.cache.items(): + if kernel_name not in kernels: + # It is possible that CudaKernelParamCache contains more Triton kernels + # than what the current graph uses + continue + if asm_file := value["asm"]: asm_files.append(asm_file) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index ebef59717f133..473b405100745 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -22,13 +22,7 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config, cpp_builder, ir -from ..utils import ( - _align, - aoti_model_name_from_config, - DeferredLineBase, - LineContext, - normalize_name, -) +from ..utils import _align, DeferredLineBase, LineContext, normalize_name from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import get_device_op_overrides, IndentedBuffer, Kernel @@ -64,11 +58,15 @@ def __init__(self): self.device = "cpu" # must be initialized prior to calling super().__init__() self.included_devices: OrderedSet[str] = OrderedSet() - self.model_class_name_suffix = "" - if config.aot_inductor.compile_standalone: - self.model_class_name_suffix = aoti_model_name_from_config() + self.model_class_name_suffix = ( + config.aot_inductor.model_name_for_generated_files + if config.aot_inductor.compile_standalone + else "" + ) self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}" + super().__init__() + self.declare = "auto " self.declare_maybe_reference = "decltype(auto) " self.ending = ";" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 49e10d7c05127..56be9dace0926 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -4483,6 +4483,11 @@ def define_kernel(self, src_code, node_schedule, kernel): kernel_name = "_".join( ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] ) + if config.aot_inductor.model_name_for_generated_files: + # When AOTI compiles multiple submodules, we need to use the model name to + # distinguish kernel related symbols. + kernel_name = f"{config.aot_inductor.model_name_for_generated_files}_{kernel_name}" + # use the original src_code as the key wrapper.src_to_kernel[src_code] = kernel_name subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c6971301efe6c..51a438840b040 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1471,12 +1471,12 @@ class aot_inductor: precompile_headers: bool = not is_fbcode() # Embed generated kernel binary files into model.so - embed_kernel_binary: bool = False + embed_kernel_binary: Optional[bool] = None # Generate kernel files that support multiple archs # For CUDA, this means generating fatbin files for kernels, and the fatbin files # contains PTX and SASS for the current architecture. - emit_multi_arch_kernel: bool = False + emit_multi_arch_kernel: Optional[bool] = None # If not None, the generated files with use this name in file stem. # If None, we will use a hash to name files. @@ -1869,6 +1869,10 @@ class test_configs: track_memory_lifecycle: Optional[Literal["assert", "log"]] = None + # If set to True, AOTI-generated CMakelists.txt will still use libtorch + # for unit testing + use_libtorch = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index b6a0e7aeef2ab..44efd8088c73a 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -28,7 +28,6 @@ from torch._inductor import config, exc from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA from torch._inductor.runtime.runtime_utils import cache_dir -from torch._inductor.utils import aoti_model_name_from_config from torch.torch_version import TorchVersion @@ -1545,7 +1544,9 @@ def __init__( self._aot_mode: bool = False self._name = name - self._target_name = aoti_model_name_from_config() + self._target_name = ( + config.aot_inductor.model_name_for_generated_files or "aoti_model" + ) # Code start here, initial self internal variables firstly. self._build_option = BuildOption @@ -1781,22 +1782,54 @@ def save_compile_cmd_to_cmake( project({self._target_name} LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) - # May need to point CMAKE_PREFIX_PATH to the right torch location - find_package(Torch REQUIRED) - - # Set a shared library target + # Set a library target add_library({self._target_name} {target_library_type}) - # Add macro definitions - target_compile_definitions({self._target_name} PRIVATE {definitions}) - - # Add compile flags - target_compile_options({self._target_name} PRIVATE {self._cflags_args}) - # Backend specific flags - target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) - """ ) + + if ( + not config.aot_inductor.compile_standalone + or config.test_configs.use_libtorch + ): + # When compile_standalone is True, the generated cpp project should + # not use Torch. But for unit testing purpose, we need to use Torch here. + contents += textwrap.dedent( + """ + # May need to point CMAKE_PREFIX_PATH to the right torch location + find_package(Torch REQUIRED) + + """ + ) + # flags and macros here are mostly CPU specific. Not emitting them for GPU models + # will make the generated CMake file more portable and won't really hurt performance. + # NOTE: standalone focuses on GPU now. For CPU, some of the flags and macros may + # be still needed. + contents += textwrap.dedent( + f""" + # Add macro definitions + target_compile_definitions({self._target_name} PRIVATE {definitions}) + + # Add compile flags + target_compile_options({self._target_name} PRIVATE {self._cflags_args}) + + # Backend-specific flags + target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) + + """ + ) + else: + # When compile_standalone is True, use TorchStandalone instead of Torch + contents += textwrap.dedent( + f""" + find_package(TorchStandalone REQUIRED) + # Set up include directories to find headers at the correct paths + target_include_directories({self._target_name} PRIVATE ${{TorchStandalone_INCLUDE_DIRS}}) + target_include_directories({self._target_name} PRIVATE ${{TorchStandalone_INCLUDE_DIRS}}/standalone) + + """ + ) + if device_type == "cuda" and torch.version.hip is None: from torch._inductor.codecache import _nvcc_arch_as_compile_option @@ -1804,7 +1837,11 @@ def save_compile_cmd_to_cmake( contents += textwrap.dedent( f""" enable_language(CUDA) + set(CMAKE_CUDA_STANDARD 17) find_package(CUDAToolkit REQUIRED) + target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}}) + target_compile_definitions({self._target_name} PRIVATE USE_CUDA) + target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static) find_program(OBJCOPY_EXECUTABLE objcopy) if(NOT OBJCOPY_EXECUTABLE) @@ -1833,7 +1870,7 @@ def save_compile_cmd_to_cmake( add_custom_command( OUTPUT ${{FATBIN_FILE}} COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} - -gencode arch=compute_80,code=compute_80 + -gencode arch=compute_{current_arch},code=compute_{current_arch} -gencode arch=compute_{current_arch},code=sm_{current_arch} DEPENDS ${{PTX_FILE}} ) @@ -1882,12 +1919,20 @@ def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> Non """ ) f.write(contents) - f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") - f.write( - f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" - ) + if asm_files: + f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") + f.write( + f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" + ) def save_link_cmd_to_cmake(self, cmake_path: str) -> None: + if ( + config.aot_inductor.compile_standalone + and not config.test_configs.use_libtorch + ): + # When compile_standalone is True, do not link with libtorch + return + lflags = " ".join(self._build_option.get_ldflags()) libs = " ".join(self._build_option.get_libraries()) contents = textwrap.dedent( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 74df1cd732490..4cc6e2c566545 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3427,20 +3427,36 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An Returns: dict[str, Any]: The possibly-updated `config_patches` dictionary. """ + + def patch_config( + config_patches: dict[str, Any], config_name: str, config_value: Any + ) -> None: + value = config_patches.get(config_name, getattr(config, config_name)) + if value is None: + config_patches[config_name] = config_value + elif not value and value != config_value: + raise RuntimeError( + f"Invalid config: {config_name}={config_value} when aot_inductor.compile_standalone is True." + ) + compile_standalone = config_patches.get( "aot_inductor.compile_standalone", config.aot_inductor.compile_standalone ) + # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing + config_patches = config_patches.copy() if compile_standalone: - package_cpp_only = config_patches.get( - "aot_inductor.package_cpp_only", config.aot_inductor.package_cpp_only + # Standlaone AOTInductor means only generate cpp project for building a standalone binary + patch_config(config_patches, "aot_inductor.package_cpp_only", True) + # Standlaone AOTInductor needs to embed the kernel code in the binary + patch_config(config_patches, "aot_inductor.embed_kernel_binary", True) + # Default to use multi-arch kernel codegen for non-rocm GPU + patch_config( + config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip ) - if package_cpp_only is None: - config_patches = {**config_patches, "aot_inductor.package_cpp_only": True} - elif not package_cpp_only: - raise RuntimeError( - "compile_standalone=True requires package_cpp_only=True. " - "Please set aot_inductor.package_cpp_only=True in your inductor config." - ) + patch_config( + config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model" + ) + return config_patches @@ -3471,14 +3487,6 @@ def is_valid_aoti_model_name() -> bool: return True -def aoti_model_name_from_config() -> str: - from torch._inductor import config - - model_name = config.aot_inductor.model_name_for_generated_files - model_name = "aoti_model" if model_name is None else model_name - return model_name - - def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: if unbacked_only: return free_unbacked_symbols(x) diff --git a/torch/export/experimental/_utils.py b/torch/export/experimental/_utils.py index b91dfbb0db802..910c45c2ceb9d 100644 --- a/torch/export/experimental/_utils.py +++ b/torch/export/experimental/_utils.py @@ -184,9 +184,14 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str "", "set(CMAKE_CXX_STANDARD 17)", "", - "find_package(Torch REQUIRED)", ] ) + + from torch._inductor.config import test_configs + + if test_configs.use_libtorch: + ib.writeline("find_package(Torch REQUIRED)") + if cuda: ib.writeline("find_package(CUDA REQUIRED)") From 4c01991b386e7b56da59f5cc68c2edd400a28871 Mon Sep 17 00:00:00 2001 From: Meet Vadakkanchery Date: Wed, 6 Aug 2025 16:52:03 +0000 Subject: [PATCH 0053/1424] [DCP][Prototype] Checkpoint replication via PGTransport (#157963) (#159801) Summary: ### PR Context Introduce simple replication logic via PGTransport. The goal is to showcase a working prototype of replication via PGTransport, in this impl we assume world_sizes are equal allowing us to create perfect bi-directional pairs for the purpose of choosing replica "partners". Test Plan: CI Rollback Plan: Differential Revision: D79590797 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159801 Approved by: https://github.com/saumishr --- .../checkpoint/test_state_dict_stager.py | 531 +++++++++++++++++- torch/distributed/checkpoint/staging.py | 151 ++++- 2 files changed, 680 insertions(+), 2 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict_stager.py b/test/distributed/checkpoint/test_state_dict_stager.py index 86a952e0701d2..8134472f52d5c 100644 --- a/test/distributed/checkpoint/test_state_dict_stager.py +++ b/test/distributed/checkpoint/test_state_dict_stager.py @@ -1,12 +1,23 @@ # Owner(s): ["oncall: distributed"] import dataclasses +import os +import tempfile +from datetime import timedelta import torch import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard as ShardedTensorShard, + ShardedTensor, + ShardMetadata, +) from torch.distributed._tensor import DTensor -from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.checkpoint._state_dict_stager import StateDictStager +from torch.distributed.checkpoint.staging import _ReplicationStager +from torch.distributed.tensor import DeviceMesh, distribute_tensor from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -818,5 +829,523 @@ def test_dtensor(self): self.assertEqual(cpu_state_dict["dtensor"].size(), dtensor.size()) +class TestReplicationStager(DTensorTestBase): + """ + Test suite for _ReplicationStager functionality. + Tests replication of state_dict across training ranks using CPU tensors only. + """ + + @property + def backend(self) -> str: + return "cpu:gloo,cuda:nccl" + + def _create_simple_state_dict(self, rank: int) -> dict: + """ + Create a simple state_dict with CPU tensors, deterministically unique per rank. + + Args: + rank: The rank number to create unique tensors for + + Returns: + dict: A state dictionary with CPU tensors + """ + # Create unique tensors for each rank + torch.manual_seed(42 + rank) # Different seed per rank + + return { + "layer1.weight": torch.randn(64, 128, device="cpu"), + "layer1.bias": torch.randn(64, device="cpu"), + "layer2.weight": torch.randn(32, 64, device="cpu"), + "layer2.bias": torch.randn(32, device="cpu"), + "nested": { + "param": torch.randn(16, 16, device="cpu"), + "buffer": torch.randn(8, device="cpu"), + }, + "scalar": torch.tensor(float(rank), device="cpu"), + } + + def _verify_simple_state_dict_replication( + self, replicated_dict: dict, rank: int, partner_rank: int + ): + """ + Verify that replication worked correctly. + + Args: + replicated_dict: The replicated state_dict received from partner + rank: Current rank + partner_rank: Partner rank we should have received from + """ + # Create expected state_dict (what partner rank would have created) + expected_dict = self._create_simple_state_dict(partner_rank) + + def compare_tensors(actual, expected, path=""): + if isinstance(actual, dict) and isinstance(expected, dict): + self.assertEqual( + actual.keys(), expected.keys(), f"Keys mismatch at {path}" + ) + for key in actual: + compare_tensors( + actual[key], expected[key], f"{path}.{key}" if path else key + ) + elif isinstance(actual, torch.Tensor) and isinstance( + expected, torch.Tensor + ): + self.assertEqual( + actual.device.type, "cpu", f"Tensor at {path} should be on CPU" + ) + self.assertEqual( + actual.shape, expected.shape, f"Shape mismatch at {path}" + ) + self.assertEqual( + actual.dtype, expected.dtype, f"Dtype mismatch at {path}" + ) + self.assertTrue( + torch.equal(actual, expected), f"Values mismatch at {path}" + ) + else: + self.assertEqual(actual, expected, f"Value mismatch at {path}") + + compare_tensors(replicated_dict, expected_dict) + + def _create_dtensor_state_dict(self, rank: int, device_mesh: DeviceMesh) -> dict: + """ + Create state_dict with DTensor and regular tensors for deterministic testing + due to DTensor Shard, Replicate placements. + + Args: + rank: Current rank + device_mesh: DeviceMesh for DTensor creation + + Returns: + dict: State dictionary with DTensors + """ + # Create a large global tensor with deterministic values + # Each position contains a unique value that encodes both position and rank info + global_size = 128 + global_tensor = torch.arange(0, global_size * 16, dtype=torch.float32).reshape( + global_size, 16 + ) + + # Create DTensor with Shard(0) - each rank gets different rows + sharded_dtensor = distribute_tensor(global_tensor, device_mesh, [Shard(0)]) + + # Create DTensor with Replicate() - all ranks have the same data + replicated_global = torch.full( + (8, 8), float(global_size * 100), dtype=torch.float32, device="cpu" + ) + replicated_dtensor = distribute_tensor( + replicated_global, device_mesh, [Replicate()] + ) + + return { + "sharded_param": sharded_dtensor, + "replicated_param": replicated_dtensor, + "rank_scalar": torch.tensor(float(rank), device="cpu"), + } + + def _verify_dtensor_replication( + self, replicated_dict: dict, rank: int, partner_rank: int + ): + """ + Verify DTensor replication accuracy by checking local shards and global reconstruction. + + Args: + replicated_dict: Replicated state_dict received from partner + rank: Current rank + partner_rank: Partner rank we should have received from + """ + # Verify sharded DTensor + if "sharded_param" in replicated_dict: + replicated_sharded = replicated_dict["sharded_param"] + self.assertIsInstance(replicated_sharded, DTensor, "Should receive DTensor") + + # Get local shard from replicated DTensor + replicated_local = replicated_sharded.to_local() + + # Create expected local shard (what partner rank would have) + expected_global = torch.arange(0, 128 * 16, dtype=torch.float32).reshape( + 128, 16 + ) + + # Calculate expected shard for this rank's position + world_size = dist.get_world_size() + shard_size = 128 // world_size + start_idx = partner_rank * shard_size + end_idx = (partner_rank + 1) * shard_size + expected_local = expected_global[start_idx:end_idx] + + self.assertTrue( + torch.equal(replicated_local, expected_local), + "Sharded DTensor value mismatch", + ) + + # Verify DTensor metadata is preserved + self.assertEqual( + replicated_sharded._spec.placements[0].__class__.__name__, + "Shard", + "DTensor should maintain Shard placement", + ) + + # Verify replicated DTensor + if "replicated_param" in replicated_dict: + replicated_replicated = replicated_dict["replicated_param"] + self.assertIsInstance( + replicated_replicated, DTensor, "Should receive DTensor" + ) + + # Get local data from replicated DTensor + replicated_local = replicated_replicated.to_local() + + # Expected value should be global_size * 100 + expected_value = float(128 * 100) + expected_tensor = torch.full( + (8, 8), expected_value, dtype=torch.float32, device="cpu" + ) + + self.assertTrue( + torch.equal(replicated_local, expected_tensor), + "Replicated DTensor value mismatch", + ) + + # Verify DTensor metadata is preserved + self.assertEqual( + replicated_replicated._spec.placements[0].__class__.__name__, + "Replicate", + "DTensor should maintain Replicate placement", + ) + + # Verify regular tensors + if "rank_scalar" in replicated_dict: + self.assertEqual( + replicated_dict["rank_scalar"].item(), + float(partner_rank), + f"Rank scalar should be {partner_rank}, got {replicated_dict['rank_scalar'].item()}", + ) + + def _create_sharded_tensor_state_dict(self, rank: int, world_size: int) -> dict: + """ + Create state_dict with ShardedTensor for deterministic testing. + + Args: + rank: Current rank + world_size: Total world size + + Returns: + dict: State dictionary with ShardedTensor + """ + # Create deterministic local shard for this rank + global_size = 64 + shard_size = global_size // world_size + start_idx = rank * shard_size + end_idx = (rank + 1) * shard_size + + # Create local tensor with deterministic values + local_tensor = torch.arange( + start_idx * 8, end_idx * 8, dtype=torch.float32, device="cpu" + ).reshape(shard_size, 8) + + # Create ShardedTensor using init_from_local_shards + sharded_tensor = init_from_local_shards( + [ + ShardedTensorShard( + tensor=local_tensor, + metadata=ShardMetadata( + shard_offsets=[start_idx, 0], + shard_sizes=[shard_size, 8], + placement=f"rank:{rank}/cpu", + ), + ) + ], + global_size, + 8, + ) + + return { + "sharded_tensor": sharded_tensor, + "rank_scalar": torch.tensor(float(rank), device="cpu"), + } + + def _verify_sharded_tensor_replication( + self, replicated_dict: dict, rank: int, partner_rank: int + ): + """ + Verify ShardedTensor replication accuracy by checking local shards and metadata. + + Args: + replicated_dict: Replicated state_dict received from partner + rank: Current rank + partner_rank: Partner rank we should have received from + """ + # Verify sharded tensor + if "sharded_tensor" in replicated_dict: + replicated_sharded = replicated_dict["sharded_tensor"] + self.assertIsInstance( + replicated_sharded, ShardedTensor, "Should receive ShardedTensor" + ) + + # Get local shard from replicated ShardedTensor + local_shards = replicated_sharded.local_shards() + self.assertEqual( + len(local_shards), 1, "Should have exactly one local shard" + ) + + local_shard = local_shards[0] + replicated_local = local_shard.tensor + + # Create expected local shard (what partner rank would have) + world_size = dist.get_world_size() + global_size = 64 + shard_size = global_size // world_size + start_idx = partner_rank * shard_size + end_idx = (partner_rank + 1) * shard_size + + expected_local = torch.arange( + start_idx * 8, end_idx * 8, dtype=torch.float32, device="cpu" + ).reshape(shard_size, 8) + + self.assertTrue( + torch.equal(replicated_local, expected_local), + "Sharded tensor value mismatch", + ) + + # Verify shard metadata is preserved + expected_metadata = ShardMetadata( + shard_offsets=[start_idx, 0], + shard_sizes=[shard_size, 8], + placement=f"rank:{partner_rank}/cpu", + ) + self.assertEqual( + local_shard.metadata.shard_offsets, + expected_metadata.shard_offsets, + "Shard offsets should match", + ) + self.assertEqual( + local_shard.metadata.shard_sizes, + expected_metadata.shard_sizes, + "Shard sizes should match", + ) + + # Verify regular tensors + if "rank_scalar" in replicated_dict: + self.assertEqual( + replicated_dict["rank_scalar"].item(), + float(partner_rank), + f"Rank scalar should be {partner_rank}, got {replicated_dict['rank_scalar'].item()}", + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_replication_basic(self): + """Test basic replication functionality with world_size=16""" + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Create unique DTensor state_dict for this rank + state_dict = self._create_simple_state_dict(current_rank) + + # Initialize replication stager + stager = _ReplicationStager( + pg=dist.new_group(backend=dist.Backend.GLOO), + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + ) + + # Perform replication + replicated_dict = stager.stage(state_dict) + + # Calculate expected partner rank + partner_rank = (current_rank + world_size // 2) % world_size + + # Verify DTensor replication + self._verify_simple_state_dict_replication( + replicated_dict, current_rank, partner_rank + ) + + # Clean up + stager.close() + + @with_comms + @skip_if_lt_x_gpu(4) + def test_replication_dtensors(self): + """Test replication with DTensor and mixed tensor types""" + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Create CPU-based DeviceMesh for DTensor + device_mesh = DeviceMesh("cpu", list(range(world_size))) + + # Create DTensor state_dict which includes different tensor types + state_dict = self._create_dtensor_state_dict(current_rank, device_mesh) + + # Initialize replication stager + stager = _ReplicationStager( + pg=dist.group.WORLD, + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + ) + + # Perform replication + result = stager.stage(state_dict) + + # Wait for completion + from concurrent.futures import Future + + if isinstance(result, Future): + replicated_dict = result.result() + else: + replicated_dict = result + + # Calculate expected partner + partner_rank = (current_rank + world_size // 2) % world_size + + # Verify all DTensor types are correctly replicated + self._verify_dtensor_replication(replicated_dict, current_rank, partner_rank) + + # Clean up + stager.close() + + @with_comms + @skip_if_lt_x_gpu(4) + def test_replication_sharded_tensors(self): + """Test replication with ShardedTensor and mixed tensor types""" + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Create ShardedTensor state_dict for this rank + state_dict = self._create_sharded_tensor_state_dict(current_rank, world_size) + + # Initialize replication stager + stager = _ReplicationStager( + pg=dist.group.WORLD, + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + ) + + # Perform replication + result = stager.stage(state_dict) + + # Wait for completion + from concurrent.futures import Future + + if isinstance(result, Future): + replicated_dict = result.result() + else: + replicated_dict = result + + # Calculate expected partner + partner_rank = (current_rank + world_size // 2) % world_size + + # Verify all ShardedTensor types are correctly replicated + self._verify_sharded_tensor_replication( + replicated_dict, current_rank, partner_rank + ) + + # Clean up + stager.close() + + @with_comms + @skip_if_lt_x_gpu(4) + def test_replication_persistence(self): + """Test persistence functionality in _ReplicationStager""" + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Test 1: Default storage directory (auto-generated tempdir) + with tempfile.TemporaryDirectory() as _: + # Create state_dict for this rank + state_dict = self._create_simple_state_dict(current_rank) + + # Initialize stager with default storage_dir (None) + stager = _ReplicationStager( + pg=dist.group.WORLD, + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + storage_dir=None, # Let it create its own tempdir + ) + + # Perform replication to trigger persistence + stager.stage(state_dict) + + # Calculate expected partner rank + partner_rank = (current_rank + world_size // 2) % world_size + + # Verify file was created with correct naming convention + expected_path = stager._get_persisted_path(current_rank, partner_rank) + + self.assertTrue( + os.path.exists(expected_path), + f"Persisted file should exist at {expected_path}", + ) + + # Verify the storage directory was created + self.assertTrue( + os.path.isdir(stager._storage_dir), "Storage directory should exist" + ) + self.assertTrue( + stager._storage_dir.startswith(tempfile.gettempdir()), + "Default storage directory should be in system temp directory", + ) + + # Load and verify the persisted state_dict matches the received one + loaded_state_dict = torch.load(expected_path) + self._verify_simple_state_dict_replication( + loaded_state_dict, current_rank, partner_rank + ) + + # Clean up + stager.close() + + # Test 2: Custom storage directory + with tempfile.TemporaryDirectory() as custom_storage_dir: + # Create custom subdirectory + custom_subdir = os.path.join(custom_storage_dir, "custom_replication_test") + + # Create state_dict for this rank + state_dict = self._create_simple_state_dict(current_rank) + + # Initialize stager with custom storage_dir + stager = _ReplicationStager( + pg=dist.group.WORLD, + timeout=timedelta(seconds=30), + device=torch.device("cpu"), + storage_dir=custom_subdir, + ) + + # Perform replication to trigger persistence + stager.stage(state_dict) + + # Verify custom storage directory was created and used + self.assertEqual( + stager._storage_dir, + custom_subdir, + "Should use custom storage directory", + ) + self.assertTrue( + os.path.isdir(custom_subdir), + "Custom storage directory should be created", + ) + + # Verify file was created in custom directory + expected_path = stager._get_persisted_path(current_rank, partner_rank) + + self.assertTrue( + os.path.exists(expected_path), + f"Persisted file should exist in custom directory at {expected_path}", + ) + + # Load and verify the persisted state_dict + loaded_state_dict = torch.load(expected_path) + self._verify_simple_state_dict_replication( + loaded_state_dict, current_rank, partner_rank + ) + + # Clean up + stager.close() + + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index 9e1031c7fddae..e7acf4975173c 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -1,11 +1,17 @@ +import os +import tempfile from concurrent.futures import Future, ThreadPoolExecutor from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Optional, Union +from datetime import timedelta +from typing import Any, cast, Optional, Union from typing_extensions import deprecated, Protocol, runtime_checkable import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint._pg_transport import PGTransport from torch.distributed.checkpoint._state_dict_stager import StateDictStager from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE @@ -315,3 +321,146 @@ def synchronize_staging(self) -> None: def close(self) -> None: pass + + +class _ReplicationStager(AsyncStager): + """ + An AsyncStager implementation that replicates state_dict across training ranks + using PGTransport. + + Args: + pg: ProcessGroup for distributed communication + timeout: Timeout for communication operations + device: Device to use for tensor operations + storage_dir: Directory to store persisted state_dicts + + Warning: This is experimental and subject to change. + """ + + _synchronize_after_execute: bool = False + + def __init__( + self, + pg: ProcessGroup, + timeout: timedelta = timedelta(minutes=30), + device: torch.device = torch.device("cpu"), + storage_dir: Optional[str] = None, + ): + self._pg = pg + self._timeout = timeout + self._device = device + self._transport = PGTransport(pg, timeout, device, None) + + # Set up storage directory for persisting exchanged state_dicts + if storage_dir is None: + self._storage_dir = tempfile.mkdtemp(prefix="replication_stager_") + else: + self._storage_dir = storage_dir + os.makedirs(self._storage_dir, exist_ok=True) + + def stage( + self, state_dict: STATE_DICT_TYPE + ) -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + """ + Stage the state_dict by replicating it across ranks. Returns a state_dict representing + the received replica. + + Perform the actual replication logic. Creates bidirectional pairs where each rank exchanges + state_dict with its partner at (rank + world_size//2) % world_size. + Uses simple rank-based ordering to prevent deadlocks. + + Assumes world_size is always even. + """ + if not dist.is_initialized(): + return state_dict + + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Calculate partner rank using half-world offset + # creates bidirectional pairs for replication. + offset = world_size // 2 + partner_rank = (current_rank + offset) % world_size + + # Use simple rank-based ordering to prevent deadlocks. + # Lower-numbered rank sends first, higher-numbered rank receives first. + if current_rank < partner_rank: + # Send first, then receive + self._transport.send_checkpoint([partner_rank], state_dict) + received_state_dict = self._transport.recv_checkpoint(partner_rank) + else: + # Receive first, then send + received_state_dict = self._transport.recv_checkpoint(partner_rank) + self._transport.send_checkpoint([partner_rank], state_dict) + + # Persist the received state_dict for future discoverability + received_state_dict = cast(STATE_DICT_TYPE, received_state_dict) + self._persist_state_dict(received_state_dict, current_rank, partner_rank) + + return received_state_dict + + def _persist_state_dict( + self, state_dict: STATE_DICT_TYPE, current_rank: int, partner_rank: int + ) -> None: + """ + Persist the received state_dict to disk for future discoverability. + Only keeps one replica per rank, overwriting any previous replica. + Uses atomic write pattern (temp file + rename). + + Args: + state_dict: The state_dict received from partner rank + current_rank: Current rank that received the state_dict + partner_rank: Rank that sent the state_dict + """ + final_path = self._get_persisted_path(current_rank, partner_rank) + temp_path = final_path + ".tmp" + + try: + # Ensure parent directory exists and is writable + os.makedirs(os.path.dirname(final_path), exist_ok=True) + + # Write to temporary file with explicit flushing + with open(temp_path, "wb") as f: + torch.save(state_dict, f) + # Flush application buffers to OS buffers + f.flush() + # Force OS buffers to disk for durability + os.fsync(f.fileno()) + + # Atomic rename to final location + os.rename(temp_path, final_path) + except Exception as e: + # Clean up temp file if it exists + try: + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception: + pass # Ignore cleanup errors + # Re-raise the original exception with more context + raise RuntimeError( + f"Failed to persist state_dict from rank {partner_rank} to rank {current_rank}: {e}" + ) from e + + def _get_persisted_path(self, current_rank: int, partner_rank: int) -> str: + """ + Get the file path where a state_dict would be persisted. + + Args: + current_rank: Current rank + + Returns: + File path for the persisted state_dict + """ + filename = f"rank_{current_rank}_replica_partner_{partner_rank}.pt" + return os.path.join(self._storage_dir, filename) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ + + def close(self) -> None: + """ + Clean up resources. Persisted files are intentionally left for future discovery. + """ From d7a855d67d704d1c114aa285d946155958716511 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Wed, 6 Aug 2025 14:23:15 +0000 Subject: [PATCH 0054/1424] [async-TP] Make scaled-mm + reduce-scatter preserve alignment of scales (#159957) After https://github.com/pytorch/pytorch/pull/157905 started using cuBLAS for row-wise scaling on CUDA 12.9+, this broke some downstream tests for fp8 which were testing "odd" shapes. After checking in with the cuBLAS team this turned out to be due to the scale tensors' starting addresses not being aligned to 16 bytes. PyTorch storages are always aligned at 256 bytes, hence this came from a "slicing" of the scale tensor being done inside async-TP when chunking a matmul in order to overlap it with reduce-scatter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159957 Approved by: https://github.com/vkuzo, https://github.com/danielvegamyhre --- torch/distributed/_symmetric_memory/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index d050c8b40c6c1..4b0e9acc19bd7 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -1270,6 +1270,11 @@ def _fused_scaled_matmul_reduce_scatter_impl( .flatten(0, -2) ) A_scale_shards = list(A_scale.chunk(group.size())) + # cuBLAS's row-wise kernel requires scales to be aligned to 16 bytes. + # When we slice them we might break this and need to reallocate them. + A_scale_shards = [ + t if t.data_ptr() % 16 == 0 else t.clone() for t in A_scale_shards + ] else: raise ValueError("A_scale cannot be none for scaled_mm") From c669b0ab87d9d4950e8031afc038b22ddfce3d9b Mon Sep 17 00:00:00 2001 From: Georgia Phillips Date: Wed, 6 Aug 2025 18:04:24 +0000 Subject: [PATCH 0055/1424] Fix execution frame cleanup logic (#158717) Summary: This fixes a bug in the execution fram cleanup logic - previously, whenever we hit the time interval to clear out the frames, we were removing any cached execution frames beyond the configured minimum number (frameEntry.used was unused). Instead, we only want to clear frames that were NOT USED in during the last time interval. This diff refactors the executor to have the correct logic. Test Plan: ``` buck2 test 'mode/dev-nosan' fbcode//sigmoid/inference/test_gpu:model_runner_test -- ModelRunnerTest.Basic_InterpreterCuda_Multithread_Cleanup --run-disabled --print-passing-details ``` Rollback Plan: Differential Revision: D78621408 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158717 Approved by: https://github.com/dolpm --- torch/nativert/detail/MPMCQueue.h | 9 ++ torch/nativert/executor/Executor.cpp | 122 +++++++++++---------------- torch/nativert/executor/Executor.h | 25 +----- 3 files changed, 60 insertions(+), 96 deletions(-) diff --git a/torch/nativert/detail/MPMCQueue.h b/torch/nativert/detail/MPMCQueue.h index 3b90503887bbb..8301ce3fdb4c5 100644 --- a/torch/nativert/detail/MPMCQueue.h +++ b/torch/nativert/detail/MPMCQueue.h @@ -55,6 +55,15 @@ class MPMCQueue { return true; } + /** + * Get the current size of the queue. + * @return The number of elements in the queue. + */ + size_t size() { + std::lock_guard lock(mutex_); + return storage_.size(); + } + private: std::mutex mutex_; std::deque storage_; diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp index 932972ae2b5bc..906a6ec327287 100644 --- a/torch/nativert/executor/Executor.cpp +++ b/torch/nativert/executor/Executor.cpp @@ -10,10 +10,6 @@ #include #include -// Maximum number of retries when trying to get a frame from -// clearedExecutionFrames_ -constexpr uint32_t kClearExecutionFrameRetries = 10; - namespace torch::nativert { Executor::Executor( @@ -29,7 +25,7 @@ Executor::Executor( ? std::optional(*graph_) : std::nullopt), executionFrames_(executorConfig_.maxNumConcurrentThreads), - clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads), + inactiveExecutionFrames_(executorConfig_.maxNumConcurrentThreads), numExecutionFrames_(0), lastClearedTimestamp_(getCurrentTimestampSeconds()) { if (weights) { @@ -193,34 +189,12 @@ Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() { std::shared_ptr weights; weights_.withLock([&](auto& w) { weights = w; }); - // First try to get a frame from clearedExecutionFrames_ if clearing is in - // progress - if (C10_UNLIKELY(clearingInProgress_)) { - ExecutionFrameEntry frameEntry; - uint32_t retry = 0; - while ( - retry < - kClearExecutionFrameRetries) { // Limit retries to avoid infinite loop - if (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) { - if (retry > 0) { - VLOG(1) << "Took " << retry - << " retries to pop from clearedExecutionFrames_"; - } - ExecutorFramePtr ptr{std::move(frameEntry.frame), *this}; - if (ptr->weightVersion() != weights->version()) { - ptr->setWeights(*weights); - } - return ptr; - } - retry++; - } - // If we couldn't get a frame from cleared pool after retries, move onto - // main pool - } - // Try to get a frame from the main pool or create a new one std::unique_ptr frame; - while (!executionFrames_.readIfNotEmpty(frame)) { + + // Try to get a frame from executionFrames_ or inactiveExecutionFrames_ + while (!executionFrames_.readIfNotEmpty(frame) && + !inactiveExecutionFrames_.readIfNotEmpty(frame)) { int64_t numFrames = numExecutionFrames_.load(); if (numFrames < executorConfig_.maxNumConcurrentThreads) { if (numExecutionFrames_.compare_exchange_strong( @@ -243,6 +217,7 @@ Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() { } void Executor::clearStaleExecutionFrames() { + LOG(INFO) << "Clearing stale execution frames"; if (!cleanupLock_.try_lock()) { // Another thread is already doing cleanup return; @@ -250,41 +225,48 @@ void Executor::clearStaleExecutionFrames() { // Update timestamp first to minimize contention lastClearedTimestamp_ = getCurrentTimestampSeconds(); - int numPopped = 0; + // Get the size of active execution frames queue directly + size_t activeFramesSize = executionFrames_.size(); + size_t inactiveFramesSize = inactiveExecutionFrames_.size(); + size_t total = activeFramesSize + inactiveFramesSize; + size_t numCleared = 0; std::unique_ptr frame; - // Move frames from executionFrames_ to clearedExecutionFrames_ - while (executionFrames_.readIfNotEmpty(frame)) { - ++numPopped; - // Keep the first popped entries up to minimum size - if (numPopped > executorConfig_.minNumExecutionFrames) { - // Discard stale frames - frame.reset(); - numExecutionFrames_ -= 1; - continue; - } + // If number of active frames is less than the configured min, then transfer + // the difference from inactive frames + size_t minFramesToKeep = std::min( + static_cast(executorConfig_.minNumExecutionFrames), total); + size_t framesToTransfer = + (minFramesToKeep - activeFramesSize) > minFramesToKeep + ? static_cast(0) + : minFramesToKeep - activeFramesSize; + ; + for (size_t i = 0; + i < framesToTransfer && inactiveExecutionFrames_.readIfNotEmpty(frame); + ++i) { + executionFrames_.writeIfNotFull(std::move(frame)); + } - ExecutionFrameEntry entry; - entry.used = false; - entry.frame = std::move(frame); - clearedExecutionFrames_.writeIfNotFull(std::move(entry)); - // Enable clients to pop from clearedExecutionFrames_ while clearing is in - // progress - clearingInProgress_ = true; + size_t newActiveFramesSize = executionFrames_.size(); + + // Clear remaining inactive frames (i.e. those that were not used in the last + // time interval) + while (inactiveExecutionFrames_.readIfNotEmpty(frame)) { + ++numCleared; + frame.reset(); + numExecutionFrames_ -= 1; } - uint32_t numPushed = 0; - ExecutionFrameEntry frameEntry; - // Move frames back from clearedExecutionFrames_ to executionFrames_ - while (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) { - ++numPushed; - executionFrames_.writeIfNotFull(std::move(frameEntry.frame)); - clearingInProgress_ = false; + // Move active frames to inactive so they are cleared next time if not used + // Check newActiveFramesSize > 0 to guuard against other threads adding + // frames to active queue during while loop + while (executionFrames_.readIfNotEmpty(frame) && newActiveFramesSize > 0) { + --newActiveFramesSize; + inactiveExecutionFrames_.writeIfNotFull(std::move(frame)); } - clearingInProgress_ = false; - VLOG(1) << "Cleared " << (numPopped - numPushed) << " out of " << numPopped - << " ExecutionFrame instances in the pool"; + LOG(INFO) << "Cleared " << numCleared << " out of " << total + << " ExecutionFrame instances in the pool"; cleanupLock_.unlock(); } @@ -292,6 +274,8 @@ void Executor::clearStaleExecutionFrames() { void Executor::returnExecutorFrameToPool( std::unique_ptr frame) { // Check if it's time to clean up stale frames + // TODO: consider moving cleanup to a dedicated thread so it does not impact + // p99 latency if (executorConfig_.doExecutionFrameCleanup && lastClearedTimestamp_ + executorConfig_.executionFramePoolCleanupIntervalSec < @@ -301,21 +285,11 @@ void Executor::returnExecutorFrameToPool( try { frame->destroyBorrowedIValues(); - - // Create an entry with used=true - if (C10_UNLIKELY(!clearingInProgress_)) { - TORCH_CHECK( - executionFrames_.writeIfNotFull(std::move(frame)), - "ExecutionFrame pool full"); - } else { - ExecutionFrameEntry frameEntry; - frameEntry.used = true; - frameEntry.frame = std::move(frame); - - TORCH_CHECK( - clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry)), - "Cleared ExecutionFrame pool full"); - } + // Always return to active execution frame pool, indicating that frame was + // used in the previous time interval + TORCH_CHECK( + executionFrames_.writeIfNotFull(std::move(frame)), + "ExecutionFrame pool full"); } catch (...) { sem_.release(); throw; diff --git a/torch/nativert/executor/Executor.h b/torch/nativert/executor/Executor.h index 4f40946b4b428..64f2372b9e85b 100644 --- a/torch/nativert/executor/Executor.h +++ b/torch/nativert/executor/Executor.h @@ -122,7 +122,7 @@ class Executor { std::vector getDelegates(); // Get the number of execution frames in the pool - int getNumExecutionFrames() const { + auto getNumExecutionFrames() const { return numExecutionFrames_.load(); } @@ -149,25 +149,6 @@ class Executor { void clearStaleExecutionFrames(); private: - // Structure to track execution frame usage - struct ExecutionFrameEntry { - bool used{false}; - std::unique_ptr frame; - - // Add move constructor and assignment operator - ExecutionFrameEntry() = default; - ExecutionFrameEntry(ExecutionFrameEntry&& other) noexcept - : used(other.used), frame(std::move(other.frame)) {} - ExecutionFrameEntry& operator=(ExecutionFrameEntry&& other) noexcept { - used = other.used; - frame = std::move(other.frame); - return *this; - } - // Delete copy constructor and assignment operator - ExecutionFrameEntry(const ExecutionFrameEntry&) = delete; - ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete; - }; - void maybeRunConstantFolding(const std::shared_ptr& weights); void validateInputs(const std::vector& inputs) const; @@ -188,8 +169,8 @@ class Executor { c10::Semaphore sem_; torch::nativert::detail::MPMCQueue> executionFrames_; - torch::nativert::detail::MPMCQueue - clearedExecutionFrames_; + torch::nativert::detail::MPMCQueue> + inactiveExecutionFrames_; std::atomic_int64_t numExecutionFrames_; std::unique_ptr layoutPlanner_; From 44dd3684d287f0d010efded69b9736a5c0a2b2c2 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 4 Aug 2025 16:44:06 -0700 Subject: [PATCH 0056/1424] [AOTI] Fix memory leak from all_reduce (#159818) Summary: This PR solves two issues: 1. When lowering the all_reduce op, Inductor expects to convert it to the in-place version, all_reduce_, but it was calling ir._AllReduceKernel.create_inplace instead of ir._AllReduce_Kernel.create_inplace. This triggers a tricky bug in AOIT because it generates cpp call to the functional version aoti_torch_cpu__c10d_functional_all_reduce, but later corresponding wait operation will still wait on the input to aoti_torch_cpu__c10d_functional_all_reduce instead of the output from aoti_torch_cpu__c10d_functional_all_reduce. This causes unwaited tensor leading to memory leak. 2. Since AOTI generates the inplace version aoti_torch_cpu__c10d_functional_all_reduce_ now. The return tensor from aoti_torch_cpu__c10d_functional_all_reduce_ doesn't get used. It will be released when the program exists, so it's not a memory leak but it will unnecessarily hold that tensor which causes high memory water mark. This PR generates tensor delete operation right after calling aoti_torch_cpu__c10d_functional_all_reduce_. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159818 Approved by: https://github.com/henryhu6, https://github.com/yushangdi --- test/distributed/test_c10d_functional_native.py | 9 ++++++--- torch/_inductor/codegen/cpp_wrapper_cpu.py | 11 ++++++----- torch/_inductor/comm_lowering.py | 4 +++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 5c127634f122f..bafc781b591c6 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -827,9 +827,12 @@ def func(arg: torch.Tensor) -> torch.Tensor: with torch._inductor.config.patch({"cpp_wrapper": True}): code = run_and_get_triton_code(compiled, arg) - # Check the return tensor from wait_tensor is not used anywhere by - # checking if it is explicitly deleted by calling aoti_torch_delete_tensor_object - FileCheck().check_count("aoti_torch_delete_tensor_object(buf", 2).run(code) + # Check the return tensors from all_reduce and wait_tensor are not used anywhere by + # checking if they are explicitly deleted by calling aoti_torch_delete_tensor_object + FileCheck().check_not( + # all_reduce must have been rewritten into all_reduce_ + "aoti_torch_cpu__c10d_functional_all_reduce(buf" + ).check_count("aoti_torch_delete_tensor_object(buf", 4).run(code) # Test aoti AOTIRunnerUtil.run(func, (arg,)) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 473b405100745..6d11fe1c8be17 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1278,12 +1278,13 @@ def generate_c_shim_extern_kernel_alloc( extern_kernel.get_kernel_name(), args, device ) - if ( - extern_kernel.python_kernel_name - == "torch.ops._c10d_functional.wait_tensor.default" + if extern_kernel.python_kernel_name in ( + "torch.ops._c10d_functional.all_reduce_.default", + "torch.ops._c10d_functional.wait_tensor.default", ): - # wait_tensor returns its input, and the returned tensor is not used anywhere, - # so we can delete the returned AtenTensorHandle to reduce its lifetime. + # all_reduce_ is an inplace op and its returned tensor is not used anywhere. + # wait_tensor returns its input without any modification and the returned tensor is not used anywhere. + # In both cases, we can immediately delete the returned AtenTensorHandle to reduce its lifetime. self.writeline( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object({output_handle_name}));" ) diff --git a/torch/_inductor/comm_lowering.py b/torch/_inductor/comm_lowering.py index b748f61f067b9..e46909432f17e 100644 --- a/torch/_inductor/comm_lowering.py +++ b/torch/_inductor/comm_lowering.py @@ -209,7 +209,9 @@ def _all_reduce(inp: ir.TensorBox, reduce_op: str, group_name: str) -> ir.Tensor inp.realize() V.graph.no_fuse_buffer_names.add(inp.get_name()) inp = ir.ExternKernel.require_contiguous(inp) - ir._AllReduceKernel.create_inplace( + # Because we are lowering as inplace c10d.all_reduce_, we should generate + # _AllReduce_Kernel instead of _AllReduceKernel. + ir._AllReduce_Kernel.create_inplace( c10d.all_reduce_.default, inp, # type: ignore[arg-type] reduce_op, From ba37f589d49a64ba0f76c3e68052025250fa2998 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 6 Aug 2025 18:41:05 +0000 Subject: [PATCH 0057/1424] Revert "[dynamo] Be consistent with storing func source for UserMethodVariable (#159696)" This reverts commit ee62177c196d716fc3a2d641370bed8a673a45d3. Reverted https://github.com/pytorch/pytorch/pull/159696 on behalf of https://github.com/anijain2305 due to broke internal tests ([comment](https://github.com/pytorch/pytorch/pull/159696#issuecomment-3161196192)) --- torch/_dynamo/codegen.py | 6 +----- torch/_dynamo/variables/functions.py | 19 +++---------------- torch/_dynamo/variables/user_defined.py | 12 +----------- 3 files changed, 5 insertions(+), 32 deletions(-) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 4d4d494191bd1..f64ef6e5231af 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -42,7 +42,6 @@ from .variables.functions import ( ContextlibContextManagerLocalGeneratorObjectVariable, LocalGeneratorObjectVariable, - UserMethodVariable, ) from .variables.nn_module import NNModuleVariable from .variables.tensor import ( @@ -251,10 +250,7 @@ def __call__( value.source is not None and allow_cache and not ( - value.is_realized() - and isinstance( - value, (LocalGeneratorObjectVariable, UserMethodVariable) - ) + value.is_realized() and isinstance(value, LocalGeneratorObjectVariable) ) ): # There's a corner case for export: for instance, if the computation diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index e628a955bc904..0da182c022b99 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1122,26 +1122,13 @@ def inspect_parameter_names(self): return super().inspect_parameter_names()[1:] def var_getattr(self, tx: "InstructionTranslator", name: str): - if name == "__func__": - # self.source points to the source of the function object and not - # the method object - return VariableTracker.build(tx, self.fn, self.source) + source = self.source and AttrSource(self.source, name) if name == "__self__": return self.obj + if name == "__func__": + return VariableTracker.build(tx, self.fn, source) return super().var_getattr(tx, name) - def reconstruct(self, codegen): - if not self.obj.source or not self.source: - raise NotImplementedError - - def get_bound_method(): - codegen(self.source) - codegen.extend_output(codegen.create_load_attrs("__get__")) - - codegen.add_push_null(get_bound_method) - codegen(self.obj.source) - codegen.extend_output(create_call_function(1, False)) - class WrappedUserMethodVariable(UserMethodVariable): def __init__(self, wrapped, context, **kwargs) -> None: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 1b6d9ffacf130..7cb21ab372801 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1380,9 +1380,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): self.value.__class__, name, NO_SUCH_SUBOBJ ) is_accessible_from_type_mro = ( - subobj_from_class is subobj - and self.cls_source is not None - and self.source is not None + subobj_from_class is subobj and self.cls_source is not None ) if isinstance(subobj, property): @@ -1414,11 +1412,6 @@ def var_getattr(self, tx: "InstructionTranslator", name): func = subobj.__get__(self.value) return VariableTracker.build(tx, func, source) elif isinstance(subobj, classmethod): - if is_accessible_from_type_mro: - # Accessing from __dict__ does not resolve the descriptor, it - # returns a classmethod object, so access the __func__ - # attribute to get to the actual function. - source = AttrSource(self.get_source_by_walking_mro(name), "__func__") return variables.UserMethodVariable( subobj.__func__, self.var_getattr(tx, "__class__"), source=source ) @@ -1468,9 +1461,6 @@ def var_getattr(self, tx: "InstructionTranslator", name): isinstance(subobj, types.MethodType) and isinstance(self.value, torch.nn.Module) ): - if is_accessible_from_type_mro: - source = self.get_source_by_walking_mro(name) - # Since we get subobj via self._getattr_static, which may not trigger dynamic lookup. # Static lookup can't tell us it's a method or function correctly, # so we trigger dynamic lookup here to get the correct type. From 6fa3592dc65b15195a145a98f344f0c38517b12f Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Wed, 6 Aug 2025 19:05:15 +0000 Subject: [PATCH 0058/1424] Dataloader benchmark script (#159432) This script adds a simple dataloading benchmark tracking throughput and memory. The output looks like this ``` System Information: PyTorch version: 2.9.0a0+gitf87d117 PyTorch location: /home/divyanshkhanna/pytorch/torch/__init__.py Torchvision version: 0.24.0a0+f52c4f1 Torchvision location: /home/divyanshkhanna/pytorch/vision/torchvision/__init__.py CUDA available: True CUDA device: NVIDIA PG509-210 CPU count: 192 Physical CPU cores: 96 Total system memory: 1510.11 GB Loading dataset from imagenet/val (1 copies) Dataset size: 50000 --- Benchmarking DataLoader with worker_method=multiprocessing --- Memory before DataLoader creation: 500.59 MB Detailed memory information: USS (Unique Set Size): 499.00 MB PSS (Proportional Set Size): 500.74 MB RSS (Resident Set Size): 497.39 MB Memory after DataLoader creation: 1127.61 MB Memory increase: 627.02 MB Starting training loop with 1 epochs (max 100 batches per epoch) Epoch 1, Batch 10, Time: 0.2910s, Memory: 12044.50 MB Epoch 1, Batch 20, Time: 0.2909s, Memory: 12185.71 MB Epoch 1, Batch 30, Time: 0.2909s, Memory: 10654.93 MB Epoch 1, Batch 40, Time: 0.2909s, Memory: 12378.26 MB Epoch 1, Batch 50, Time: 0.2907s, Memory: 12402.28 MB Epoch 1, Batch 60, Time: 0.2909s, Memory: 10559.35 MB Epoch 1, Batch 70, Time: 0.2907s, Memory: 12644.69 MB Epoch 1, Batch 80, Time: 0.2909s, Memory: 12654.65 MB Epoch 1, Batch 90, Time: 0.2909s, Memory: 12727.20 MB Epoch 1, Batch 100, Time: 0.2908s, Memory: 12722.09 MB Results: Worker method: multiprocessing DataLoader init time: 0.1553 seconds Average batch time: 0.3408 seconds Samples per second: 375.53 Peak memory usage: 12738.76 MB Memory increase: 12238.17 MB ``` > TODO: This script right now is CPU-only friendly and GPU friendly. But it might be worth upgrading it to test against a canonical DistributedDataParallel setup on say a 1x8 node. Or maybe we can keep that as a separate script inside `benchmarks` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159432 Approved by: https://github.com/ramanishsingh --- benchmarks/data/dataloader_benchmark.py | 316 ++++++++++++++++++++++++ 1 file changed, 316 insertions(+) create mode 100644 benchmarks/data/dataloader_benchmark.py diff --git a/benchmarks/data/dataloader_benchmark.py b/benchmarks/data/dataloader_benchmark.py new file mode 100644 index 0000000000000..7d1dd3afc7e98 --- /dev/null +++ b/benchmarks/data/dataloader_benchmark.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +""" +Benchmark script for PyTorch DataLoader with different worker methods. + +This script measures: +1. Dataloader initialization time +2. Dataloading speed (time per batch) +3. CPU memory utilization + +Usage: + python dataloader_benchmark.py --data_path /path/to/dataset --batch_size 32 --num_workers 4 +""" + +import argparse +import copy +import gc +import time + +import psutil +import torchvision +import torchvision.transforms as transforms +from torchvision.models import resnet18 + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.utils.data.dataset import ConcatDataset + + +def get_memory_usage(): + """ + Get current memory usage in MB. This includes all child processes. + + Returns: + Total memory usage in MB + """ + process = psutil.Process() + + main_memory = process.memory_full_info().pss + + # Add memory usage of all child processes + for child in process.children(recursive=True): + try: + child_mem = child.memory_full_info().pss + main_memory += child_mem + except (psutil.NoSuchProcess, psutil.AccessDenied, AttributeError): + # Process might have terminated or doesn't support PSS, fall back to USS + print(f"Failed to get PSS for {child}, falling back to USS") + child_mem = child.memory_info().uss + main_memory += child_mem + + return main_memory / (1024 * 1024) + + +def print_detailed_memory(): + """Print detailed memory information.""" + process = psutil.Process() + print("\nDetailed memory information:") + try: + print( + f" USS (Unique Set Size): {process.memory_full_info().uss / (1024 * 1024):.2f} MB" + ) + print( + f" PSS (Proportional Set Size): {process.memory_full_info().pss / (1024 * 1024):.2f} MB" + ) + print( + f" RSS (Resident Set Size): {process.memory_info().rss / (1024 * 1024):.2f} MB" + ) + except Exception: + print(" Detailed memory info not available") + + +def create_model(): + """Create a simple model for benchmarking.""" + model = resnet18() + return model + + +def benchmark_dataloader( + dataset, + batch_size, + num_workers, + num_epochs=1, + max_batches=10, + multiprocessing_context=None, + logging_freq=10, +): + """Benchmark a dataloader with specific configuration.""" + print("\n--- Benchmarking DataLoader ---") + + # Clear memory before starting + gc.collect() + torch.cuda.empty_cache() + + # Create model + model = create_model() + + # Measure memory before dataloader creation + memory_before = get_memory_usage() + print(f"Memory before DataLoader creation: {memory_before:.2f} MB") + print_detailed_memory() + + # Measure dataloader initialization time + start = time.perf_counter() + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=torch.cuda.is_available(), + prefetch_factor=2 if num_workers > 0 else None, + multiprocessing_context=multiprocessing_context, + ) + it = iter(dataloader) + dataloader_init_time = time.perf_counter() - start + + # Measure memory after dataloader creation + memory_after = get_memory_usage() + print(f"Memory after DataLoader creation: {memory_after:.2f} MB") + print(f"Memory increase: {memory_after - memory_before:.2f} MB") + + # Create model and optimizer + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + + # Benchmark dataloading speed + model.train() + total_batches = 0 + total_samples = 0 + total_time = 0 + total_data_load_time = 0 + + # Measure peak memory during training + peak_memory = memory_after + + print( + f"\nStarting training loop with {num_epochs} epochs (max {max_batches} batches per epoch)" + ) + + for epoch in range(num_epochs): + while total_batches < max_batches: + batch_start = time.perf_counter() + + try: + inputs, labels = next(it) + except StopIteration: + break + + # Move data to device + inputs = inputs.to(device) + labels = labels.to(device) + + # Capture data fetch time (including sending to device) + data_load_time = time.perf_counter() - batch_start + + # Forward pass + outputs = model(inputs) + loss = criterion(outputs, labels) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Capture batch time + batch_time = time.perf_counter() - batch_start + + total_batches += 1 + total_samples += inputs.size(0) + total_data_load_time += data_load_time + total_time += batch_time + + # Update peak memory and log memory usage periodically + if total_batches % 5 == 0: + # Force garbage collection before measuring memory + gc.collect() + current_memory = get_memory_usage() + + if current_memory > peak_memory: + peak_memory = current_memory + + if total_batches % logging_freq == 0: + print( + f"Epoch {epoch + 1}, Batch {total_batches}, " + f"Time: {batch_time:.4f}s, " + f"Memory: {current_memory:.2f} MB" + ) + + # Calculate statistics + avg_data_load_time = ( + total_data_load_time / total_batches if total_batches > 0 else 0 + ) + avg_batch_time = total_time / total_batches if total_batches > 0 else 0 + samples_per_second = total_samples / total_time if total_time > 0 else 0 + + results = { + "dataloader_init_time": dataloader_init_time, + "num_workers": num_workers, + "batch_size": batch_size, + "total_batches": total_batches, + "avg_batch_time": avg_batch_time, + "avg_data_load_time": avg_data_load_time, + "samples_per_second": samples_per_second, + "peak_memory_mb": peak_memory, + "memory_increase_mb": peak_memory - memory_before, + } + + print("\nResults:") + print(f" DataLoader init time: {dataloader_init_time:.4f} seconds") + print(f" Average data loading time: {avg_data_load_time:.4f} seconds") + print(f" Average batch time: {avg_batch_time:.4f} seconds") + print(f" Samples per second: {samples_per_second:.2f}") + print(f" Peak memory usage: {peak_memory:.2f} MB") + print(f" Memory increase: {peak_memory - memory_before:.2f} MB") + + # Clean up + del model, optimizer + del dataloader + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark PyTorch DataLoader with different worker methods" + ) + parser.add_argument("--data_path", required=True, help="Path to dataset") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size") + parser.add_argument("--num_workers", type=int, default=4, help="Number of workers") + parser.add_argument( + "--max_batches", + type=int, + default=100, + help="Maximum number of batches per epoch", + ) + parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs") + parser.add_argument( + "--multiprocessing_context", + choices=["fork", "spawn", "forkserver"], + default="forkserver", + help="Multiprocessing context to use (fork, spawn, forkserver)", + ) + parser.add_argument( + "--dataset_copies", + type=int, + default=1, + help="Number of copies of the dataset to concatenate (for testing memory usage)", + ) + parser.add_argument( + "--logging_freq", + type=int, + default=10, + help="Frequency of logging memory usage during training", + ) + args = parser.parse_args() + + # Print system info + print("System Information:") + # The following are handy for debugging if building from source worked correctly + print(f" PyTorch version: {torch.__version__}") + print(f" PyTorch location: {torch.__file__}") + print(f" Torchvision version: {torchvision.__version__}") + print(f" Torchvision location: {torchvision.__file__}") + print(f" CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f" CUDA device: {torch.cuda.get_device_name(0)}") + print(f" CPU count: {psutil.cpu_count(logical=True)}") + print(f" Physical CPU cores: {psutil.cpu_count(logical=False)}") + print(f" Total system memory: {psutil.virtual_memory().total / (1024**3):.2f} GB") + + # Define transforms + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + # Load dataset + print(f"\nLoading dataset from {args.data_path} ({args.dataset_copies} copies)") + + # Try to load as ImageFolder + datasets = [] + for _ in range(args.dataset_copies): + base_dataset = torchvision.datasets.ImageFolder( + args.data_path, transform=transform + ) + datasets.append(copy.deepcopy(base_dataset)) + del base_dataset + dataset = ConcatDataset(datasets) + + print(f"Dataset size: {len(dataset)}") + + # Run benchmark with specified worker method + benchmark_dataloader( + dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + multiprocessing_context=args.multiprocessing_context, + num_epochs=args.num_epochs, + max_batches=args.max_batches, + logging_freq=args.logging_freq, + ) + + +if __name__ == "__main__": + main() From c71950907df19f2438b0909dd409ea23116ccef3 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 6 Aug 2025 19:31:42 +0000 Subject: [PATCH 0059/1424] [inductor] add _get_inductor_debug_symbol_cflags for debug symbol control. (#159938) We need to add inductor debug symbol support for crash case debug. When we turn on generate debug symbol. On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. On Linux, it should create some debug sections in binary file. I added UT for it also. It works well on Windows inductor debug. image Pull Request resolved: https://github.com/pytorch/pytorch/pull/159938 Approved by: https://github.com/jansel, https://github.com/angelayi --- test/inductor/test_compile.py | 74 ++++++++++++++++++++++++++++++++++ torch/_inductor/cpp_builder.py | 28 +++++++++++-- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_compile.py b/test/inductor/test_compile.py index e1f4f146636d4..6908936eca3f3 100644 --- a/test/inductor/test_compile.py +++ b/test/inductor/test_compile.py @@ -1,6 +1,14 @@ # Owner(s): ["module: inductor"] +import os +import shlex +import subprocess +import sys +from unittest import mock + import torch from torch import _dynamo as dynamo, _inductor as inductor +from torch._inductor.codecache import write +from torch._inductor.cpp_builder import CppBuilder, CppOptions from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import gen_gm_and_inputs from torch.fx import symbolic_trace @@ -8,6 +16,25 @@ from torch.testing._internal.inductor_utils import HAS_CPU +_IS_MACOS = sys.platform.startswith("darwin") +_IS_WINDOWS = sys.platform == "win32" + + +def safe_command_output(cmd, timeout=30): + try: + return subprocess.check_output( + cmd, + stderr=subprocess.STDOUT, + text=True, + timeout=timeout, + shell=isinstance(cmd, str), + ).strip() + except subprocess.CalledProcessError as e: + return f"run failed(error code {e.returncode}): {e.output.strip()}" + except subprocess.TimeoutExpired: + return "runt timeout" + + class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -109,6 +136,53 @@ def test_inductor_via_op_with_multiple_outputs(self): mod_opt = inductor.compile(mod, inp) self.assertEqual(mod(*inp), mod_opt(*inp)) + @mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"}) + def test_inductor_generate_debug_symbol(self): + cpp_code = """ +int main(){ + return 0; +} + """ + + _, source_path = write( + cpp_code, + "cpp", + ) + build_option = CppOptions() + cpp_builder = CppBuilder( + name="test_symbol", + sources=source_path, + output_dir=os.path.dirname(source_path), + BuildOption=build_option, + ) + cpp_builder.build() + binary_path = cpp_builder.get_target_file_path() + + """ + When we turn on generate debug symbol. + On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. + On Linux, it should create some debug sections in binary file. + """ + + def check_linux_debug_section(module_path: str): + check_cmd = shlex.split(f"readelf -S {module_path}") + output = safe_command_output(check_cmd) + has_debug_sym = ".debug_info" in output + self.assertEqual(has_debug_sym, True) + + def check_windows_pdb_exist(module_path: str): + file_name_no_ext = os.path.splitext(module_path)[0] + file_name_pdb = f"{file_name_no_ext}.pdb" + has_pdb_file = os.path.exists(file_name_pdb) + self.assertEqual(has_pdb_file, True) + + if _IS_WINDOWS: + check_windows_pdb_exist(binary_path) + elif _IS_MACOS: + pass # MacOS not sure that if it should be works. + else: + check_linux_debug_section(binary_path) + if __name__ == "__main__": if HAS_CPU: diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 44efd8088c73a..baa852fbaf4fc 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -637,7 +637,7 @@ def _get_optimization_cflags( return cflags -def _get_shared_cflag(do_link: bool) -> list[str]: +def _get_shared_cflags(do_link: bool) -> list[str]: if _IS_WINDOWS: """ MSVC `/MD` using python `ucrtbase.dll` lib as runtime. @@ -652,6 +652,25 @@ def _get_shared_cflag(do_link: bool) -> list[str]: return ["shared", "fPIC"] +def _get_inductor_debug_symbol_cflags() -> tuple[list[str], list[str]]: + """ + When we turn on generate debug symbol. + On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. + On Linux, it should create some debug sections in binary file. + """ + cflags: list[str] = [] + ldflags: list[str] = [] + b_enable_debug_symbol = os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1" + if b_enable_debug_symbol: + if _IS_WINDOWS: + cflags = ["Z7", "_DEBUG", "OD"] + ldflags = ["DEBUG", "OPT:REF", "OPT:ICF"] + else: + cflags.append("g") + + return cflags, ldflags + + def get_cpp_options( cpp_compiler: str, do_link: bool, @@ -667,12 +686,15 @@ def get_cpp_options( libraries: list[str] = [] passthrough_args: list[str] = [] + dbg_cflags, dbg_ldflags = _get_inductor_debug_symbol_cflags() + cflags = ( - _get_shared_cflag(do_link) + _get_shared_cflags(do_link) + _get_optimization_cflags(cpp_compiler, min_optimize) + _get_warning_all_cflag(warning_all) + _get_cpp_std_cflag() + _get_os_related_cpp_cflags(cpp_compiler) + + dbg_cflags ) if not _IS_WINDOWS and config.aot_inductor.enable_lto and _is_clang(cpp_compiler): @@ -685,7 +707,7 @@ def get_cpp_options( definitions, include_dirs, cflags, - ldflags, + ldflags + dbg_ldflags, libraries_dirs, libraries, passthrough_args, From d10e9e47815d3045b3f237289d3bc2a94ed1ebbd Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 5 Aug 2025 22:27:30 -0700 Subject: [PATCH 0060/1424] [MPS] Remove all pre-MacOS14 logic (#159912) Delete older enums, checks for MacOS-13.3+ for int64 support, etc Fixes https://github.com/pytorch/pytorch/issues/159275 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159912 Approved by: https://github.com/manuelcandales --- aten/src/ATen/mps/EmptyTensor.cpp | 1 - aten/src/ATen/mps/MPSDevice.h | 6 +- aten/src/ATen/mps/MPSDevice.mm | 16 +- aten/src/ATen/mps/MPSHooks.mm | 16 +- aten/src/ATen/native/mps/OperationUtils.h | 29 +--- aten/src/ATen/native/mps/OperationUtils.mm | 39 +---- .../ATen/native/mps/operations/BinaryOps.mm | 17 --- aten/src/ATen/native/mps/operations/Blas.mm | 3 - .../ATen/native/mps/operations/Convolution.mm | 4 - aten/src/ATen/native/mps/operations/Copy.mm | 20 +-- .../native/mps/operations/Distributions.mm | 1 - .../mps/operations/FastFourierTransform.mm | 3 - .../ATen/native/mps/operations/GridSampler.mm | 9 -- .../ATen/native/mps/operations/Indexing.mm | 15 +- .../ATen/native/mps/operations/ReduceOps.mm | 52 ++----- aten/src/ATen/native/mps/operations/Repeat.mm | 10 +- .../ATen/native/mps/operations/ScanKernel.mm | 137 +----------------- aten/src/ATen/native/mps/operations/Sort.mm | 6 +- .../native/mps/operations/TensorCompare.mm | 3 - .../ATen/native/mps/operations/UnaryOps.mm | 48 ++---- 20 files changed, 42 insertions(+), 393 deletions(-) diff --git a/aten/src/ATen/mps/EmptyTensor.cpp b/aten/src/ATen/mps/EmptyTensor.cpp index 7b04d65ebdd02..d858df0733975 100644 --- a/aten/src/ATen/mps/EmptyTensor.cpp +++ b/aten/src/ATen/mps/EmptyTensor.cpp @@ -43,7 +43,6 @@ TensorBase empty_mps( int64_t nelements = c10::multiply_integers(size); auto dtype = dtype_or_default(dtype_opt); TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED); - TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer"); auto dtype_meta = scalarTypeToTypeMeta(dtype); diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index a70ce25108201..87c820430c98a 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -18,11 +18,7 @@ namespace at::mps { // Helper enum to check if a MPSGraph op is supported in a given macOS version enum class MacOSVersion : uint32_t { - MACOS_VER_13_1_PLUS = 0, - MACOS_VER_13_2_PLUS, - MACOS_VER_13_3_PLUS, - MACOS_VER_14_0_PLUS, - MACOS_VER_14_4_PLUS, + MACOS_VER_14_4_PLUS = 0, MACOS_VER_15_0_PLUS, MACOS_VER_15_1_PLUS, MACOS_VER_15_2_PLUS, diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index 55af5f83b388c..72a066c69450a 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -32,11 +32,11 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de MPSDevice::MPSDevice() : _mtl_device(nil) { // Check that MacOS 13.0+ version of MPS framework is available - // Create the MPSGraph and check method introduced in 13.0 + // Create the MPSGraph and check method introduced in 14.0 // which is used by MPS backend. id mpsCD = NSClassFromString(@"MPSGraph"); - if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) { + if ([mpsCD instancesRespondToSelector:@selector(HermiteanToRealFFTWithTensor:axes:descriptor:name:)] == NO) { return; } @@ -66,24 +66,12 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}]; } }; - static bool _macos_13_1_plus = is_os_version_at_least(13, 1); - static bool _macos_13_2_plus = is_os_version_at_least(13, 2); - static bool _macos_13_3_plus = is_os_version_at_least(13, 3); - static bool _macos_14_0_plus = is_os_version_at_least(14, 0); static bool _macos_14_4_plus = is_os_version_at_least(14, 4); static bool _macos_15_0_plus = is_os_version_at_least(15, 0); static bool _macos_15_1_plus = is_os_version_at_least(15, 1); static bool _macos_15_2_plus = is_os_version_at_least(15, 2); switch (version) { - case MacOSVersion::MACOS_VER_13_1_PLUS: - return _macos_13_1_plus; - case MacOSVersion::MACOS_VER_13_2_PLUS: - return _macos_13_2_plus; - case MacOSVersion::MACOS_VER_13_3_PLUS: - return _macos_13_3_plus; - case MacOSVersion::MACOS_VER_14_0_PLUS: - return _macos_14_0_plus; case MacOSVersion::MACOS_VER_14_4_PLUS: return _macos_14_4_plus; case MacOSVersion::MACOS_VER_15_0_PLUS: diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index f6133e8877222..a2ec221c1bfea 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -34,7 +34,7 @@ case 14: switch (minor) { case 0: - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); + return true; case 4: return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); default: @@ -42,19 +42,7 @@ return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); } case 13: - switch (minor) { - case 0: - return true; - case 1: - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS); - case 2: - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); - case 3: - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - default: - TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+"); - return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - } + return true; default: TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false"); return false; diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index e6f87f5499a47..f9cd28ca06fa8 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -88,14 +88,8 @@ std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view Tensor gatherViewTensor(const Tensor& src, Tensor& dst); Tensor& scatterViewTensor(const Tensor& src, Tensor& output); -MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, - MPSGraphTensor* inputTensor, - const TensorBase& input, - bool includesInt64 = false); -MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, - MPSGraphTensor* inputTensor, - const TensorBase& input, - bool includesInt64 = false); +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray); MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); @@ -435,14 +429,6 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::functionexecuteMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE); } -static inline void checkSupportsComplex() { - TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer."); -} - MPSDataType getMPSDataType(ScalarType scalar_type) { switch (scalar_type) { case ScalarType::Float: @@ -100,7 +96,6 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { case ScalarType::Half: return MPSDataTypeFloat16; case ScalarType::BFloat16: - checkSupportsBFloat16(); return MPSDataTypeBFloat16; case ScalarType::Int: return MPSDataTypeInt32; @@ -119,10 +114,8 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. " "Please use float32 instead.") case ScalarType::ComplexHalf: - checkSupportsComplex(); return MPSDataTypeComplexFloat16; case ScalarType::ComplexFloat: - checkSupportsComplex(); return MPSDataTypeComplexFloat32; // Unsigned types case ScalarType::UInt64: @@ -140,16 +133,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // #issue 104398441 sortWithTensor and argsortWithTensor has support of // Int32, Half and Float32 types. These utilities are to help cast to these // types. -MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, - MPSGraphTensor* inputTensor, - const TensorBase& input, - bool includesInt64) { +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) { MPSDataType dataType = getMPSDataType(input.scalar_type()); - bool condition = - (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); - if (includesInt64) { - condition = condition && (dataType != MPSDataTypeInt64); - } + bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && + (dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64); if (condition) { dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; @@ -160,16 +147,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // #issue 104398441 sortWithTensor and argsortWithTensor has support of // Int32, Half and Float32 types. These utilities are to help cast from these // types. -MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, - MPSGraphTensor* inputTensor, - const TensorBase& input, - bool includesInt64) { +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) { MPSDataType dataType = getMPSDataType(input.scalar_type()); - bool condition = - (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); - if (includesInt64) { - condition = condition && (dataType != MPSDataTypeInt64); - } + bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && + (dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64); if (condition) { inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; } @@ -186,7 +167,6 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Half: return MPSDataTypeFloat16; case ScalarType::BFloat16: - checkSupportsBFloat16(); return MPSDataTypeBFloat16; case ScalarType::Int: return MPSDataTypeInt32; @@ -201,13 +181,11 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Bool: return MPSDataTypeBool; case ScalarType::ComplexHalf: - checkSupportsComplex(); return MPSDataTypeComplexFloat16; // This is an intentional fallthrough supporting ComplexDouble for Scalar // types as they are casted to Complex64 currently. case ScalarType::ComplexDouble: case ScalarType::ComplexFloat: - checkSupportsComplex(); return MPSDataTypeComplexFloat32; // Unsigned types case ScalarType::UInt64: @@ -267,7 +245,6 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Half: return "half"; case ScalarType::BFloat16: - checkSupportsBFloat16(); return "bfloat"; case ScalarType::Int: return "int"; @@ -879,9 +856,7 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} MTLCompileOptions* options = compile_options; if (!options) { options = [[MTLCompileOptions new] autorelease]; - // Need 3.0 for atomic oprations, 3.1 introduces bfloat support - [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 - : MTLLanguageVersion3_0]; + [options setLanguageVersion:MTLLanguageVersion3_1]; if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { options.mathMode = fast_math ? MTLMathModeFast : MTLMathModeSafe; options.mathFloatingPointFunctions = diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index a9589ecc490ee..06b6edcff9407 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -48,28 +48,11 @@ #define BinaryOpFn(graph, primary, secondary) \ MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary) -static inline Tensor legacy_complex_as_view(const Tensor& t) { - // Convert non-complex types (and cdouble CPU scalars) to cfloat - if (!isComplexType(t.scalar_type()) || t.scalar_type() == kComplexDouble) { - return at::view_as_real(t.to(kMPS, kComplexFloat)); - } - return at::view_as_real(t.dim() != 0 ? t : t.to(kMPS)); -} - static void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock) { - TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) && - (self.scalar_type() == ScalarType::Long || - (other.scalar_type() == ScalarType::Long && - (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))), - "MPS: ", - op_name, - " op with int64 input is supported natively starting from macOS 13.2"); - TORCH_CHECK_TYPE(!isComplexType(self.scalar_type()) || mps::supportsComplex(), - "Complex types are supported starting from MacOS 14.0+"); MPSStream* mpsStream = getCurrentMPSStream(); const bool is_self_scalar = self.dim() == 0; diff --git a/aten/src/ATen/native/mps/operations/Blas.mm b/aten/src/ATen/native/mps/operations/Blas.mm index f167067216d48..101ef5feb224e 100644 --- a/aten/src/ATen/native/mps/operations/Blas.mm +++ b/aten/src/ATen/native/mps/operations/Blas.mm @@ -51,9 +51,6 @@ inline void dot_check(const Tensor& self, const Tensor& other) { } // namespace mps Tensor dot_mps(const Tensor& self, const Tensor& other) { - TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || self.scalar_type() != ScalarType::Long, - "MPS: dot op doesn't support int64 input on MacOS13") - using namespace mps; using CachedGraph = MPSBinaryCachedGraph; diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 97d562730dd8a..d572d52d103a1 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -124,7 +124,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_, IntArrayRef dilation, int64_t groups, std::optional input_shape) { - const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); Tensor input_t = input_t_; bool is3DConv = input_t.dim() == 5; @@ -132,9 +131,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_, input_t = input_t.contiguous(); } - TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer), - "Conv3D is only supported on MPS for MacOS_13_2 or newer"); - TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types"); using namespace at::native::mps; diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 4f879c3b63b02..0c121cee8fb62 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -60,7 +60,6 @@ static void copy_cast_mps(at::Tensor& dst, outputTensor = [mpsGraph castTensor:outputTensor toType:dstDType name:@"cast"]; } if (needs_conj) { - TORCH_CHECK(supportsComplex(), "MPS complex tensors conjugation needs MacOS14+"); outputTensor = [mpsGraph conjugateWithTensor:outputTensor name:nil]; } @@ -275,24 +274,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { // for GPU to GPU copies we only encode to stream's command buffer (no flushing) stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset, profile_id); } else { - // Simulate cast to Complex on older MacOS by initializing real and imag parts - if (dst_.is_complex() && !supportsComplex()) { - if (!src.is_complex()) { - at::real(dst_).copy_(src); - at::imag(dst_).fill_(0); - } else if (src.is_conj() || dst_.is_conj()) { - // One cannot take view of conjugated tensor, but for some reason real and imag views are fine - // Use this to implement a conjugation - at::real(dst_).copy_(at::real(src)); - if (src.is_conj() != dst_.is_conj()) { - at::imag(dst_).copy_(at::neg(at::imag(src))); - } else { - at::imag(dst_).copy_(at::imag(src)); - } - } else { - at::view_as_real(dst_).copy_(at::view_as_real(src)); - } - } else if (dst_byte_offset) { + if (dst_byte_offset) { auto maybeCastedSource = at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource); diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index d072e5a40ac96..4d3f99ea9e02d 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -87,7 +87,6 @@ case kFloat: return MPSDataTypeFloat32; case kBFloat16: { - checkSupportsBFloat16(); return MPSDataTypeBFloat16; } default: diff --git a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm index a9ac701106170..7e9867c9b948d 100644 --- a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm +++ b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm @@ -88,7 +88,6 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, // TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237 Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) { - TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(onesided); @autoreleasepool { @@ -129,7 +128,6 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t normalization, int64_t last_dim_size, Tensor& out) { - TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(last_dim_size); @autoreleasepool { @@ -155,7 +153,6 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, } Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) { - TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(forward); @autoreleasepool { diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm index 1e701d314354d..8f51474e7a2c2 100644 --- a/aten/src/ATen/native/mps/operations/GridSampler.mm +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -127,15 +127,6 @@ Tensor grid_sampler_2d_mps(const Tensor& input, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS)) { - TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.2. ", - "Falling back on CPU. This may have performance implications."); - - return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners) - .clone() - .to("mps"); - } - auto in_size = input.sizes(); auto grid_size = grid.sizes(); auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options()); diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index f00d155559da0..66ae1114f841d 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -353,14 +353,7 @@ static Tensor nonzero_fallback(const Tensor& self) { } Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) { - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ", - "Falling back on CPU. This may have performance implications."); - Tensor out_fallback = nonzero_fallback(self); - at::native::resize_output(out_, out_fallback.sizes()); - out_.copy_(out_fallback); - return out_; - } else if (self.is_complex()) { + if (self.is_complex()) { TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ", "Falling back on CPU. This may have performance implications."); Tensor out_fallback = nonzero_fallback(self); @@ -445,11 +438,7 @@ static Tensor nonzero_fallback(const Tensor& self) { } Tensor nonzero_mps(const Tensor& self) { - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ", - "Falling back on CPU. This may have performance implications."); - return nonzero_fallback(self); - } else if (self.is_complex()) { + if (self.is_complex()) { TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ", "Falling back on CPU. This may have performance implications."); return nonzero_fallback(self); diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 21020bad467d0..4b209403f853a 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -152,8 +152,6 @@ static void reduction_out_mps(const Tensor& input_t, const Tensor& output_t, MPSReductionType reduction_type, const std::string& func_name) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name); // NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor bool canSqueezeLastDim = true; IntArrayRef input_shape = input_t.sizes(); @@ -236,12 +234,10 @@ static void reduction_out_mps(const Tensor& input_t, MPSGraphTensor* castInputTensor = inputTensor; MPSDataType inputCastType = MPSDataTypeInvalid; if (dtype.has_value() && - (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || - (dtype.value() == kLong && macOS13_3_plus))) { + (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || dtype.value() == kLong)) { inputCastType = getMPSDataType(dtype.value()); } else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && - inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && - (inputScalarType != kLong || !macOS13_3_plus)) { + inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && inputScalarType != kLong) { inputCastType = getMPSDataType(kFloat); } @@ -615,9 +611,6 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, } static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, nanmedian ? "nanmedian" : "median"); - IntArrayRef input_shape = input_t.sizes(); int64_t num_in_elements = c10::multiply_integers(input_shape); @@ -634,8 +627,7 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { auto medianCachedGraph = LookUpOrCreateCachedGraph(medianKey, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); MPSGraphTensor* reshapedTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil]; @@ -693,9 +685,6 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { } static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max"); - using CachedGraph = MPSUnaryCachedGraph; IntArrayRef input_shape = input_t.sizes(); @@ -713,8 +702,7 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* castOutputTensor = nil; - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); NSArray* axes = getTensorAxes(input_t); if (reduction_type == MPSReductionType::MAX) { @@ -749,9 +737,6 @@ static void min_max_out_mps(const Tensor& input_t, const Tensor& indices_t, MPSReductionType reduction_type, const std::string& func_name) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out"); - if (output_t.numel() == 0) { return; } @@ -789,8 +774,7 @@ static void min_max_out_mps(const Tensor& input_t, auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* outputTensor = nil; - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); if (reduction_type == MPSReductionType::MAX) { outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; @@ -896,9 +880,6 @@ static void argmax_argmin_out_mps(const Tensor& input_t, const std::string& func_name) { using CachedGraph = MPSUnaryCachedGraph; - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out"); - int64_t dim_ = -1; if (dim.has_value()) { @@ -953,7 +934,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t, MPSGraphTensor* castInputTensor = inputTensor; if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && - (inputScalarType != kLong || !macOS13_3_plus)) { + inputScalarType != kLong) { castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat); } if (reduction_type == MPSReductionType::MAX) { @@ -1282,9 +1263,6 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, op_name); - int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, op_name.c_str()); @@ -1303,7 +1281,7 @@ static void all_any_common_impl_mps(const Tensor& input_t, auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); // reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 MPSGraphTensor* outputTensor = nil; @@ -1369,14 +1347,11 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out"); - @autoreleasepool { std::string key = std::string("any_all_out_mps:") + getTensorsStringKey(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); // reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 if (input_t.dim() > 4) { @@ -1420,14 +1395,11 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out"); - @autoreleasepool { std::string key = std::string("all_all_out_mps:") + getTensorsStringKey(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); // reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 if (input_t.ndimension() > 4) { @@ -1512,9 +1484,6 @@ static void median_out_mps_common(const Tensor& input_t, Tensor& indices, const std::string& func_name, bool nanmedian) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out"); - int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, "max()"); @@ -1585,8 +1554,7 @@ static void median_out_mps_common(const Tensor& input_t, getTensorsStringKey(indices); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); MPSGraphTensor* effectiveLengthTensor = nil; if (nanmedian) { diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index 10668309a8c23..40afa15b4f700 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -129,16 +129,8 @@ void computeRepeatIndices(const index_t* repeat_ptr, }); } -Tensor repeat_interleave_mps(const Tensor& repeat_, std::optional output_size) { +Tensor repeat_interleave_mps(const Tensor& repeat, std::optional output_size) { Tensor output; - Tensor repeat = repeat_; - if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { - // #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output, - // which currently doesn't support int64_t as input. Casting internally the indices to int32_t. - TORCH_WARN_ONCE( - "MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3"); - repeat = repeat.to(kInt); - } AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() { output = repeat_interleave_common>(repeat, output_size); }); diff --git a/aten/src/ATen/native/mps/operations/ScanKernel.mm b/aten/src/ATen/native/mps/operations/ScanKernel.mm index 9e3269d970143..80495ba9d501d 100644 --- a/aten/src/ATen/native/mps/operations/ScanKernel.mm +++ b/aten/src/ATen/native/mps/operations/ScanKernel.mm @@ -23,125 +23,6 @@ #include #endif -// Generic scan implementation that handles both simple scans and scans with indices -static void scan_mps_impl(const Tensor& self, - const std::vector& outputs, - int64_t dim, - const std::string& op_name) { - if (outputs[0].numel() == 0) { - return; - } - - const int64_t ndim = self.dim(); - const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim); - - // Calculate dimensions for scan operation - int64_t row_size = self.size(wrapped_dim); - auto sizes = self.sizes(); - - bool is_innermost = (wrapped_dim == ndim - 1); - - // Check if all tensors are contiguous - bool is_contiguous = self.is_contiguous(); - for (const auto& output : outputs) { - is_contiguous = is_contiguous && output.is_contiguous(); - } - - uint32_t num_rows, num_orows, num_irows, num_threads; - - if (is_innermost) { - // Treat all outer dimensions as a single dimension - num_rows = self.numel() / row_size; - num_threads = num_rows; - } else { - // Treat all outer dimensions (i.e. dim_ < dim) as one - num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + wrapped_dim); - // Treat all inner dimensions (i.e. dim > dimension) as one - num_irows = c10::multiply_integers(sizes.begin() + wrapped_dim + 1, sizes.end()); - num_threads = num_orows * num_irows; - } - - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync_with_rethrow(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - - // Choose kernel based on contiguity and dimension - std::string kernel_name; - if (is_contiguous) { - kernel_name = - op_name + "_contiguous_" + (is_innermost ? "innermost_" : "outer_") + scalarToMetalTypeString(self); - } else { - kernel_name = op_name + "_strided_" + scalarToMetalTypeString(self); - } - - id scanPSO = lib.getPipelineStateForFunc(kernel_name); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() { - std::vector all_tensors = {self}; - all_tensors.insert(all_tensors.end(), outputs.begin(), outputs.end()); - return all_tensors; - }()); - - [computeEncoder setComputePipelineState:scanPSO]; - - // Set input tensor - mtl_setBuffer(computeEncoder, self, 0); - - // Set output tensors - for (size_t i = 0; i < outputs.size(); ++i) { - mtl_setBuffer(computeEncoder, outputs[i], i + 1); - } - - if (is_contiguous) { - // Contiguous kernels - if (is_innermost) { - if (outputs.size() == 1) { - // Simple scan - mtl_setArgs<2>(computeEncoder, num_rows, static_cast(row_size)); - } else { - // Scan with indices - mtl_setArgs<3>(computeEncoder, num_rows, static_cast(row_size)); - } - } else { - if (outputs.size() == 1) { - // Simple scan - mtl_setArgs<2>(computeEncoder, num_orows, num_irows, static_cast(row_size)); - } else { - // Scan with indices - mtl_setArgs<3>(computeEncoder, num_orows, num_irows, static_cast(row_size)); - } - } - } else { - // Strided kernels - pass full tensor information - if (outputs.size() == 1) { - // Simple scan - mtl_setArgs<2>(computeEncoder, - self.sizes(), - self.strides(), - outputs[0].strides(), - static_cast(self.ndimension()), - static_cast(wrapped_dim)); - } else { - // Scan with indices - mtl_setArgs<3>(computeEncoder, - self.sizes(), - self.strides(), - outputs[0].strides(), - outputs[1].strides(), - static_cast(self.ndimension()), - static_cast(wrapped_dim)); - } - } - - mtl_dispatch1DJob(computeEncoder, scanPSO, num_threads); - - getMPSProfiler().endProfileKernel(scanPSO); - } - }); -} - // Utility function to get 2D grid dimensions for dispatch static std::pair get_2d_grid_dims(const IntArrayRef& shape, const int64_t dim) { size_t grid_x = 1; @@ -375,19 +256,11 @@ static void scan_with_indices_mps_impl(const Tensor& self, } // namespace mps void cummax_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { - if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax"); - } else { - mps::scan_mps_impl(self, {values, indices}, dim, "cummax"); - } + mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax"); } void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { - if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin"); - } else { - mps::scan_mps_impl(self, {values, indices}, dim, "cummin"); - } + mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin"); } Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) { @@ -402,11 +275,7 @@ void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int6 return result; } - if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { - mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp"); - } else { - mps::scan_mps_impl(self, {result}, wrap_dim, "logcumsumexp"); - } + mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp"); return result; } diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm index c73b7c33098f1..cfec1e443e251 100644 --- a/aten/src/ATen/native/mps/operations/Sort.mm +++ b/aten/src/ATen/native/mps/operations/Sort.mm @@ -26,9 +26,6 @@ const Tensor& indices) { using namespace mps; - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); - MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out"); - if (self.numel() == 0) { return; } @@ -55,8 +52,7 @@ auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); - MPSGraphTensor* castInputTensor = - castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self); MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor axis:(NSInteger)dim descending:(BOOL)descending diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 6e030c99d0356..16e0608012f37 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -297,9 +297,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements, const auto common_type = at::result_type(elements, test_elements); TORCH_CHECK(elements.is_mps() && test_elements.is_mps()); - TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || supportedFloatingType(common_type), - "isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ", - common_type); @autoreleasepool { std::string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert); diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index edf45a5ff80d0..8fbefcb6ab8a0 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -208,28 +208,12 @@ static void unary_op(const Tensor& self, } Tensor& angle_out_mps(const Tensor& self, Tensor& output) { - if (mps::supportsComplex()) { - mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil]; - auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil]; - return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil]; - }); - return output; - } else { - TORCH_CHECK(!self.is_complex(), "MPS does not support angle with complex input on macOS13") - mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - // On macOS 13 with non-complex input, realPartOfTensor and imaginaryPartOfTensor are - // not available, and NaN is not propagated correctly: - auto imagPart = [mpsGraph constantWithScalar:0.0 shape:inputTensor.shape dataType:inputTensor.dataType]; - auto result = [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:inputTensor name:nil]; - auto nanMask = [mpsGraph isNaNWithTensor:inputTensor name:nil]; - return [mpsGraph selectWithPredicateTensor:nanMask - truePredicateTensor:inputTensor - falsePredicateTensor:result - name:nil]; - }); - return output; - } + mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil]; + auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil]; + return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil]; + }); + return output; } Tensor angle_mps(const Tensor& self) { @@ -362,7 +346,6 @@ static void cumulative_op_impl(const Tensor& self, const Tensor& result, MPSCumulativeOpType cumulativeOpType, const std::string& op_name) { - bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); auto nDims = self.dim(); auto wrapped_dim = maybe_wrap_dim(dim, nDims); TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()), @@ -381,11 +364,6 @@ static void cumulative_op_impl(const Tensor& self, bool castInputData = (isIntegralType(input.scalar_type(), true) && input.scalar_type() != ScalarType::Int && input.scalar_type() != ScalarType::Long); - TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long, - "MPS does not support ", - op_name, - " op with int64 input. Support has been added in macOS 13.3"); - mps::unary_op( input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { if (castInputData) { @@ -440,17 +418,9 @@ static void cumulative_op_impl(const Tensor& self, Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) { TORCH_CHECK(self.is_complex()); - if (!mps::supportsComplex()) { - if (!result.is_same_size(self)) { - result.resize_(self.sizes()); - } - at::real(result).copy_(at::real(self)); - at::imag(result).copy_(at::neg(at::imag(self))); - } else { - mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - return [mpsGraph conjugateWithTensor:inputTensor name:nil]; - }); - } + mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + return [mpsGraph conjugateWithTensor:inputTensor name:nil]; + }); return result; } From 12a54e4ac13a9d4804c393f7d28c4e27a881499e Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Wed, 6 Aug 2025 03:58:52 -0700 Subject: [PATCH 0061/1424] [Inductor UT][Fix XPU CI] Fix case failures introduced by community. (#159759) Fixes #159631 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159759 Approved by: https://github.com/EikanWang, https://github.com/jansel --- test/dynamo/test_modes.py | 3 +++ test/inductor/test_pattern_matcher.py | 6 +++--- test/inductor/test_torchinductor.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 8dab1819f2548..a844efd51af93 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -12,6 +12,7 @@ _push_on_torch_function_stack, ) from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode +from torch.testing._internal.common_utils import skipIfXpu from torch.testing._internal.triton_utils import requires_gpu from torch.utils._device import DeviceContext from torch.utils._python_dispatch import TorchDispatchMode @@ -678,6 +679,7 @@ def forward(self, x): torch.compile(mod, fullgraph=True)(x) @requires_gpu + @skipIfXpu(msg="XPU does not support flex attention") def test_hop(self): import torch import torch._higher_order_ops @@ -701,6 +703,7 @@ def test_hop(self): ) @requires_gpu + @skipIfXpu(msg="XPU does not support flex attention") def test_hop_eager(self): import torch import torch._higher_order_ops diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index ac940f0480098..0ffe7cb37deb6 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -1355,13 +1355,13 @@ def repl(inp, x1, x2): FileCheck().check_not("extern_kernels.addmm(").run(code[0]) def test_addmm_dtype_mismatch(self): - a = torch.nn.Linear(1024, 1024, bias=False).cuda() + a = torch.nn.Linear(1024, 1024, bias=False).to(GPU_TYPE) a = a.to(dtype=torch.float16) - w = torch.randn(1024, 1024, device="cuda") + w = torch.randn(1024, 1024, device=GPU_TYPE) def func(): - x = torch.ones(1024, 1024, device="cuda", dtype=torch.float16) + x = torch.ones(1024, 1024, device=GPU_TYPE, dtype=torch.float16) x = a(x) x = x + w return x diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ed4b1ba3e466d..1a73c6ef13032 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14538,11 +14538,11 @@ def fn(x): else: self.assertTrue("Graph fragment" in code) self.assertTrue( - '%sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default]' + f'%sin : Tensor "f32[4, 4][4, 1]{GPU_TYPE}:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default]' in code ) self.assertTrue( - '%relu : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.relu.default]' + f'%relu : Tensor "f32[4, 4][4, 1]{GPU_TYPE}:0"[num_users=1] = call_function[target=torch.ops.aten.relu.default]' in code ) From 0de2a45a48b1b97860c4281cc491ee161419e7c9 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 6 Aug 2025 08:38:21 -0700 Subject: [PATCH 0062/1424] [BE] Merge 3 CUDA build jobs into one (#159890) Before this change there were build+test jobs: - s89 build+tests - sm75 build+distributed_test - sm_75 build+pr_time_benchmark test This change compiles all 3 builds into one (for 2 architectures) and skips testing sm86 as it never found any new regressions that were not found at the same time on sm89 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159890 Approved by: https://github.com/clee2000, https://github.com/seemethere --- .ci/pytorch/build.sh | 2 +- .github/workflows/pull.yml | 61 ++++---------------------------------- 2 files changed, 6 insertions(+), 57 deletions(-) diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index a7ce0fef736cf..34982ac9b3233 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -176,7 +176,7 @@ fi # We only build FlashAttention files for CUDA 8.0+, and they require large amounts of # memory to build and will OOM -if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]]; then +if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && echo "${TORCH_CUDA_ARCH_LIST}" | tr ' ' '\n' | sed 's/$/>= 8.0/' | bc | grep -q 1; then export BUILD_CUSTOM_STEP="ninja -C build flash_attention -j 2" fi diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 061586437a1a9..8c297b1136889 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -254,36 +254,6 @@ jobs: timeout-minutes: 600 secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed: - name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - cuda-arch-list: '7.5' - test-matrix: | - { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc11-test-distributed: - name: linux-jammy-cuda12.8-py3.10-gcc11-test - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed - - target-determination - with: - timeout-minutes: 360 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.test-matrix }} - secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc11-build: name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml @@ -292,7 +262,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda12.8-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - cuda-arch-list: 8.9 + cuda-arch-list: '7.5 8.9' test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, @@ -300,6 +270,10 @@ jobs: { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, ]} secrets: inherit @@ -429,31 +403,6 @@ jobs: test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: - name: cuda12.8-py3.10-gcc9-sm75 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks - cuda-arch-list: '7.5' - test-matrix: | - { include: [ - { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: - name: cuda12.8-py3.10-gcc9-sm75 - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build - with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-xpu-2025_1-py3_9-build: name: linux-jammy-xpu-2025.1-py3.9 uses: ./.github/workflows/_linux-build.yml From b8ef60b6bcce244a7c5baa5f5cd29a81abde8c92 Mon Sep 17 00:00:00 2001 From: Frank Seide Date: Wed, 6 Aug 2025 20:20:32 +0000 Subject: [PATCH 0063/1424] Enable XNNPACK aarch64 builds (#159762) Summary: This fixes the build of TorchScript's XNNPACK dependency for our aarch64 device. Thanks to andrewjcg for proposing this fix. Rollback Plan: Reviewed By: andrewjcg Differential Revision: D79497613 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159762 Approved by: https://github.com/frankseide, https://github.com/malfet Co-authored-by: Frank Seide --- third_party/xnnpack.buck.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index db16e3565273a..b353d5d0d5982 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -2227,6 +2227,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], # doesn't cover iphonesimulator-x86_64 "ovr_config//runtime:arm64-linux-ubuntu-neon": [":arm64_lib"], + "ovr_config//runtime:fbcode-arm64": [":arm64_lib"], "ovr_config//runtime:platform010": [":x86_and_x86_64_lib"], }), ) From 50580b505326272e694a480dfbe056c8d5e605bd Mon Sep 17 00:00:00 2001 From: Alan Du Date: Wed, 6 Aug 2025 20:33:58 +0000 Subject: [PATCH 0064/1424] Add minimal nn.functional.log_softmax support for NestedTensor (#159662) This only works for the jagged layout and for the non-batch and non-jagged dimensions. I did this mostly by copy-pasting from the existing softmax implementation, but it seems fairly straightforward and I think it should work. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159662 Approved by: https://github.com/jbschlosser --- test/test_nestedtensor.py | 27 +++++++++++++++-------- torch/nested/_internal/ops.py | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 38c029f3c367c..a0c018c45d80f 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4444,12 +4444,18 @@ def test_jagged_op_different_output_shape_dim( @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @parametrize("components_require_grad", [False, True]) + @parametrize( + "func", + [torch.nn.functional.softmax, torch.nn.functional.log_softmax], + name_fn=lambda func: func.__name__, + ) def test_softmax_dim( self, device, dtype, requires_grad, components_require_grad, + func, ): """ Softmax passes when reducing on valid reduction dimensions. @@ -4468,7 +4474,7 @@ def test_softmax_dim( for reduce_dim, _ in reduce_dims: nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) - out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim) + out_actual = func(nt, dim=reduce_dim) torch._dynamo.disable(self.assertEqual)( len(out_actual.shape), len(output_shape) ) # disable if running on dynamo @@ -4498,12 +4504,10 @@ def test_softmax_dim( reduce_dim, reduce_dim_expected = reduce_dim_tuple if nt.dim() > reduce_dim: - out_actual = torch.nn.functional.softmax( - nt, dim=reduce_dim - ) # nested tensor - out_expected = torch.nn.functional.softmax( - nt.values(), dim=reduce_dim_expected - ) # dense tensor of dimensions 1 less than out_actual + # nested tensor + out_actual = func(nt, dim=reduce_dim) + # dense tensor of dimensions 1 less than out_actual + out_expected = func(nt.values(), dim=reduce_dim_expected) self.assertTrue( torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) ) @@ -4601,8 +4605,13 @@ def test_softmax_dim_reduce_ragged_idx_1( @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @parametrize("components_require_grad", [False, True]) + @parametrize( + "func", + [torch.nn.functional.softmax, torch.nn.functional.log_softmax], + name_fn=lambda func: func.__name__, + ) def test_softmax_reduce_batch_dim( - self, device, dtype, requires_grad, components_require_grad + self, device, dtype, requires_grad, components_require_grad, func ): """ Softmax on NestedTensor fails when trying to reduce across batch dimension. @@ -4627,7 +4636,7 @@ def test_softmax_reduce_batch_dim( RuntimeError, "not supported when reducing across the batch dimension for NestedTensor", ): - out = torch.nn.functional.softmax(nt, dim=reduce_dim) + out = func(nt, dim=reduce_dim) @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 8eb962f8a308d..1f26a4d90a4a0 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -841,6 +841,46 @@ def _softmax_default(func, *args, **kwargs): return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) +@register_jagged_func( + torch.ops.aten._log_softmax.default, "self: jt_all, dim: any, half_to_float: any" +) +def _log_softmax_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + if isinstance(new_kwargs["dim"], tuple): + raise RuntimeError( + "log_softmax(): not supported for dimensions of type 'tuple' for NestedTensor" + ) + + inp = new_kwargs.pop("input") + + ( + new_kwargs["dim"], + reduce_on_batch, + reduce_on_ragged, + _reduce_on_non_batch, + ) = _wrap_jagged_dims( + inp.dim(), (new_kwargs["dim"],), "log_softmax", inp._ragged_idx + ) + + if reduce_on_batch: + raise RuntimeError( + "log_softmax(): not supported when reducing across the batch dimension for NestedTensor" + ) + + if reduce_on_ragged: + raise RuntimeError( + "log_softmax(): not supported when reducing along the ragged dimension for NestedTensor" + ) + + # torch.log_softmax takes in the reduction dimension as an integer + new_kwargs["dim"] = new_kwargs["dim"][0] + + return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) + + @register_jagged_func( torch.ops.aten._softmax_backward_data.default, "grad_output: jt, output: jt, dim: any, input_dtype: any", From 0afaeb7c4ec7fd7ecd03e7553b170f76b348e782 Mon Sep 17 00:00:00 2001 From: Mwiza Kunda Date: Wed, 6 Aug 2025 20:45:18 +0000 Subject: [PATCH 0065/1424] Improve `extract_test_fn` (#158637) The current implementation assumes test functions are resolved as test_module.TestClass.test_fn, however this would not work for modules nested in directories e.g. inductor.test_torchinductor.TestClass.test_fn Pull Request resolved: https://github.com/pytorch/pytorch/pull/158637 Approved by: https://github.com/jbschlosser --- torch/testing/_internal/common_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index e3adef752e406..57b7a9fed43fb 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -329,9 +329,10 @@ def extract_test_fn() -> Optional[Callable]: self_val = frame.f_locals["self"] if isinstance(self_val, unittest.TestCase): test_id = self_val.id() - test_name = test_id.split('.')[2] - test_fn = getattr(self_val, test_name).__func__ - return test_fn + *_, cls_name, test_name = test_id.rsplit('.', 2) + if cls_name == type(self_val).__name__ and test_name.startswith("test"): + test_fn = getattr(self_val, test_name).__func__ + return test_fn except Exception: pass return None From d2368aa6f38416345cc0c1393efafe7413d1a324 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 6 Aug 2025 20:54:05 +0000 Subject: [PATCH 0066/1424] [CPUBLAS] add macros for brgemm APIs for versioning (#158629) **Summary** Add macros for brgemm, so that callers (e.g., Torchao's cpp kernels) know which APIs are available. It is useful when callers need to co-work with old versions of PyTorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158629 Approved by: https://github.com/CaoE, https://github.com/Valentine233, https://github.com/ezyang --- aten/src/ATen/native/CPUBlas.cpp | 2 +- aten/src/ATen/native/CPUBlas.h | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 79dbe7353e159..b16c1ef04fa0a 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -51,7 +51,7 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int * // brgemm_pack_B is changed to transform and the setting of brgemm beta is changed to set_add_C #if (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR == 5) #define ONEDNN_UKERNEL_1 -#elif (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 6) +#elif ((IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR >= 6) || (IDEEP_VERSION_MAJOR > 3)) #define ONEDNN_UKERNEL_2 #endif #if ((defined(ONEDNN_UKERNEL_1) || defined(ONEDNN_UKERNEL_2)) && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))) diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 95d11903dc773..8b75f12ebaf21 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -206,6 +206,16 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex float +#define CPUBLAS_BRGEMM_BF16BF16F32 // bfloat16 * bfloat16 -> float +#define CPUBLAS_BRGEMM_F32F32F32 // float * float -> float +#define CPUBLAS_BRGEMM_U8U8I32 // unsigned char * unsigned char -> int32 +#define CPUBLAS_BRGEMM_U8I8I32 // unsigned char * signed char -> int32 +#define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32 + TORCH_API void brgemm( int64_t M, int64_t N, From 512b4730e3c7b931360ae7f78953d943bb483d9a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 6 Aug 2025 13:34:54 -0700 Subject: [PATCH 0067/1424] [EZ] Remove useless `cross_compile_arm64` (#159986) As we don't have any Intel Mac runners in CI for last 2+ years Pull Request resolved: https://github.com/pytorch/pytorch/pull/159986 Approved by: https://github.com/atalman --- .ci/wheel/build_wheel.sh | 3 --- .github/scripts/generate_ci_workflows.py | 3 --- .github/templates/macos_binary_build_workflow.yml.j2 | 3 --- 3 files changed, 9 deletions(-) diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 878d6595c84c0..0c6857f62b249 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -192,9 +192,6 @@ retry brew install libomp # For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule export USE_DISTRIBUTED=1 -if [[ -n "$CROSS_COMPILE_ARM64" ]]; then - export CMAKE_OSX_ARCHITECTURES=arm64 -fi export USE_MKLDNN=OFF export USE_QNNPACK=OFF export BUILD_TEST=OFF diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 4df6150f97655..9dfed6d00df8f 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -59,7 +59,6 @@ class BinaryBuildWorkflow: is_scheduled: str = "" branches: str = "nightly" # Mainly for macos - cross_compile_arm64: bool = False macos_runner: str = "macos-14-xlarge" use_split_build: bool = False # Mainly used for libtorch builds @@ -338,7 +337,6 @@ class OperatingSystem: generate_binary_build_matrix.RELEASE, libtorch_variants=["shared-with-deps"], ), - cross_compile_arm64=False, macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_LIBTORCH}, @@ -351,7 +349,6 @@ class OperatingSystem: build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.MACOS_ARM64 ), - cross_compile_arm64=False, macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index 29b92ad461ef4..1a5780b01519d 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -47,9 +47,6 @@ env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} SKIP_ALL_TESTS: 0 -{%- if cross_compile_arm64 %} - CROSS_COMPILE_ARM64: 1 -{% endif %} !{{ common.concurrency(build_environment) }} jobs: From 289f62ce8a121223cc98cbba37fcdffdcc62551f Mon Sep 17 00:00:00 2001 From: Ruben Rodriguez Buchillon Date: Wed, 6 Aug 2025 02:45:23 -0700 Subject: [PATCH 0068/1424] [inductor][ez] fixup scaled_mm (#159948) Summary: This reverts the part of #159383 for scaled_mm where now, like before, we pass through the normal input_nodes (not the triton_input_nodes) to select_algorithm - #159383 refactored how kwargs are retrieved - it introduced this notion of KernelInputs that wrap input_nodes - scaled_mm uses unsqueezed input nodes for triton to retrieve params - the issue: it uses a squeezed (regular) bias for select_algorithm instead This fixes that by passing the original input nodes rather than the triton input nodes. Test Plan: ``` buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_False (caffe2.test.inductor.test_fp8.TestFP8Lowering)' buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)' ``` This set of tests was failing, and is passing now Side note: these tests were failing I believe because the unsqueezed bias made the ATEN choice no longer eligible, and there is some minor numerical discrepancy between ATEN and Triton for this. I'm not sure the test should be written like that, as we're implicitly relying on ATEN being the choice here. Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D79717654](https://our.internmc.facebook.com/intern/diff/D79717654) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159948 Approved by: https://github.com/izaitsevfb, https://github.com/eellison --- torch/_inductor/kernel/mm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index d97eebdb78e5b..6e741430f36d6 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1259,9 +1259,7 @@ def tuned_scaled_mm( if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) - return autotune_select_algorithm( - "scaled_mm", choices, kernel_inputs.nodes(), layout - ) + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) @functools.cache From a5725965ea21f684a314defab0bba5b9b5407705 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Wed, 6 Aug 2025 18:25:16 +0000 Subject: [PATCH 0069/1424] Remove unnecessary "# noqa: set_linter" comments (#159467) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159467 Approved by: https://github.com/eellison --- torch/_inductor/autotune_process.py | 2 +- torch/_inductor/codegen/rocm/rocm_benchmark_request.py | 2 +- torch/_inductor/codegen/triton.py | 2 +- torch/_inductor/utils.py | 5 ++--- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index c936fbe92c671..dfaabd1ef5941 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -764,7 +764,7 @@ def update_workspace_size(self) -> None: return self.ensure_dll_loaded() unique_input_count = len( - {meta.name for meta in self.input_tensor_meta} # noqa: set_linter + dict.fromkeys(meta.name for meta in self.input_tensor_meta) ) args = [c_void_p(None) for _ in range(unique_input_count + 1)] stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) diff --git a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py index 4a08773433c3a..df4982988aa15 100644 --- a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py +++ b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -96,7 +96,7 @@ def update_workspace_size(self) -> None: return self.ensure_dll_loaded() unique_input_count = len( - {meta.name for meta in self.input_tensor_meta} # noqa: set_linter + dict.fromkeys(meta.name for meta in self.input_tensor_meta) ) args = [c_void_p(None) for _ in range(unique_input_count + 1)] stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 56be9dace0926..0f9139ae0611a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3970,8 +3970,8 @@ def add_constexpr_arg(arg_name): optimize_mem = V.graph.is_inference or V.graph.is_backward inductor_meta = { - # Triton will not accept an OrderedSet for autotune_hints "grid_type": self._get_grid_type().__name__, + # Triton will not accept an OrderedSet for autotune_hints "autotune_hints": set(self.autotune_hints), # noqa: set_linter "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), "mutated_arg_names": mutated_args, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 4cc6e2c566545..026f5f14fe74f 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3366,13 +3366,12 @@ def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str: for i, e in enumerate(row): widths[i] = max(widths[i], len(str(e))) lines = [] - # Need nested {} for string formatting; ignore SET_LINTER here - lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # noqa: set_linter + lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # widths whitespace horizontal separators total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1) lines.append("-" * total_width) for row in elements: - lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) # noqa: set_linter + lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) return "\n".join(lines) From 40c4d61f9ab95b3416de90257694a8207f683605 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Wed, 6 Aug 2025 21:52:14 +0000 Subject: [PATCH 0070/1424] [Dynamo][Better Engineering] Typing `torch/_dynamo/guards.py` (#159315) As part of better engineering effort, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to `torch/_dynamo/guards.py` Running ``` mypy torch/_dynamo/guards.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Annotated | Lines Total | % lines covered | Funcs Annotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 2030 | 3945 | 51.46% | 70 | 138 | 50.72% | | This PR | 4055 | 4055 | 100.00% | 138 | 138 | 100.00% | | Delta | +2025 | +90 | +48.54% | +68 | 0 | +49.28% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/159315 Approved by: https://github.com/williamwen42, https://github.com/Skylion007 --- torch/_C/_dynamo/eval_frame.pyi | 23 +- torch/_C/_dynamo/guards.pyi | 225 ++++++++++++- torch/_dynamo/guards.py | 565 +++++++++++++++++++------------- torch/_dynamo/output_graph.py | 7 +- torch/_dynamo/testing.py | 2 +- torch/_guards.py | 4 +- 6 files changed, 577 insertions(+), 249 deletions(-) diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 6261679dcdef4..117795db5ac3e 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -2,12 +2,9 @@ import enum import types from typing import Optional, overload -from torch._dynamo.types import ( - DynamoCallback, - DynamoGuardCompleteHook, - DynamoGuardHook, - GuardFn, -) +from torch._dynamo.guards import GuardManagerWrapper +from torch._dynamo.types import DynamoCallback, DynamoGuardCompleteHook, DynamoGuardHook +from torch._guards import CompileId def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def set_skip_guard_eval_unsafe(value: bool) -> bool: ... @@ -25,14 +22,20 @@ def raise_sigtrap() -> None: ... class _CacheEntry: def check_fn(self, *args: object, **kwargs: object) -> bool: ... + def update_diff_guard_root_manager(self) -> None: ... code: types.CodeType + compile_id: CompileId + # If we run into circular issues, just use object + guard_manager: GuardManagerWrapper next: _CacheEntry | None class _PrecompileEntry: - guard_manager: GuardFn + guard_manager: GuardManagerWrapper class _ExtraState: - def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ... + def invalidate( + self, cache_entry: _CacheEntry, guard_manager: GuardManagerWrapper + ) -> None: ... class _FrameAction(enum.IntEnum): DEFAULT = 0 @@ -69,7 +72,9 @@ py_opcode_caches: list[int] def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... def _load_precompile_entry( - code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType + code: types.CodeType, + guard_manager: GuardManagerWrapper, + dynamo_code: types.CodeType, ) -> None: ... def _reset_precompile_entries(code: types.CodeType) -> None: ... def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ... diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 9c2c379ae589b..5e0a014e8f784 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -7,8 +7,15 @@ class GlobalStateGuard: def check(self) -> bool: ... def reason(self) -> str: ... -class LeafGuard: ... -class GuardDebugInfo: ... +class LeafGuard: + def verbose_code_parts(self) -> list[str]: ... + +class RelationalGuard: ... + +class GuardDebugInfo: + verbose_code_parts: list[str] + result: bool + num_guards_executed: int class GuardManager: def check(self, value) -> bool: ... @@ -36,6 +43,84 @@ class GuardManager: example_value, guard_manager_enum, ) -> GuardManager: ... + def grad_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def generic_getattr_manager( + self, + attr: str, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def get_generic_dict_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def list_getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tuple_getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def set_getitem_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def func_defaults_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def func_kwdefaults_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tuple_iterator_getitem_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def weakref_call_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def call_function_no_args_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... def global_weakref_manager( self, global_name: str, @@ -91,7 +176,44 @@ class GuardManager: example_value, guard_manager_enum, ) -> GuardManager: ... - + def get_root(self) -> RootGuardManager: ... + def get_source(self) -> str: ... + def fail_count(self) -> int: ... + def get_child_managers(self) -> list[GuardManager]: ... + def repr(self) -> str: ... + def type_of_guarded_value(self) -> str: ... + def get_leaf_guards(self) -> list[LeafGuard]: ... + def get_accessors(self) -> list[GuardManager]: ... + def is_guarded_value_immutable(self) -> bool: ... + def is_tag_safe(self) -> bool: ... + def is_tag_safe_root(self) -> bool: ... + def has_no_accessors(self) -> bool: ... + def has_object_aliasing_guard(self) -> bool: ... + def get_type_of_guarded_value(self) -> type: ... + def type_dict_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def type_mro_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def code_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def closure_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... # Leaf guards def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ... def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ... @@ -106,7 +228,94 @@ class GuardManager: def add_torch_function_mode_stack_guard( self, initial_stack, verbose_code_parts: list[str] ) -> None: ... - def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ... + def add_mapping_keys_guard(self, value, verbose_code_parts: list[str]) -> None: ... + def add_dict_length_check_guard( + self, value, verbose_code_parts: list[str] + ) -> None: ... + def add_length_check_guard(self, value, verbose_code_parts: list[str]) -> None: ... + def add_true_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_false_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_none_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_not_none_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_dispatch_key_set_guard( + self, + dispatch_key, + verbose_code_parts: list[str], + ) -> None: ... + def add_tensor_match_guard( + self, + value, + sizes, + strides, + tensor_name, + verbose_code_parts: list[str], + ptype, + dispatch_keys, + ) -> None: ... + def add_dynamic_indices_guard( + self, + value, + verbose_code_parts: list[str], + ) -> None: ... + def add_no_hasattr_guard( + self, + attr_name, + verbose_code_parts: list[str], + ) -> None: ... + def add_dict_contains_guard( + self, + contains, + key, + verbose_code_parts: list[str], + ) -> None: ... + def add_type_match_guard( + self, + value, + verbose_code_parts: list[str], + ) -> None: ... + def add_dict_version_guard( + self, + value, + verbose_code_parts: list[str], + ) -> None: ... + def add_set_contains_guard( + self, + contains, + item, + verbose_code_parts: list[str], + ) -> None: ... + def add_tuple_iterator_length_guard( + self, + length, + type_id, + verbose_code_parts: list[str], + ) -> None: ... + def add_range_iterator_match_guard( + self, + start, + stop, + step, + type_id, + verbose_code_parts: list[str], + ) -> None: ... + def add_default_device_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def mark_tag_safe(self) -> None: ... + def mark_tag_safe_root(self) -> None: ... class RootGuardManager(GuardManager): def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ... @@ -118,6 +327,7 @@ class RootGuardManager(GuardManager): def clone_manager( self, clone_filter_fn: Callable[[GuardManager], bool] ) -> RootGuardManager: ... + def attach_compile_id(self, compile_id: str) -> None: ... class DictGuardManager(GuardManager): def get_key_manager( @@ -134,6 +344,9 @@ class DictGuardManager(GuardManager): example_value, guard_manager_enum, ) -> GuardManager: ... + def get_key_value_managers( + self, + ) -> dict[int, tuple[GuardManager, GuardManager]]: ... # Guard accessor stubs class GuardAccessor: ... @@ -146,8 +359,8 @@ class GetAttrGuardAccessor(GuardAccessor): def get_attr_name(self) -> str: ... def install_object_aliasing_guard( - guard_managers: list[GuardManager], - tensor_names: list[str], + x: GuardManager, + y: GuardManager, verbose_code_parts: list[str], ): ... def install_no_tensor_aliasing_guard( diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 2d5d0af995b59..5ffa6d06d7c4e 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Core guard system for Dynamo that detects when compiled code needs to be recompiled due to changes in program state. Guards are conditions that must remain true for previously-compiled @@ -40,6 +38,7 @@ from copy import deepcopy from inspect import currentframe from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAliasType, TypeVar from weakref import ReferenceType import torch @@ -53,11 +52,15 @@ DictGetItemGuardAccessor, DictGuardManager, GetGenericDictGuardAccessor, + GuardDebugInfo, + GuardManager, install_no_tensor_aliasing_guard, install_object_aliasing_guard, install_storage_overlapping_guard, install_symbolic_shape_guard, + LeafGuard, profile_guard_manager, + RelationalGuard, RootGuardManager, ) from torch._dynamo.source import ( @@ -83,6 +86,7 @@ Source, StorageOverlap, ) +from torch._inductor.utils import IndentedBuffer from torch._logging import structured from torch._utils_internal import justknobs_check from torch.fx.experimental.symbolic_shapes import ( @@ -182,11 +186,14 @@ if TYPE_CHECKING: - from sympy import Symbol + from collections.abc import Generator, KeysView, Sequence - from torch._dynamo.output_graph import OutputGraphGuardsState + from sympy import Symbol + from torch._C import DispatchKeySet + from torch._dynamo.output_graph import OutputGraph +T = TypeVar("T") log = logging.getLogger(__name__) guards_log = torch._logging.getArtifactLogger(__name__, "guards") recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") @@ -196,6 +203,17 @@ verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") +class IndentedBufferWithPrefix(IndentedBuffer): + def prefix(self) -> str: + return "| " * (self._indent * self.tabwidth) + + def writeline(self, line: str, skip_prefix: bool = False) -> None: # type: ignore[override] + if skip_prefix: + super().writeline(line) + else: + super().writeline("+- " + line) + + class GuardManagerWrapper: """ A helper class that contains the root guard manager. An instance of this @@ -204,37 +222,38 @@ class is stored in the Dynamo cache entry, so that the cache entry can the check_nopybind from C++. """ - def __init__(self, root=None): + def __init__(self, root: Optional[RootGuardManager] = None) -> None: if root is None: self.root = RootGuardManager() else: self.root = root - self.diff_guard_root = None - self.closure_vars = None - self.args = None - self.code_parts = [] - self.verbose_code_parts = None - self.global_scope = None - self.guard_fail_fn = None - self.cache_entry = None - self.extra_state = None - self.id_matched_objs = {} - self.no_tensor_aliasing_sources = [] + self.diff_guard_root: Optional[RootGuardManager] = None + self.closure_vars: Optional[dict[str, Any]] = None + self.args: Optional[list[str]] = None + self.code_parts: list[str] = [] + self.verbose_code_parts: Optional[list[str]] = None + self.global_scope: Optional[dict[str, Any]] = None + self.guard_fail_fn: Optional[Callable[[GuardFail], None]] = None + self.cache_entry: Optional[CacheEntry] = None + self.extra_state: Optional[ExtraState] = None + self.id_matched_objs: dict[str, ReferenceType[object]] = {} + self.no_tensor_aliasing_sources: list[str] = [] - self.printed_relational_guards = set() + self.printed_relational_guards: set[RelationalGuard] = set() self.diff_guard_sources: OrderedSet[str] = OrderedSet() @contextmanager - def _preserve_printed_relational_guards(self): + def _preserve_printed_relational_guards(self) -> Generator[None, None, None]: self.printed_relational_guards = set() try: yield finally: self.printed_relational_guards = set() - def collect_diff_guard_sources(self): + # TODO: clarify what fn and attributes guard manager has to get the right things here + def collect_diff_guard_sources(self) -> OrderedSet[str]: # At the time of finalize, we have only marked guard managers with # TENSOR_MATCH guards as diff guard managers. So, we do a tree traversal # and collect all the nodes in the tree (branches) that lead to tensor @@ -244,7 +263,7 @@ def collect_diff_guard_sources(self): # 0, so we collect them as well. Later on, we accumulate the diff guard # sources for all the guard managers. - def visit_dict_manager(node): + def visit_dict_manager(node: DictGuardManager) -> bool: is_diff_guard_node = ( node.get_source() in self.diff_guard_sources or node.fail_count() > 0 ) @@ -258,7 +277,7 @@ def visit_dict_manager(node): return is_diff_guard_node - def visit_manager(node): + def visit_manager(node: GuardManager) -> bool: assert not isinstance(node, DictGuardManager) is_diff_guard_node = ( @@ -272,7 +291,7 @@ def visit_manager(node): return is_diff_guard_node - def visit(node): + def visit(node: GuardManager) -> bool: if node is None: return False if isinstance(node, DictGuardManager): @@ -283,18 +302,18 @@ def visit(node): return self.diff_guard_sources - def finalize(self): + def finalize(self) -> None: if config.use_recursive_dict_tags_for_guards and justknobs_check( "pytorch/compiler:use_recursive_dict_tags_for_guards" ): self.find_tag_safe_roots() self.prepare_diff_guard_manager() - def prepare_diff_guard_manager(self): + def prepare_diff_guard_manager(self) -> None: self.collect_diff_guard_sources() self.populate_diff_guard_manager() - def find_tag_safe_roots(self): + def find_tag_safe_roots(self) -> None: """ Identify ``tag safe nodes`` and ``tag safe roots`` within a guard tree. @@ -352,7 +371,7 @@ def find_tag_safe_roots(self): subset that are tag safe roots. """ - def visit_dict_manager(node): + def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]: # Just recurse through the key and value dict managers and check if # all of them are tag safe nodes. assert issubclass(node.get_type_of_guarded_value(), dict) @@ -382,7 +401,7 @@ def visit_dict_manager(node): node.mark_tag_safe() return tag_safe_roots - def visit_manager(node): + def visit_manager(node: GuardManager) -> list[GuardManager]: assert not isinstance(node, DictGuardManager) # Collect the subtree tag safe roots @@ -425,7 +444,7 @@ def visit_manager(node): ] return tag_safe_roots - def visit(node): + def visit(node: GuardManager) -> list[GuardManager]: if node is None: return [] if isinstance(node, DictGuardManager): @@ -437,7 +456,7 @@ def visit(node): if issubclass(node.get_type_of_guarded_value(), torch.nn.Module): node.mark_tag_safe_root() - def populate_diff_guard_manager(self): + def populate_diff_guard_manager(self) -> None: self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources) # Ensure that that C++ side points to the updated diff guard manager. @@ -450,19 +469,23 @@ def populate_diff_guard_manager(self): if self.cache_entry: self.cache_entry.update_diff_guard_root_manager() - def clone_with_chosen_sources(self, chosen_sources): - def filter_fn(node_mgr): + def clone_with_chosen_sources( + self, chosen_sources: OrderedSet[str] + ) -> RootGuardManager: + def filter_fn(node_mgr: GuardManager) -> bool: return node_mgr.get_source() in chosen_sources return self.root.clone_manager(filter_fn) - def get_guard_lines(self, guard): + def get_guard_lines(self, guard: LeafGuard) -> list[str]: guard_name = guard.__class__.__name__ parts = guard.verbose_code_parts() parts = [guard_name + ": " + part for part in parts] return parts - def get_manager_line(self, guard_manager, accessor_str=None): + def get_manager_line( + self, guard_manager: GuardManager, accessor_str: Optional[str] = None + ) -> str: source = guard_manager.get_source() t = guard_manager.__class__.__name__ s = t + ": source=" + source @@ -472,7 +495,9 @@ def get_manager_line(self, guard_manager, accessor_str=None): s += f", tag_safe=({guard_manager.is_tag_safe()}, {guard_manager.is_tag_safe_root()})" return s - def construct_dict_manager_string(self, mgr, body): + def construct_dict_manager_string( + self, mgr: DictGuardManager, body: IndentedBufferWithPrefix + ) -> None: for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()): body.writeline(f"KeyValueManager pair at index={idx}") with body.indent(): @@ -484,10 +509,12 @@ def construct_dict_manager_string(self, mgr, body): body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}") self.construct_manager_string(val_mgr, body) - def construct_manager_string(self, mgr, body): + def construct_manager_string( + self, mgr: GuardManager, body: IndentedBufferWithPrefix + ) -> None: with body.indent(): for guard in mgr.get_leaf_guards(): - if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] + if isinstance(guard, RelationalGuard): if guard not in self.printed_relational_guards: self.printed_relational_guards.add(guard) body.writelines(self.get_guard_lines(guard)) @@ -513,19 +540,7 @@ def construct_manager_string(self, mgr, body): ) self.construct_manager_string(child_mgr, body) - def __str__(self): - from torch._inductor.utils import IndentedBuffer - - class IndentedBufferWithPrefix(IndentedBuffer): - def prefix(self): - return "| " * (self._indent * self.tabwidth) - - def writeline(self, line, skip_prefix=False): - if skip_prefix: - super().writeline(line) - else: - super().writeline("+- " + line) - + def __str__(self) -> str: with self._preserve_printed_relational_guards(): body = IndentedBufferWithPrefix() body.tabwidth = 1 @@ -538,29 +553,29 @@ def writeline(self, line, skip_prefix=False): body.writelines(self.get_guard_lines(guard)) return body.getvalue() - def check(self, x): + def check(self, x: Any) -> bool: # Only needed for debugging purposes. return self.root.check(x) - def check_verbose(self, x): + def check_verbose(self, x: Any) -> GuardDebugInfo: # Only needed for debugging purposes. return self.root.check_verbose(x) - def populate_code_parts_for_debugging(self): + def populate_code_parts_for_debugging(self) -> None: # This should be called when the guard manager is fully populated relational_guards_seen = set() - def get_code_parts(leaf_guard): + def get_code_parts(leaf_guard: LeafGuard) -> list[str]: code_parts = [] for verbose_code_part in leaf_guard.verbose_code_parts(): code_part = verbose_code_part.split("#")[0].rstrip() code_parts.append(code_part) return code_parts - def visit(mgr): + def visit(mgr: GuardManager) -> None: nonlocal relational_guards_seen for guard in mgr.get_leaf_guards(): - if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] + if isinstance(guard, RelationalGuard): if guard not in relational_guards_seen: self.code_parts.extend(get_code_parts(guard)) relational_guards_seen.add(guard) @@ -573,7 +588,7 @@ def visit(mgr): visit(self.root) -def from_numpy(a): +def from_numpy(a: Any) -> torch.Tensor: # If not numpy array, piggy back on e.g. tensor guards to check type # Re-enable torch function since we disable it on leaf guards # we need it to properly construct the tensor if a default device is set @@ -583,7 +598,7 @@ def from_numpy(a): # For user stack printing @functools.cache -def uninteresting_files(): +def uninteresting_files() -> set[str]: import torch._dynamo.external_utils import torch._dynamo.polyfills @@ -599,7 +614,7 @@ def uninteresting_files(): _CLOSURE_VARS: Optional[dict[str, object]] = None -def _get_closure_vars(): +def _get_closure_vars() -> dict[str, object]: global _CLOSURE_VARS if _CLOSURE_VARS is None: _CLOSURE_VARS = { @@ -635,7 +650,7 @@ def _ast_unparse(node: ast.AST) -> str: strip_function_call = torch._C._dynamo.strip_function_call -def get_verbose_code_part(code_part: str, guard: Guard) -> str: +def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str: extra = "" if guard is not None: if guard.user_stack: @@ -653,14 +668,14 @@ def get_verbose_code_part(code_part: str, guard: Guard) -> str: def get_verbose_code_parts( - code_parts: Union[str | list[str]], guard: Guard + code_parts: Union[str, list[str]], guard: Optional[Guard] ) -> list[str]: if not isinstance(code_parts, list): code_parts = [code_parts] return [get_verbose_code_part(code_part, guard) for code_part in code_parts] -def convert_int_to_concrete_values(dim) -> Optional[int]: +def convert_int_to_concrete_values(dim: Any) -> Optional[int]: if dim is None: return None if not is_symbolic(dim): @@ -670,11 +685,18 @@ def convert_int_to_concrete_values(dim) -> Optional[int]: return dim.node.maybe_as_int() -def convert_to_concrete_values(size_or_stride): +def convert_to_concrete_values(size_or_stride: list[Any]) -> list[Optional[int]]: return [convert_int_to_concrete_values(dim) for dim in size_or_stride] -def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys): +def get_tensor_guard_code_part( + value: torch.Tensor, + name: str, + sizes: list[Optional[int]], + strides: list[Optional[int]], + pytype: type, + dispatch_keys: DispatchKeySet, +) -> str: dispatch_key = ( dispatch_keys | torch._C._dispatch_tls_local_include_set() ) - torch._C._dispatch_tls_local_exclude_set() @@ -688,7 +710,7 @@ def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_key return guard_str -def get_key_index(dct, key): +def get_key_index(dct: dict[Any, Any], key: Any) -> int: # Ensure that we call dict.keys and not value.keys (which can call # overridden keys method). In the C++ guards, we relied on PyDict_Next # to traverse the dictionary, which uses the internal data structure and @@ -696,7 +718,7 @@ def get_key_index(dct, key): return list(builtin_dict_keys(dct)).index(key) -def get_key_index_source(source, index): +def get_key_index_source(source: Any, index: Any) -> str: return f"list(dict.keys({source}))[{index}]" @@ -724,8 +746,12 @@ class NNModuleAttrAccessorInfo: def getitem_on_dict_manager( - source, base_guard_manager, base_example_value, example_value, guard_manager_enum -): + source: Union[DictGetItemSource, DictSubclassGetItemSource], + base_guard_manager: DictGuardManager, + base_example_value: Any, + example_value: Any, + guard_manager_enum: GuardManagerType, +) -> GuardManager: base_source_name = source.base.name() if isinstance(source.index, ConstDictKeySource): index = source.index.index @@ -764,7 +790,7 @@ def getitem_on_dict_manager( ) -def match_on_id_for_tensor(guard): +def match_on_id_for_tensor(guard: Guard) -> bool: source = guard.originating_source # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads # to a new tensor every time and therefore id differs. @@ -791,7 +817,7 @@ class GuardManagerType(enum.Enum): @functools.cache -def code_framelocals_names_reversed_cached(code: types.CodeType): +def code_framelocals_names_reversed_cached(code: types.CodeType) -> list[str]: return list(reversed(code_framelocals_names(code))) @@ -799,16 +825,16 @@ class GuardBuilder(GuardBuilderBase): def __init__( self, f_code: types.CodeType, - id_ref: Callable[[Any, str], str], + id_ref: Callable[[object, str], int], source_ref: Callable[[Source], str], - lookup_weakrefs: Callable[[object], ReferenceType[object]], + lookup_weakrefs: Callable[[object], Optional[weakref.ref[object]]], local_scope: dict[str, object], global_scope: dict[str, object], guard_manager: GuardManagerWrapper, check_fn_manager: CheckFunctionManager, serialization_mode: Optional[str] = None, - runtime_global_scope: Optional[dict[str, Any]] = None, - ): + runtime_global_scope: Optional[dict[str, object]] = None, + ) -> None: self.f_code = f_code self.id_ref = id_ref self.source_ref = source_ref @@ -839,7 +865,7 @@ def __init__( # Collect the guard managers and debug info to insert no tensor aliasing # guards. self.no_tensor_aliasing_names: list[str] = [] - self.no_tensor_aliasing_guard_managers: list[GuardManagerWrapper] = [] + self.no_tensor_aliasing_guard_managers: list[GuardManager] = [] self.check_fn_manager: CheckFunctionManager = check_fn_manager @@ -848,6 +874,7 @@ def __init__( # to access the same object - self._module["param"] is same as # self.param. self.key_order_guarded_dict_ids = set() + assert self.check_fn_manager.output_graph is not None for source in self.check_fn_manager.output_graph.guard_on_key_order: self.key_order_guarded_dict_ids.add(id(self.get(source.name()))) @@ -857,9 +884,7 @@ def __init__( self.id_matched_objs: dict[str, ReferenceType[object]] = {} # Save the guard managers to avoid repeatedly traversing sources. - self._cached_guard_managers: dict[ - str, torch._C._dynamo.guards.GuardManager - ] = {} + self._cached_guard_managers: dict[str, GuardManager] = {} self._cached_duplicate_input_guards: set[tuple[str, str]] = set() self.object_aliasing_guard_codes: list[tuple[str, str]] = [] self.serialization_mode = serialization_mode @@ -870,7 +895,9 @@ def __init__( tuple[str, str] ] = OrderedSet() - def guard_on_dict_keys_and_ignore_order(self, example_value, guard): + def guard_on_dict_keys_and_ignore_order( + self, example_value: dict[Any, Any], guard: Guard + ) -> None: dict_mgr = self.get_guard_manager(guard) if isinstance(dict_mgr, DictGuardManager): raise NotImplementedError( @@ -898,7 +925,7 @@ def guard_on_dict_keys_and_ignore_order(self, example_value, guard): guard_manager_enum=guard_manager_enum, ) - def guard_on_dict_keys_and_order(self, value, guard): + def guard_on_dict_keys_and_order(self, value: dict[Any, Any], guard: Guard) -> None: # Add key managers for the DictGuardManager. Then add either an # ID_MATCH or EQUALS_MATCH guard on the key. dict_mgr = self.get_guard_manager(guard) @@ -937,7 +964,7 @@ def guard_on_dict_keys_and_order(self, value, guard): ) @staticmethod - def _get_generic_dict_manager_example_value(example_value): + def _get_generic_dict_manager_example_value(example_value: Any) -> Optional[Any]: # due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115, # reported in https://github.com/python/cpython/issues/125608, # fixed by https://github.com/python/cpython/pull/125611), we cannot take @@ -956,14 +983,14 @@ def _get_generic_dict_manager_example_value(example_value): def getattr_on_nn_module( self, - source, - base_guard_manager, - base_example_value, - example_value, - base_source_name, - source_name, - guard_manager_enum, - ): + source: AttrSource, + base_guard_manager: GuardManager, + base_example_value: Any, + example_value: Any, + base_source_name: str, + source_name: str, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: """ This tries to avoid calling the expensive nn module custom getattr method by checking if the attribute is accessible via __dict__. For attributes that @@ -982,8 +1009,13 @@ def getattr_on_nn_module( """ def getitem_on_dict_mgr( - mgr, key, source_name, base_example_value, example_value, guard_manager_enum - ): + mgr: GuardManager, + key: Any, + source_name: str, + base_example_value: Any, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: if isinstance(mgr, DictGuardManager): # Case where the user code relies on key order, e.g., # named_parameters @@ -1093,6 +1125,7 @@ def getitem_on_dict_mgr( ) if l2_key: + assert l2_source_name is not None and l2_guard_manager_enum is not None return getitem_on_dict_mgr( mgr=l1_mgr, key=l2_key, @@ -1103,14 +1136,20 @@ def getitem_on_dict_mgr( ) return l1_mgr - def requires_key_order_guarding(self, source): + def requires_key_order_guarding(self, source: Source) -> bool: source_name = source.name() if source_name == "": return False obj_id = id(self.get(source_name)) return obj_id in self.key_order_guarded_dict_ids - def get_guard_manager_type(self, source, example_value): + def get_guard_manager_type( + self, + source: Source, + example_value: Optional[ + Union[KeysView[Any], set[Any], frozenset[Any], dict[Any, Any]] + ], + ) -> GuardManagerType: guard_manager_enum = GuardManagerType.GUARD_MANAGER if self.requires_key_order_guarding(source): # Fix this if condition @@ -1126,10 +1165,10 @@ def get_guard_manager_type(self, source, example_value): guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER return guard_manager_enum - def manager_guards_on_keys(self, mgr_enum): + def manager_guards_on_keys(self, mgr_enum: GuardManagerType) -> bool: return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER - def get_global_guard_manager(self): + def get_global_guard_manager(self) -> GuardManager: return self.guard_manager.root.globals_dict_manager( f_globals=self.runtime_global_scope, source="G", @@ -1137,7 +1176,7 @@ def get_global_guard_manager(self): guard_manager_enum=GuardManagerType.GUARD_MANAGER, ) - def get_guard_manager_from_source(self, source): + def get_guard_manager_from_source(self, source: Source) -> GuardManager: root_guard_manager = self.guard_manager.root example_value = None @@ -1275,12 +1314,13 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, (AttrSource, UnspecializedParamBufferSource)): assert base_guard_manager # to make mypy happy - + assert isinstance(source, AttrSource) if ( isinstance(base_example_value, torch.nn.Module) and get_custom_getattr(base_example_value) is unpatched_nn_module_getattr ): + assert base_source_name out = self.getattr_on_nn_module( source, base_guard_manager, @@ -1300,6 +1340,7 @@ def get_guard_manager_from_source(self, source): elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)): assert base_guard_manager # to make mypy happy assert isinstance(base_example_value, (dict, collections.OrderedDict)) + assert isinstance(source, (DictGetItemSource, DictSubclassGetItemSource)) if isinstance(base_guard_manager, DictGuardManager): assert self.manager_guards_on_keys(base_guard_manager_enum) out = getitem_on_dict_manager( @@ -1538,16 +1579,16 @@ def get_guard_manager_from_source(self, source): self._cached_guard_managers[source.name()] = out return out - def get_guard_manager(self, guard: Guard): + def get_guard_manager(self, guard: Guard) -> GuardManager: return self.get_guard_manager_from_source(guard.originating_source) def add_python_lambda_leaf_guard_to_root( self, - code_parts, - verbose_code_parts, - closure_vars=None, - is_epilogue=True, - ): + code_parts: list[str], + verbose_code_parts: list[str], + closure_vars: Optional[dict[str, object]] = None, + is_epilogue: bool = True, + ) -> None: if closure_vars is None: closure_vars = _get_closure_vars() # Adds a lambda leaf guard to the root guard manager. It wraps the @@ -1602,7 +1643,12 @@ def arg_ref(self, guard: Union[str, Guard]) -> str: return name - def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): + def _guard_on_attribute( + self, + guard: Guard, + attr_name: str, + guard_fn: Callable[[GuardBuilderBase, Guard], Any], + ) -> None: if attr_name == "__code__": attr_source = CodeSource(guard.originating_source) else: @@ -1614,7 +1660,7 @@ def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): new_guard.create(self) # Note: the order of the guards in this file matters since we sort guards on the same object by lineno - def HASATTR(self, guard: Guard): + def HASATTR(self, guard: Guard) -> None: source = guard.originating_source if isinstance(source, NNModuleSource): source = source.base @@ -1652,7 +1698,7 @@ def HASATTR(self, guard: Guard): and get_custom_getattr(base_example_value) is unpatched_nn_module_getattr ): - return self.getattr_on_nn_module( + self.getattr_on_nn_module( source, base_manager, base_example_value, @@ -1671,7 +1717,9 @@ def HASATTR(self, guard: Guard): else: base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard)) - def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: + def NOT_PRESENT_IN_GENERIC_DICT( + self, guard: Guard, attr: Optional[Any] = None + ) -> None: assert attr is not None ref = self.arg_ref(guard) val = self.get(guard.name) @@ -1714,7 +1762,7 @@ def TYPE_MATCH(self, guard: Guard) -> None: obj_id, get_verbose_code_parts(code, guard) ) - def DICT_VERSION(self, guard: Guard): + def DICT_VERSION(self, guard: Guard) -> None: if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError( "DICT_VERSION guard cannot be serialized." @@ -1732,7 +1780,7 @@ def DICT_VERSION(self, guard: Guard): val, get_verbose_code_parts(code, guard) ) - def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): + def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool) -> None: dict_ref = self.arg_ref(guard) maybe_not = "not " if invert else "" @@ -1743,7 +1791,7 @@ def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): not invert, key, get_verbose_code_parts(code, guard) ) - def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool): + def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None: set_ref = self.arg_ref(guard) item = key contains = not invert # install_dict_contains_guard inverts "contains" @@ -1756,7 +1804,7 @@ def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool): contains, item, get_verbose_code_parts(code, guard) ) - def BOOL_MATCH(self, guard: Guard): + def BOOL_MATCH(self, guard: Guard) -> None: # checks val == True or val == False ref = self.arg_ref(guard) val = self.get(guard.name) @@ -1773,7 +1821,7 @@ def BOOL_MATCH(self, guard: Guard): get_verbose_code_parts(code, guard) ) - def NONE_MATCH(self, guard: Guard): + def NONE_MATCH(self, guard: Guard) -> None: # checks `val is None` ref = self.arg_ref(guard) val = self.get(guard.name) @@ -1785,12 +1833,12 @@ def NONE_MATCH(self, guard: Guard): get_verbose_code_parts(code, guard) ) - def ID_MATCH(self, guard: Guard): + def ID_MATCH(self, guard: Guard) -> None: if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.") return self.id_match_unchecked(guard) - def id_match_unchecked(self, guard: Guard): + def id_match_unchecked(self, guard: Guard) -> None: # ___check_obj_id is same as `id(x) == y` if isinstance(guard.originating_source, TypeSource): # optional optimization to produce cleaner/faster guard code @@ -1820,7 +1868,7 @@ def id_match_unchecked(self, guard: Guard): if weak_id is not None: self.id_matched_objs[local_name] = weak_id - def NOT_NONE_MATCH(self, guard: Guard, value=None): + def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: ref = self.arg_ref(guard) val = self.get(guard.name) assert isinstance(val, torch.Tensor) @@ -1831,7 +1879,7 @@ def NOT_NONE_MATCH(self, guard: Guard, value=None): get_verbose_code_parts(code, guard) ) - def DISPATCH_KEY_SET_MATCH(self, guard: Guard): + def DISPATCH_KEY_SET_MATCH(self, guard: Guard) -> None: ref = self.arg_ref(guard) val = self.get(guard.name) assert isinstance(val, torch._C.DispatchKeySet) @@ -1841,28 +1889,30 @@ def DISPATCH_KEY_SET_MATCH(self, guard: Guard): val, get_verbose_code_parts(code_parts, guard) ) - def NAME_MATCH(self, guard: Guard): - self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) + def NAME_MATCH(self, guard: Guard) -> None: + self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) # type: ignore[arg-type] - def DUAL_LEVEL(self, guard: Guard): + def DUAL_LEVEL(self, guard: Guard) -> None: # Invalidate dual level if current dual level is different than the one # in the fx graph + assert self.check_fn_manager.output_graph is not None dual_level = self.check_fn_manager.output_graph.dual_level code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] - self._set_guard_export_info(guard, [code]) + self._set_guard_export_info(guard, code) # TODO(anijain2305) - Consider this moving this guard to C++ forward_ad = torch.autograd.forward_ad - def fn(x): + def fn(x: Any) -> bool: return forward_ad._current_level == dual_level self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) - def FUNCTORCH_STACK_MATCH(self, guard: Guard): + def FUNCTORCH_STACK_MATCH(self, guard: Guard) -> None: # Invalidate functorch code if current level is different than # the one when FX graph was generated + assert self.check_fn_manager.output_graph is not None cis = self.check_fn_manager.output_graph.functorch_layers states = [ci.get_state() for ci in cis] code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] @@ -1871,20 +1921,22 @@ def FUNCTORCH_STACK_MATCH(self, guard: Guard): # TODO(anijain2305) - Consider this moving this guard to C++ compare_fn = torch._functorch.pyfunctorch.compare_functorch_state - def fn(x): + def fn(x: Any) -> bool: return compare_fn(states) self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) - def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard): + def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard) -> None: get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks are_inline_hooks = ( torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable ) - def hooks_ids_fn(hooks): + def hooks_ids_fn( + hooks: tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]], + ) -> Optional[tuple[int, ...]]: if not are_inline_hooks(hooks): return None @@ -1898,27 +1950,27 @@ def hooks_ids_fn(hooks): ] self._set_guard_export_info(guard, code) - def fn(x): + def fn(x: Any) -> bool: return guard_hooks_ids == hooks_ids_fn(get_hooks()) self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) - def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard): + def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard) -> None: value = self.get(guard.name) original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) if hasattr(value, "__metadata_guard__"): verify_guard_fn_signature(value) - def metadata_checker(x): + def metadata_checker(x: Any) -> bool: return value.__metadata_guard__( original_metadata, x.__tensor_flatten__()[1] ) else: - def metadata_checker(x): + def metadata_checker(x: Any) -> bool: return x.__tensor_flatten__()[1] == original_metadata global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" @@ -1926,7 +1978,7 @@ def metadata_checker(x): metadata_checker, get_verbose_code_parts(global_name, guard) ) - def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None): + def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: ref = self.arg_ref(guard) val = self.get(guard.name) if np: @@ -2034,7 +2086,7 @@ def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None): self._set_guard_export_info(guard, code) return - def CONSTANT_MATCH(self, guard: Guard): + def CONSTANT_MATCH(self, guard: Guard) -> None: val = self.get(guard.name) if istype(val, bool): self.BOOL_MATCH(guard) @@ -2045,7 +2097,7 @@ def CONSTANT_MATCH(self, guard: Guard): else: self.EQUALS_MATCH(guard) - def NN_MODULE(self, guard: Guard): + def NN_MODULE(self, guard: Guard) -> None: # don't support this in serialization because it uses unsupported ID_MATCH if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError( @@ -2057,7 +2109,7 @@ def NN_MODULE(self, guard: Guard): assert istype(val.training, bool) if not self.guard_nn_modules: # If guard_nn_modules is true, we will guard on the right set of guards - self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) + self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type] else: exc.unimplemented_v2( gb_type="Attempted to guard on uninitialized nn.Module", @@ -2069,7 +2121,7 @@ def NN_MODULE(self, guard: Guard): ], ) - def FUNCTION_MATCH(self, guard: Guard): + def FUNCTION_MATCH(self, guard: Guard) -> None: """things like torch.add and user defined functions""" # don't support this in serialization because it uses unsupported ID_MATCH if self.serialization_mode == "save": @@ -2078,7 +2130,7 @@ def FUNCTION_MATCH(self, guard: Guard): ) return self.ID_MATCH(guard) - def CLOSURE_MATCH(self, guard: Guard): + def CLOSURE_MATCH(self, guard: Guard) -> None: """matches a closure by __code__ id.""" # don't support this in serialization because it uses unsupported FUNCTION_MATCH if self.serialization_mode == "save": @@ -2088,12 +2140,12 @@ def CLOSURE_MATCH(self, guard: Guard): val = self.get(guard.name) # Strictly only want user-defined functions if type(val) == types.FunctionType and hasattr(val, "__code__"): - self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) - self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) + self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type] + self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type] else: self.FUNCTION_MATCH(guard) - def BUILTIN_MATCH(self, guard: Guard): + def BUILTIN_MATCH(self, guard: Guard) -> None: if self.serialization_mode == "save": # Record which builtin variables are used for pruning later. if isinstance(guard.originating_source, DictGetItemSource): @@ -2104,7 +2156,7 @@ def BUILTIN_MATCH(self, guard: Guard): return self.ID_MATCH(guard) - def SEQUENCE_LENGTH(self, guard): + def SEQUENCE_LENGTH(self, guard: Guard) -> None: # This guard is used to check length of PySequence objects like list, # tuple, collections.deque etc ref = self.arg_ref(guard) @@ -2130,7 +2182,7 @@ def SEQUENCE_LENGTH(self, guard): len(value), get_verbose_code_parts(code, guard) ) - def TUPLE_ITERATOR_LEN(self, guard): + def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None: ref = self.arg_ref(guard) value = self.get(guard.name) t = type(value) @@ -2146,7 +2198,7 @@ def TUPLE_ITERATOR_LEN(self, guard): tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) ) - def RANGE_ITERATOR_MATCH(self, guard): + def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None: ref = self.arg_ref(guard) value = self.get(guard.name) t = type(value) @@ -2165,7 +2217,7 @@ def RANGE_ITERATOR_MATCH(self, guard): ) # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards - def DUPLICATE_INPUT(self, guard, source_b): + def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None: if self.serialization_mode == "save": if name := get_local_source_name(source_b): self.check_fn_manager.additional_used_local_vars.add(name) @@ -2205,7 +2257,7 @@ def DUPLICATE_INPUT(self, guard, source_b): get_verbose_code_parts(code, guard), ) - def WEAKREF_ALIVE(self, guard): + def WEAKREF_ALIVE(self, guard: Guard) -> None: if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError( "WEAKREF_ALIVE guard cannot be serialized." @@ -2217,7 +2269,7 @@ def WEAKREF_ALIVE(self, guard): get_verbose_code_parts(code, guard) ) - def MAPPING_KEYS_CHECK(self, guard): + def MAPPING_KEYS_CHECK(self, guard: Guard) -> None: """Guard on the key order of types.MappingProxyType object""" ref = self.arg_ref(guard) value = self.get(guard.name) @@ -2227,7 +2279,7 @@ def MAPPING_KEYS_CHECK(self, guard): self._set_guard_export_info(guard, code) self.get_guard_manager(guard).add_mapping_keys_guard(value, code) - def DICT_KEYS_MATCH(self, guard): + def DICT_KEYS_MATCH(self, guard: Guard) -> None: """Insert guard to check that the keys of a dict are same""" ref = self.arg_ref(guard) value = self.get(guard.name) @@ -2252,29 +2304,30 @@ def DICT_KEYS_MATCH(self, guard): else: self.guard_on_dict_keys_and_ignore_order(value, guard) - def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): + def EMPTY_NN_MODULE_HOOKS_DICT(self, guard: Guard) -> None: """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" if config.skip_nnmodule_hook_guards: # This is unsafe if you add/remove a hook on nn module variable return self.SEQUENCE_LENGTH(guard) - def GRAD_MODE(self, guard: Guard): + def GRAD_MODE(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def DETERMINISTIC_ALGORITHMS(self, guard: Guard): + def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def TORCH_FUNCTION_STATE(self, guard: Guard): + def TORCH_FUNCTION_STATE(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def FSDP_TRAINING_STATE(self, guard: Guard): + def FSDP_TRAINING_STATE(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() - def DEFAULT_DEVICE(self, guard: Guard): + def DEFAULT_DEVICE(self, guard: Guard) -> None: """Guard on CURRENT_DEVICE per torch.utils._device""" assert guard.source is GuardSource.GLOBAL + assert self.check_fn_manager.output_graph is not None code = [ f"utils_device.CURRENT_DEVICE == {self.check_fn_manager.output_graph.current_device!r}" ] @@ -2284,9 +2337,10 @@ def DEFAULT_DEVICE(self, guard: Guard): get_verbose_code_parts(code, guard) ) - def SHAPE_ENV(self, guard: Guard): + def SHAPE_ENV(self, guard: Guard) -> None: assert guard.name == "" output_graph = self.check_fn_manager.output_graph + assert output_graph is not None if self.serialization_mode == "load": assert self.check_fn_manager.shape_code_parts is not None shape_code_parts = self.check_fn_manager.shape_code_parts @@ -2303,7 +2357,7 @@ def SHAPE_ENV(self, guard: Guard): fs = output_graph.tracked_fakes input_contexts = [a.symbolic_context for a in fs] - def get_sources(t_id, dim): + def get_sources(t_id: int, dim: int) -> list[Source]: # Looks up base sources mapped to a tensor id and uses them to create # sources for the corresponding tensor dimension. return [ @@ -2311,6 +2365,7 @@ def get_sources(t_id, dim): for source in output_graph.tracked_fakes_id_to_source[t_id] ] + assert output_graph.shape_env is not None if output_graph.export_constraints: names: dict[str, tuple[int, int]] = {} source_pairs: list[tuple[Source, Source]] = [] @@ -2319,7 +2374,7 @@ def get_sources(t_id, dim): ] = [] phantom_symbols: dict[str, Symbol] = {} relaxed_sources: set[Source] = set() - for constraint in output_graph.export_constraints: + for constraint in output_graph.export_constraints: # type: ignore[attr-defined] if constraint.t_id in output_graph.tracked_fakes_id_to_source: torch.export.dynamic_shapes._process_equalities( constraint, @@ -2343,15 +2398,15 @@ def get_sources(t_id, dim): else: equalities_inputs = None - def _get_code_parts(langs): + def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: return output_graph.shape_env.produce_guards_verbose( - [a.fake for a in fs], + [a.fake for a in fs], # type: ignore[misc] [a.source for a in fs], - input_contexts=input_contexts, + input_contexts=input_contexts, # type: ignore[arg-type] equalities_inputs=equalities_inputs, source_ref=self.source_ref, # Export keeps static. - ignore_static=(not self.check_fn_manager.output_graph.export), + ignore_static=(not output_graph.export), langs=langs, ) @@ -2359,7 +2414,7 @@ def _get_code_parts(langs): try: # For exporting we need the python code parts python_code_parts, verbose_code_parts, cpp_code_parts = ( - _get_code_parts(("python", "verbose_python", "cpp")) + _get_code_parts(("python", "verbose_python", "cpp")) # type: ignore[assignment] ) python_fallback = False except OverflowError: @@ -2376,7 +2431,7 @@ def _get_code_parts(langs): # When exporting, we may work with the shape constraints some more in # postprocessing, so don't freeze yet - if not self.check_fn_manager.output_graph.export: + if not output_graph.export: output_graph.shape_env.freeze() if self.serialization_mode == "save": @@ -2520,7 +2575,7 @@ def _get_code_parts(langs): closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, ) - def TENSOR_MATCH(self, guard: Guard, value=None): + def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module(): return # For tensors that are part of the Dynamo extracted Fx graph module, an @@ -2573,6 +2628,7 @@ def TENSOR_MATCH(self, guard: Guard, value=None): # The list of tensor fields and calls we care about can be found in `terms` below. # TODO(voz): We are missing storage offset in all our tensor guards? code: list[str] = [] + assert self.check_fn_manager.output_graph is not None if self.check_fn_manager.output_graph.export: self.TYPE_MATCH(guard) terms = [ @@ -2624,7 +2680,12 @@ def TENSOR_MATCH(self, guard: Guard, value=None): verbose_code_parts = get_verbose_code_parts( get_tensor_guard_code_part( - value, tensor_name, size, stride, pytype, dispatch_keys + value, + tensor_name, + size, + stride, + pytype, + dispatch_keys, # type: ignore[arg-type] ), guard, ) @@ -2700,8 +2761,12 @@ def TENSOR_MATCH(self, guard: Guard, value=None): # A util that in the case of export, adds data onto guards def _set_guard_export_info( - self, guard, code_list, provided_guarded_object=None, provided_func_name=None - ): + self, + guard: Guard, + code_list: list[str], + provided_guarded_object: Optional[Any] = None, + provided_func_name: Optional[str] = None, + ) -> None: # WARNING: It is important that cur_frame/caller do NOT stay in # the current frame, because they will keep things live longer # than they should. See TestMisc.test_release_module_memory @@ -2779,7 +2844,7 @@ class ExprCounter(ast.NodeVisitor): def __init__(self, config: PyExprCSEPass.Config) -> None: self._config = config - def visit(self, node: ast.AST) -> Any: + def visit(self, node: ast.AST) -> None: if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): self._config.expr_count[_ast_unparse(node)] += 1 super().visit(node) @@ -2847,7 +2912,7 @@ def replace(self, expr: str) -> tuple[list[str], str]: return replacer.preface, _ast_unparse(new_node) -def must_add_nn_module_guards(guard): +def must_add_nn_module_guards(guard: Guard) -> bool: # For config.guard_nn_modules=False, we can skip all the guards that # originate from inside of nn module except for a few categories. return ( @@ -2862,11 +2927,11 @@ def must_add_nn_module_guards(guard): class DeletedGuardManagerWrapper(GuardManagerWrapper): - def __init__(self, reason): + def __init__(self, reason: str) -> None: super().__init__() self.invalidation_reason = reason - def populate_diff_guard_manager(self): + def populate_diff_guard_manager(self) -> None: self.diff_guard_root = None @@ -2881,7 +2946,7 @@ class ShapeCodeParts: @dataclasses.dataclass class GuardsState: - output_graph: OutputGraphGuardsState + output_graph: OutputGraph shape_code_parts: Optional[ShapeCodeParts] @@ -2890,19 +2955,26 @@ class _Missing: class GuardsStatePickler(pickle.Pickler): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.fake_mode = torch._subclasses.FakeTensorMode() self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() @classmethod - def _unpickle_module(cls, state): + def _unpickle_module(cls, state: Any) -> torch.nn.Module: mod = torch.nn.Module() mod.__setstate__(state) return mod @classmethod - def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw, grad): + def _unpickle_tensor( + cls, + meta_tensor: torch.Tensor, + device: torch.device, + pytype: type, + dispatch_keys_raw: int, + grad: torch.Tensor, + ) -> torch.Tensor: fake_mode = torch._subclasses.FakeTensorMode() tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() ret = tensor_converter.from_meta_and_device( @@ -2917,15 +2989,21 @@ def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw, grad): @classmethod def _unpickle_traceable_wrapper_subclass( - cls, meta_tensor, device, pytype, dispatch_keys_raw, ctx, inner_data - ): + cls, + meta_tensor: torch.Tensor, + device: torch.device, + pytype: type, + dispatch_keys_raw: int, + ctx: Any, + inner_data: list[tuple[str, Callable[..., Any], tuple[Any, ...]]], + ) -> torch.Tensor: # Unpickle the inner tensor components. These could also be subclass instances. inner_tensors = {} for attr, unpickle_func, unpickle_func_args in inner_data: inner_tensors[attr] = unpickle_func(*unpickle_func_args) outer_size, outer_stride = meta_tensor.shape, meta_tensor.stride() - out = type(meta_tensor).__tensor_unflatten__( + out = type(meta_tensor).__tensor_unflatten__( # type: ignore[attr-defined] inner_tensors, ctx, outer_size, outer_stride ) out.pytype = pytype @@ -2933,26 +3011,32 @@ def _unpickle_traceable_wrapper_subclass( return out @classmethod - def _unpickle_python_module(cls, alias: str): + def _unpickle_python_module(cls, alias: str) -> types.ModuleType: return importlib.import_module(alias) @classmethod - def _unpickle_dispatch_key_set(cls, raw_repr: int): + def _unpickle_dispatch_key_set(cls, raw_repr: int) -> torch._C.DispatchKeySet: return torch._C.DispatchKeySet.from_raw_repr(raw_repr) @classmethod - def _unpickle_functorch_interpreter(cls, json: bytes): + def _unpickle_functorch_interpreter( + cls, json: bytes + ) -> torch._C._functorch.CInterpreter: return torch._C._functorch.CInterpreter.deserialize(json) @classmethod - def _unpickle_mapping_proxy(cls, d): + def _unpickle_mapping_proxy( + cls, d: dict[Any, Any] + ) -> types.MappingProxyType[Any, Any]: return types.MappingProxyType(d) @classmethod - def _unpickle_c_op(cls, name): + def _unpickle_c_op(cls, name: str) -> Any: return getattr(torch.ops._C, name) - def reducer_override(self, obj): + def reducer_override( + self, obj: Any + ) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]: import sympy if isinstance(obj, torch.Tensor) and obj.device.type != "meta": @@ -3065,9 +3149,9 @@ def pickle_guards_state(state: GuardsState) -> bytes: class CheckFunctionManager: def __init__( self, - f_code, - output_graph=None, - cache_entry=None, + f_code: types.CodeType, + output_graph: Optional[OutputGraph] = None, + cache_entry: Optional[CacheEntry] = None, guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, guard_filter_fn: Optional[ Callable[[list[GuardFilterEntry]], list[bool]] @@ -3110,7 +3194,7 @@ def __init__( ): _guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs]) - def guard_filter_fn(guards): + def guard_filter_fn(guards: list[GuardFilterEntry]) -> list[bool]: ret = [] for keep, g in zip(_guard_filter_fn(guards), guards): if not keep: @@ -3130,6 +3214,7 @@ def guard_filter_fn(guards): return ret sorted_guards = sorted(guards or (), key=Guard.sort_key) + assert output_graph is not None builder, guard_manager = self.build_guards( sorted_guards, existing_diff_guard_sources, @@ -3140,7 +3225,7 @@ def guard_filter_fn(guards): if guard_filter_fn: - def make_guard_filter_entry(guard): + def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: MISSING = object() name = strip_local_scope(guard.name) if name == "": @@ -3160,15 +3245,15 @@ def make_guard_filter_entry(guard): is_global = get_global_source_name(guard.originating_source) is not None guard_fn = guard.create_fn if isinstance(guard_fn, functools.partial): - guard_fn = guard.create_fn.func + guard_fn = guard.create_fn.func # type: ignore[attr-defined] return GuardFilterEntry( name=name, has_value=has_value, value=value, guard_type=guard_fn.__name__, - derived_guard_types=tuple(guard.guard_types) - if guard.guard_types - else (), + derived_guard_types=( + tuple(guard.guard_types) if guard.guard_types else () + ), is_global=is_global, orig_guard=guard, ) @@ -3214,7 +3299,7 @@ def make_guard_filter_entry(guard): if not output_graph.export and self.guards_serialization_mode != "load": if not self.guard_manager.check(output_graph.local_scope): reasons = get_guard_fail_reason_helper( - self.guard_manager, # type: ignore[arg-type] + self.guard_manager, output_graph.local_scope, CompileContext.current_compile_id(), ) @@ -3247,12 +3332,13 @@ def make_guard_filter_entry(guard): CompileEventLogger.increment_toplevel("guard_latency_us", int(latency)) self.guards_state: Optional[bytes] = None + assert self.output_graph is not None builtins_dict_name = self.output_graph.name_of_builtins_dict_key_in_fglobals if self.guards_serialization_mode == "save": used_global_vars = set() used_local_vars = set() - def prune_variable(source): + def prune_variable(source: Source) -> None: if name := get_global_source_name(source): assert isinstance(name, str) # Leave out the builtins dict key, as we will special handle @@ -3277,10 +3363,10 @@ def prune_variable(source): for source in self.output_graph.guard_on_key_order: prune_variable(source) - def normalize_create_fn(x): + def normalize_create_fn(x: Any) -> Any: if isinstance(x, functools.partial): - def _ref(x): + def _ref(x: Any) -> Any: if isinstance(x, (TensorWeakRef, weakref.ref)): return x() return x @@ -3300,7 +3386,7 @@ def _ref(x): k: v for k, v in output_graph_guards_state.global_scope[ builtins_dict_name - ].items() + ].items() # type: ignore[attr-defined] if k in self.used_builtin_vars } output_graph_guards_state = dataclasses.replace( @@ -3328,7 +3414,7 @@ def _ref(x): ), ) guards_state = GuardsState( - output_graph=output_graph_guards_state, + output_graph=output_graph_guards_state, # type: ignore[arg-type] shape_code_parts=self.shape_code_parts, ) self.guards_state = pickle_guards_state(guards_state) @@ -3351,18 +3437,18 @@ def _ref(x): def build_guards( self, - sorted_guards, - existing_diff_guard_sources, - f_code, - output_graph, - serialization_mode=None, - ): + sorted_guards: list[Guard], + existing_diff_guard_sources: OrderedSet[str], + f_code: types.CodeType, + output_graph: OutputGraph, + serialization_mode: Optional[str] = None, + ) -> tuple[GuardBuilder, GuardManagerWrapper]: guard_manager = GuardManagerWrapper() guard_manager.diff_guard_sources = existing_diff_guard_sources w_builder = None - def source_ref(source): + def source_ref(source: Source) -> str: guard_source = source.guard_source() if guard_source is GuardSource.CONSTANT: # No need to track constants @@ -3386,10 +3472,10 @@ def source_ref(source): ) # Break retain cycle. See test_release_scope_memory - def cleanup_builder(weak_b): + def cleanup_builder(weak_b: weakref.ref[GuardBuilder]) -> None: b = weak_b() if b: - b.scope = None + b.scope = None # type: ignore[assignment] # Break retain cycle. See test_release_input_memory w_builder = weakref.ref(builder, cleanup_builder) @@ -3413,7 +3499,12 @@ def cleanup_builder(weak_b): guard.create(builder) return builder, guard_manager - def compile_check_fn(self, builder, guards_out, guard_fail_fn): + def compile_check_fn( + self, + builder: GuardBuilder, + guards_out: list[Guard], + guard_fail_fn: Optional[Callable[[GuardFail], None]], + ) -> None: # see parallel handling of ".0" / "___implicit0" in _eval_frame.c largs = builder.argnames largs += ["**___kwargs_ignored"] @@ -3424,6 +3515,7 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): verbose_code_parts = [] structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] + assert self.torch_function_mode_stack is not None torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( self.torch_function_mode_stack ) @@ -3447,7 +3539,9 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): # Clear references to torch_function modes held in the list self.torch_function_mode_stack = None - def add_code_part(code_part, guard, log_only=False): + def add_code_part( + code_part: str, guard: Optional[Guard], log_only: bool = False + ) -> None: verbose_code_part = get_verbose_code_part(code_part, guard) guards_log.debug("%s", verbose_code_part) @@ -3617,7 +3711,7 @@ def add_code_part(code_part, guard, log_only=False): self.guard_manager.extra_state = None self.guard_manager.no_tensor_aliasing_sources = no_tensor_aliasing_names - def invalidate(self, obj_str): + def invalidate(self, obj_str: str) -> None: # Some tests reveal that CheckFunctionManager has no attribute # guard_manager, but this case should not be of any concern. # This case doesn't seem easy to repro. @@ -3634,7 +3728,7 @@ def invalidate(self, obj_str): extra_state.invalidate(cache_entry, deleted_guard_manager) self.guard_manager = deleted_guard_manager - def id_ref(self, obj, obj_str): + def id_ref(self, obj: object, obj_str: str) -> int: """add a weakref, return the id""" try: if id(obj) not in self._weakrefs: @@ -3649,14 +3743,14 @@ def id_ref(self, obj, obj_str): pass # cannot weakref bool object return id(obj) - def lookup_weakrefs(self, obj): + def lookup_weakrefs(self, obj: object) -> Optional[weakref.ref[object]]: """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" if id(obj) in self._weakrefs: return self._weakrefs[id(obj)] return None -def build_guard_function(code_parts, closure_args) -> tuple[str, str]: +def build_guard_function(code_parts: list[str], closure_args: str) -> tuple[str, str]: from torch._inductor.utils import IndentedBuffer csepass = PyExprCSEPass() @@ -3665,6 +3759,7 @@ def build_guard_function(code_parts, closure_args) -> tuple[str, str]: def replace(expr: str) -> tuple[list[str], str]: return csepass.replace(expr) + except RecursionError: # If we hit recursion limits during CSE analysis, fall back to a no-op replace function # This can happen with extremely complex guard expressions @@ -3699,19 +3794,21 @@ def replace(expr: str) -> tuple[list[str], str]: return guard_body.getvalue(), make_guard_fn.getvalue() -def is_recompiles_enabled(): +def is_recompiles_enabled() -> bool: return torch._logging._internal.log_state.is_artifact_enabled("recompiles") -def is_recompiles_verbose_enabled(): +def is_recompiles_verbose_enabled() -> bool: return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") # this will only be used if cpp guards are disabled -def make_torch_function_mode_stack_guard(initial_stack): +def make_torch_function_mode_stack_guard( + initial_stack: list[torch.overrides.TorchFunctionMode], +) -> Callable[[], bool]: types = [type(x) for x in initial_stack] - def check_torch_function_mode_stack(): + def check_torch_function_mode_stack() -> bool: cur_stack = get_torch_function_mode_stack() if len(cur_stack) != len(types): @@ -3726,10 +3823,16 @@ def check_torch_function_mode_stack(): return check_torch_function_mode_stack -def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): +Scope = TypeAliasType("Scope", dict[str, object]) + + +def recompilation_reason_for_no_tensor_aliasing_guard( + guard_manager: GuardManagerWrapper, scope: Scope +) -> list[str]: + assert guard_manager.global_scope is not None global_scope = dict(guard_manager.global_scope) ids_to_source = collections.defaultdict(list) - for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined] + for tensor_source in guard_manager.no_tensor_aliasing_sources: global_scope["__compile_source__"] = tensor_source tensor_id = id(eval(tensor_source, global_scope, scope)) ids_to_source[tensor_id].append(tensor_source) @@ -3756,7 +3859,7 @@ def strip_local_scope(s: str) -> str: def get_guard_fail_reason_helper( - guard_manager: GuardFn, + guard_manager: GuardManagerWrapper, f_locals: dict[str, object], compile_id: Optional[CompileId], ) -> str: @@ -3765,6 +3868,8 @@ def get_guard_fail_reason_helper( Updates `guard_failures` with the generated reason. Only the first failed check of guard_manager is reported. """ + assert guard_manager.global_scope is not None + assert guard_manager.closure_vars is not None scope = {"L": f_locals, "G": guard_manager.global_scope["G"]} scope.update(guard_manager.closure_vars) reasons: list[str] = [] @@ -3772,7 +3877,7 @@ def get_guard_fail_reason_helper( no_tensor_aliasing_check_failed = False verbose_code_parts: list[str] = [] - guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] + guard_debug_info = guard_manager.check_verbose(f_locals) # For test_export_with_map_cond, the check_verbose fail even without the # C++ guard manager. We need to fix the issue to remove the comment. # assert not guard_debug_info.result @@ -3823,7 +3928,7 @@ def get_guard_fail_reason_helper( def get_guard_fail_reason( - guard_manager: GuardFn, + guard_manager: GuardManagerWrapper, code: types.CodeType, f_locals: dict[str, object], compile_id: CompileId, @@ -3847,7 +3952,7 @@ def get_guard_fail_reason( def get_and_maybe_log_recompilation_reasons( - cache_entry, frame: DynamoFrameType + cache_entry: Optional[CacheEntry], frame: DynamoFrameType ) -> list[str]: """ Return the list of guard failure reasons using cache_entry. @@ -3906,18 +4011,20 @@ def get_and_maybe_log_recompilation_reasons( return reasons -def update_diff_guard_managers_for_existing_cache_entries(cache_entry): +def update_diff_guard_managers_for_existing_cache_entries( + cache_entry: Optional[CacheEntry], +) -> OrderedSet[str]: first_cache_entry = cache_entry # On the first pass, go through the cache entries and accumulate the diff # guard sources. Different guard managers can fail with different sources. # So, we collect all of them first. - acc_diff_guard_sources = set() + acc_diff_guard_sources: OrderedSet[str] = OrderedSet() while cache_entry is not None: acc_diff_guard_sources.update( cache_entry.guard_manager.collect_diff_guard_sources() ) - cache_entry = cache_entry.next + cache_entry = cache_entry.next # type: ignore[assignment] # On the second pass, set the diff_guard_sources for each cache line to the # accumulated value. And the re-populate the diff guard manager. @@ -3925,7 +4032,7 @@ def update_diff_guard_managers_for_existing_cache_entries(cache_entry): while cache_entry is not None: cache_entry.guard_manager.diff_guard_sources = acc_diff_guard_sources cache_entry.guard_manager.populate_diff_guard_manager() - cache_entry = cache_entry.next + cache_entry = cache_entry.next # type: ignore[assignment] # return the accumulated sources to set up the new cache line. return acc_diff_guard_sources @@ -3937,7 +4044,7 @@ def guard_error_hook( f_locals: dict[str, object], index: int, last: bool, -): +) -> None: print( f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" ) @@ -3957,7 +4064,7 @@ def guard_error_hook( set_guard_error_hook(guard_error_hook) -def unique(seq): +def unique(seq: Sequence[T]) -> Generator[T, None, None]: seen = set() for x in seq: if x not in seen: @@ -3965,7 +4072,9 @@ def unique(seq): seen.add(x) -def make_dupe_guard(obj_source, dupe_source): +def make_dupe_guard( + obj_source: Source, dupe_source: Source +) -> Optional[functools.partial[Any]]: # Note - we may end up in a situation where we invoke something like # def fn(x, y) # with fn(x, x) @@ -3999,7 +4108,7 @@ def make_dupe_guard(obj_source, dupe_source): return None -def install_guard(*guards, skip=0): +def install_guard(*guards: Guard, skip: int = 0) -> None: """ Add dynamo guards to the current tracing context. diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index aa8902f05e2b9..caa7b6fef5305 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -31,7 +31,7 @@ import sys import traceback import weakref -from collections.abc import Generator +from collections.abc import Generator, Sequence from dataclasses import dataclass, field as dc_field from types import CodeType from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union @@ -57,6 +57,7 @@ ) from torch._subclasses.fake_tensor import FakeTensor from torch._utils_internal import signpost_event +from torch.export.dynamic_shapes import _ConstraintTarget from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.symbolic_shapes import ( @@ -388,7 +389,7 @@ def __init__( compiler_fn: Optional[CompilerFn], root_tx: "InstructionTranslatorBase", export: bool, - export_constraints: Any, + export_constraints: Sequence[_ConstraintTarget], frame_state: Any, local_scope: Scope, global_scope: Scope, @@ -414,7 +415,7 @@ def __init__( # de-duplicate graph inputs by source and reuse the tracker self.input_source_to_var: dict[Source, VariableTracker] = {} self.export = export - self.export_constraints = export_constraints + self.export_constraints = export_constraints # type: ignore[assignment] self.frame_state = frame_state self.cleanup_hooks: list[Callable[[], Any]] = [] # compile_id is an id number for the current torch.compile diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index c87efa048cec2..f0f1dab4f9c8c 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -206,7 +206,7 @@ def insert_nops(instructions: list[Any], code_options: Any) -> None: compiler_fn=None, root_tx=None, # type: ignore[arg-type] export=False, - export_constraints=None, + export_constraints=[], frame_state={"_id": 0}, # TODO: shouldn't this be f_locals/f_globals from frame? local_scope=locals(), diff --git a/torch/_guards.py b/torch/_guards.py index fa6f9cc1e7bd6..dd2ba47747923 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -267,7 +267,7 @@ class Guard: guard_types: Optional[list[str]] = None code_list: Optional[list[str]] = None obj_weakref: Optional[object] = None - guarded_class_weakref: Optional[type] = None + guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None stack: Optional[CapturedTraceback] = None user_stack: Optional[traceback.StackSummary] = None @@ -380,7 +380,7 @@ def is_local(self) -> bool: def set_export_info( self, guard_type: str, - guarded_class: Optional[type], + guarded_class: Optional[weakref.ReferenceType[Any]], code_list: list[str], obj_weakref: object, ) -> None: From 2507ae63f293354170695fd20a5c5ce5f64e323d Mon Sep 17 00:00:00 2001 From: Xiaochang Wu Date: Wed, 6 Aug 2025 22:12:47 +0000 Subject: [PATCH 0071/1424] Partitioner: Fix to align partition node order with original graph (#157892) Fixes #157891 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157892 Approved by: https://github.com/ezyang --- test/fx/test_partitioner_order.py | 15 ++++++++-- torch/fx/passes/infra/partitioner.py | 44 +++++++++++++++++++++------- torch/fx/passes/utils/fuser_utils.py | 4 +-- 3 files changed, 49 insertions(+), 14 deletions(-) diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index ab50b59fb96b7..f4c3ef072f9a6 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -24,6 +24,7 @@ def __init__(self, graph_module: torch.fx.GraphModule): ) +# original graph node order is: ['x', 'add', 'add_1', 'output'] class AddModule(torch.nn.Module): def forward(self, x): y = torch.add(x, x) @@ -32,8 +33,18 @@ def forward(self, x): class TestPartitionerOrder(TestCase): - # partitoner test to check graph node order - def test_partitioner_order(self): + # partitoner test to check graph node order remains the same with the original graph after partitioning + def test_partitioner_graph_node_order(self): + m = AddModule() + traced_m = torch.fx.symbolic_trace(m) + origin_node_order = [n.name for n in traced_m.graph.nodes] + partions = DummyPartitioner(traced_m).propose_partitions() + partion_nodes = [list(partition.nodes) for partition in partions] + partition_node_order = [n.name for n in partion_nodes[0]] + self.assertTrue(partition_node_order == origin_node_order) + + # partitoner test to check graph node order remains the same during multiple runs + def test_partitioner_multiple_runs_order(self): m = AddModule() traced_m = torch.fx.symbolic_trace(m) partitions = DummyPartitioner(traced_m).propose_partitions() diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 438661090942a..6fc17b959424d 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -18,16 +18,29 @@ class Partition: def __init__( - self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None + self, + id: Optional[int] = None, + nodes: Optional[Iterable[Node]] = None, + node_orders: Optional[Iterable[int]] = None, ): self.id = id - self.nodes = dict.fromkeys(nodes) if nodes is not None else {} + self.nodes: dict[Node, Optional[int]] = {} + if nodes is not None: + if node_orders is None: + self.nodes = dict.fromkeys(nodes, None) + else: + nodes_list = list(nodes) + node_orders_list = list(node_orders) + assert len(nodes_list) == len(node_orders_list), ( + "nodes and node_orders must have the same length" + ) + self.nodes = dict(zip(nodes_list, node_orders_list)) def __repr__(self) -> str: return str(self.nodes) - def add_node(self, node: Node): - self.nodes.update({node: None}) + def add_node(self, node: Node, node_order: Optional[int] = None): + self.nodes.update({node: node_order}) def remove_node(self, node: Node): del self.nodes[node] @@ -172,7 +185,7 @@ def dfs_iter_find_cycle(all_user_nodes: set[Node]): return merge_id, True - def merge_single_node(node: Node, id: Optional[int]): + def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]): def _update_partition_map(node: Node, id: int): # Iterate through all the users of this node and update the partition map to indicate # that there is a path from the partition id of this node to the target partition id. @@ -189,16 +202,19 @@ def _update_partition_map(node: Node, id: int): assignment.pop(node) elif id not in partitions_by_id: assignment[node] = id - partitions_by_id[id] = Partition(id=id, nodes=[node]) + assert node_order is not None + partitions_by_id[id] = Partition( + id=id, nodes=[node], node_orders=[node_order] + ) partition_users[id] = set(node.users) _update_partition_map(node, id) else: assignment[node] = id - partitions_by_id[id].add_node(node) + partitions_by_id[id].add_node(node, node_order) logger.debug("Proposing partitions...") - for node in reversed(self.graph_module.graph.nodes): + for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)): # use Dict as an ordered set to ensure deterministic partitioning result, don't care value merge_candidates: dict[int, None] = {} @@ -211,7 +227,7 @@ def _update_partition_map(node: Node, id: int): partition_id = next(new_partition_id) nodes_order[node] = partition_id partitions_order[partition_id] = partition_id - merge_single_node(node, partition_id) + merge_single_node(node, node_order, partition_id) merge_candidates[partition_id] = None # merge all possible partitions @@ -228,6 +244,14 @@ def _update_partition_map(node: Node, id: int): # in the graph, otherwise, this is a no-op self_id, _ = maybe_merge_partition(self_id, other_id) + # sort partition nodes based on descending node order + for partition in partitions_by_id.values(): + partition.nodes = dict( + sorted( + partition.nodes.items(), key=operator.itemgetter(1), reverse=True + ) + ) + # post processing to re-assign "getitem" nodes into upstream partition logger.debug("Reassigning getitem nodes to its producer node's partition...") nodes_reassignment: dict[Node, int] = {} @@ -248,7 +272,7 @@ def _update_partition_map(node: Node, id: int): if assignment.get(user, None) != id: # type: ignore[arg-type] nodes_reassignment[user] = id # type: ignore[assignment] for node, id in nodes_reassignment.items(): - merge_single_node(node, id) + merge_single_node(node, None, id) # filter out single node partitions if not self.allows_single_node_partition: diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 1b22490405de5..33db9fd03d790 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -96,7 +96,7 @@ def fuse_as_graphmodule( gm: GraphModule, nodes: NodeList, module_name: str, - partition_lookup_table: _Optional[dict[Node, None]] = None, + partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None, *, always_return_tuple: bool = False, ) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]: @@ -249,7 +249,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList) -> None: @compatibility(is_backward_compatible=False) def fuse_by_partitions( gm: GraphModule, - partitions: list[dict[Node, None]], + partitions: list[dict[Node, _Optional[int]]], prefix: str = "fused_", always_return_tuple: bool = False, ) -> GraphModule: From 9fd5b5f73589cf08dca60910368cc0f05c7906c8 Mon Sep 17 00:00:00 2001 From: Jovian Anthony Jaison Date: Wed, 6 Aug 2025 22:33:04 +0000 Subject: [PATCH 0072/1424] [pytorch] Moving torch.compile worker process logs to a dedicated rank based log directory (#159874) Summary: Writing torch.compile worked logs to dedicated_log_rank{RANK} if we're running on mast. Test Plan: See: D79456310 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159874 Approved by: https://github.com/c00w --- test/inductor/test_compile_worker.py | 15 ++++++++++- .../_inductor/compile_worker/subproc_pool.py | 26 +++++++++++++++---- torch/_inductor/config.py | 3 +++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index dcbf1b380934f..e76bf932d145a 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -1,6 +1,8 @@ # Owner(s): ["module: inductor"] +import importlib import operator import os +import tempfile from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, @@ -11,7 +13,6 @@ from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import HAS_CPU - class TestCompileWorker(TestCase): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_basic_jobs(self): @@ -66,6 +67,18 @@ def test_quiesce(self): finally: pool.shutdown() + @skipIfWindows(msg="pass_fds not supported on Windows.") + def test_logging(self): + os.environ["MAST_HPC_JOB_NAME"] = "test_job" + os.environ["ROLE_RANK"] = "0" + with tempfile.NamedTemporaryFile(delete=True) as temp_log: + os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name + pool = SubprocPool(2) + try: + pool.submit(operator.add, 100, 1) + self.assertEqual(os.path.exists(temp_log.name), True) + finally: + pool.shutdown() if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 0b670b268b37e..dd8cab8643f1d 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -145,11 +145,24 @@ def __init__( f"--write-fd={str(subproc_write_fd)}", f"--torch-key={torch_key_str}", ] - local = False - if config.worker_suppress_logging: + mast_job_id = os.environ.get("MAST_HPC_JOB_NAME", None) + global_rank = os.environ.get("ROLE_RANK", "0") + worker_log_path = os.environ.get("TORCHINDUCTOR_WORKER_LOGPATH", config.worker_log_path) + stdout_pipe = None + stderr_pipe = None + self.log_file = None + + if mast_job_id is not None: + log_loc = f"{worker_log_path}{global_rank}" + self.log_file = open(log_loc, "w") + elif config.worker_suppress_logging: log.info("Suppressing compile worker output due to config") - local = True + self.log_file = open(os.devnull, "w") + if self.log_file: + stdout_pipe = self.log_file + stderr_pipe = self.log_file + self.process = subprocess.Popen( cmd, env={ @@ -164,9 +177,10 @@ def __init__( "LD_LIBRARY_PATH": get_ld_library_path(), }, pass_fds=(subproc_read_fd, subproc_write_fd), - stdout=subprocess.DEVNULL if local else None, - stderr=subprocess.DEVNULL if local else None, + stdout=stdout_pipe, + stderr=stderr_pipe, ) + self.write_lock = threading.Lock() self.read_thread = threading.Thread( target=self._read_thread, name="InductorSubproc", daemon=True @@ -262,6 +276,8 @@ def shutdown(self) -> None: _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) self.write_pipe.close() self.process.wait(300) + if self.log_file: + self.log_file.close() except OSError as e: log.warning("Ignored OSError in pool shutdown: %s", e) finally: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 51a438840b040..c581a7611862c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -81,6 +81,9 @@ def prologue_fusion_enabled() -> bool: # Whether to enable printing the source code for each future verbose_progress = False +# Configurable compile worker logging path for subproc_pool +worker_log_path = "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None + # precompilation timeout precompilation_timeout_seconds: int = 60 * 60 From 3a2c3c8ed365eb4e4cf4620c25d70b2f70483762 Mon Sep 17 00:00:00 2001 From: christinaburge Date: Wed, 6 Aug 2025 22:41:07 +0000 Subject: [PATCH 0073/1424] unskipped mobilenet_v3 quantization and mobilenet_v2 quantization plus tests from https://github.com/pytorch/pytorch/issues/125438 (#157786) These tests now pass on AArch64 in our downstream CI. `test_quantization.py::TestNumericSuiteEager::test_mobilenet_v2 <- test/quantization/eager/test_numeric_suite_eager.py PASSED [2.4434s] [ 35%]` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157786 Approved by: https://github.com/jerryzh168, https://github.com/malfet --- test/quantization/eager/test_numeric_suite_eager.py | 5 +---- test/test_linalg.py | 2 -- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/test/quantization/eager/test_numeric_suite_eager.py b/test/quantization/eager/test_numeric_suite_eager.py index cd11e96859937..ccffad4b5ab63 100644 --- a/test/quantization/eager/test_numeric_suite_eager.py +++ b/test/quantization/eager/test_numeric_suite_eager.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: quantization"] # ruff: noqa: F841 -import unittest import torch import torch.ao.nn.quantized as nnq @@ -38,7 +37,7 @@ test_only_eval_fn, ) from torch.testing._internal.common_quantized import override_qengines -from torch.testing._internal.common_utils import IS_ARM64, raise_on_run_directly +from torch.testing._internal.common_utils import raise_on_run_directly class SubModule(torch.nn.Module): @@ -600,14 +599,12 @@ def compute_error(x, y): act_compare_dict = get_matching_activations(float_model, qmodel) @skip_if_no_torchvision - @unittest.skipIf(IS_ARM64, "Not working on arm right now") def test_mobilenet_v2(self): from torchvision.models.quantization import mobilenet_v2 self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False)) @skip_if_no_torchvision - @unittest.skipIf(IS_ARM64, "Not working on arm right now") def test_mobilenet_v3(self): from torchvision.models.quantization import mobilenet_v3_large diff --git a/test/test_linalg.py b/test/test_linalg.py index ac668fee049d2..909e8747f1d34 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1401,8 +1401,6 @@ def run_test_case(input_size, ord, keepdim): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16) def test_vector_norm(self, device, dtype): - if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]: - raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438") # have to use torch.randn(...).to(bfloat16) instead of # This test compares torch.linalg.vector_norm's output with # torch.linalg.norm given a flattened tensor From 93da9952a77f59cb29a2d599362ba9c7ba22eaec Mon Sep 17 00:00:00 2001 From: Nathan Brown Date: Wed, 6 Aug 2025 22:56:31 +0000 Subject: [PATCH 0074/1424] gloo: fix building system gloo with CUDA/HIP (#146637) Fix incorrect linking of Gloo's libraries when building with system Gloo. Previously, either Gloo's native library or Gloo's CUDA library were linked. However, Gloo had changed such that all users of Gloo must link the native library, and can optionally link the CUDA or HIP library for Gloo + CUDA/HIP support. This had been updated when building/linking with vendored Gloo, but not when using system Gloo. Fixes: #146239 Reported-by: Adam J Stewart Pull Request resolved: https://github.com/pytorch/pytorch/pull/146637 Approved by: https://github.com/malfet --- cmake/Dependencies.cmake | 11 ++++++++-- cmake/Modules/FindGloo.cmake | 39 +++++++++++++++--------------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 0501e00c08664..b7f545027b02d 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1235,10 +1235,17 @@ if(USE_GLOO) if(NOT Gloo_FOUND) message(FATAL_ERROR "Cannot find gloo") endif() - message("Found gloo: ${Gloo_LIBRARY}") + message("Found gloo: ${Gloo_NATIVE_LIBRARY}, cuda lib: ${Gloo_CUDA_LIBRARY}, hip lib: ${Gloo_HIP_LIBRARY}") message("Found gloo include directories: ${Gloo_INCLUDE_DIRS}") add_library(gloo SHARED IMPORTED) - set_target_properties(gloo PROPERTIES IMPORTED_LOCATION ${Gloo_LIBRARY}) + set_target_properties(gloo PROPERTIES IMPORTED_LOCATION ${Gloo_NATIVE_LIBRARY}) + if(USE_CUDA) + add_library(gloo_cuda SHARED IMPORTED) + set_target_properties(gloo_cuda PROPERTIES IMPORTED_LOCATION ${Gloo_CUDA_LIBRARY}) + elseif(USE_ROCM) + add_library(gloo_hip SHARED IMPORTED) + set_target_properties(gloo_hip PROPERTIES IMPORTED_LOCATION ${Gloo_HIP_LIBRARY}) + endif() # need to use Gloo_INCLUDE_DIRS over third_party/gloo to find Gloo's auto-generated config.h include_directories(BEFORE SYSTEM ${Gloo_INCLUDE_DIRS}) endif() diff --git a/cmake/Modules/FindGloo.cmake b/cmake/Modules/FindGloo.cmake index e965326e2e8a0..944cd4d8d2573 100644 --- a/cmake/Modules/FindGloo.cmake +++ b/cmake/Modules/FindGloo.cmake @@ -1,7 +1,8 @@ # Try to find the Gloo library and headers. # Gloo_FOUND - system has Gloo lib # Gloo_INCLUDE_DIRS - the Gloo include directory -# Gloo_LIBRARY/Gloo_NATIVE_LIBRARY - libraries needed to use Gloo +# Gloo_NATIVE_LIBRARY - base gloo library, needs to be linked +# Gloo_CUDA_LIBRARY/Gloo_HIP_LIBRARY - CUDA/HIP support library in Gloo find_path(Gloo_INCLUDE_DIR NAMES gloo/common/common.h @@ -10,40 +11,32 @@ find_path(Gloo_INCLUDE_DIR find_library(Gloo_NATIVE_LIBRARY NAMES gloo - DOC "The Gloo library (without CUDA)" + DOC "The Gloo library" ) +# Gloo has optional CUDA support +# if Gloo + CUDA is desired, Gloo_CUDA_LIBRARY +# needs to be linked into desired target find_library(Gloo_CUDA_LIBRARY NAMES gloo_cuda - DOC "The Gloo library (with CUDA)" + DOC "Gloo's CUDA support/code" +) + +# Gloo has optional HIP support +# if Gloo + HIP is desired, Gloo_HIP_LIBRARY +# needs to be linked to desired target +find_library(Gloo_HIP_LIBRARY + NAMES gloo_hiop + DOC "Gloo's HIP support/code" ) set(Gloo_INCLUDE_DIRS ${Gloo_INCLUDE_DIR}) -# use the CUDA library depending on the Gloo_USE_CUDA variable -if (DEFINED Gloo_USE_CUDA) - if (${Gloo_USE_CUDA}) - set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY}) - set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) - else() - set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY}) - set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) - endif() -else() - # else try to use the CUDA library if found - if (${Gloo_CUDA_LIBRARY} STREQUAL "Gloo_CUDA_LIBRARY-NOTFOUND") - set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY}) - set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) - else() - set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY}) - set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) - endif() -endif() include(FindPackageHandleStandardArgs) find_package_handle_standard_args(Gloo FOUND_VAR Gloo_FOUND - REQUIRED_VARS Gloo_INCLUDE_DIR Gloo_LIBRARY + REQUIRED_VARS Gloo_INCLUDE_DIR Gloo_NATIVE_LIBRARY ) mark_as_advanced(Gloo_FOUND) From 64dc30c2139f607b2e9c11ca299e8f92f3ead7ff Mon Sep 17 00:00:00 2001 From: Thomas Bohnstingl Date: Wed, 6 Aug 2025 23:02:42 +0000 Subject: [PATCH 0075/1424] [HOP, map] Rework of map autograd to the new interface (#153343) This PR reworks the current autograd implementation of map to the new interface. @pytorchbot label "topic: not user facing" Pull Request resolved: https://github.com/pytorch/pytorch/pull/153343 Approved by: https://github.com/ydwu4 --- torch/_dynamo/variables/higher_order_ops.py | 1 - torch/_higher_order_ops/cond.py | 58 +------ torch/_higher_order_ops/map.py | 166 +++++++++----------- torch/_higher_order_ops/scan.py | 12 +- torch/_higher_order_ops/utils.py | 64 ++++++++ 5 files changed, 140 insertions(+), 161 deletions(-) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index cdaf1e9e52ccc..8c0730907a4d5 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3415,7 +3415,6 @@ def _call_function( _hop_name_to_variable_class = { "cond": CondHigherOrderVariable, "while_loop": WhileLoopHigherOrderVariable, - "map": MapHigherOrderVariable, "map_impl": MapHigherOrderVariable, "executorch_call_delegate": ExecutorchCallDelegateHigherOrderVariable, "out_dtype": OutDtypeHigherOrderVariable, diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 648d41b0b95a6..10f6ca9f386c5 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -6,7 +6,6 @@ from typing import Any, Callable, Optional, Union import torch -import torch._subclasses.functional_tensor import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._C._functorch import ( @@ -19,6 +18,7 @@ from torch._higher_order_ops.utils import ( _maybe_run_with_interpreter, _set_compilation_env, + create_bw_fn, materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, @@ -36,8 +36,6 @@ ) from torch.utils._python_dispatch import _get_current_dispatch_mode -from .utils import clone_outputs_aliasing_inputs - log = logging.getLogger(__name__) @@ -201,60 +199,6 @@ def _cond_op_wrapper(*args, **kwargs): ) -def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable: - """ - For a fn that accepts flat inputs and returns flat outputs: - fw_out = fn(*args), - this function returns: - grad_args = bw_fn(*args_and_grad_output) - with the following invariants: - 1. args + fw_out has an 1-1 correspondence to args_and_grad_output - 2. grad_args has an 1-1 corresponsence to args - 3. for tensor arg whose requires_grad is False, its corresponding grad in - grad_args will be a zero tensor with the same shape. - """ - - from torch._functorch.aot_autograd import AOTConfig, create_joint - from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad - - dummy_aot_config = AOTConfig( - fw_compiler=None, # type: ignore[arg-type] - bw_compiler=None, # type: ignore[arg-type] - partition_fn=None, # type: ignore[arg-type] - decompositions={}, - num_params_buffers=0, - aot_id=0, - keep_inference_input_mutations=False, - ) - n_primals = len(args) - - bw_fn = create_joint( - prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config - ) - - def flat_fn(*args_and_grad_outs): - primals = args_and_grad_outs[:n_primals] - tangents = args_and_grad_outs[n_primals:] - grad_args = bw_fn(primals, tangents)[1] - assert len(args) == len(grad_args) - # In order to keep HOPs functional where the backward graph, - # would have outputs that are aliasing inputs. - # For example in cases where the backward of the function is simply - # passing the upstream gradients through. - maybe_clone = clone_outputs_aliasing_inputs(args_and_grad_outs) - - return [ - ( - torch.zeros_like(arg) - if isinstance(arg, torch.Tensor) and grad is None - else maybe_clone(grad) - ) - for grad, arg in zip(grad_args, primals) - ] - - return flat_fn - - def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): assert isinstance(operands, (list, tuple)), ( f"Cond operands must be a list or tuple of tensors and SymInts {operands}" diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 9f73df7ef478a..332bde7e464f2 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -13,7 +13,6 @@ from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, - make_fx, ProxyTorchDispatchMode, track_tensor_tree, ) @@ -22,10 +21,11 @@ _from_fun, _stack_pytree, _unstack_pytree, - clone_outputs_aliasing_inputs, - prepare_fw_with_masks, + create_bw_fn, + materialize_as_graph, save_tensors_and_symints_for_backward, saved_tensors_and_symints, + split_into_chunks, ) @@ -40,77 +40,6 @@ def __call__(self, *args, **kwargs): map_impl = MapImpl() -def create_fw_bw_graph(f, num_mapped_args, *args): - mapped_xs = args[:num_mapped_args] - pos_args = args[num_mapped_args:] - - # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py - - with suspend_functionalization(), disable_functional_mode(): - with disable_proxy_modes_tracing(): - unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs) - example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] - - example_pos_args = [ - _from_fun(arg) if isinstance(arg, torch.Tensor) else arg - for arg in pos_args - ] - example_flat_out = pytree.tree_map( - _from_fun, f(*example_xs, *example_pos_args) - ) - if any( - not isinstance(out, torch.Tensor) - for out in example_flat_out - if out is not None - ): - raise RuntimeError( - "Expect outputs of map only contains tensors or None. " - f"Got types {[type(out) for out in example_flat_out]}." - ) - example_grad = [_from_fun(out) for out in example_flat_out] - - fw_graph = make_fx(f)(*example_xs, *example_pos_args) - - from torch._functorch.aot_autograd import AOTConfig, create_joint - - dummy_aot_config = AOTConfig( - fw_compiler=None, # type: ignore[arg-type] - bw_compiler=None, # type: ignore[arg-type] - partition_fn=None, # type: ignore[arg-type] - decompositions={}, - num_params_buffers=0, - aot_id=0, - keep_inference_input_mutations=False, - ) - - def joint_f(*example_args): - joint_mapped_args = example_args[:joint_num_mapped] - args = example_args[joint_num_mapped:] - - mapped_input = joint_mapped_args[:num_mapped_args] - mapped_grads = joint_mapped_args[num_mapped_args:] - - joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config) - _, grads = joint( - list(mapped_input) + list(args), - [ - grad - for grad in mapped_grads - if grad is not None and grad.requires_grad - ], - ) - - # In order to keep map functional for backward graph, - # we clone outputs that are aliasing inputs - maybe_clone = clone_outputs_aliasing_inputs(example_args) - - return pytree.tree_map(maybe_clone, grads) - - joint_num_mapped = len(example_grad) + len(example_xs) - joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) - return fw_graph, joint_graph - - def map( f: Callable[[pytree.PyTree, tuple[pytree.PyTree, ...]], pytree.PyTree], xs: Union[pytree.PyTree, torch.Tensor], @@ -193,36 +122,88 @@ def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs): class MapAutogradOp(torch.autograd.Function): @staticmethod - def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): - save_tensors_and_symints_for_backward(ctx, flat_args) - ctx._joint_graph = joint_graph + def forward(ctx, f, num_mapped_args, *flat_args): + ctx._f = f ctx._num_mapped_args = num_mapped_args + ctx._num_pos_args = len(flat_args) - num_mapped_args + + # We snapshot the dispatch keys in forward for materializing the + # the bw_graph in backward. + ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set() + ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + save_tensors_and_symints_for_backward(ctx, flat_args) with torch._C._AutoDispatchBelowAutograd(): return ( - *map_impl( - fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:] - ), + *map_impl(f, flat_args[:num_mapped_args], flat_args[num_mapped_args:]), ) @staticmethod def backward(ctx, *flat_grads): fw_args = saved_tensors_and_symints(ctx) - fw_mapped_args = fw_args[: ctx._num_mapped_args] - pos_args = fw_args[ctx._num_mapped_args :] - - grads = map_impl( - ctx._joint_graph, - fw_mapped_args + flat_grads, - pos_args, + num_mapped_args = ctx._num_mapped_args + num_pos_args = ctx._num_pos_args + num_grads = len(flat_grads) + + fw_mapped_args, pos_args = split_into_chunks( + fw_args, + [ + num_mapped_args, + num_pos_args, + ], ) - return None, None, None, *grads + + bw_f = create_bw_fn(ctx._f, fw_args) + + # Create a wrapper around thefor the bw_f + def bw_f_wrapper(*args): + # Dissect args and re-order them for the ``ctx._bw_f`` + # args provided to the wrapper are composed of [*fw_mapped_args, *flat_grads, *pos_args] + # The content of ``bw_f_tangents`` are the upstream gradients, i.e. flat_grads + # The content of ``bw_f_primals`` are the fw_args, i.e., [*fw_mapped_args, *pos_args] + # The bw_f requires *bw_f_primals, *bw_f_tangents + fw_m_args, bw_f_tangents, pos_args = split_into_chunks( + args, [num_mapped_args, num_grads, num_pos_args] + ) + bw_f_primals = *fw_m_args, *pos_args + return bw_f(*bw_f_primals, *bw_f_tangents) + + def construct_args_single_step_bw(): + unwrapped_mapped_xs = pytree.tree_map(_from_fun, fw_mapped_args) + example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] + unwrapped_grads = pytree.tree_map(_from_fun, flat_grads) + example_grads = _unstack_pytree(unwrapped_grads)[0] + example_pos_args = [ + _from_fun(arg) if isinstance(arg, torch.Tensor) else arg + for arg in pos_args + ] + return *example_xs, *example_grads, *example_pos_args + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + args_single_step_bw = construct_args_single_step_bw() + + # TODO: we need to materialize the bw graphs because dynamo is unable to + # trace through the joint function when torch.compile torch.autograd.grad. + fn_bw_gm = materialize_as_graph( + bw_f_wrapper, + args_single_step_bw, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + force_enable_grad=True, + ) + + grads = map_impl(fn_bw_gm, fw_mapped_args + flat_grads, pos_args) + + return None, None, *grads def trace_map(proxy_mode, func_overload, f, xs, pos_args): - example_input = _unstack_pytree(xs)[0] - body_graph = f + with disable_proxy_modes_tracing(): + example_input = _unstack_pytree(xs)[0] + + body_graph = f - body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) + body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_") @@ -249,8 +230,7 @@ def map_dense(f, xs, pos_args): @map_impl.py_autograd_impl def map_autograd(f, xs, pos_args): num_mapped_args = len(xs) - fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) - flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args) + flat_out = MapAutogradOp.apply(f, num_mapped_args, *xs, *pos_args) return flat_out diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index 3cd5bf9ec4e22..4e636b396b38b 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -1,22 +1,22 @@ # mypy: allow-untyped-defs import functools import itertools -from collections.abc import Sequence from typing import Any, Callable, Optional import torch import torch._prims_common as utils import torch.utils._pytree as pytree from torch._C import DispatchKey -from torch._higher_order_ops.cond import create_bw_fn from torch._higher_order_ops.utils import ( _maybe_compile_and_run_fn, check_meta_consistency, + create_bw_fn, first_slice_copy, materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, saved_tensors_and_symints, + split_into_chunks, unique_graph_id, validate_subgraph_args_types, ) @@ -95,14 +95,6 @@ def first_slice_copy_with_grad(li: list[Any]) -> list[Any]: return slc -def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: - it = iter(iterable) - assert sum(chunk_sizes) == len(iterable), ( - "the sum of all chunks needs to match the length of the iterable." - ) - return [list(itertools.islice(it, size)) for size in chunk_sizes] - - def call_operator(operator, *args): return pytree.tree_leaves(operator(*args)) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 25ef972864d58..ab0fc4e654c60 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import contextlib import functools +from collections.abc import Sequence from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import Any, Callable, Optional, overload, TypeVar, Union @@ -722,6 +723,69 @@ def saved_tensors_and_symints(ctx): return tuple(args) +def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: + assert sum(chunk_sizes) == len(iterable), ( + "the sum of all chunks needs to match the length of the iterable." + ) + elements = [] + idx = 0 + for size in chunk_sizes: + elements.append(iterable[idx : idx + size]) + idx += size + return elements + + +def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable: + """ + For a fn that accepts flat inputs and returns flat outputs: + fw_out = fn(*args), + this function returns: + grad_args = bw_fn(*args_and_grad_output) + with the following invariants: + 1. args + fw_out has an 1-1 correspondence to args_and_grad_output + 2. grad_args has an 1-1 corresponsence to args + 3. for tensor arg whose requires_grad is False, its corresponding grad in + grad_args will be a zero tensor with the same shape. + """ + + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + n_primals = len(args) + + bw_fn = create_joint( + prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config + ) + + def flat_fn(*args_and_grad_outs): + primals = args_and_grad_outs[:n_primals] + tangents = args_and_grad_outs[n_primals:] + grad_args = bw_fn(primals, tangents)[1] + assert len(args) == len(grad_args) + + maybe_clone = clone_outputs_aliasing_inputs(args_and_grad_outs) + + return [ + ( + torch.zeros_like(arg) + if isinstance(arg, torch.Tensor) and grad is None + else maybe_clone(grad) + ) + for grad, arg in zip(grad_args, primals) + ] + + return flat_fn + + def get_dummy_aot_autograd_config(): from torch._functorch.aot_autograd import AOTConfig From a6bc296207843134302e3e55b3ae77afdcb3532b Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 5 Aug 2025 12:05:59 -0700 Subject: [PATCH 0076/1424] [FlexAttention] Update the guard semantics for divisibility (#159884) We don't add guards unless we know (and another guard has ensured this) that this is a safe optimization Pull Request resolved: https://github.com/pytorch/pytorch/pull/159884 Approved by: https://github.com/Chillee --- test/inductor/test_flex_attention.py | 44 +++++++++++++++++++ torch/_inductor/kernel/flex/flex_attention.py | 8 ++-- torch/_inductor/kernel/flex/flex_decoding.py | 9 ++-- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index e78cf68244ee6..8e4746212a0bc 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1889,6 +1889,50 @@ def score_mod_scale(qk, b, h, q, kv): self.run_test(score_mod_scale, dtype, device=device) + @supported_platform + @dtypes(*device_configs["cpu"].dtypes_fast) + @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @skip_on_cpu + def test_dynamic_divisibility_guards(self, device, dtype): + """Test guards for divisible/non-divisible shape transitions""" + if device == "cpu" and dtype is torch.float16: + dtype = torch.float32 + + def score_mod(qk, b, h, q, kv): + return torch.where(q >= kv, qk, -float("inf")) + + def test_shape(S, backend): + """Test a single shape configuration""" + block_mask = create_block_mask(noop_mask, 1, 1, S, S, device=device) + sdpa_partial = create_attention(score_mod, block_mask=block_mask) + + tensors = [ + torch.randn( + 2, 4, S, 64, dtype=dtype, device=device, requires_grad=False + ) + for _ in range(3) + ] + + compiled_sdpa = torch.compile(sdpa_partial, backend=backend) + out, code = run_and_get_code(compiled_sdpa, *tensors) + + # Check divisibility flag + is_divisible = S % 128 == 0 + expected_flag = f"IS_DIVISIBLE : tl.constexpr = {is_divisible}" + self.assertIn( + expected_flag, str(code), f"S={S} should have {expected_flag}" + ) + + self.assertEqual(out.shape, (2, 4, S, 64)) + return out, code + + torch._dynamo.reset() + backend = CompileCounterWithBackend("inductor") + + # Test divisible and non-divisible shapes + test_shapes = [256, 255, 383, 384] + _ = [test_shape(S, backend) for S in test_shapes] + @supported_platform def test_multiple_score_mod_calls(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 0553fd06755d0..b6f5646bb57cb 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -1535,10 +1535,12 @@ def flex_attention_backward(*args, **kwargs): for k, v in kernel_options.items() } kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) - if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: - kernel_options.setdefault("IS_DIVISIBLE", False) - else: + seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) + seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + if seq_q_divisible and seq_kv_divisible: kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) fwd_placeholder_inps = [ create_placeholder(name, dtype, device) diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py index 83c6b59cec96c..7f92fbc705a59 100644 --- a/torch/_inductor/kernel/flex/flex_decoding.py +++ b/torch/_inductor/kernel/flex/flex_decoding.py @@ -410,11 +410,12 @@ def create_flex_decoding_kernel(*args, **kwargs): for k, v in kernel_options.items() } - # TODO: Fix flex decoding non-divisible case! - if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: - kernel_options.setdefault("IS_DIVISIBLE", False) - else: + seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) + seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + if seq_q_divisible and seq_kv_divisible: kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) # Calculate GQA head sharing gqa_shared_heads = Hq // Hkv From cb4b29b754bb76fed5464fb51413bf9c023e124f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 6 Aug 2025 23:21:29 +0000 Subject: [PATCH 0077/1424] Revert "[pytorch] Moving torch.compile worker process logs to a dedicated rank based log directory (#159874)" This reverts commit 9fd5b5f73589cf08dca60910368cc0f05c7906c8. Reverted https://github.com/pytorch/pytorch/pull/159874 on behalf of https://github.com/malfet due to Broke lint ([comment](https://github.com/pytorch/pytorch/pull/159874#issuecomment-3161896978)) --- test/inductor/test_compile_worker.py | 15 +---------- .../_inductor/compile_worker/subproc_pool.py | 26 ++++--------------- torch/_inductor/config.py | 3 --- 3 files changed, 6 insertions(+), 38 deletions(-) diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index e76bf932d145a..dcbf1b380934f 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -1,8 +1,6 @@ # Owner(s): ["module: inductor"] -import importlib import operator import os -import tempfile from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, @@ -13,6 +11,7 @@ from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import HAS_CPU + class TestCompileWorker(TestCase): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_basic_jobs(self): @@ -67,18 +66,6 @@ def test_quiesce(self): finally: pool.shutdown() - @skipIfWindows(msg="pass_fds not supported on Windows.") - def test_logging(self): - os.environ["MAST_HPC_JOB_NAME"] = "test_job" - os.environ["ROLE_RANK"] = "0" - with tempfile.NamedTemporaryFile(delete=True) as temp_log: - os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name - pool = SubprocPool(2) - try: - pool.submit(operator.add, 100, 1) - self.assertEqual(os.path.exists(temp_log.name), True) - finally: - pool.shutdown() if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index dd8cab8643f1d..0b670b268b37e 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -145,24 +145,11 @@ def __init__( f"--write-fd={str(subproc_write_fd)}", f"--torch-key={torch_key_str}", ] - mast_job_id = os.environ.get("MAST_HPC_JOB_NAME", None) - global_rank = os.environ.get("ROLE_RANK", "0") - worker_log_path = os.environ.get("TORCHINDUCTOR_WORKER_LOGPATH", config.worker_log_path) - stdout_pipe = None - stderr_pipe = None - self.log_file = None - - if mast_job_id is not None: - log_loc = f"{worker_log_path}{global_rank}" - self.log_file = open(log_loc, "w") - elif config.worker_suppress_logging: + local = False + if config.worker_suppress_logging: log.info("Suppressing compile worker output due to config") - self.log_file = open(os.devnull, "w") + local = True - if self.log_file: - stdout_pipe = self.log_file - stderr_pipe = self.log_file - self.process = subprocess.Popen( cmd, env={ @@ -177,10 +164,9 @@ def __init__( "LD_LIBRARY_PATH": get_ld_library_path(), }, pass_fds=(subproc_read_fd, subproc_write_fd), - stdout=stdout_pipe, - stderr=stderr_pipe, + stdout=subprocess.DEVNULL if local else None, + stderr=subprocess.DEVNULL if local else None, ) - self.write_lock = threading.Lock() self.read_thread = threading.Thread( target=self._read_thread, name="InductorSubproc", daemon=True @@ -276,8 +262,6 @@ def shutdown(self) -> None: _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) self.write_pipe.close() self.process.wait(300) - if self.log_file: - self.log_file.close() except OSError as e: log.warning("Ignored OSError in pool shutdown: %s", e) finally: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c581a7611862c..51a438840b040 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -81,9 +81,6 @@ def prologue_fusion_enabled() -> bool: # Whether to enable printing the source code for each future verbose_progress = False -# Configurable compile worker logging path for subproc_pool -worker_log_path = "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None - # precompilation timeout precompilation_timeout_seconds: int = 60 * 60 From 3daef4d128879d1f6bad55d33d0396e94f19981b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 6 Aug 2025 13:36:02 -0700 Subject: [PATCH 0078/1424] [dynamo] Trace nn.Module __delattr__ (#159969) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159969 Approved by: https://github.com/atalman, https://github.com/malfet, https://github.com/StrongerXi --- test/dynamo/test_modules.py | 52 ++++++++++++++++++++++++++++ torch/_dynamo/variables/nn_module.py | 13 ++++--- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index f38b9bc502775..7cac7eca72394 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3422,6 +3422,58 @@ def forward(self, x): compiled_mod = torch.compile(mod, backend="eager") compiled_mod(x) + def test_trace_delattr(self): + TMP_PREFIX = "_tmp_" + + def pre_forward_rename_hook(module: torch.nn.Module, _input: torch.Tensor): + param_name = "weight" + original_param = getattr(module, param_name) + setattr(module, TMP_PREFIX + param_name, original_param) + new_param = original_param + 1.0 + delattr(module, param_name) + setattr(module, param_name, new_param) + + def post_forward_restore_hook( + module: torch.nn.Module, _input: torch.Tensor, _output: torch.Tensor + ): + param_name = "weight" + tmp_param_name = TMP_PREFIX + param_name + original_param = getattr(module, tmp_param_name) + delattr(module, param_name) + setattr(module, param_name, original_param) + delattr(module, tmp_param_name) + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + torch.manual_seed(0) + model = SimpleModel() + + model.linear.register_forward_pre_hook(pre_forward_rename_hook) + model.linear.register_forward_hook(post_forward_restore_hook) + + input_tensor = torch.randn(4, 10) + + eager_output = model(input_tensor) + assert hasattr(model.linear, "weight") + assert not hasattr(model.linear, "_tmp_weight") + + torch.manual_seed(0) + model_to_compile = SimpleModel() + model_to_compile.linear.register_forward_pre_hook(pre_forward_rename_hook) + model_to_compile.linear.register_forward_hook(post_forward_restore_hook) + + compiled_model = torch.compile(model_to_compile, fullgraph=True) + compiled_output = compiled_model(input_tensor) + assert hasattr(model.linear, "weight") + assert not hasattr(compiled_model.linear, "_tmp_weight") + torch.testing.assert_close(eager_output, compiled_output) + devices = ["cuda", "hpu", "xpu"] instantiate_device_type_tests( diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 3ca91814b8ae9..10ad8c4a12865 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -909,7 +909,11 @@ def set_nn_module_stack_source(self, source): @functools.cache def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler - supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} + supported = { + torch.nn.Module.__setattr__, + torch.nn.Module.__init__, + torch.nn.Module.__delattr__, + } return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -1091,9 +1095,10 @@ def call_method( # Handle submodules self.is_state_mutated = True - if method is torch.nn.Module.__setattr__ and isinstance( - args[1], variables.DeletedVariable - ): + if ( + method is torch.nn.Module.__setattr__ + and isinstance(args[1], variables.DeletedVariable) + ) or method is torch.nn.Module.__delattr__: # Trace through __delattr__ to track mutations on the module # members like `_modules``. return tx.inline_user_function_return( From fd606a3a918f34824333111038e034c9e18ea8e2 Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 6 Aug 2025 11:32:19 -0700 Subject: [PATCH 0079/1424] [dynamo] update pytorch-labs -> meta-pytorch in graph break URLs (#159975) Related PR: https://github.com/meta-pytorch/compile-graph-break-site/pull/30 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159975 Approved by: https://github.com/Lucaskabela --- .../fsdp/test_fully_shard_compile.py | 2 +- test/dynamo/test_error_messages.py | 80 +++++++++---------- test/dynamo/test_exc.py | 4 +- test/dynamo/test_reorder_logs.py | 2 +- test/dynamo/test_repros.py | 2 +- test/dynamo/test_sets.py | 2 +- test/test_custom_ops.py | 2 +- torch/_dynamo/exc.py | 2 +- 8 files changed, 48 insertions(+), 48 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 478eb498ac5d5..c8e98c5c3e1f3 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -549,7 +549,7 @@ def test_compiled(): Developer debug context: call_method TensorVariable() backward () {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0123.html""", # noqa: B950 + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0123.html""", # noqa: B950 ) else: self.assertGreater(len(counters["graph_break"]), 1) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 063e6863b8705..e91e7ef52097c 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -62,7 +62,7 @@ def fn(): Developer debug context: aten.nonzero.default - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0036.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html from user code: File "test_error_messages.py", line N, in fn @@ -84,7 +84,7 @@ def fn(): Developer debug context: aten.linalg_lstsq.default - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0037.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0037.html from user code: File "test_error_messages.py", line N, in fn @@ -107,7 +107,7 @@ def fn(x): Developer debug context: call_method TensorVariable() item () {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0124.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html from user code: File "test_error_messages.py", line N, in fn @@ -131,7 +131,7 @@ def fn(x): Developer debug context: aten.equal.default - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0033.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0033.html from user code: File "test_error_messages.py", line N, in fn @@ -159,7 +159,7 @@ def fn(lst): Developer debug context: TensorVariable() - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0207.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0207.html from user code: File "test_error_messages.py", line N, in fn @@ -185,7 +185,7 @@ def fn(it): Developer debug context: call_method UserDefinedObjectVariable(zip) __iter__ [] {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0156.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html from user code: File "test_error_messages.py", line N, in fn @@ -214,7 +214,7 @@ def fn(x, items): Developer debug context: call_method UserDefinedObjectVariable(dict_items) __iter__ [] {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0156.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html from user code: File "test_error_messages.py", line N, in fn @@ -238,7 +238,7 @@ def fn(it): Developer debug context: call_function UserDefinedObjectVariable(zip) [] {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0147.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0147.html from user code: File "test_error_messages.py", line N, in fn @@ -262,7 +262,7 @@ def fn(obj): Developer debug context: Attempted SETUP_WITH/BEFORE_WITH on ConstantVariable(int: 3) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0142.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0142.html from user code: File "test_error_messages.py", line N, in fn @@ -293,7 +293,7 @@ def fn(x): return x + 1 - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0219.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0219.html""", ) def test_unsupported_builtin(self): @@ -312,7 +312,7 @@ def fn(): Developer debug context: builtin print [] False - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0059.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0059.html from user code: File "test_error_messages.py", line N, in fn @@ -338,7 +338,7 @@ def post_munge(s): Developer debug context: module: unittest.case, qualname: skip, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html from user code: File "test_error_messages.py", line N, in fn @@ -360,7 +360,7 @@ def fn(): Developer debug context: module: torch._dynamo.decorators, qualname: disable, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html from user code: File "test_error_messages.py", line N, in fn @@ -389,7 +389,7 @@ def post_munge(s): Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup unittest - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0008.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0008.html from user code: File "test_error_messages.py", line N, in fn @@ -411,7 +411,7 @@ def fn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_error_messages.py", line N, in fn @@ -432,7 +432,7 @@ def fn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{'msg': ConstantVariable(str: 'test graph break')}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_error_messages.py", line N, in fn @@ -454,7 +454,7 @@ def fn(): Developer debug context: module: _warnings, qualname: warn, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html from user code: File "test_error_messages.py", line N, in fn @@ -483,7 +483,7 @@ def fn(x): Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", ) @scoped_load_inline @@ -530,7 +530,7 @@ def f(x): Developer debug context: module: mylib, qualname: PyCapsule.foobar, skip reason: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0007.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", ) cpp_source = """ @@ -582,7 +582,7 @@ def fn(x, y): Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: TensorVariable(), step: ConstantVariable(NoneType: None) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0038.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html from user code: File "test_error_messages.py", line N, in fn @@ -604,7 +604,7 @@ def fn(): Developer debug context: raised exception RuntimeError([ConstantVariable(str: 'test')]) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0088.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html from user code: File "test_error_messages.py", line N, in fn @@ -630,7 +630,7 @@ def fn(mod): Developer debug context: Foo - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0119.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0119.html from user code: File "test_error_messages.py", line N, in fn @@ -659,7 +659,7 @@ def fn(mod, x): Developer debug context: nn.Module subclass: Foo, name: attr, attribute type: module - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0161.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0161.html from user code: File "test_error_messages.py", line N, in fn @@ -689,7 +689,7 @@ def fn(): Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr), GenericContextWrappingVariable(GenericCtxMgr)] - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0066.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html from user code: File "test_error_messages.py", line N, in fn @@ -705,7 +705,7 @@ def fn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html""", ) def test_load_build_class(self): @@ -726,7 +726,7 @@ class Foo: Developer debug context: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0075.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0075.html from user code: File "test_error_messages.py", line N, in fn @@ -759,7 +759,7 @@ def post_munge(s): Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. Developer debug context: GET_AITER with args (, Instruction(GET_AITER) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0082.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0082.html from user code: File "test_error_messages.py", line N, in fn @@ -790,7 +790,7 @@ def post_munge(s): Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0092.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html from user code: File "test_error_messages.py", line N, in fn @@ -826,7 +826,7 @@ def post_munge(s): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_error_messages.py", line N, in test_reconstruction_failure_gb torch.compile(fn, backend="eager")() @@ -846,7 +846,7 @@ def post_munge(s): Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0092.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html from user code: File "test_error_messages.py", line N, in fn @@ -875,7 +875,7 @@ def fn(x): Developer debug context: - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0087.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0087.html from user code: File "test_error_messages.py", line N, in fn @@ -899,7 +899,7 @@ def fn(x): Developer debug context: attempted to jump with TensorVariable() - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html from user code: File "test_error_messages.py", line N, in fn @@ -966,7 +966,7 @@ def fn(x): Developer debug context: value: ConstantVariable(bool: False) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0034.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0034.html from user code: File "test_error_messages.py", line N, in fn @@ -1010,7 +1010,7 @@ def gn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_error_messages.py", line N, in fn @@ -1063,7 +1063,7 @@ def gn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_error_messages.py", line N, in fn @@ -1099,7 +1099,7 @@ def hn(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_error_messages.py", line N, in test_nested_compile_user_frames torch.compile(fn, backend="eager")(torch.randn(3)) @@ -1213,7 +1213,7 @@ def f3(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_error_messages.py", line N, in test_graph_break_traceback_collapsed_resume_frames f1(torch.randn(3)) @@ -1303,7 +1303,7 @@ def post_munge(s): Developer debug context: .f at 0xmem_addr> - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0098.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0098.html from user code: File "test_error_messages.py", line N, in outer @@ -1325,7 +1325,7 @@ def g(x): Developer debug context: .g at 0xmem_addr> - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0098.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0098.html from user code: File "test_error_messages.py", line N, in outer @@ -1351,7 +1351,7 @@ def forward(self, x): Developer debug context: source: LocalSource(local_name='fn', is_input=True, dynamism=None, is_derefed_cell_contents=False) - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0148.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0148.html from user code: File "test_error_messages.py", line N, in outer diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index a7cb02132bd5f..ad56417ed568d 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -43,7 +43,7 @@ def fn001(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html from user code: File "test_exc.py", line N, in fn001 @@ -183,7 +183,7 @@ def fn001(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_exc.py", line N, in test_graph_break_log torch.compile(fn001, backend="eager")(torch.randn(1)) diff --git a/test/dynamo/test_reorder_logs.py b/test/dynamo/test_reorder_logs.py index e833dd9df8865..be6bf8085af27 100644 --- a/test/dynamo/test_reorder_logs.py +++ b/test/dynamo/test_reorder_logs.py @@ -211,7 +211,7 @@ def f(x): Developer debug context: call_method TensorVariable() item () {} - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0124.html""", # noqa: B950 + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html""", # noqa: B950 ) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e0a3f7a5223f0..1da35106d54c8 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7160,7 +7160,7 @@ def fn(): "Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.\n\n" " Developer debug context: \n\n" " For more details about this graph break, please visit: " - "https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0264.html" + "https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0264.html" ) self.assertEqual(explain_output.break_reasons[0].reason, expected_msg) diff --git a/test/dynamo/test_sets.py b/test/dynamo/test_sets.py index 0871c0c1e565c..7b6421ce6a25a 100644 --- a/test/dynamo/test_sets.py +++ b/test/dynamo/test_sets.py @@ -174,7 +174,7 @@ def fn(x, s): Developer debug context: Python set containing torch.Tensor elements - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0222.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0222.html from user code: File "test_sets.py", line N, in fn diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index b713edeb7a954..5a494f5487423 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -1769,7 +1769,7 @@ def f(x): Developer debug context: _torch_testing.numpy_nonzero.default - For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0036.html""", + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html""", ) # pre-existing problem: torch.compile(dynamic=True) will, by default, diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 5039cf63526c3..e1247917ef82e 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -527,7 +527,7 @@ def get_gbid_documentation_link(gb_type: str) -> Optional[str]: A string containing the documentation URL if found, otherwise None. """ GRAPH_BREAK_SITE_URL = ( - "https://pytorch-labs.github.io/compile-graph-break-site/gb/" # @lint-ignore + "https://meta-pytorch.github.io/compile-graph-break-site/gb/" # @lint-ignore ) registry = _load_graph_break_registry() From 5cedc5a0ff236529f76ac514805b825bc73e1a74 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 6 Aug 2025 20:57:29 +0000 Subject: [PATCH 0080/1424] [BE][PYFMT] migrate PYFMT for `torch/[p-z]*/` to `ruff format` (#144552) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144552 Approved by: https://github.com/ezyang --- tools/linter/adapters/pyfmt_linter.py | 1 - torch/package/_mangling.py | 1 + torch/package/package_exporter.py | 6 +- torch/package/package_importer.py | 7 +- torch/profiler/__init__.py | 1 + torch/profiler/_memory_profiler.py | 6 +- torch/profiler/_utils.py | 13 +- torch/profiler/profiler.py | 29 ++- torch/quantization/fuser_method_mappings.py | 1 + torch/quantization/fx/_equalize.py | 1 + torch/quantization/fx/convert.py | 1 + torch/quantization/fx/fuse.py | 1 + torch/quantization/fx/fusion_patterns.py | 1 + torch/quantization/fx/graph_module.py | 1 + torch/quantization/fx/match_utils.py | 1 + torch/quantization/fx/pattern_utils.py | 1 + torch/quantization/fx/prepare.py | 1 + .../quantization/fx/quantization_patterns.py | 1 + torch/quantization/fx/quantization_types.py | 1 + torch/quantization/fx/utils.py | 1 + torch/quantization/observer.py | 1 + torch/quantization/qconfig.py | 1 + torch/quantization/quantization_mappings.py | 1 + torch/signal/windows/windows.py | 32 +-- torch/sparse/__init__.py | 8 +- torch/sparse/_triton_ops.py | 10 +- torch/sparse/_triton_ops_meta.py | 7 +- torch/sparse/semi_structured.py | 27 ++- torch/special/__init__.py | 200 +++++------------- torch/testing/_comparison.py | 4 +- torch/testing/_creation.py | 4 +- torch/testing/_internal/common_device_type.py | 48 ++--- torch/testing/_internal/common_distributed.py | 18 +- torch/testing/_internal/common_fsdp.py | 12 +- torch/testing/_internal/common_optimizers.py | 6 +- .../distributed/_tensor/common_dtensor.py | 4 +- .../ddp_under_dist_autograd_test.py | 9 +- torch/testing/_internal/opinfo/core.py | 24 +-- .../_internal/opinfo/definitions/_masked.py | 20 +- torch/utils/_config_module.py | 15 +- torch/utils/_cxx_pytree.py | 71 +++---- torch/utils/_functools.py | 2 +- torch/utils/_python_dispatch.py | 33 +-- torch/utils/_pytree.py | 97 ++++----- .../_strobelight/cli_function_profiler.py | 2 +- torch/utils/_sympy/functions.py | 20 +- torch/utils/_sympy/value_ranges.py | 27 +-- torch/utils/backend_registration.py | 6 +- torch/utils/data/_utils/collate.py | 8 +- torch/utils/data/_utils/pin_memory.py | 4 +- torch/utils/data/_utils/worker.py | 2 +- torch/utils/data/dataloader.py | 6 +- torch/utils/data/datapipes/_decorator.py | 3 +- torch/utils/data/datapipes/datapipe.py | 8 +- torch/utils/data/datapipes/iter/callable.py | 7 +- .../data/datapipes/iter/combinatorics.py | 10 +- torch/utils/data/datapipes/iter/combining.py | 33 +-- torch/utils/data/datapipes/iter/fileopener.py | 8 +- torch/utils/data/datapipes/iter/grouping.py | 11 +- torch/utils/data/datapipes/map/utils.py | 4 +- torch/utils/data/datapipes/utils/decoder.py | 22 +- torch/utils/data/dataset.py | 30 +-- torch/utils/data/sampler.py | 22 +- torch/utils/module_tracker.py | 1 + torch/xpu/__init__.py | 4 +- 65 files changed, 446 insertions(+), 522 deletions(-) diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 55ffa429e7f9a..927325bffeb2f 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -52,7 +52,6 @@ # torch/[e-m]*/** # torch/optim/** # torch/[p-z]*/** - "torch/[p-z]*/**", ], ), ) diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py index 09d7901c2d6cc..08b0560f79322 100644 --- a/torch/package/_mangling.py +++ b/torch/package/_mangling.py @@ -2,6 +2,7 @@ """Import mangling. See mangling.md for details. """ + import re diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 21446c626b9a3..6118e8ce80964 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -605,9 +605,9 @@ def save_pickle( dependencies (bool, optional): If ``True``, we scan the source for dependencies. """ - assert (pickle_protocol == 4) or ( - pickle_protocol == 3 - ), "torch.package only supports pickle protocols 3 and 4" + assert (pickle_protocol == 4) or (pickle_protocol == 3), ( + "torch.package only supports pickle protocols 3 and 4" + ) filename = self._filename(package, resource) # Write the pickle data for `obj` diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index a97cf475b350a..7291227e42ae2 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -423,7 +423,12 @@ def _load_module(self, name: str, parent: str): module.__dict__.setdefault(old_name, new_name) return module - return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] + return self._make_module( + name, + cur.source_file, # type: ignore[attr-defined] + isinstance(cur, _PackageNode), + parent, + ) def _compile_source(self, fullpath: str, mangled_filename: str): source = self.zip_reader.get_record(fullpath) diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index a90a371130e7a..153d4560e2641 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -7,6 +7,7 @@ An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated. """ + import os from typing import Any from typing_extensions import TypeVarTuple, Unpack diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 7ad917d1e86be..d9f3a917c1525 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -239,10 +239,12 @@ def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> tuple[Optional[bool], .. def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]: signature = tuple( # Tensor - TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) + TensorKey.from_tensor(i) + if isinstance(i, _TensorMetadata) # # TensorList - else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) + else [TensorKey.from_tensor(j) for j in i] + if isinstance(i, list) # # Scalar and uncaptured inputs. else i diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index b1160324cb906..5b631ef743c6e 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -124,9 +124,9 @@ def compute_self_time(self) -> None: for child_event in curr_event.children: self_time -= child_event.duration_time_ns stack.append(child_event) - assert ( - EventKey(curr_event) not in self.metrics - ), f"Duplicate id: {curr_event.id}, {curr_event.name}" + assert EventKey(curr_event) not in self.metrics, ( + f"Duplicate id: {curr_event.id}, {curr_event.name}" + ) self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) self.metrics[ EventKey(curr_event) @@ -227,8 +227,7 @@ def new_old_event_comparator(event): while ( current_kernel_index < len(cuda_kernel_events) - and (cuda_kernel_events[current_kernel_index].start_ns()) - <= start_time # type: ignore[possibly-undefined] + and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined] ): current_kernel_index += 1 current_queue_depth = spawned_kernel_index - current_kernel_index + 1 @@ -352,11 +351,11 @@ def get_optimizable_events(self, length: int = 1, print_enable: bool = True): output += "\n".join( [ - f"""{'-' * 80} + f"""{"-" * 80} Event: {event} Source code location: {source_code_location(event.event)} Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% -{'-' * 80}""" +{"-" * 80}""" for event in event_list ] ) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f7be416cfaa7f..d88d6c5cad72c 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -624,8 +624,7 @@ class profile(_KinetoProfile): ] ) as p: code_to_profile() - print(p.key_averages().table( - sort_by="self_cuda_time_total", row_limit=-1)) + print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: @@ -635,16 +634,17 @@ class profile(_KinetoProfile): # on different iterations of the training loop; # trace_handler is called every time a new trace becomes available def trace_handler(prof): - print(prof.key_averages().table( - sort_by="self_cuda_time_total", row_limit=-1)) + print( + prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1) + ) # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], - # In this example with wait=1, warmup=1, active=2, repeat=1, # profiler will skip the first step/iteration, # start warming up on the second, record @@ -652,20 +652,15 @@ def trace_handler(prof): # after which the trace will become available # and on_trace_ready (when set) is called; # the cycle repeats starting with the next step - - schedule=torch.profiler.schedule( - wait=1, - warmup=1, - active=2, - repeat=1), - on_trace_ready=trace_handler + schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1), + on_trace_ready=trace_handler, # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') # used when outputting for tensorboard - ) as p: - for iter in range(N): - code_iteration_to_profile(iter) - # send a signal to the profiler that the next iteration has started - p.step() + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index cfb13ac96271f..5a68fbf02015f 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -6,6 +6,7 @@ `torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement here. """ + from torch.ao.quantization.fuser_method_mappings import ( _DEFAULT_OP_LIST_TO_FUSER_METHOD, fuse_conv_bn, diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py index 7acea4f84a2a0..d6b8611d4a769 100644 --- a/torch/quantization/fx/_equalize.py +++ b/torch/quantization/fx/_equalize.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx._equalize import ( _convert_equalization_ref, _InputEqualizationObserver, diff --git a/torch/quantization/fx/convert.py b/torch/quantization/fx/convert.py index 9d6ac350602bb..30a661da41e5e 100644 --- a/torch/quantization/fx/convert.py +++ b/torch/quantization/fx/convert.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.convert import convert diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 67527080304fb..22ad750e9f878 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.fuse import fuse diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index e29337b3f861e..982d919655f36 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler diff --git a/torch/quantization/fx/graph_module.py b/torch/quantization/fx/graph_module.py index a71e980a57ba1..74b63903d7400 100644 --- a/torch/quantization/fx/graph_module.py +++ b/torch/quantization/fx/graph_module.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.graph_module import ( _is_observed_module, _is_observed_standalone_module, diff --git a/torch/quantization/fx/match_utils.py b/torch/quantization/fx/match_utils.py index 8b49f7c645d8d..8585a21ad445d 100644 --- a/torch/quantization/fx/match_utils.py +++ b/torch/quantization/fx/match_utils.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.match_utils import ( _find_matches, _is_match, diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index 2a83e180fc4db..fa601d1eb619c 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.pattern_utils import ( _register_fusion_pattern, _register_quant_pattern, diff --git a/torch/quantization/fx/prepare.py b/torch/quantization/fx/prepare.py index ca65dcc04dd00..a6007ef242af5 100644 --- a/torch/quantization/fx/prepare.py +++ b/torch/quantization/fx/prepare.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.prepare import prepare diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 20d8cc52ee4fb..89f8d4406e912 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.quantize_handler import ( BatchNormQuantizeHandler, BinaryOpQuantizeHandler, diff --git a/torch/quantization/fx/quantization_types.py b/torch/quantization/fx/quantization_types.py index a422cdd3142e0..0820ea057078e 100644 --- a/torch/quantization/fx/quantization_types.py +++ b/torch/quantization/fx/quantization_types.py @@ -6,4 +6,5 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.utils import Pattern, QuantizerCls diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index ef35559884b7c..e45c82b8fb6f2 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -6,6 +6,7 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ + from torch.ao.quantization.fx.utils import ( all_node_args_have_no_tensors, assert_and_get_unique_device, diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 6e6c7c1917c83..2163e2717b069 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -6,6 +6,7 @@ `torch/ao/quantization/observer.py`, while adding an import statement here. """ + from torch.ao.quantization.observer import ( _is_activation_post_process, _is_per_channel_script_obs_instance, diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 6bb7e14110cb9..a02ff7d6f7388 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -6,6 +6,7 @@ `torch/ao/quantization/qconfig.py`, while adding an import statement here. """ + from torch.ao.quantization.qconfig import ( _add_module_to_qconfig_obs_ctr, _assert_valid_qconfig, diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 8b44a980ce82f..faa24d391d31a 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -6,6 +6,7 @@ `torch/ao/quantization/quantization_mappings.py`, while adding an import statement here. """ + from torch.ao.quantization.quantization_mappings import ( _get_special_act_post_process, _has_special_act_post_process, diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 7d67de3f83848..e68c202f03e8a 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -128,9 +128,7 @@ def _window_function_checks( >>> # Generates a periodic exponential window and decay factor equal to .5 >>> torch.signal.windows.exponential(10, sym=False,tau=.5) tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04]) - """.format( - **window_common_args - ), + """.format(**window_common_args), ) def exponential( M: int, @@ -452,9 +450,7 @@ def kaiser( >>> # Generates a periodic Hamming window. >>> torch.signal.windows.hamming(10, sym=False) tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def hamming( M: int, @@ -508,9 +504,7 @@ def hamming( >>> # Generates a periodic Hann window. >>> torch.signal.windows.hann(10, sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def hann( M: int, @@ -564,9 +558,7 @@ def hann( >>> # Generates a periodic Blackman window. >>> torch.signal.windows.blackman(5, sym=False) tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def blackman( M: int, @@ -627,9 +619,7 @@ def blackman( >>> # Generates a periodic Bartlett window. >>> torch.signal.windows.bartlett(10, sym=False) tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def bartlett( M: int, @@ -704,9 +694,7 @@ def bartlett( >>> # Generates a periodic general cosine window with 2 coefficients. >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def general_cosine( M, @@ -799,9 +787,7 @@ def general_cosine( >>> # Generates a periodic Hann window with the general Hamming window. >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def general_hamming( M, @@ -866,9 +852,7 @@ def general_hamming( >>> # Generates a periodic Nuttall window. >>> torch.signal.windows.general_hamming(5, sym=False) tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01]) -""".format( - **window_common_args - ), +""".format(**window_common_args), ) def nuttall( M: int, diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 39d78e8c26ab7..31299314a85f1 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -559,7 +559,11 @@ def as_sparse_gradcheck(gradcheck): For example: >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) - >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True) + >>> x = ( + ... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64) + ... .to_sparse_coo() + ... .requires_grad_(True) + ... ) >>> gradcheck(lambda x: x.to_sparse_csr(), x) True """ @@ -667,7 +671,7 @@ def restore_from_strided_representation(args): ) else: raise NotImplementedError( - f'conversion of {d["layout"]} strided representation to tensor' + f"conversion of {d['layout']} strided representation to tensor" ) new_args.append(a) return tuple(new_args) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index a5e802084c28b..ea36264d8f822 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -296,11 +296,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): for b in range(nbatches): for i, r in enumerate(r_offsets): r0, r1 = divmod(r, N) - acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] - for g in range(c_indices[i], c_indices[i+1]): + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] + for g in range(c_indices[i], c_indices[i + 1]): p = p_offsets[g] q0, q1 = divmod(q_offsets[g], N) - acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are integer multiples of ``Ms`` and ``Ks``, respectively. @@ -320,11 +320,11 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): n = (r % N) // Ns r0, r1 = divmod(r, N) c0, c1 = c_indices[m], c_indices[m + 1] - acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] for i, p in enumerate(range(c0, c1)): q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] q0, q1 = divmod(q, N) - acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] + acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are integer multiples of ``Ms`` and ``Ks``, respectively. diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 762874077c7ac..89245246395a9 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -97,6 +97,7 @@ kernel parameters for addmm-based operations. """ + __all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"] import inspect @@ -432,9 +433,9 @@ def from_key(key, parameters): def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): - assert ( - sparsity <= 1.0 and sparsity >= 0.0 - ), "sparsity should be a value between 0 and 1" + assert sparsity <= 1.0 and sparsity >= 0.0, ( + "sparsity should be a value between 0 and 1" + ) assert M % blocksize[0] == 0 assert N % blocksize[1] == 0 shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 721f25512794d..b225eaabb3206 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -465,14 +465,26 @@ def prune_dense_static_sort( The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below: ``` from torch.sparse import SparseSemiStructuredTensorCUTLASS - from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + from torch.sparse._semi_structured_conversions import ( + _sparse_semi_structured_tile, + _compute_compressed_swizzled_bitmask, + ) pruned = _sparse_semi_structured_tile(dense) packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) - packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) + packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass( + pruned.t().contiguous() + ) bitmask = _compute_compressed_swizzled_bitmask(pruned) - SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask) + SparseSemiStructuredTensorCUTLASS( + dense.shape, + packed_cutlass, + meta_cutlass, + packed_t_cutlass, + meta_t_cutlass, + bitmask, + ) ``` """ # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. @@ -583,14 +595,19 @@ def prune_dense_static_sort( The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below: ``` from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + from torch.sparse._semi_structured_conversions import ( + _sparse_semi_structured_tile, + _compute_compressed_swizzled_bitmask, + ) pruned = _sparse_semi_structured_tile(dense) packed_cusparselt = torch._cslt_compress(pruned) packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) bitmask = _compute_compressed_swizzled_bitmask(pruned) - SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) + SparseSemiStructuredTensorCUSPARSELT( + dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask + ) ``` """ ( diff --git a/torch/special/__init__.py b/torch/special/__init__.py index be027caa94cbb..dbc9314ad2087 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -134,9 +134,7 @@ >>> torch.special.digamma(a) tensor([-0.5772, -1.9635]) -""".format( - **common_args - ), +""".format(**common_args), ) gammaln = _add_docstr( @@ -162,9 +160,7 @@ >>> torch.special.gammaln(a) tensor([ 0.5724, 0.0000, -0.1208]) -""".format( - **common_args - ), +""".format(**common_args), ) polygamma = _add_docstr( @@ -200,9 +196,7 @@ tensor([ 6.4939, 97.4091]) >>> torch.special.polygamma(4, a) tensor([ -24.8863, -771.4742]) -""".format( - **common_args - ), +""".format(**common_args), ) erf = _add_docstr( @@ -226,9 +220,7 @@ >>> torch.special.erf(torch.tensor([0, -1., 10.])) tensor([ 0.0000, -0.8427, 1.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) erfc = _add_docstr( @@ -253,9 +245,7 @@ >>> torch.special.erfc(torch.tensor([0, -1., 10.])) tensor([ 1.0000, 1.8427, 0.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) erfcx = _add_docstr( @@ -283,9 +273,7 @@ >>> torch.special.erfcx(torch.tensor([0, -1., 10.])) tensor([ 1.0000, 5.0090, 0.0561]) -""".format( - **common_args - ), +""".format(**common_args), ) erfinv = _add_docstr( @@ -311,9 +299,7 @@ >>> torch.special.erfinv(torch.tensor([0, 0.5, -1.])) tensor([ 0.0000, 0.4769, -inf]) -""".format( - **common_args - ), +""".format(**common_args), ) logit = _add_docstr( @@ -351,9 +337,7 @@ tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) >>> torch.special.logit(a, eps=1e-6) tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261]) -""".format( - **common_args - ), +""".format(**common_args), ) logsumexp = _add_docstr( @@ -362,9 +346,7 @@ logsumexp(input, dim, keepdim=False, *, out=None) Alias for :func:`torch.logsumexp`. -""".format( - **multi_dim_common - ), +""".format(**multi_dim_common), ) expit = _add_docstr( @@ -391,9 +373,7 @@ tensor([ 0.9213, 1.0887, -0.8858, -1.7683]) >>> torch.special.expit(t) tensor([ 0.7153, 0.7481, 0.2920, 0.1458]) -""".format( - **common_args - ), +""".format(**common_args), ) exp2 = _add_docstr( @@ -418,9 +398,7 @@ >>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4])) tensor([ 1., 2., 8., 16.]) -""".format( - **common_args - ), +""".format(**common_args), ) expm1 = _add_docstr( @@ -448,9 +426,7 @@ >>> torch.special.expm1(torch.tensor([0, math.log(2.)])) tensor([ 0., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) xlog1py = _add_docstr( @@ -495,9 +471,7 @@ tensor([1.6094, 3.2189, 4.8283]) >>> torch.special.xlog1py(2, y) tensor([2.7726, 2.1972, 1.3863]) -""".format( - **common_args - ), +""".format(**common_args), ) xlogy = _add_docstr( @@ -542,9 +516,7 @@ tensor([1.3863, 2.7726, 4.1589]) >>> torch.special.xlogy(2, y) tensor([2.1972, 1.3863, 0.0000]) -""".format( - **common_args - ), +""".format(**common_args), ) i0 = _add_docstr( @@ -570,9 +542,7 @@ >>> torch.i0(torch.arange(5, dtype=torch.float32)) tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019]) -""".format( - **common_args - ), +""".format(**common_args), ) i0e = _add_docstr( @@ -597,9 +567,7 @@ >>> torch.special.i0e(torch.arange(5, dtype=torch.float32)) tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070]) -""".format( - **common_args - ), +""".format(**common_args), ) i1 = _add_docstr( @@ -624,9 +592,7 @@ >>> torch.special.i1(torch.arange(5, dtype=torch.float32)) tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595]) -""".format( - **common_args - ), +""".format(**common_args), ) i1e = _add_docstr( @@ -652,9 +618,7 @@ >>> torch.special.i1e(torch.arange(5, dtype=torch.float32)) tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788]) -""".format( - **common_args - ), +""".format(**common_args), ) ndtr = _add_docstr( @@ -679,9 +643,7 @@ >>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987]) -""".format( - **common_args - ), +""".format(**common_args), ) ndtri = _add_docstr( @@ -709,9 +671,7 @@ >>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1])) tensor([ -inf, -0.6745, 0.0000, 0.6745, inf]) -""".format( - **common_args - ), +""".format(**common_args), ) log_ndtr = _add_docstr( @@ -736,9 +696,7 @@ >>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014]) -""".format( - **common_args - ), +""".format(**common_args), ) log1p = _add_docstr( @@ -779,9 +737,7 @@ tensor([ 0.2252, -0.2948, 1.0267, -1.1566]) >>> torch.special.sinc(t) tensor([ 0.9186, 0.8631, -0.0259, -0.1300]) -""".format( - **common_args - ), +""".format(**common_args), ) round = _add_docstr( @@ -886,9 +842,7 @@ tensor([1.6449, 0.0823]) >>> torch.special.zeta(2, torch.tensor([1., 2.])) tensor([1.6449, 0.6449]) -""".format( - **common_args - ), +""".format(**common_args), ) multigammaln = _add_docstr( @@ -925,9 +879,7 @@ >>> torch.special.multigammaln(a, 2) tensor([[0.3928, 0.4007, 0.7586], [1.0311, 0.3901, 0.5049]]) -""".format( - **common_args - ), +""".format(**common_args), ) gammainc = _add_docstr( @@ -976,9 +928,7 @@ >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) tensor([1., 1., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) gammaincc = _add_docstr( @@ -1026,9 +976,7 @@ >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) tensor([1., 1., 1.]) -""".format( - **common_args - ), +""".format(**common_args), ) airy_ai = _add_docstr( @@ -1045,9 +993,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_j0 = _add_docstr( @@ -1064,9 +1010,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_j1 = _add_docstr( @@ -1083,9 +1027,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_y0 = _add_docstr( @@ -1102,9 +1044,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) bessel_y1 = _add_docstr( @@ -1121,9 +1061,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_t = _add_docstr( @@ -1154,9 +1092,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_u = _add_docstr( @@ -1188,9 +1124,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_v = _add_docstr( @@ -1208,9 +1142,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) chebyshev_polynomial_w = _add_docstr( @@ -1228,9 +1160,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) hermite_polynomial_h = _add_docstr( @@ -1256,9 +1186,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) hermite_polynomial_he = _add_docstr( @@ -1284,9 +1212,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) laguerre_polynomial_l = _add_docstr( @@ -1312,9 +1238,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) legendre_polynomial_p = _add_docstr( @@ -1340,9 +1264,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_i0 = _add_docstr( @@ -1359,9 +1281,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_i1 = _add_docstr( @@ -1378,9 +1298,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_k0 = _add_docstr( @@ -1397,9 +1315,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) modified_bessel_k1 = _add_docstr( @@ -1416,9 +1332,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) scaled_modified_bessel_k0 = _add_docstr( @@ -1435,9 +1349,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) scaled_modified_bessel_k1 = _add_docstr( @@ -1454,9 +1366,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_t = _add_docstr( @@ -1474,9 +1384,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_u = _add_docstr( @@ -1494,9 +1402,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_v = _add_docstr( @@ -1514,9 +1420,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) shifted_chebyshev_polynomial_w = _add_docstr( @@ -1534,9 +1438,7 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) spherical_bessel_j0 = _add_docstr( @@ -1553,7 +1455,5 @@ Keyword args: {out} -""".format( - **common_args - ), +""".format(**common_args), ) diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 228c04cd312f2..eff07c413deb4 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -1538,7 +1538,9 @@ def assert_close( >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. - >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") + >>> torch.testing.assert_close( + ... actual, expected, msg="Argh, the tensors are not close!" + ... ) Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index e513b8d856035..23d80d6ceae4f 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -115,11 +115,11 @@ def make_tensor( >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) - >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) + >>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1) >>> # xdoctest: +SKIP tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA - >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) + >>> make_tensor((2, 2), device="cuda", dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0') """ diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 01499280da8f5..528497ba54576 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -721,9 +721,9 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo intersect = set(except_for if except_for else []) & set( only_for if only_for else [] ) - assert ( - not intersect - ), f"device ({intersect}) appeared in both except_for and only_for" + assert not intersect, ( + f"device ({intersect}) appeared in both except_for and only_for" + ) # Replace your privateuse1 backend name with 'privateuse1' if is_privateuse1_backend_available(): @@ -1407,9 +1407,9 @@ def __init__(self, num_required_devices): self.num_required_devices = num_required_devices def __call__(self, fn): - assert not hasattr( - fn, "num_required_devices" - ), f"deviceCountAtLeast redefinition for {fn.__name__}" + assert not hasattr(fn, "num_required_devices"), ( + f"deviceCountAtLeast redefinition for {fn.__name__}" + ) fn.num_required_devices = self.num_required_devices @wraps(fn) @@ -1474,13 +1474,13 @@ def only_fn(self, *args, **kwargs): # self.precision *2, max(1, self.precision)). class precisionOverride: def __init__(self, d): - assert isinstance( - d, dict - ), "precisionOverride not given a dtype : precision dict!" + assert isinstance(d, dict), ( + "precisionOverride not given a dtype : precision dict!" + ) for dtype in d.keys(): - assert isinstance( - dtype, torch.dtype - ), f"precisionOverride given unknown dtype {dtype}" + assert isinstance(dtype, torch.dtype), ( + f"precisionOverride given unknown dtype {dtype}" + ) self.d = d @@ -1513,12 +1513,12 @@ class toleranceOverride: def __init__(self, d): assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" for dtype, prec in d.items(): - assert isinstance( - dtype, torch.dtype - ), f"toleranceOverride given unknown dtype {dtype}" - assert isinstance( - prec, tol - ), "toleranceOverride not given a dtype : tol dict!" + assert isinstance(dtype, torch.dtype), ( + f"toleranceOverride given unknown dtype {dtype}" + ) + assert isinstance(prec, tol), ( + "toleranceOverride not given a dtype : tol dict!" + ) self.d = d @@ -1546,13 +1546,13 @@ def __init__(self, *args, device_type="all"): "all dtype variants must be. " f"Received non-list non-tuple dtype {str(arg)}" ) - assert all( - isinstance(dtype, torch.dtype) for dtype in arg - ), f"Unknown dtype in {str(arg)}" + assert all(isinstance(dtype, torch.dtype) for dtype in arg), ( + f"Unknown dtype in {str(arg)}" + ) else: - assert all( - isinstance(arg, torch.dtype) for arg in args - ), f"Unknown dtype in {str(args)}" + assert all(isinstance(arg, torch.dtype) for arg in args), ( + f"Unknown dtype in {str(args)}" + ) self.args = args self.device_type = device_type diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index af1aafd3871ae..0dbb6ca0ea718 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -253,9 +253,9 @@ def verify_ddp_error_logged(model_DDP, err_substr): if err_substr.find("\nException raised from ") == -1 else err_substr.split("\nException raised from ")[0] ) - assert ( - actual in logging_err - ), f"Did not find expected {actual} in ddp logging data error: {logging_err}" + assert actual in logging_err, ( + f"Did not find expected {actual} in ddp logging data error: {logging_err}" + ) def with_nccl_blocking_wait(func): @@ -294,9 +294,9 @@ def wrapper(*args, **kwargs): finally: # restore old values. if cached_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = cached_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + cached_nccl_async_error_handling + ) if cached_nccl_blocking_wait is not None: os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait @@ -812,7 +812,7 @@ def run_test(self, test_name: str, parent_pipe) -> None: sys.exit(TEST_SKIPS["generic"].exit_code) except Exception: logger.error( - "Caught exception: \n%s exiting " "process %s with exit code: %s", + "Caught exception: \n%s exiting process %s with exit code: %s", traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, @@ -1689,9 +1689,7 @@ def _spawn_processes(cls, world_size) -> None: cls.processes.append(process) cls.task_queues.append(task_queue) cls.completion_queues.append(completion_queue) - logger.info( - "Started process %s with pid %s", rank, process.pid - ) # noqa: UP031 + logger.info("Started process %s with pid %s", rank, process.pid) # noqa: UP031 @classmethod def setUpClass(cls): diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index a9e24eb90ef8c..0e50762893d70 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1285,10 +1285,10 @@ def _train_for_several_steps( loss = sharded_grad_scaler.scale(loss) if not mixed_precision and not use_pure_fp16: - assert ( - loss.dtype == torch.float32 - ), "loss data type should be float32, as the original \ + assert loss.dtype == torch.float32, ( + "loss data type should be float32, as the original \ parameter data type is float32." + ) else: if use_pure_fp16: self.assertEqual(loss.dtype, torch.float16) @@ -1354,9 +1354,9 @@ def _test_fsdp_parity( wrapper should provide data parallel semantics. If ``None``, then the callable defaults to the DDP constructor. """ - assert ( - fsdp_init_mode != FSDPInitMode.NO_FSDP - ), "Expects an FSDP init mode that wraps with FSDP" + assert fsdp_init_mode != FSDPInitMode.NO_FSDP, ( + "Expects an FSDP init mode that wraps with FSDP" + ) if init_kwargs is None: init_kwargs = {} lr = 1e-2 diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 780514e674397..96bab4a084c4f 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1268,9 +1268,9 @@ def _get_optim_inputs_including_global_cliquey_kwargs( trivial. That said, we sometimes want to test for all possible configs on an optimizer including all supported flags, so this helper returns all optim inputs. """ - assert all( - x in ["foreach", "fused", "differentiable"] for x in skip - ), "skip must be a subset of ['foreach', 'fused', 'differentiable']" + assert all(x in ["foreach", "fused", "differentiable"] for x in skip), ( + "skip must be a subset of ['foreach', 'fused', 'differentiable']" + ) optim_inputs = optim_info.optim_inputs_func(device) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index f3a72441f3704..4eb6677a035ec 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -477,7 +477,9 @@ def with_comms( def decorator(func, eager_init: bool = False, backend: Optional[str] = None): @wraps(func) # pyre-ignore[6] def wrapper( - self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] + self, + *args: tuple[object], + **kwargs: dict[str, Any], # type: ignore[misc] ) -> None: self.init_pg(eager_init, backend) diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index 1ac9252d498e0..61c21be3ca075 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -253,7 +253,11 @@ def train_batch( else: input_batches = batches - with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext(): + with ( + self.hybrid_module.join() + if simulate_uneven_inputs + else contextlib.nullcontext() + ): for b in input_batches: with dist_autograd.context() as context_id: output = self.hybrid_module.forward(b) @@ -261,8 +265,7 @@ def train_batch( dist_autograd.backward(context_id, [loss]) grads_dict = dist_autograd.get_gradients(context_id) gLogger.info( - "Loss is %s for mini batch: %s. " - "Grads dict has %s entries: %s", + "Loss is %s for mini batch: %s. Grads dict has %s entries: %s", loss, mini_batch, len(grads_dict), diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 5cd248792dcb1..97dee3c7c0f4e 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -162,9 +162,7 @@ def __init__( # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as # SampleInput(input, *args, **kwargs) but not to mix the two forms if args is not None or kwargs is not None: - assert ( - not var_args and not var_kwargs - ), """ + assert not var_args and not var_kwargs, """ A SampleInput can be constructed "naturally" with *args and **kwargs or by explicitly setting the "args" and "kwargs" parameters, but the two methods of construction cannot be mixed!""" @@ -226,7 +224,7 @@ def _repr_helper(self, formatter): f"name={repr(self.name)}", ] - return f'SampleInput({", ".join(a for a in arguments if a is not None)})' + return f"SampleInput({', '.join(a for a in arguments if a is not None)})" def __repr__(self): return self._repr_helper(lambda x: x) @@ -1601,13 +1599,11 @@ def __post_init__(self): # returns a string identifier of the rule type @abstractmethod - def type(self) -> str: - ... + def type(self) -> str: ... # returns an appropriate context that handles the xfail, skips, etc. @abstractmethod - def get_context(self, test_case): - ... + def get_context(self, test_case): ... # useful for specifying xfails @@ -1791,8 +1787,10 @@ def __init__( # kwargs to use when calling the op. This is required for operators that # have other required parameters besides the input tensor. generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( - yield (), - {}, + yield ( + (), + {}, + ) ), # Options from the OpInfo base class **kwargs, @@ -2476,9 +2474,9 @@ def __init__( self.supports_one_python_scalar = True if self.supports_one_python_scalar: - assert ( - supports_rhs_python_scalar - ), "Can't support lhs and rhs Python scalars but not rhs scalars!" + assert supports_rhs_python_scalar, ( + "Can't support lhs and rhs Python scalars but not rhs scalars!" + ) # The following functions and classes are for testing elementwise unary operators. diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index e05299632d04d..c5d08073803bb 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -102,8 +102,9 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) yield SampleInput( sample_input.input.detach().requires_grad_(requires_grad), @@ -224,8 +225,9 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs): op_info, device, dtype, requires_grad, **kwargs ): sample_input_args, sample_input_kwargs = ( - ord, - ) + sample_input.args, sample_input.kwargs.copy() + (ord,) + sample_input.args, + sample_input.kwargs.copy(), + ) yield SampleInput( sample_input.input.clone().requires_grad_(requires_grad), args=sample_input_args, @@ -276,8 +278,9 @@ def masked_samples(): for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) yield SampleInput( sample_input.input.detach().requires_grad_(requires_grad), @@ -364,8 +367,9 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs) ): if type(mask) != torch.Tensor: continue - sample_input_args, sample_input_kwargs = sample_input.args, dict( - mask=mask, **sample_input.kwargs + sample_input_args, sample_input_kwargs = ( + sample_input.args, + dict(mask=mask, **sample_input.kwargs), ) if "keepdim" in sample_input_kwargs: sample_input_kwargs.pop("keepdim") diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 4ec4e5b591596..811b45fd1d697 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -112,7 +112,7 @@ def __init__( @staticmethod def string_or_list_of_string_to_list( - val: Optional[Union[str, list[str]]] + val: Optional[Union[str, list[str]]], ) -> Optional[list[str]]: if val is None: return None @@ -135,8 +135,7 @@ def Config( env_name_force: Optional[Union[str, list[str]]] = None, value_type: Optional[type] = None, alias: Optional[str] = None, - ) -> T: - ... + ) -> T: ... else: @@ -323,9 +322,9 @@ def __init__(self, config: _Config): # Ensure justknobs and envvars are allowlisted types if self.justknob is not None and self.default is not None: - assert isinstance( - self.default, bool - ), f"justknobs only support booleans, {self.default} is not a boolean" + assert isinstance(self.default, bool), ( + f"justknobs only support booleans, {self.default} is not a boolean" + ) if self.value_type is not None and ( config.env_name_default is not None or config.env_name_force is not None ): @@ -334,7 +333,9 @@ def __init__(self, config: _Config): str, Optional[bool], Optional[str], - ), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" + ), ( + f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" + ) class ConfigModule(ModuleType): diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 24c73061b716a..5ddda2c7edb6c 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -282,9 +282,9 @@ def tree_is_leaf( False >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) True - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": 3}) False - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": None}) False Args: @@ -586,29 +586,28 @@ def tree_map_( # These specializations help with type inference on the lambda passed to this # function @overload -def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: - ... +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: - ... +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: - ... +def map_only( + type_or_types_or_pred: Type3[T, S, U], / +) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... # This specialization is needed for the implementations below that call @overload -def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... @overload -def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only( + type_or_types_or_pred: Callable[[Any], bool], / +) -> MapOnlyFn[FnAny[Any]]: ... def map_only( @@ -664,8 +663,7 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -675,8 +673,7 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -686,8 +683,7 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -697,8 +693,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -708,8 +703,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only( @@ -729,8 +723,7 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -740,8 +733,7 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -751,8 +743,7 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -762,8 +753,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -773,8 +763,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only_( @@ -812,8 +801,7 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -823,8 +811,7 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -834,8 +821,7 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_all_only( @@ -856,8 +842,7 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -867,8 +852,7 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -878,8 +862,7 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_any_only( diff --git a/torch/utils/_functools.py b/torch/utils/_functools.py index 40ffd8f80a9e7..0b555ffc27f96 100644 --- a/torch/utils/_functools.py +++ b/torch/utils/_functools.py @@ -12,7 +12,7 @@ def cache_method( - f: Callable[Concatenate[_C, _P], _T] + f: Callable[Concatenate[_C, _P], _T], ) -> Callable[Concatenate[_C, _P], _T]: """ Like `@functools.cache` but for methods. diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 664994e6fe38f..84353fbbebf7a 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -302,14 +302,12 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # Subtypes which have __tensor_flatten__ and __tensor_unflatten__. class TensorWithFlatten(Protocol): - def __tensor_flatten__(self) -> tuple[Sequence[str], object]: - ... + def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ... @staticmethod def __tensor_unflatten__( inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... # It would be really nice to be able to say that the return of # is_traceable_wrapper_subclass() is Intersection[torch.Tensor, @@ -318,26 +316,20 @@ def __tensor_unflatten__( shape: torch._C.Size @overload - def stride(self, dim: None = None) -> tuple[int, ...]: - ... + def stride(self, dim: None = None) -> tuple[int, ...]: ... @overload - def stride(self, dim: int) -> int: - ... + def stride(self, dim: int) -> int: ... @overload - def size(self, dim: None = None) -> tuple[int, ...]: - ... + def size(self, dim: None = None) -> tuple[int, ...]: ... @overload - def size(self, dim: int) -> int: - ... + def size(self, dim: int) -> int: ... - def storage_offset(self) -> int: - ... + def storage_offset(self) -> int: ... - def dim(self) -> int: - ... + def dim(self) -> int: ... @overload def to( @@ -347,8 +339,7 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... @overload def to( @@ -359,8 +350,7 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... @overload def to( @@ -370,8 +360,7 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 3e7cadc6dc7a7..02954d33866cb 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -99,17 +99,13 @@ class KeyEntry(Protocol): - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... - def __eq__(self, other: object) -> bool: - ... + def __eq__(self, other: object) -> bool: ... - def __str__(self) -> str: - ... + def __str__(self) -> str: ... - def get(self, parent: Any) -> Any: - ... + def get(self, parent: Any) -> Any: ... class EnumEncoder(json.JSONEncoder): @@ -757,7 +753,7 @@ def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]: def _tuple_flatten_with_keys( - d: tuple[T, ...] + d: tuple[T, ...], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _tuple_flatten(d) return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -785,7 +781,7 @@ def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]: def _dict_flatten_with_keys( - d: dict[Any, T] + d: dict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _dict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -849,7 +845,7 @@ def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]: def _ordereddict_flatten_with_keys( - d: OrderedDict[Any, T] + d: OrderedDict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _ordereddict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -872,7 +868,7 @@ def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]: def _defaultdict_flatten_with_keys( - d: defaultdict[Any, T] + d: defaultdict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _defaultdict_flatten(d) _, dict_context = context @@ -1035,9 +1031,9 @@ def tree_is_leaf( False >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) True - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": 3}) False - >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + >>> tree_is_leaf({"a": 1, "b": 2, "c": None}) False """ if is_leaf is not None and is_leaf(tree): @@ -1346,9 +1342,9 @@ def tree_map( See also :func:`tree_map_`. - >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)}) {'x': 8, 'y': (43, 65)} - >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None}) {'x': False, 'y': (False, False), 'z': True} If multiple inputs are given, the structure of the tree is taken from the first input; @@ -1432,29 +1428,28 @@ def tree_map_( # These specializations help with type inference on the lambda passed to this # function @overload -def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: - ... +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: - ... +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: ... @overload -def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: - ... +def map_only( + type_or_types_or_pred: Type3[T, S, U], / +) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... # This specialization is needed for the implementations below that call @overload -def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... @overload -def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only( + type_or_types_or_pred: Callable[[Any], bool], / +) -> MapOnlyFn[FnAny[Any]]: ... def map_only( @@ -1510,8 +1505,7 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1521,8 +1515,7 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1532,8 +1525,7 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1543,8 +1535,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1554,8 +1545,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only( @@ -1575,8 +1565,7 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1586,8 +1575,7 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1597,8 +1585,7 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1608,8 +1595,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1619,8 +1605,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only_( @@ -1658,8 +1643,7 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1669,8 +1653,7 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1680,8 +1663,7 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_all_only( @@ -1702,8 +1684,7 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1713,8 +1694,7 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1724,8 +1704,7 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_any_only( @@ -1862,7 +1841,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: raise NotImplementedError( - f'Deserializing {json_schema["type"]} in pytree is not registered.', + f"Deserializing {json_schema['type']} in pytree is not registered.", ) typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index 39e981a78ac5b..9b94a7b7a484b 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -301,7 +301,7 @@ def strobelight( profiler = StrobelightCLIFunctionProfiler(**kwargs) def strobelight_inner( - work_function: Callable[_P, _R] + work_function: Callable[_P, _R], ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 42c99839d4164..2b6c159f5c3a0 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -98,7 +98,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool: def _keep_float( - f: Callable[[Unpack[_Ts]], _T] + f: Callable[[Unpack[_Ts]], _T], ) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: @functools.wraps(f) def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: @@ -926,10 +926,12 @@ def _find_localzeros(cls, values, **options): _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 _eval_is_antihermitian = lambda s: _torf( # noqa: E731 - i.is_antihermitian for i in s.args # noqa: E731 + i.is_antihermitian + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_commutative = lambda s: _torf( # noqa: E731 - i.is_commutative for i in s.args # noqa: E731 + i.is_commutative + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 @@ -943,10 +945,12 @@ def _find_localzeros(cls, values, **options): _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 _eval_is_nonnegative = lambda s: _torf( # noqa: E731 - i.is_nonnegative for i in s.args # noqa: E731 + i.is_nonnegative + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_nonpositive = lambda s: _torf( # noqa: E731 - i.is_nonpositive for i in s.args # noqa: E731 + i.is_nonpositive + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 @@ -956,10 +960,12 @@ def _find_localzeros(cls, values, **options): _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 _eval_is_extended_real = lambda s: _torf( # noqa: E731 - i.is_extended_real for i in s.args # noqa: E731 + i.is_extended_real + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_transcendental = lambda s: _torf( # noqa: E731 - i.is_transcendental for i in s.args # noqa: E731 + i.is_transcendental + for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 1b360337a53bb..e02e049cc36dd 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -144,16 +144,14 @@ def __init__( self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn, - ) -> None: - ... + ) -> None: ... @overload def __init__( # type: ignore[misc] self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn, - ) -> None: - ... + ) -> None: ... def __init__(self, lower: AllIn, upper: AllIn) -> None: lower = simple_sympify(lower) @@ -240,15 +238,13 @@ def tighten(self, other) -> ValueRanges: def __and__( self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr], - ) -> ValueRanges[sympy.Expr]: - ... + ) -> ValueRanges[sympy.Expr]: ... @overload def __and__( # type: ignore[misc] self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean], - ) -> ValueRanges[SympyBoolean]: - ... + ) -> ValueRanges[SympyBoolean]: ... def __and__(self: AllVR, other: AllVR) -> AllVR: if other in (ValueRanges.unknown(), ValueRanges.unknown_int()): @@ -272,15 +268,13 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: def __or__( self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr], - ) -> ValueRanges[sympy.Expr]: - ... + ) -> ValueRanges[sympy.Expr]: ... @overload def __or__( # type: ignore[misc] self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean], - ) -> ValueRanges[SympyBoolean]: - ... + ) -> ValueRanges[SympyBoolean]: ... def __or__(self: AllVR, other: AllVR) -> AllVR: if ValueRanges.unknown() in (self, other): @@ -343,8 +337,7 @@ def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: @overload @staticmethod - def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: - ... + def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ... @overload @staticmethod @@ -384,8 +377,7 @@ def coordinatewise_increasing_map( x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2, - ) -> ExprVR: - ... + ) -> ExprVR: ... @overload @staticmethod @@ -393,8 +385,7 @@ def coordinatewise_increasing_map( # type: ignore[misc] x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2, - ) -> BoolVR: - ... + ) -> BoolVR: ... @staticmethod def coordinatewise_increasing_map( diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index e11a7afc09d8a..5a83aede8d468 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -426,9 +426,9 @@ def func_name(*args, **kwargs): it is marked as private. It is a convenience function for backend implementers to more easily call the hooks into their backend extensions. """ - assert isinstance( - func_name, str - ), f"func_name must be `str`, but got `{type(func_name)}`." + assert isinstance(func_name, str), ( + f"func_name must be `str`, but got `{type(func_name)}`." + ) backend_name = _get_privateuse1_backend_name() custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type] function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type] diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index 68a4da0731c0e..3b291b1e60a4c 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -44,7 +44,7 @@ def default_convert(data): >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) @@ -366,13 +366,13 @@ def default_collate(batch): >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: - >>> default_collate(['a', 'b', 'c']) + >>> default_collate(["a", "b", "c"]) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: - >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) + >>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index c75756dd5fdb1..b53c7aef9596f 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -69,7 +69,9 @@ def pin_memory(data, device=None): ) return clone else: - return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] + return type(data)( + {k: pin_memory(sample, device) for k, sample in data.items()} + ) # type: ignore[call-arg] except TypeError: # The mapping type may not support `copy()` / `update(mapping)` # or `__init__(iterable)`. diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index a275e2e86b6ff..97c7243e78ef7 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. +r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These **needs** to be in global scope since Py2 doesn't support serializing static methods. diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index dd7a73ea11e08..991b4f00eb85e 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,6 +5,7 @@ functions to be run in multiprocessing. E.g., the data loading worker loop is in `./_utils/worker.py`. """ + from __future__ import annotations import functools @@ -1208,7 +1209,10 @@ def __init__(self, loader): atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) # .pid can be None only before process is spawned (not the case, so ignore) - _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_worker_pids( + id(self), + tuple(w.pid for w in self._workers), # type: ignore[misc] + ) _utils.signal_handling._set_SIGCHLD_handler() self._worker_pids_set = True self._reset(loader, first_iter=True) diff --git a/torch/utils/data/datapipes/_decorator.py b/torch/utils/data/datapipes/_decorator.py index 13e28a19d6266..0833f8fdf759b 100644 --- a/torch/utils/data/datapipes/_decorator.py +++ b/torch/utils/data/datapipes/_decorator.py @@ -109,8 +109,7 @@ def __call__(self, *args, **kwargs): # Decorate with a functional argument if not ( - isinstance(args[0], type) - and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] + isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] ): raise TypeError( f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found" diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index d3eeee0ebfdd5..506f642c411db 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -99,7 +99,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): >>> from torchdata.datapipes.iter import IterableWrapper, Mapper >>> dp = IterableWrapper(range(10)) >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor - >>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended) + >>> map_dp_2 = dp.map( + ... lambda x: x + 1 + ... ) # Using functional form (recommended) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> list(map_dp_2) @@ -114,7 +116,9 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): >>> list(it1) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> it1 = iter(source_dp) - >>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1` + >>> it2 = iter( + ... source_dp + ... ) # The creation of a new iterator invalidates `it1` >>> next(it2) 0 >>> next(it1) # Further usage of `it1` will raise a `RunTimeError` diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 718e728c9389d..41c6bb362af2b 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -55,7 +55,8 @@ class MapperIterDataPipe(IterDataPipe[_T_co]): >>> def add_one(x): ... return x + 1 >>> dp = IterableWrapper(range(10)) - >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred + >>> # Invocation via functional form is preferred + ... map_dp_1 = dp.map(add_one) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` @@ -202,7 +203,7 @@ class CollatorIterDataPipe(MapperIterDataPipe): >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): ... def __init__(self, start, end): ... super(MyIterDataPipe).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -211,13 +212,11 @@ class CollatorIterDataPipe(MapperIterDataPipe): ... ... def __len__(self): ... return self.end - self.start - ... >>> ds = MyIterDataPipe(start=3, end=7) >>> print(list(ds)) [3, 4, 5, 6] >>> def collate_fn(batch): ... return torch.tensor(batch, dtype=torch.float) - ... >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) >>> print(list(collated_ds)) [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index 4c602ce4eeda0..f92edd6b7b39c 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -38,15 +38,17 @@ def __init__( sampler_args: Optional[tuple] = None, sampler_kwargs: Optional[dict] = None, ) -> None: - assert isinstance( - datapipe, Sized - ), "Sampler class requires input datapipe implemented `__len__`" + assert isinstance(datapipe, Sized), ( + "Sampler class requires input datapipe implemented `__len__`" + ) super().__init__() self.datapipe = datapipe self.sampler_args = () if sampler_args is None else sampler_args self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs # https://github.com/python/mypy/pull/9629 will solve - self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc] + self.sampler = sampler( + *self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs + ) # type: ignore[misc] def __iter__(self) -> Iterator[_T_co]: return iter(self.sampler) diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index deaca079c68c0..8c6abc5062105 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -116,16 +116,13 @@ class _ContainerTemplate(ABC): r"""Abstract class for container ``DataPipes``. The followings are three required methods.""" @abstractmethod - def get_next_element_by_instance(self, instance_id: int): - ... + def get_next_element_by_instance(self, instance_id: int): ... @abstractmethod - def is_every_instance_exhausted(self) -> bool: - ... + def is_every_instance_exhausted(self) -> bool: ... @abstractmethod - def reset(self) -> None: - ... + def reset(self) -> None: ... @abstractmethod def get_length_by_instance(self, instance_id: int): @@ -403,7 +400,9 @@ class DemultiplexerIterDataPipe(IterDataPipe): >>> # It can also filter out any element that gets `None` from the `classifier_fn` >>> def odd_or_even_no_zero(n): ... return n % 2 if n != 0 else None - >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) + >>> dp1, dp2 = source_dp.demux( + ... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True + ... ) >>> list(dp1) [2, 4] >>> list(dp2) @@ -428,7 +427,9 @@ def __new__( # When num_instances == 1, demux can be replaced by filter, # but keep it as Demultiplexer for the sake of consistency # like throwing Error when classification result is out of o range - container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract] + container = _DemultiplexerIterDataPipe( + datapipe, num_instances, classifier_fn, drop_none, buffer_size + ) # type: ignore[abstract] return [_ChildDataPipe(container, i) for i in range(num_instances)] @@ -602,16 +603,18 @@ class MultiplexerIterDataPipe(IterDataPipe): Example: >>> # xdoctest: +REQUIRES(module:torchdata) >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(3)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) >>> list(dp1.mux(dp2, dp3)) [0, 10, 20, 1, 11, 21, 2, 12, 22] """ def __init__(self, *datapipes): self.datapipes = datapipes - self.buffer: list = ( - [] - ) # Store values to be yielded only when every iterator provides one + self.buffer: list = [] # Store values to be yielded only when every iterator provides one def __iter__(self): iterators = [iter(x) for x in self.datapipes] @@ -670,7 +673,11 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]): Example: >>> # xdoctest: +REQUIRES(module:torchdata) >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> dp1, dp2, dp3 = ( + ... IterableWrapper(range(5)), + ... IterableWrapper(range(10, 15)), + ... IterableWrapper(range(20, 25)), + ... ) >>> list(dp1.zip(dp2, dp3)) [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] """ diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 2542c89773bdd..3025b809e12df 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -33,8 +33,12 @@ class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]): Example: >>> # xdoctest: +SKIP - >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader - >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt')) + >>> from torchdata.datapipes.iter import ( + ... FileLister, + ... FileOpener, + ... StreamReader, + ... ) + >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt")) >>> dp = FileOpener(dp) >>> dp = StreamReader(dp) >>> list(dp) diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 08d124fdc6087..055d9c28b09be 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -182,7 +182,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): >>> from torchdata.datapipes.iter import IterableWrapper >>> def group_fn(file): ... return os.path.basename(file).split(".")[0] - >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) + >>> source_dp = IterableWrapper( + ... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"] + ... ) >>> dp0 = source_dp.groupby(group_key_fn=group_fn) >>> list(dp0) [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] @@ -191,7 +193,12 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): >>> list(dp1) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` - >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) + >>> dp2 = source_dp.groupby( + ... group_key_fn=group_fn, + ... buffer_size=3, + ... group_size=3, + ... guaranteed_group_size=2, + ... ) >>> list(dp2) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] """ diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py index 02865e8064f86..e1290df323724 100644 --- a/torch/utils/data/datapipes/map/utils.py +++ b/torch/utils/data/datapipes/map/utils.py @@ -31,8 +31,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]): >>> dp = SequenceWrapper(range(10)) >>> list(dp) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) - >>> dp['a'] + >>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400}) + >>> dp["a"] 100 """ diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index ee5bee8f15280..9db7309bdc525 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -45,8 +45,8 @@ def basichandlers(extension: str, data): Example: >>> import pickle - >>> data = pickle.dumps('some data') - >>> new_data = basichandlers('pickle', data) + >>> data = pickle.dumps("some data") + >>> new_data = basichandlers("pickle", data) >>> new_data some data @@ -169,9 +169,9 @@ class ImageHandler: """ def __init__(self, imagespec): - assert imagespec in list( - imagespecs.keys() - ), f"unknown image specification: {imagespec}" + assert imagespec in list(imagespecs.keys()), ( + f"unknown image specification: {imagespec}" + ) self.imagespec = imagespec.lower() def __call__(self, extension, data): @@ -205,18 +205,18 @@ def __call__(self, extension, data): return img elif atype == "numpy": result = np.asarray(img) - assert ( - result.dtype == np.uint8 - ), f"numpy image array should be type uint8, but got {result.dtype}" + assert result.dtype == np.uint8, ( + f"numpy image array should be type uint8, but got {result.dtype}" + ) if etype == "uint8": return result else: return result.astype("f") / 255.0 elif atype == "torch": result = np.asarray(img) - assert ( - result.dtype == np.uint8 - ), f"numpy image array should be type uint8, but got {result.dtype}" + assert result.dtype == np.uint8, ( + f"numpy image array should be type uint8, but got {result.dtype}" + ) if etype == "uint8": result = np.array(result.transpose(2, 0, 1)) diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index d0234c553ce68..e8164e015a668 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -96,7 +96,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]): >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -138,7 +138,7 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]): >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() - ... assert end > start, "this example code only works with end >= start" + ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... @@ -198,9 +198,9 @@ class TensorDataset(Dataset[tuple[Tensor, ...]]): tensors: tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: - assert all( - tensors[0].size(0) == tensor.size(0) for tensor in tensors - ), "Size mismatch between tensors" + assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), ( + "Size mismatch between tensors" + ) self.tensors = tensors def __getitem__(self, index): @@ -222,7 +222,7 @@ class StackDataset(Dataset[_T_stack]): >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) - >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} + >>> dict_stack[0] == {"image": images[0], "text": texts[0]} Args: *args (Dataset): Datasets for stacking returned as tuple. @@ -323,9 +323,9 @@ def __init__(self, datasets: Iterable[Dataset]) -> None: self.datasets = list(datasets) assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] for d in self.datasets: - assert not isinstance( - d, IterableDataset - ), "ConcatDataset does not support IterableDataset" + assert not isinstance(d, IterableDataset), ( + "ConcatDataset does not support IterableDataset" + ) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): @@ -371,17 +371,17 @@ def __init__(self, datasets: Iterable[Dataset]) -> None: def __iter__(self): for d in self.datasets: - assert isinstance( - d, IterableDataset - ), "ChainDataset only supports IterableDataset" + assert isinstance(d, IterableDataset), ( + "ChainDataset only supports IterableDataset" + ) yield from d def __len__(self): total = 0 for d in self.datasets: - assert isinstance( - d, IterableDataset - ), "ChainDataset only supports IterableDataset" + assert isinstance(d, IterableDataset), ( + "ChainDataset only supports IterableDataset" + ) total += len(d) # type: ignore[arg-type] return total diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index c92bdbb00e102..6c2e6dcaf2f45 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -236,9 +236,17 @@ class WeightedRandomSampler(Sampler[int]): Example: >>> # xdoctest: +IGNORE_WANT("non-deterministic") - >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + >>> list( + ... WeightedRandomSampler( + ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True + ... ) + ... ) [4, 4, 1, 4, 5] - >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) + >>> list( + ... WeightedRandomSampler( + ... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False + ... ) + ... ) [0, 1, 4, 3, 2] """ @@ -298,9 +306,15 @@ class BatchSampler(Sampler[list[int]]): its size would be less than ``batch_size`` Example: - >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + >>> list( + ... BatchSampler( + ... SequentialSampler(range(10)), batch_size=3, drop_last=False + ... ) + ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] - >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + >>> list( + ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) + ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 8ac97f2e2e826..4c7dec0481522 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -49,6 +49,7 @@ class ModuleTracker: def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias + torch.nn.functional.linear = my_linear mod(torch.rand(2, 2)) diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 23e3a25c90f5f..9a4ade5e71eaa 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -6,6 +6,7 @@ This package is lazily initialized, so you can always import it, and use :func:`is_available()` to determine if your system supports XPU. """ + import threading import traceback from functools import lru_cache @@ -292,6 +293,7 @@ class StreamContext: ``None``. .. note:: Streams are per-device. """ + cur_stream: Optional["torch.xpu.Stream"] def __init__(self, stream: Optional["torch.xpu.Stream"]): @@ -438,7 +440,7 @@ def get_gencode_flags() -> str: arch_list = get_arch_list() if len(arch_list) == 0: return "" - return f'-device {",".join(arch for arch in arch_list)}' + return f"-device {','.join(arch for arch in arch_list)}" def _get_generator(device: torch.device) -> torch._C.Generator: From 8b0be7b65a5dd83c2739a1d4d17e177e2e5cf569 Mon Sep 17 00:00:00 2001 From: Denghui Dong Date: Thu, 7 Aug 2025 01:17:52 +0000 Subject: [PATCH 0081/1424] [Profiler] Fix unexpected C return events (#159574) The fix in https://github.com/pytorch/pytorch/pull/155446 addressed the "stack empty" issue that's easily reproducible on CPython 3.12.0-4. While this issue can also appear in other versions, it's not as easy to reproduce there. I recently found a new cause for this problem. https://github.com/python/cpython/blob/1df5d0014578be7fe7ae25e2cc60c50c8b5cc0f7/Python/ceval.c#L5807-L5836 In the CPython 3.10 implementation, PyTrace_C_CALL and PyTrace_C_RETURN/PyTrace_C_EXCEPTION are supposed to appear in pairs. However, when c_profilefunc is changed, unexpected PyTrace_C_RETURN/PyTrace_C_EXCEPTION events can occur. Here is the code to reproduce this problem. ``` import threading import time import torch from threading import Event, Lock lock = Lock() lock.acquire() event1 = Event() event2 = Event() event3 = Event() def run(): event1.set() event2.wait() lock.acquire() event3.set() threading.Thread(target=run).start() with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True): event1.wait() event2.set() time.sleep(1) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True): lock.release() event3.wait() ``` image To fix this problem, we can record active_frames_ and remaining_start_frames_ for each thread, and when the PyTrace_C-RETURN/PyTrace_CEXT CEPTION event occurs, we can determine whether to record this event based on these two fields. In reality, even without this fix, the final data appears to be right since the match process can handle this case (it would just result in an exception log being printed). Do you think the fix is necessary? Pull Request resolved: https://github.com/pytorch/pytorch/pull/159574 Approved by: https://github.com/sraikund16 --- test/profiler/test_python_tracer.py | 41 +++++++++++++++++++++++++ torch/csrc/autograd/profiler_python.cpp | 23 ++++++++++++-- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/test/profiler/test_python_tracer.py b/test/profiler/test_python_tracer.py index 389395d8027c6..f7732b0b3893f 100644 --- a/test/profiler/test_python_tracer.py +++ b/test/profiler/test_python_tracer.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: profiler"] import json +import subprocess import sys import time @@ -63,6 +64,46 @@ def test_monitoring_callback(self): name = monitoring.get_tool(2) self.assertEqual(name, None) + def test_unexpected_c_return_events(self): + code = """ +import threading +import time +import torch + +from threading import Event, Lock + +lock = Lock() +lock.acquire() +event1 = Event() +event2 = Event() +event3 = Event() + +def run(): + event1.set() + event2.wait() + lock.acquire() + event3.set() + +threading.Thread(target=run).start() + +with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True): + event1.wait() + event2.set() + time.sleep(1) + +with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True): + lock.release() + event3.wait() + """ + + result = subprocess.run( + [sys.executable, "-c", code], capture_output=True, text=True, check=True + ) + + self.assertFalse( + "Python replay stack is empty during pop operation" in result.stderr + ) + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index fd672a48502a5..7c6792f5e6986 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -674,6 +674,9 @@ struct ThreadLocalResults { CallTypeHelper::tuple_type trace_keys_; AppendOnlyList exit_times_; AppendOnlyList c_exit_times_; + + int active_frames_{0}; + int remaining_start_frames_{0}; }; // ============================================================================ @@ -999,7 +1002,8 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) PyThreadState_Swap(thread_state); thread_local_results_.emplace_back(thread_state, &value_cache_, this); - auto* ctx = thread_local_results_.back().ctx_; + auto& tls = thread_local_results_.back(); + auto* ctx = tls.ctx_; // When we begin profiling there are already frames on the Python // interpreter stack. To ensure a complete trace, we must push calls @@ -1021,7 +1025,7 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) } for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { - recordPyCall(thread_local_results_.back(), it->get(), true); + recordPyCall(tls, it->get(), true); auto frame_refcount = Py_REFCNT(it->get()); // We hold one reference in `current_stack`, and the interpreter holds @@ -1029,6 +1033,8 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount); } + tls.remaining_start_frames_ = tls.active_frames_; + // Note: // This profile will not compose with other CPython profilers, and // cannot be round tripped via `sys.settrace(sys.gettrace())` @@ -1141,6 +1147,7 @@ void PythonTracer::recordPyCall( const auto time = c10::getApproximateTime(); is_startup_frame ? start_frames_.push_back({key, time}) : queue_->getSubqueue()->emplace_py_call(key, time); + ++tls.active_frames_; } void PythonTracer::recordCCall( @@ -1160,6 +1167,7 @@ void PythonTracer::recordCCall( auto key = tls.intern( arg, (void*)(fn->m_ml), frame); queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime()); + ++tls.active_frames_; } // ============================================================================ @@ -1457,11 +1465,20 @@ int PythonTracer::pyProfileFn( case PyTrace_RETURN: local_results.exit_times_.emplace_back(c10::getApproximateTime()); + local_results.active_frames_--; + if (local_results.active_frames_ < + local_results.remaining_start_frames_) { + local_results.remaining_start_frames_ = local_results.active_frames_; + } break; case PyTrace_C_EXCEPTION: case PyTrace_C_RETURN: - local_results.c_exit_times_.emplace_back(c10::getApproximateTime()); + if (local_results.active_frames_ > + local_results.remaining_start_frames_) { + local_results.c_exit_times_.emplace_back(c10::getApproximateTime()); + local_results.active_frames_--; + } break; } return 0; From 1bb5e6c076990b55d0704ee2fcfc49551e609c7b Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 5 Aug 2025 07:20:57 -0700 Subject: [PATCH 0082/1424] update expected results (#159867) refresh due to https://github.com/pytorch/pytorch/pull/159696 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159867 Approved by: https://github.com/masnesral --- benchmarks/dynamo/pr_time_benchmarks/expected_results.csv | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 5398c40f3573a..debddc5c7fa36 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,1009000000,0.1 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1 @@ -82,7 +82,7 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1 -basic_NestedModule_eager,compile_time_instruction_count,8787000000,0.1 +basic_NestedModule_eager,compile_time_instruction_count,9199000000,0.1 From 2ba2f598f3d6b9b656ce850a6b58be99b2d7b162 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Tue, 5 Aug 2025 14:07:12 +0000 Subject: [PATCH 0083/1424] [Dynamo] Add torch.xpu.stream to trace rules (#159844) # Motivation Previously, I thought using `with stream:` was sufficient. However, many older scripts still use `torch.xpu.stream` as the context manager. To maintain backward compatibility, I had to include `torch.xpu.stream` in the trace rules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159844 Approved by: https://github.com/jansel --- torch/_dynamo/trace_rules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index a3beb561f1866..56b5e508f058e 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2963,6 +2963,7 @@ "torch.xpu.random.seed_all", "torch.xpu.random.seed", "torch.xpu.set_stream", + "torch.xpu.stream", "torch.xpu.synchronize", ], TorchInGraphFunctionVariable, From 38d65c64658928929a5c70114b56041096aaf0dd Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 6 Aug 2025 15:42:31 -0700 Subject: [PATCH 0084/1424] Add a USE_NIGHTLY option to setup.py (#159965) If you run python setup.py develop with USE_NIGHTLY, instead of actually building PyTorch we will just go ahead and download the corresponding nightly version you specified and dump its binaries. This is intended to obsolete tools/nightly.py. There's some UX polish for detecting what the latest nightly is if you pass in a blank string. I only tested on OS X. Coded with claude code. Signed-off-by: Edward Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/159965 Approved by: https://github.com/malfet --- setup.py | 372 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 371 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 189a78c23bbb6..e30896a2fdf4e 100644 --- a/setup.py +++ b/setup.py @@ -229,6 +229,11 @@ # # BUILD_PYTHON_ONLY # Builds pytorch as a wheel using libtorch.so from a separate wheel +# +# USE_NIGHTLY=VERSION +# Skip cmake build and instead download and extract nightly PyTorch wheel +# matching the specified version (e.g., USE_NIGHTLY="2.8.0.dev20250608+cpu") +# into the local directory for development use from __future__ import annotations @@ -266,8 +271,10 @@ import shutil import subprocess import sysconfig +import tempfile import textwrap import time +import zipfile from collections import defaultdict from pathlib import Path from typing import Any, ClassVar, IO @@ -588,9 +595,372 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") +# ATTENTION: THIS IS AI SLOP +def extract_variant_from_version(version: str) -> str: + """Extract variant from version string, defaulting to 'cpu'.""" + import re + + variant_match = re.search(r"\+([^-\s,)]+)", version) + return variant_match.group(1) if variant_match else "cpu" + + +# ATTENTION: THIS IS AI SLOP +def get_nightly_git_hash(version: str) -> str: + """Download a nightly wheel and extract the git hash from its version.py file.""" + # Extract variant from version to construct correct URL + variant = extract_variant_from_version(version) + nightly_index_url = f"https://download.pytorch.org/whl/nightly/{variant}/" + + torch_version_spec = f"torch=={version}" + + # Create a temporary directory for downloading + with tempfile.TemporaryDirectory(prefix="pytorch-hash-extract-") as temp_dir: + temp_path = Path(temp_dir) + + # Download the wheel + report(f"-- Downloading {version} wheel to extract git hash...") + download_cmd = [ + "uvx", + "pip", + "download", + "--index-url", + nightly_index_url, + "--pre", + "--no-deps", + "--dest", + str(temp_path), + torch_version_spec, + ] + + result = subprocess.run(download_cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"Failed to download {version} wheel for git hash extraction: {result.stderr}" + ) + + # Find the downloaded wheel file + wheel_files = list(temp_path.glob("torch-*.whl")) + if not wheel_files: + raise RuntimeError(f"No torch wheel found after downloading {version}") + + wheel_file = wheel_files[0] + + # Extract the wheel and look for version.py + with tempfile.TemporaryDirectory( + prefix="pytorch-wheel-extract-" + ) as extract_dir: + extract_path = Path(extract_dir) + + with zipfile.ZipFile(wheel_file, "r") as zip_ref: + zip_ref.extractall(extract_path) + + # Find torch directory and version.py + torch_dirs = list(extract_path.glob("torch")) + if not torch_dirs: + torch_dirs = list(extract_path.glob("*/torch")) + + if not torch_dirs: + raise RuntimeError(f"Could not find torch directory in {version} wheel") + + version_file = torch_dirs[0] / "version.py" + if not version_file.exists(): + raise RuntimeError(f"Could not find version.py in {version} wheel") + + # Read and parse version.py to extract git_version (nightly branch commit) + from ast import literal_eval + + nightly_commit = None + with version_file.open(encoding="utf-8") as f: + for line in f: + if line.strip().startswith("git_version"): + try: + # Parse the git_version assignment, e.g., git_version = "abc123def456" + nightly_commit = literal_eval( + line.partition("=")[2].strip() + ) + break + except (ValueError, SyntaxError): + continue + + if not nightly_commit: + raise RuntimeError( + f"Could not parse git_version from {version} wheel's version.py" + ) + + # Now fetch the nightly branch and extract the real source commit from the message + report("-- Fetching nightly branch to extract source commit...") + + # Fetch only the nightly branch + subprocess.check_call(["git", "fetch", "origin", "nightly"], cwd=str(CWD)) + + # Get the commit message from the nightly commit + commit_message = subprocess.check_output( + ["git", "show", "--no-patch", "--format=%s", nightly_commit], + cwd=str(CWD), + text=True, + ).strip() + + # Parse the commit message to extract the real hash + # Format: "2025-08-06 nightly release (74a754aae98aabc2aca67e5edb41cc684fae9a82)" + import re + + hash_match = re.search(r"\(([0-9a-fA-F]{40})\)", commit_message) + if hash_match: + real_commit = hash_match.group(1) + report(f"-- Extracted source commit: {real_commit[:12]}...") + return real_commit + else: + raise RuntimeError( + f"Could not parse commit hash from nightly commit message: {commit_message}" + ) + + +# ATTENTION: THIS IS AI SLOP +def get_latest_nightly_version(variant: str = "cpu") -> str: + """Get the latest available nightly version using pip to query the PyTorch nightly index.""" + # Get the latest available nightly version for the specified variant + nightly_index_url = f"https://download.pytorch.org/whl/nightly/{variant}/" + + # Run pip index to get available versions + output = subprocess.check_output( + [ + "uvx", + "pip", + "index", + "versions", + "--index-url", + nightly_index_url, + "--pre", + "torch", + ], + text=True, + timeout=30, + ) + + # Parse the first line to get the latest version + # Format: "torch (2.9.0.dev20250806)" or "torch (2.9.0.dev20250806+cpu)" + first_line = output.strip().split("\n")[0] + if "(" in first_line and ")" in first_line: + # Extract version from parentheses exactly as reported + version = first_line.split("(")[1].split(")")[0] + return version + + raise RuntimeError(f"Could not parse version from pip index output: {first_line}") + + +# ATTENTION: THIS IS AI SLOP +def download_and_extract_nightly_wheel(version: str) -> None: + """Download and extract nightly PyTorch wheel for USE_NIGHTLY=VERSION builds.""" + + # Extract variant from version (e.g., cpu, cu121, cu118, rocm5.7) + variant = extract_variant_from_version(version) + nightly_index_url = f"https://download.pytorch.org/whl/nightly/{variant}/" + + # Construct the full torch version spec + torch_version_spec = f"torch=={version}" + + # Create a temporary directory for downloading + with tempfile.TemporaryDirectory(prefix="pytorch-nightly-") as temp_dir: + temp_path = Path(temp_dir) + + # Use pip to download the specific nightly wheel + download_cmd = [ + "uvx", + "pip", + "download", + "--index-url", + nightly_index_url, + "--pre", + "--no-deps", + "--dest", + str(temp_path), + torch_version_spec, + ] + + report("-- Downloading nightly PyTorch wheel...") + result = subprocess.run(download_cmd, capture_output=True, text=True) + if result.returncode != 0: + # Try to get the latest nightly version for the same variant to help the user + variant = extract_variant_from_version(version) + try: + report(f"-- Detecting latest {variant} nightly version...") + latest_version = get_latest_nightly_version(variant) + error_msg = f"Failed to download nightly wheel for version {version}: {result.stderr.strip()}" + error_msg += ( + f"\n\nLatest available {variant} nightly version: {latest_version}" + ) + error_msg += f'\nTry: USE_NIGHTLY="{latest_version}"' + + # Also get the git hash for the latest version + git_hash = get_nightly_git_hash(latest_version) + error_msg += f"\n\nIMPORTANT: You must checkout the matching source commit:\ngit checkout {git_hash}" + except Exception: + # If we can't get latest for this variant, try CPU as fallback + try: + report("-- Detecting latest CPU nightly version...") + latest_version = get_latest_nightly_version("cpu") + error_msg = f"Failed to download nightly wheel for version {version}: {result.stderr.strip()}" + error_msg += f"\n\nCould not find {variant} nightlies. Latest available CPU nightly version: {latest_version}" + error_msg += f'\nTry: USE_NIGHTLY="{latest_version}"' + except Exception: + error_msg = f"Failed to download nightly wheel for version {version}: {result.stderr.strip()}" + error_msg += "\n\nCould not determine latest nightly version. " + error_msg += "Check https://download.pytorch.org/whl/nightly/ for available versions." + + raise RuntimeError(error_msg) + + # Find the downloaded wheel file + wheel_files = list(temp_path.glob("torch-*.whl")) + if not wheel_files: + raise RuntimeError("No torch wheel found after download") + elif len(wheel_files) > 1: + raise RuntimeError(f"Multiple torch wheels found: {wheel_files}") + + wheel_file = wheel_files[0] + report(f"-- Downloaded wheel: {wheel_file.name}") + + # Extract the wheel + with tempfile.TemporaryDirectory( + prefix="pytorch-wheel-extract-" + ) as extract_dir: + extract_path = Path(extract_dir) + + # Use Python's zipfile to extract the wheel + with zipfile.ZipFile(wheel_file, "r") as zip_ref: + zip_ref.extractall(extract_path) + + # Find the torch directory in the extracted wheel + torch_dirs = list(extract_path.glob("torch")) + if not torch_dirs: + # Sometimes the torch directory might be nested + torch_dirs = list(extract_path.glob("*/torch")) + + if not torch_dirs: + raise RuntimeError("Could not find torch directory in extracted wheel") + + source_torch_dir = torch_dirs[0] + target_torch_dir = TORCH_DIR + + report( + f"-- Extracting wheel contents from {source_torch_dir} to {target_torch_dir}" + ) + + # Copy the essential files from the wheel to our local directory + # Based on the file listing logic from tools/nightly.py + files_to_copy: list[Path] = [] + + # Get platform-specific binary files + if IS_LINUX: + files_to_copy.extend(source_torch_dir.glob("*.so")) + files_to_copy.extend( + (source_torch_dir / "lib").glob("*.so*") + if (source_torch_dir / "lib").exists() + else [] + ) + elif IS_DARWIN: + files_to_copy.extend(source_torch_dir.glob("*.so")) + files_to_copy.extend( + (source_torch_dir / "lib").glob("*.dylib") + if (source_torch_dir / "lib").exists() + else [] + ) + elif IS_WINDOWS: + files_to_copy.extend(source_torch_dir.glob("*.pyd")) + files_to_copy.extend( + (source_torch_dir / "lib").glob("*.lib") + if (source_torch_dir / "lib").exists() + else [] + ) + files_to_copy.extend( + (source_torch_dir / "lib").glob("*.dll") + if (source_torch_dir / "lib").exists() + else [] + ) + + # Add essential directories and files + essential_items = ["version.py", "bin", "include", "lib"] + for item_name in essential_items: + item_path = source_torch_dir / item_name + if item_path.exists(): + files_to_copy.append(item_path) + + # Add testing internal generated files + testing_generated = source_torch_dir / "testing" / "_internal" / "generated" + if testing_generated.exists(): + files_to_copy.append(testing_generated) + + # Copy all the files and directories + for src_path in files_to_copy: + rel_path = src_path.relative_to(source_torch_dir) + dst_path = target_torch_dir / rel_path + + # Copy files and directories, preserving existing subdirectories + if src_path.is_dir(): + # Create destination directory if it doesn't exist + dst_path.mkdir(parents=True, exist_ok=True) + # Copy individual entries from source directory + for src_item in src_path.iterdir(): + dst_item = dst_path / src_item.name + if src_item.is_dir(): + # Recursively copy subdirectories (this will preserve existing ones) + shutil.copytree(src_item, dst_item, dirs_exist_ok=True) + else: + # Copy individual files, overwriting existing ones + shutil.copy2(src_item, dst_item) + else: + # For files, remove existing and copy new + if dst_path.exists(): + dst_path.unlink() + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_path, dst_path) + + report(f" Copied {rel_path}") + + report("-- Nightly wheel extraction completed") + + # all the work we need to do _before_ setup runs def build_deps() -> None: report(f"-- Building version {TORCH_VERSION}") + + # ATTENTION: THIS IS AI SLOP + # Check for USE_NIGHTLY=VERSION to bypass normal build and download nightly wheel + nightly_version = os.getenv("USE_NIGHTLY") + if nightly_version is not None: + import re + + if ( + nightly_version == "" + or nightly_version == "cpu" + or re.match(r"^cu\d+$", nightly_version) + or re.match(r"^rocm\d+\.\d+$", nightly_version) + ): + # Empty string or variant-only specification, show error with latest version + variant = "cpu" if nightly_version == "" else nightly_version + report(f"-- Detecting latest {variant} nightly version...") + latest_version = get_latest_nightly_version(variant) + # Also get the git hash to tell user which commit to checkout + git_hash = get_nightly_git_hash(latest_version) + + if nightly_version == "": + error_msg = f"USE_NIGHTLY cannot be empty. Latest available version: {latest_version}\n" + else: + error_msg = ( + "USE_NIGHTLY requires a specific version, not just a variant. " + "Latest available {nightly_version} version: {latest_version}\n" + ) + + error_msg += f'Try: USE_NIGHTLY="{latest_version}"' + error_msg += f"\n\nIMPORTANT: You must checkout the matching source commit for this binary:\ngit checkout {git_hash}" + raise RuntimeError(error_msg) + else: + # Full version specification + report( + f"-- USE_NIGHTLY={nightly_version} detected, downloading nightly wheel" + ) + download_and_extract_nightly_wheel(nightly_version) + return + check_submodules() check_pydep("yaml", "pyyaml") build_pytorch( @@ -750,7 +1120,7 @@ def _embed_libomp(self) -> None: def run(self) -> None: # Report build options. This is run after the build completes so # `CMakeCache.txt` exists # and we can get an accurate report on what is used and what is not. - cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) + cmake_cache_vars = get_cmake_cache_vars() if cmake_cache_vars["USE_NUMPY"]: report("-- Building with NumPy bindings") else: From d0226719a956ef891105f7cddcec39c415fbb177 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 6 Aug 2025 13:34:54 -0700 Subject: [PATCH 0085/1424] [BE][EZ] Delete remains of split-build logic (#159990) Hopefully last piece of https://github.com/pytorch/pytorch/issues/138750 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159990 Approved by: https://github.com/atalman ghstack dependencies: #159986 --- .ci/manywheel/build_common.sh | 33 +--- .ci/pytorch/build.sh | 13 +- .ci/wheel/build_wheel.sh | 11 +- .circleci/scripts/binary_linux_test.sh | 12 +- .circleci/scripts/binary_populate_env.sh | 1 - .circleci/scripts/binary_upload.sh | 4 - .../actions/test-pytorch-binary/action.yml | 1 - .../scripts/generate_binary_build_matrix.py | 13 -- .github/scripts/generate_ci_workflows.py | 35 ---- .github/templates/upload.yml.j2 | 5 - .github/workflows/_binary-build-linux.yml | 10 - .github/workflows/_binary-test-linux.yml | 9 - .github/workflows/_binary-upload.yml | 8 - .github/workflows/_linux-build.yml | 1 - ...linux-aarch64-binary-manywheel-nightly.yml | 30 --- .../generated-linux-binary-manywheel-main.yml | 6 - ...nerated-linux-binary-manywheel-nightly.yml | 171 ------------------ ...rated-linux-binary-manywheel-rocm-main.yml | 2 - ...d-linux-s390x-binary-manywheel-nightly.yml | 15 -- tools/packaging/split_wheel.py | 109 ----------- 20 files changed, 7 insertions(+), 482 deletions(-) delete mode 100644 tools/packaging/split_wheel.py diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index 49549c9f2994e..4c268befb30e5 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -138,28 +138,11 @@ fi echo "Calling setup.py bdist at $(date)" -if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" - time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ - BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 \ +time CMAKE_ARGS=${CMAKE_ARGS[@]} \ + EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR - echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" - echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" - time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ - BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 \ - BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ - USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ - CMAKE_FRESH=1 python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR - echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" -else - time CMAKE_ARGS=${CMAKE_ARGS[@]} \ - EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ - BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ - USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ - python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR -fi echo "Finished setup.py bdist at $(date)" # Build libtorch packages @@ -272,10 +255,6 @@ ls /tmp/$WHEELHOUSE_DIR mkdir -p "/$WHEELHOUSE_DIR" mv /tmp/$WHEELHOUSE_DIR/torch*linux*.whl /$WHEELHOUSE_DIR/ -if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - mv /tmp/$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/ || true -fi - if [[ -n "$BUILD_PYTHONLESS" ]]; then mkdir -p /$LIBTORCH_HOUSE_DIR mv /tmp/$LIBTORCH_HOUSE_DIR/*.zip /$LIBTORCH_HOUSE_DIR @@ -452,16 +431,8 @@ if [[ -z "$BUILD_PYTHONLESS" ]]; then pushd $PYTORCH_ROOT/test # Install the wheel for this Python version - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - pip uninstall -y "$TORCH_NO_PYTHON_PACKAGE_NAME" || true - fi - pip uninstall -y "$TORCH_PACKAGE_NAME" - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - pip install "$TORCH_NO_PYTHON_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v - fi - pip install "$TORCH_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v # Print info on the libraries installed in this wheel diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 34982ac9b3233..c7d2cb93a64b9 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -265,22 +265,13 @@ else WERROR=1 python setup.py clean - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - python3 tools/packaging/split_wheel.py bdist_wheel - else - WERROR=1 python setup.py bdist_wheel - fi + WERROR=1 python setup.py bdist_wheel else python setup.py clean if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then source .ci/pytorch/install_cache_xla.sh fi - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - echo "USE_SPLIT_BUILD cannot be used with xla or rocm" - exit 1 - else - python setup.py bdist_wheel - fi + python setup.py bdist_wheel fi pip_install_whl "$(echo dist/*.whl)" diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 0c6857f62b249..b90e6f38e9111 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -199,16 +199,7 @@ export BUILD_TEST=OFF pushd "$pytorch_rootdir" echo "Calling setup.py bdist_wheel at $(date)" -if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" - BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel -d "$whl_tmp_dir" - echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" - echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" - BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 CMAKE_FRESH=1 python setup.py bdist_wheel -d "$whl_tmp_dir" - echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" -else - python setup.py bdist_wheel -d "$whl_tmp_dir" -fi +python setup.py bdist_wheel -d "$whl_tmp_dir" echo "Finished setup.py bdist_wheel at $(date)" diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 11678cabb2c31..c24a50b8b17ed 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -65,16 +65,8 @@ fi if [[ "$PACKAGE_TYPE" != libtorch ]]; then if [[ "\$BUILD_ENVIRONMENT" != *s390x* ]]; then - if [[ "$USE_SPLIT_BUILD" == "true" ]]; then - pkg_no_python="$(ls -1 /final_pkgs/torch_no_python* | sort |tail -1)" - pkg_torch="$(ls -1 /final_pkgs/torch-* | sort |tail -1)" - # todo: after folder is populated use the pypi_pkg channel instead - pip install "\$pkg_no_python" "\$pkg_torch" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}_pypi_pkg" - retry pip install -q numpy protobuf typing-extensions - else - pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}" - retry pip install -q numpy protobuf typing-extensions - fi + pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}" + retry pip install -q numpy protobuf typing-extensions else pip install "\$pkg" retry pip install -q numpy protobuf typing-extensions diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 7f89c5c2dd8e6..87fea14b8d285 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -134,7 +134,6 @@ export DESIRED_PYTHON="${DESIRED_PYTHON:-}" export DESIRED_CUDA="$DESIRED_CUDA" export LIBTORCH_VARIANT="${LIBTORCH_VARIANT:-}" export BUILD_PYTHONLESS="${BUILD_PYTHONLESS:-}" -export USE_SPLIT_BUILD="${USE_SPLIT_BUILD:-}" if [[ "${OSTYPE}" == "msys" ]]; then export LIBTORCH_CONFIG="${LIBTORCH_CONFIG:-}" if [[ "${LIBTORCH_CONFIG:-}" == 'debug' ]]; then diff --git a/.circleci/scripts/binary_upload.sh b/.circleci/scripts/binary_upload.sh index cf87748d538ce..6c4aa8bee1dfd 100755 --- a/.circleci/scripts/binary_upload.sh +++ b/.circleci/scripts/binary_upload.sh @@ -23,10 +23,6 @@ if [[ "${DRY_RUN}" = "disabled" ]]; then AWS_S3_CP="aws s3 cp" fi -if [[ "${USE_SPLIT_BUILD:-false}" == "true" ]]; then - UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_pypi_pkg" -fi - # this is special build with all dependencies packaged if [[ ${BUILD_NAME} == *-full* ]]; then UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_full" diff --git a/.github/actions/test-pytorch-binary/action.yml b/.github/actions/test-pytorch-binary/action.yml index 63acd791b85c6..d4b8be8b609a0 100644 --- a/.github/actions/test-pytorch-binary/action.yml +++ b/.github/actions/test-pytorch-binary/action.yml @@ -24,7 +24,6 @@ runs: -e PYTORCH_FINAL_PACKAGE_DIR \ -e PYTORCH_ROOT \ -e SKIP_ALL_TESTS \ - -e USE_SPLIT_BUILD \ --tty \ --detach \ -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index def91d29f2bd2..ce4a44953413b 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -273,7 +273,6 @@ def generate_wheels_matrix( os: str, arches: Optional[list[str]] = None, python_versions: Optional[list[str]] = None, - use_split_build: bool = False, ) -> list[dict[str, str]]: package_type = "wheel" if os == "linux" or os == "linux-aarch64" or os == "linux-s390x": @@ -321,15 +320,6 @@ def generate_wheels_matrix( ): continue - if use_split_build and ( - arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux" - ): - raise RuntimeError( - "Split build is only supported on linux with cuda 12* and cpu.\n" - f"Currently attempting to build on arch version {arch_version} and os {os}.\n" - "Please modify the matrix generation to exclude this combination." - ) - # cuda linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install if ( @@ -344,7 +334,6 @@ def generate_wheels_matrix( "gpu_arch_type": gpu_arch_type, "gpu_arch_version": gpu_arch_version, "desired_cuda": desired_cuda, - "use_split_build": "True" if use_split_build else "False", "container_image": WHEEL_CONTAINER_IMAGES[arch_version].split( ":" )[0], @@ -377,7 +366,6 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), - "use_split_build": "True" if use_split_build else "False", "container_image": WHEEL_CONTAINER_IMAGES[ arch_version ].split(":")[0], @@ -400,7 +388,6 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), - "use_split_build": "True" if use_split_build else "False", "container_image": WHEEL_CONTAINER_IMAGES[arch_version].split( ":" )[0], diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 9dfed6d00df8f..b0849ca0f8524 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -60,7 +60,6 @@ class BinaryBuildWorkflow: branches: str = "nightly" # Mainly for macos macos_runner: str = "macos-14-xlarge" - use_split_build: bool = False # Mainly used for libtorch builds build_variant: str = "" @@ -71,9 +70,6 @@ def __post_init__(self) -> None: for item in [self.os, "binary", self.package_type, self.build_variant] if item != "" ) - if self.use_split_build: - # added to distinguish concurrency groups - self.build_environment += "-split" def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: output_file_path = ( @@ -116,21 +112,6 @@ class OperatingSystem: isolated_workflow=True, ), ), - # See https://github.com/pytorch/pytorch/issues/138750 - # BinaryBuildWorkflow( - # os=OperatingSystem.LINUX, - # package_type="manywheel", - # build_configs=generate_binary_build_matrix.generate_wheels_matrix( - # OperatingSystem.LINUX, - # use_split_build=True, - # arches=["11.8", "12.1", "12.4", "cpu"], - # ), - # ciflow_config=CIFlowConfig( - # labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, - # isolated_workflow=True, - # ), - # use_split_build=True, - # ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", @@ -179,22 +160,6 @@ class OperatingSystem: ), branches="main", ), - # See https://github.com/pytorch/pytorch/issues/138750 - # BinaryBuildWorkflow( - # os=OperatingSystem.LINUX, - # package_type="manywheel", - # build_configs=generate_binary_build_matrix.generate_wheels_matrix( - # OperatingSystem.LINUX, - # arches=["11.8", "12.1", "12.4"], - # python_versions=["3.9"], - # use_split_build=True, - # ), - # ciflow_config=CIFlowConfig( - # labels={LABEL_CIFLOW_PERIODIC}, - # ), - # branches="main", - # use_split_build=True, - # ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", diff --git a/.github/templates/upload.yml.j2 b/.github/templates/upload.yml.j2 index f159d623f1bf7..763784f5f3e1e 100644 --- a/.github/templates/upload.yml.j2 +++ b/.github/templates/upload.yml.j2 @@ -25,11 +25,6 @@ DOCKER_IMAGE: !{{ config["container_image"] }} DOCKER_IMAGE_TAG_PREFIX: !{{ config["container_image_tag_prefix"] }} {%- endif %} -{%- if config["package_type"] == "manywheel" %} - {%- if config.use_split_build is defined %} - use_split_build: !{{ config["use_split_build"] }} - {%- endif %} -{%- endif %} {%- if config["package_type"] == "libtorch" %} {%- if config["libtorch_config"] %} LIBTORCH_CONFIG: !{{ config["libtorch_config"] }} diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index f11ee4a6621e1..bfa035bc753b8 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -26,13 +26,6 @@ on: default: 240 type: number description: timeout for the job - use_split_build: - description: | - [Experimental] Build a libtorch only wheel and build pytorch such that - are built from the libtorch wheel. - required: false - type: boolean - default: false ALPINE_IMAGE: required: false type: string @@ -117,7 +110,6 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - USE_SPLIT_BUILD: ${{ inputs.use_split_build }} steps: - name: Make the env permanent during this workflow (but not the secrets) shell: bash @@ -142,7 +134,6 @@ jobs: echo "PR_NUMBER=${{ env.PR_NUMBER }}" echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" echo "SHA1=${{ env.SHA1 }}" - echo "USE_SPLIT_BUILD=${{ env.use_split_build }}" } >> "${GITHUB_ENV} }}" - name: List the env @@ -261,7 +252,6 @@ jobs: -e PYTORCH_ROOT \ -e SKIP_ALL_TESTS \ -e PYTORCH_EXTRA_INSTALL_REQUIREMENTS \ - -e USE_SPLIT_BUILD \ --tty \ --detach \ -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index 434167d0f0c6d..476dd182db0f8 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -64,13 +64,6 @@ on: required: true type: string description: Hardware to run this job on. Valid values are linux.4xlarge, linux.4xlarge.nvidia.gpu, linux.arm64.2xlarge, and linux.rocm.gpu - use_split_build: - description: | - [Experimental] Build a libtorch only wheel and build pytorch such that - are built from the libtorch wheel. - required: false - type: boolean - default: false secrets: github-token: required: true @@ -104,7 +97,6 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - USE_SPLIT_BUILD: ${{ inputs.use_split_build }} steps: - name: Make the env permanent during this workflow (but not the secrets) shell: bash @@ -129,7 +121,6 @@ jobs: echo "PR_NUMBER=${{ env.PR_NUMBER }}" echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" echo "SHA1=${{ env.SHA1 }}" - echo "USE_SPLIT_BUILD=${{ env.USE_SPLIT_BUILD }}" } >> "${GITHUB_ENV} }}" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" diff --git a/.github/workflows/_binary-upload.yml b/.github/workflows/_binary-upload.yml index 6750102b5a293..636b76d42931a 100644 --- a/.github/workflows/_binary-upload.yml +++ b/.github/workflows/_binary-upload.yml @@ -51,13 +51,6 @@ on: required: false type: string description: Desired python version - use_split_build: - description: | - [Experimental] Build a libtorch only wheel and build pytorch such that - are built from the libtorch wheel. - required: false - type: boolean - default: false secrets: github-token: required: true @@ -86,7 +79,6 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - USE_SPLIT_BUILD: ${{ inputs.use_split_build }} steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index 5173425009f69..4d46de4b86576 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -306,7 +306,6 @@ jobs: -e OUR_GITHUB_JOB_ID \ -e HUGGING_FACE_HUB_TOKEN \ -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ - -e USE_SPLIT_BUILD \ -e BUILD_ADDITIONAL_PACKAGES \ --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \ --memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \ diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 8cde3006e3816..757eadc0cc043 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -60,7 +60,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -84,7 +83,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -108,7 +106,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 secrets: @@ -129,7 +126,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -156,7 +152,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda-aarch64-12_9 secrets: @@ -176,7 +171,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -200,7 +194,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -224,7 +217,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 secrets: @@ -245,7 +237,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -272,7 +263,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda-aarch64-12_9 secrets: @@ -292,7 +282,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -316,7 +305,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -340,7 +328,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 secrets: @@ -361,7 +348,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -388,7 +374,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda-aarch64-12_9 secrets: @@ -408,7 +393,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -432,7 +416,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -456,7 +439,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 secrets: @@ -477,7 +459,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -504,7 +485,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda-aarch64-12_9 secrets: @@ -524,7 +504,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -548,7 +527,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -572,7 +550,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-aarch64 secrets: @@ -593,7 +570,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -620,7 +596,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda-aarch64-12_9 secrets: @@ -640,7 +615,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -664,7 +638,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -688,7 +661,6 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu-aarch64 secrets: @@ -709,7 +681,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -736,7 +707,6 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda-aarch64-12_9 secrets: diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index d1e89bb6e2d85..c532d5774b530 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -56,7 +56,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_6 @@ -80,7 +79,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_6 build_environment: linux-binary-manywheel @@ -103,7 +101,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_8 @@ -127,7 +124,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_8 build_environment: linux-binary-manywheel @@ -150,7 +146,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_9 @@ -174,7 +169,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_9 build_environment: linux-binary-manywheel diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 464bef0e1f7db..e68d26c669ad5 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -60,7 +60,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu @@ -82,7 +81,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel @@ -105,7 +103,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu secrets: @@ -126,7 +123,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_6 @@ -150,7 +146,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_6 build_environment: linux-binary-manywheel @@ -174,7 +169,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_6 secrets: @@ -195,7 +189,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_8 @@ -219,7 +212,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_8 build_environment: linux-binary-manywheel @@ -243,7 +235,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_8 secrets: @@ -264,7 +255,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_9 @@ -288,7 +278,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_9 build_environment: linux-binary-manywheel @@ -312,7 +301,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_9 secrets: @@ -333,7 +321,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_3 @@ -358,7 +345,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm @@ -426,7 +412,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_3 secrets: @@ -447,7 +432,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_4 @@ -472,7 +456,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm @@ -540,7 +523,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_4 secrets: @@ -560,7 +542,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-xpu @@ -585,7 +566,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.9" permissions: id-token: write @@ -653,7 +633,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-xpu secrets: @@ -673,7 +652,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu @@ -695,7 +673,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel @@ -718,7 +695,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu secrets: @@ -739,7 +715,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_6 @@ -763,7 +738,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel @@ -787,7 +761,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_6 secrets: @@ -808,7 +781,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_8 @@ -832,7 +804,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel @@ -856,7 +827,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_8 secrets: @@ -877,7 +847,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_9 @@ -901,7 +870,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_9 build_environment: linux-binary-manywheel @@ -925,7 +893,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_9 secrets: @@ -946,7 +913,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_3 @@ -971,7 +937,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -1039,7 +1004,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_3 secrets: @@ -1060,7 +1024,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_4 @@ -1085,7 +1048,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -1153,7 +1115,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_4 secrets: @@ -1173,7 +1134,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-xpu @@ -1198,7 +1158,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.10" permissions: id-token: write @@ -1266,7 +1225,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-xpu secrets: @@ -1286,7 +1244,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu @@ -1308,7 +1265,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel @@ -1331,7 +1287,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu secrets: @@ -1352,7 +1307,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_6 @@ -1376,7 +1330,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel @@ -1400,7 +1353,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_6 secrets: @@ -1421,7 +1373,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_8 @@ -1445,7 +1396,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel @@ -1469,7 +1419,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8 secrets: @@ -1490,7 +1439,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_8-full @@ -1513,7 +1461,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8-full build_environment: linux-binary-manywheel @@ -1537,7 +1484,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8-full secrets: @@ -1558,7 +1504,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_9 @@ -1582,7 +1527,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_9 build_environment: linux-binary-manywheel @@ -1606,7 +1550,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_9 secrets: @@ -1627,7 +1570,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_3 @@ -1652,7 +1594,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -1720,7 +1661,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_3 secrets: @@ -1741,7 +1681,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_4 @@ -1766,7 +1705,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -1834,7 +1772,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_4 secrets: @@ -1854,7 +1791,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-xpu @@ -1879,7 +1815,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.11" permissions: id-token: write @@ -1947,7 +1882,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-xpu secrets: @@ -1967,7 +1901,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu @@ -1989,7 +1922,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel @@ -2012,7 +1944,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu secrets: @@ -2033,7 +1964,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_6 @@ -2057,7 +1987,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel @@ -2081,7 +2010,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_6 secrets: @@ -2102,7 +2030,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_8 @@ -2126,7 +2053,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel @@ -2150,7 +2076,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_8 secrets: @@ -2171,7 +2096,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_9 @@ -2195,7 +2119,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_9 build_environment: linux-binary-manywheel @@ -2219,7 +2142,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_9 secrets: @@ -2240,7 +2162,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_3 @@ -2265,7 +2186,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm @@ -2333,7 +2253,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_3 secrets: @@ -2354,7 +2273,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_4 @@ -2379,7 +2297,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm @@ -2447,7 +2364,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_4 secrets: @@ -2467,7 +2383,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-xpu @@ -2492,7 +2407,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.12" permissions: id-token: write @@ -2560,7 +2474,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-xpu secrets: @@ -2580,7 +2493,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu @@ -2602,7 +2514,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu build_environment: linux-binary-manywheel @@ -2625,7 +2536,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu secrets: @@ -2646,7 +2556,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_6 @@ -2670,7 +2579,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel @@ -2694,7 +2602,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_6 secrets: @@ -2715,7 +2622,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_8 @@ -2739,7 +2645,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel @@ -2763,7 +2668,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_8 secrets: @@ -2784,7 +2688,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_9 @@ -2808,7 +2711,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_9 build_environment: linux-binary-manywheel @@ -2832,7 +2734,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_9 secrets: @@ -2853,7 +2754,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-rocm6_3 @@ -2878,7 +2778,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13" steps: - name: Setup ROCm @@ -2946,7 +2845,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-rocm6_3 secrets: @@ -2967,7 +2865,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-rocm6_4 @@ -2992,7 +2889,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13" steps: - name: Setup ROCm @@ -3060,7 +2956,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-rocm6_4 secrets: @@ -3080,7 +2975,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-xpu @@ -3105,7 +2999,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13" permissions: id-token: write @@ -3173,7 +3066,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-xpu secrets: @@ -3193,7 +3085,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cpu @@ -3215,7 +3106,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu build_environment: linux-binary-manywheel @@ -3238,7 +3128,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu secrets: @@ -3259,7 +3148,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_6 @@ -3283,7 +3171,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel @@ -3307,7 +3194,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_6 secrets: @@ -3328,7 +3214,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_8 @@ -3352,7 +3237,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel @@ -3376,7 +3260,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_8 secrets: @@ -3397,7 +3280,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_9 @@ -3421,7 +3303,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_9 build_environment: linux-binary-manywheel @@ -3445,7 +3326,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_9 secrets: @@ -3466,7 +3346,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-rocm6_3 @@ -3491,7 +3370,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13t" steps: - name: Setup ROCm @@ -3559,7 +3437,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-rocm6_3 secrets: @@ -3580,7 +3457,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-rocm6_4 @@ -3605,7 +3481,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13t" steps: - name: Setup ROCm @@ -3673,7 +3548,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-rocm6_4 secrets: @@ -3693,7 +3567,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-xpu @@ -3718,7 +3591,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13t" permissions: id-token: write @@ -3786,7 +3658,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-xpu secrets: @@ -3806,7 +3677,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cpu @@ -3828,7 +3698,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-cpu build_environment: linux-binary-manywheel @@ -3851,7 +3720,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-cpu secrets: @@ -3872,7 +3740,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_6 @@ -3896,7 +3763,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-cuda12_6 build_environment: linux-binary-manywheel @@ -3920,7 +3786,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-cuda12_6 secrets: @@ -3941,7 +3806,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_8 @@ -3965,7 +3829,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-cuda12_8 build_environment: linux-binary-manywheel @@ -3989,7 +3852,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-cuda12_8 secrets: @@ -4010,7 +3872,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_9 @@ -4034,7 +3895,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-cuda12_9 build_environment: linux-binary-manywheel @@ -4058,7 +3918,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-cuda12_9 secrets: @@ -4079,7 +3938,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-rocm6_3 @@ -4104,7 +3962,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.14" steps: - name: Setup ROCm @@ -4172,7 +4029,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-rocm6_3 secrets: @@ -4193,7 +4049,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-rocm6_4 @@ -4218,7 +4073,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.14" steps: - name: Setup ROCm @@ -4286,7 +4140,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-rocm6_4 secrets: @@ -4306,7 +4159,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-xpu @@ -4331,7 +4183,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.14" permissions: id-token: write @@ -4399,7 +4250,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.14" build_name: manywheel-py3_14-xpu secrets: @@ -4419,7 +4269,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cpu @@ -4441,7 +4290,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cpu build_environment: linux-binary-manywheel @@ -4464,7 +4312,6 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cpu secrets: @@ -4485,7 +4332,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_6 @@ -4509,7 +4355,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cuda12_6 build_environment: linux-binary-manywheel @@ -4533,7 +4378,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cuda12_6 secrets: @@ -4554,7 +4398,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_8 @@ -4578,7 +4421,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cuda12_8 build_environment: linux-binary-manywheel @@ -4602,7 +4444,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cuda12_8 secrets: @@ -4623,7 +4464,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_9 @@ -4647,7 +4487,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cuda12_9 build_environment: linux-binary-manywheel @@ -4671,7 +4510,6 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cuda12_9 secrets: @@ -4692,7 +4530,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-rocm6_3 @@ -4717,7 +4554,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.14t" steps: - name: Setup ROCm @@ -4785,7 +4621,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-rocm6_3 secrets: @@ -4806,7 +4641,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-rocm6_4 @@ -4831,7 +4665,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.14t" steps: - name: Setup ROCm @@ -4899,7 +4732,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-rocm6_4 secrets: @@ -4919,7 +4751,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-xpu @@ -4944,7 +4775,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.14t" permissions: id-token: write @@ -5012,7 +4842,6 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu - use_split_build: False DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-xpu secrets: diff --git a/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml b/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml index b6b63c4e38d5e..a3e5937fdcc4e 100644 --- a/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml @@ -58,7 +58,6 @@ jobs: GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_4 @@ -83,7 +82,6 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 66c0813afe900..9570f8d97a2db 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -60,7 +60,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.9" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -84,7 +83,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -107,7 +105,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x secrets: @@ -127,7 +124,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.10" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -151,7 +147,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -174,7 +169,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x secrets: @@ -194,7 +188,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.11" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -218,7 +211,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -241,7 +233,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x secrets: @@ -261,7 +252,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.12" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -285,7 +275,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -308,7 +297,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x secrets: @@ -328,7 +316,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.13" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -352,7 +339,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -375,7 +361,6 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x - use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x secrets: diff --git a/tools/packaging/split_wheel.py b/tools/packaging/split_wheel.py deleted file mode 100644 index fd52c39a22b02..0000000000000 --- a/tools/packaging/split_wheel.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Script to build split pytorch wheels - -What is split build / why is it important? - > Split build is splitting the PyTorch build into a libtorch & - > PyTorch python frontend package. This allows us to to publish - > both as separate packages and opens up our ability to have users - > install different libtorch backends per their PyTorch frontend - > - > Example: opening up the door to things like: - > pip install torch[cuda] - > pip install torch[rocm] - > pip install torch[cpu] - > etc. - -Why does this exist? - > Currently our split build requires you to invoke setup.py twice - > Which ends up complicating the build process and adds some level - > of complexity to our setup.py / build invocation for split builds. - > Ideally this script will eventually not be needed but for - > development purposes we should have an easy way to invoke this script -""" - -import argparse -import logging -import os -import subprocess -import sys -from pathlib import Path -from typing import Optional - - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -# NOTE: This will need to be updated if this script is ever moved -ROOT_PATH = Path(__file__).absolute().parents[2] -SETUP_PY_PATH = ROOT_PATH / "setup.py" - - -def requirements_installed() -> bool: - try: - import setuptools # type: ignore[import-untyped] # noqa: F401 - - return True - except ImportError: - logger.error( - "Requirements not installed, run the following command to install:" - ) - logger.error( - " > %s -m pip install -r %s/requirements.txt", sys.executable, ROOT_PATH - ) - return False - - -def setup_py(cmd_args: list[str], extra_env: Optional[dict[str, str]] = None) -> None: - if extra_env is None: - extra_env = {} - cmd = [sys.executable, str(SETUP_PY_PATH), *cmd_args] - logger.debug("+ %s", " ".join(cmd)) - subprocess.run( - cmd, - # Give the parent environment to the subprocess - env={**os.environ, **extra_env}, - check=True, - ) - - -def split_build(cmd: str) -> None: - logger.info("Running %s for libtorch wheel", cmd) - setup_py( - [cmd], - extra_env={"BUILD_LIBTORCH_WHL": "1", "BUILD_PYTHON_ONLY": "0"}, - ) - logger.info("Running %s for torch wheel", cmd) - # NOTE: Passing CMAKE_FRESH=1 is necessary here since the torch frontend has it's - # own cmake files that it needs to generate - setup_py( - [cmd], - extra_env={ - "BUILD_LIBTORCH_WHL": "0", - "BUILD_PYTHON_ONLY": "1", - "CMAKE_FRESH": "1", - }, - ) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - command_subparser = parser.add_subparsers(dest="command") - # Ideally these should mirror setuptools commands if we need support here for that - command_subparser.add_parser("install") - command_subparser.add_parser("bdist_wheel") - command_subparser.add_parser("develop") - return parser.parse_args() - - -def main() -> None: - args = parse_args() - if not requirements_installed(): - sys.exit(1) - split_build(args.command) - - -if __name__ == "__main__": - main() From 81d72fb1f7d42584688011c5a13d6b667539fe32 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 6 Aug 2025 13:34:55 -0700 Subject: [PATCH 0086/1424] Move smoke binary builds to 3.12 (#159993) And limit them just to stable CUDA version (as there weren't any recent instances when only one of those jobs failed to build) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159993 Approved by: https://github.com/ngimel ghstack dependencies: #159986, #159990 --- .github/scripts/generate_ci_workflows.py | 4 +- .../generated-linux-binary-manywheel-main.yml | 104 ++---------------- 2 files changed, 9 insertions(+), 99 deletions(-) diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index b0849ca0f8524..67906d4ad88d5 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -155,8 +155,8 @@ class OperatingSystem: package_type="manywheel", build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.LINUX, - arches=["12.6", "12.8", "12.9"], - python_versions=["3.9"], + arches=["12.8"], + python_versions=["3.12"], ), branches="main", ), diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index c532d5774b530..6387d75a73b50 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -42,52 +42,7 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - manywheel-py3_9-cuda12_6-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_6 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_6-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_6-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_6 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_8-build: + manywheel-py3_12-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -101,17 +56,17 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.9" + DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_8 + build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_8-test: # Testing + manywheel-py3_12-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_9-cuda12_8-build + - manywheel-py3_12-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -124,53 +79,8 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_9-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: 12.9 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_9 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_9-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_9-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: 12.9 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_9 + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner From d4c1a08c89f37d249a0146ff511c82ecc5c53b8f Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 6 Aug 2025 11:25:32 -0700 Subject: [PATCH 0087/1424] Relax unclaimed successes in dtype op tests when running under TEST_WITH_DYNAMO/TEST_WITH_INDUCTOR (#159976) This PR changes the behavior for compile wrapped op tests: - supported_but_unclaimed_forward - supported_but_unclaimed_backward These typically manifest when the op doesn't support inputs of certain dtypes. But under torch.compile, Dynamo/AOTAutograd will trace the graph with FakeTensors, which @ezyang and @eellison tell me need to run decomps before op dispatch. The decomp may map this test to a different op, one that does support the dtype. I suspect all of our failures here are due to decomps, and so I propose to just disable this check for compile. ~~TODO: re-enable all the failed tests.~~ jk there were no failed tests outside of compiled autograd due to this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159976 Approved by: https://github.com/ezyang --- test/test_ops.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 201b0323a86fd..2d5af9966690f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1601,6 +1601,16 @@ def _tensor_requires_grad(x): ) == 0: return + if TEST_WITH_TORCHDYNAMO: + # NOTE: Also for TEST_WITH_TORCHINDUCTOR tests + # Under compile, some ops may be decomposed into supported ops + # So it is okay to have supported_but_unclaimed_* + if ( + len(claimed_but_unsupported_forward) + + len(claimed_but_unsupported_backward) + ) == 0: + return + # Reference operators often support additional dtypes, and that's OK if op in python_ref_db: if ( From c859ba7114b1fcb49527e090745fa17091d1f8d5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 7 Aug 2025 04:06:04 +0000 Subject: [PATCH 0088/1424] Make onnx export SDPA match aten behavior (#159973) This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used. @justinchuby ```python import onnxruntime as ort import torch class ScaledDotProductAttention(torch.nn.Module): def forward(self, query, key, value, attn_mask): return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) model = ScaledDotProductAttention() attn_mask = torch.ones(2, 4, 8, 8).bool() # boolean mask for attention attn_mask[0, 0, 0, :] = False # masking an entire row (padding token) query = key = value = torch.randn(2, 4, 8, 16) output = model(query, key, value, attn_mask) torch.onnx.export( model, (query, key, value, attn_mask), "scaled_dot_product_attention.onnx", input_names=["query", "key", "value", "attn_mask"], output_names=["output"], dynamo=false, # or True, ) ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx") np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()} onnx_outputs = ort_session.run(None, np_inputs)[0] torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True) ``` fails the assertion because the ort model outputs nans. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159973 Approved by: https://github.com/xadupre, https://github.com/titaiwangms --- torch/onnx/symbolic_opset14.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 8bc6f0f9f4d26..80743c6a49121 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -177,6 +177,7 @@ def scaled_dot_product_attention( if symbolic_helper._is_none(attn_mask): mul_qk_add = mul_qk + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) elif ( _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL @@ -186,19 +187,24 @@ def scaled_dot_product_attention( const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) + # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values + # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output. + # This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match + # the behavior of PyTorch with boolean masks. + attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight) elif _type_utils.JitScalarType.from_value(attn_mask) in ( _type_utils.JitScalarType.FLOAT, _type_utils.JitScalarType.HALF, _type_utils.JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) else: raise ValueError( f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" ) - attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) - if dropout_p != 0: attn_weight = g.op( "Dropout", From 3f1636ebef9b45e8a3cb0eb20d327ee6acb74be0 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Thu, 7 Aug 2025 04:16:32 +0000 Subject: [PATCH 0089/1424] [audio hash update] update the pinned audio hash (#160046) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160046 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 5e75486031249..cdfbede9e8f09 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -6fbc710b617f79b992ef2ebc7f95e818aa390293 +0c22347335f4c9a5b92a2f5bad65e05e2464c184 From aa75e917bdb0f95bb6dee81853c2d3c4ab3e1883 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 7 Aug 2025 07:31:42 +0000 Subject: [PATCH 0090/1424] [Export Schema] Remove deviceAllocationMap field (#159653) Summary: This field is not used today, and it's not useful either. The device allocation is configured at model loading time, specified by user. It shouldn't be part of the model definition. Test Plan: CI Rollback Plan: Differential Revision: D79385513 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159653 Approved by: https://github.com/zhxchen17 --- torch/_export/serde/export_schema.thrift | 3 +-- torch/_export/serde/schema.py | 1 - torch/_export/serde/schema.yaml | 4 +--- torch/csrc/utils/generated_serialization_types.h | 13 +------------ 4 files changed, 3 insertions(+), 18 deletions(-) diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index 50472c02375cc..0b2f2b4fe7408 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<31664e4faa0eacd6f538ffed163078e190d9d2b98d762dd45b68eb1b7b12f0d1>> +// checksum<<0b6fec18525f05577f007055f774b5e6f143ca7499b931474d1f4cd4a5dc5004>> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -341,7 +341,6 @@ struct Model { 20: map tensorPaths; 40: Program program; 50: map delegates; - 60: map deviceAllocationMap; 70: map constantPaths; } diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 933d30310b72c..30bc119a54007 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -461,7 +461,6 @@ class Model: # Backend-specialized Lowered GraphModule # e.g. "aotinductor-a100" : ExportedProgram_with_AOTInductor_delegate delegates: Annotated[dict[str, Program], 50] - deviceAllocationMap: Annotated[dict[str, str], 60] # key is the FQN of constant in exported program (constant tensor or torchbind objs) # value is the archive path of serialized constants constantPaths: Annotated[dict[str, str], 70] diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 9167a6820ef40..56e40f309744e 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<5c990535d373dcaa291a4f994b4d7b025e0f8e806ca5268085ef699d0e4d3000>> +# checksum<<89a616d78254f20c027a2e0f882a3f8b096b4169c781d5dfd0254c8bce33cb35>> AOTInductorModelPickleData: kind: struct fields: @@ -304,8 +304,6 @@ Model: type: Program delegates: type: Dict[str, Program] - deviceAllocationMap: - type: Dict[str, str] constantPaths: type: Dict[str, str] ModuleCallEntry: diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index 14741e4d2c6e1..f93532ef9de23 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<5c990535d373dcaa291a4f994b4d7b025e0f8e806ca5268085ef699d0e4d3000>> +// checksum<<89a616d78254f20c027a2e0f882a3f8b096b4169c781d5dfd0254c8bce33cb35>> // clang-format off #pragma once @@ -3093,7 +3093,6 @@ class Model { std::unordered_map tensorPaths; Program program; std::unordered_map delegates; - std::unordered_map deviceAllocationMap; std::unordered_map constantPaths; public: @@ -3130,14 +3129,6 @@ class Model { delegates = std::move(def); } - const std::unordered_map& get_deviceAllocationMap() const { - return deviceAllocationMap; - } - - void set_deviceAllocationMap(std::unordered_map def) { - deviceAllocationMap = std::move(def); - } - const std::unordered_map& get_constantPaths() const { return constantPaths; } @@ -3515,7 +3506,6 @@ inline void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_ nlohmann_json_j["tensorPaths"] = nlohmann_json_t.tensorPaths; nlohmann_json_j["program"] = nlohmann_json_t.program; nlohmann_json_j["delegates"] = nlohmann_json_t.delegates; - nlohmann_json_j["deviceAllocationMap"] = nlohmann_json_t.deviceAllocationMap; nlohmann_json_j["constantPaths"] = nlohmann_json_t.constantPaths; } @@ -3525,7 +3515,6 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, Model& nlohmann_jso nlohmann_json_t.tensorPaths = nlohmann_json_j.value("tensorPaths", nlohmann_json_default_obj.tensorPaths); nlohmann_json_t.program = nlohmann_json_j.value("program", nlohmann_json_default_obj.program); nlohmann_json_t.delegates = nlohmann_json_j.value("delegates", nlohmann_json_default_obj.delegates); - nlohmann_json_t.deviceAllocationMap = nlohmann_json_j.value("deviceAllocationMap", nlohmann_json_default_obj.deviceAllocationMap); nlohmann_json_t.constantPaths = nlohmann_json_j.value("constantPaths", nlohmann_json_default_obj.constantPaths); } From 24f43d0da7ad9c6e95a09a2fee610387728cc1cd Mon Sep 17 00:00:00 2001 From: thenumberouscode Date: Thu, 7 Aug 2025 08:03:01 +0000 Subject: [PATCH 0091/1424] [inductor] [cpu] fix the dype hardcoded to int64 in store_reduction (#157904) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Fixes https://github.com/pytorch/pytorch/issues/157683 ## mini repro * Just copy the code from the issue to reproduce it. ```python import torch device = "cpu" # Input tensors v2_0 = torch.randn(16, 24, 59, dtype=torch.complex64, device=device) v3_0 = torch.randn(16, 24, 59, dtype=torch.complex64, device=device) def my_model(v2_0, v3_0): v6_0 = -v3_0 v4_0 = v2_0 * v3_0 v1_0 = v4_0.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) v0_0 = v2_0.to(torch.int32) v5_0 = v0_0.amax(dim=0) return v6_0, v4_0, v1_0, v0_0, v5_0 v6_0, v4_0, v1_0, v0_0, v5_0 = my_model(v2_0, v3_0) print("v6_0", v6_0.shape) print("v4_0", v4_0.shape) compiled_model = torch.compile(my_model, backend="inductor") v6_0, v4_0, v1_0, v0_0, v5_0 = compiled_model(v2_0, v3_0) print("v6_0", v6_0.shape) print("v4_0", v4_0.shape) print("v1_0", v1_0.shape) print("v0_0", v0_0.shape) print("v5_0", v5_0.shape) ``` error_stack ``` /home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注:candidate: ‘template std::enable_if_t<(! is_same_v), at::vec::CPU_CAPABILITY::Vectorized > at::vec::CPU_CAPABILITY::convert(const at::vec::CPU_CAPABILITY::Vectorized&)’ 41 | convert(const Vectorized& src) { | ^~~~~~~ /home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注: template argument deduction/substitution failed: /tmp/torchinductor_admin/6k/c6kr65o43rlmp2cmkpn5ezewhe5bla4w72hpcrg5biyelrs4skyw.main.cpp:37:99: 错误:模板参数数目不对(不应是 4 个而应是 2 个) 37 | auto int32_t_tmp_acc0_vec = at::vec::convert(tmp_acc0_vec); ``` ## summary **The C++ kernel generated by the Inductor had the wrong data type for the output variable; it should be int32_t instead of int64_t. This incorrect data type led to an incompatible data type conversion, which caused the g++ compilation to fail.** The original code that caused the problem. ``` def my_model(v2_0, v3_0): v6_0 = -v3_0 v4_0 = v2_0 * v3_0 v1_0 = v4_0.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) v0_0 = v2_0.to(torch.int32) // The original code that caused the problem. v5_0 = v0_0.amax(dim=0) ``` ## proof procedure The c++ kernel generated by inductor: ```c++ #include extern "C" void kernel(const int32_t* in_ptr0, int32_t* out_ptr0) { { for(int64_t x0=static_cast(0L); x0(1416L); x0+=static_cast(16L)) { { int32_t tmp_acc0_arr[16]; for (int i = 0; i < 16; i++) { tmp_acc0_arr[i] = std::numeric_limits::min(); } int32_t tmp_acc0 = std::numeric_limits::min(); at::vec::Vectorized tmp_acc0_vec = at::vec::Vectorized(std::numeric_limits::min()); for(int64_t x1=static_cast(0L); x1(16L); x1+=static_cast(1L)) { { if(C10_LIKELY(x0 >= static_cast(0) && x0 < static_cast(1408L))) { auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(x0 + 1416L*x1), static_cast(16)); tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp0); } if(C10_UNLIKELY(x0 >= static_cast(1408L) && x0 < static_cast(1416L))) { for (int64_t x0_tail = static_cast(1408L);x0_tail < static_cast(1416L); x0_tail++) { auto tmp0 = in_ptr0[static_cast(x0_tail + 1416L*x1)]; tmp_acc0_arr[x0_tail - static_cast(1408L)] = max_propagate_nan(tmp_acc0_arr[x0_tail - static_cast(1408L)], tmp0); } } } } if(C10_LIKELY(x0 >= static_cast(0) && x0 < static_cast(1408L))) { // impossible data type conversion which would caused the g++ compilation to fail. auto int32_t_tmp_acc0_vec = at::vec::convert(tmp_acc0_vec); int32_t_tmp_acc0_vec.store(out_ptr0 + static_cast(x0), static_cast(16)); } if(C10_UNLIKELY(x0 >= static_cast(1408L) && x0 < static_cast(1416L))) { for (int64_t x0_tail = static_cast(1408L);x0_tail < static_cast(1416L); x0_tail++) { out_ptr0[static_cast(x0_tail)] = tmp_acc0_arr[x0_tail - static_cast(1408L)]; } } } } } } ``` the compilers complains ```text /home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注:candidate: ‘template std::enable_if_t<(! is_same_v), at::vec::CPU_CAPABILITY::Vectorized > at::vec::CPU_CAPABILITY::convert(const at::vec::CPU_CAPABILITY::Vectorized&)’ 41 | convert(const Vectorized& src) { | ^~~~~~~ /home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注: template argument deduction/substitution failed: /tmp/torchinductor_admin/6k/c6kr65o43rlmp2cmkpn5ezewhe5bla4w72hpcrg5biyelrs4skyw.main.cpp:37:99: 错误:模板参数数目不对(不应是 4 个而应是 2 个) 37 | auto int32_t_tmp_acc0_vec = at::vec::convert(tmp_acc0_vec); ``` so the following line have problem ```c++ // this line means that tmp_acc0_vec should be Vectorized, and it will convert it to Vectorized. auto int32_t_tmp_acc0_vec = at::vec::convert(tmp_acc0_vec); ``` The issue is that tmp_acc0_vec is of type Vectorized, but the template parameters expect it to be Vectorized. and it will convert it to a Vectorized. this is conflict. the conversion should not be exist for tmp_acc0_vec is already Vectorized.The following line hardcodes the output variable type to int64, which causes unnecessary and incorrect type conversions. https://github.com/pytorch/pytorch/blob/d89f30ad45b9d4bfe5cf5ab441b53e849e55df7b/torch/_inductor/codegen/cpp.py#L2985-L2993 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157904 Approved by: https://github.com/jgong5 --- test/inductor/test_cpu_repro.py | 24 ++++++++++++++++++++++++ torch/_inductor/codegen/cpp.py | 9 ++++----- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 55c0a2977daf9..53b3e013a6b28 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3117,6 +3117,30 @@ def get_traj_idx(lengths: torch.Tensor, num_slices: int) -> torch.Tensor: lengths = torch.zeros(11, dtype=torch.long) get_traj_idx(lengths, num_slices=4) + def test_store_reduction(self): + # fix https://github.com/pytorch/pytorch/issues/157683 + def fn(x, y): + r1 = x.amax(dim=0) + r2 = y.amax(dim=0) + return r1, r2 + + device = "cpu" + for int_dypte, float_dtype in zip( + [torch.int64, torch.int32, torch.int16, torch.int8], + [torch.float64, torch.float32, torch.float16, torch.bfloat16], + ): + x = torch.randint( + low=0, high=100, size=(16, 24, 59), dtype=int_dypte, device=device + ) + y = torch.randn(16, 24, 59, dtype=float_dtype, device=device) + self.common( + fn, + ( + x, + y, + ), + ) + @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_sign_cpu_only(self): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e995faae26523..1ee9d033d4f97 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3218,11 +3218,10 @@ def store_reduction(self, name, index, value): index = self.rename_indexing(index) var = self.args.output(name) out_dtype = V.graph.get_dtype(name) - dtype = ( - (out_dtype if out_dtype == torch.double else torch.float) - if out_dtype.is_floating_point - else torch.int64 - ) + if out_dtype.is_floating_point and out_dtype != torch.double: + dtype = torch.float + else: + dtype = out_dtype out_num_vectors = V.kernel._get_num_vectors(out_dtype) src_num_vectors = V.kernel._get_num_vectors(dtype) code = IndentedBuffer() From 422bd6808bb98cbbac31d157d9c82ad11ba9732d Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Thu, 7 Aug 2025 08:22:41 +0000 Subject: [PATCH 0092/1424] dataclass pytree fix (#159916) Differential Revision: D79687243 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159916 Approved by: https://github.com/XuehaiPan, https://github.com/angelayi --- torch/utils/_pytree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 02954d33866cb..773e9f00e3d15 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -370,7 +370,7 @@ def _unflatten_fn(values: Iterable[Any], context: Context) -> Any: def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc] - return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + return [(GetAttrKey(k), v) for k, v in zip(flat_names, flattened)], flat_names _private_register_pytree_node( cls, From b0df7715e8c590c0001d1f9cdb97057be80c9107 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 7 Aug 2025 09:26:58 +0000 Subject: [PATCH 0093/1424] Remove benchmark dependencies from regular ROCm CI images (#160047) Instead, use a new `pytorch-linux-jammy-rocm-n-py3-benchmarks` image for Docker benchmark job. This addresses 2 issues: * The current ROCm failures in trunk w.r.t librosa version https://github.com/pytorch/pytorch/actions/runs/16789466749/job/47549950994 that TorchBench pulls in. * Reduce the size of the regular ROCm CI images by removing TorchBench models, which is needed only for benchmarking jobs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160047 Approved by: https://github.com/malfet, https://github.com/izaitsevfb --- .ci/docker/build.sh | 7 ++++--- .github/workflows/docker-builds.yml | 1 + .github/workflows/inductor-perf-test-nightly-rocm.yml | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 0bf0847c3400d..aabfbd5a47724 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -176,7 +176,7 @@ case "$tag" in VISION=yes TRITON=yes ;; - pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3) + pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3) if [[ $tag =~ "jammy" ]]; then ANACONDA_PYTHON_VERSION=3.10 else @@ -190,7 +190,9 @@ case "$tag" in KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} - INDUCTOR_BENCHMARKS=yes + if [[ $tag =~ "benchmarks" ]]; then + INDUCTOR_BENCHMARKS=yes + fi ;; pytorch-linux-noble-rocm-alpha-py3) ANACONDA_PYTHON_VERSION=3.12 @@ -202,7 +204,6 @@ case "$tag" in KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} - INDUCTOR_BENCHMARKS=yes PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950" ;; pytorch-linux-jammy-xpu-2025.0-py3) diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 548847944cd73..c83609facbd97 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -61,6 +61,7 @@ jobs: pytorch-linux-jammy-rocm-n-py3, pytorch-linux-noble-rocm-n-py3, pytorch-linux-noble-rocm-alpha-py3, + pytorch-linux-jammy-rocm-n-py3-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12, pytorch-linux-jammy-py3.9-gcc11, pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks, diff --git a/.github/workflows/inductor-perf-test-nightly-rocm.yml b/.github/workflows/inductor-perf-test-nightly-rocm.yml index 377f6d04bc8ce..1ec494ace6577 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm.yml @@ -85,7 +85,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-jammy-rocm-py3_10 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks test-matrix: | { include: [ { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, From 8cb91e20bc205b1416648d0ffd98d1ba1f3a6fc4 Mon Sep 17 00:00:00 2001 From: Dev Sashidhar Date: Thu, 7 Aug 2025 11:24:40 +0000 Subject: [PATCH 0094/1424] Renaming HAS_XPU to HAS_XPU_AND_TRITON (#159908) This PR follows up on the discussion in #159399 where @Akabbaj and @janeyx99 mentioned renaming HAS_XPU to HAS_XPU_AND_TRITON for consistency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159908 Approved by: https://github.com/janeyx99, https://github.com/guangyey --- test/dynamo/test_logging.py | 6 ++++-- test/dynamo/test_package.py | 24 +++++++++++----------- test/inductor/test_fused_attention.py | 9 ++++++-- test/inductor/test_torchinductor_opinfo.py | 6 ++++-- test/inductor/test_triton_kernels.py | 9 ++++++-- test/inductor/test_xpu_basic.py | 4 ++-- torch/testing/_internal/inductor_utils.py | 4 ++-- 7 files changed, 38 insertions(+), 24 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 015bb660512bd..99d992a899dbc 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -26,7 +26,7 @@ TEST_XPU, xfailIf, ) -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU +from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU_AND_TRITON from torch.testing._internal.logging_utils import ( LoggingTestCase, make_logging_test, @@ -35,7 +35,9 @@ requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") -requires_gpu = unittest.skipUnless(HAS_CUDA or HAS_XPU, "requires cuda or xpu") +requires_gpu = unittest.skipUnless( + HAS_CUDA or HAS_XPU_AND_TRITON, "requires cuda or xpu with triton" +) requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index a3c83ec28222f..5739f45504a6d 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -24,7 +24,7 @@ skipIfRocm, skipIfXpu, ) -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU +from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU_AND_TRITON def compute_loss_helper(x): @@ -96,7 +96,7 @@ def forward(self, x): def test_basic_fn(self, backend, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() @@ -140,7 +140,7 @@ def fn(x): def test_lazy_backward(self, backend, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() @@ -187,7 +187,7 @@ def fn(x): def test_graph_break_bomb(self, backend, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() @@ -251,7 +251,7 @@ def guard_filter_fn(guards): def test_dynamic_shape(self, backend, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() @@ -370,7 +370,7 @@ def guard_filter_fn(guards): def test_dynamo_cache_manual_load(self, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): @@ -407,7 +407,7 @@ def fn2(x): def test_automatic_dynamo_serialize(self, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): @@ -443,7 +443,7 @@ def fn2(x): def test_automatic_dynamo_autotune_cache(self, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x, y): @@ -476,7 +476,7 @@ def fn(x, y): def test_automatic_dynamo_recompiles(self, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): @@ -509,7 +509,7 @@ def fn(x): def test_automatic_dynamo_graph_breaks(self, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x, l, r): @@ -555,7 +555,7 @@ def guard_filter_fn(guards): def test_automatic_dynamo_lazy_backward(self, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") def fn(x): @@ -584,7 +584,7 @@ def fn(x): def test_call_function_from_resume(self, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: + if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") mod = torch.nn.Linear(2, 3, device=device) diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index a0e1b47032b86..19757d8942071 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -15,7 +15,12 @@ SM80OrLater, ) from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_CUDA, + HAS_XPU_AND_TRITON, +) def checkpoint_wrapper(fn): @@ -1114,7 +1119,7 @@ def dot_prod_attention( ) -if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION): +if HAS_XPU_AND_TRITON or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION): class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): device = GPU_TYPE diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 2ea0f263d5937..242f774a0c880 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -48,7 +48,7 @@ HAS_CPU, HAS_CUDA, has_triton, - HAS_XPU, + HAS_XPU_AND_TRITON, maybe_skip_size_asserts, ) from torch.utils._dtype_abbrs import dtype_abbrs @@ -1116,7 +1116,9 @@ def tearDown(self): True ) # inductor kernels failing this test intermittently @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") - @skipXPUIf(not HAS_XPU, "Skipped! Supported XPU compiler not found") + @skipXPUIf( + not HAS_XPU_AND_TRITON, "Skipped! Supported XPU compiler and Triton not found" + ) @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @skipIfTorchDynamo("Test uses dynamo already") diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 689cf218b2bcd..03ba4dc712702 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -31,7 +31,12 @@ skipIfWindows, skipIfXpu, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CUDA, + HAS_GPU, + HAS_XPU_AND_TRITON, +) from torch.testing._internal.logging_utils import log_settings, logs_to_string # Defines all the kernels for tests @@ -58,7 +63,7 @@ fast_dividef, fast_dividef as my_fast_dividef, ) - elif HAS_XPU: + elif HAS_XPU_AND_TRITON: from triton.language.extra.intel.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, diff --git a/test/inductor/test_xpu_basic.py b/test/inductor/test_xpu_basic.py index 0572eccb77fd4..4501b8264c5f9 100644 --- a/test/inductor/test_xpu_basic.py +++ b/test/inductor/test_xpu_basic.py @@ -53,7 +53,7 @@ def fn(a, b): if __name__ == "__main__": from torch._dynamo.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_XPU + from torch.testing._internal.inductor_utils import HAS_XPU_AND_TRITON - if HAS_XPU: + if HAS_XPU_AND_TRITON: run_tests(needs="filelock") diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 8a521d56f5f84..7ce065c64317c 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -71,11 +71,11 @@ def test_cpu(): HAS_CUDA = torch.cuda.is_available() and HAS_TRITON -HAS_XPU = torch.xpu.is_available() and HAS_TRITON +HAS_XPU_AND_TRITON = torch.xpu.is_available() and HAS_TRITON HAS_MPS = torch.mps.is_available() -HAS_GPU = HAS_CUDA or HAS_XPU +HAS_GPU = HAS_CUDA or HAS_XPU_AND_TRITON GPU_TYPE = get_gpu_type() From a53d14d5f846ac44f6c205abb1c5bc4d2f3126ae Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 7 Aug 2025 13:09:33 +0000 Subject: [PATCH 0095/1424] Revert "unskipped mobilenet_v3 quantization and mobilenet_v2 quantization plus tests from https://github.com/pytorch/pytorch/issues/125438 (#157786)" This reverts commit 3a2c3c8ed365eb4e4cf4620c25d70b2f70483762. Reverted https://github.com/pytorch/pytorch/pull/157786 on behalf of https://github.com/albanD due to Breaks lint ([comment](https://github.com/pytorch/pytorch/pull/157786#issuecomment-3164126250)) --- test/quantization/eager/test_numeric_suite_eager.py | 5 ++++- test/test_linalg.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/quantization/eager/test_numeric_suite_eager.py b/test/quantization/eager/test_numeric_suite_eager.py index ccffad4b5ab63..cd11e96859937 100644 --- a/test/quantization/eager/test_numeric_suite_eager.py +++ b/test/quantization/eager/test_numeric_suite_eager.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: quantization"] # ruff: noqa: F841 +import unittest import torch import torch.ao.nn.quantized as nnq @@ -37,7 +38,7 @@ test_only_eval_fn, ) from torch.testing._internal.common_quantized import override_qengines -from torch.testing._internal.common_utils import raise_on_run_directly +from torch.testing._internal.common_utils import IS_ARM64, raise_on_run_directly class SubModule(torch.nn.Module): @@ -599,12 +600,14 @@ def compute_error(x, y): act_compare_dict = get_matching_activations(float_model, qmodel) @skip_if_no_torchvision + @unittest.skipIf(IS_ARM64, "Not working on arm right now") def test_mobilenet_v2(self): from torchvision.models.quantization import mobilenet_v2 self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False)) @skip_if_no_torchvision + @unittest.skipIf(IS_ARM64, "Not working on arm right now") def test_mobilenet_v3(self): from torchvision.models.quantization import mobilenet_v3_large diff --git a/test/test_linalg.py b/test/test_linalg.py index 909e8747f1d34..ac668fee049d2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1401,6 +1401,8 @@ def run_test_case(input_size, ord, keepdim): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16) def test_vector_norm(self, device, dtype): + if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]: + raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438") # have to use torch.randn(...).to(bfloat16) instead of # This test compares torch.linalg.vector_norm's output with # torch.linalg.norm given a flattened tensor From 83875cdb5594ccb3c9206b8eb5745fe1d011cf26 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Thu, 7 Aug 2025 14:23:21 +0000 Subject: [PATCH 0096/1424] [nativert] Expose ModelRunner to public through pmpl type ModelRunnerHandle. (#159989) Summary: Today users outside of pytorch core cannot `#include `. It turns out that we should place a header inside `torch/csrc/api/include/`. Placing every single nativert header here would pollute the namespace a lot and that's not what we want in general. Therefore here we just create a Handle type which hold a pointer to decouple the actual type from header definition. Test Plan: CI Rollback Plan: Differential Revision: D79751098 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159989 Approved by: https://github.com/dolpm --- .../torch/nativert/ModelRunnerHandle.h | 46 +++++++++++++++++++ torch/nativert/ModelRunner.cpp | 17 +++++++ torch/nativert/ModelRunner.h | 1 + 3 files changed, 64 insertions(+) create mode 100644 torch/csrc/api/include/torch/nativert/ModelRunnerHandle.h diff --git a/torch/csrc/api/include/torch/nativert/ModelRunnerHandle.h b/torch/csrc/api/include/torch/nativert/ModelRunnerHandle.h new file mode 100644 index 0000000000000..866e09b13407a --- /dev/null +++ b/torch/csrc/api/include/torch/nativert/ModelRunnerHandle.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch::nativert { + +// We don't want to forward declare in general but including ModelRunner will +// pollute the public API namespace too much. Therefore, we just use pimpl an +// incomplete ModelRunner here. +class ModelRunner; + +class TORCH_API ModelRunnerHandle { + public: + ModelRunnerHandle( + const std::string& packagePath, + const std::string& modelName); + + ModelRunnerHandle(ModelRunnerHandle&&) = default; + ModelRunnerHandle& operator=(ModelRunnerHandle&&) = default; + ModelRunnerHandle(const ModelRunnerHandle&) = delete; + ModelRunnerHandle& operator=(const ModelRunnerHandle&) = delete; + ~ModelRunnerHandle(); + + c10::IValue run( + const std::vector& args, + const std::unordered_map& kwargs); + + /** + * A low level API which expects user to always pass in flattened inputs. + * The ownership of the entire input list must be transferred to the + * executor via std::move or in-place construction. + */ + std::vector runWithFlatInputsAndOutputs( + std::vector flatInputs); + + private: + std::unique_ptr impl_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/ModelRunner.cpp b/torch/nativert/ModelRunner.cpp index f1c2a35db14cb..83cb0e00bd728 100644 --- a/torch/nativert/ModelRunner.cpp +++ b/torch/nativert/ModelRunner.cpp @@ -136,4 +136,21 @@ std::vector ModelRunner::runWithFlatInputsAndOutputs( return executor_->execute(std::move(flatInputs)); } +ModelRunnerHandle::ModelRunnerHandle( + const std::string& packagePath, + const std::string& modelName) + : impl_(std::make_unique(packagePath, modelName)) {} +ModelRunnerHandle::~ModelRunnerHandle() = default; + +c10::IValue ModelRunnerHandle::run( + const std::vector& args, + const std::unordered_map& kwargs) { + return impl_->run(args, kwargs); +} + +std::vector ModelRunnerHandle::runWithFlatInputsAndOutputs( + std::vector flatInputs) { + return impl_->runWithFlatInputsAndOutputs(std::move(flatInputs)); +} + } // namespace torch::nativert diff --git a/torch/nativert/ModelRunner.h b/torch/nativert/ModelRunner.h index 4c88757318850..e037e3b26ca89 100644 --- a/torch/nativert/ModelRunner.h +++ b/torch/nativert/ModelRunner.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include From d20c4c20e61adecf00335c4d8c22eb1ace472cd3 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Thu, 7 Aug 2025 15:18:48 +0000 Subject: [PATCH 0097/1424] [CI] Update xpu ci use rolling driver for new features (#158340) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/158340 Approved by: https://github.com/seemethere Co-authored-by: xinan.lin --- .ci/docker/common/install_xpu.sh | 41 +++++++++++-------- test/inductor/test_compile_subprocess.py | 3 -- test/inductor/test_max_autotune.py | 6 +++ test/inductor/test_torchinductor.py | 4 ++ test/inductor/test_torchinductor_opinfo.py | 11 +++++ .../test_torchinductor_strided_blocks.py | 3 ++ 6 files changed, 49 insertions(+), 19 deletions(-) diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh index ecbbb8ccccf89..7f21d2e42c723 100644 --- a/.ci/docker/common/install_xpu.sh +++ b/.ci/docker/common/install_xpu.sh @@ -34,18 +34,27 @@ function install_ubuntu() { # The xpu-smi packages apt-get install -y flex bison xpu-smi - # Compute and Media Runtimes - apt-get install -y \ - intel-opencl-icd intel-level-zero-gpu level-zero \ - intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \ - libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ - libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ - mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo - if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then - apt-get install -y intel-ocloc + + if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then + # Compute and Media Runtimes + apt-get install -y \ + intel-opencl-icd intel-level-zero-gpu level-zero \ + intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \ + libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ + libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ + mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo + # Development Packages + apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev + else # rolling driver + apt-get install -y \ + intel-opencl-icd libze-intel-gpu1 libze1 \ + intel-media-va-driver-non-free libmfx-gen1 libvpl2 \ + libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ + libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ + mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc + apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev fi - # Development Packages - apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev + # Install Intel Support Packages apt-get install -y ${XPU_PACKAGES} @@ -130,11 +139,11 @@ function install_sles() { } -# Default use GPU driver LTS releases -XPU_DRIVER_VERSION="/lts/2350" -if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then - # Use GPU driver rolling releases - XPU_DRIVER_VERSION="" +# Default use GPU driver rolling releases +XPU_DRIVER_VERSION="" +if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then + # Use GPU driver LTS releases + XPU_DRIVER_VERSION="/lts/2350" fi # Default use Intel® oneAPI Deep Learning Essentials 2025.0 diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index 04297c38bf299..51aa7b70b9c40 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -62,9 +62,6 @@ "test_remove_noop_slice_scatter": TestFailure(("xpu"), is_skip=True), "test_remove_noop_view_default": TestFailure(("xpu"), is_skip=True), "test_remove_noop_view_dtype": TestFailure(("xpu"), is_skip=True), - # TODO:remove test_upsample_bicubic2d after the following issue resolved: - # https://github.com/intel/intel-xpu-backend-for-triton/issues/4184 - "test_upsample_bicubic2d": TestFailure(("xpu"), is_skip=False), } diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 1163ec408148b..8917c7a6ed360 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -2155,6 +2155,9 @@ def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): "del", num_deallocs, exactly=True ).run(code_str) + @skipIfXpu( + msg="Triton issue exposed by new driver, will be resolved after next triton update." + ) @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_upcast(self, sizes): M, K, N = sizes @@ -2319,6 +2322,9 @@ def test_multiple_fusions(x): ).run(code[0]) self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05) + @skipIfXpu( + msg="Triton issue exposed by new driver, will be resolved after next triton update." + ) @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_multiple_inputs(self, sizes): M, K, N = sizes diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 1a73c6ef13032..3b71fe464667b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -9837,6 +9837,7 @@ def fn(x): ], ) + @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin2(self): def fn(x): return ( @@ -9848,6 +9849,7 @@ def fn(x): self.common(fn, (torch.randn([144, 144]),)) + @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin_with_duplicates(self): def fn(x): return ( @@ -9869,6 +9871,7 @@ def fn(x): t1 = torch.randint(8, size=(1028, 1028)) self.common(fn, (t1,)) + @skipIfXpu(msg="# Incorrect XPU reference ") @xfail_if_mps # eager nan is wrong, see https://github.com/pytorch/pytorch/issues/130295 @skip_if_halide # nan behavior def test_argmax_argmin_with_nan(self): @@ -9969,6 +9972,7 @@ def shrink_rank(x, rank): [rank4_inps, rank3_inps, rank5_inps], ) + @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin3(self): def fn(x): return ( diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 242f774a0c880..2a0e4c63fb682 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -293,6 +293,17 @@ def format_op(op): # a deconvolution forward propagation primitive "nn.functional.conv_transpose2d": {f32, f64}, "nn.functional.conv_transpose3d": {f32, f64}, + # [Begin] Incorrect XPU reference due to new driver. + "masked.prod": {b8, i32, i64}, + "masked.amin": {i64}, + "masked.amax": {i64}, + "amax": {i64}, + "amin": {i64}, + "std": {f64}, + "var": {f64}, + "std_mean": {f64}, + "var_mean": {f64}, + # [End] } diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 82bfdd6290bba..67d197f0750d0 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -1188,6 +1188,9 @@ def foo(x, y, z): # } # This is now fixed by ensuring that that wild symbols only match integers @xfail_if_use_tensor_descriptor + @skipIfXpu( + msg="Triton issue exposed by new driver, will be resolved after next triton update." + ) def test_ensure_integral_dims_and_strides(self): def model(data, *args): return torch.nn.functional.unfold(data, *args) From 8ab5868a2199fe485c2d66533b9244ccb97e487d Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 4 Aug 2025 10:12:15 -0700 Subject: [PATCH 0098/1424] Actually run the einops tests in CI (#159776) The test filter was wrong, it should not start with "test/". Test Plan: - wait for CI - Tested locally with `python test/run_test.py --einops --verbose` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159776 Approved by: https://github.com/atalman, https://github.com/StrongerXi --- test/run_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/run_test.py b/test/run_test.py index 4c49acfdee9c0..5e9548d4eab11 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1555,7 +1555,7 @@ def get_selected_tests(options) -> list[str]: if options.einops: selected_tests = list( filter( - lambda test_name: test_name.startswith("test/dynamo/test_einops"), + lambda test_name: test_name.startswith("dynamo/test_einops"), selected_tests, ) ) From f60454cce8b93e5bbf67f2f3c88c8ac01ed65457 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov Date: Thu, 7 Aug 2025 15:58:30 +0000 Subject: [PATCH 0099/1424] S390X: update test dependencies (#158636) numba currently doesn't build from source due to https://github.com/numba/numba/pull/10073 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158636 Approved by: https://github.com/malfet --- .ci/docker/requirements-ci.txt | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 4de9431bf300f..d4bdd9b2a9cbf 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -63,11 +63,12 @@ lark==0.12.0 #Pinned versions: 0.12.0 #test that import: -librosa>=0.6.2 ; python_version < "3.11" -librosa==0.10.2 ; python_version == "3.12" +librosa>=0.6.2 ; python_version < "3.11" and platform_machine != "s390x" +librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x" #Description: A python package for music and audio analysis #Pinned versions: >=0.6.2 #test that import: test_spectral_ops.py +#librosa depends on numba; disable it for s390x while numba is disabled too #mkl #this breaks linux-bionic-rocm4.5-py3.7 #Description: Intel oneAPI Math Kernel Library @@ -110,14 +111,15 @@ ninja==1.11.1.3 #Pinned versions: 1.11.1.3 #test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py -numba==0.49.0 ; python_version < "3.9" -numba==0.55.2 ; python_version == "3.9" -numba==0.55.2 ; python_version == "3.10" -numba==0.60.0 ; python_version == "3.12" +numba==0.49.0 ; python_version < "3.9" and platform_machine != "s390x" +numba==0.55.2 ; python_version == "3.9" and platform_machine != "s390x" +numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x" +numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x" #Description: Just-In-Time Compiler for Numerical Functions #Pinned versions: 0.54.1, 0.49.0, <=0.49.1 #test that import: test_numba_integration.py #For numba issue see https://github.com/pytorch/pytorch/issues/51511 +#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073 #numpy #Description: Provides N-dimensional arrays and linear algebra @@ -307,7 +309,7 @@ pytest-cpp==2.3.0 #Pinned versions: 2.3.0 #test that import: -z3-solver==4.15.1.0 +z3-solver==4.15.1.0 ; platform_machine != "s390x" #Description: The Z3 Theorem Prover Project #Pinned versions: #test that import: From e248719ac03c103767ab72034f6b9fd56855bf98 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 6 Aug 2025 17:54:21 -0700 Subject: [PATCH 0100/1424] [DTensor] support _StridedShard in view op (#159656) **Summary** Some thoughts on view-op and `_StridedShard` interaction: 1. `_StridedShard` has no impact on sharding (i.e. how tensor is partitioned) compared to `Shard`. It only changes how shards permute across the devices. 2. `view()` op on DTensor strictly forbids shard redistribution which means if `view()` may cause shard permutation across devices, it should be rejected. This is enforced in today's sharding prop for `view()`. 3. Since DTensor `view()` won't introduce any redistribution, it's certain that `placements` won't change except the inner `dim` attribute of `Shard` or `_StridedShard`. Therefore, to support `_StridedShard` in `view()` op, the only change required is to keep `_StridedShard` as `_StridedShard` in the output spec. **Test** `pytest test/distributed/tensor/test_view_ops.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159656 Approved by: https://github.com/wconstab --- test/distributed/tensor/test_view_ops.py | 39 ++++++++++++++++++---- torch/distributed/tensor/_ops/_view_ops.py | 31 +++++++++++++++-- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/test/distributed/tensor/test_view_ops.py b/test/distributed/tensor/test_view_ops.py index 92de79bc188b8..39f5b98d4eabc 100644 --- a/test/distributed/tensor/test_view_ops.py +++ b/test/distributed/tensor/test_view_ops.py @@ -10,6 +10,7 @@ from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, + DTensor, init_device_mesh, Replicate, Shard, @@ -25,7 +26,7 @@ view_groups, ) from torch.distributed.tensor.debug import CommDebugMode -from torch.distributed.tensor.placement_types import Placement +from torch.distributed.tensor.placement_types import _StridedShard, Placement from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -168,8 +169,34 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): *(device_mesh.ndim * [sharding_choices]) ) - for in_shard in all_sharding_choices: - in_dt = distribute_tensor(args[0], device_mesh, in_shard) + outer_mesh = device_mesh["outer"] + inner_mesh = device_mesh["inner"] + inner_mesh_size = inner_mesh.size() + strided_sharding_choices = [ + (_StridedShard(i, split_factor=inner_mesh_size), Shard(i)) + for i, s in enumerate(in_shape) + if s > 1 and i not in no_shard_dims + ] + + for in_shard in itertools.chain(all_sharding_choices, strided_sharding_choices): + if isinstance(in_shard[0], _StridedShard): + if op != Tensor.view: + continue + # cannot produce DTensor using ``distribute_tensor()`` + # with ``_StridedShard``. Need to distribute the input + # over inner mesh dim first, then distribute the + # _local_tensor over the outer mesh dim. + in_dt = distribute_tensor(args[0], inner_mesh, (in_shard[1],)) + in_dt = distribute_tensor( + in_dt._local_tensor, outer_mesh, (Shard(in_shard[0].dim),) + ) + in_dt = DTensor.from_local( + in_dt._local_tensor, + device_mesh, + in_shard, + ) + else: + in_dt = distribute_tensor(args[0], device_mesh, in_shard) comm_mode = CommDebugMode() with comm_mode: @@ -216,8 +243,9 @@ def test_illegal_views(self): @with_comms def test_view_ops(self): - self.device_mesh = DeviceMesh( - self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) + mesh_shape = (dist.get_world_size() // 2, 2) + self.device_mesh = init_device_mesh( + self.device_type, mesh_shape=mesh_shape, mesh_dim_names=("outer", "inner") ) self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),)) self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),)) @@ -442,7 +470,6 @@ def test_view_ops(self): (randn(42, 24, 36), 1), (InputDim(0), Singleton(), InputDim(1), InputDim(2)), ) - self.dimmap_test( Tensor.view, (randn(6, 12, 24), 72, 24), diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index c942da67cd8a1..1f0906b0beff0 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -22,7 +22,12 @@ prod, register_op_strategy, ) -from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Placement, + Replicate, + Shard, +) aten = torch.ops.aten @@ -605,8 +610,30 @@ def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: ) for mesh_dim, p in enumerate(input_src_placements) ] + + def _rewrite_shard_dim(p: Shard): + """ + Rewrite the shard dim to the corresponding tensor dim in output. + For ``_StridedShard``, we can safely keep the placement type and + ``split_factor`` unchanged and only rewrite the ``dim`` because: + 1. ``_StridedShard`` has no impact on sharding (i.e. how + tensor is partitioned) compared to ``Shard``. It only changes + how shards permute across the devices. + 2. ``view()`` op on DTensor strictly forbids shard redistribution + which means if ``view()`` may cause shard permutation across + devices, it should be rejected. This is enforced in today's + sharding prop for ``view()``. + 3. Since DTensor ``view()`` won't introduce any redistribution, + it's certain that ``placements`` won't change except the + inner ``dim`` attribute of ``Shard`` or ``_StridedShard``. + """ + if isinstance(p, _StridedShard): + return _StridedShard(shard_dim_map[p.dim], split_factor=p.split_factor) + else: + return Shard(shard_dim_map[p.dim]) + output_placements = [ - Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + _rewrite_shard_dim(p) if isinstance(p, Shard) else p for p in input_tgt_placements ] From 90b78ee50f73b5c963996076a3d54b74b1b965be Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Thu, 7 Aug 2025 16:22:52 +0000 Subject: [PATCH 0101/1424] Move xla jobs to unstable workflow (#159272) Disables the job on PRs completely, so that we don't litter people's CI signals and use machines unnecessarily. If you want to run these xla tests, add the ciflow/unstable label to your PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/159272 Approved by: https://github.com/atalman, https://github.com/malfet --- .github/workflows/pull.yml | 24 ------------------------ .github/workflows/unstable.yml | 28 +++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 8c297b1136889..cc2c4e89664ba 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -304,30 +304,6 @@ jobs: ]} secrets: inherit - linux-jammy-py3_9-clang9-xla-build: - name: linux-jammy-py3_9-clang9-xla - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.9-clang9-xla - docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite - test-matrix: | - { include: [ - { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - ]} - secrets: inherit - - linux-jammy-py3_9-clang9-xla-test: - name: linux-jammy-py3_9-clang9-xla - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-py3_9-clang9-xla-build - with: - build-environment: linux-jammy-py3.9-clang9-xla - docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-cpu-py3_10-gcc11-bazel-test: name: linux-jammy-cpu-py3.10-gcc11-bazel-test uses: ./.github/workflows/_bazel-build-test.yml diff --git a/.github/workflows/unstable.yml b/.github/workflows/unstable.yml index 08ae920e7cb0d..7f0fe6058bd08 100644 --- a/.github/workflows/unstable.yml +++ b/.github/workflows/unstable.yml @@ -12,7 +12,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: # There must be at least one job here to satisfy GitHub action workflow syntax @@ -51,3 +53,27 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} + + linux-jammy-py3_9-clang9-xla-build: + name: linux-jammy-py3_9-clang9-xla + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-clang9-xla + docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite + test-matrix: | + { include: [ + { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + ]} + secrets: inherit + + linux-jammy-py3_9-clang9-xla-test: + name: linux-jammy-py3_9-clang9-xla + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-py3_9-clang9-xla-build + with: + build-environment: linux-jammy-py3.9-clang9-xla + docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }} + secrets: inherit From c4e64467b5a30d12fefcb8e1de5a8963cb01306d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 7 Aug 2025 16:34:36 +0000 Subject: [PATCH 0102/1424] Revert "Add UT for torch.accelerator memory-related API (#155200)" This reverts commit 4604f0482c2b4a3001b62e5bc5085149a9bb053c. Reverted https://github.com/pytorch/pytorch/pull/155200 on behalf of https://github.com/jithunnair-amd due to Broke ROCm periodic runs on MI300 e.g. https://github.com/pytorch/pytorch/actions/runs/16764977800/job/47470050573 ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3164941815)) --- test/test_accelerator.py | 78 ---------------------------------------- test/test_cuda.py | 36 ------------------- test/test_xpu.py | 37 ------------------- 3 files changed, 151 deletions(-) diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 21731bd275b60..0ea224d704cb8 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -1,6 +1,5 @@ # Owner(s): ["module: tests"] -import gc import sys import unittest @@ -157,83 +156,6 @@ def test_generic_event_behavior(self): ): event1.elapsed_time(event2) - @unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!") - def test_memory_stats(self): - # Ensure that device allocator is initialized - acc = torch.accelerator.current_accelerator() - tmp = torch.randn(100, device=acc) - del tmp - gc.collect() - self.assertTrue(torch._C._accelerator_isAllocatorInitialized()) - torch.accelerator.empty_cache() - - pool_type = ["all", "small_pool", "large_pool"] - metric_type = ["peak", "current", "allocated", "freed"] - stats_type = [ - "allocated_bytes", - "reserved_bytes", - "active_bytes", - "requested_bytes", - ] - mem_stats = torch.accelerator.memory_stats() - expected_stats = [ - f"{st}.{pt}.{mt}" - for st in stats_type - for pt in pool_type - for mt in metric_type - ] - missing_stats = [stat for stat in expected_stats if stat not in mem_stats] - self.assertEqual( - len(missing_stats), - 0, - f"Missing expected memory statistics: {missing_stats}", - ) - - prev_allocated = torch.accelerator.memory_allocated() - prev_reserved = torch.accelerator.memory_reserved() - prev_max_allocated = torch.accelerator.max_memory_allocated() - prev_max_reserved = torch.accelerator.max_memory_reserved() - self.assertGreaterEqual(prev_allocated, 0) - self.assertGreaterEqual(prev_reserved, 0) - self.assertGreater(prev_max_allocated, 0) - self.assertGreater(prev_max_reserved, 0) - tmp = torch.ones(256, device=acc) - self.assertGreater(torch.accelerator.memory_allocated(), prev_allocated) - self.assertGreaterEqual(torch.accelerator.memory_reserved(), prev_reserved) - del tmp - gc.collect() - torch.accelerator.empty_cache() - torch.accelerator.reset_peak_memory_stats() - self.assertEqual(torch.accelerator.memory_allocated(), prev_allocated) - self.assertEqual(torch.accelerator.memory_reserved(), prev_reserved) - torch.accelerator.reset_accumulated_memory_stats() - prev_max_allocated = torch.accelerator.max_memory_allocated() - prev_max_reserved = torch.accelerator.max_memory_reserved() - # Activate 1kB memory - prev_active_current = torch.accelerator.memory_stats()[ - "active_bytes.all.current" - ] - tmp = torch.randn(256, device=acc) - # Detect if the current active memory is 1kB - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.current"], - 1024 + prev_active_current, - ) - self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) - del tmp - gc.collect() - torch.accelerator.empty_cache() - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.current"], - prev_active_current, - ) - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 - ) - torch.accelerator.reset_peak_memory_stats() - self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) - self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) - if __name__ == "__main__": run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index 9755835853eed..f2f3304069f1b 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -373,42 +373,6 @@ def test_memory_allocation(self): torch.cuda.caching_allocator_delete(mem) self.assertEqual(torch.cuda.memory_allocated(), prev) - def test_memory_stats(self): - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - torch.cuda.reset_accumulated_memory_stats() - prev_allocated = torch.accelerator.memory_allocated() - prev_reserved = torch.accelerator.memory_reserved() - prev_max_allocated = torch.accelerator.max_memory_allocated() - prev_max_reserved = torch.accelerator.max_memory_reserved() - self.assertEqual(prev_allocated, prev_max_allocated) - self.assertEqual(prev_reserved, prev_max_reserved) - # Activate 1kB memory - prev_active_current = torch.accelerator.memory_stats()[ - "active_bytes.all.current" - ] - tmp = torch.randn(256, device="cuda") - # Detect if the current active memory is 1kB - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.current"], - 1024 + prev_active_current, - ) - self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) - del tmp - gc.collect() - torch.accelerator.empty_cache() - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.current"], - prev_active_current, - ) - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 - ) - torch.accelerator.reset_peak_memory_stats() - self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) - self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) - def test_check_error(self): # Assert this call doesn't raise. torch.cuda.check_error(0) diff --git a/test/test_xpu.py b/test/test_xpu.py index beb5a53a4a6b3..cd5275418c440 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,6 +1,5 @@ # Owner(s): ["module: intel"] -import gc import re import subprocess import sys @@ -521,42 +520,6 @@ def test_device_memory_allocated(self): ) del a - def test_memory_stats(self): - gc.collect() - torch.xpu.empty_cache() - torch.xpu.reset_peak_memory_stats() - torch.xpu.reset_accumulated_memory_stats() - prev_allocated = torch.accelerator.memory_allocated() - prev_reserved = torch.accelerator.memory_reserved() - prev_max_allocated = torch.accelerator.max_memory_allocated() - prev_max_reserved = torch.accelerator.max_memory_reserved() - self.assertEqual(prev_allocated, prev_max_allocated) - self.assertEqual(prev_reserved, prev_max_reserved) - # Activate 1kB memory - prev_active_current = torch.accelerator.memory_stats()[ - "active_bytes.all.current" - ] - tmp = torch.randn(256, device="xpu") - # Detect if the current active memory is 1kB - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.current"], - 1024 + prev_active_current, - ) - self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) - del tmp - gc.collect() - torch.accelerator.empty_cache() - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.current"], - prev_active_current, - ) - self.assertEqual( - torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 - ) - torch.accelerator.reset_peak_memory_stats() - self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) - self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) - @skipXPUIf( int(torch.version.xpu) < 20250000, "Test requires SYCL compiler version 2025.0.0 or newer.", From 74da2604c9da37bf3701205c051e67e48a3d17bd Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 7 Aug 2025 16:34:36 +0000 Subject: [PATCH 0103/1424] Revert "Add unified memory APIs for torch.accelerator (#152932)" This reverts commit 15f1173e5d72d6d45faba4cecd135e0160f06c6f. Reverted https://github.com/pytorch/pytorch/pull/152932 on behalf of https://github.com/jithunnair-amd due to Broke ROCm periodic runs on MI300 e.g. https://github.com/pytorch/pytorch/actions/runs/16764977800/job/47470050573 ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3164941815)) --- aten/src/ATen/DeviceAccelerator.h | 22 ---- docs/source/accelerator.md | 23 ---- torch/_C/__init__.pyi.in | 5 - torch/accelerator/__init__.py | 18 --- torch/accelerator/memory.py | 201 ------------------------------ torch/csrc/DeviceAccelerator.cpp | 64 ---------- torch/cuda/memory.py | 4 +- 7 files changed, 2 insertions(+), 335 deletions(-) delete mode 100644 torch/accelerator/memory.py diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f23b35047fcc8..f37e492c861fe 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -73,27 +72,6 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); -TORCH_API inline void emptyCache() { - const auto device_type = getAccelerator(true).value(); - at::getDeviceAllocator(device_type)->emptyCache(); -} - -TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device_index) { - const auto device_type = getAccelerator(true).value(); - return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); -} - -TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { - const auto device_type = getAccelerator(true).value(); - at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); -} - -TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { - const auto device_type = getAccelerator(true).value(); - at::getDeviceAllocator(device_type)->resetPeakStats(device_index); -} - } // namespace at::accelerator namespace at { diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index ce593a9acf518..c6f2fb1080400 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -25,26 +25,3 @@ synchronize device_index ``` - -```{eval-rst} -.. automodule:: torch.accelerator.memory -``` -```{eval-rst} -.. currentmodule:: torch.accelerator.memory -``` - -## Memory management -```{eval-rst} -.. autosummary:: - :toctree: generated - :nosignatures: - - empty_cache - max_memory_allocated - max_memory_reserved - memory_allocated - memory_reserved - memory_stats - reset_accumulated_memory_stats - reset_peak_memory_stats -``` diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index fb7e9c5ce56e0..9e03c7dba8305 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2435,11 +2435,6 @@ def _accelerator_synchronizeDevice(device_index: _int) -> None: ... def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... def _accelerator_setAllocatorSettings(env: str) -> None: ... -def _accelerator_isAllocatorInitialized() -> _bool: ... -def _accelerator_emptyCache() -> None: ... -def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... -def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... -def _accelerator_resetPeakStats(device_index: _int) -> None: ... # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index 4d1a78df1f74c..e9e48f1cf3061 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -8,16 +8,6 @@ import torch from ._utils import _device_t, _get_device_index -from .memory import ( - empty_cache, - max_memory_allocated, - max_memory_reserved, - memory_allocated, - memory_reserved, - memory_stats, - reset_accumulated_memory_stats, - reset_peak_memory_stats, -) __all__ = [ @@ -25,17 +15,9 @@ "current_device_idx", # deprecated "current_device_index", "current_stream", - "empty_cache", "device_count", "device_index", "is_available", - "max_memory_allocated", - "max_memory_reserved", - "memory_allocated", - "memory_reserved", - "memory_stats", - "reset_accumulated_memory_stats", - "reset_peak_memory_stats", "set_device_idx", # deprecated "set_device_index", "set_stream", diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py deleted file mode 100644 index d34a11a3a02e5..0000000000000 --- a/torch/accelerator/memory.py +++ /dev/null @@ -1,201 +0,0 @@ -from collections import OrderedDict -from typing import Any - -import torch - -from ._utils import _device_t, _get_device_index - - -__all__ = [ - "empty_cache", - "max_memory_allocated", - "max_memory_reserved", - "memory_allocated", - "memory_reserved", - "memory_stats", - "reset_accumulated_memory_stats", - "reset_peak_memory_stats", -] - - -def empty_cache() -> None: - r"""Release all unoccupied cached memory currently held by the caching - allocator so that those can be used in other application. - - .. note:: This function is a no-op if the memory allocator for the current - :ref:`accelerator ` has not been initialized. - """ - if not torch._C._accelerator_isAllocatorInitialized(): - return - torch._C._accelerator_emptyCache() - - -def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: - r"""Return a dictionary of accelerator device memory allocator statistics for a given device index. - - The return value of this function is a dictionary of statistics, each of - which is a non-negative integer. - - Core statistics: - - - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - number of allocation requests received by the memory allocator. - - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - amount of allocated memory. - - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - number of reserved segments from device memory allocation. - - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - amount of reserved memory. - - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - number of active memory blocks. - - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - amount of active memory. - - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - number of inactive, non-releasable memory blocks. - - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - amount of inactive, non-releasable memory. - - For these core statistics, values are broken down as follows. - - Pool type: - - - ``all``: combined statistics across all memory pools. - - ``large_pool``: statistics for the large allocation pool - (as of June 2025, for size >= 1MB allocations). - - ``small_pool``: statistics for the small allocation pool - (as of June 2025, for size < 1MB allocations). - - Metric type: - - - ``current``: current value of this metric. - - ``peak``: maximum value of this metric. - - ``allocated``: historical total increase in this metric. - - ``freed``: historical total decrease in this metric. - - In addition to the core statistics, we also provide some simple event - counters: - - - ``"num_alloc_retries"``: number of failed device memory allocation calls that - result in a cache flush and retry. - - ``"num_ooms"``: number of out-of-memory errors thrown. - - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. - - ``"num_device_alloc"``: number of device memory allocation calls. - - ``"num_device_free"``: number of device memory free calls. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - if not torch._C._accelerator_isAllocatorInitialized(): - return OrderedDict() - device_index = _get_device_index(device_index, optional=True) - stats = torch._C._accelerator_getDeviceStats(device_index) - flat_stats = [] - - def flatten(prefix: str, value: Any) -> None: - if isinstance(value, dict): - for k, v in value.items(): - nested_prefix = f"{prefix}.{k}" if prefix else k - flatten(nested_prefix, v) - else: - flat_stats.append((prefix, value)) - - flatten("", stats) - flat_stats.sort() - return OrderedDict(flat_stats) - - -def memory_allocated(device_index: _device_t = None, /) -> int: - r"""Return the current :ref:`accelerator` device memory occupied by tensors - in bytes for a given device index. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - return memory_stats(device_index).get("allocated_bytes.all.current", 0) - - -def max_memory_allocated(device_index: _device_t = None, /) -> int: - r"""Return the current :ref:`accelerator` maximum device memory occupied by tensors - in bytes for a given device index. - - By default, this returns the peak allocated memory since the beginning of - this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to - reset the starting point in tracking this metric. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - return memory_stats(device_index).get("allocated_bytes.all.peak", 0) - - -def memory_reserved(device_index: _device_t = None, /) -> int: - r"""Return the current :ref:`accelerator` device memory managed by the caching allocator - in bytes for a given device index. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - return memory_stats(device_index).get("reserved_bytes.all.current", 0) - - -def max_memory_reserved(device_index: _device_t = None, /) -> int: - r"""Return the current :ref:`accelerator` maximum device memory managed by the caching allocator - in bytes for a given device index. - - By default, this returns the peak cached memory since the beginning of this - program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset - the starting point in tracking this metric. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - return memory_stats(device_index).get("reserved_bytes.all.peak", 0) - - -def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None: - r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator` - memory allocator for a given device index. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - - .. note:: This function is a no-op if the memory allocator for the current - :ref:`accelerator ` has not been initialized. - """ - device_index = _get_device_index(device_index, optional=True) - return torch._C._accelerator_resetAccumulatedStats(device_index) - - -def reset_peak_memory_stats(device_index: _device_t = None, /) -> None: - r"""Reset the "peak" stats tracked by the current :ref:`accelerator` - memory allocator for a given device index. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - - .. note:: This function is a no-op if the memory allocator for the current - :ref:`accelerator ` has not been initialized. - """ - device_index = _get_device_index(device_index, optional=True) - return torch._C._accelerator_resetPeakStats(device_index) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 59cb8047467c9..3a97c0794684f 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -77,70 +77,6 @@ void initModule(PyObject* module) { m.def("_accelerator_setAllocatorSettings", [](std::string env) { c10::CachingAllocator::setAllocatorSettings(env); }); - - m.def("_accelerator_isAllocatorInitialized", []() { - const auto device_type = at::accelerator::getAccelerator(true).value(); - return at::getDeviceAllocator(device_type)->initialized(); - }); - - m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); }); - - m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) { - using c10::CachingAllocator::Stat; - using c10::CachingAllocator::StatArray; - using c10::CachingAllocator::StatType; - using c10::CachingDeviceAllocator::DeviceStats; - - const auto stats = at::accelerator::getDeviceStats(device_index); - const auto stat_to_dict = [](const Stat& stat) -> py::dict { - py::dict dict; - dict["current"] = stat.current; - dict["peak"] = stat.peak; - dict["allocated"] = stat.allocated; - dict["freed"] = stat.freed; - return dict; - }; - - const auto stat_array_to_dict = [=](const StatArray& stats) -> py::dict { - const std::array(StatType::NUM_TYPES)> - kStatTypeNames = {"all", "small_pool", "large_pool"}; - py::dict dict; - for (const auto i : c10::irange(kStatTypeNames.size())) { - dict[kStatTypeNames[i]] = stat_to_dict(stats[i]); - } - return dict; - }; - - py::dict result; - result["num_alloc_retries"] = stats.num_alloc_retries; - result["num_ooms"] = stats.num_ooms; - result["max_split_size"] = stats.max_split_size; - result["num_sync_all_streams"] = stats.num_sync_all_streams; - result["num_device_alloc"] = stats.num_device_alloc; - result["num_device_free"] = stats.num_device_free; - result["allocated_bytes"] = stat_array_to_dict(stats.allocated_bytes); - result["reserved_bytes"] = stat_array_to_dict(stats.reserved_bytes); - result["active_bytes"] = stat_array_to_dict(stats.active_bytes); - result["requested_bytes"] = stat_array_to_dict(stats.requested_bytes); - result["allocation"] = stat_array_to_dict(stats.allocation); - result["segment"] = stat_array_to_dict(stats.segment); - result["active"] = stat_array_to_dict(stats.active); - result["inactive_split"] = stat_array_to_dict(stats.inactive_split); - result["inactive_split_bytes"] = - stat_array_to_dict(stats.inactive_split_bytes); - result["oversize_allocations"] = stat_to_dict(stats.oversize_allocations); - result["oversize_segments"] = stat_to_dict(stats.oversize_segments); - return result; - }); - - m.def( - "_accelerator_resetAccumulatedStats", [](c10::DeviceIndex device_index) { - at::accelerator::resetAccumulatedStats(device_index); - }); - - m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) { - at::accelerator::resetPeakStats(device_index); - }); } } // namespace torch::accelerator diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 1bd6f9edc0319..63e59096162fb 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -255,9 +255,9 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: - ``all``: combined statistics across all memory pools. - ``large_pool``: statistics for the large allocation pool - (as of June 2025, for size >= 1MB allocations). + (as of October 2019, for size >= 1MB allocations). - ``small_pool``: statistics for the small allocation pool - (as of June 2025, for size < 1MB allocations). + (as of October 2019, for size < 1MB allocations). Metric type: From f3a4d742ece08de4cb0e59dcc62e0093a7d0b0c7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 7 Aug 2025 16:34:36 +0000 Subject: [PATCH 0104/1424] Revert "Add DeviceAllocator as the base device allocator (#138222)" This reverts commit f7a66da5f9f6b8b75119b1ee8ce9ddc23e15570e. Reverted https://github.com/pytorch/pytorch/pull/138222 on behalf of https://github.com/jithunnair-amd due to Broke ROCm periodic runs on MI300 e.g. https://github.com/pytorch/pytorch/actions/runs/16764977800/job/47470050573 ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3164941815)) --- aten/src/ATen/cuda/CUDAGraph.cpp | 1 + aten/src/ATen/cuda/CUDAGraph.h | 1 - c10/core/CachingDeviceAllocator.cpp | 10 ------ c10/core/CachingDeviceAllocator.h | 53 ----------------------------- c10/cuda/CUDACachingAllocator.cpp | 11 ------ c10/cuda/CUDACachingAllocator.h | 19 +++++------ c10/cuda/CUDAGraphsC10Utils.h | 6 ++++ c10/xpu/XPUCachingAllocator.cpp | 19 ++++------- 8 files changed, 22 insertions(+), 98 deletions(-) delete mode 100644 c10/core/CachingDeviceAllocator.cpp diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 2800e505a9b76..7fba7c4c7424c 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index 4f2aa31dd1c35..c8cae16b624fe 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -2,7 +2,6 @@ #include #include -#include #include #include #include diff --git a/c10/core/CachingDeviceAllocator.cpp b/c10/core/CachingDeviceAllocator.cpp deleted file mode 100644 index 582efd59cf1b1..0000000000000 --- a/c10/core/CachingDeviceAllocator.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include - -namespace c10 { - -// Ensures proper DLL export of this pure virtual base class on Windows, -// since it's mainly used in other DLLs outside c10.dll. -DeviceAllocator::DeviceAllocator() = default; -DeviceAllocator::~DeviceAllocator() = default; - -} // namespace c10 diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index 0bec03ae417fa..b23490de693a8 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -1,7 +1,6 @@ #pragma once #include -#include namespace c10::CachingDeviceAllocator { @@ -60,55 +59,3 @@ struct DeviceStats { }; } // namespace c10::CachingDeviceAllocator - -namespace c10 { - -using CaptureId_t = unsigned long long; - -// first is set if the instance is created by Graph mode capture_begin. -// second is set if the instance is created by Graph mode graph_pool_handle. -using MempoolId_t = std::pair; - -struct C10_API DeviceAllocator : public c10::Allocator { - DeviceAllocator(); - ~DeviceAllocator() override; - - // Returns true if the allocator has been properly initialized and is ready - // for use - virtual bool initialized() = 0; - - // Releases all cached device memory from the specified memory pool back to - // the system - virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; - - // Associates a memory allocation with a stream to establish dependency - // tracking. Prevents memory reuse until all operations on the specified - // stream complete - virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0; - - // Retrieves comprehensive memory statistics for the specified device, - // including allocation patterns, usage metrics - virtual CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device) = 0; - - // Resets cumulative allocation statistics for the specified device to zero - virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; - - // Resets peak memory usage statistics for the specified device - virtual void resetPeakStats(c10::DeviceIndex device) = 0; -}; - -// This function is used to get the DeviceAllocator for a specific device type -// and keep backward compatibility with c10::GetAllocator. -C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) { - TORCH_CHECK( - t != DeviceType::CPU, - "getDeviceAllocator is not supported for CPU device type."); - auto* allocator = c10::GetAllocator(t); - auto* device_allocator = dynamic_cast(allocator); - TORCH_INTERNAL_ASSERT( - device_allocator, "Allocator for ", t, " is not a DeviceAllocator."); - return device_allocator; -} - -} // namespace c10 diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 59b62dcac07f0..c2a46ac9f3f74 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -4118,18 +4118,7 @@ struct BackendStaticInitializer { BackendStaticInitializer() { auto r = parseEnvForBackend(); -// Register this HIP allocator as the CUDA allocator to allow it to work -// with both c10::GetAllocator(kCUDA) and c10::getDeviceAllocator(kCUDA) -// APIs. We don't perform this masquerading inside -// HIPAllocatorMasqueradingAsCUDA because it needs to happen during static -// initialization, and doing so there may introduce static initialization -// order (SIOF) issues. -#define HIP_MASQUERADING_AS_CUDA \ - "cud" \ - "a" - at::SetAllocator(c10::Device(HIP_MASQUERADING_AS_CUDA).type(), r, 0); allocator.store(r); -#undef HIP_MASQUERADING_AS_CUDA } }; diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 75a2d4c8e481b..956411fe22827 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -202,24 +202,25 @@ struct ShareableHandle { std::string handle; }; -class CUDAAllocator : public DeviceAllocator { +class CUDAAllocator : public Allocator { public: virtual void* raw_alloc(size_t nbytes) = 0; virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; virtual void raw_delete(void* ptr) = 0; virtual void init(int device_count) = 0; + virtual bool initialized() = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; - // Keep for BC only - virtual void recordStream(const DataPtr& ptr, CUDAStream stream) = 0; - void recordStream(const DataPtr& ptr, c10::Stream stream) override { - CUDAStream cuda_stream = CUDAStream(stream); - recordStream(ptr, cuda_stream); - } + virtual void recordStream(const DataPtr&, CUDAStream stream) = 0; + virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) = 0; + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + virtual void resetPeakStats(c10::DeviceIndex device) = 0; virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, @@ -524,10 +525,6 @@ inline void enablePeerAccess( namespace c10::cuda { -// Keep BC only -using c10::CaptureId_t; -using c10::MempoolId_t; - // MemPool represents a pool of memory in a caching allocator. Currently, // it's just the ID of the pool object maintained in the CUDACachingAllocator. // diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index 936875fd71d5c..eb29ca8bc9f02 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -9,6 +9,12 @@ namespace c10::cuda { +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by CUDAGraph::capture_begin. +// second is set if the instance is created by at::cuda::graph_pool_handle. +using MempoolId_t = std::pair; + // RAII guard for "cudaStreamCaptureMode", a thread-local value // that controls the error-checking strictness of a capture. struct C10_CUDA_API CUDAStreamCaptureModeGuard { diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 04ab3cabcbc2b..afae32d92a4b4 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -539,7 +539,7 @@ class DeviceCachingAllocator { static void local_raw_delete(void* ptr); -class XPUAllocator : public DeviceAllocator { +class XPUAllocator : public Allocator { private: std::mutex mutex; ska::flat_hash_map allocated_blocks; @@ -575,10 +575,6 @@ class XPUAllocator : public DeviceAllocator { } } - bool initialized() override { - return !device_allocators.empty(); - } - void malloc( void** devPtr, DeviceIndex device, @@ -613,13 +609,13 @@ class XPUAllocator : public DeviceAllocator { } } - void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override { + void emptyCache() { for (auto& da : device_allocators) { da->emptyCache(); } } - void recordStream(const DataPtr& ptr, c10::Stream stream) override { + void recordStream(const DataPtr& ptr, XPUStream stream) { if (!ptr.get()) { return; } @@ -629,8 +625,7 @@ class XPUAllocator : public DeviceAllocator { Block* block = get_allocated_block(ptr.get()); TORCH_CHECK(block, "No allocated block can be found."); - c10::xpu::XPUStream xpu_stream{stream}; - device_allocators[block->device]->recordStream(block, xpu_stream); + device_allocators[block->device]->recordStream(block, stream); } DataPtr allocate(size_t size) override { @@ -683,17 +678,17 @@ class XPUAllocator : public DeviceAllocator { ": did you call init?"); } - DeviceStats getDeviceStats(DeviceIndex device) override { + DeviceStats getDeviceStats(DeviceIndex device) { assertValidDevice(device); return device_allocators[device]->getStats(); } - void resetPeakStats(DeviceIndex device) override { + void resetPeakStats(DeviceIndex device) { assertValidDevice(device); device_allocators[device]->resetPeakStats(); } - void resetAccumulatedStats(DeviceIndex device) override { + void resetAccumulatedStats(DeviceIndex device) { assertValidDevice(device); device_allocators[device]->resetAccumulatedStats(); } From 06824f3c7268bb807a422b663047cd0900ddd126 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Thu, 7 Aug 2025 16:37:52 +0000 Subject: [PATCH 0105/1424] [inductor] fix test_dynamo_timed on Windows. (#159981) Fixed `test_dynamo_timed `: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/159981 Approved by: https://github.com/angelayi --- test/dynamo/test_utils.py | 208 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index b14a6c41dbdc7..d4206575d7b08 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -12,6 +12,9 @@ from torch._inductor.test_case import TestCase +_IS_WINDOWS = sys.platform == "win32" + + class TestUtils(TestCase): def test_nan(self): a = torch.Tensor([float("nan")]) @@ -283,6 +286,37 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): self.assertExpectedInline( pprint.pformat(utils.compilation_time_metrics), """\ +{'GraphLowering.codegen': [0.0, 0.0], + 'GraphLowering.compile_to_fn': [0.0, 0.0], + 'GraphLowering.compile_to_module': [0.0, 0.0], + 'GraphLowering.run': [0.0, 0.0], + 'OutputGraph.call_user_compiler': [0.0], + 'PyCodeCache.load_by_key_path': [0.0, 0.0], + 'PythonWrapperCodegen.generate': [0.0, 0.0], + 'Scheduler.__init__': [0.0, 0.0], + 'Scheduler.codegen': [0.0, 0.0], + 'Scheduler.fused_nodes': [0.0, 0.0], + '_compile.compile_inner': [0.0], + '_recursive_joint_graph_passes': [0.0], + '_recursive_post_grad_passes': [0.0, 0.0], + '_recursive_pre_grad_passes': [0.0], + 'additional_fake_tensor_prop': [0.0, 0.0], + 'aot_collect_metadata': [0.0], + 'aot_trace_joint_graph': [0.0], + 'backward._backward_impl': [0.0], + 'build_guards': [0.0], + 'bytecode_tracing': [0.0], + 'compile_attempt_0': [0.0], + 'compile_file': [0.0, 0.0], + 'compile_fx..bw_compiler': [0.0], + 'compile_fx..fw_compiler_base': [0.0], + 'compile_fx_inner': [0.0, 0.0], + 'create_aot_dispatcher_function': [0.0], + 'fx_codegen_and_compile': [0.0, 0.0], + 'gc': [0.0], + 'min_cut_rematerialization_partition': [0.0]}""" + if _IS_WINDOWS + else """\ {'GraphLowering.codegen': [0.0, 0.0], 'GraphLowering.compile_to_fn': [0.0, 0.0], 'GraphLowering.compile_to_module': [0.0, 0.0], @@ -321,6 +355,18 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): self.assertExpectedInline( pprint.pformat(time_spent), """\ +{'_recursive_joint_graph_passes': 0.0, + '_recursive_post_grad_passes': 0.0, + '_recursive_pre_grad_passes': 0.0, + 'backend_compile': 0.0, + 'code_gen': 0.0, + 'entire_backward_compile': 0.0, + 'entire_frame_compile': 0.0, + 'gc': 0.0, + 'inductor_compile': 0.0, + 'total_wall_time': 0.0}""" + if _IS_WINDOWS + else """\ {'_recursive_joint_graph_passes': 0.0, '_recursive_post_grad_passes': 0.0, '_recursive_pre_grad_passes': 0.0, @@ -364,6 +410,87 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): self.assertExpectedInline( pprint.pformat(raw), """\ +{'accumulated_cache_size': 0, + 'aot_autograd_cumulative_compile_time_us': 0, + 'backend_compile_time_s': 0.0, + 'backward_cumulative_compile_time_us': None, + 'cache_size': 0, + 'co_filename': None, + 'co_firstlineno': None, + 'co_name': 'forward', + 'code_gen_time_s': 0.0, + 'compile_id': '1/0', + 'compile_time_autotune_time_us': None, + 'compliant_custom_ops': set(), + 'config_inline_inbuilt_nn_modules': False, + 'config_suppress_errors': False, + 'cuda_version': None, + 'cudagraph_skip_reason': None, + 'distributed_ephemeral_timeout_us': None, + 'duration_us': 0, + 'dynamo_compile_time_before_restart_us': 0, + 'dynamo_config': None, + 'dynamo_cumulative_compile_time_us': 0, + 'dynamo_time_before_restart_s': 0.0, + 'end_time_us': 100, + 'entire_frame_compile_time_s': 0.0, + 'fail_reason': None, + 'fail_type': None, + 'fail_user_frame_filename': None, + 'fail_user_frame_lineno': None, + 'frame_key': '1', + 'gc_time_us': 0, + 'graph_input_count': 1, + 'graph_node_count': 3, + 'graph_op_count': 1, + 'guard_count': 9, + 'has_guarded_code': True, + 'inductor_code_gen_cumulative_compile_time_us': 0, + 'inductor_compile_time_s': 0.0, + 'inductor_config': None, + 'inductor_cumulative_compile_time_us': 0, + 'inductor_fx_remote_cache_backend_type': None, + 'inductor_fx_remote_cache_hit_count': None, + 'inductor_fx_remote_cache_hit_keys': None, + 'inductor_fx_remote_cache_miss_count': None, + 'inductor_fx_remote_cache_miss_keys': None, + 'is_forward': True, + 'is_runtime': False, + 'joint_graph_pass_time_us': 0, + 'log_format_version': 3, + 'non_compliant_ops': set(), + 'num_graph_breaks': 0, + 'num_triton_bundles': None, + 'pgo_get_remote_code_state_time_us': None, + 'pgo_put_remote_code_state_time_us': None, + 'post_grad_pass_time_us': 0, + 'pre_grad_pass_time_us': 0, + 'python_version': None, + 'recompile_reason': None, + 'recompile_user_contexts': None, + 'remote_cache_time_saved_s': None, + 'remote_cache_version': None, + 'remote_fx_graph_cache_get_time_ms': None, + 'remote_fx_graph_cache_get_time_us': None, + 'remote_fx_graph_cache_put_time_ms': None, + 'remote_fx_graph_cache_put_time_us': None, + 'restart_reasons': set(), + 'runtime_cudagraphify_time_us': None, + 'runtime_triton_autotune_time_us': None, + 'shape_env_guard_count': 0, + 'specialize_float': False, + 'start_time': 0.0001, + 'start_time_us': 100, + 'structured_logging_overhead_s': 0.0, + 'structured_logging_overhead_us': 0, + 'tensorify_float_attempt': None, + 'tensorify_float_failure': None, + 'tensorify_float_success': None, + 'triton_compile_time_us': None, + 'triton_kernel_compile_times_us': None, + 'triton_version': None}""" + if _IS_WINDOWS + else """\ {'accumulated_cache_size': 0, 'aot_autograd_cumulative_compile_time_us': 0, 'backend_compile_time_s': 0.0, @@ -456,6 +583,87 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): self.assertExpectedInline( pprint.pformat(raw), """\ +{'accumulated_cache_size': None, + 'aot_autograd_cumulative_compile_time_us': None, + 'backend_compile_time_s': None, + 'backward_cumulative_compile_time_us': 0, + 'cache_size': None, + 'co_filename': None, + 'co_firstlineno': None, + 'co_name': None, + 'code_gen_time_s': 0.0, + 'compile_id': '1/0', + 'compile_time_autotune_time_us': None, + 'compliant_custom_ops': None, + 'config_inline_inbuilt_nn_modules': False, + 'config_suppress_errors': False, + 'cuda_version': None, + 'cudagraph_skip_reason': None, + 'distributed_ephemeral_timeout_us': None, + 'duration_us': 0, + 'dynamo_compile_time_before_restart_us': None, + 'dynamo_config': None, + 'dynamo_cumulative_compile_time_us': None, + 'dynamo_time_before_restart_s': None, + 'end_time_us': 100, + 'entire_frame_compile_time_s': None, + 'fail_reason': None, + 'fail_type': None, + 'fail_user_frame_filename': None, + 'fail_user_frame_lineno': None, + 'frame_key': None, + 'gc_time_us': None, + 'graph_input_count': None, + 'graph_node_count': None, + 'graph_op_count': None, + 'guard_count': None, + 'has_guarded_code': None, + 'inductor_code_gen_cumulative_compile_time_us': 0, + 'inductor_compile_time_s': 0.0, + 'inductor_config': None, + 'inductor_cumulative_compile_time_us': 0, + 'inductor_fx_remote_cache_backend_type': None, + 'inductor_fx_remote_cache_hit_count': None, + 'inductor_fx_remote_cache_hit_keys': None, + 'inductor_fx_remote_cache_miss_count': None, + 'inductor_fx_remote_cache_miss_keys': None, + 'is_forward': False, + 'is_runtime': False, + 'joint_graph_pass_time_us': None, + 'log_format_version': 3, + 'non_compliant_ops': None, + 'num_graph_breaks': 0, + 'num_triton_bundles': None, + 'pgo_get_remote_code_state_time_us': None, + 'pgo_put_remote_code_state_time_us': None, + 'post_grad_pass_time_us': 0, + 'pre_grad_pass_time_us': None, + 'python_version': None, + 'recompile_reason': None, + 'recompile_user_contexts': None, + 'remote_cache_time_saved_s': None, + 'remote_cache_version': None, + 'remote_fx_graph_cache_get_time_ms': None, + 'remote_fx_graph_cache_get_time_us': None, + 'remote_fx_graph_cache_put_time_ms': None, + 'remote_fx_graph_cache_put_time_us': None, + 'restart_reasons': None, + 'runtime_cudagraphify_time_us': None, + 'runtime_triton_autotune_time_us': None, + 'shape_env_guard_count': None, + 'specialize_float': None, + 'start_time': 0.0001, + 'start_time_us': 100, + 'structured_logging_overhead_s': 0.0, + 'structured_logging_overhead_us': 0, + 'tensorify_float_attempt': None, + 'tensorify_float_failure': None, + 'tensorify_float_success': None, + 'triton_compile_time_us': None, + 'triton_kernel_compile_times_us': None, + 'triton_version': None}""" + if _IS_WINDOWS + else """\ {'accumulated_cache_size': None, 'aot_autograd_cumulative_compile_time_us': None, 'backend_compile_time_s': None, From e1cf0d496ea85d1807c8c740f296e77bf7bdc1df Mon Sep 17 00:00:00 2001 From: "Han, Xu" Date: Thu, 7 Aug 2025 16:37:57 +0000 Subject: [PATCH 0106/1424] [inductor] unification for inductor debug. (#159998) Unification inductor debug build, follow @desertfire 's suggestion: https://github.com/pytorch/pytorch/pull/159938#pullrequestreview-3093803196 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159998 Approved by: https://github.com/angelayi --- torch/_inductor/cpp_builder.py | 120 ++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 55 deletions(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index baa852fbaf4fc..45e655d1dfa8e 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -601,40 +601,70 @@ def _get_ffast_math_flags() -> list[str]: return flags +def _get_inductor_debug_symbol_cflags() -> tuple[list[str], list[str]]: + """ + When we turn on generate debug symbol. + On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. + On Linux, it should create some debug sections in binary file. + """ + cflags: list[str] = [] + ldflags: list[str] = [] + + if _IS_WINDOWS: + cflags = ["ZI", "_DEBUG"] + ldflags = ["DEBUG", "ASSEMBLYDEBUG ", "OPT:REF", "OPT:ICF"] + else: + cflags.append("g") + + return cflags, ldflags + + def _get_optimization_cflags( cpp_compiler: str, min_optimize: bool = False -) -> list[str]: - if _IS_WINDOWS: - return ["O1" if min_optimize else "O2"] +) -> tuple[list[str], list[str]]: + cflags: list[str] = [] + ldflags: list[str] = [] + + b_debug_build = ( + config.aot_inductor.debug_compile + or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1" + ) + wrapper_opt_level = config.aot_inductor.compile_wrapper_opt_level + + if b_debug_build: + cflags, ldflags = _get_inductor_debug_symbol_cflags() + if _IS_WINDOWS: + cflags += ["Od", "Ob0", "Oy-"] + else: + cflags.append("O0") else: - wrapper_opt_level = config.aot_inductor.compile_wrapper_opt_level - cflags = ( - ["O0", "g"] - if config.aot_inductor.debug_compile - else [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"] - ) - cflags += _get_ffast_math_flags() - cflags.append("fno-finite-math-only") - if not config.cpp.enable_unsafe_math_opt_flag: - cflags.append("fno-unsafe-math-optimizations") - cflags.append(f"ffp-contract={config.cpp.enable_floating_point_contract_flag}") - - if sys.platform != "darwin": - # on macos, unknown argument: '-fno-tree-loop-vectorize' - if _is_gcc(cpp_compiler): - cflags.append("fno-tree-loop-vectorize") - # https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1 - # `-march=native` is unrecognized option on M1 - if not config.is_fbcode(): - if platform.machine() == "ppc64le": - cflags.append("mcpu=native") - else: - cflags.append("march=native") - - if config.aot_inductor.enable_lto and _is_clang(cpp_compiler): - cflags.append("flto=thin") - - return cflags + if _IS_WINDOWS: + cflags = ["O1" if min_optimize else "O2"] + else: + cflags = [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"] + + cflags += _get_ffast_math_flags() + cflags.append("fno-finite-math-only") + if not config.cpp.enable_unsafe_math_opt_flag: + cflags.append("fno-unsafe-math-optimizations") + cflags.append(f"ffp-contract={config.cpp.enable_floating_point_contract_flag}") + + if sys.platform != "darwin": + # on macos, unknown argument: '-fno-tree-loop-vectorize' + if _is_gcc(cpp_compiler): + cflags.append("fno-tree-loop-vectorize") + # https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1 + # `-march=native` is unrecognized option on M1 + if not config.is_fbcode(): + if platform.machine() == "ppc64le": + cflags.append("mcpu=native") + else: + cflags.append("march=native") + + if config.aot_inductor.enable_lto and _is_clang(cpp_compiler): + cflags.append("flto=thin") + + return cflags, ldflags def _get_shared_cflags(do_link: bool) -> list[str]: @@ -652,25 +682,6 @@ def _get_shared_cflags(do_link: bool) -> list[str]: return ["shared", "fPIC"] -def _get_inductor_debug_symbol_cflags() -> tuple[list[str], list[str]]: - """ - When we turn on generate debug symbol. - On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. - On Linux, it should create some debug sections in binary file. - """ - cflags: list[str] = [] - ldflags: list[str] = [] - b_enable_debug_symbol = os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1" - if b_enable_debug_symbol: - if _IS_WINDOWS: - cflags = ["Z7", "_DEBUG", "OD"] - ldflags = ["DEBUG", "OPT:REF", "OPT:ICF"] - else: - cflags.append("g") - - return cflags, ldflags - - def get_cpp_options( cpp_compiler: str, do_link: bool, @@ -686,15 +697,14 @@ def get_cpp_options( libraries: list[str] = [] passthrough_args: list[str] = [] - dbg_cflags, dbg_ldflags = _get_inductor_debug_symbol_cflags() + opt_cflags, opt_ldflags = _get_optimization_cflags(cpp_compiler, min_optimize) cflags = ( - _get_shared_cflags(do_link) - + _get_optimization_cflags(cpp_compiler, min_optimize) + opt_cflags + + _get_shared_cflags(do_link) + _get_warning_all_cflag(warning_all) + _get_cpp_std_cflag() + _get_os_related_cpp_cflags(cpp_compiler) - + dbg_cflags ) if not _IS_WINDOWS and config.aot_inductor.enable_lto and _is_clang(cpp_compiler): @@ -707,7 +717,7 @@ def get_cpp_options( definitions, include_dirs, cflags, - ldflags + dbg_ldflags, + ldflags + opt_ldflags, libraries_dirs, libraries, passthrough_args, From b1a602762e6a6674b406a3137e7e7a678885a97b Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Thu, 7 Aug 2025 16:44:41 +0000 Subject: [PATCH 0107/1424] [Profiler] Update README (#159816) Summary: Updated README with code structure and explanation of core features within profiler Test Plan: N/A Rollback Plan: Differential Revision: D79604189 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159816 Approved by: https://github.com/sanrise, https://github.com/aaronenyeshi --- torch/csrc/profiler/README.md | 74 +++++++++++++++++++++++++++++++---- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/torch/csrc/profiler/README.md b/torch/csrc/profiler/README.md index 339c84c0a08e7..dc27337349ddc 100644 --- a/torch/csrc/profiler/README.md +++ b/torch/csrc/profiler/README.md @@ -13,14 +13,49 @@ The profiler instruments PyTorch to collect information about the model's execut - [Codebase Structure](#codebase-structure) - [`RecordFunction`](#recordfunction) - [Autograd Integration](#autograd-integration) -- [Collection and Post-Processing](#collection-and-post-processing) +- [Torch Operation Collection](#torch-operation-collection) +- [Allocation Event Collection](#allocation-event-collection) - [Kineto Integration](#kineto-integration) - [Python Tracing](#python-tracing) +- [Clock Alignment](#clock-alignment) ## Codebase Structure ## -TODO - +This section highlights directories an files that are significant to the profiler. Lesser relevant files, directories, and modules are omitted. +``` +torch/ +│ +├── profiler/ # Main package containing the core frontend logic +│ ├── __init__.py # Initialization file for profiler package +│ ├── profiler.py # Main profiler frontend class +│ └── _utils.py # FunctionEvent utils +│ +├── autograd/ # Autograd package +│ ├── __init__.py # Initialization file for autograd package +│ ├── profiler.py # Main profiler backend class +│ └── profiler_utils.py # FunctionEvent utils +│ +├── csrc/ # C and C++ source code +│ └── profiler/ # Profiler C++ source code +│ ├── collection.cpp # Main collection logic +│ ├── collection.h # Collection definitions +│ ├── kineto_client_interface.cpp # Interface to call Profiler from kineto (on-demand only) +│ ├── kineto_client_interface.h # Client interface definitions +│ ├── kineto_shim.cpp # Shim to call kineto from profiler +│ ├── kineto_shim.h # Shim definitions +│ ├── util.cpp # utils for handling args in profiler events +│ ├── util.h # util definitions +│ └── README.md # This file +│ └── autograd/ # Autograd C++ source code +│ ├── profiler_python.cpp # Main python stack collection logic +│ ├── profiler_python.h # Python stack collection definitions +│ ├── profiler_kineto.cpp # Profiler backend logic for starting collection/kineto +│ └── profiler_kineto.h # Profiler backend definitions for starting collection/kineto +│ └── ATen/ # ATen C++ source code +│ ├── record_function.cpp # RecordFunction collection logic +│ └── record_function.h # RecordFunction definitions +└── LICENSE # License information +``` ## `RecordFunction` ## [aten/src/ATen/record_function.h](../../../aten/src/ATen/record_function.h) @@ -43,14 +78,39 @@ The profiler records two pieces of information from the autograd engine: (\*) Note that only op invocations whose inputs require gradients are assigned a sequence number -## Collection and Post-Processing ## +## Torch Operation Collection ## +This section describes the general flow for collecting torch operations during auto-trace (in-process, synchronous tracing). For details on on-demand tracing (out-of-process, asynchronous), please refer to the Libkineto README. + +When a trace begins, the autograd/profiler backend calls into `profiler_kineto.cpp` to prepare, start, or stop collection. At the start of tracing, the `onFunctionEnter` and `onFunctionExit` callbacks defined in `profiler_kineto.cpp` are registered. + +Callback registration can be either global or local, depending on the `ExperimentalConfig` used: +- **Global:** The callback is registered to all threads throughout execution. +- **Local:** The callback is registered only to threads present *at the start* of tracing. +Within `onFunctionEnter`, the profiler creates a `ThreadLocalSubqueue` instance for each thread, ensuring that each CPU operation is associated with the thread on which it was executed. When a torch operation is entered, the profiler calls `begin_op` (defined in `collection.cpp`) to record the necessary information. The `begin_op` routine is intentionally lightweight, as it is on the "hot path" during profiling. Excessive overhead here would distort the profile and reduce its usefulness. Therefore, only minimal information is collected during the callback; most logic occurs during post-processing. -TODO +## Allocation Event Collection ## + +Unlike torch operations, which have a start and stop, allocation events are represented as `cpu_instant_event` (zero duration). As a result, `RecordFunction` is bypassed for these events. Instead, `emplace_allocation_event` is called directly to enqueue the event into the appropriate `ThreadLocalSubqueue`. ## Kineto Integration ## -TODO +Kineto serves as an abstraction layer for collecting events across multiple architectures. It interacts with libraries such as CUPTI to receive GPU and accelerator events, which are then forwarded to the frontend profiler. Kineto requires time to "prepare" (also referred to as "warmup") these third-party modules to avoid distorting the profile with initialization routines. While this could theoretically be done at job startup, keeping a heavy library like CUPTI running unnecessarily introduces significant overhead. +As previously mentioned, `profiler_kineto.cpp` is used in the backend to invoke the appropriate profiler stage. It also calls into `kineto_shim.cpp`, which triggers the corresponding routines in Kineto. Once a trace is complete, all events collected by Kineto are forwarded to the profiler for two main reasons: +1. To coalesce all data and complete any post-processing between profiler and Kineto events. +2. To forward these events to the Python frontend as `FunctionEvents`. +The final step in integration is file export. After all events have been collected and post-processed, they can be exported to a JSON file for visualization in Perfetto or Chrome Tracer. This is done by calling Kineto's `ActivityTraceInterface::save`, which writes all event information to disk. ## Python Tracing ## -TODO +When `with_stack=True` is set in the profiler, the Python stack tracer is generated using the `make` function defined in `PythonTracerBase`. The implementation resides in `profiler_python.cpp`. +To profile the stack, `PyEval_SetProfile` is used to trace and handle various execution events within a Python program. This enables comprehensive profiling by monitoring and responding to specific cases: +- **Python Function Calls (`PyTrace_CALL`):** The `recordPyCall` method logs each Python function call, capturing essential details for later analysis. +- **C Function Calls (`PyTrace_C_CALL`):** The `recordCCall` method documents calls to C functions, including relevant arguments, providing a complete view of the program's execution flow. +- **Python Function Returns (`PyTrace_RETURN`):** Exit times of Python functions are recorded, enabling precise measurement of function execution durations. +- **C Function Returns and Exceptions (`PyTrace_C_RETURN` and `PyTrace_C_EXCEPTION`):** Exit times for C functions are tracked, whether they conclude normally or due to an exception, ensuring all execution paths are accounted for. +This setup allows for detailed and accurate data collection on both Python and C function executions, facilitating thorough post-processing and analysis. After profiling, the accumulated event stacks are processed to match entrances and exits, constructing complete events for further analysis by the profiler. +**Note:** For Python 3.12.0–3.12.4, a bug in CPython requires the use of `sys.monitoring` as a workaround. + +## Clock Alignment ## + +Depending on the system environment, the profiler will use the most efficient clock when creating a timestamp. The default for most Linux systems is TSC, which records time in the form of CPU cycles. To convert from this time to the unix time in nanoseconds, we create a clock converter. If Kineto is included in the profiler, this converter will also be passed into Kineto as well to ensure alignment. From e167c7d0f3b77e7440208f2a4096f56a0e285c29 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Wed, 6 Aug 2025 14:08:09 -0700 Subject: [PATCH 0108/1424] [inductor] allocate non-blocking copy destinations in pinned memory (#155121) (#158758) Fixes #155121 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158758 Approved by: https://github.com/EikanWang, https://github.com/eellison --- test/inductor/test_aot_inductor.py | 30 +++++++++ test/inductor/test_torchinductor.py | 43 ++++++++++++ torch/_inductor/codegen/common.py | 3 + torch/_inductor/codegen/cpp_wrapper_cpu.py | 6 +- .../codegen/cpp_wrapper_cpu_array_ref.py | 13 +++- torch/_inductor/codegen/wrapper.py | 13 +++- torch/_inductor/ir.py | 66 ++++++++++++++++++- torch/csrc/dynamo/guards.cpp | 16 ++++- torch/csrc/inductor/aoti_torch/c/shim.h | 10 +++ .../csrc/inductor/aoti_torch/shim_common.cpp | 22 +++++++ 10 files changed, 212 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index de8a34809bd14..e0218cd9d8bec 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6785,6 +6785,36 @@ def forward(self, x, y): aot_inductor_module = torch._inductor.aoti_load_package(package_path) self.assertEqual(aot_inductor_module(*example_inputs), model(*example_inputs)) + def test_copy_non_blocking_is_pinned(self): + if self.device == "cpu" or self.device == "mps": + raise unittest.SkipTest("only matters for device-to-cpu copy") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + a_cpu = a.to(device="cpu", non_blocking=True) + b_cpu = b.to(device="cpu", non_blocking=True) + a_to_cpu_event = torch.Event() + a_to_cpu_event.record() + a_to_cpu_event.synchronize() + return torch.cat([a_cpu, b_cpu]) + + model = Model() + a = torch.randn(2, 2, device=self.device) + b = torch.randn(2, 2, device=self.device) + example_inputs = (a, b) + outputs = model(*example_inputs) + package_path, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, model, example_inputs + ) + FileCheck().check("pinned").run(code) + model_aoti = torch._inductor.aoti_load_package(package_path) + outputs_aoti = model_aoti(*example_inputs) + + self.assertEqual(outputs, outputs_aoti) + class AOTInductorLoggingTest(LoggingTestCase): @make_logging_test(dynamic=logging.DEBUG) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 3b71fe464667b..98604366b842b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -13654,6 +13654,49 @@ def forward(self, x): inputs = (torch.randn(4, device=self.device),) self.common(Model(), inputs) + @requires_cuda + @parametrize("use_cat", [True, False]) + def test_copy_non_blocking_is_pinned(self, use_cat): + def f(a_list): + a_cpu_list = [] + a_to_cpu_event_list = [] + + for a in a_list: + a_cpu = a.to(device="cpu", non_blocking=True) + a_to_cpu_event = torch.Event() + a_to_cpu_event.record() + a_cpu_list.append(a_cpu) + a_to_cpu_event_list.append(a_to_cpu_event) + + for e in a_to_cpu_event_list: + e.synchronize() + + if use_cat: + return torch.cat(a_cpu_list) + else: + return a_cpu_list + + f_compiled = torch.compile(f) + inputs = [ + torch.rand(1000, dtype=torch.float16, device=GPU_TYPE) for _ in range(100) + ] + outputs = f(inputs) + + with torch.profiler.profile( + activities=[ + getattr(torch.profiler.ProfilerActivity, GPU_TYPE.upper()), + ], + ) as p: + outputs_compiled = f_compiled(inputs) + + # outputs_compiled, (code,) = run_and_get_code(f_compiled, inputs) + # self.assertTrue("pinned" in code) + + self.assertEqual(outputs, outputs_compiled) + profile_output = str(p.key_averages()) + print(profile_output) + self.assertFalse("Pageable" in profile_output) + @dataclasses.dataclass class TestFailure: diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index dad5a281e10a6..471c9030f1e6c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -253,6 +253,9 @@ def get_stride(self) -> list[sympy.Expr]: def get_name(self) -> str: return self.outer_name + def get_is_pinned(self) -> bool: + return False + def get_inputs_that_alias_output(self) -> list[str]: return [] diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 6d11fe1c8be17..0edeabccebbd8 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1575,10 +1575,11 @@ def make_buffer_allocation(self, buffer): buffer.get_size(), buffer.get_stride(), V.graph.get_allocation_size(buffer), + buffer.get_is_pinned(), ) def make_allocation( - self, name, device, dtype, shape, stride, allocation_shape=None + self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False ): if allocation_shape is None: allocation_shape = shape @@ -1630,8 +1631,9 @@ def make_allocation( ] self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") + pinned_str = "_pinned" if is_pinned else "" self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));" ) if allocation_size != size: diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index eb3390cbc39cf..fd145ece606d1 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -565,10 +565,18 @@ def make_buffer_allocation(self, buffer): buffer.get_size(), buffer.get_stride(), buffer if self.can_stack_allocate_buffer(buffer) else None, + buffer.get_is_pinned(), ) def make_allocation( - self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None + self, + name, + device, + dtype, + shape, + stride, + buffer_if_can_stack_allocate=None, + is_pinned=False, ): orig_stride = stride device_str = self.codegen_device(device) @@ -615,8 +623,9 @@ def make_allocation( ] self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + pinned_str = "_pinned" if is_pinned else "" self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));" ) return f"RAIIAtenTensorHandle {name}({name}_handle);" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index dd03163440999..49f8549170b6b 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -998,6 +998,7 @@ def write_header(self) -> None: assert_size_stride = torch._C._dynamo.guards.assert_size_stride assert_alignment = torch._C._dynamo.guards.assert_alignment empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia @@ -2772,8 +2773,9 @@ def make_buffer_allocation(self, buffer: BufferLike): shape = tuple(buffer.get_size()) allocation_shape = tuple(V.graph.get_allocation_size(buffer)) stride = tuple(buffer.get_stride()) + is_pinned = buffer.get_is_pinned() return self.make_allocation( - buffer.get_name(), device, dtype, shape, stride, allocation_shape + buffer.get_name(), device, dtype, shape, stride, allocation_shape, is_pinned ) @cache_on_self @@ -2785,7 +2787,7 @@ def write_memory_track_allocation_once(self): self.imports.splice(import_str, strip=True) def make_allocation( - self, name, device, dtype, shape, stride, allocation_shape=None + self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False ): if allocation_shape is None: allocation_shape = shape @@ -2804,6 +2806,13 @@ def make_allocation( f"device='{device.type}', " f"name='{name}')" ) + elif device.type == "cpu" and is_pinned: + out = ( + f"{name} = empty_strided_cpu_pinned(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"{dtype})" + ) elif device.type in ("cpu", "cuda", "xpu", "mtia"): # optimized path for faster allocations, saving ~2us versus the stuff below out = ( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 3f03c33d70daa..4f9f2f1e0b59f 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -510,6 +510,7 @@ def try_match_insignificant_strides( old_layout.size, new_stride, old_layout.offset, + old_layout.is_pinned, ) return TensorBox(ReinterpretView(data=storage, layout=new_layout)) @@ -2906,6 +2907,7 @@ def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: list(new_size), new_stride, old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -2952,6 +2954,7 @@ def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView: [old_layout.size[i] for i in dims], [old_layout.stride[i] for i in dims], old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3013,6 +3016,7 @@ def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode: new_size, new_stride, old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3131,6 +3135,7 @@ def fake_reindex(index: Any) -> tuple[int, ...]: new_size, FlexibleLayout.contiguous_strides(new_size), old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3365,6 +3370,7 @@ def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView: old_layout.size, old_layout.stride, old_layout.offset, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) return DtypeView(data=x, target_dtype=new_dtype) @@ -3472,6 +3478,7 @@ def create( # type: ignore[override] new_size, new_stride, old_layout.offset + old_layout.stride[dim] * start, + old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3568,6 +3575,13 @@ def storage_size(self) -> int: @ir_dataclass class Layout(OutputSpec): + """ + Layout base class + + Carries tensor meta-information including offset and + whether it is pinned. + """ + def __init__( self, device: torch.device, @@ -3575,6 +3589,7 @@ def __init__( size: Sequence[Expr], stride: Optional[Sequence[Expr]] = None, offset: Expr = Integer(0), + is_pinned: bool = False, ) -> None: if stride is None: stride = FlexibleLayout.contiguous_strides(size) @@ -3585,6 +3600,9 @@ def __init__( self.size = size self.stride = stride self.offset = offset + self.is_pinned = is_pinned + # is_pinned implies cpu + assert (not self.is_pinned) or (self.device.type == "cpu") def __str__(self) -> str: offset = "" @@ -3592,9 +3610,12 @@ def __str__(self) -> str: offset = f", offset={self.offset}" device_index_str = "" if self.device.index is None else f":{self.device.index}" + is_pinned_str = "" + if self.is_pinned: + is_pinned_str = f", is_pinned={self.is_pinned}" return ( f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, " - f"size={self.size}, stride={self.stride}{offset})" + f"size={self.size}, stride={self.stride}{offset}{is_pinned_str})" ) __repr__ = __str__ @@ -3609,6 +3630,7 @@ def get_example(self) -> torch.Tensor: convert_shape_to_symint(self.stride), dtype=self.dtype, device=self.device, + pin_memory=self.is_pinned, ) def is_contiguous(self) -> bool: @@ -3760,6 +3782,7 @@ def as_fixed(self) -> FixedLayout: self.size, self.stride, self.offset, + self.is_pinned, ) def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: @@ -3776,6 +3799,7 @@ def __eq__(self, other: object) -> bool: and self.size == other.size and self.stride == other.stride and self.offset == other.offset + and self.is_pinned == other.is_pinned ) def storage_size(self) -> Expr: @@ -3889,6 +3913,7 @@ def as_stride_order( self.size, new_stride, self.offset, + self.is_pinned, ) def as_exact_strides( @@ -3904,6 +3929,7 @@ def as_exact_strides( self.size, new_stride, self.offset, + self.is_pinned, ) def as_fill_order(self, order: Sequence[int]) -> FixedLayout: @@ -3916,6 +3942,7 @@ def as_fill_order(self, order: Sequence[int]) -> FixedLayout: self.size, new_stride, self.offset, + self.is_pinned, ) def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: @@ -3928,6 +3955,7 @@ def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: self.size, new_stride, self.offset, + self.is_pinned, ) def __init__( @@ -3936,12 +3964,13 @@ def __init__( dtype: torch.dtype, size: Sequence[Expr], stride_order: Optional[Sequence[Union[int, Integer]]] = None, + is_pinned: bool = False, ) -> None: if stride_order: strides = FlexibleLayout.fill_ordered(size, stride_order) else: strides = FlexibleLayout.contiguous_strides(size) - super().__init__(device, dtype, size, strides) + super().__init__(device, dtype, size, strides, is_pinned=is_pinned) class NonOwningLayout(Layout): @@ -4007,6 +4036,7 @@ def __init__( size=fixed.size, stride=fixed.stride, offset=fixed.offset, + is_pinned=fixed.is_pinned, ) self.comm_buffer_type = comm_buffer_type self.group_name = group_name @@ -4181,6 +4211,9 @@ def get_output_spec(self) -> OutputSpec: def get_storage_numel(self) -> int: return self.get_numel() + def get_is_pinned(self) -> bool: + return self.get_layout().is_pinned + def freeze_layout(self) -> None: if isinstance(self.layout, Layout) and not isinstance( self.layout, NonOwningLayout @@ -5148,6 +5181,9 @@ class ConcatKernel(NopKernel): @classmethod def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: + """ + Create the concat kernel from inputs + """ device = inputs[0].get_device() dtype = inputs[0].get_dtype() new_size = list(inputs[0].get_size()) @@ -5201,6 +5237,10 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: ): output_stride = make_channels_last_strides_for(new_size) + is_pinned = all( + is_storage_and_layout(x) and x.get_layout().is_pinned for x in inputs + ) + assert device is not None concat_kernel = ConcatKernel( name=None, @@ -5209,6 +5249,7 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: dtype=dtype, size=new_size, stride=output_stride, + is_pinned=is_pinned, ), inputs=[], ) @@ -5693,6 +5734,7 @@ def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView: size=x.get_size(), stride=strides, offset=offset, + is_pinned=False, ), ) @@ -7027,12 +7069,21 @@ def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: if x.get_size(): # x.get_stride() may be unimplemented if x's size is empty stride = x.get_stride() + is_destination_pinned = ( + x_device.type == "cuda" and device.type == "cpu" and non_blocking + ) + is_source_pinned = ( + x_device.type == "cpu" and device.type == "cuda" and non_blocking + ) + if is_source_pinned and is_storage_and_layout(x): + x.get_layout().is_pinned = True return DeviceCopy( FixedLayout( device, x.get_dtype(), x.get_size(), stride, + is_pinned=is_destination_pinned, ), [cls.realize_input(x)], constant_args, @@ -7601,11 +7652,18 @@ def is_number(t: torch.JitType) -> bool: @staticmethod def tensor_to_layout(output: torch.Tensor) -> FixedLayout: + is_pinned = False + try: + is_pinned = output.is_pinned() + except RuntimeError: + # dispatch not implemented + pass return FixedLayout( output.device, output.dtype, convert_shape_to_inductor(output.size()), convert_shape_to_inductor(output.stride()), + is_pinned=is_pinned, ) @classmethod @@ -8006,6 +8064,7 @@ def realize(self) -> Optional[str]: device=device, dtype=self.data.get_dtype(), size=self.data.get_size(), + is_pinned=False, ), data=self.data, ) @@ -8186,6 +8245,7 @@ def create_output( size=output.get_size(), stride=output.get_stride(), offset=output.get_layout().offset, + is_pinned=output.get_layout().is_pinned, ), invoke_subgraph, # type: ignore[has-type] [(list, ind)], @@ -8315,6 +8375,7 @@ def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]: size=[_maybe_expr(sz) for sz in merged_output.size()], stride=[_maybe_expr(sz) for sz in merged_output.stride()], offset=output.get_layout().offset, + is_pinned=output.get_layout().is_pinned, ), conditional, [(list, i)], @@ -8542,6 +8603,7 @@ def _guard_list_equals( size=output.get_size(), stride=output.get_stride(), offset=output.get_layout().offset, + is_pinned=output.get_layout().is_pinned, ), while_loop, [(list, idx)], diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index ae7aa20be29c8..9e25d07b1e839 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -1042,7 +1042,8 @@ static void _parse_empty_strided_args( static PyObject* _empty_strided_device( PyObject* dummy, PyObject* args, - c10::DeviceType device_type) { + c10::DeviceType device_type, + bool is_pinned = false) { HANDLE_TH_ERRORS; at::SmallVector sizes; at::SmallVector strides; @@ -1050,7 +1051,7 @@ static PyObject* _empty_strided_device( _parse_empty_strided_args(args, sizes, strides, dtype); if (device_type == c10::DeviceType::CPU) { return THPVariable_Wrap( - at::detail::empty_strided_cpu(sizes, strides, dtype)); + at::detail::empty_strided_cpu(sizes, strides, dtype, is_pinned)); } #ifdef USE_CUDA else if (device_type == c10::DeviceType::CUDA) { @@ -1084,6 +1085,13 @@ static PyObject* _empty_strided_cpu(PyObject* dummy, PyObject* args) { return _empty_strided_device(dummy, args, c10::DeviceType::CPU); } +static PyObject* _empty_strided_cpu_pinned(PyObject* dummy, PyObject* args) { + // at::empty_strided is surprising slow. This is a lower-overhead + // version that saves ~2us on every allocation. + return _empty_strided_device( + dummy, args, c10::DeviceType::CPU, /*is_pinned=*/true); +} + static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) { // at::empty_strided is surprising slow. This is lower-overhead. return _empty_strided_device(dummy, args, c10::DeviceType::CUDA); @@ -1127,6 +1135,10 @@ static PyMethodDef _methods[] = { {"assert_alignment", assert_alignment, METH_VARARGS, nullptr}, {"dict_version", dict_version, METH_VARARGS, nullptr}, {"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr}, + {"_empty_strided_cpu_pinned", + _empty_strided_cpu_pinned, + METH_VARARGS, + nullptr}, {"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr}, {"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr}, {"_empty_strided_mtia", _empty_strided_mtia, METH_VARARGS, nullptr}, diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 9d512ce1f4817..d6f32358cdcc5 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -267,6 +267,16 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided( AtenTensorHandle* ret_new_tensor // returns new reference ); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided_pinned( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret_new_tensor // returns new reference +); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_as_strided( AtenTensorHandle self, const int64_t* sizes_ptr, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index a33198fd1ba06..eff8276315a20 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -452,6 +452,28 @@ AOTITorchError aoti_torch_empty_strided( }); } +AOTITorchError aoti_torch_empty_strided_pinned( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret_new_tensor) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::IntArrayRef sizes(sizes_ptr, ndim); + c10::IntArrayRef strides(strides_ptr, ndim); + TORCH_CHECK( + c10::DeviceType(device_type) == c10::DeviceType::CPU, + "only CPU tensors can be pinned"); + *ret_new_tensor = new_tensor_handle(at::detail::empty_strided_cpu( + sizes, + strides, + static_cast(dtype), + /*is_pinned=*/true)); + }); +} + AOTITorchError aoti_torch_create_tensor_from_blob( void* data, int64_t ndim, From 57f738b6357cc8fcdde479a0948e723809a1a44d Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Wed, 6 Aug 2025 14:08:09 -0700 Subject: [PATCH 0109/1424] [inductor] move all cpu scalars using pinned memory for graph partition (#155360) (#158983) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158983 Approved by: https://github.com/eellison ghstack dependencies: #158758 --- test/inductor/test_cudagraph_trees.py | 22 +++++++++++ torch/_inductor/fx_passes/post_grad.py | 55 ++++++++++++++++++++------ 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index dc8ec985fbae3..688c4d87230cf 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -2849,6 +2849,28 @@ def foo(x): self.assertEqual(x, torch.tensor(1, device="cpu")) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_cpu_scalar_multiple(self): + def f(x, y, z): + return x + y, x + z + + compiled_f = torch.compile(f, mode="reduce-overhead") + + inputs = ( + torch.ones((), device="cpu"), + torch.ones((), device="cpu"), + torch.ones(2, 2, device="cuda"), + ) + for i in range(3): + if i == 0: + _, code = run_and_get_code(compiled_f, *inputs) + FileCheck().check_regex(r".copy_.*True").run(code[0]) + FileCheck().check_count(".copy_", 1, exactly=True).run(code[0]) + else: + compiled_f(*inputs) + self.assertEqual(compiled_f(*inputs), f(*inputs)) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + @torch._inductor.config.patch("graph_partition", True) @torch._inductor.config.patch("triton.cudagraphs", False) def test_graph_partition_reduce_overhead_mode_effectiveness(self): diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 7133d77740bc9..db273b06c8e6c 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1760,17 +1760,44 @@ def __call__(self, graph: fx.Graph) -> None: movable_constructors = self.find_movable_constructors(graph, constructors) target_device = next(iter(target_devices)) - for node in movable_constructors: - if node in cpu_placeholders: - with graph.inserting_after(node): - gpu_node = graph.call_function( - torch.ops.prims.device_put.default, (node, target_device) + movable_cpu_placeholders = movable_constructors & cpu_placeholders + if movable_cpu_placeholders: + node = next(iter(reversed(movable_cpu_placeholders))) + last_node = node + unsqueezed_nodes = [] + for elem in movable_cpu_placeholders: + with graph.inserting_after(last_node): + unsqueezed_nodes.append( + graph.call_function(torch.ops.aten.unsqueeze.default, (elem, 0)) ) - node.replace_all_uses_with( - gpu_node, - lambda x: x != gpu_node - and x.target != torch.ops.aten.copy_.default, + last_node = unsqueezed_nodes[-1] + with graph.inserting_after(last_node): + cpu_concat = graph.call_function( + torch.ops.aten.cat.default, (unsqueezed_nodes,) + ) + last_node = cpu_concat + with graph.inserting_after(last_node): + gpu_concat = graph.call_function( + torch.ops.prims.device_put.default, + (cpu_concat, target_device, True), ) + last_node = gpu_concat + with graph.inserting_after(last_node): + gpu_split = graph.call_function( + torch.ops.aten.unbind.int, (gpu_concat,) + ) + last_node = gpu_split + for idx, node in enumerate(movable_cpu_placeholders): + with graph.inserting_after(last_node): + gpu_node = graph.call_function(operator.getitem, (gpu_split, idx)) + node.replace_all_uses_with( + gpu_node, + lambda x: x + not in [cpu_concat, gpu_concat, gpu_split, gpu_node] + + unsqueezed_nodes + and x.target != torch.ops.aten.copy_.default, + ) + last_node = gpu_node # noop elimination if there are other device_put for gpu_node to # target device. Alternatively, we could just move the other device_put @@ -1784,10 +1811,12 @@ def __call__(self, graph: fx.Graph) -> None: for noop in noop_device_puts: noop.replace_all_uses_with(gpu_node) graph.erase_node(noop) - else: - kwargs = node.kwargs.copy() - kwargs["device"] = target_device - node.kwargs = kwargs + + movable_constructors -= movable_cpu_placeholders + for node in movable_constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = target_device + node.kwargs = kwargs def find_movable_constructors( self, graph: fx.Graph, constructors: list[fx.Node] From 69cc606fda9d70828e01346f891298bee3917683 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Thu, 7 Aug 2025 06:48:21 -0700 Subject: [PATCH 0110/1424] HF component update to not use fsspec components (#159405) Update HF components to not inherit from fsspec components and instead use filesystem writer/reader. The reason is because there doesn't seem to be much of a need for fsspec, since users are using mounted storage. Using local storage will allow for performance improvements because we can take advantage of the safe_open API provided by HF safetensors (30s vs 4s for load of 8b model), which is signifcant performance wins over reading bytes and converting to tensors which is what we are doing now. Also, we can use the official methods provided by HF instead of relying on reading the metadata by bytes and loading it Differential Revision: [D78993550](https://our.internmc.facebook.com/intern/diff/D78993550/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159405 Approved by: https://github.com/saumishr --- torch/distributed/checkpoint/hf_storage.py | 44 ++++++---------------- 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 13fd61910dd21..81ba503fb9ee9 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -7,10 +7,10 @@ import torch from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter from torch.distributed.checkpoint._consolidate_hf_safetensors import ( consolidate_safetensors_files, ) -from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter from torch.distributed.checkpoint._hf_utils import ( _gen_file_name, _get_dtype, @@ -52,7 +52,7 @@ __all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"] -class HuggingFaceStorageWriter(FsspecWriter): +class HuggingFaceStorageWriter(FileSystemWriter): """ A writer that writes to a huggingface repository in the huggingface format. Uses Fsspec back-end to communicate with back-end storage. @@ -64,26 +64,20 @@ def __init__( path: str, fqn_to_index_mapping: Optional[dict[str, int]] = None, thread_count: int = 1, - token: Optional[str] = None, save_distributed: bool = False, enable_consolidation: bool = False, - consolidated_output_path: Optional[str] = None, thread_count_consolidation: int = 1, ) -> None: """ Initialize the huggingface writer pointing to path. Args: - path: hf directory where the checkpoint will be read from. - Needs to have .safetensors files, but can be from any fsspec supported storage, - including localFS and hf://. - This needs to be a remote path if you want to enable consolidation after saving. + path: directory where the checkpoint will be read from. fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. Indices are from 1 to N, where N is the number of files. If not provided, the tensors will be written to a single file. If none, then all the tensors on the same rank will be written to the same file. thread_count: Number of threads to use to write distributed checkpoint. Default to 1. - token: The token to use to authenticate with huggingface hub. save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard. Default is False which assumes rank-0 checkpointing of the full state_dict. enable_consolidation: If True, consolidate the sharded checkpoint after saving. The sharded tensors will be @@ -92,19 +86,11 @@ def __init__( to consolidated output files. Default to 1. """ - if token is not None: - super().__init__( - path=path, - token=token, - serialization_format=SerializationFormat.SAFETENSORS, - thread_count=thread_count, - ) - else: - super().__init__( - path=path, - serialization_format=SerializationFormat.SAFETENSORS, - thread_count=thread_count, - ) + super().__init__( + path=path, + serialization_format=SerializationFormat.SAFETENSORS, + thread_count=thread_count, + ) self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping self.save_distributed: bool = save_distributed self.enable_consolidation: bool = enable_consolidation @@ -215,28 +201,22 @@ def metadata_path(self) -> str: return _metadata_fn -class HuggingFaceStorageReader(FsspecReader): +class HuggingFaceStorageReader(FileSystemReader): """ A reader that reads from a huggingface repository in the huggingface format. Uses in Fsspec back-end to communicate with storage. Fsspec registration of the storage solution is required. """ - def __init__(self, path: str, token: Optional[str] = None) -> None: + def __init__(self, path: str) -> None: """ Initialize the huggingface reader pointing to path. Args: - path: hf directory where the checkpoint will be read from. - Needs to have .safetensors file, but can be from any fsspec supported storage, - including localFS and hf://. - token: The token to use to authenticate with huggingface hub. + path: directory where the checkpoint will be read from. """ - if token is not None: - super().__init__(path=path, token=token) - else: - super().__init__(path=path) + super().__init__(path=path) def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: per_file: dict[str, list[ReadItem]] = {} From 0b187b3114fa9f2c938d624d3c8b8b0178a666bd Mon Sep 17 00:00:00 2001 From: Ankita George Date: Thu, 7 Aug 2025 06:48:22 -0700 Subject: [PATCH 0111/1424] DCP HF reader: use safe_open instead of reading the bytes (#159406) Reading the bytes and converting to tensors is much slower than using safe_open. For a 8B model across 8 ranks, took ~30s to load before this change and ~4s after. Differential Revision: [D78994259](https://our.internmc.facebook.com/intern/diff/D78994259/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159406 Approved by: https://github.com/saumishr ghstack dependencies: #159405 --- .../distributed/checkpoint/test_hf_storage.py | 13 +++++++++++- torch/distributed/checkpoint/hf_storage.py | 20 ++++++++----------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/test/distributed/checkpoint/test_hf_storage.py b/test/distributed/checkpoint/test_hf_storage.py index 637dd228944f1..478c1722d4e39 100644 --- a/test/distributed/checkpoint/test_hf_storage.py +++ b/test/distributed/checkpoint/test_hf_storage.py @@ -162,8 +162,16 @@ def test_write_data_with_sharding(self) -> None: ) def test_read_data_hf(self) -> None: - # Create test tensors tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0]) + + mock_safe_open = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value.get_slice.return_value = tensor_0 + mock_safe_open.return_value = mock_context + + sys.modules["safetensors"] = MagicMock() + sys.modules["safetensors"].safe_open = mock_safe_open + with tempfile.TemporaryDirectory() as path: # Create the reader reader = HuggingFaceStorageReader(path=path) @@ -260,6 +268,9 @@ def test_read_data_hf(self) -> None: # Verify results - the target tensors should now contain the values from our test tensor self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0)) + mock_safe_open.assert_called_once_with(filename=file_path, framework="pt") + mock_context.__enter__.return_value.get_slice.assert_called_with("tensor_0") + def test_write_metadata_hf(self) -> None: mock_module = MagicMock() sys.modules["huggingface_hub"] = mock_module diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 81ba503fb9ee9..21a1636b308d7 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -6,7 +6,6 @@ from typing import Any, Optional import torch -from torch.distributed._shard._utils import narrow_tensor_by_index from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter from torch.distributed.checkpoint._consolidate_hf_safetensors import ( consolidate_safetensors_files, @@ -219,6 +218,8 @@ def __init__(self, path: str) -> None: super().__init__(path=path) def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + from safetensors import safe_open # type: ignore[import] + per_file: dict[str, list[ReadItem]] = {} for read_item in plan.items: @@ -227,21 +228,16 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: per_file.setdefault(file_name, []).append(read_item) for file_name, reqs in per_file.items(): - with self.fs.create_stream(file_name, "rb") as stream: + with safe_open(filename=file_name, framework="pt") as f: for req in reqs: item_md = self.storage_data[req.storage_index] - stream.seek(item_md.offset) - tensor_bytes = stream.read(item_md.length) - - tensor = torch.frombuffer( - tensor_bytes, - dtype=item_md.dtype, - ) - tensor = tensor.reshape(item_md.shape) - tensor = narrow_tensor_by_index( - tensor, req.storage_offsets, req.lengths + # Create slices for each dimension based on offsets and lengths + slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) ) + tensor = f.get_slice(req.storage_index.fqn)[slices] target_tensor = planner.resolve_tensor(req).detach() assert target_tensor.size() == tensor.size(), ( From 8399cf88ce8399d2be93355f29d4cb69f51c0654 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Thu, 7 Aug 2025 06:48:23 -0700 Subject: [PATCH 0112/1424] Use only safetensors APIs in HFStorageReader (#159681) Get rid of the logic to read the metadata from the header of the safetensors file manually and use the functions as part of safe_open() to get the metadata. This is much cleaner and allows us to not rely on our own custom methods to get metadata, but use safetensors provided APIs Differential Revision: [D79460272](https://our.internmc.facebook.com/intern/diff/D79460272/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159681 Approved by: https://github.com/saumishr ghstack dependencies: #159405, #159406 --- .../distributed/checkpoint/test_hf_storage.py | 61 +++++++++++-------- torch/distributed/checkpoint/_hf_utils.py | 2 - torch/distributed/checkpoint/hf_storage.py | 53 ++++++---------- 3 files changed, 56 insertions(+), 60 deletions(-) diff --git a/test/distributed/checkpoint/test_hf_storage.py b/test/distributed/checkpoint/test_hf_storage.py index 478c1722d4e39..81558db13a69f 100644 --- a/test/distributed/checkpoint/test_hf_storage.py +++ b/test/distributed/checkpoint/test_hf_storage.py @@ -208,8 +208,6 @@ def test_read_data_hf(self) -> None: fqn="tensor_0", offset=torch.Size([0]), index=None ): _HFStorageInfo( file_path, - len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN, - tensor_0.numel() * tensor_0.element_size(), tensor_0.shape, tensor_0.dtype, ), @@ -324,35 +322,50 @@ def test_write_metadata_hf(self) -> None: self.assertEqual(metadata, expected_metadata) def test_read_metadata_hf(self): + mock_safe_open = MagicMock() + mock_context = MagicMock() + + mock_safe_open.return_value = mock_context + + mock_context.__enter__.return_value.keys.return_value = ["tensor_0"] + mock_context.__enter__.return_value.metadata.return_value = {} + + mock_slice = MagicMock() + mock_slice.get_shape.return_value = [5, 10] + mock_slice.get_dtype.return_value = "F32" + mock_context.__enter__.return_value.get_slice.return_value = mock_slice + + mock_safetensors = MagicMock() + mock_safetensors.safe_open = mock_safe_open + + mock_safetensors.torch._getdtype = MagicMock(return_value=torch.float32) + + sys.modules["safetensors"] = mock_safetensors + sys.modules["safetensors.torch"] = mock_safetensors.torch + with tempfile.TemporaryDirectory() as path: reader = HuggingFaceStorageReader(path=path) key = "tensor_0" file_name = "test.safetensors" - with open(os.path.join(path, file_name), "wb") as f: - # write metadata the same way it would be in safetensors file - metadata_contents = json.dumps( - { - "tensor_0": { - "dtype": "F32", - "shape": [5, 10], - "data_offsets": [0, 200], - } - } - ) - metadata_bytes = metadata_contents.encode("utf-8") + file_path = os.path.join(path, file_name) - f.write( - len(metadata_bytes).to_bytes( - NUM_BYTES_FOR_HEADER_LEN, byteorder="little" - ) - ) - f.write(metadata_bytes) + # Create an empty file so fs.ls can find it + with open(file_path, "wb") as _: + pass + + # Mock the fs.ls method to return our test file + original_ls = reader.fs.ls + reader.fs.ls = MagicMock(return_value=[file_path]) - tensor = torch.rand(5, 10) - f.write(tensor.numpy().tobytes()) + try: + metadata = reader.read_metadata() + finally: + # Restore the original ls method + reader.fs.ls = original_ls - metadata = reader.read_metadata() + # Verify that safe_open was called with our file path + mock_safe_open.assert_called_once_with(file_path, framework="pt") self.assertEqual( metadata.state_dict_metadata, @@ -376,8 +389,6 @@ def test_read_metadata_hf(self): fqn=key, offset=torch.Size([0, 0]), index=None ): _HFStorageInfo( os.path.join(path, file_name), - len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN, - 200, torch.Size([5, 10]), torch.float32, ) diff --git a/torch/distributed/checkpoint/_hf_utils.py b/torch/distributed/checkpoint/_hf_utils.py index 1a3f627fd69b5..0d14229b7f8cc 100644 --- a/torch/distributed/checkpoint/_hf_utils.py +++ b/torch/distributed/checkpoint/_hf_utils.py @@ -51,8 +51,6 @@ class _HFStorageInfo: """This is the per entry storage info.""" relative_path: str - offset: int - length: int shape: torch.Size dtype: torch.dtype diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 21a1636b308d7..6b36e619f7ced 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -12,16 +12,10 @@ ) from torch.distributed.checkpoint._hf_utils import ( _gen_file_name, - _get_dtype, - _get_safetensors_file_metadata, _HFStorageInfo, _metadata_fn, CUSTOM_METADATA_KEY, - DATA_OFFSETS_KEY, - DEFAULT_EXTRA_METADATA_KEY, - DTYPE_KEY, SAVED_OFFSETS_KEY, - SHAPE_KEY, SHARDED_DIR_NAME, SUFFIX, ) @@ -252,6 +246,9 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: return fut def read_metadata(self) -> Metadata: + from safetensors import safe_open # type: ignore[import] + from safetensors.torch import _getdtype # type: ignore[import] + state_dict_metadata: dict[str, TensorStorageMetadata] = {} storage_data: dict[MetadataIndex, _HFStorageInfo] = {} @@ -261,53 +258,47 @@ def read_metadata(self) -> Metadata: safetensors_files.append(file) for safetensor_file in safetensors_files: - with self.fs.create_stream(safetensor_file, "rb") as f: - safetensors_metadata, metadata_size = _get_safetensors_file_metadata(f) - custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY) + with safe_open(safetensor_file, framework="pt") as f: + keys = f.keys() + extra_metadata = f.metadata() dcp_sharding_info = None - if custom_metadata and custom_metadata.get(CUSTOM_METADATA_KEY): + if extra_metadata and extra_metadata.get(CUSTOM_METADATA_KEY): dcp_sharding_info = json.loads( - custom_metadata.get(CUSTOM_METADATA_KEY) + extra_metadata.get(CUSTOM_METADATA_KEY) ) - for key, val in safetensors_metadata.items(): - if key == DEFAULT_EXTRA_METADATA_KEY: - continue - + for key in keys: + shape = f.get_slice(key).get_shape() + dtype = f.get_slice(key).get_dtype() # construct state_dict_metadata if dcp_sharding_info is not None: offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY] else: - offset = [0] * len(val[SHAPE_KEY]) + offset = [0] * len(shape) if key not in state_dict_metadata: state_dict_metadata[key] = TensorStorageMetadata( - properties=TensorProperties( - dtype=_get_dtype(val[DTYPE_KEY]) - ), + properties=TensorProperties(dtype=_getdtype(dtype)), size=torch.Size( - [ - saved + offset - for saved, offset in zip(val[SHAPE_KEY], offset) - ] + [saved + offset for saved, offset in zip(shape, offset)] ), chunks=[ ChunkStorageMetadata( offsets=torch.Size(offset), - sizes=torch.Size(val[SHAPE_KEY]), + sizes=torch.Size(shape), ) ], ) else: state_dict_metadata[key].chunks.append( ChunkStorageMetadata( - torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]) + torch.Size(offset), sizes=torch.Size(shape) ) ) size = list(state_dict_metadata[key].size) for i in range(len(size)): - size[i] = max(size[i], val[SHAPE_KEY][i] + offset[i]) + size[i] = max(size[i], shape[i] + offset[i]) state_dict_metadata[key].size = torch.Size(size) # construct storage data @@ -316,15 +307,11 @@ def read_metadata(self) -> Metadata: fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY] ) else: - metadata_index = MetadataIndex( - fqn=key, offset=[0] * len(val[SHAPE_KEY]) - ) + metadata_index = MetadataIndex(fqn=key, offset=[0] * len(shape)) storage_data[metadata_index] = _HFStorageInfo( relative_path=safetensor_file, - offset=val[DATA_OFFSETS_KEY][0] + metadata_size, - length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0], - shape=torch.Size(val[SHAPE_KEY]), - dtype=_get_dtype(val[DTYPE_KEY]), + shape=torch.Size(shape), + dtype=_getdtype(dtype), ) metadata = Metadata( From 0bd3af4fb87445f4de3a1f9b823e399c8b3cefde Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Thu, 7 Aug 2025 17:32:58 +0000 Subject: [PATCH 0113/1424] Further fix failing tests in test/inductor/test_analysis.py (#160070) This is a follow up on #159800 as other tests are still failing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160070 Approved by: https://github.com/aorenste --- test/inductor/test_analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_analysis.py b/test/inductor/test_analysis.py index 51c601b4d1d7b..ac0467a2d1b80 100644 --- a/test/inductor/test_analysis.py +++ b/test/inductor/test_analysis.py @@ -337,6 +337,7 @@ def test_augment_trace_helper_unit(self): ], ) @skipIf(not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune") + @torch._inductor.config.patch(force_disable_caches=True) def test_triton_has_metadata(self, device, dtype, maxat): """ make sure that the chrome trace of triton kernels contains certain values @@ -359,7 +360,6 @@ def om(i, w): options={ "benchmark_kernel": True, "max_autotune_gemm_backends": backends, - "force_disable_caches": True, "max_autotune": max_autotune, }, ) @@ -507,6 +507,7 @@ def test_augment_trace_against_flop_counter(self, device, dtype, maxat): @unittest.skipIf( not IS_BIG_GPU, "we can't use Triton only as a backend for max autotune" ) + @torch._inductor.config.patch(force_disable_caches=True) def test_pointwise_bandwidth(self, device, dtype, maxat): # this tests to see if we can only use a Triton backend for max autotune max_autotune, backends = maxat @@ -518,7 +519,6 @@ def test_pointwise_bandwidth(self, device, dtype, maxat): options={ "benchmark_kernel": True, "max_autotune_gemm_backends": backends, - "force_disable_caches": True, "max_autotune": max_autotune, }, ) From ee1fb43450c2e985657f95a91b68328d6f20f24e Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov Date: Thu, 7 Aug 2025 17:41:47 +0000 Subject: [PATCH 0114/1424] Fix docker image creation (#158634) Since switching from wheel 0.34.2 to wheel 0.45.1 python symlinks are no longer correctly created. Migrate to packaging package for symlink creation Pull Request resolved: https://github.com/pytorch/pytorch/pull/158634 Approved by: https://github.com/malfet --- .ci/docker/common/install_cpython.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.ci/docker/common/install_cpython.sh b/.ci/docker/common/install_cpython.sh index d7fc6ea264ddb..c160e5704ba31 100755 --- a/.ci/docker/common/install_cpython.sh +++ b/.ci/docker/common/install_cpython.sh @@ -66,8 +66,9 @@ function do_cpython_build { ln -s pip3 ${prefix}/bin/pip fi # install setuptools since python 3.12 is required to use distutils - ${prefix}/bin/pip install wheel==0.45.1 setuptools==80.9.0 - local abi_tag=$(${prefix}/bin/python -c "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag; print('{0}{1}-{2}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag()))") + # packaging is needed to create symlink since wheel no longer provides needed information + ${prefix}/bin/pip install packaging==25.0 wheel==0.45.1 setuptools==80.9.0 + local abi_tag=$(${prefix}/bin/python -c "from packaging.tags import interpreter_name, interpreter_version; import sysconfig ; from sysconfig import get_config_var; print('{0}{1}-{0}{1}{2}'.format(interpreter_name(), interpreter_version(), 't' if sysconfig.get_config_var('Py_GIL_DISABLED') else ''))") ln -sf ${prefix} /opt/python/${abi_tag} } From 21392c0e06ac2b2621950455975ca6332f0bf641 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Thu, 7 Aug 2025 18:07:32 +0000 Subject: [PATCH 0115/1424] [inductor] disable flex decoding on Windows. (#160072) Discussed with @jianan-gu and @Valentine233 , disable flex decoding on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160072 Approved by: https://github.com/angelayi --- test/inductor/test_flex_decoding.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index b5ec59dc291c6..9a0cb945fc331 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -2,6 +2,7 @@ # flake8: noqa: B950 import functools +import sys import unittest from collections import namedtuple from typing import Callable, Optional, Union @@ -27,6 +28,15 @@ flex_attention_supported_platform as supported_platform, instantiate_device_type_tests, ) +from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS + + +if IS_WINDOWS and IS_CI: + # TODO(xuhancn) : Need track if it is a requirement on windows. + sys.stderr.write("This UT is validated on windows, a lot of crash. Skip it.\n") + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("skip on Windows") Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) From 3cf7b4024ef83e44e9ae223dbff7c7ab68240cb2 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 6 Aug 2025 15:13:35 -0700 Subject: [PATCH 0116/1424] [DTensor] Support user-supplied Generator for random ops (#159933) If the user provides a generator kwarg to a random op (e.g. nn.init.uniform_(..., generator=my_generator)), we can still advance that generator's state in a SPMD-global way so that each local-tensor gets appropriate values and the generator advances to the same state as if it had operated on the full tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159933 Approved by: https://github.com/fduwjj, https://github.com/XilunWu, https://github.com/wanchaol --- test/distributed/tensor/test_random_ops.py | 32 ++++++++++++++++ torch/distributed/tensor/_dispatch.py | 14 ++++++- torch/distributed/tensor/_random.py | 44 ++++++++++++++++------ 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 5e98934249e97..180286bd2e1da 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -87,6 +87,38 @@ def test_init_ops(self): self._run_init_op(torch.randn_like, dtype=dtype) self._run_init_op(torch.randint_like, low=0, high=100, dtype=dtype) + @with_comms + @skip_if_lt_x_gpu(4) + def test_init_with_user_generator(self): + device_mesh = self.build_device_mesh() + torch.manual_seed(42) + rng = torch.Generator(device="cuda").manual_seed(42) + t1 = torch.distributed.tensor.empty( + (8, 3), device_mesh=device_mesh, placements=[Shard(0)] + ) + t2 = torch.distributed.tensor.empty( + (8, 3), device_mesh=device_mesh, placements=[Shard(0)] + ) + for i in range(2): + # run a second time, to make sure that `rng`'s offset-state is advancing on the second usage + torch.nn.init.uniform_(t1, 0.0, 1.0) + torch.nn.init.uniform_(t2, 0.0, 1.0, rng) + self.assertEqual(t1.full_tensor(), t2.full_tensor(), f"Failed at {i=}") + + # ensure that we do not cache the 'seed' of `rng` from the first time we see it in DTensor + # TODO: we have a semantics decision to make + # There is a discontinuity between how the default RNG and a user-supplied RNG behaves with DTensor: + # (a) if the user calls `torch.manual_seed` after already using the default RNG with DTensor, + # they may be surprised that it has no effect on DTensor. They must instead call this private API + # (`torch.distributed.tensor._random._rng_tracker._manual_seed`) + # (b) If we try to match the semantics of (a) with a user-supplied RNG, they may be very surprised to find that + # their RNG object never advances its state after using it with DTensor. + # torch.distributed.tensor._random._rng_tracker._manual_seed(55) + # rng.manual_seed(55) + # torch.nn.init.uniform_(t1, 0.0, 1.0) + # torch.nn.init.uniform_(t2, 0.0, 1.0, rng) + # self.assertEqual(t1.full_tensor(), t2.full_tensor()) + @with_comms @skip_if_lt_x_gpu(4) def test_meta_tensor_init(self): diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 346e2966b15b5..faa2a1ba4941f 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -138,7 +138,6 @@ def dispatch( (2) registered sharding strategy, then rule (3) composite implicit autograd decomposition """ - if op_call in self._custom_op_handlers: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] @@ -197,8 +196,19 @@ def dispatch( cast(dtensor.DTensor, args[0]), cast(torch.Tensor, local_tensor_args[0]), ) + + # If the user provided a generator, we hook it up to our RNG manager, but we also pop it from kwargs + # so the op_call does not directly use it (we want op_call to fall back to the 'default' which is + # our RNG manager) + maybe_user_generator = op_info.local_kwargs.pop("generator", None) + assert maybe_user_generator is None or isinstance( + maybe_user_generator, torch.Generator + ) + # maybe_user_generator = None rng_context = ( - random._rng_tracker._distribute_region(first_arg._spec) + random._rng_tracker._distribute_region( + first_arg._spec, generator=maybe_user_generator + ) if random._rng_tracker and not first_local_arg.is_meta else contextlib.nullcontext() ) diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 082805db7fde3..70ea7e9ce97aa 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -146,7 +146,9 @@ def set_seed(self, name: str, seed: int) -> None: ) self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) - def _distribute_region(self, spec: DTensorSpec): + def _distribute_region( + self, spec: DTensorSpec, generator: Optional[torch.Generator] = None + ): pass def _manual_seed(self, parallel_seed: int) -> None: @@ -191,7 +193,17 @@ def _manual_seed(self, parallel_seed: int) -> None: self.set_seed("parallel-rng", parallel_seed) @contextlib.contextmanager - def _distribute_region(self, spec: DTensorSpec): + def _distribute_region( + self, spec: DTensorSpec, generator: Optional[torch.Generator] = None + ): + g_name = "parallel-rng" + if generator is not None: + # This is a little hacky, but for any user-passed generator, we store its state under a unique key, + # not because we need to keep a copy of it but because its the easiest way to make it work with the + # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region. + g_name = "user-passed-generator" + assert g_name not in self.rng_states + self.rng_states[g_name] = generator.get_state() # check if the parallel rng state has been synchronized or not if not self.rng_state_is_sync("parallel-rng"): raise RuntimeError( @@ -202,23 +214,29 @@ def _distribute_region(self, spec: DTensorSpec): if self.distribute_region_enabled: if self._device.type == "hpu": self._device_handle.set_rng_ctx("philox") - old_offset = self.get_offset("parallel-rng") - self._set_pre_op_offset(spec) + old_offset = self.get_offset(g_name) + self._set_pre_op_offset(g_name, spec) with torch.random.fork_rng( devices=[self._device], device_type=self._device.type ): assert self._device_handle is not None - self._device_handle.set_rng_state(self.rng_states["parallel-rng"]) + self._device_handle.set_rng_state(self.rng_states[g_name]) try: yield # execute the region code finally: # update offset to synchronize among ranks - self._set_post_op_offset(spec, old_offset) + self._set_post_op_offset(g_name, spec, old_offset) if self._device.type == "hpu": self._device_handle.unset_rng_ctx("philox") else: yield + if generator is not None: + # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future + # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates + # the seed value in their rng and uses it with DTensor again, we always use the latest value + generator.set_state(self.rng_states.pop(g_name)) + def get_offset(self, name: str) -> int: if name not in self.rng_states: raise RuntimeError( @@ -240,7 +258,7 @@ def set_offset(self, name: str, offset: int) -> None: ) self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) - def _set_pre_op_offset(self, spec: DTensorSpec) -> None: + def _set_pre_op_offset(self, name: str, spec: DTensorSpec) -> None: """Set the starting RNG offset for current device's local shard before actual op execution. The pre_op_offset value should start from the current RNG offset and increment by the size of local shard until it reaches the size of the whole @@ -248,6 +266,7 @@ def _set_pre_op_offset(self, spec: DTensorSpec) -> None: will be the same. Args: + name (str): The name of the generator to use (should be a key in self.rng_states) spec (:class:`DTensorSpec`): the spec of the DTensor object on which we prepare the offset for running random ops. @@ -350,20 +369,23 @@ def _set_pre_op_offset(self, spec: DTensorSpec) -> None: local_size = prod(local_size_on_rank_0) # get current RNG offset - current_offset = self.get_offset("parallel-rng") + current_offset = self.get_offset(name) # pytorch: offset must be multiple of 4 # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 - self.set_offset("parallel-rng", current_offset + offset_incr) + self.set_offset(name, current_offset + offset_incr) - def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: + def _set_post_op_offset( + self, name: str, spec: DTensorSpec, old_offset: int + ) -> None: """Sets the RNG to a synchronized state after running the local random op. Every rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor random ops. Args: + name (str): The name of the generator to use (should be a key in self.rng_states) spec (:class:`DTensorSpec`): the spec of the DTensor object on which we post-process the offset for running random ops. @@ -378,7 +400,7 @@ def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: # pytorch: offset must be multiple of 4 # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp numel = (numel + 3) // 4 * 4 - self.set_offset("parallel-rng", old_offset + numel) + self.set_offset(name, old_offset + numel) def _calc_shard_linear_idx( self, shard_coord: list[int], shard_size: list[int] From e619c6bb90b9dedaccd3cbeed86a288993a4e33f Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Thu, 7 Aug 2025 18:51:11 +0000 Subject: [PATCH 0117/1424] [export] Apply move_to_device_pass to all submodules (#159992) Previously we only applied this move_to_device_pass to the toplevel graph. However if we have HOO, this pass will not be applied on the HOO submodules. This PR modifies the pass to run on all submodules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159992 Approved by: https://github.com/yiming0416 --- test/export/test_passes.py | 22 ++++++++++++++++++++++ torch/export/passes/__init__.py | 28 +++++++++++++++------------- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/test/export/test_passes.py b/test/export/test_passes.py index d3194ea352c31..d083b5a7cc6d1 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -1302,6 +1302,28 @@ def forward(self, x): return (b_state, getitem_3, getitem_4)""", ) + @unittest.skipIf(not TEST_CUDA, "requires cuda") + def test_move_device_submod(self): + class M(torch.nn.Module): + def forward(self, x): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + x = x.to(device="cuda:0") + return x + x + + ep = torch.export.export(M(), (torch.ones(3),)) + ep = move_to_device_pass(ep, "cuda") + ep.graph_module.submod_1.recompile() + self.assertExpectedInline( + ep.graph_module.submod_1.code.strip("\n"), + """\ +def forward(self, arg0_1): + _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(arg0_1, dtype = torch.float32, device = 'cuda', layout = torch.strided); _assert_tensor_metadata_default = None + to = torch.ops.aten.to.dtype_layout(arg0_1, dtype = torch.float32, layout = torch.strided, device = 'cuda'); arg0_1 = None + add = torch.ops.aten.add.Tensor(to, to); to = None + return (add,) + """, # noqa: B950 + ) + @unittest.skipIf(not TEST_CUDA, "requires cuda") def test_move_to_device_pass(self): class Model(torch.nn.Module): diff --git a/torch/export/passes/__init__.py b/torch/export/passes/__init__.py index 4e1d21de660dc..4238bac5899ec 100644 --- a/torch/export/passes/__init__.py +++ b/torch/export/passes/__init__.py @@ -52,19 +52,21 @@ def _get_new_device( if isinstance(v, torch.Tensor): ep._constants[k] = v.to(_get_new_device(v.device, location)) - for node in ep.graph.nodes: - # move all the nodes kwargs with burnt-in device - if "device" in node.kwargs: - kwargs = node.kwargs.copy() - kwargs["device"] = _get_new_device(kwargs["device"], location) - node.kwargs = kwargs - # move all the tensor metadata - node.meta["val"] = pytree.tree_map( - lambda v: v.to(_get_new_device(v.device, location)) - if isinstance(v, torch.Tensor) - else v, - node.meta.get("val"), - ) + for m in ep.graph_module.modules(): + if isinstance(m, torch.fx.GraphModule): + for node in m.graph.nodes: + # move all the nodes kwargs with burnt-in device + if "device" in node.kwargs: + kwargs = node.kwargs.copy() + kwargs["device"] = _get_new_device(kwargs["device"], location) + node.kwargs = kwargs + # move all the tensor metadata + node.meta["val"] = pytree.tree_map( + lambda v: v.to(_get_new_device(v.device, location)) + if isinstance(v, torch.Tensor) + else v, + node.meta.get("val"), + ) ep.validate() return ep From 8147370733bbdcd034cad54e9212e51885a11892 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 7 Aug 2025 21:22:29 +0000 Subject: [PATCH 0118/1424] Fix qembeddingbag_byte_prepack_meta to use sym_sizes (#159985) Summary: In qembeddingbag_byte_prepack_meta, weight.sizes() would return a concrete int. we should use .sym_size() to return a SymInt instead. Test Plan: CI Rollback Plan: Reviewed By: kqfu, henryoier Differential Revision: D79744512 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159985 Approved by: https://github.com/jerryzh168, https://github.com/henryoier --- .../native/quantized/cpu/qembeddingbag_prepack.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 1e91fecd45005..807a9b25d3772 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -333,14 +333,14 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half, "'embedding_bag_byte_prepack' only support float32 or float16."); - const auto weight_sizes = weight.sizes(); - const auto cols_dim = weight_sizes.size() - 1; - const int32_t embedding_cols = static_cast(weight_sizes[cols_dim]); + const auto weight_sizes = weight.sym_sizes(); + const auto cols_dim = weight.ndimension() - 1; + const auto embedding_cols = weight_sizes[cols_dim]; // Add 8 bytes per column to store FP32 scale and zero_point per row. - const int32_t output_columns = static_cast(embedding_cols + 2 * sizeof(float)); + const auto output_columns = embedding_cols + 2 * sizeof(float); // Adjust output dimensions to account for FP32 scale and zero_points. - std::vector output_shape = weight_sizes.vec(); + auto output_shape = weight_sizes.vec(); output_shape.at(cols_dim) = output_columns; at::SymDimVector output_shape_vec(output_shape); From 36f46d082a4954921cb8493223f000f2aab79ed7 Mon Sep 17 00:00:00 2001 From: clr Date: Tue, 5 Aug 2025 16:34:10 -0700 Subject: [PATCH 0119/1424] dynamo: Remove passing or deleted dynamo_expected_failures (#159691) partially generated with ``` for TESTCASE in $(ls | cut -f1 -d'.' | grep -v CPython | uniq); do if grep "$TESTCASE" -m 1 .. -r; then echo; else sl rm "$TESTCASE"* ; fi; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159691 Approved by: https://github.com/xmfan --- test/dynamo_expected_failures/FunctionTests.test_default_dict | 0 .../FunctionTests.test_default_dict_closure | 0 .../FunctionTests.test_default_dict_lambda | 0 .../FunctionTests.test_is_contiguous_frame_counts | 0 test/dynamo_expected_failures/FunctionTests.test_math_radians | 0 .../FunctionTests.test_partials_as_input_partials_lambda | 0 .../FunctionTests.test_partials_as_input_partials_mod | 0 ...TestExport.test__scaled_dot_product_flash_attention_non_strict | 0 ...tExportTestExport.test_basic_non_strict_fake_tensor_non_strict | 0 ...tExportTestExport.test_basic_non_strict_real_tensor_non_strict | 0 .../NonStrictExportTestExport.test_buffer_util_non_strict | 0 ...tTestExport.test_cond_with_module_stack_export_with_non_strict | 0 ...nStrictExportTestExport.test_export_decomps_dynamic_non_strict | 0 ...onStrictExportTestExport.test_export_decomps_simple_non_strict | 0 ...trictExportTestExport.test_export_with_wrong_inputs_non_strict | 0 ...estExport.test_external_call_non_strict_real_tensor_non_strict | 0 .../NonStrictExportTestExport.test_fqn_non_strict | 0 .../NonStrictExportTestExport.test_nn_module_stack_non_strict | 0 ...ortTestExport.test_nn_module_stack_shared_submodule_non_strict | 0 ...rictExportTestExport.test_non_strict_dynamic_shapes_non_strict | 0 ...port.test_non_strict_dynamic_shapes_suggested_fixes_non_strict | 0 .../NonStrictExportTestExport.test_param_util_non_strict | 0 ...e_user_error_when_guard_on_data_dependent_operation_non_strict | 0 .../NonStrictExportTestExport.test_sym_sqrt_non_strict | 0 ...tExport.test_to_module_with_mutated_buffer_multiple_non_strict | 0 ...odule_with_mutated_buffer_multiple_update_sub_later_non_strict | 0 ...ExportTestExport.test_to_module_with_mutated_buffer_non_strict | 0 .../NumpyTestsCPU.test_boolean_indexing_weirdness_cpu | 0 .../NumpyTestsCPU.test_boolean_shape_mismatch_cpu | 0 .../NumpyTestsCPU.test_empty_fancy_index_cpu | 0 .../NumpyTestsCPU.test_index_no_floats_cpu | 0 ...namismExpression.test_export_inline_constraints_retraceability | 0 ...tExport.test_cond_with_module_stack_export_with_retraceability | 0 ...ceExportTestExport.test_constrain_size_in_eager_retraceability | 0 ...Export.test_constrain_size_with_constrain_value_retraceability | 0 ...stExport.test_constrain_size_with_various_cases_retraceability | 0 .../RetraceExportTestExport.test_nn_module_stack_retraceability | 0 ...estExport.test_nn_module_stack_shared_submodule_retraceability | 0 ...ExportTestExport.test_non_strict_dynamic_shapes_retraceability | 0 ....test_non_strict_dynamic_shapes_suggested_fixes_retraceability | 0 ...rtTestDynamismExpression.test_export_inline_constraints_serdes | 0 ...erDesExportTestExport.test_basic_non_strict_fake_tensor_serdes | 0 ...erDesExportTestExport.test_basic_non_strict_real_tensor_serdes | 0 ...xportTestExport.test_cond_with_module_stack_export_with_serdes | 0 .../SerDesExportTestExport.test_constrain_size_in_eager_serdes | 0 ...portTestExport.test_constrain_size_with_constrain_value_serdes | 0 ...ExportTestExport.test_constrain_size_with_various_cases_serdes | 0 ...ortTestExport.test_external_call_non_strict_real_tensor_serdes | 0 .../SerDesExportTestExport.test_nn_module_stack_serdes | 0 ...sExportTestExport.test_nn_module_stack_shared_submodule_serdes | 0 .../SerDesExportTestExport.test_non_strict_dynamic_shapes_serdes | 0 ...stExport.test_non_strict_dynamic_shapes_suggested_fixes_serdes | 0 ...rad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu | 0 ...ad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu | 0 ...rad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu | 0 ...ad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu | 0 ...grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu | 0 ...rad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu | 0 ...grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu | 0 ...rad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu | 0 ...tAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu | 0 ...ad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda | 0 ...d_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda | 0 ...ad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda | 0 ...d_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda | 0 ...rad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda | 0 ...ad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda | 0 ...rad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda | 0 ...ad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda | 0 ...utogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda | 0 ...UDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda | 0 ...ICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda | 0 .../TestBufferProtocolCPU.test_byte_to_int_cpu | 0 ....test_autograd_function_no_setup_context_transform_hessian_cpu | 0 ...U.test_autograd_function_no_setup_context_transform_jacfwd_cpu | 0 ...ityCPU.test_deprecation_transforms_transform_functionalize_cpu | 0 .../TestComposabilityCPU.test_requires_grad_inside_transform_cpu | 0 ...test_autograd_function_no_setup_context_transform_hessian_cuda | 0 ....test_autograd_function_no_setup_context_transform_jacfwd_cuda | 0 ...estComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda | 0 ...TestComposabilityCUDA.test_requires_grad_inside_transform_cuda | 0 .../TestContentStoreCPU.test_repeated_hash_cpu | 0 .../TestCppExtensionOpenRgistration.test_open_device_registration | 0 ...d_weight_per_sample_grad_mean_nn_functional_conv1d_cpu_float64 | 0 ...d_weight_per_sample_grad_mean_nn_functional_conv2d_cpu_float64 | 0 ...d_weight_per_sample_grad_mean_nn_functional_conv3d_cpu_float64 | 0 ...ight_per_sample_grad_mean_nn_functional_group_norm_cpu_float64 | 0 ...t_per_sample_grad_mean_nn_functional_instance_norm_cpu_float64 | 0 ...ight_per_sample_grad_mean_nn_functional_layer_norm_cpu_float64 | 0 ...ed_weight_per_sample_grad_sum_nn_functional_conv1d_cpu_float64 | 0 ...ed_weight_per_sample_grad_sum_nn_functional_conv2d_cpu_float64 | 0 ...ed_weight_per_sample_grad_sum_nn_functional_conv3d_cpu_float64 | 0 ...eight_per_sample_grad_sum_nn_functional_group_norm_cpu_float64 | 0 ...ht_per_sample_grad_sum_nn_functional_instance_norm_cpu_float64 | 0 ...eight_per_sample_grad_sum_nn_functional_layer_norm_cpu_float64 | 0 ...per_sample_grad_input_no_grad_nn_functional_conv1d_cpu_float64 | 0 ...per_sample_grad_input_no_grad_nn_functional_conv2d_cpu_float64 | 0 ...per_sample_grad_input_no_grad_nn_functional_conv3d_cpu_float64 | 0 ...sample_grad_input_no_grad_nn_functional_group_norm_cpu_float64 | 0 ...ple_grad_input_no_grad_nn_functional_instance_norm_cpu_float64 | 0 ...sample_grad_input_no_grad_nn_functional_layer_norm_cpu_float64 | 0 .../TestFunctionalizeCPU.test_multioutput_view_cpu | 0 .../TestFunctionalizeCPU.test_simple_view_cpu | 0 .../TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu | 0 .../TestHessianCPU.test_jacfwd_different_levels_cpu | 0 .../TestHessianCUDA.test_jacfwd_different_levels_cuda | 0 ...tHigherOrderOperatorInteractionCPU.test_grad_name_wrapping_cpu | 0 test/dynamo_expected_failures/TestIndexingCPU.test_byte_mask_cpu | 0 .../TestIndexingCPU.test_empty_ndim_index_bool_cpu | 0 test/dynamo_expected_failures/TestIndexingCPU.test_index_cpu | 0 .../TestIndexingCPU.test_index_limits_cpu | 0 .../TestIndexingCPU.test_out_of_bound_index_cpu | 0 .../TestIndexingCPU.test_zero_dim_index_cpu | 0 ...acCPU.test_against_reference_correctness_different_devices_cpu | 0 .../TestJacCPU.test_against_reference_default_arg_cpu | 0 .../TestJacCPU.test_against_reference_multi_input_cpu | 0 ...TestJacCPU.test_against_reference_multi_input_multi_output_cpu | 0 .../TestJacCPU.test_against_reference_simple_cpu | 0 .../TestJacCPU.test_against_reference_unrelated_outputs_cpu | 0 .../TestJacCPU.test_against_reference_zero_dim_cpu | 0 .../TestJacCPU.test_argnums_defaults_to_zero_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_aux_pytree_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_dimensionality_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_empty_output_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_inplace_cpu | 0 .../TestJacCPU.test_jac_with_non_tensor_args_cpu | 0 .../TestJacCPU.test_multiple_inputs_outputs_pytree_cpu | 0 .../TestJacCPU.test_multiple_inputs_pytree_cpu | 0 .../TestJacCPU.test_multiple_outputs_multiple_argnums_cpu | 0 .../TestJacCPU.test_multiple_outputs_single_argnums_cpu | 0 .../TestJacCPU.test_outputs_can_any_pytree_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_unrelated_input_cpu | 0 .../dynamo_expected_failures/TestJacCPU.test_unrelated_output_cpu | 0 .../TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu | 0 .../TestNNDeviceTypeCPU.test_module_to_empty_cpu_float32 | 0 .../TestNNDeviceTypeCPU.test_module_to_empty_cpu_float64 | 0 ...TestNNDeviceTypeCPU.test_nll_loss_byte_target_matches_long_cpu | 0 .../TestNNDeviceTypeCPU.test_threshold_inplace_overlap_cpu | 0 ...ionDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda | 0 ...tionDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda | 0 .../TestNumPyInteropCPU.test_numpy_non_writeable_cpu | 0 .../TestReductionsCPU.test_std_vs_numpy_cpu_complex128 | 0 .../TestReductionsCPU.test_std_vs_numpy_cpu_complex64 | 0 .../TestReductionsCPU.test_std_vs_numpy_cpu_float32 | 0 .../TestReductionsCPU.test_std_vs_numpy_cpu_float64 | 0 .../TestReductionsCPU.test_var_vs_numpy_cpu_complex128 | 0 .../TestReductionsCPU.test_var_vs_numpy_cpu_complex64 | 0 .../TestReductionsCPU.test_var_vs_numpy_cpu_float32 | 0 .../TestReductionsCPU.test_var_vs_numpy_cpu_float64 | 0 ...ed_sdp_choice_cpu_type_dense_dropout_0_0_bfloat16_cpu_bfloat16 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_0_float16_cpu_float16 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_0_float32_cpu_float32 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_0_float64_cpu_float64 | 0 ...ed_sdp_choice_cpu_type_dense_dropout_0_7_bfloat16_cpu_bfloat16 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_7_float16_cpu_float16 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_7_float32_cpu_float32 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_7_float64_cpu_float64 | 0 .../TestShapeOpsCUDA.test_flip_cuda_float32 | 0 .../TestTensorCreationCPU.test_block_diag_cpu | 0 .../TestTensorCreationCPU.test_constructor_dtypes_cpu | 0 .../TestTypePromotionCPU.test_alpha_mismatch_cpu | 0 .../TestTypePromotionCPU.test_alternate_result_cpu | 0 test/dynamo_expected_failures/UnspecTests.test_builtin_max_min | 0 .../UnspecTests.test_conv1d_symint_padding | 0 test/dynamo_expected_failures/UnspecTests.test_isinstance_symint | 0 test/dynamo_expected_failures/UnspecTests.test_mark_01_dynamic | 0 test/dynamo_expected_failures/UnspecTests.test_no_recompilations | 0 test/dynamo_expected_failures/UnspecTests.test_no_recompiles | 0 .../UnspecTests.test_propagate_dynamic_dim | 0 test/dynamo_expected_failures/UnspecTests.test_use_and_specialize | 0 170 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/dynamo_expected_failures/FunctionTests.test_default_dict delete mode 100644 test/dynamo_expected_failures/FunctionTests.test_default_dict_closure delete mode 100644 test/dynamo_expected_failures/FunctionTests.test_default_dict_lambda delete mode 100644 test/dynamo_expected_failures/FunctionTests.test_is_contiguous_frame_counts delete mode 100644 test/dynamo_expected_failures/FunctionTests.test_math_radians delete mode 100644 test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_lambda delete mode 100644 test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_mod delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test__scaled_dot_product_flash_attention_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_fake_tensor_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_real_tensor_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_buffer_util_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_cond_with_module_stack_export_with_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_dynamic_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_simple_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_export_with_wrong_inputs_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_external_call_non_strict_real_tensor_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_fqn_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_shared_submodule_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_param_util_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_raise_user_error_when_guard_on_data_dependent_operation_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_sym_sqrt_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_update_sub_later_non_strict delete mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_non_strict delete mode 100644 test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_indexing_weirdness_cpu delete mode 100644 test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_shape_mismatch_cpu delete mode 100644 test/dynamo_expected_failures/NumpyTestsCPU.test_empty_fancy_index_cpu delete mode 100644 test/dynamo_expected_failures/NumpyTestsCPU.test_index_no_floats_cpu delete mode 100644 test/dynamo_expected_failures/RetraceExportTestDynamismExpression.test_export_inline_constraints_retraceability delete mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_cond_with_module_stack_export_with_retraceability delete mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_in_eager_retraceability delete mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_constrain_value_retraceability delete mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_various_cases_retraceability delete mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_retraceability delete mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_shared_submodule_retraceability delete mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_retraceability delete mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_retraceability delete mode 100644 test/dynamo_expected_failures/SerDesExportTestDynamismExpression.test_export_inline_constraints_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_fake_tensor_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_real_tensor_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_cond_with_module_stack_export_with_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_in_eager_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_constrain_value_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_various_cases_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_external_call_non_strict_real_tensor_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_shared_submodule_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_serdes delete mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_serdes delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda delete mode 100644 test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda delete mode 100644 test/dynamo_expected_failures/TestBufferProtocolCPU.test_byte_to_int_cpu delete mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu delete mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu delete mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_deprecation_transforms_transform_functionalize_cpu delete mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_requires_grad_inside_transform_cpu delete mode 100644 test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_hessian_cuda delete mode 100644 test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_jacfwd_cuda delete mode 100644 test/dynamo_expected_failures/TestComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda delete mode 100644 test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda delete mode 100644 test/dynamo_expected_failures/TestContentStoreCPU.test_repeated_hash_cpu delete mode 100644 test/dynamo_expected_failures/TestCppExtensionOpenRgistration.test_open_device_registration delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv1d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv2d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv3d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_group_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_instance_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_layer_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv1d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv2d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv3d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_group_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_instance_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_layer_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv1d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv2d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv3d_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_group_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_instance_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_layer_norm_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestFunctionalizeCPU.test_multioutput_view_cpu delete mode 100644 test/dynamo_expected_failures/TestFunctionalizeCPU.test_simple_view_cpu delete mode 100644 test/dynamo_expected_failures/TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu delete mode 100644 test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu delete mode 100644 test/dynamo_expected_failures/TestHessianCUDA.test_jacfwd_different_levels_cuda delete mode 100644 test/dynamo_expected_failures/TestHigherOrderOperatorInteractionCPU.test_grad_name_wrapping_cpu delete mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_byte_mask_cpu delete mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_empty_ndim_index_bool_cpu delete mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_index_cpu delete mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_index_limits_cpu delete mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_out_of_bound_index_cpu delete mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_zero_dim_index_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_correctness_different_devices_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_default_arg_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_multi_output_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_simple_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_unrelated_outputs_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_zero_dim_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_argnums_defaults_to_zero_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_aux_pytree_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_dimensionality_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_empty_output_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_inplace_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_jac_with_non_tensor_args_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_outputs_pytree_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_pytree_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_multiple_argnums_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_single_argnums_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_outputs_can_any_pytree_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_unrelated_input_cpu delete mode 100644 test/dynamo_expected_failures/TestJacCPU.test_unrelated_output_cpu delete mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu delete mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float32 delete mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_nll_loss_byte_target_matches_long_cpu delete mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_threshold_inplace_overlap_cpu delete mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda delete mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda delete mode 100644 test/dynamo_expected_failures/TestNumPyInteropCPU.test_numpy_non_writeable_cpu delete mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex128 delete mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex64 delete mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float32 delete mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex128 delete mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex64 delete mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float32 delete mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_bfloat16_cpu_bfloat16 delete mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float16_cpu_float16 delete mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float32_cpu_float32 delete mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float64_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_bfloat16_cpu_bfloat16 delete mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float16_cpu_float16 delete mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float32_cpu_float32 delete mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float64_cpu_float64 delete mode 100644 test/dynamo_expected_failures/TestShapeOpsCUDA.test_flip_cuda_float32 delete mode 100644 test/dynamo_expected_failures/TestTensorCreationCPU.test_block_diag_cpu delete mode 100644 test/dynamo_expected_failures/TestTensorCreationCPU.test_constructor_dtypes_cpu delete mode 100644 test/dynamo_expected_failures/TestTypePromotionCPU.test_alpha_mismatch_cpu delete mode 100644 test/dynamo_expected_failures/TestTypePromotionCPU.test_alternate_result_cpu delete mode 100644 test/dynamo_expected_failures/UnspecTests.test_builtin_max_min delete mode 100644 test/dynamo_expected_failures/UnspecTests.test_conv1d_symint_padding delete mode 100644 test/dynamo_expected_failures/UnspecTests.test_isinstance_symint delete mode 100644 test/dynamo_expected_failures/UnspecTests.test_mark_01_dynamic delete mode 100644 test/dynamo_expected_failures/UnspecTests.test_no_recompilations delete mode 100644 test/dynamo_expected_failures/UnspecTests.test_no_recompiles delete mode 100644 test/dynamo_expected_failures/UnspecTests.test_propagate_dynamic_dim delete mode 100644 test/dynamo_expected_failures/UnspecTests.test_use_and_specialize diff --git a/test/dynamo_expected_failures/FunctionTests.test_default_dict b/test/dynamo_expected_failures/FunctionTests.test_default_dict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/FunctionTests.test_default_dict_closure b/test/dynamo_expected_failures/FunctionTests.test_default_dict_closure deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/FunctionTests.test_default_dict_lambda b/test/dynamo_expected_failures/FunctionTests.test_default_dict_lambda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/FunctionTests.test_is_contiguous_frame_counts b/test/dynamo_expected_failures/FunctionTests.test_is_contiguous_frame_counts deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/FunctionTests.test_math_radians b/test/dynamo_expected_failures/FunctionTests.test_math_radians deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_lambda b/test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_lambda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_mod b/test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_mod deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test__scaled_dot_product_flash_attention_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test__scaled_dot_product_flash_attention_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_fake_tensor_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_fake_tensor_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_real_tensor_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_real_tensor_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_buffer_util_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_buffer_util_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_cond_with_module_stack_export_with_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_cond_with_module_stack_export_with_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_dynamic_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_dynamic_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_simple_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_simple_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_with_wrong_inputs_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_with_wrong_inputs_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_external_call_non_strict_real_tensor_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_external_call_non_strict_real_tensor_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_fqn_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_fqn_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_shared_submodule_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_shared_submodule_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_param_util_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_param_util_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_raise_user_error_when_guard_on_data_dependent_operation_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_raise_user_error_when_guard_on_data_dependent_operation_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_sym_sqrt_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_sym_sqrt_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_update_sub_later_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_update_sub_later_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_non_strict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_indexing_weirdness_cpu b/test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_indexing_weirdness_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_shape_mismatch_cpu b/test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_shape_mismatch_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NumpyTestsCPU.test_empty_fancy_index_cpu b/test/dynamo_expected_failures/NumpyTestsCPU.test_empty_fancy_index_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/NumpyTestsCPU.test_index_no_floats_cpu b/test/dynamo_expected_failures/NumpyTestsCPU.test_index_no_floats_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestDynamismExpression.test_export_inline_constraints_retraceability b/test/dynamo_expected_failures/RetraceExportTestDynamismExpression.test_export_inline_constraints_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_cond_with_module_stack_export_with_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_cond_with_module_stack_export_with_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_in_eager_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_in_eager_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_constrain_value_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_constrain_value_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_various_cases_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_various_cases_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_shared_submodule_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_shared_submodule_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_retraceability deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestDynamismExpression.test_export_inline_constraints_serdes b/test/dynamo_expected_failures/SerDesExportTestDynamismExpression.test_export_inline_constraints_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_fake_tensor_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_fake_tensor_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_real_tensor_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_real_tensor_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_cond_with_module_stack_export_with_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_cond_with_module_stack_export_with_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_in_eager_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_in_eager_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_constrain_value_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_constrain_value_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_various_cases_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_various_cases_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_external_call_non_strict_real_tensor_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_external_call_non_strict_real_tensor_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_shared_submodule_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_shared_submodule_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_serdes deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda b/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda b/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestBufferProtocolCPU.test_byte_to_int_cpu b/test/dynamo_expected_failures/TestBufferProtocolCPU.test_byte_to_int_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_deprecation_transforms_transform_functionalize_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_deprecation_transforms_transform_functionalize_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_requires_grad_inside_transform_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_requires_grad_inside_transform_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_hessian_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_hessian_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_jacfwd_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_jacfwd_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestContentStoreCPU.test_repeated_hash_cpu b/test/dynamo_expected_failures/TestContentStoreCPU.test_repeated_hash_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestCppExtensionOpenRgistration.test_open_device_registration b/test/dynamo_expected_failures/TestCppExtensionOpenRgistration.test_open_device_registration deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv1d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv1d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv2d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv2d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv3d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv3d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_group_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_group_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_instance_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_instance_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_layer_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_layer_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv1d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv1d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv2d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv2d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv3d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv3d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_group_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_group_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_instance_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_instance_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_layer_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_layer_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv1d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv1d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv2d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv2d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv3d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv3d_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_group_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_group_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_instance_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_instance_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_layer_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_layer_norm_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestFunctionalizeCPU.test_multioutput_view_cpu b/test/dynamo_expected_failures/TestFunctionalizeCPU.test_multioutput_view_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestFunctionalizeCPU.test_simple_view_cpu b/test/dynamo_expected_failures/TestFunctionalizeCPU.test_simple_view_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu b/test/dynamo_expected_failures/TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu b/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestHessianCUDA.test_jacfwd_different_levels_cuda b/test/dynamo_expected_failures/TestHessianCUDA.test_jacfwd_different_levels_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestHigherOrderOperatorInteractionCPU.test_grad_name_wrapping_cpu b/test/dynamo_expected_failures/TestHigherOrderOperatorInteractionCPU.test_grad_name_wrapping_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_byte_mask_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_byte_mask_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_empty_ndim_index_bool_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_empty_ndim_index_bool_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_index_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_index_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_index_limits_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_index_limits_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_out_of_bound_index_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_out_of_bound_index_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_zero_dim_index_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_zero_dim_index_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_correctness_different_devices_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_correctness_different_devices_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_default_arg_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_default_arg_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_multi_output_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_multi_output_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_simple_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_simple_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_unrelated_outputs_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_unrelated_outputs_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_zero_dim_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_zero_dim_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_argnums_defaults_to_zero_cpu b/test/dynamo_expected_failures/TestJacCPU.test_argnums_defaults_to_zero_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_aux_pytree_cpu b/test/dynamo_expected_failures/TestJacCPU.test_aux_pytree_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_dimensionality_cpu b/test/dynamo_expected_failures/TestJacCPU.test_dimensionality_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_empty_output_cpu b/test/dynamo_expected_failures/TestJacCPU.test_empty_output_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_inplace_cpu b/test/dynamo_expected_failures/TestJacCPU.test_inplace_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_jac_with_non_tensor_args_cpu b/test/dynamo_expected_failures/TestJacCPU.test_jac_with_non_tensor_args_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_outputs_pytree_cpu b/test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_outputs_pytree_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_pytree_cpu b/test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_pytree_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_multiple_argnums_cpu b/test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_multiple_argnums_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_single_argnums_cpu b/test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_single_argnums_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_outputs_can_any_pytree_cpu b/test/dynamo_expected_failures/TestJacCPU.test_outputs_can_any_pytree_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_unrelated_input_cpu b/test/dynamo_expected_failures/TestJacCPU.test_unrelated_input_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestJacCPU.test_unrelated_output_cpu b/test/dynamo_expected_failures/TestJacCPU.test_unrelated_output_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float32 b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float32 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float64 b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_nll_loss_byte_target_matches_long_cpu b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_nll_loss_byte_target_matches_long_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_threshold_inplace_overlap_cpu b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_threshold_inplace_overlap_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda b/test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda b/test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestNumPyInteropCPU.test_numpy_non_writeable_cpu b/test/dynamo_expected_failures/TestNumPyInteropCPU.test_numpy_non_writeable_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex128 b/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex128 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex64 b/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float32 b/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float32 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float64 b/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex128 b/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex128 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex64 b/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float32 b/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float32 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float64 b/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_bfloat16_cpu_bfloat16 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_bfloat16_cpu_bfloat16 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float16_cpu_float16 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float16_cpu_float16 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float32_cpu_float32 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float32_cpu_float32 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float64_cpu_float64 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float64_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_bfloat16_cpu_bfloat16 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_bfloat16_cpu_bfloat16 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float16_cpu_float16 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float16_cpu_float16 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float32_cpu_float32 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float32_cpu_float32 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float64_cpu_float64 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float64_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestShapeOpsCUDA.test_flip_cuda_float32 b/test/dynamo_expected_failures/TestShapeOpsCUDA.test_flip_cuda_float32 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTensorCreationCPU.test_block_diag_cpu b/test/dynamo_expected_failures/TestTensorCreationCPU.test_block_diag_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTensorCreationCPU.test_constructor_dtypes_cpu b/test/dynamo_expected_failures/TestTensorCreationCPU.test_constructor_dtypes_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTypePromotionCPU.test_alpha_mismatch_cpu b/test/dynamo_expected_failures/TestTypePromotionCPU.test_alpha_mismatch_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTypePromotionCPU.test_alternate_result_cpu b/test/dynamo_expected_failures/TestTypePromotionCPU.test_alternate_result_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/UnspecTests.test_builtin_max_min b/test/dynamo_expected_failures/UnspecTests.test_builtin_max_min deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/UnspecTests.test_conv1d_symint_padding b/test/dynamo_expected_failures/UnspecTests.test_conv1d_symint_padding deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/UnspecTests.test_isinstance_symint b/test/dynamo_expected_failures/UnspecTests.test_isinstance_symint deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/UnspecTests.test_mark_01_dynamic b/test/dynamo_expected_failures/UnspecTests.test_mark_01_dynamic deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/UnspecTests.test_no_recompilations b/test/dynamo_expected_failures/UnspecTests.test_no_recompilations deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/UnspecTests.test_no_recompiles b/test/dynamo_expected_failures/UnspecTests.test_no_recompiles deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/UnspecTests.test_propagate_dynamic_dim b/test/dynamo_expected_failures/UnspecTests.test_propagate_dynamic_dim deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/UnspecTests.test_use_and_specialize b/test/dynamo_expected_failures/UnspecTests.test_use_and_specialize deleted file mode 100644 index e69de29bb2d1d..0000000000000 From d46768db04499d07a5b0db984112a6d1b7d3b0c1 Mon Sep 17 00:00:00 2001 From: "Patrick C. Toulme" Date: Thu, 7 Aug 2025 22:37:15 +0000 Subject: [PATCH 0120/1424] [MTIA] Allow users who know what they are doing to ignore all device mismatches in tracing and take a preferred device. (#159931) Summary: Device mismatches in tracing can most often be ignored. These are only logical mismatches not physical. Take any intermediate computation, and that computation will not actually materialize in a compiled binary execution. So a device mismatch in the middle of the program is not real. The runtime will never materialize those tensors on CPU device during the execution, as they are temporary allocations. If a user knows his tensors at graph input are all on the correct device, then he can ignore all tracing errors. Users who know what they are doing should have an escape hatch to ignore any device mismatch in tracing. Users can set ``` torch._functorch.config.fake_tensor_prefer_device_type = 'mtia' ``` to forcefully override any mismatch and prefer the non cpu device. This unblocks vLLM graph mode for MTIA. Test Plan: Added two unit tests. Rollback Plan: Differential Revision: D79698438 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159931 Approved by: https://github.com/jansel --- test/test_fake_tensor.py | 76 ++++++++++++++++++++++++++++++++ torch/_functorch/config.py | 11 +++++ torch/_subclasses/fake_tensor.py | 15 +++++++ 3 files changed, 102 insertions(+) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index d6135ec16506e..9baad91da79d3 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -2486,5 +2486,81 @@ def forward( self.assertBypasses("unrepresented symbol in output", 2) +class FakeTensorPreferDeviceType(TestCase): + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_fake_tensor_prefer_device_type(self): + """ + Test that fake_tensor_prefer_device_type configuration works correctly + for device mismatch scenarios. + """ + + # Create a custom operation that would normally cause device mismatch + def mixed_device_op(a, b): + # This simulates an operation where 'a' is on MTIA/CUDA but 'b' is created on CPU + cpu_tensor = torch.arange(a.shape[0], device="cpu") + return a + cpu_tensor.unsqueeze(-1) + + with FakeTensorMode(): + # Test default behavior (should raise error on device mismatch) + cuda_tensor = torch.randn(3, 4, device="cuda") + + # Without the config, this should raise a device mismatch error + with self.assertRaisesRegex( + RuntimeError, "Unhandled FakeTensor Device Propagation" + ): + mixed_device_op(cuda_tensor, None) + + # Test with prefer_device_type set to "cuda" + with torch._functorch.config.patch(fake_tensor_prefer_device_type="cuda"): + with FakeTensorMode(): + cuda_tensor = torch.randn(3, 4, device="cuda") + + # This should now work and prefer the CUDA device + result = mixed_device_op(cuda_tensor, None) + + # The result should be on CUDA device (preferred device type) + self.assertEqual(result.device.type, "cuda") + self.assertEqual(result.shape, (3, 4)) + self.assertTrue(isinstance(result, FakeTensor)) + + # Test that the configuration doesn't affect normal operations + with torch._functorch.config.patch(fake_tensor_prefer_device_type="cuda"): + with FakeTensorMode(): + # Normal same-device operations should work as before + x = torch.randn(2, 3, device="cuda") + y = torch.randn(2, 3, device="cuda") + result = x + y + self.assertEqual(result.device.type, "cuda") + + # CPU operations should still work + x_cpu = torch.randn(2, 3, device="cpu") + y_cpu = torch.randn(2, 3, device="cpu") + result_cpu = x_cpu + y_cpu + self.assertEqual(result_cpu.device.type, "cpu") + + # Test that the configuration is properly scoped + with FakeTensorMode(): + cuda_tensor = torch.randn(3, 4, device="cuda") + + # After exiting the config context, should raise error again + with self.assertRaisesRegex( + RuntimeError, "Unhandled FakeTensor Device Propagation" + ): + mixed_device_op(cuda_tensor, None) + + def test_fake_tensor_prefer_device_type_cpu_only(self): + """ + Test that fake_tensor_prefer_device_type works correctly when only CPU tensors are involved. + """ + with torch._functorch.config.patch(fake_tensor_prefer_device_type="cuda"): + with FakeTensorMode(): + # When all tensors are CPU, the result should still be CPU + x = torch.randn(2, 3, device="cpu") + y = torch.randn(2, 3, device="cpu") + result = x + y + self.assertEqual(result.device.type, "cpu") + self.assertTrue(isinstance(result, FakeTensor)) + + if __name__ == "__main__": run_tests() diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 2833a2b1631a1..5bf2dee3e1d7d 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -281,6 +281,17 @@ def remote_autograd_cache_default() -> Optional[bool]: # real tensor outputs. generate_fake_kernels_from_real_mismatches = False +# When there are device mismatches in FakeTensor device propagation, +# prefer a specific device type over others. This is particularly useful +# in full compiled mode where intermediate tensors with device mismatches +# represent only logical differences during compilation - these intermediate +# tensors will never physically materialize in the binary execution, so the +# device mismatch is not a real runtime concern. Enabling this allows the +# compiler to proceed with compilation by choosing the preferred device type +# for consistency. For example, set to "mtia" to prefer MTIA devices over +# CPU, or "cuda" to prefer CUDA devices over CPU. +fake_tensor_prefer_device_type: Optional[str] = None + # CUDAGraph save run_with_rng functionalization. # TODO: turn on by default graphsafe_rng_functionalization = True diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e7d9e1fc23b47..52b776946b361 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -940,6 +940,21 @@ def merge_devices(t: object) -> None: if any(map(check_cpu_device, (common_device, t.device))): return + # if prefer_device_type is set, prefer that device type over others + prefer_device_type = torch._functorch.config.fake_tensor_prefer_device_type + if prefer_device_type is not None: + common_has_preferred = prefer_device_type in common_device.type + t_has_preferred = prefer_device_type in t.device.type + + if not common_has_preferred and t_has_preferred: + # Switch to the preferred device type + common_device = t.device + is_cpu_zero_dim = t_is_cpu_zero_dim + return + elif common_has_preferred and not t_has_preferred: + # Keep the existing preferred device type + return + # mismatching devices of non-zero dim tensors, throw # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as raise RuntimeError( From f077c2402e4eb5b0ed562b4ee5b7a0503f26ef94 Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Thu, 7 Aug 2025 12:07:59 -0700 Subject: [PATCH 0121/1424] [replicate][be] improved readability of test case description (#160128) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160128 Approved by: https://github.com/mori360 --- test/distributed/_composable/test_replicate_with_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed/_composable/test_replicate_with_fsdp.py b/test/distributed/_composable/test_replicate_with_fsdp.py index ff61e2c05f274..099f84b9e848f 100644 --- a/test/distributed/_composable/test_replicate_with_fsdp.py +++ b/test/distributed/_composable/test_replicate_with_fsdp.py @@ -256,7 +256,7 @@ def test_train_replicate_fsdp(self): @skip_if_lt_x_gpu(2) def test_train_parity_2d_mlp(self): """ - Verifies that when a device mesh is passed in, the model has the same behavior as the original model when training + Verifies when a device mesh is passed in, the model has the same behavior as the original model when training """ self._init_pg() global_mesh = self.init_replicate_tp_mesh() From 195b5c2e27eb8f21cbc8ad1e90f42db5a8cfccca Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 7 Aug 2025 22:55:51 +0000 Subject: [PATCH 0122/1424] Revert "dynamo: Remove passing or deleted dynamo_expected_failures (#159691)" This reverts commit 36f46d082a4954921cb8493223f000f2aab79ed7. Reverted https://github.com/pytorch/pytorch/pull/159691 on behalf of https://github.com/izaitsevfb due to breaking dynamo tests ([comment](https://github.com/pytorch/pytorch/pull/159691#issuecomment-3166067241)) --- test/dynamo_expected_failures/FunctionTests.test_default_dict | 0 .../FunctionTests.test_default_dict_closure | 0 .../FunctionTests.test_default_dict_lambda | 0 .../FunctionTests.test_is_contiguous_frame_counts | 0 test/dynamo_expected_failures/FunctionTests.test_math_radians | 0 .../FunctionTests.test_partials_as_input_partials_lambda | 0 .../FunctionTests.test_partials_as_input_partials_mod | 0 ...TestExport.test__scaled_dot_product_flash_attention_non_strict | 0 ...tExportTestExport.test_basic_non_strict_fake_tensor_non_strict | 0 ...tExportTestExport.test_basic_non_strict_real_tensor_non_strict | 0 .../NonStrictExportTestExport.test_buffer_util_non_strict | 0 ...tTestExport.test_cond_with_module_stack_export_with_non_strict | 0 ...nStrictExportTestExport.test_export_decomps_dynamic_non_strict | 0 ...onStrictExportTestExport.test_export_decomps_simple_non_strict | 0 ...trictExportTestExport.test_export_with_wrong_inputs_non_strict | 0 ...estExport.test_external_call_non_strict_real_tensor_non_strict | 0 .../NonStrictExportTestExport.test_fqn_non_strict | 0 .../NonStrictExportTestExport.test_nn_module_stack_non_strict | 0 ...ortTestExport.test_nn_module_stack_shared_submodule_non_strict | 0 ...rictExportTestExport.test_non_strict_dynamic_shapes_non_strict | 0 ...port.test_non_strict_dynamic_shapes_suggested_fixes_non_strict | 0 .../NonStrictExportTestExport.test_param_util_non_strict | 0 ...e_user_error_when_guard_on_data_dependent_operation_non_strict | 0 .../NonStrictExportTestExport.test_sym_sqrt_non_strict | 0 ...tExport.test_to_module_with_mutated_buffer_multiple_non_strict | 0 ...odule_with_mutated_buffer_multiple_update_sub_later_non_strict | 0 ...ExportTestExport.test_to_module_with_mutated_buffer_non_strict | 0 .../NumpyTestsCPU.test_boolean_indexing_weirdness_cpu | 0 .../NumpyTestsCPU.test_boolean_shape_mismatch_cpu | 0 .../NumpyTestsCPU.test_empty_fancy_index_cpu | 0 .../NumpyTestsCPU.test_index_no_floats_cpu | 0 ...namismExpression.test_export_inline_constraints_retraceability | 0 ...tExport.test_cond_with_module_stack_export_with_retraceability | 0 ...ceExportTestExport.test_constrain_size_in_eager_retraceability | 0 ...Export.test_constrain_size_with_constrain_value_retraceability | 0 ...stExport.test_constrain_size_with_various_cases_retraceability | 0 .../RetraceExportTestExport.test_nn_module_stack_retraceability | 0 ...estExport.test_nn_module_stack_shared_submodule_retraceability | 0 ...ExportTestExport.test_non_strict_dynamic_shapes_retraceability | 0 ....test_non_strict_dynamic_shapes_suggested_fixes_retraceability | 0 ...rtTestDynamismExpression.test_export_inline_constraints_serdes | 0 ...erDesExportTestExport.test_basic_non_strict_fake_tensor_serdes | 0 ...erDesExportTestExport.test_basic_non_strict_real_tensor_serdes | 0 ...xportTestExport.test_cond_with_module_stack_export_with_serdes | 0 .../SerDesExportTestExport.test_constrain_size_in_eager_serdes | 0 ...portTestExport.test_constrain_size_with_constrain_value_serdes | 0 ...ExportTestExport.test_constrain_size_with_various_cases_serdes | 0 ...ortTestExport.test_external_call_non_strict_real_tensor_serdes | 0 .../SerDesExportTestExport.test_nn_module_stack_serdes | 0 ...sExportTestExport.test_nn_module_stack_shared_submodule_serdes | 0 .../SerDesExportTestExport.test_non_strict_dynamic_shapes_serdes | 0 ...stExport.test_non_strict_dynamic_shapes_suggested_fixes_serdes | 0 ...rad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu | 0 ...ad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu | 0 ...rad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu | 0 ...ad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu | 0 ...grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu | 0 ...rad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu | 0 ...grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu | 0 ...rad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu | 0 ...tAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu | 0 ...ad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda | 0 ...d_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda | 0 ...ad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda | 0 ...d_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda | 0 ...rad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda | 0 ...ad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda | 0 ...rad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda | 0 ...ad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda | 0 ...utogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda | 0 ...UDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda | 0 ...ICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda | 0 .../TestBufferProtocolCPU.test_byte_to_int_cpu | 0 ....test_autograd_function_no_setup_context_transform_hessian_cpu | 0 ...U.test_autograd_function_no_setup_context_transform_jacfwd_cpu | 0 ...ityCPU.test_deprecation_transforms_transform_functionalize_cpu | 0 .../TestComposabilityCPU.test_requires_grad_inside_transform_cpu | 0 ...test_autograd_function_no_setup_context_transform_hessian_cuda | 0 ....test_autograd_function_no_setup_context_transform_jacfwd_cuda | 0 ...estComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda | 0 ...TestComposabilityCUDA.test_requires_grad_inside_transform_cuda | 0 .../TestContentStoreCPU.test_repeated_hash_cpu | 0 .../TestCppExtensionOpenRgistration.test_open_device_registration | 0 ...d_weight_per_sample_grad_mean_nn_functional_conv1d_cpu_float64 | 0 ...d_weight_per_sample_grad_mean_nn_functional_conv2d_cpu_float64 | 0 ...d_weight_per_sample_grad_mean_nn_functional_conv3d_cpu_float64 | 0 ...ight_per_sample_grad_mean_nn_functional_group_norm_cpu_float64 | 0 ...t_per_sample_grad_mean_nn_functional_instance_norm_cpu_float64 | 0 ...ight_per_sample_grad_mean_nn_functional_layer_norm_cpu_float64 | 0 ...ed_weight_per_sample_grad_sum_nn_functional_conv1d_cpu_float64 | 0 ...ed_weight_per_sample_grad_sum_nn_functional_conv2d_cpu_float64 | 0 ...ed_weight_per_sample_grad_sum_nn_functional_conv3d_cpu_float64 | 0 ...eight_per_sample_grad_sum_nn_functional_group_norm_cpu_float64 | 0 ...ht_per_sample_grad_sum_nn_functional_instance_norm_cpu_float64 | 0 ...eight_per_sample_grad_sum_nn_functional_layer_norm_cpu_float64 | 0 ...per_sample_grad_input_no_grad_nn_functional_conv1d_cpu_float64 | 0 ...per_sample_grad_input_no_grad_nn_functional_conv2d_cpu_float64 | 0 ...per_sample_grad_input_no_grad_nn_functional_conv3d_cpu_float64 | 0 ...sample_grad_input_no_grad_nn_functional_group_norm_cpu_float64 | 0 ...ple_grad_input_no_grad_nn_functional_instance_norm_cpu_float64 | 0 ...sample_grad_input_no_grad_nn_functional_layer_norm_cpu_float64 | 0 .../TestFunctionalizeCPU.test_multioutput_view_cpu | 0 .../TestFunctionalizeCPU.test_simple_view_cpu | 0 .../TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu | 0 .../TestHessianCPU.test_jacfwd_different_levels_cpu | 0 .../TestHessianCUDA.test_jacfwd_different_levels_cuda | 0 ...tHigherOrderOperatorInteractionCPU.test_grad_name_wrapping_cpu | 0 test/dynamo_expected_failures/TestIndexingCPU.test_byte_mask_cpu | 0 .../TestIndexingCPU.test_empty_ndim_index_bool_cpu | 0 test/dynamo_expected_failures/TestIndexingCPU.test_index_cpu | 0 .../TestIndexingCPU.test_index_limits_cpu | 0 .../TestIndexingCPU.test_out_of_bound_index_cpu | 0 .../TestIndexingCPU.test_zero_dim_index_cpu | 0 ...acCPU.test_against_reference_correctness_different_devices_cpu | 0 .../TestJacCPU.test_against_reference_default_arg_cpu | 0 .../TestJacCPU.test_against_reference_multi_input_cpu | 0 ...TestJacCPU.test_against_reference_multi_input_multi_output_cpu | 0 .../TestJacCPU.test_against_reference_simple_cpu | 0 .../TestJacCPU.test_against_reference_unrelated_outputs_cpu | 0 .../TestJacCPU.test_against_reference_zero_dim_cpu | 0 .../TestJacCPU.test_argnums_defaults_to_zero_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_aux_pytree_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_dimensionality_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_empty_output_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_inplace_cpu | 0 .../TestJacCPU.test_jac_with_non_tensor_args_cpu | 0 .../TestJacCPU.test_multiple_inputs_outputs_pytree_cpu | 0 .../TestJacCPU.test_multiple_inputs_pytree_cpu | 0 .../TestJacCPU.test_multiple_outputs_multiple_argnums_cpu | 0 .../TestJacCPU.test_multiple_outputs_single_argnums_cpu | 0 .../TestJacCPU.test_outputs_can_any_pytree_cpu | 0 test/dynamo_expected_failures/TestJacCPU.test_unrelated_input_cpu | 0 .../dynamo_expected_failures/TestJacCPU.test_unrelated_output_cpu | 0 .../TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu | 0 .../TestNNDeviceTypeCPU.test_module_to_empty_cpu_float32 | 0 .../TestNNDeviceTypeCPU.test_module_to_empty_cpu_float64 | 0 ...TestNNDeviceTypeCPU.test_nll_loss_byte_target_matches_long_cpu | 0 .../TestNNDeviceTypeCPU.test_threshold_inplace_overlap_cpu | 0 ...ionDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda | 0 ...tionDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda | 0 .../TestNumPyInteropCPU.test_numpy_non_writeable_cpu | 0 .../TestReductionsCPU.test_std_vs_numpy_cpu_complex128 | 0 .../TestReductionsCPU.test_std_vs_numpy_cpu_complex64 | 0 .../TestReductionsCPU.test_std_vs_numpy_cpu_float32 | 0 .../TestReductionsCPU.test_std_vs_numpy_cpu_float64 | 0 .../TestReductionsCPU.test_var_vs_numpy_cpu_complex128 | 0 .../TestReductionsCPU.test_var_vs_numpy_cpu_complex64 | 0 .../TestReductionsCPU.test_var_vs_numpy_cpu_float32 | 0 .../TestReductionsCPU.test_var_vs_numpy_cpu_float64 | 0 ...ed_sdp_choice_cpu_type_dense_dropout_0_0_bfloat16_cpu_bfloat16 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_0_float16_cpu_float16 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_0_float32_cpu_float32 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_0_float64_cpu_float64 | 0 ...ed_sdp_choice_cpu_type_dense_dropout_0_7_bfloat16_cpu_bfloat16 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_7_float16_cpu_float16 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_7_float32_cpu_float32 | 0 ...used_sdp_choice_cpu_type_dense_dropout_0_7_float64_cpu_float64 | 0 .../TestShapeOpsCUDA.test_flip_cuda_float32 | 0 .../TestTensorCreationCPU.test_block_diag_cpu | 0 .../TestTensorCreationCPU.test_constructor_dtypes_cpu | 0 .../TestTypePromotionCPU.test_alpha_mismatch_cpu | 0 .../TestTypePromotionCPU.test_alternate_result_cpu | 0 test/dynamo_expected_failures/UnspecTests.test_builtin_max_min | 0 .../UnspecTests.test_conv1d_symint_padding | 0 test/dynamo_expected_failures/UnspecTests.test_isinstance_symint | 0 test/dynamo_expected_failures/UnspecTests.test_mark_01_dynamic | 0 test/dynamo_expected_failures/UnspecTests.test_no_recompilations | 0 test/dynamo_expected_failures/UnspecTests.test_no_recompiles | 0 .../UnspecTests.test_propagate_dynamic_dim | 0 test/dynamo_expected_failures/UnspecTests.test_use_and_specialize | 0 170 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/dynamo_expected_failures/FunctionTests.test_default_dict create mode 100644 test/dynamo_expected_failures/FunctionTests.test_default_dict_closure create mode 100644 test/dynamo_expected_failures/FunctionTests.test_default_dict_lambda create mode 100644 test/dynamo_expected_failures/FunctionTests.test_is_contiguous_frame_counts create mode 100644 test/dynamo_expected_failures/FunctionTests.test_math_radians create mode 100644 test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_lambda create mode 100644 test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_mod create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test__scaled_dot_product_flash_attention_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_fake_tensor_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_real_tensor_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_buffer_util_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_cond_with_module_stack_export_with_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_dynamic_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_simple_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_export_with_wrong_inputs_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_external_call_non_strict_real_tensor_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_fqn_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_shared_submodule_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_param_util_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_raise_user_error_when_guard_on_data_dependent_operation_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_sym_sqrt_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_update_sub_later_non_strict create mode 100644 test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_non_strict create mode 100644 test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_indexing_weirdness_cpu create mode 100644 test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_shape_mismatch_cpu create mode 100644 test/dynamo_expected_failures/NumpyTestsCPU.test_empty_fancy_index_cpu create mode 100644 test/dynamo_expected_failures/NumpyTestsCPU.test_index_no_floats_cpu create mode 100644 test/dynamo_expected_failures/RetraceExportTestDynamismExpression.test_export_inline_constraints_retraceability create mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_cond_with_module_stack_export_with_retraceability create mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_in_eager_retraceability create mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_constrain_value_retraceability create mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_various_cases_retraceability create mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_retraceability create mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_shared_submodule_retraceability create mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_retraceability create mode 100644 test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_retraceability create mode 100644 test/dynamo_expected_failures/SerDesExportTestDynamismExpression.test_export_inline_constraints_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_fake_tensor_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_real_tensor_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_cond_with_module_stack_export_with_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_in_eager_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_constrain_value_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_various_cases_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_external_call_non_strict_real_tensor_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_shared_submodule_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_serdes create mode 100644 test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_serdes create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda create mode 100644 test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda create mode 100644 test/dynamo_expected_failures/TestBufferProtocolCPU.test_byte_to_int_cpu create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_deprecation_transforms_transform_functionalize_cpu create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_requires_grad_inside_transform_cpu create mode 100644 test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_hessian_cuda create mode 100644 test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_jacfwd_cuda create mode 100644 test/dynamo_expected_failures/TestComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda create mode 100644 test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda create mode 100644 test/dynamo_expected_failures/TestContentStoreCPU.test_repeated_hash_cpu create mode 100644 test/dynamo_expected_failures/TestCppExtensionOpenRgistration.test_open_device_registration create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv1d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv2d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv3d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_group_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_instance_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_layer_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv1d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv2d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv3d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_group_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_instance_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_layer_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv1d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv2d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv3d_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_group_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_instance_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_layer_norm_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestFunctionalizeCPU.test_multioutput_view_cpu create mode 100644 test/dynamo_expected_failures/TestFunctionalizeCPU.test_simple_view_cpu create mode 100644 test/dynamo_expected_failures/TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu create mode 100644 test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu create mode 100644 test/dynamo_expected_failures/TestHessianCUDA.test_jacfwd_different_levels_cuda create mode 100644 test/dynamo_expected_failures/TestHigherOrderOperatorInteractionCPU.test_grad_name_wrapping_cpu create mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_byte_mask_cpu create mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_empty_ndim_index_bool_cpu create mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_index_cpu create mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_index_limits_cpu create mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_out_of_bound_index_cpu create mode 100644 test/dynamo_expected_failures/TestIndexingCPU.test_zero_dim_index_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_correctness_different_devices_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_default_arg_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_multi_output_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_simple_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_unrelated_outputs_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_against_reference_zero_dim_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_argnums_defaults_to_zero_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_aux_pytree_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_dimensionality_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_empty_output_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_inplace_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_jac_with_non_tensor_args_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_outputs_pytree_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_pytree_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_multiple_argnums_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_single_argnums_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_outputs_can_any_pytree_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_unrelated_input_cpu create mode 100644 test/dynamo_expected_failures/TestJacCPU.test_unrelated_output_cpu create mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu create mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float32 create mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_nll_loss_byte_target_matches_long_cpu create mode 100644 test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_threshold_inplace_overlap_cpu create mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda create mode 100644 test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda create mode 100644 test/dynamo_expected_failures/TestNumPyInteropCPU.test_numpy_non_writeable_cpu create mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex128 create mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex64 create mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float32 create mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex128 create mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex64 create mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float32 create mode 100644 test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_bfloat16_cpu_bfloat16 create mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float16_cpu_float16 create mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float32_cpu_float32 create mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float64_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_bfloat16_cpu_bfloat16 create mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float16_cpu_float16 create mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float32_cpu_float32 create mode 100644 test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float64_cpu_float64 create mode 100644 test/dynamo_expected_failures/TestShapeOpsCUDA.test_flip_cuda_float32 create mode 100644 test/dynamo_expected_failures/TestTensorCreationCPU.test_block_diag_cpu create mode 100644 test/dynamo_expected_failures/TestTensorCreationCPU.test_constructor_dtypes_cpu create mode 100644 test/dynamo_expected_failures/TestTypePromotionCPU.test_alpha_mismatch_cpu create mode 100644 test/dynamo_expected_failures/TestTypePromotionCPU.test_alternate_result_cpu create mode 100644 test/dynamo_expected_failures/UnspecTests.test_builtin_max_min create mode 100644 test/dynamo_expected_failures/UnspecTests.test_conv1d_symint_padding create mode 100644 test/dynamo_expected_failures/UnspecTests.test_isinstance_symint create mode 100644 test/dynamo_expected_failures/UnspecTests.test_mark_01_dynamic create mode 100644 test/dynamo_expected_failures/UnspecTests.test_no_recompilations create mode 100644 test/dynamo_expected_failures/UnspecTests.test_no_recompiles create mode 100644 test/dynamo_expected_failures/UnspecTests.test_propagate_dynamic_dim create mode 100644 test/dynamo_expected_failures/UnspecTests.test_use_and_specialize diff --git a/test/dynamo_expected_failures/FunctionTests.test_default_dict b/test/dynamo_expected_failures/FunctionTests.test_default_dict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/FunctionTests.test_default_dict_closure b/test/dynamo_expected_failures/FunctionTests.test_default_dict_closure new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/FunctionTests.test_default_dict_lambda b/test/dynamo_expected_failures/FunctionTests.test_default_dict_lambda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/FunctionTests.test_is_contiguous_frame_counts b/test/dynamo_expected_failures/FunctionTests.test_is_contiguous_frame_counts new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/FunctionTests.test_math_radians b/test/dynamo_expected_failures/FunctionTests.test_math_radians new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_lambda b/test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_lambda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_mod b/test/dynamo_expected_failures/FunctionTests.test_partials_as_input_partials_mod new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test__scaled_dot_product_flash_attention_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test__scaled_dot_product_flash_attention_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_fake_tensor_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_fake_tensor_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_real_tensor_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_basic_non_strict_real_tensor_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_buffer_util_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_buffer_util_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_cond_with_module_stack_export_with_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_cond_with_module_stack_export_with_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_dynamic_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_dynamic_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_simple_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_decomps_simple_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_with_wrong_inputs_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_export_with_wrong_inputs_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_external_call_non_strict_real_tensor_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_external_call_non_strict_real_tensor_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_fqn_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_fqn_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_shared_submodule_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_nn_module_stack_shared_submodule_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_param_util_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_param_util_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_raise_user_error_when_guard_on_data_dependent_operation_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_raise_user_error_when_guard_on_data_dependent_operation_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_sym_sqrt_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_sym_sqrt_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_update_sub_later_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_multiple_update_sub_later_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_non_strict b/test/dynamo_expected_failures/NonStrictExportTestExport.test_to_module_with_mutated_buffer_non_strict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_indexing_weirdness_cpu b/test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_indexing_weirdness_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_shape_mismatch_cpu b/test/dynamo_expected_failures/NumpyTestsCPU.test_boolean_shape_mismatch_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NumpyTestsCPU.test_empty_fancy_index_cpu b/test/dynamo_expected_failures/NumpyTestsCPU.test_empty_fancy_index_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/NumpyTestsCPU.test_index_no_floats_cpu b/test/dynamo_expected_failures/NumpyTestsCPU.test_index_no_floats_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestDynamismExpression.test_export_inline_constraints_retraceability b/test/dynamo_expected_failures/RetraceExportTestDynamismExpression.test_export_inline_constraints_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_cond_with_module_stack_export_with_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_cond_with_module_stack_export_with_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_in_eager_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_in_eager_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_constrain_value_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_constrain_value_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_various_cases_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_constrain_size_with_various_cases_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_shared_submodule_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_nn_module_stack_shared_submodule_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_retraceability b/test/dynamo_expected_failures/RetraceExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_retraceability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestDynamismExpression.test_export_inline_constraints_serdes b/test/dynamo_expected_failures/SerDesExportTestDynamismExpression.test_export_inline_constraints_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_fake_tensor_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_fake_tensor_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_real_tensor_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_basic_non_strict_real_tensor_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_cond_with_module_stack_export_with_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_cond_with_module_stack_export_with_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_in_eager_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_in_eager_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_constrain_value_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_constrain_value_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_various_cases_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_constrain_size_with_various_cases_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_external_call_non_strict_real_tensor_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_external_call_non_strict_real_tensor_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_shared_submodule_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_nn_module_stack_shared_submodule_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_serdes b/test/dynamo_expected_failures/SerDesExportTestExport.test_non_strict_dynamic_shapes_suggested_fixes_serdes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda b/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda b/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestBufferProtocolCPU.test_byte_to_int_cpu b/test/dynamo_expected_failures/TestBufferProtocolCPU.test_byte_to_int_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_deprecation_transforms_transform_functionalize_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_deprecation_transforms_transform_functionalize_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_requires_grad_inside_transform_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_requires_grad_inside_transform_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_hessian_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_hessian_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_jacfwd_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_jacfwd_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestContentStoreCPU.test_repeated_hash_cpu b/test/dynamo_expected_failures/TestContentStoreCPU.test_repeated_hash_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestCppExtensionOpenRgistration.test_open_device_registration b/test/dynamo_expected_failures/TestCppExtensionOpenRgistration.test_open_device_registration new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv1d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv1d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv2d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv2d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv3d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_conv3d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_group_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_group_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_instance_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_instance_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_layer_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_mean_nn_functional_layer_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv1d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv1d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv2d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv2d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv3d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_conv3d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_group_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_group_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_instance_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_instance_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_layer_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weight_per_sample_grad_sum_nn_functional_layer_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv1d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv1d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv2d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv2d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv3d_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_conv3d_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_group_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_group_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_instance_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_instance_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_layer_norm_cpu_float64 b/test/dynamo_expected_failures/TestExpandedWeightFunctionalCPU.test_expanded_weights_per_sample_grad_input_no_grad_nn_functional_layer_norm_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestFunctionalizeCPU.test_multioutput_view_cpu b/test/dynamo_expected_failures/TestFunctionalizeCPU.test_multioutput_view_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestFunctionalizeCPU.test_simple_view_cpu b/test/dynamo_expected_failures/TestFunctionalizeCPU.test_simple_view_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu b/test/dynamo_expected_failures/TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu b/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestHessianCUDA.test_jacfwd_different_levels_cuda b/test/dynamo_expected_failures/TestHessianCUDA.test_jacfwd_different_levels_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestHigherOrderOperatorInteractionCPU.test_grad_name_wrapping_cpu b/test/dynamo_expected_failures/TestHigherOrderOperatorInteractionCPU.test_grad_name_wrapping_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_byte_mask_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_byte_mask_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_empty_ndim_index_bool_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_empty_ndim_index_bool_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_index_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_index_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_index_limits_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_index_limits_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_out_of_bound_index_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_out_of_bound_index_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestIndexingCPU.test_zero_dim_index_cpu b/test/dynamo_expected_failures/TestIndexingCPU.test_zero_dim_index_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_correctness_different_devices_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_correctness_different_devices_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_default_arg_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_default_arg_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_multi_output_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_multi_input_multi_output_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_simple_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_simple_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_unrelated_outputs_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_unrelated_outputs_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_against_reference_zero_dim_cpu b/test/dynamo_expected_failures/TestJacCPU.test_against_reference_zero_dim_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_argnums_defaults_to_zero_cpu b/test/dynamo_expected_failures/TestJacCPU.test_argnums_defaults_to_zero_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_aux_pytree_cpu b/test/dynamo_expected_failures/TestJacCPU.test_aux_pytree_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_dimensionality_cpu b/test/dynamo_expected_failures/TestJacCPU.test_dimensionality_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_empty_output_cpu b/test/dynamo_expected_failures/TestJacCPU.test_empty_output_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_inplace_cpu b/test/dynamo_expected_failures/TestJacCPU.test_inplace_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_jac_with_non_tensor_args_cpu b/test/dynamo_expected_failures/TestJacCPU.test_jac_with_non_tensor_args_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_outputs_pytree_cpu b/test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_outputs_pytree_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_pytree_cpu b/test/dynamo_expected_failures/TestJacCPU.test_multiple_inputs_pytree_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_multiple_argnums_cpu b/test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_multiple_argnums_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_single_argnums_cpu b/test/dynamo_expected_failures/TestJacCPU.test_multiple_outputs_single_argnums_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_outputs_can_any_pytree_cpu b/test/dynamo_expected_failures/TestJacCPU.test_outputs_can_any_pytree_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_unrelated_input_cpu b/test/dynamo_expected_failures/TestJacCPU.test_unrelated_input_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestJacCPU.test_unrelated_output_cpu b/test/dynamo_expected_failures/TestJacCPU.test_unrelated_output_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float32 b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float64 b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_module_to_empty_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_nll_loss_byte_target_matches_long_cpu b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_nll_loss_byte_target_matches_long_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_threshold_inplace_overlap_cpu b/test/dynamo_expected_failures/TestNNDeviceTypeCPU.test_threshold_inplace_overlap_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda b/test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda b/test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestNumPyInteropCPU.test_numpy_non_writeable_cpu b/test/dynamo_expected_failures/TestNumPyInteropCPU.test_numpy_non_writeable_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex128 b/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex128 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex64 b/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_complex64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float32 b/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float64 b/test/dynamo_expected_failures/TestReductionsCPU.test_std_vs_numpy_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex128 b/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex128 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex64 b/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_complex64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float32 b/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float64 b/test/dynamo_expected_failures/TestReductionsCPU.test_var_vs_numpy_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_bfloat16_cpu_bfloat16 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_bfloat16_cpu_bfloat16 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float16_cpu_float16 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float16_cpu_float16 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float32_cpu_float32 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float32_cpu_float32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float64_cpu_float64 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_0_float64_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_bfloat16_cpu_bfloat16 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_bfloat16_cpu_bfloat16 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float16_cpu_float16 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float16_cpu_float16 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float32_cpu_float32 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float32_cpu_float32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float64_cpu_float64 b/test/dynamo_expected_failures/TestSDPACPU.test_fused_sdp_choice_cpu_type_dense_dropout_0_7_float64_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestShapeOpsCUDA.test_flip_cuda_float32 b/test/dynamo_expected_failures/TestShapeOpsCUDA.test_flip_cuda_float32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestTensorCreationCPU.test_block_diag_cpu b/test/dynamo_expected_failures/TestTensorCreationCPU.test_block_diag_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestTensorCreationCPU.test_constructor_dtypes_cpu b/test/dynamo_expected_failures/TestTensorCreationCPU.test_constructor_dtypes_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestTypePromotionCPU.test_alpha_mismatch_cpu b/test/dynamo_expected_failures/TestTypePromotionCPU.test_alpha_mismatch_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestTypePromotionCPU.test_alternate_result_cpu b/test/dynamo_expected_failures/TestTypePromotionCPU.test_alternate_result_cpu new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/UnspecTests.test_builtin_max_min b/test/dynamo_expected_failures/UnspecTests.test_builtin_max_min new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/UnspecTests.test_conv1d_symint_padding b/test/dynamo_expected_failures/UnspecTests.test_conv1d_symint_padding new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/UnspecTests.test_isinstance_symint b/test/dynamo_expected_failures/UnspecTests.test_isinstance_symint new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/UnspecTests.test_mark_01_dynamic b/test/dynamo_expected_failures/UnspecTests.test_mark_01_dynamic new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/UnspecTests.test_no_recompilations b/test/dynamo_expected_failures/UnspecTests.test_no_recompilations new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/UnspecTests.test_no_recompiles b/test/dynamo_expected_failures/UnspecTests.test_no_recompiles new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/UnspecTests.test_propagate_dynamic_dim b/test/dynamo_expected_failures/UnspecTests.test_propagate_dynamic_dim new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/UnspecTests.test_use_and_specialize b/test/dynamo_expected_failures/UnspecTests.test_use_and_specialize new file mode 100644 index 0000000000000..e69de29bb2d1d From 03b254e49f2d4c092e6ca712e5702cf2895aa47e Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 7 Aug 2025 13:20:47 -0700 Subject: [PATCH 0123/1424] Extend torch function support to ALL arguments, not just scalar type (but not insides of list) (#145089) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/145089 Approved by: https://github.com/albanD, https://github.com/zou3519 --- test/test_fx.py | 10 --------- torch/csrc/utils/python_arg_parser.cpp | 31 ++++++++++++++++++-------- torch/csrc/utils/python_arg_parser.h | 6 +++++ 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index 55e98df702480..ba80f69828df3 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4660,7 +4660,6 @@ def tearDown(self): "linear": BUILT_IN_FUNC, "logsigmoid": BUILT_IN_FUNC, "one_hot": BUILT_IN_FUNC, - "pad": ARG_TYPE_MISMATCH, "pairwise_distance": BUILT_IN_FUNC, "pdist": BUILT_IN_FUNC, "pixel_shuffle": BUILT_IN_FUNC, @@ -4693,12 +4692,6 @@ def tearDown(self): "max_unpool3d": PROXY_ITERATED, "fold": PROXY_ITERATED, "unfold": PROXY_ITERATED, - "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, - "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, - "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, - "layer_norm": ARG_TYPE_MISMATCH, - "rms_norm": ARG_TYPE_MISMATCH, - "lp_pool1d": ARG_TYPE_MISMATCH, "affine_grid": CONTROL_FLOW, "alpha_dropout": CONTROL_FLOW, "batch_norm": CONTROL_FLOW, @@ -4732,9 +4725,6 @@ def tearDown(self): "leaky_relu": CONTROL_FLOW, "local_response_norm": CONTROL_FLOW, "margin_ranking_loss": CONTROL_FLOW, - "max_pool1d_with_indices": ARG_TYPE_MISMATCH, - "max_pool2d_with_indices": ARG_TYPE_MISMATCH, - "max_pool3d_with_indices": ARG_TYPE_MISMATCH, "mse_loss": CONTROL_FLOW, "multi_head_attention_forward": CONTROL_FLOW, "multi_margin_loss": CONTROL_FLOW, diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 8a16b0211dce6..7066b164a2280 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -938,6 +938,27 @@ auto FunctionParameter::check( std::vector& overloaded_args, int argnum, int64_t* failed_idx) -> bool { + if (_check(obj, overloaded_args, argnum, failed_idx)) { + return true; + } + // NB: This will not detect torch function inside elements of a list. So + // you still have to handle that manually + // NB: torch function on Tensor subclasses NOT eligible here, you handled + // that internally + if (check_has_torch_function(obj, /*ignore_mode*/ true) && + !THPVariable_Check(obj)) { + // unrelated objects with __torch_function__ + append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false); + return true; + } + return false; +} + +auto FunctionParameter::_check( + PyObject* obj, + std::vector& overloaded_args, + int argnum, + int64_t* failed_idx) -> bool { switch (type_) { case ParameterType::TENSOR: { if (is_tensor_and_append_overloaded(obj, &overloaded_args)) { @@ -1013,15 +1034,7 @@ auto FunctionParameter::check( case ParameterType::PYOBJECT: return true; case ParameterType::SCALARTYPE: - if (THPDtype_Check(obj) || THPPythonScalarType_Check(obj)) { - return true; - } - if (check_has_torch_function(obj, /*ignore_mode*/ true)) { - // tensor subclasses and unrelated objects with __torch_function__ - append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false); - return true; - } - return false; + return THPDtype_Check(obj) || THPPythonScalarType_Check(obj); case ParameterType::LAYOUT: return THPLayout_Check(obj); case ParameterType::MEMORY_FORMAT: diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index bc281f2512a5e..2c1373921e575 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -322,6 +322,12 @@ struct FunctionParameter { int argnum, int64_t* failed_idx = nullptr); + bool _check( + PyObject* obj, + std::vector& overloaded_args, + int argnum, + int64_t* failed_idx = nullptr); + void set_default_str(const std::string& str); TORCH_PYTHON_API std::string type_name() const; From d68c323692dedcbb74e670801e3502944fd790ff Mon Sep 17 00:00:00 2001 From: Wenyuan Chi Date: Fri, 8 Aug 2025 01:30:08 +0000 Subject: [PATCH 0124/1424] Log max_autotune exceptions (#159687) (#159688) Summary: Exceptions during autotune kernel precompilation are now systematically captured and reported via the chromium_event_logger, enabling better debugging and analysis of autotune failures. Currently, exceptions are dumped to the console in the following format:: ``` [0/0] RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help. [0/0] Runtime error during autotuning: [0/0] No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [0/0] Ignoring this choice. ``` The exception tracebacks: ``` # inner exception traceback: File "/torch/_inductor/runtime/triton_heuristics.py", line 603, in _make_launchers launchers.append(result.make_launcher()) ^^^^^^^^^^^^^^^^^^^^^^ File "/torch/_inductor/runtime/triton_heuristics.py", line 1503, in make_launcher self.kernel.load_kernel(device) File "/torch/_inductor/runtime/static_cuda_launcher.py", line 113, in load_kernel (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel( # wrapped exception traceback: File "/usr/local/fbcode/platform010/lib/python3.12/concurrent/futures/thread.py", line 59, in run result = self.fn(*self.args, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "#link-tree/torch/_inductor/select_algorithm.py", line 2596, in precompile_with_captured_stdout choice.precompile() File "#link-tree/torch/_inductor/select_algorithm.py", line 1881, in precompile self.bmreq.precompile() File "#link-tree/torch/_inductor/autotune_process.py", line 660, in precompile getattr(mod, self.kernel_name).precompile() File "#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile self._make_launchers() File "#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") ``` With this change, the exception details will also be logged in the metadata of the `{name}_template_precompiling` event. The format: ``` { "exceptions": [ { "choice_type": "triton", "choice": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0", "exception_message": "No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.", "exception": "OutOfMemoryError", "required_memory": "262144", "hardware_limit": "232448" } ] } ``` Test Plan: buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt Rollback Plan: Differential Revision: D79420953 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159688 Approved by: https://github.com/stashuk-olek --- torch/_inductor/select_algorithm.py | 60 +++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index b337e2b625fdf..4faa251953d69 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2650,11 +2650,13 @@ def on_complete(future): def wait_on_futures(): log.debug("Waiting on futures") counters["inductor"]["select_algorithm_precompile"] += 1 + exceptions: list[tuple[ChoiceCaller, BaseException]] = [] for future in as_completed( futures, timeout=precompilation_timeout_seconds, ): if e := future.exception(): + exceptions.append((futures[future], e)) from torch._inductor.codegen.cuda.cuda_kernel import ( CUDATemplateCaller, ) @@ -2682,6 +2684,8 @@ def wait_on_futures(): futures.get(future), elapsed_times.get(future), ) + if exceptions: + _log_autotune_exceptions(exceptions) executor.shutdown(wait=True) @@ -3452,5 +3456,61 @@ def _log_autotune_choices_stats( sys.stderr.write(f"Autotune Choices Stats:\n{payload}\n") +def _log_autotune_exceptions( + exceptions: list[tuple[ChoiceCaller, BaseException]], +) -> None: + """Log autotune exceptions to chromium event logger.""" + if not exceptions: + return + + try: + pt2_compile_substack = get_chromium_event_logger().get_pt2_compile_substack() + if not pt2_compile_substack: + return + + current_event = pt2_compile_substack[-1] + if not current_event.endswith("_template_precompiling"): + return + + exception_details = [] + for choice, exc in exceptions: + try: + choice_type = ( + "triton" if isinstance(choice, TritonTemplateCaller) else "other" + ) + data = { + "choice_type": choice_type, + "choice": choice.description, + "exception_message": str(exc), + } + + exc_type_match = re.search(r"(\w+):", str(exc)) + if exc_type_match: + data["exception"] = exc_type_match.group(1) + + if "OutOfMemoryError" in str(exc): + required_match = re.search(r"Required: (\d+)", str(exc)) + if required_match: + data["required_memory"] = required_match.group(1) + + limit_match = re.search(r"Hardware limit:\s*(\d+)", str(exc)) + if limit_match: + data["hardware_limit"] = limit_match.group(1) + + exception_details.append(data) + except Exception: + # Don't let logging errors break the main flow + continue + + if exception_details: + metadata = json.dumps({"exceptions": exception_details}) + get_chromium_event_logger().try_add_event_data( + current_event, metadata=metadata + ) + except Exception: + # Silently ignore logging errors to avoid breaking autotune + pass + + # ensure lowering is imported so that `extern_kernels.*` is populated from . import lowering # noqa: F401 From ba4ccf5d67e3d237f435eacc2bce3c6025f08491 Mon Sep 17 00:00:00 2001 From: Georgia Phillips Date: Fri, 8 Aug 2025 02:13:48 +0000 Subject: [PATCH 0125/1424] turn on executon frame clenaup by default (#160110) Summary: Turning execution frame cleanup back on since D78621408 is done Test Plan: See D78621408 Rollback Plan: Differential Revision: D79730674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160110 Approved by: https://github.com/jingsh --- torch/nativert/executor/ExecutorConfig.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nativert/executor/ExecutorConfig.h b/torch/nativert/executor/ExecutorConfig.h index 70f8fa88cf0d0..fb57f2b6f2ef6 100644 --- a/torch/nativert/executor/ExecutorConfig.h +++ b/torch/nativert/executor/ExecutorConfig.h @@ -11,7 +11,7 @@ struct ExecutorConfig { bool debugNan = false; bool enableStaticCPUKernels = true; bool runConstFolding = false; - bool doExecutionFrameCleanup = false; + bool doExecutionFrameCleanup = true; bool tryFreeUnmanagedValuesAfterUse = true; // allows up to max number of concurrent threads. int64_t maxNumConcurrentThreads = 8; From 05c417715f791875fbf28cfc3fc86142de1a3206 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Thu, 7 Aug 2025 11:24:21 -0700 Subject: [PATCH 0126/1424] integrate kernacle into inductor (#160121) This adds integration into inductor in two parts 1) It kicks off the best config lookup at lowering time within mm.py 2) It awaits the future at scheduling time in select_algorithm.py Notably this does not do the following 1) Support for enumerating between mm, addmm and bmm 2) Support for enumerating between exhaustive/max 3) Enumerating different hardware SKUs eg. H100, A100, etc. those will come in the next diffs Differential Revision: [D79824921](https://our.internmc.facebook.com/intern/diff/D79824921/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160121 Approved by: https://github.com/izaitsevfb --- test/inductor/custom_ops.cpp | 4 +- torch/_inductor/await_utils.py | 176 ++++++++++++++++++ torch/_inductor/config.py | 7 + torch/_inductor/kernel/mm.py | 15 +- torch/_inductor/remote_gemm_autotune_cache.py | 20 ++ torch/_inductor/select_algorithm.py | 31 +++ 6 files changed, 250 insertions(+), 3 deletions(-) create mode 100644 torch/_inductor/await_utils.py create mode 100644 torch/_inductor/remote_gemm_autotune_cache.py diff --git a/test/inductor/custom_ops.cpp b/test/inductor/custom_ops.cpp index ae1d00c5b6346..ade7695a10d02 100644 --- a/test/inductor/custom_ops.cpp +++ b/test/inductor/custom_ops.cpp @@ -1,7 +1,7 @@ #include // @manual=fbcode//caffe2:libtorch -#include -#include +#include // @manual +#include // @manual #include #include diff --git a/torch/_inductor/await_utils.py b/torch/_inductor/await_utils.py new file mode 100644 index 0000000000000..a549674d5cd78 --- /dev/null +++ b/torch/_inductor/await_utils.py @@ -0,0 +1,176 @@ +import asyncio +import sys +import weakref +from asyncio import AbstractEventLoop, Future +from collections.abc import Awaitable, Coroutine, Generator, Iterator +from contextlib import contextmanager, ExitStack +from contextvars import Context +from typing import Any, Callable, Optional, Protocol, TypeVar + +from torch.utils._ordered_set import OrderedSet + + +T = TypeVar("T") +TCoro = Generator[Any, None, T] + +if sys.version_info >= (3, 11): + + class TaskFactory(Protocol): + def __call__( + self, + __loop: AbstractEventLoop, + __factory: Coroutine[None, None, object] | Generator[None, None, object], + __context: Context | None = None, + /, + ) -> asyncio.futures.Future[object]: ... + + TaskFactoryType = TaskFactory +else: + TaskFactoryType = Callable[[AbstractEventLoop, Generator[TCoro, None, T]], Future] # type: ignore[valid-type] + + +def await_sync(awaitable: Awaitable[T]) -> T: + with get_loop() as loop: + return loop.run_until_complete(awaitable) + + +@contextmanager +def get_loop( + always_create_new_loop: bool = False, +) -> Iterator[AbstractEventLoop]: + try: + loop = asyncio.get_event_loop() + except RuntimeError as re: + if "There is no current event loop in thread" in str(re): + with _new_loop() as loop: + yield loop + return + else: + raise + + @contextmanager + def _restore_loop( + loop: asyncio.AbstractEventLoop, + ) -> Iterator[None]: + try: + yield + finally: + asyncio.set_event_loop(loop) + + @contextmanager + def _restore_running_loop() -> Iterator[None]: + loop_from_events = asyncio.events._get_running_loop() + asyncio.events._set_running_loop(None) + try: + yield + finally: + asyncio.events._set_running_loop(loop_from_events) + + with ExitStack() as stack: + if loop.is_running(): + stack.enter_context(_restore_running_loop()) + stack.enter_context(_restore_loop(loop=loop)) + loop = stack.enter_context(_new_loop(loop.get_task_factory())) # type: ignore[arg-type] + elif loop.is_closed(): + loop = stack.enter_context(_new_loop()) # type: ignore[arg-type] + elif always_create_new_loop: + stack.enter_context(_restore_loop(loop=loop)) + loop = stack.enter_context(_new_loop()) # type: ignore[arg-type] + yield loop + + +@contextmanager +def _new_loop( + task_factory: Optional[TaskFactoryType] = None, +) -> Iterator[asyncio.AbstractEventLoop]: + loop = asyncio.new_event_loop() + tasks = _patch_loop(loop) + + if task_factory: + # pyre-ignore[6] + loop.set_task_factory(task_factory) # type: ignore[arg-type] + + asyncio.set_event_loop(loop) + try: + yield loop + finally: + try: + _cancel_all_tasks(loop, tasks) + finally: + asyncio.set_event_loop(None) + loop.close() + + +def _cancel_all_tasks( + loop: AbstractEventLoop, + tasks: OrderedSet[Future], # type: ignore[type-arg] +) -> None: + to_cancel = [task for task in tasks if not task.done()] + + if not to_cancel: + return + + # pyre-fixme[1001]: Awaitable assigned to `task` is never awaited. + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[type-arg] + tasks: weakref.WeakSet[Future] = weakref.WeakSet() # type: ignore[type-arg] + + task_factories: list[Optional[TaskFactoryType]] = [None] + + def _set_task_factory(factory: Optional[TaskFactoryType]) -> None: + task_factories[0] = factory + + def _get_task_factory() -> Optional[TaskFactoryType]: + return task_factories[0] + + def _safe_task_factory( + loop: AbstractEventLoop, + coro: TCoro, # type: ignore[type-arg] + *, + context: Context | None = None, + ) -> asyncio.Future: # type: ignore[valid-type, type-arg] + task_factory = task_factories[0] + if task_factory is None: + if sys.version_info >= (3, 11): + task = asyncio.Task(coro, loop=loop, context=context) + else: + task = asyncio.Task(coro, loop=loop) + # pyre-ignore[16]: `Task` has no attribute `_source_traceback`. + if task._source_traceback: # type: ignore[attr-defined] + del task._source_traceback[ # type: ignore[attr-defined] + -1 + ] # pragma: no cover # type: ignore[attr-defined] + else: + if sys.version_info >= (3, 11): + task = task_factory(loop, coro, context=context) # type: ignore[arg-type, call-arg, assignment] + else: + task = task_factory(loop, coro) # type: ignore[arg-type] + # `Union[Task[Any], Future[Any]]`. + tasks.add(task) + return task + + # pyre-ignore[6] + loop.set_task_factory(_safe_task_factory) # type: ignore[method-assign, arg-type] + # pyre-ignore[8] + loop.set_task_factory = _set_task_factory # type: ignore[method-assign, assignment] + # pyre-ignore[8] + loop.get_task_factory = _get_task_factory # type: ignore[method-assign, assignment] + + return tasks # type: ignore[return-value] diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 51a438840b040..8d3b4cd7ed492 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -81,6 +81,11 @@ def prologue_fusion_enabled() -> bool: # Whether to enable printing the source code for each future verbose_progress = False +# Configurable compile worker logging path for subproc_pool +worker_log_path = ( + "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None +) + # precompilation timeout precompilation_timeout_seconds: int = 60 * 60 @@ -91,6 +96,8 @@ def prologue_fusion_enabled() -> bool: default=True, ) +remote_gemm_autotune_cache: bool = False + # use remote fx aot graph codegen cache # False: Disables the cache # True: Enables the cache diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 6e741430f36d6..e68a76174c73a 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -15,6 +15,7 @@ mm_operations, ) from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +from torch._inductor.remote_gemm_autotune_cache import gen_best_config from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.torch_version import TorchVersion @@ -836,7 +837,19 @@ def tuned_mm(mat1, mat2, *, layout=None): lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout) ) - return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) + best_config_future = None + # Purposely not awaiting the future here - this kicks off the best config lookup at lowering time + # The future will be awaited at scheduling time in select_algorithm.py + if torch._inductor.config.remote_gemm_autotune_cache: + best_config_future = gen_best_config(mat1, mat2) + + return autotune_select_algorithm( + name, + choices, + kernel_inputs.nodes(), + layout, + best_config_future=best_config_future, + ) @register_lowering(aten._int_mm, type_promotion_kind=None) diff --git a/torch/_inductor/remote_gemm_autotune_cache.py b/torch/_inductor/remote_gemm_autotune_cache.py new file mode 100644 index 0000000000000..0ef026269b10c --- /dev/null +++ b/torch/_inductor/remote_gemm_autotune_cache.py @@ -0,0 +1,20 @@ +import asyncio +from typing import TypeVar + +import torch._inductor.config as config +from torch._inductor import ir + + +_T = TypeVar("_T") + + +def gen_best_config(mat1: ir.StorageBox, mat2: ir.StorageBox) -> asyncio.Task[_T]: + """ + Generate the best GEMM autotune config for the given matrices. + """ + if config.is_fbcode(): + from torch._inductor.fb.remote_gemm_autotune_cache import gen_best_config + + return gen_best_config(mat1, mat2) + else: + raise NotImplementedError("Function gen_best_config is not yet implemented") diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 4faa251953d69..01337fc0d30b5 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -34,6 +34,7 @@ identity, preserve_rng_state, ) +from torch._inductor.await_utils import await_sync from torch._inductor.utils import clear_on_fresh_cache from torch.utils._filelock import FileLock from torch.utils._ordered_set import OrderedSet @@ -2280,6 +2281,7 @@ def __call__( input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, precompilation_timeout_seconds: int = 60 * 60, return_multi_template=False, + best_config_future=None, ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2387,6 +2389,35 @@ def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None): log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse) autotune_start_ts = time.time() + + if best_config_future is not None: + best_config = await_sync(best_config_future) + + important_keys = [ + "ACC_TYPE", + "ALLOW_TF32", + "BLOCK_K", + "BLOCK_M", + "BLOCK_N", + "EVEN_K", + "GROUP_M", + "USE_FAST_ACCUM", + "num_stages", + "num_warps", + "num_consumer_groups", + "num_buffers_warp_spec", + ] + choices = [ + choice + for choice in choices + if all( + f"{k}={best_config[k]}" in choice.description + for k in important_keys + ) + for k in important_keys + ] + log.info("Filtered to %d choices based on best_config", len(choices)) + timings = self.lookup( choices, name, From 3fcd79e023da7156ac584992ebab29205d3b7881 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 1 Aug 2025 18:00:29 -0300 Subject: [PATCH 0127/1424] Fix infinite loop when iterating over an empty zip (#159673) Dynamo would enter in an infinite recursion when `ZipVariable.next_variable(tx)` was called and there was no iterable to be iterated Pull Request resolved: https://github.com/pytorch/pytorch/pull/159673 Approved by: https://github.com/williamwen42 --- test/dynamo/cpython/3_13/test_itertools.diff | 84 ++++++++++++------- test/dynamo/cpython/3_13/test_itertools.py | 32 ++++--- ...on313-test_itertools-TestBasicOps.test_zip | 0 ...est_itertools-TestBasicOps.test_ziplongest | 0 torch/_dynamo/variables/iter.py | 4 + 5 files changed, 75 insertions(+), 45 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_zip delete mode 100644 test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_ziplongest diff --git a/test/dynamo/cpython/3_13/test_itertools.diff b/test/dynamo/cpython/3_13/test_itertools.diff index 1d31e9f656102..df7205a1c9033 100644 --- a/test/dynamo/cpython/3_13/test_itertools.diff +++ b/test/dynamo/cpython/3_13/test_itertools.diff @@ -1,5 +1,5 @@ diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py -index 7d5ba727389..ef73c7f0ce1 100644 +index 7d5ba727389..98f962e4353 100644 --- a/test/dynamo/cpython/3_13/test_itertools.py +++ b/test/dynamo/cpython/3_13/test_itertools.py @@ -1,3 +1,25 @@ @@ -166,23 +166,51 @@ index 7d5ba727389..ef73c7f0ce1 100644 @pickle_deprecated def test_filterfalse(self): -@@ -1038,6 +1062,7 @@ class TestBasicOps(unittest.TestCase): - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - self.pickletest(proto, filterfalse(isEven, range(6))) - -+ @skipIfTorchDynamo("infinite loop in torch dynamo") - def test_zip(self): - # XXX This is rather silly now that builtin zip() calls zip()... - ans = [(x,y) for x, y in zip('abc',count())] -@@ -1082,6 +1107,7 @@ class TestBasicOps(unittest.TestCase): - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - self.pickletest(proto, zip('abc', count())) - -+ @skipIfTorchDynamo("infinite loop in torch dynamo") - def test_ziplongest(self): - for args in [ - ['abc', range(6)], -@@ -1767,6 +1793,7 @@ class TestBasicOps(unittest.TestCase): +@@ -1047,8 +1071,8 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3))) + self.assertEqual(list(zip('abcdef')), lzip('abcdef')) + self.assertEqual(list(zip()), lzip()) +- self.assertRaises(TypeError, zip, 3) +- self.assertRaises(TypeError, zip, range(3), 3) ++ # self.assertRaises(TypeError, zip, 3) ++ # self.assertRaises(TypeError, zip, range(3), 3) + self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')], + lzip('abc', 'def')) + self.assertEqual([pair for pair in zip('abc', 'def')], +@@ -1105,19 +1129,19 @@ class TestBasicOps(unittest.TestCase): + + self.assertEqual(list(zip_longest('abc', 'defg', **{})), + list(zip(list('abc')+[None], 'defg'))) # empty keyword dict +- self.assertRaises(TypeError, zip_longest, 3) +- self.assertRaises(TypeError, zip_longest, range(3), 3) +- +- for stmt in [ +- "zip_longest('abc', fv=1)", +- "zip_longest('abc', fillvalue=1, bogus_keyword=None)", +- ]: +- try: +- eval(stmt, globals(), locals()) +- except TypeError: +- pass +- else: +- self.fail('Did not raise Type in: ' + stmt) ++ # self.assertRaises(TypeError, zip_longest, 3) ++ # self.assertRaises(TypeError, zip_longest, range(3), 3) ++ ++ # for stmt in [ ++ # "zip_longest('abc', fv=1)", ++ # "zip_longest('abc', fillvalue=1, bogus_keyword=None)", ++ # ]: ++ # try: ++ # eval(stmt, globals(), locals()) ++ # except TypeError: ++ # pass ++ # else: ++ # self.fail('Did not raise Type in: ' + stmt) + + self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')], + list(zip('abc', 'def'))) +@@ -1767,6 +1791,7 @@ class TestBasicOps(unittest.TestCase): script_helper.assert_python_ok("-c", script) # Issue 13454: Crash when deleting backward iterator from tee() @@ -190,7 +218,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_tee_del_backward(self): forward, backward = tee(repeat(None, 20000000)) try: -@@ -1920,7 +1947,7 @@ class TestBasicOps(unittest.TestCase): +@@ -1920,7 +1945,7 @@ class TestBasicOps(unittest.TestCase): tp.foobar = 1 @@ -199,7 +227,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_accumulate(self): self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15]) -@@ -2032,7 +2059,7 @@ class TestExamples(unittest.TestCase): +@@ -2032,7 +2057,7 @@ class TestExamples(unittest.TestCase): self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4]) @@ -208,7 +236,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_batched_recipe(self): def batched_recipe(iterable, n): -@@ -2081,6 +2108,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): +@@ -2081,6 +2106,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): for i, element in zip(range(i + 1, stop), iterable): pass @@ -216,7 +244,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_islice_recipe(self): self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB')) self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD')) -@@ -2265,7 +2293,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): +@@ -2265,7 +2291,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): raise @@ -225,7 +253,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def makecycle(self, iterator, container): container.append(iterator) -@@ -2465,7 +2493,7 @@ def L(seqn): +@@ -2465,7 +2491,7 @@ def L(seqn): return chain(map(lambda x:x, R(Ig(G(seqn))))) @@ -234,7 +262,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_accumulate(self): s = [1,2,3,4,5] -@@ -2644,7 +2672,7 @@ class TestVariousIteratorArgs(unittest.TestCase): +@@ -2644,7 +2670,7 @@ class TestVariousIteratorArgs(unittest.TestCase): self.assertRaises(TypeError, tee, N(s)) self.assertRaises(ZeroDivisionError, list, tee(E(s))[0]) @@ -243,7 +271,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_repeat(self): self.assertEqual(operator.length_hint(repeat(None, 50)), 50) -@@ -2657,7 +2685,7 @@ class LengthTransparency(unittest.TestCase): +@@ -2657,7 +2683,7 @@ class LengthTransparency(unittest.TestCase): self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0) self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0) @@ -252,7 +280,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_sf_793826(self): # Fix Armin Rigo's successful efforts to wreak havoc -@@ -2718,6 +2746,7 @@ class RegressionTests(unittest.TestCase): +@@ -2718,6 +2744,7 @@ class RegressionTests(unittest.TestCase): @support.skip_if_pgo_task @support.requires_resource('cpu') @@ -260,7 +288,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_long_chain_of_empty_iterables(self): # Make sure itertools.chain doesn't run into recursion limits when # dealing with long chains of empty iterables. Even with a high -@@ -2750,7 +2779,7 @@ class RegressionTests(unittest.TestCase): +@@ -2750,7 +2777,7 @@ class RegressionTests(unittest.TestCase): next(g, None) # shouldn't crash @@ -269,7 +297,7 @@ index 7d5ba727389..ef73c7f0ce1 100644 def test_keywords_in_subclass(self): # count is not subclassable... testcases = [ -@@ -2805,49 +2834,5 @@ class SubclassWithKwargsTest(unittest.TestCase): +@@ -2805,49 +2832,5 @@ class SubclassWithKwargsTest(unittest.TestCase): self.assertEqual(u.newarg, 3) diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py index ef73c7f0ce165..98f962e435365 100644 --- a/test/dynamo/cpython/3_13/test_itertools.py +++ b/test/dynamo/cpython/3_13/test_itertools.py @@ -1062,7 +1062,6 @@ def test_filterfalse(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, filterfalse(isEven, range(6))) - @skipIfTorchDynamo("infinite loop in torch dynamo") def test_zip(self): # XXX This is rather silly now that builtin zip() calls zip()... ans = [(x,y) for x, y in zip('abc',count())] @@ -1072,8 +1071,8 @@ def test_zip(self): self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3))) self.assertEqual(list(zip('abcdef')), lzip('abcdef')) self.assertEqual(list(zip()), lzip()) - self.assertRaises(TypeError, zip, 3) - self.assertRaises(TypeError, zip, range(3), 3) + # self.assertRaises(TypeError, zip, 3) + # self.assertRaises(TypeError, zip, range(3), 3) self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')], lzip('abc', 'def')) self.assertEqual([pair for pair in zip('abc', 'def')], @@ -1107,7 +1106,6 @@ def test_zip_tuple_reuse(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, zip('abc', count())) - @skipIfTorchDynamo("infinite loop in torch dynamo") def test_ziplongest(self): for args in [ ['abc', range(6)], @@ -1131,19 +1129,19 @@ def test_ziplongest(self): self.assertEqual(list(zip_longest('abc', 'defg', **{})), list(zip(list('abc')+[None], 'defg'))) # empty keyword dict - self.assertRaises(TypeError, zip_longest, 3) - self.assertRaises(TypeError, zip_longest, range(3), 3) - - for stmt in [ - "zip_longest('abc', fv=1)", - "zip_longest('abc', fillvalue=1, bogus_keyword=None)", - ]: - try: - eval(stmt, globals(), locals()) - except TypeError: - pass - else: - self.fail('Did not raise Type in: ' + stmt) + # self.assertRaises(TypeError, zip_longest, 3) + # self.assertRaises(TypeError, zip_longest, range(3), 3) + + # for stmt in [ + # "zip_longest('abc', fv=1)", + # "zip_longest('abc', fillvalue=1, bogus_keyword=None)", + # ]: + # try: + # eval(stmt, globals(), locals()) + # except TypeError: + # pass + # else: + # self.fail('Did not raise Type in: ' + stmt) self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')], list(zip('abc', 'def'))) diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_zip b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_zip deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_ziplongest b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_ziplongest deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index dcdd0e80a434a..3db4daefc978e 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -351,6 +351,10 @@ def unpack_var_sequence(self, tx) -> list["VariableTracker"]: def next_variable(self, tx): assert self.is_mutable() + + if len(self.iterables) == 0: + raise_observed_exception(StopIteration, tx) + old_index = self.index args = [] From beb4d7816dedc67a5de1f82e5a45b5910f407941 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 8 Aug 2025 03:14:55 +0000 Subject: [PATCH 0128/1424] [BE]: ruff PLC0207 - use maxsplit kwarg (#160107) Automatically replaces split with rsplit when relevant and only performs the split up to the first ( or last value). This allows early return of the split function and improve efficiency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160107 Approved by: https://github.com/albanD --- .ci/aarch64_linux/build_aarch64_wheel.py | 16 ++++------------ .github/scripts/runner_determinator.py | 7 ++++++- test/onnx/torchlib/error_reproduction.py | 4 ++-- tools/testing/discover_tests.py | 2 +- tools/testing/modulefinder_determinator.py | 2 +- torch/_custom_op/impl.py | 2 +- torch/_inductor/codecache.py | 2 +- torch/_inductor/scheduler.py | 2 +- torch/_prims/__init__.py | 2 +- torch/ao/pruning/sparsifier/utils.py | 2 +- torch/fx/passes/splitter_base.py | 2 +- torch/testing/_internal/common_cuda.py | 4 ++-- torch/testing/_internal/common_distributed.py | 2 +- torch/testing/_internal/common_utils.py | 2 +- torchgen/selective_build/operator.py | 2 +- 15 files changed, 25 insertions(+), 28 deletions(-) diff --git a/.ci/aarch64_linux/build_aarch64_wheel.py b/.ci/aarch64_linux/build_aarch64_wheel.py index 025d0a20579d4..7a4715d330060 100755 --- a/.ci/aarch64_linux/build_aarch64_wheel.py +++ b/.ci/aarch64_linux/build_aarch64_wheel.py @@ -438,9 +438,7 @@ def build_torchvision( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: - build_vars += ( - f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" - ) + build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -495,9 +493,7 @@ def build_torchdata( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: - build_vars += ( - f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" - ) + build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -553,9 +549,7 @@ def build_torchtext( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: - build_vars += ( - f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" - ) + build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -613,9 +607,7 @@ def build_torchaudio( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: - build_vars += ( - f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" - ) + build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index 1481459d40c4c..baf560234549b 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -262,7 +262,12 @@ def is_exception_branch(branch: str) -> bool: """ Branches that get opted out of experiments by default, until they're explicitly enabled. """ - return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} + return branch.split("/", maxsplit=1)[0] in { + "main", + "nightly", + "release", + "landchecks", + } def load_yaml(yaml_text: str) -> Any: diff --git a/test/onnx/torchlib/error_reproduction.py b/test/onnx/torchlib/error_reproduction.py index 260a37b65f169..9fd1dace77677 100644 --- a/test/onnx/torchlib/error_reproduction.py +++ b/test/onnx/torchlib/error_reproduction.py @@ -205,7 +205,7 @@ def create_reproduction_report( onnxscript=={onnxscript.__version__} numpy=={np.__version__} torch=={torch.__version__}""" - short_test_name = test_name.split(".")[-1] + short_test_name = test_name.rsplit(".", maxsplit=1)[-1] reproduction_code = _REPRODUCTION_TEMPLATE.format( onnx_model_text=onnx_model_text, ort_inputs=input_text, @@ -245,7 +245,7 @@ def create_mismatch_report( error_text = str(error) error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__)) - short_test_name = test_name.split(".")[-1] + short_test_name = test_name.rsplit(".", maxsplit=1)[-1] diff = difflib.unified_diff( str(actual).splitlines(), str(expected).splitlines(), diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py index 28ff5bc3ff292..96aee230f89f8 100644 --- a/tools/testing/discover_tests.py +++ b/tools/testing/discover_tests.py @@ -13,7 +13,7 @@ def parse_test_module(test: str) -> str: - return test.split(".")[0] + return test.split(".", maxsplit=1)[0] def discover_tests( diff --git a/tools/testing/modulefinder_determinator.py b/tools/testing/modulefinder_determinator.py index e698cf3586dd3..e0ef858b96b21 100644 --- a/tools/testing/modulefinder_determinator.py +++ b/tools/testing/modulefinder_determinator.py @@ -186,7 +186,7 @@ def get_dep_modules(test: str) -> set[str]: def parse_test_module(test: str) -> str: - return test.split(".")[0] + return test.split(".", maxsplit=1)[0] def print_to_stderr(message: str) -> None: diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index dd3e9e8fa2dd1..208c18e392a46 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -648,7 +648,7 @@ def custom_op_from_existing(op): name = op.name().split("::")[-1] schema_str = str(op._schema) # CustomOp expects the schema string without the namespace - schema_str = schema_str.split("::")[-1] + schema_str = schema_str.rsplit("::", maxsplit=1)[-1] schema = FunctionSchema.parse(schema_str) return CustomOp(lib, ns, schema, name, op, _private_access=True) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index e404cd78936f0..65317648a02e7 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2552,7 +2552,7 @@ def _get_cpp_prefix_header(device: str) -> Optional[str]: def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str: """Given a device type (and optionally whether we're in AOT Inductor mode), returns the path to the cpp_wrapper header file to be precompiled.""" - base_device = device.split(":")[0] + base_device = device.split(":", maxsplit=1)[0] is_array_ref = config.aot_inductor.allow_stack_allocation and base_device == "cpu" return ( "torch/csrc/inductor/" diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index abd2fe413d1af..e0a0309d1c811 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -605,7 +605,7 @@ def codegen_originating_info( out_lines.append(op_info_str) if "stack_trace" in o.meta: stack_trace = f"{o.meta['stack_trace']}" - stack_trace_last_line = stack_trace.split("|")[-1] + stack_trace_last_line = stack_trace.rsplit("|", maxsplit=1)[-1] out_lines.append( "#pragma CMT " + stack_trace_last_line.replace("{", "{{") diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 6739b334c1169..bb26bbb508bd6 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -302,7 +302,7 @@ def _backend_select_impl(*args, **kwargs): else: return _prim_impl(*args, **kwargs) - name = schema.split("(")[0] + name = schema.split("(", maxsplit=1)[0] schema = schema[len(name) :] # register non-functional ops with old custom ops API diff --git a/torch/ao/pruning/sparsifier/utils.py b/torch/ao/pruning/sparsifier/utils.py index 302f7e0b0b7c1..47185aeea5274 100644 --- a/torch/ao/pruning/sparsifier/utils.py +++ b/torch/ao/pruning/sparsifier/utils.py @@ -98,7 +98,7 @@ def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str, # string manip to split tensor_fqn into module_fqn and tensor_name # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' - tensor_name = tensor_fqn.split(".")[-1] + tensor_name = tensor_fqn.rsplit(".", maxsplit=1)[-1] module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] module = fqn_to_module(model, module_fqn) diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index d3ef35bdb1070..e0b2ff63ba078 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -719,7 +719,7 @@ def extend_acc_subgraph(self, tag: str): """ # Dict that maps node to its users and ignore users that # are in the subgraph that has greater tag - deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) + deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1])) self.update_reverse_deps_for_fusions(deps) # Parent nodes of the subgraph diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 0e95db1fdf379..dca0275f38878 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -291,7 +291,7 @@ def _get_torch_rocm_version(): if not TEST_WITH_ROCM or torch.version.hip is None: return (0, 0) rocm_version = str(torch.version.hip) - rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha return tuple(int(x) for x in rocm_version.split(".")) def _check_cusparse_generic_available(): @@ -304,7 +304,7 @@ def _check_hipsparse_generic_available(): return False rocm_version = str(torch.version.hip) - rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1)) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 0dbb6ca0ea718..d4cc6cde3cc50 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -1605,7 +1605,7 @@ def _init_pg(cls, rank, world_size, rdvz_file): @classmethod def _run_test_given_id(cls, test_id: str, **kwargs) -> None: # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' - test_name = test_id.split(".")[-1] + test_name = test_id.rsplit(".", maxsplit=1)[-1] # Get the test function from the test class self = cls(test_name) self.rank = cls.rank diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 57b7a9fed43fb..bfc568bc14645 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2017,7 +2017,7 @@ def dec_fn(fn): def wrap_fn(self, *args, **kwargs): if TEST_WITH_ROCM: rocm_version = str(torch.version.hip) - rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): reason = f"ROCm {rocm_version_tuple} is available but {version} required" diff --git a/torchgen/selective_build/operator.py b/torchgen/selective_build/operator.py index 0cb92dfc09e28..8047f033e3d2b 100644 --- a/torchgen/selective_build/operator.py +++ b/torchgen/selective_build/operator.py @@ -168,4 +168,4 @@ def merge_operator_dicts( def strip_operator_overload_name(op_name: str) -> str: - return op_name.split(".")[0] + return op_name.split(".", maxsplit=1)[0] From 2ea40fba841b3af8103f332ba62e54f350ba9a51 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Wed, 6 Aug 2025 03:58:53 -0700 Subject: [PATCH 0129/1424] [Linter] Improve device-bias linter by adding detection for `with torch.device("cuda")`. (#159926) ``` For example, detect the following situation: >>>Lint for test/dynamo/test_modes.py: Error (TEST_DEVICE_BIAS) [device-bias] `@requires_gpu` function should not hardcode `with torch.device('cuda')`, suggest to use torch.device(GPU_TYPE) 687 | flex_attention as flex_attention_eager, 688 | ) 689 | >>> 690 | with torch.device("cuda"): 691 | flex_attention = torch.compile(flex_attention_eager, dynamic=False) 692 | 693 | with self.assertRaisesRegex( ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159926 Approved by: https://github.com/EikanWang, https://github.com/jansel ghstack dependencies: #159759 --- test/dynamo/test_modes.py | 5 +++-- .../adapters/test_device_bias_linter.py | 21 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index a844efd51af93..ec9c4473a17fb 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -13,6 +13,7 @@ ) from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode from torch.testing._internal.common_utils import skipIfXpu +from torch.testing._internal.inductor_utils import GPU_TYPE from torch.testing._internal.triton_utils import requires_gpu from torch.utils._device import DeviceContext from torch.utils._python_dispatch import TorchDispatchMode @@ -687,7 +688,7 @@ def test_hop(self): flex_attention as flex_attention_eager, ) - with torch.device("cuda"): + with torch.device(GPU_TYPE): flex_attention = torch.compile(flex_attention_eager, dynamic=False) with self.assertRaisesRegex( @@ -711,7 +712,7 @@ def test_hop_eager(self): flex_attention as flex_attention_eager, ) - with torch.device("cuda"): + with torch.device(GPU_TYPE): with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, "raised exception HopDetectionError([ConstantVariable(str: 'test')])", diff --git a/tools/linter/adapters/test_device_bias_linter.py b/tools/linter/adapters/test_device_bias_linter.py index 9901d5f3fe523..00786ef3df86c 100644 --- a/tools/linter/adapters/test_device_bias_linter.py +++ b/tools/linter/adapters/test_device_bias_linter.py @@ -105,6 +105,25 @@ def _check_device_methods(self, subnode: ast.Call, msg_prefix: str) -> None: f"{msg_prefix} .to('{arg.value}'), suggest to use .to(GPU_TYPE)", ) + def _check_with_statement(self, node: ast.With, msg_prefix: str) -> None: + for item in node.items: + ctx_expr = item.context_expr + if isinstance(ctx_expr, ast.Call): + func = ctx_expr.func + if ( + isinstance(func, ast.Attribute) + and func.attr == "device" + and isinstance(func.value, ast.Name) + and func.value.id == "torch" + and ctx_expr.args + and isinstance(ctx_expr.args[0], ast.Constant) + and any(bias in ctx_expr.args[0].value for bias in DEVICE_BIAS) + ): + self.record( + ctx_expr, + f"{msg_prefix} `with torch.device('{ctx_expr.args[0].value}')`, suggest to use torch.device(GPU_TYPE)", + ) + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # Check if the function is decorated with @requires_gpu, which indicates # that the function is intended to run on GPU devices (e.g., CUDA or XPU), @@ -121,6 +140,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: subnode.func, ast.Attribute ): self._check_device_methods(subnode, msg_prefix) + elif isinstance(subnode, ast.With): + self._check_with_statement(subnode, msg_prefix) self.generic_visit(node) From 017259f9c65b6fad55fb9597d7077e2543eaae46 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Fri, 8 Aug 2025 03:38:28 +0000 Subject: [PATCH 0130/1424] [benchmarks] Add nativert benchmark (#159922) Add NativeRT as an option in the PT2 OSS benchmark ``` python ./benchmarks/dynamo/huggingface.py --performance --inference --export-nativert python ./benchmarks/dynamo/timm_models.py --performance --inference --export-nativert python ./benchmarks/dynamo/torchbench.py --performance --inference --export-nativert ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159922 Approved by: https://github.com/angelayi --- benchmarks/dynamo/common.py | 63 +++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 516549d7f6569..651bc90ba194b 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -21,6 +21,7 @@ import signal import subprocess import sys +import tempfile import time import weakref from contextlib import contextmanager @@ -41,6 +42,7 @@ import torch.distributed import torch.multiprocessing as mp from torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU +from torch._C._nativert import PyModelRunner from torch._dynamo.profiler import fx_insert_profiling, Profiler from torch._dynamo.testing import ( dummy_fx_compile, @@ -1100,6 +1102,8 @@ def maybe_mark_profile(*args, **kwargs): frozen_model_iter_fn = export_aot_inductor( model, example_inputs, args.inductor_compile_mode ) + elif args.export_nativert: + frozen_model_iter_fn = export_nativert(model, example_inputs) else: frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) @@ -1446,6 +1450,38 @@ def get_excess_memory(cls, model) -> float: return cls.cache.get(weakref.ref(model), (None, 0.0))[1] +class NativeRTCache: + cache: dict[weakref.ref, Any] = {} + + @classmethod + def load(cls, model, example_inputs): + from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path + + key = weakref.ref(model) + if key not in cls.cache: + example_args, example_kwargs = _normalize_bench_inputs(example_inputs) + example_outputs = model(*example_args, **example_kwargs) + _register_dataclass_output_as_pytree(example_outputs) + + combined_args = _combine_args(model, example_args, example_kwargs) + dynamic_shapes = _tree_map_with_path( + _produce_dynamic_shapes_for_export, combined_args + ) + + ep = torch.export.export( + model, example_args, example_kwargs, dynamic_shapes=dynamic_shapes + ) + ep = ep.run_decompositions({}) + with tempfile.NamedTemporaryFile(delete=False) as f: + torch.export.pt2_archive._package.package_pt2( + f, exported_programs={"forward": ep} + ) + filename = f.name + cls.cache[key] = PyModelRunner(filename, "forward") + + return cls.cache[key] + + def export(model, example_inputs): from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path @@ -1472,6 +1508,16 @@ def opt_export(_, example_inputs): return opt_export +def export_nativert(model, example_inputs): + optimized = NativeRTCache.load(model, example_inputs) + + def opt_nativert(_, example_inputs, collect_outputs=False): + example_args, example_kwargs = _normalize_bench_inputs(example_inputs) + return optimized.run(*example_args, **example_kwargs) + + return opt_nativert + + def export_aot_inductor(model, example_inputs, mode): optimized = AOTInductorModelCache.load(model, example_inputs, mode) @@ -2228,7 +2274,11 @@ def record_status(accuracy_status, dynamo_start_stats): try: model_copy = self.deepcopy_and_maybe_parallelize(model) self.init_optimizer(name, current_device, model_copy.parameters()) - if self.args.export or self.args.export_aot_inductor: + if ( + self.args.export + or self.args.export_aot_inductor + or self.args.export_nativert + ): # apply export on module directly # no need for n iterations # the logic should be the same to self.model_iter_fn (forward_pass) @@ -2624,7 +2674,7 @@ def warmup(fn, model, example_inputs, mode, niters=5): niters=1, ) - if self.args.export_aot_inductor: + if self.args.export_aot_inductor or self.args.export_nativert: optimized_model_iter_fn = optimize_ctx else: optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) @@ -3377,6 +3427,11 @@ def get_example_inputs(self): action="store_true", help="Measure pass rate with Export+AOTInductor", ) + group.add_argument( + "--export-nativert", + action="store_true", + help="Measure pass rate with Export+NativeRT", + ) group.add_argument( "--xla", action="store_true", help="Compare TorchXLA to eager PyTorch" ) @@ -3818,6 +3873,10 @@ def run(runner, args, original_dir=None): optimize_ctx = export experiment = speedup_experiment output_filename = "export.csv" + elif args.export_nativert: + optimize_ctx = export_nativert + experiment = speedup_experiment + output_filename = "export_nativert.csv" elif args.xla: (dev,) = args.devices os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev] From 24257f5bfaa37795f74d9f64c1b43584128d4b8c Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Fri, 8 Aug 2025 04:13:44 +0000 Subject: [PATCH 0131/1424] [vllm hash update] update the pinned vllm hash (#159822) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vllm hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159822 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vllm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt index 21863c19dec73..d5b7ebc020178 100644 --- a/.github/ci_commit_pins/vllm.txt +++ b/.github/ci_commit_pins/vllm.txt @@ -1 +1 @@ -6a39ba85fe0f2fff9494b5eccea717c93510c230 +7e3a8dc90670fd312ce1e0d4eba9bf11c571e3ad From b5c937259b17b65c1c6039a8f08ef2758ce615db Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Wed, 6 Aug 2025 16:58:26 -0700 Subject: [PATCH 0132/1424] [SymmMem] Add NVSHMEM Reduction support (sum, min, max) into Triton (#158515) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements sum_reduce, min_reduce, and max_reduce collective operations for NVSHMEM Triton kernels. Enables parallel reduction computations across PE teams for int64 data types. Tests: `python test/distributed/test_nvshmem_triton.py`
Quick debug print for sanity check ```markdown ============================================================ [Rank 1] Starting min/max reduction test with world_size=2 ============================================================ ============================================================ [Rank 0] Starting min/max reduction test with world_size=2 ============================================================ [Rank 0] Source data for min/max: [10, 20] [Rank 1] Source data for min/max: [15, 5] [Rank 1] All values across PEs: [Rank 0] All values across PEs: - Position 0: [10, 15] - Position 0: [10, 15] - Position 1: [20, 5] - Position 1: [20, 5] [Rank 1] Expected min: [10, 5] [Rank 0] Expected min: [10, 5] [Rank 1] Expected max: [15, 20] [Rank 0] Expected max: [15, 20] [Rank 0] Executing MIN reduction... [Rank 1] Executing MIN reduction... [Rank 0] Executing MAX reduction... [Rank 1] Executing MAX reduction... /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once [Rank 1] Results: [Rank 0] Results: [Rank 1] MIN reduction result: [10, 5] [Rank 1] MAX reduction result: [15, 20] [Rank 0] MIN reduction result: [10, 5] [Rank 0] MAX reduction result: [15, 20] [Rank 1] ============================================================ [Rank 1] Min/Max reduction test PASSED ✓ [Rank 1] ============================================================ [Rank 0] ============================================================ [Rank 0] Min/Max reduction test PASSED ✓ [Rank 0] ============================================================ ...... ============================================================ ============================================================ [Rank 0] Starting sum reduction test with world_size=2 [Rank 1] Starting sum reduction test with world_size=2 ============================================================ ============================================================ [Rank 0] Configuration: [Rank 1] Configuration: - nreduce: 3 (number of separate reductions) - nreduce: 3 (number of separate reductions) - dtype: torch.int64 - dtype: torch.int64 [Rank 1] Source data: [2, 4, 6] [Rank 1] Contribution explanation: [Rank 0] Source data: [1, 2, 3] [Rank 0] Contribution explanation: - Element 0: 2 = (rank=1+1) * (index=0+1) - Element 0: 1 = (rank=0+1) * (index=0+1) - Element 1: 4 = (rank=1+1) * (index=1+1) - Element 1: 2 = (rank=0+1) * (index=1+1) - Element 2: 6 = (rank=1+1) * (index=2+1) - Element 2: 3 = (rank=0+1) * (index=2+1) [Rank 1] Initial destination: [-1, -1, -1] [Rank 0] Initial destination: [-1, -1, -1] [Rank 0] Expected results after reduction: [3, 6, 9] [Rank 1] Expected results after reduction: [3, 6, 9] [Rank 0] Executing sum reduction... [Rank 1] Executing sum reduction... [Rank 1] Sum reduction completed /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once [Rank 0] Sum reduction completed /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once [Rank 0] Results after reduction: [Rank 0] Destination buffer: [3, 6, 9] [Rank 1] Results after reduction: [Rank 0] Verification: - Reduction 0: PE0: 1 + PE1: 2 = 3 Result: 3, Match: ✓ - Reduction 1: PE0: 2 + PE1: 4 = 6 Result: 6, Match: ✓ [Rank 1] Destination buffer: [3, 6, 9] - Reduction 2: PE0: 3 + PE1: 6 = 9 [Rank 1] Verification: - Reduction 0: PE0: 1 + PE1: 2 = 3 Result: 9, Match: ✓ Result: 3, Match: ✓ - Reduction 1: PE0: 2 + PE1: 4 = 6 Result: 6, Match: ✓ - Reduction 2: PE0: 3 + PE1: 6 = 9 Result: 9, Match: ✓ [Rank 0] ============================================================ [Rank 0] Sum reduction test PASSED ✓ [Rank 0] All 3 reductions computed correctly across 2 PEs [Rank 0] ============================================================ [Rank 1] ============================================================ [Rank 1] Sum reduction test PASSED ✓ [Rank 1] All 3 reductions computed correctly across 2 PEs [Rank 1] ============================================================ ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158515 Approved by: https://github.com/mandroid6, https://github.com/ngimel --- test/distributed/test_nvshmem_triton.py | 150 ++++++++++++++++++ .../_symmetric_memory/_nvshmem_triton.py | 57 +++++++ 2 files changed, 207 insertions(+) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index c4565a96496ce..1145da014543d 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -231,6 +231,36 @@ def broadcast_kernel( nvshmem.broadcast(team_handle, dest_ptr, src_ptr, nelems, pe_root) +@triton.jit +def sum_reduce_kernel( + team_handle, + dest_ptr, + src_ptr, + nreduce, +): + nvshmem.sum_reduce(team_handle, dest_ptr, src_ptr, nreduce) + + +@triton.jit +def max_reduce_kernel( + team_handle, + dest_ptr, + src_ptr, + nreduce, +): + nvshmem.max_reduce(team_handle, dest_ptr, src_ptr, nreduce) + + +@triton.jit +def min_reduce_kernel( + team_handle, + dest_ptr, + src_ptr, + nreduce, +): + nvshmem.min_reduce(team_handle, dest_ptr, src_ptr, nreduce) + + @instantiate_parametrized_tests @requires_nvshmem() class NVSHMEMTritonTest(MultiProcContinousTest): @@ -947,6 +977,126 @@ def test_triton_broadcast(self) -> None: dst, torch.tensor(expected, device=self.device, dtype=dtype) ) + @skipIfRocm + @requires_triton() + def test_triton_sum_reduce(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + world_size = dist.get_world_size() + rank = self.rank + # Configuration + nreduce = 3 # number of separate reductions + dtype = torch.int64 + # Source buffer - each rank contributes different values + src = symm_mem.empty(nreduce, dtype=dtype, device=self.device) + for i in range(nreduce): + src[i] = (rank + 1) * (i + 1) # Rank 0: [1,2,3], Rank 1: [2,4,6], etc. + # Destination buffer + dst = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Calculate expected results + expected = [] + for i in range(nreduce): + # Sum across all ranks: sum((rank+1)*(i+1) for rank in range(world_size)) + total = sum((r + 1) * (i + 1) for r in range(world_size)) + expected.append(total) + # Synchronize before reduction + dist.barrier() + # Execute reduction + team_handle = 0 # NVSHMEM_TEAM_WORLD + sum_reduce_kernel[(1,)]( + team_handle, + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + nreduce, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + # Synchronize after reduction + dist.barrier() + # Verify results + torch.testing.assert_close( + dst, torch.tensor(expected, device=self.device, dtype=dtype) + ) + + @skipIfRocm + @requires_triton() + def test_triton_minmax_reduce(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + world_size = dist.get_world_size() + rank = self.rank + # Configuration + nreduce = 2 # number of values to reduce + dtype = torch.int64 + # Source buffers for min and max + src_min = symm_mem.empty(nreduce, dtype=dtype, device=self.device) + src_max = symm_mem.empty(nreduce, dtype=dtype, device=self.device) + # Each rank contributes different values + # For min: rank 0: [10, 20], rank 1: [15, 5], etc. + # For max: same values + for i in range(nreduce): + if i == 0: + src_min[i] = 10 + rank * 5 # 10, 15, 20, ... + src_max[i] = 10 + rank * 5 + else: + src_min[i] = 20 - rank * 15 # 20, 5, -10, ... + src_max[i] = 20 - rank * 15 + # Destination buffers + dst_min = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) + dst_max = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) + src_min_hdl = symm_mem.rendezvous(src_min, group=group_name) + src_max_hdl = symm_mem.rendezvous(src_max, group=group_name) + dst_min_hdl = symm_mem.rendezvous(dst_min, group=group_name) + dst_max_hdl = symm_mem.rendezvous(dst_max, group=group_name) + # Calculate expected results + all_values = [] + for i in range(nreduce): + values = [] + for r in range(world_size): + if i == 0: + values.append(10 + r * 5) + else: + values.append(20 - r * 15) + all_values.append(values) + expected_min = [min(vals) for vals in all_values] + expected_max = [max(vals) for vals in all_values] + dist.barrier() + # Execute MIN reduction + team_handle = 0 + min_reduce_kernel[(1,)]( + team_handle, + dst_min_hdl.buffer_ptrs[rank], + src_min_hdl.buffer_ptrs[rank], + nreduce, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + # Execute MAX reduction + max_reduce_kernel[(1,)]( + team_handle, + dst_max_hdl.buffer_ptrs[rank], + src_max_hdl.buffer_ptrs[rank], + nreduce, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + dist.barrier() + # Verify results + torch.testing.assert_close( + dst_min, torch.tensor(expected_min, device=self.device, dtype=dtype) + ) + torch.testing.assert_close( + dst_max, torch.tensor(expected_max, device=self.device, dtype=dtype) + ) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index dda1885a8e167..aefb7541d8308 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -280,3 +280,60 @@ def broadcast(team, dest, source, nelems, pe_root, _builder=None): # type: igno is_pure=False, _builder=_builder, ) + + @core.extern + def sum_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] + """Sum reduction for int64""" + return core.extern_elementwise( + "", + "", + [team, dest, source, nreduce], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_int64_sum_reduce", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def max_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] + """Max reduction for int64""" + return core.extern_elementwise( + "", + "", + [team, dest, source, nreduce], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_int64_max_reduce", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def min_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] + """Min reduction for int64""" + return core.extern_elementwise( + "", + "", + [team, dest, source, nreduce], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_int64_min_reduce", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) From b0b229b19757179c7ba161e9f6ecbf435946f535 Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Wed, 6 Aug 2025 16:58:27 -0700 Subject: [PATCH 0133/1424] [SymmMem] Use _get_default_group() instead of group.WORLD for group_name access (#158718) Both approaches functionally return the default process group created by `init_process_group()` but `_get_default_group()` is a dedicated function with [better error handling and type safety](https://github.com/pytorch/pytorch/blob/4869f7117009fb99a57482fce56b00c6163fbce6/torch/distributed/distributed_c10d.py#L1300-L1310). Pull Request resolved: https://github.com/pytorch/pytorch/pull/158718 Approved by: https://github.com/Skylion007, https://github.com/fduwjj ghstack dependencies: #158515 --- test/distributed/test_nvshmem_triton.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 1145da014543d..94e68d7ff100c 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -851,7 +851,7 @@ def test_triton_sync(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank numel = 1 @@ -888,7 +888,7 @@ def test_triton_alltoall(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) world_size = dist.get_world_size() rank = self.rank @@ -936,7 +936,7 @@ def test_triton_broadcast(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank # Configuration @@ -983,7 +983,7 @@ def test_triton_sum_reduce(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) world_size = dist.get_world_size() rank = self.rank @@ -1029,7 +1029,7 @@ def test_triton_minmax_reduce(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) world_size = dist.get_world_size() rank = self.rank From ea7fe0ecf62b44185181fba8263cfb6cbf58fa09 Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Wed, 6 Aug 2025 16:58:27 -0700 Subject: [PATCH 0134/1424] [SymmMem] Standardize NVSHMEM Triton wrappers on byte-based APIs + improve code clarity (#159136) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Quick refactor for consistency and clarity. 1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms. 2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals. To clean this up: • All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit. • Typed APIs consistently use nelems for element counts. • Docstrings were added or updated to clarify expected units. Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping. This should make the API more intuitive and reduce friction for developers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159136 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718 --- test/distributed/test_nvshmem_triton.py | 132 +++++++++--------- .../_symmetric_memory/_nvshmem_triton.py | 57 +++++--- 2 files changed, 104 insertions(+), 85 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 94e68d7ff100c..1cd2247a93457 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -1,9 +1,7 @@ # Owner(s): ["oncall: distributed"] - # To run: # python test/distributed/test_nvshmem_triton.py - import triton.language as tl import torch @@ -36,37 +34,37 @@ def requires_nvshmem(): # Shared Triton JIT kernels @triton.jit -def put_kernel( +def putmem_block_kernel( dst_ptr, src_ptr, - numel, + size_bytes, peer, ): - nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + nvshmem.putmem_block(dst_ptr, src_ptr, size_bytes, peer) @triton.jit -def get_kernel( +def getmem_block_kernel( dst_ptr, src_ptr, - numel, + size_bytes, peer, ): - nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer) + nvshmem.getmem_block(dst_ptr, src_ptr, size_bytes, peer) @triton.jit -def put_signal_kernel( +def putmem_signal_block_kernel( dst_ptr, src_ptr, - numel, + size_bytes, sig_ptr, signal_val, sig_op, peer, ): nvshmem.putmem_signal_block( - dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer + dst_ptr, src_ptr, size_bytes, sig_ptr, signal_val, sig_op, peer ) @@ -95,18 +93,8 @@ def wait_until_kernel( @triton.jit -def put_and_signal_kernel( - dst_ptr, - src_ptr, - numel, - sig_ptr, - signal_val, - sig_op, - peer, -): - nvshmem.putmem_signal_block( - dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer - ) +def fence_kernel(): + nvshmem.fence() @triton.jit @@ -117,19 +105,19 @@ def put_with_fence_kernel( src_ptr2, flag_ptr, flag_src_ptr, - numel, + size_bytes, peer, ): # First put - nvshmem.putmem_block(dst_ptr1, src_ptr1, numel, peer) + nvshmem.putmem_block(dst_ptr1, src_ptr1, size_bytes, peer) # Ensure the first put is ordered before the next. nvshmem.fence() # Second put - nvshmem.putmem_block(dst_ptr2, src_ptr2, numel, peer) + nvshmem.putmem_block(dst_ptr2, src_ptr2, size_bytes, peer) # Order the second put before flag update. nvshmem.fence() # Write the flag (single int64) to signal completion. - nvshmem.putmem_block(flag_ptr, flag_src_ptr, 1, peer) + nvshmem.putmem_block(flag_ptr, flag_src_ptr, 8, peer) # 8 bytes for int64 @triton.jit @@ -138,23 +126,23 @@ def put_with_quiet_kernel( src_ptr, flag_dst_ptr, flag_src_ptr, - numel, + size_bytes, peer, ): # Put data - nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + nvshmem.putmem_block(dst_ptr, src_ptr, size_bytes, peer) # Call quiet to ensure put is complete nvshmem.quiet() # Only after quiet, set the completion flag # This ensures the data put is complete before flag is set - nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer) + nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 8, peer) # 8 bytes for int64 @triton.jit def barrier_test_kernel( dst_ptr, src_ptr, - numel, + size_bytes, ): # Testing barrier_all() requires coordinated operations across PEs within # the same kernel execution. Unlike other kernels that just wrap NVSHMEM @@ -162,6 +150,7 @@ def barrier_test_kernel( # device-side barrier synchronization. my_pe = nvshmem.my_pe() n_pes = nvshmem.n_pes() + # Rank 0 broadcasts its value to all other ranks if my_pe == 0: # Write initial value @@ -170,10 +159,12 @@ def barrier_test_kernel( # Put to all other ranks i = 1 while i < n_pes: - nvshmem.putmem_block(dst_ptr, src_ptr, numel, i) + nvshmem.putmem_block(dst_ptr, src_ptr, size_bytes, i) i += 1 + # Synchronize all PEs nvshmem.barrier_all() + # Non-zero ranks increment the received value if my_pe != 0: p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) @@ -185,7 +176,7 @@ def barrier_test_kernel( def sync_test_kernel( dst_ptr, src_ptr, - numel, + size_bytes, ): my_pe = nvshmem.my_pe() n_pes = nvshmem.n_pes() @@ -198,11 +189,13 @@ def sync_test_kernel( # Put to all other ranks i = 1 while i < n_pes: - nvshmem.putmem_block(dst_ptr, src_ptr, numel, i) + nvshmem.putmem_block(dst_ptr, src_ptr, size_bytes, i) i += 1 + # Synchronize all PEs (this is more lightweight than barrier_all() b/c it only ensures local store visibility # and doesn't wait for remote ops to complete) nvshmem.sync_all() + # Non-zero ranks increment the received value if my_pe != 0: p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) @@ -211,24 +204,24 @@ def sync_test_kernel( @triton.jit -def alltoall_kernel( +def alltoallmem_block_kernel( team_handle, dest_ptr, src_ptr, - nelems, + size_bytes_per_pe, ): - nvshmem.alltoall(team_handle, dest_ptr, src_ptr, nelems) + nvshmem.alltoallmem_block(team_handle, dest_ptr, src_ptr, size_bytes_per_pe) @triton.jit -def broadcast_kernel( +def broadcastmem_block_kernel( team_handle, dest_ptr, src_ptr, - nelems, + size_bytes, pe_root, ): - nvshmem.broadcast(team_handle, dest_ptr, src_ptr, nelems, pe_root) + nvshmem.broadcastmem_block(team_handle, dest_ptr, src_ptr, size_bytes, pe_root) @triton.jit @@ -303,10 +296,10 @@ def test_triton_put(self) -> None: if rank == 0: dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] - put_kernel[(1, 1, 1)]( + putmem_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel=numel, + size_bytes=msg_size_bytes, peer=peer, extern_libs=nvshmem_lib, ) @@ -343,10 +336,10 @@ def test_triton_get(self) -> None: # Rank 1 gets data from rank 0 dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] - get_kernel[(1, 1, 1)]( + getmem_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel=numel, + size_bytes=msg_size_bytes, peer=peer, extern_libs=nvshmem_lib, ) @@ -384,10 +377,10 @@ def test_triton_get_ring(self) -> None: # All ranks execute the get operation dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] - get_kernel[(1, 1, 1)]( + getmem_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel=numel, + size_bytes=msg_size_bytes, peer=peer, extern_libs=nvshmem_lib, ) @@ -434,10 +427,10 @@ def test_triton_put_signal_set(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - put_signal_kernel[(1, 1, 1)]( + putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel=numel, + size_bytes=msg_size_bytes, sig_ptr=sig_ptr, signal_val=SIGNAL_VAL, sig_op=NVSHMEM_SIGNAL_SET, @@ -499,10 +492,10 @@ def test_triton_put_signal_add(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - put_signal_kernel[(1, 1, 1)]( + putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel=numel, + size_bytes=msg_size_bytes, sig_ptr=sig_ptr, signal_val=SIGNAL_VAL, sig_op=NVSHMEM_SIGNAL_ADD, @@ -573,10 +566,10 @@ def test_triton_wait_until(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] - put_kernel[(1, 1, 1)]( + putmem_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel=numel, + size_bytes=msg_size_bytes, peer=peer, extern_libs=nvshmem_lib, ) @@ -592,10 +585,10 @@ def fence_kernel(): flag_src = torch.tensor([flag_val], dtype=torch.int64, device=self.device) flag_dst_ptr = out_hdl.signal_pad_ptrs[peer] - put_kernel[(1, 1, 1)]( + putmem_block_kernel[(1, 1, 1)]( flag_dst_ptr, flag_src.data_ptr(), - numel=1, + size_bytes=8, # 8 bytes for int64 peer=peer, extern_libs=nvshmem_lib, ) @@ -619,6 +612,7 @@ def test_triton_signal_wait_until(self) -> None: msg_size_bytes = 8 dtype = torch.int8 numel = msg_size_bytes // dtype.itemsize + val_to_put = 123 # arbitrary test value COMPLETION_FLAG_VAL = 1 @@ -637,11 +631,11 @@ def test_triton_signal_wait_until(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - put_and_signal_kernel[(1, 1, 1)]( + putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, - numel, - sig_ptr, + size_bytes=msg_size_bytes, + sig_ptr=sig_ptr, signal_val=COMPLETION_FLAG_VAL, sig_op=NVSHMEM_SIGNAL_SET, peer=peer, @@ -690,6 +684,7 @@ def test_triton_fence(self) -> None: msg_size_bytes = 8 dtype = torch.int8 numel = msg_size_bytes // dtype.itemsize + val1 = 10 val2 = 20 flag_val = 1 @@ -725,7 +720,7 @@ def test_triton_fence(self) -> None: src_ptr2, flag_ptr, flag_src_ptr, - numel, + size_bytes=msg_size_bytes, peer=peer, extern_libs=nvshmem_lib, ) @@ -763,6 +758,7 @@ def test_triton_quiet(self) -> None: msg_size_bytes = 8 dtype = torch.int8 numel = msg_size_bytes // dtype.itemsize + # Data buffers val = 15 inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) @@ -802,7 +798,7 @@ def test_triton_quiet(self) -> None: src_ptr, flag_dst_ptr, flag_src_ptr, - numel=numel, + size_bytes=msg_size_bytes, peer=peer, extern_libs=nvshmem_lib, ) @@ -818,6 +814,7 @@ def test_triton_barrier(self) -> None: rank = self.rank numel = 1 dtype = torch.int32 + size_bytes = numel * dtype.itemsize # Create symmetric buffers src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) @@ -827,7 +824,7 @@ def test_triton_barrier(self) -> None: barrier_test_kernel[(1,)]( dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], - numel=numel, + size_bytes=size_bytes, extern_libs=nvshmem_lib, launch_cooperative_grid=True, num_ctas=1, @@ -856,6 +853,7 @@ def test_triton_sync(self) -> None: rank = self.rank numel = 1 dtype = torch.int32 + size_bytes = numel * dtype.itemsize # Create symmetric buffers src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) @@ -865,7 +863,7 @@ def test_triton_sync(self) -> None: sync_test_kernel[(1,)]( dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], - numel=numel, + size_bytes=size_bytes, extern_libs=nvshmem_lib, launch_cooperative_grid=True, num_ctas=1, @@ -895,6 +893,7 @@ def test_triton_alltoall(self) -> None: # Each PE will send 2 int64 elements to every other PE nelems_per_pe = 2 dtype = torch.int64 + size_bytes_per_pe = nelems_per_pe * dtype.itemsize # Source buffer: contains data for all PEs # Layout: [data_for_pe0, data_for_pe1, ...] src_size = nelems_per_pe * world_size @@ -912,11 +911,11 @@ def test_triton_alltoall(self) -> None: dist.barrier() team_handle = 0 # NVSHMEM_TEAM_WORLD handle is 0 # Launch the kernel - alltoall_kernel[(1,)]( + alltoallmem_block_kernel[(1,)]( team_handle, dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], - nelems_per_pe, + size_bytes_per_pe=size_bytes_per_pe, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) @@ -942,6 +941,7 @@ def test_triton_broadcast(self) -> None: # Configuration nelems = 4 # number of elements dtype = torch.int64 + size_bytes = nelems * dtype.itemsize # Source buffer - only root will have meaningful data pe_root = 0 # PE 0 will be the root src = symm_mem.empty(nelems, dtype=dtype, device=self.device) @@ -960,12 +960,12 @@ def test_triton_broadcast(self) -> None: dist.barrier() # Execute broadcast team_handle = 0 # NVSHMEM_TEAM_WORLD - broadcast_kernel[(1,)]( + broadcastmem_block_kernel[(1,)]( team_handle, dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], - nelems, - pe_root, + size_bytes=size_bytes, + pe_root=pe_root, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index aefb7541d8308..3e0ee87611304 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -54,12 +54,14 @@ def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] if has_triton(): from triton.language import core + # RMA Operations (mem-based APIs - sizes in bytes) @core.extern - def putmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + def putmem_block(dst, src, size_bytes, pe, _builder=None): # type: ignore[no-untyped-def] + """Put data to remote PE. size_bytes specifies the size in bytes.""" return core.extern_elementwise( "", "", - [dst, src, nelems, pe], + [dst, src, size_bytes, pe], { ( core.dtype("int64"), @@ -73,11 +75,12 @@ def putmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untype ) @core.extern - def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + def getmem_block(dst, src, size_bytes, pe, _builder=None): # type: ignore[no-untyped-def] + """Get data from remote PE. size_bytes specifies the size in bytes.""" return core.extern_elementwise( "", "", - [dst, src, nelems, pe], + [dst, src, size_bytes, pe], { ( core.dtype("int64"), @@ -94,17 +97,18 @@ def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untype def putmem_signal_block( # type: ignore[no-untyped-def] dst, src, - nelems, + size_bytes, sig_addr, signal, sig_op, pe, _builder=None, ): # type: ignore[no-untyped-def] + """Put data to remote PE with signal. size_bytes specifies the size in bytes.""" return core.extern_elementwise( "", "", - [dst, src, nelems, sig_addr, signal, sig_op, pe], + [dst, src, size_bytes, sig_addr, signal, sig_op, pe], { ( core.dtype("int64"), @@ -120,8 +124,10 @@ def putmem_signal_block( # type: ignore[no-untyped-def] _builder=_builder, ) + # Wait and Signal Operations @core.extern def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + """Wait until a condition is met on a symmetric variable.""" return core.extern_elementwise( "", "", @@ -139,6 +145,7 @@ def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-de @core.extern def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + """Wait until a signal variable meets a condition.""" return core.extern_elementwise( "", "", @@ -156,6 +163,7 @@ def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no @core.extern def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-untyped-def] + """Perform a signal operation on a remote PE.""" return core.extern_elementwise( "", "", @@ -172,8 +180,10 @@ def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-u _builder=_builder, ) + # Memory Ordering Operations @core.extern def fence(_builder=None): # type: ignore[no-untyped-def] + """Ensure ordering of put operations.""" return core.extern_elementwise( "", "", @@ -187,6 +197,7 @@ def fence(_builder=None): # type: ignore[no-untyped-def] @core.extern def quiet(_builder=None): # type: ignore[no-untyped-def] + """Wait for completion of all outstanding put operations.""" return core.extern_elementwise( "", "", @@ -198,8 +209,10 @@ def quiet(_builder=None): # type: ignore[no-untyped-def] _builder=_builder, ) + # PE Information Operations @core.extern def my_pe(_builder=None): # type: ignore[no-untyped-def] + """Get the PE number of the calling PE.""" return core.extern_elementwise( "", "", @@ -211,6 +224,7 @@ def my_pe(_builder=None): # type: ignore[no-untyped-def] @core.extern def n_pes(_builder=None): # type: ignore[no-untyped-def] + """Get the total number of PEs.""" return core.extern_elementwise( "", "", @@ -220,8 +234,10 @@ def n_pes(_builder=None): # type: ignore[no-untyped-def] _builder=_builder, ) + # Synchronization Operations @core.extern def barrier_all(_builder=None): # type: ignore[no-untyped-def] + """Synchronize all PEs.""" return core.extern_elementwise( "", "", @@ -233,6 +249,7 @@ def barrier_all(_builder=None): # type: ignore[no-untyped-def] @core.extern def sync_all(_builder=None): # type: ignore[no-untyped-def] + """Synchronize all PEs (lightweight version, does not ensure completion of remote memory updates).""" return core.extern_elementwise( "", "", @@ -242,48 +259,50 @@ def sync_all(_builder=None): # type: ignore[no-untyped-def] _builder=_builder, ) + # Collective Operations (mem-based APIs - sizes in bytes) @core.extern - def alltoall(team, dest, source, nelems, _builder=None): # type: ignore[no-untyped-def] - """Perform alltoall operation on NVSHMEM symmetric memory""" + def alltoallmem_block(team, dest, source, size_bytes, _builder=None): # type: ignore[no-untyped-def] + """Perform alltoall operation on symmetric memory. size_bytes specifies the number of bytes to exchange per PE.""" return core.extern_elementwise( "", "", - [team, dest, source, nelems], + [team, dest, source, size_bytes], { ( core.dtype("int64"), # team handle core.dtype("int64"), # dest ptr core.dtype("int64"), # source ptr - core.dtype("int64"), # nelems - ): ("nvshmem_longlong_alltoall", core.dtype("int32")) + core.dtype("int64"), # size in bytes + ): ("nvshmemx_alltoallmem_block", core.dtype("int32")) }, is_pure=False, _builder=_builder, ) @core.extern - def broadcast(team, dest, source, nelems, pe_root, _builder=None): # type: ignore[no-untyped-def] - """Broadcasts data from a root PE to all other PEs in a team""" + def broadcastmem_block(team, dest, source, size_bytes, pe_root, _builder=None): # type: ignore[no-untyped-def] + """Broadcast data from a root PE to all other PEs in a team. size_bytes specifies the size in bytes.""" return core.extern_elementwise( "", "", - [team, dest, source, nelems, pe_root], + [team, dest, source, size_bytes, pe_root], { ( core.dtype("int64"), # team handle core.dtype("int64"), # dest ptr core.dtype("int64"), # source ptr - core.dtype("int64"), # nelems + core.dtype("int64"), # size in bytes core.dtype("int64"), # pe_root - ): ("nvshmem_longlong_broadcast", core.dtype("int32")) + ): ("nvshmemx_broadcastmem_block", core.dtype("int32")) }, is_pure=False, _builder=_builder, ) + # Reduction Operations @core.extern def sum_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] - """Sum reduction for int64""" + """Sum reduction for int64. nreduce is number of elements in the dest and source arrays.""" return core.extern_elementwise( "", "", @@ -302,7 +321,7 @@ def sum_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-u @core.extern def max_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] - """Max reduction for int64""" + """Max reduction for int64. nreduce is number of elements in the dest and source arrays.""" return core.extern_elementwise( "", "", @@ -321,7 +340,7 @@ def max_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-u @core.extern def min_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] - """Min reduction for int64""" + """Min reduction for int64. nreduce is number of elements in the dest and source arrays.""" return core.extern_elementwise( "", "", From 1783d6e966234d07cf9076ecd76b76ba28dfc031 Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Wed, 6 Aug 2025 16:58:27 -0700 Subject: [PATCH 0135/1424] [SymmMem] Fix flaky wait_until test (#159215) When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159215 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136 --- test/distributed/test_nvshmem_triton.py | 76 +++++++++++-------------- 1 file changed, 32 insertions(+), 44 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 1cd2247a93457..b0f29c0f05cb5 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -172,6 +172,11 @@ def barrier_test_kernel( tl.store(p_dst, received + 1) +@triton.jit +def barrier_all_kernel(): + nvshmem.barrier_all() + + @triton.jit def sync_test_kernel( dst_ptr, @@ -530,66 +535,49 @@ def test_triton_wait_until(self) -> None: rank = self.rank peer = (self.world_size - 1) - rank - NVSHMEM_CMP_EQ = 0 # from nvshmem.h - - # Allocate symmetric buffers - msg_size_bytes = 8 - dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize - val = 13 - flag_val = 21 + NVSHMEM_CMP_EQ = 0 # equal comparison + FLAG_INITIAL_VALUE = 0 + FLAG_FINAL_VALUE = 42 - inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) - out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + # Use a single int64 symmetric tensor as our synchronization flag. + flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_( + FLAG_INITIAL_VALUE + ) + flag_hdl = symm_mem.rendezvous(flag, group=group_name) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) + barrier_all_kernel[(1,)](extern_libs=nvshmem_lib) if rank == 0: - # Rank 0 waits for the flag to be set by Rank 1, then checks the data - ivar_ptr = out_hdl.signal_pad_ptrs[rank] - - wait_until_kernel[(1, 1, 1)]( + # Rank 0 (the waiter) + ivar_ptr = flag_hdl.buffer_ptrs[rank] + wait_until_kernel[(1,)]( ivar_ptr, cmp_op=NVSHMEM_CMP_EQ, - cmp_val=flag_val, + cmp_val=FLAG_FINAL_VALUE, extern_libs=nvshmem_lib, ) + # Verification torch.testing.assert_close( - out, - val * torch.ones(numel, dtype=dtype, device=self.device), + flag, + torch.tensor([FLAG_FINAL_VALUE], dtype=torch.int64, device=self.device), ) if rank == 1: - # Rank 1 puts data into Rank 0's output buffer - dst_ptr = out_hdl.buffer_ptrs[peer] - src_ptr = inp_hdl.buffer_ptrs[rank] - - putmem_block_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - size_bytes=msg_size_bytes, - peer=peer, - extern_libs=nvshmem_lib, + # Rank 1 (the signaler) + val_to_put = torch.tensor( + [FLAG_FINAL_VALUE], dtype=torch.int64, device=self.device ) - # Fence to order data put before flag put - @triton.jit - def fence_kernel(): - nvshmem.fence() - - fence_kernel[(1, 1, 1)](extern_libs=nvshmem_lib) + # The destination is Rank 0's flag buffer. + dst_ptr = flag_hdl.buffer_ptrs[rank] - # Put the flag value (do not use signal_op here) - flag_src = torch.tensor([flag_val], dtype=torch.int64, device=self.device) - flag_dst_ptr = out_hdl.signal_pad_ptrs[peer] - - putmem_block_kernel[(1, 1, 1)]( - flag_dst_ptr, - flag_src.data_ptr(), - size_bytes=8, # 8 bytes for int64 - peer=peer, + # Launch a kernel to put the value to Rank 0. + putmem_block_kernel[(1,)]( + dst_ptr, # Destination pointer on the remote PE + val_to_put.data_ptr(), # Source data pointer (local) + size_bytes=8, # Size of one int64 + peer=peer, # The target PE (Rank 0) extern_libs=nvshmem_lib, ) From 7c4f7b93404fabe1a80f4a60c26d062154a3d95b Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Wed, 6 Aug 2025 16:58:27 -0700 Subject: [PATCH 0136/1424] [SymmMem] Add Triton 3.4 support to NVSHMEM Triton and fix CI tests (make device library discoverable + fix peer calculation bug) (#159701) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159701 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215 --- test/distributed/test_nvshmem_triton.py | 42 ++++-- .../_symmetric_memory/_nvshmem_triton.py | 137 ++++++++++++------ 2 files changed, 125 insertions(+), 54 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index b0f29c0f05cb5..a58fe9638b2cc 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -16,10 +16,10 @@ skip_but_pass_in_sandcastle_if, skipIfRocm, ) -from torch.testing._internal.inductor_utils import requires_triton +from torch.testing._internal.inductor_utils import IS_H100, requires_triton -# Decorator +# Decorators def requires_nvshmem(): return skip_but_pass_in_sandcastle_if( not symm_mem.is_nvshmem_available(), @@ -27,6 +27,13 @@ def requires_nvshmem(): ) +def requires_h100(): + return skip_but_pass_in_sandcastle_if( + not IS_H100, + "NVSHMEM requires H100. Skipping test on non-H100 GPU.", + ) + + # So that tests are written in device-agnostic way device_type = "cuda" device_module = torch.get_device_module(device_type) @@ -276,6 +283,7 @@ def device(self) -> torch.device: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_put(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -297,7 +305,7 @@ def test_triton_put(self) -> None: inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) - peer = (self.world_size - 1) - rank + peer = 1 - rank if rank == 0: dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] @@ -317,6 +325,7 @@ def test_triton_put(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_get(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -336,7 +345,7 @@ def test_triton_get(self) -> None: inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) dist.barrier() - peer = (self.world_size - 1) - rank + peer = 1 - rank if rank == 1: # Rank 1 gets data from rank 0 dst_ptr = out_hdl.buffer_ptrs[rank] @@ -355,6 +364,7 @@ def test_triton_get(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_get_ring(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -397,6 +407,7 @@ def test_triton_get_ring(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_put_signal_set(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -422,7 +433,7 @@ def test_triton_put_signal_set(self) -> None: # as the flag buffer for signaling completion. flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) - peer = (self.world_size - 1) - rank + peer = 1 - rank NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set SIGNAL_VAL = 1 # Signal completion value NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until @@ -462,6 +473,7 @@ def test_triton_put_signal_set(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_put_signal_add(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -487,7 +499,7 @@ def test_triton_put_signal_add(self) -> None: # as the flag buffer for signaling completion. flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) - peer = (self.world_size - 1) - rank + peer = 1 - rank NVSHMEM_SIGNAL_ADD = 5 # atomic add operation SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD NVSHMEM_CMP_EQ = 0 @@ -525,6 +537,7 @@ def test_triton_put_signal_add(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_wait_until(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -534,7 +547,7 @@ def test_triton_wait_until(self) -> None: symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = (self.world_size - 1) - rank + peer = 1 - rank NVSHMEM_CMP_EQ = 0 # equal comparison FLAG_INITIAL_VALUE = 0 FLAG_FINAL_VALUE = 42 @@ -583,6 +596,7 @@ def test_triton_wait_until(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_signal_wait_until(self) -> None: self._init_device() # Enable NVSHMEM for Triton @@ -590,7 +604,7 @@ def test_triton_signal_wait_until(self) -> None: group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = (self.world_size - 1) - rank + peer = 1 - rank # NVSHMEM constants from documentation NVSHMEM_CMP_EQ = 0 # equal comparison @@ -651,6 +665,7 @@ def test_triton_signal_wait_until(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_fence(self) -> None: """ Rank 0 performs two put operations into Rank 1's buffers with a fence @@ -667,7 +682,7 @@ def test_triton_fence(self) -> None: group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = (self.world_size - 1) - rank + peer = 1 - rank # Message configuration msg_size_bytes = 8 dtype = torch.int8 @@ -735,6 +750,7 @@ def test_triton_fence(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_quiet(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -755,7 +771,7 @@ def test_triton_quiet(self) -> None: out_hdl = symm_mem.rendezvous(out, group=group_name) # Use signal pad as completion flag flag_val = 42 - peer = (self.world_size - 1) - rank + peer = 1 - rank NVSHMEM_CMP_EQ = 0 if rank == 0: @@ -793,6 +809,7 @@ def test_triton_quiet(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_barrier(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -832,6 +849,7 @@ def test_triton_barrier(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_sync(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -870,6 +888,7 @@ def test_triton_sync(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_alltoall(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -919,6 +938,7 @@ def test_triton_alltoall(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_broadcast(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -967,6 +987,7 @@ def test_triton_broadcast(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_sum_reduce(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -1013,6 +1034,7 @@ def test_triton_sum_reduce(self) -> None: @skipIfRocm @requires_triton() + @requires_h100() def test_triton_minmax_reduce(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 3e0ee87611304..b4c2cebf16ce2 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -1,10 +1,58 @@ import os +import subprocess import sysconfig from typing import Optional from torch.utils._triton import has_triton +def _find_nvshmem_device_library() -> str: + paths = [os.path.join(sysconfig.get_path("purelib"), "nvidia", "nvshmem", "lib")] + + # Add common system installation paths + common_paths = [ + "/usr/local/lib", + "/usr/lib", + "/opt/nvidia/nvshmem/lib", + ] + paths.extend(common_paths) + + try: + import torch + + torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib") + so_path = os.path.join(torch_lib, "libtorch_nvshmem.so") + + if os.path.exists(so_path): + try: + result = subprocess.run( + ["readelf", "-d", so_path], + capture_output=True, + text=True, + check=True, + ) + + for line in result.stdout.splitlines(): + if ("RPATH" in line or "RUNPATH" in line) and "[" in line: + rpath = line.split("[", 1)[1].split("]", 1)[0] + for p in rpath.split(":"): + p = p.strip().replace("$ORIGIN", torch_lib) + if p and p not in paths: + paths.append(p) + except subprocess.CalledProcessError: + pass + + except ImportError: + pass + + for path in paths: + device_lib = os.path.join(path, "libnvshmem_device.bc") + if os.path.exists(device_lib): + return device_lib + + raise RuntimeError(f"NVSHMEM device library not found. Searched: {paths}") + + def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: """ Enable NVSHMEM device functions for Triton. It performs a NVSHMEM @@ -19,18 +67,19 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: dict[str, str]: A dictionary containing the NVSHMEM device library name and path. """ - from triton.runtime.jit import JITFunction + import triton from torch._C._distributed_c10d import _nvshmemx_cumodule_init - # Detect NVSHMEM device library path from python library path - if lib_dir is None: - py_lib_path = sysconfig.get_path("purelib") - lib_dir = py_lib_path + "/nvidia/nvshmem/lib" - - lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") - if not os.path.exists(lib_path): - raise RuntimeError("NVSHMEM device library not found") + if lib_dir is not None: + lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") + if not os.path.exists(lib_path): + raise RuntimeError( + f"NVSHMEM device library not found at specified path: {lib_path}" + ) + else: + # Otherwise, search for the library automatically. + lib_path = _find_nvshmem_device_library() extern_libs = {"libnvshmem_device": lib_path} @@ -45,7 +94,7 @@ def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] _nvshmemx_cumodule_init(kernel.module) # Register the function as a post-compile hook - JITFunction.compiled_hook = nvshmem_init_hook + triton.knobs.runtime.jit_post_compile_hook = nvshmem_init_hook # Return to user so that they can use it in Triton kernel invocation return extern_libs @@ -56,7 +105,7 @@ def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] # RMA Operations (mem-based APIs - sizes in bytes) @core.extern - def putmem_block(dst, src, size_bytes, pe, _builder=None): # type: ignore[no-untyped-def] + def putmem_block(dst, src, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] """Put data to remote PE. size_bytes specifies the size in bytes.""" return core.extern_elementwise( "", @@ -71,11 +120,11 @@ def putmem_block(dst, src, size_bytes, pe, _builder=None): # type: ignore[no-un ): ("nvshmemx_putmem_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def getmem_block(dst, src, size_bytes, pe, _builder=None): # type: ignore[no-untyped-def] + def getmem_block(dst, src, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] """Get data from remote PE. size_bytes specifies the size in bytes.""" return core.extern_elementwise( "", @@ -90,7 +139,7 @@ def getmem_block(dst, src, size_bytes, pe, _builder=None): # type: ignore[no-un ): ("nvshmemx_getmem_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern @@ -102,7 +151,7 @@ def putmem_signal_block( # type: ignore[no-untyped-def] signal, sig_op, pe, - _builder=None, + _semantic=None, ): # type: ignore[no-untyped-def] """Put data to remote PE with signal. size_bytes specifies the size in bytes.""" return core.extern_elementwise( @@ -121,12 +170,12 @@ def putmem_signal_block( # type: ignore[no-untyped-def] ): ("nvshmemx_putmem_signal_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) # Wait and Signal Operations @core.extern - def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + def wait_until(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] """Wait until a condition is met on a symmetric variable.""" return core.extern_elementwise( "", @@ -140,11 +189,11 @@ def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-de ): ("nvshmem_longlong_wait_until", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + def signal_wait_until(sig_addr, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] """Wait until a signal variable meets a condition.""" return core.extern_elementwise( "", @@ -158,11 +207,11 @@ def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no ): ("nvshmem_signal_wait_until", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-untyped-def] + def signal_op(sig_addr, signal, sig_op, pe, _semantic=None): # type: ignore[no-untyped-def] """Perform a signal operation on a remote PE.""" return core.extern_elementwise( "", @@ -177,12 +226,12 @@ def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-u ): ("nvshmemx_signal_op", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) # Memory Ordering Operations @core.extern - def fence(_builder=None): # type: ignore[no-untyped-def] + def fence(_semantic=None): # type: ignore[no-untyped-def] """Ensure ordering of put operations.""" return core.extern_elementwise( "", @@ -192,11 +241,11 @@ def fence(_builder=None): # type: ignore[no-untyped-def] (): ("nvshmem_fence", core.dtype("int32")), }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def quiet(_builder=None): # type: ignore[no-untyped-def] + def quiet(_semantic=None): # type: ignore[no-untyped-def] """Wait for completion of all outstanding put operations.""" return core.extern_elementwise( "", @@ -206,12 +255,12 @@ def quiet(_builder=None): # type: ignore[no-untyped-def] (): ("nvshmem_quiet", core.dtype("int32")), }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) # PE Information Operations @core.extern - def my_pe(_builder=None): # type: ignore[no-untyped-def] + def my_pe(_semantic=None): # type: ignore[no-untyped-def] """Get the PE number of the calling PE.""" return core.extern_elementwise( "", @@ -219,11 +268,11 @@ def my_pe(_builder=None): # type: ignore[no-untyped-def] [], {(): ("nvshmem_my_pe", core.dtype("int32"))}, is_pure=True, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def n_pes(_builder=None): # type: ignore[no-untyped-def] + def n_pes(_semantic=None): # type: ignore[no-untyped-def] """Get the total number of PEs.""" return core.extern_elementwise( "", @@ -231,12 +280,12 @@ def n_pes(_builder=None): # type: ignore[no-untyped-def] [], {(): ("nvshmem_n_pes", core.dtype("int32"))}, is_pure=True, - _builder=_builder, + _semantic=_semantic, ) # Synchronization Operations @core.extern - def barrier_all(_builder=None): # type: ignore[no-untyped-def] + def barrier_all(_semantic=None): # type: ignore[no-untyped-def] """Synchronize all PEs.""" return core.extern_elementwise( "", @@ -244,11 +293,11 @@ def barrier_all(_builder=None): # type: ignore[no-untyped-def] [], {(): ("nvshmem_barrier_all", core.dtype("int32"))}, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def sync_all(_builder=None): # type: ignore[no-untyped-def] + def sync_all(_semantic=None): # type: ignore[no-untyped-def] """Synchronize all PEs (lightweight version, does not ensure completion of remote memory updates).""" return core.extern_elementwise( "", @@ -256,12 +305,12 @@ def sync_all(_builder=None): # type: ignore[no-untyped-def] [], {(): ("nvshmem_sync_all", core.dtype("int32"))}, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) # Collective Operations (mem-based APIs - sizes in bytes) @core.extern - def alltoallmem_block(team, dest, source, size_bytes, _builder=None): # type: ignore[no-untyped-def] + def alltoallmem_block(team, dest, source, size_bytes, _semantic=None): # type: ignore[no-untyped-def] """Perform alltoall operation on symmetric memory. size_bytes specifies the number of bytes to exchange per PE.""" return core.extern_elementwise( "", @@ -276,11 +325,11 @@ def alltoallmem_block(team, dest, source, size_bytes, _builder=None): # type: i ): ("nvshmemx_alltoallmem_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def broadcastmem_block(team, dest, source, size_bytes, pe_root, _builder=None): # type: ignore[no-untyped-def] + def broadcastmem_block(team, dest, source, size_bytes, pe_root, _semantic=None): # type: ignore[no-untyped-def] """Broadcast data from a root PE to all other PEs in a team. size_bytes specifies the size in bytes.""" return core.extern_elementwise( "", @@ -296,12 +345,12 @@ def broadcastmem_block(team, dest, source, size_bytes, pe_root, _builder=None): ): ("nvshmemx_broadcastmem_block", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) # Reduction Operations @core.extern - def sum_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] + def sum_reduce(team, dest, source, nreduce, _semantic=None): # type: ignore[no-untyped-def] """Sum reduction for int64. nreduce is number of elements in the dest and source arrays.""" return core.extern_elementwise( "", @@ -316,11 +365,11 @@ def sum_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-u ): ("nvshmem_int64_sum_reduce", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def max_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] + def max_reduce(team, dest, source, nreduce, _semantic=None): # type: ignore[no-untyped-def] """Max reduction for int64. nreduce is number of elements in the dest and source arrays.""" return core.extern_elementwise( "", @@ -335,11 +384,11 @@ def max_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-u ): ("nvshmem_int64_max_reduce", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) @core.extern - def min_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-untyped-def] + def min_reduce(team, dest, source, nreduce, _semantic=None): # type: ignore[no-untyped-def] """Min reduction for int64. nreduce is number of elements in the dest and source arrays.""" return core.extern_elementwise( "", @@ -354,5 +403,5 @@ def min_reduce(team, dest, source, nreduce, _builder=None): # type: ignore[no-u ): ("nvshmem_int64_min_reduce", core.dtype("int32")) }, is_pure=False, - _builder=_builder, + _semantic=_semantic, ) From 1c881440f4c3ae46d409fa2206029e219b2e08c8 Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Wed, 6 Aug 2025 16:58:28 -0700 Subject: [PATCH 0137/1424] [SymmMem] Initialize NVSHMEM module only for kernels that have nvshmem in their name (#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159734 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701 --- test/distributed/test_nvshmem_triton.py | 80 +++++++++---------- .../_symmetric_memory/_nvshmem_triton.py | 22 +++-- 2 files changed, 56 insertions(+), 46 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index a58fe9638b2cc..a02d8b58110e0 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -41,7 +41,7 @@ def requires_h100(): # Shared Triton JIT kernels @triton.jit -def putmem_block_kernel( +def nvshmem_putmem_block_kernel( dst_ptr, src_ptr, size_bytes, @@ -51,7 +51,7 @@ def putmem_block_kernel( @triton.jit -def getmem_block_kernel( +def nvshmem_getmem_block_kernel( dst_ptr, src_ptr, size_bytes, @@ -61,7 +61,7 @@ def getmem_block_kernel( @triton.jit -def putmem_signal_block_kernel( +def nvshmem_putmem_signal_block_kernel( dst_ptr, src_ptr, size_bytes, @@ -76,12 +76,12 @@ def putmem_signal_block_kernel( @triton.jit -def signal_wait_until_kernel(sig_ptr, cmp_op, cmp_val): +def nvshmem_signal_wait_until_kernel(sig_ptr, cmp_op, cmp_val): nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val) @triton.jit -def signal_op_kernel( +def nvshmem_signal_op_kernel( sig_addr, signal, sig_op, @@ -91,7 +91,7 @@ def signal_op_kernel( @triton.jit -def wait_until_kernel( +def nvshmem_wait_until_kernel( ivar_ptr, cmp_op, cmp_val, @@ -100,12 +100,12 @@ def wait_until_kernel( @triton.jit -def fence_kernel(): +def nvshmem_fence_kernel(): nvshmem.fence() @triton.jit -def put_with_fence_kernel( +def nvshmem_put_with_fence_kernel( dst_ptr1, dst_ptr2, src_ptr1, @@ -128,7 +128,7 @@ def put_with_fence_kernel( @triton.jit -def put_with_quiet_kernel( +def nvshmem_put_with_quiet_kernel( dst_ptr, src_ptr, flag_dst_ptr, @@ -146,7 +146,7 @@ def put_with_quiet_kernel( @triton.jit -def barrier_test_kernel( +def nvshmem_barrier_test_kernel( dst_ptr, src_ptr, size_bytes, @@ -180,12 +180,12 @@ def barrier_test_kernel( @triton.jit -def barrier_all_kernel(): +def nvshmem_barrier_all_kernel(): nvshmem.barrier_all() @triton.jit -def sync_test_kernel( +def nvshmem_sync_test_kernel( dst_ptr, src_ptr, size_bytes, @@ -216,7 +216,7 @@ def sync_test_kernel( @triton.jit -def alltoallmem_block_kernel( +def nvshmem_alltoallmem_block_kernel( team_handle, dest_ptr, src_ptr, @@ -226,7 +226,7 @@ def alltoallmem_block_kernel( @triton.jit -def broadcastmem_block_kernel( +def nvshmem_broadcastmem_block_kernel( team_handle, dest_ptr, src_ptr, @@ -237,7 +237,7 @@ def broadcastmem_block_kernel( @triton.jit -def sum_reduce_kernel( +def nvshmem_sum_reduce_kernel( team_handle, dest_ptr, src_ptr, @@ -247,7 +247,7 @@ def sum_reduce_kernel( @triton.jit -def max_reduce_kernel( +def nvshmem_max_reduce_kernel( team_handle, dest_ptr, src_ptr, @@ -257,7 +257,7 @@ def max_reduce_kernel( @triton.jit -def min_reduce_kernel( +def nvshmem_min_reduce_kernel( team_handle, dest_ptr, src_ptr, @@ -309,7 +309,7 @@ def test_triton_put(self) -> None: if rank == 0: dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] - putmem_block_kernel[(1, 1, 1)]( + nvshmem_putmem_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, size_bytes=msg_size_bytes, @@ -350,7 +350,7 @@ def test_triton_get(self) -> None: # Rank 1 gets data from rank 0 dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] - getmem_block_kernel[(1, 1, 1)]( + nvshmem_getmem_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, size_bytes=msg_size_bytes, @@ -392,7 +392,7 @@ def test_triton_get_ring(self) -> None: # All ranks execute the get operation dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] - getmem_block_kernel[(1, 1, 1)]( + nvshmem_getmem_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, size_bytes=msg_size_bytes, @@ -443,7 +443,7 @@ def test_triton_put_signal_set(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - putmem_signal_block_kernel[(1, 1, 1)]( + nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, size_bytes=msg_size_bytes, @@ -457,7 +457,7 @@ def test_triton_put_signal_set(self) -> None: if rank == 1: # Wait until signal flag is set by Rank 0 sig_ptr_local = out_hdl.signal_pad_ptrs[rank] - signal_wait_until_kernel[(1,)]( + nvshmem_signal_wait_until_kernel[(1,)]( sig_ptr_local, cmp_op=NVSHMEM_CMP_EQ, cmp_val=SIGNAL_VAL, @@ -509,7 +509,7 @@ def test_triton_put_signal_add(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - putmem_signal_block_kernel[(1, 1, 1)]( + nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, size_bytes=msg_size_bytes, @@ -522,7 +522,7 @@ def test_triton_put_signal_add(self) -> None: if rank == 1: sig_ptr_local = out_hdl.signal_pad_ptrs[rank] - signal_wait_until_kernel[(1, 1, 1)]( + nvshmem_signal_wait_until_kernel[(1, 1, 1)]( sig_ptr_local, cmp_op=NVSHMEM_CMP_EQ, cmp_val=SIGNAL_VAL, @@ -558,12 +558,12 @@ def test_triton_wait_until(self) -> None: ) flag_hdl = symm_mem.rendezvous(flag, group=group_name) - barrier_all_kernel[(1,)](extern_libs=nvshmem_lib) + nvshmem_barrier_all_kernel[(1,)](extern_libs=nvshmem_lib) if rank == 0: # Rank 0 (the waiter) ivar_ptr = flag_hdl.buffer_ptrs[rank] - wait_until_kernel[(1,)]( + nvshmem_wait_until_kernel[(1,)]( ivar_ptr, cmp_op=NVSHMEM_CMP_EQ, cmp_val=FLAG_FINAL_VALUE, @@ -586,7 +586,7 @@ def test_triton_wait_until(self) -> None: dst_ptr = flag_hdl.buffer_ptrs[rank] # Launch a kernel to put the value to Rank 0. - putmem_block_kernel[(1,)]( + nvshmem_putmem_block_kernel[(1,)]( dst_ptr, # Destination pointer on the remote PE val_to_put.data_ptr(), # Source data pointer (local) size_bytes=8, # Size of one int64 @@ -633,7 +633,7 @@ def test_triton_signal_wait_until(self) -> None: dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] sig_ptr = out_hdl.signal_pad_ptrs[peer] - putmem_signal_block_kernel[(1, 1, 1)]( + nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( dst_ptr, src_ptr, size_bytes=msg_size_bytes, @@ -646,7 +646,7 @@ def test_triton_signal_wait_until(self) -> None: elif rank == 1: # Consumer (rank 1): Waits on the signal variable using `signal_wait_until`. sig_ptr = out_hdl.signal_pad_ptrs[rank] - signal_wait_until_kernel[(1, 1, 1)]( + nvshmem_signal_wait_until_kernel[(1, 1, 1)]( sig_ptr, cmp_op=NVSHMEM_CMP_EQ, cmp_val=COMPLETION_FLAG_VAL, @@ -716,7 +716,7 @@ def test_triton_fence(self) -> None: flag_ptr = out2_hdl.signal_pad_ptrs[rank] flag_src_ptr = flag_update_val.data_ptr() - put_with_fence_kernel[(1, 1, 1)]( + nvshmem_put_with_fence_kernel[(1, 1, 1)]( dst_ptr1, dst_ptr2, src_ptr1, @@ -730,7 +730,7 @@ def test_triton_fence(self) -> None: elif rank == 1: # Wait until flag is set by Rank 0. ivar_ptr = out2_hdl.signal_pad_ptrs[rank] - wait_until_kernel[(1, 1, 1)]( + nvshmem_wait_until_kernel[(1, 1, 1)]( ivar_ptr, cmp_op=NVSHMEM_CMP_EQ, cmp_val=flag_val, @@ -777,7 +777,7 @@ def test_triton_quiet(self) -> None: if rank == 0: # Rank 0 waits for flag from Rank 1 ivar_ptr = out_hdl.signal_pad_ptrs[rank] - wait_until_kernel[(1, 1, 1)]( + nvshmem_wait_until_kernel[(1, 1, 1)]( ivar_ptr, cmp_op=NVSHMEM_CMP_EQ, cmp_val=flag_val, @@ -797,7 +797,7 @@ def test_triton_quiet(self) -> None: [flag_val], dtype=torch.int64, device=self.device ) flag_src_ptr = flag_update_val.data_ptr() - put_with_quiet_kernel[(1, 1, 1)]( + nvshmem_put_with_quiet_kernel[(1, 1, 1)]( dst_ptr, src_ptr, flag_dst_ptr, @@ -826,7 +826,7 @@ def test_triton_barrier(self) -> None: src_hdl = symm_mem.rendezvous(src, group=group_name) dst_hdl = symm_mem.rendezvous(dst, group=group_name) # Launch kernel with cooperative grid - barrier_test_kernel[(1,)]( + nvshmem_barrier_test_kernel[(1,)]( dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], size_bytes=size_bytes, @@ -866,7 +866,7 @@ def test_triton_sync(self) -> None: src_hdl = symm_mem.rendezvous(src, group=group_name) dst_hdl = symm_mem.rendezvous(dst, group=group_name) # Launch kernel with cooperative grid - sync_test_kernel[(1,)]( + nvshmem_sync_test_kernel[(1,)]( dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], size_bytes=size_bytes, @@ -918,7 +918,7 @@ def test_triton_alltoall(self) -> None: dist.barrier() team_handle = 0 # NVSHMEM_TEAM_WORLD handle is 0 # Launch the kernel - alltoallmem_block_kernel[(1,)]( + nvshmem_alltoallmem_block_kernel[(1,)]( team_handle, dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], @@ -968,7 +968,7 @@ def test_triton_broadcast(self) -> None: dist.barrier() # Execute broadcast team_handle = 0 # NVSHMEM_TEAM_WORLD - broadcastmem_block_kernel[(1,)]( + nvshmem_broadcastmem_block_kernel[(1,)]( team_handle, dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], @@ -1017,7 +1017,7 @@ def test_triton_sum_reduce(self) -> None: dist.barrier() # Execute reduction team_handle = 0 # NVSHMEM_TEAM_WORLD - sum_reduce_kernel[(1,)]( + nvshmem_sum_reduce_kernel[(1,)]( team_handle, dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], @@ -1081,7 +1081,7 @@ def test_triton_minmax_reduce(self) -> None: dist.barrier() # Execute MIN reduction team_handle = 0 - min_reduce_kernel[(1,)]( + nvshmem_min_reduce_kernel[(1,)]( team_handle, dst_min_hdl.buffer_ptrs[rank], src_min_hdl.buffer_ptrs[rank], @@ -1090,7 +1090,7 @@ def test_triton_minmax_reduce(self) -> None: launch_cooperative_grid=True, ) # Execute MAX reduction - max_reduce_kernel[(1,)]( + nvshmem_max_reduce_kernel[(1,)]( team_handle, dst_max_hdl.buffer_ptrs[rank], src_max_hdl.buffer_ptrs[rank], diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index b4c2cebf16ce2..ae09e3e05ed39 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -58,6 +58,12 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: Enable NVSHMEM device functions for Triton. It performs a NVSHMEM device-side initialization on the kernel module created by Triton. + This function sets a global hook that initializes NVSHMEM for Triton + kernels. To avoid unnecessary initializations, the hook only acts on + kernels that have "nvshmem" in their function name. Therefore, it is + required that all Triton kernels using NVSHMEM primitives follow this + naming convention. + Args: lib_dir (Optional[str]): The directory where the NVSHMEM device library is located. If not provided, it will use the default path where NVSHMEM @@ -85,13 +91,17 @@ def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: # A hook function to initialize NVSHMEM in Triton def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] - key = kwargs["key"] - device = kwargs["compile"]["device"] jit_function = kwargs["fn"].jit_function - kernel_cache, _, _, _ = jit_function.device_caches[device] - kernel = kernel_cache.get(key, None) - kernel.run - _nvshmemx_cumodule_init(kernel.module) + # Only initialize NVSHMEM module for kernels containing "nvshmem" in their name + if "nvshmem" in jit_function.fn.__name__: + key = kwargs["key"] + device = kwargs["compile"]["device"] + jit_function = kwargs["fn"].jit_function + kernel_cache, _, _, _ = jit_function.device_caches[device] + kernel = kernel_cache.get(key, None) + if kernel is not None: + kernel.run + _nvshmemx_cumodule_init(kernel.module) # Register the function as a post-compile hook triton.knobs.runtime.jit_post_compile_hook = nvshmem_init_hook From bfff2e359226be4e48216ca4ec80415eb33ca364 Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Thu, 7 Aug 2025 18:40:15 -0700 Subject: [PATCH 0138/1424] =?UTF-8?q?[SymmMem]=20Refactor=20NVSHMEM=20Redu?= =?UTF-8?q?ction=20API=20to=20be=20more=20ergonomic=20with=20automatic=20d?= =?UTF-8?q?type=E2=80=90based=20dispatch=20(#159755)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: https://github.com/pytorch/pytorch/pull/159755 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734 --- test/distributed/test_nvshmem_triton.py | 156 ++++++++++++++---- .../_symmetric_memory/_nvshmem_triton.py | 135 +++++++++------ 2 files changed, 214 insertions(+), 77 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index a02d8b58110e0..5a722c0bba34d 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -12,6 +12,7 @@ from torch.testing._internal.common_distributed import MultiProcContinousTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + parametrize, run_tests, skip_but_pass_in_sandcastle_if, skipIfRocm, @@ -237,33 +238,15 @@ def nvshmem_broadcastmem_block_kernel( @triton.jit -def nvshmem_sum_reduce_kernel( +def nvshmem_reduce_kernel( team_handle, dest_ptr, src_ptr, nreduce, + operation: tl.constexpr, + dtype_id: tl.constexpr, ): - nvshmem.sum_reduce(team_handle, dest_ptr, src_ptr, nreduce) - - -@triton.jit -def nvshmem_max_reduce_kernel( - team_handle, - dest_ptr, - src_ptr, - nreduce, -): - nvshmem.max_reduce(team_handle, dest_ptr, src_ptr, nreduce) - - -@triton.jit -def nvshmem_min_reduce_kernel( - team_handle, - dest_ptr, - src_ptr, - nreduce, -): - nvshmem.min_reduce(team_handle, dest_ptr, src_ptr, nreduce) + nvshmem.reduce(team_handle, dest_ptr, src_ptr, nreduce, operation, dtype_id) @instantiate_parametrized_tests @@ -988,7 +971,21 @@ def test_triton_broadcast(self) -> None: @skipIfRocm @requires_triton() @requires_h100() - def test_triton_sum_reduce(self) -> None: + @parametrize( + "dtype", + [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + ], + ) + def test_triton_sum_reduce(self, dtype) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() @@ -998,7 +995,6 @@ def test_triton_sum_reduce(self) -> None: rank = self.rank # Configuration nreduce = 3 # number of separate reductions - dtype = torch.int64 # Source buffer - each rank contributes different values src = symm_mem.empty(nreduce, dtype=dtype, device=self.device) for i in range(nreduce): @@ -1013,20 +1009,26 @@ def test_triton_sum_reduce(self) -> None: # Sum across all ranks: sum((rank+1)*(i+1) for rank in range(world_size)) total = sum((r + 1) * (i + 1) for r in range(world_size)) expected.append(total) + # Synchronize before reduction dist.barrier() - # Execute reduction + + # Execute sum reduction across all ranks team_handle = 0 # NVSHMEM_TEAM_WORLD - nvshmem_sum_reduce_kernel[(1,)]( + nvshmem_reduce_kernel[(1,)]( team_handle, dst_hdl.buffer_ptrs[rank], src_hdl.buffer_ptrs[rank], nreduce, + operation="sum", + dtype_id=src.dtype, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) + # Synchronize after reduction dist.barrier() + # Verify results torch.testing.assert_close( dst, torch.tensor(expected, device=self.device, dtype=dtype) @@ -1035,7 +1037,20 @@ def test_triton_sum_reduce(self) -> None: @skipIfRocm @requires_triton() @requires_h100() - def test_triton_minmax_reduce(self) -> None: + @parametrize( + "dtype", + [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + ], + ) + def test_triton_minmax_reduce(self, dtype) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() @@ -1045,7 +1060,6 @@ def test_triton_minmax_reduce(self) -> None: rank = self.rank # Configuration nreduce = 2 # number of values to reduce - dtype = torch.int64 # Source buffers for min and max src_min = symm_mem.empty(nreduce, dtype=dtype, device=self.device) src_max = symm_mem.empty(nreduce, dtype=dtype, device=self.device) @@ -1081,20 +1095,24 @@ def test_triton_minmax_reduce(self) -> None: dist.barrier() # Execute MIN reduction team_handle = 0 - nvshmem_min_reduce_kernel[(1,)]( + nvshmem_reduce_kernel[(1,)]( team_handle, dst_min_hdl.buffer_ptrs[rank], src_min_hdl.buffer_ptrs[rank], nreduce, + operation="min", + dtype_id=src_min.dtype, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) # Execute MAX reduction - nvshmem_max_reduce_kernel[(1,)]( + nvshmem_reduce_kernel[(1,)]( team_handle, dst_max_hdl.buffer_ptrs[rank], src_max_hdl.buffer_ptrs[rank], nreduce, + operation="max", + dtype_id=src_max.dtype, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) @@ -1107,6 +1125,84 @@ def test_triton_minmax_reduce(self) -> None: dst_max, torch.tensor(expected_max, device=self.device, dtype=dtype) ) + @skipIfRocm + @requires_triton() + @requires_h100() + @parametrize( + "dtype", + [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + ], + ) + def test_triton_prod_reduce(self, dtype) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.distributed_c10d._get_default_group().group_name + symm_mem.enable_symm_mem_for_group(group_name) + world_size = dist.get_world_size() + rank = self.rank + # Configuration + nreduce = 3 # number of separate reductions + # Source buffer - each rank contributes different values + # Use very small values to avoid overflow, especially for small integer types + src = symm_mem.empty(nreduce, dtype=dtype, device=self.device) + for i in range(nreduce): + # Use values that won't overflow even for int8: all values 1 or 2 + if i == 0: + # For first element: rank 0,2,4... gets 1, rank 1,3,5... gets 2 + src[i] = 1 if rank % 2 == 0 else 2 + elif i == 1: + # For second element: all get 1 (no multiplication effect) + src[i] = 1 + else: + # For third element: rank 0,1 get 1, rank 2,3 get 2, etc. (groups of 2) + src[i] = 1 if (rank // 2) % 2 == 0 else 2 + # Destination buffer + dst = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Calculate expected results + vals = torch.empty(nreduce, world_size, dtype=dtype) + vals[0, ::2] = 1 + vals[0, 1::2] = 2 + vals[1] = 1 + vals2 = vals[2].view(-1, 2, 2) + vals2[:, 0] = 1 + vals2[:, 1] = 2 + expected = vals.prod(-1).tolist() + + # Synchronize before reduction + dist.barrier() + + # Execute product reduction across all ranks + team_handle = 0 # NVSHMEM_TEAM_WORLD + nvshmem_reduce_kernel[(1,)]( + team_handle, + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + nreduce, + operation="prod", + dtype_id=src.dtype, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + + # Synchronize after reduction + dist.barrier() + + # Verify results + torch.testing.assert_close( + dst, torch.tensor(expected, device=self.device, dtype=dtype) + ) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index ae09e3e05ed39..10f4d27c14389 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -358,60 +358,101 @@ def broadcastmem_block(team, dest, source, size_bytes, pe_root, _semantic=None): _semantic=_semantic, ) - # Reduction Operations - @core.extern - def sum_reduce(team, dest, source, nreduce, _semantic=None): # type: ignore[no-untyped-def] - """Sum reduction for int64. nreduce is number of elements in the dest and source arrays.""" - return core.extern_elementwise( - "", - "", - [team, dest, source, nreduce], - { - ( - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - ): ("nvshmem_int64_sum_reduce", core.dtype("int32")) - }, - is_pure=False, - _semantic=_semantic, - ) + # Reduction Operation + @core.extern # type: ignore[misc] + def reduce(team, dest, source, nreduce, operation: str, dtype_id, _semantic=None): # type: ignore[no-untyped-def] + """ + Performs a collective reduction operation on symmetric data across a team of PEs. + + This function provides a generic interface to NVSHMEM reduction operations, + automatically selecting the appropriate NVSHMEM function based on the data type + and operation specified. + Args: + team (int64): The team handle (0 for NVSHMEM_TEAM_WORLD). + dest (pointer): Destination pointer where reduction results are stored. + source (pointer): Source pointer containing data to be reduced. + nreduce (int64): Number of elements to reduce. + operation (str): Reduction operation ("sum", "max", "min", "prod"). + dtype_id: Data type specification - accepts torch.dtype, tl.dtype, str, or constexpr. + _semantic: Optional semantic information for Triton compilation. + + Raises: + ValueError: If the operation is not supported. + TypeError: If the data type is not supported. + + Example: + nvshmem.reduce(0, dest_ptr, src_ptr, 100, "sum", torch.float32) + """ + # Mapping from PyTorch/Triton dtype names to NVSHMEM typenames + DTYPE_TO_NVSHMEM_MAP = { + "int8": "int8", + "int16": "int16", + "int32": "int32", + "int64": "int64", + "uint8": "uint8", + "uint16": "uint16", + "uint32": "uint32", + "uint64": "uint64", + "float16": "half", + "bfloat16": "bfloat16", + "float32": "float", + "float64": "double", + } + + # Extract operation name from constexpr if needed + op_name = operation.value if hasattr(operation, "value") else operation + + # Normalize dtype_id to a canonical string name + # Handle different input formats: tl.dtype, torch.dtype, str, constexpr[dtype] + if hasattr(dtype_id, "name"): + # Triton language dtype (e.g., tl.float32) + dtype_name = dtype_id.name + elif isinstance(dtype_id, str): + # Already a plain string name + dtype_name = dtype_id + elif hasattr(dtype_id, "value"): + # Constexpr wrapper around a dtype + inner_value = dtype_id.value + if hasattr(inner_value, "name"): + # Triton dtype inside constexpr + dtype_name = inner_value.name + else: + # PyTorch dtype inside constexpr + dtype_name = str(inner_value).replace("torch.", "") + else: + # PyTorch dtype (e.g., torch.float32) + dtype_name = str(dtype_id).replace("torch.", "") + + # Validate operation is supported + supported_ops = {"sum", "max", "min", "prod"} + if op_name not in supported_ops: + raise ValueError( + f"Unsupported reduction operation: '{op_name}'. Supported ops are {supported_ops}" + ) - @core.extern - def max_reduce(team, dest, source, nreduce, _semantic=None): # type: ignore[no-untyped-def] - """Max reduction for int64. nreduce is number of elements in the dest and source arrays.""" - return core.extern_elementwise( - "", - "", - [team, dest, source, nreduce], - { - ( - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - ): ("nvshmem_int64_max_reduce", core.dtype("int32")) - }, - is_pure=False, - _semantic=_semantic, + # Map to NVSHMEM typename and validate dtype is supported + nvshmem_typename = DTYPE_TO_NVSHMEM_MAP.get(dtype_name) + if nvshmem_typename is None: + raise TypeError( + f"Unsupported reduction dtype: {dtype_name}. Supported dtypes are {list(DTYPE_TO_NVSHMEM_MAP.keys())}" + ) + + # Generate NVSHMEM function name + nvshmem_func = f"nvshmem_{nvshmem_typename}_{op_name}_reduce" + + # Define function signature - all parameters are int64 in Triton (they are just ptrs) + signature = ( + core.dtype("int64"), # team handle + core.dtype("int64"), # destination pointer + core.dtype("int64"), # source pointer + core.dtype("int64"), # number of elements ) - @core.extern - def min_reduce(team, dest, source, nreduce, _semantic=None): # type: ignore[no-untyped-def] - """Min reduction for int64. nreduce is number of elements in the dest and source arrays.""" return core.extern_elementwise( "", "", [team, dest, source, nreduce], - { - ( - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - ): ("nvshmem_int64_min_reduce", core.dtype("int32")) - }, + {signature: (nvshmem_func, core.dtype("int32"))}, is_pure=False, _semantic=_semantic, ) From e0d8a315c5da75840bbb4b061fdeb140959b5e60 Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Thu, 7 Aug 2025 18:40:16 -0700 Subject: [PATCH 0139/1424] [SymmMem] Add helpful docstrings for all NVSHMEM APIs (#159756) Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755 --- .../_symmetric_memory/_nvshmem_triton.py | 511 +++++++++++++++++- 1 file changed, 497 insertions(+), 14 deletions(-) diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 10f4d27c14389..0b6eed12b2963 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -116,7 +116,40 @@ def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] # RMA Operations (mem-based APIs - sizes in bytes) @core.extern def putmem_block(dst, src, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] - """Put data to remote PE. size_bytes specifies the size in bytes.""" + """ + Put data to remote PE using block-scoped operation. + + This function copies a contiguous block of data from the local PE's memory + to a symmetric data object on the remote PE. The operation is performed at + thread block scope, meaning all threads in the block cooperate to perform + the transfer efficiently. + + Args: + dst (int64): Symmetric address of the destination data object on the remote PE. + Must be a pointer to symmetric memory allocated via NVSHMEM. + src (int64): Local address of the source data object containing data to be copied. + Can be any valid local memory address. + size_bytes (int64): Number of bytes to transfer. Must be positive. + pe (int64): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that returns after data has been copied out + of the source array on the local PE. + - The operation does not guarantee delivery to the destination PE. + Use nvshmem_fence() for ordering or nvshmem_quiet() for completion. + - All threads in the block should call this function with the same parameters. + - The source memory remains valid for use immediately after the call returns. + + Example: + ```python + # Transfer 1024 bytes from local buffer to PE 1 + nvshmem.putmem_block(remote_ptr, local_ptr, 1024, 1) + ``` + """ return core.extern_elementwise( "", "", @@ -135,7 +168,39 @@ def putmem_block(dst, src, size_bytes, pe, _semantic=None): # type: ignore[no-u @core.extern def getmem_block(dst, src, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] - """Get data from remote PE. size_bytes specifies the size in bytes.""" + """ + Get data from remote PE using block-scoped operation. + + This function copies a contiguous block of data from a symmetric data object + on the remote PE to the local PE's memory. The operation is performed at + thread block scope, meaning all threads in the block cooperate to perform + the transfer efficiently. + + Args: + dst (int64): Local address of the destination data object to be updated. + Can be any valid local memory address. + src (int64): Symmetric address of the source data object on the remote PE. + Must be a pointer to symmetric memory allocated via NVSHMEM. + size_bytes (int64): Number of bytes to transfer. Must be positive. + pe (int64): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that returns after data has been delivered + to the destination array on the local PE. + - All threads in the block should call this function with the same parameters. + - The destination data is guaranteed to be available for use after the call returns. + - Provides method for copying contiguous symmetric data from different PE. + + Example: + ``` + # Get 1024 bytes from PE 0 into local buffer + nvshmem.getmem_block(local_ptr, remote_ptr, 1024, 0) + ``` + """ return core.extern_elementwise( "", "", @@ -163,7 +228,46 @@ def putmem_signal_block( # type: ignore[no-untyped-def] pe, _semantic=None, ): # type: ignore[no-untyped-def] - """Put data to remote PE with signal. size_bytes specifies the size in bytes.""" + """ + Put data to remote PE with atomic signal operation using block-scoped operation. + + This function copies data from the local PE to the remote PE and then + atomically updates a signal variable on the remote PE to indicate completion. + This enables efficient point-to-point synchronization between PEs. + + Args: + dst (int64): Symmetric address of the destination data object on the remote PE. + src (int64): Local address of the source data object containing data to be copied. + size_bytes (int64): Number of bytes to transfer. Must be positive. + sig_addr (int64): Symmetric address of the signal variable (uint64_t) on the remote PE. + Must be 8-byte aligned symmetric memory. + signal (int64): Value to be used in the signal operation. + sig_op (int64): Signal operation type. Common values: + - NVSHMEM_SIGNAL_SET (0): Atomic set operation + - NVSHMEM_SIGNAL_ADD (5): Atomic add operation + pe (int64): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that returns after data has been copied out + of the source array and the signal has been updated on the remote PE. + - The signal update is performed atomically with respect to other signal + operations and synchronization routines. + - The signal variable must be of type uint64_t in symmetric memory. + - Use with nvshmem_signal_wait_until() for synchronization. + + Example: + ``` + # Transfer data and set completion flag to 1 + NVSHMEM_SIGNAL_SET = 0 + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, 1024, sig_ptr, 1, NVSHMEM_SIGNAL_SET, target_pe + ) + ``` + """ return core.extern_elementwise( "", "", @@ -186,7 +290,43 @@ def putmem_signal_block( # type: ignore[no-untyped-def] # Wait and Signal Operations @core.extern def wait_until(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] - """Wait until a condition is met on a symmetric variable.""" + """ + Wait until a condition is met on a symmetric variable. + + This function blocks the calling thread until the value at the specified + symmetric memory location satisfies the given comparison condition. This + provides a mechanism for point-to-point synchronization between PEs. + + Args: + ivar (int64): Symmetric address of the variable to monitor. Must be a + pointer to symmetric memory (typically int64/uint64). + cmp (int64): Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until ivar == cmp_val + - NVSHMEM_CMP_NE (1): Wait until ivar != cmp_val + - NVSHMEM_CMP_GT (2): Wait until ivar > cmp_val + - NVSHMEM_CMP_GE (3): Wait until ivar >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until ivar < cmp_val + - NVSHMEM_CMP_LE (5): Wait until ivar <= cmp_val + cmp_val (int64): Value to compare against. + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that will wait indefinitely until the + condition is satisfied. + - The variable must be in symmetric memory and accessible from other PEs. + - Updates to the variable from remote PEs will eventually become visible. + - Can be used with put operations from other PEs for synchronization. + + Example: + ``` + # Wait until flag becomes 1 (set by another PE) + NVSHMEM_CMP_EQ = 0 + nvshmem.wait_until(flag_ptr, NVSHMEM_CMP_EQ, 1) + ``` + """ return core.extern_elementwise( "", "", @@ -204,7 +344,44 @@ def wait_until(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-d @core.extern def signal_wait_until(sig_addr, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] - """Wait until a signal variable meets a condition.""" + """ + Wait until a signal variable meets a specified condition. + + This function blocks the calling thread until the value at the specified + signal variable satisfies the given comparison condition. Signal variables + are special uint64_t symmetric objects used for efficient synchronization + with signal operations. + + Args: + sig_addr (int64): Symmetric address of the signal variable (uint64_t). + Must be 8-byte aligned symmetric memory. + cmp (int64): Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until signal == cmp_val + - NVSHMEM_CMP_NE (1): Wait until signal != cmp_val + - NVSHMEM_CMP_GT (2): Wait until signal > cmp_val + - NVSHMEM_CMP_GE (3): Wait until signal >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until signal < cmp_val + - NVSHMEM_CMP_LE (5): Wait until signal <= cmp_val + cmp_val (int64): Value to compare against. + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation designed specifically for signal variables. + - Signal variables are updated atomically by putmem_signal operations. + - More efficient than wait_until for signal-based synchronization patterns. + - Ensures the signal update is fully complete before returning. + - Commonly used with putmem_signal_block for producer-consumer patterns. + + Example: + ``` + # Wait for signal to be set to completion value + NVSHMEM_CMP_EQ = 0 + nvshmem.signal_wait_until(signal_ptr, NVSHMEM_CMP_EQ, 42) + ``` + """ return core.extern_elementwise( "", "", @@ -222,7 +399,40 @@ def signal_wait_until(sig_addr, cmp, cmp_val, _semantic=None): # type: ignore[n @core.extern def signal_op(sig_addr, signal, sig_op, pe, _semantic=None): # type: ignore[no-untyped-def] - """Perform a signal operation on a remote PE.""" + """ + Perform an atomic signal operation on a remote PE. + + This function atomically updates a signal variable on the specified remote PE + using the given operation and value. This enables efficient point-to-point + synchronization and notification between PEs. + + Args: + sig_addr (int64): Symmetric address of the signal variable (uint64_t) on the remote PE. + Must be 8-byte aligned symmetric memory. + signal (int64): Value to be used in the signal operation. + sig_op (int64): Signal operation type. Common values: + - NVSHMEM_SIGNAL_SET (0): Atomically set sig_addr = signal + - NVSHMEM_SIGNAL_ADD (5): Atomically set sig_addr += signal + pe (int64): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a one-sided operation - the remote PE does not need to participate. + - The signal operation is performed atomically on the remote PE. + - Can be used with signal_wait_until() on the remote PE for synchronization. + - Provides low-overhead notification mechanism between PEs. + - The signal variable must be of type uint64_t in symmetric memory. + + Example: + ```python + # Atomically set remote signal to 1 to notify completion + NVSHMEM_SIGNAL_SET = 0 + nvshmem.signal_op(remote_signal_ptr, 1, NVSHMEM_SIGNAL_SET, target_pe) + ``` + """ return core.extern_elementwise( "", "", @@ -242,7 +452,41 @@ def signal_op(sig_addr, signal, sig_op, pe, _semantic=None): # type: ignore[no- # Memory Ordering Operations @core.extern def fence(_semantic=None): # type: ignore[no-untyped-def] - """Ensure ordering of put operations.""" + """ + Ensure ordering of put operations to each remote PE. + + This function provides a memory fence that ensures point-to-point ordering + of remote memory operations. Put operations issued before the fence are + guaranteed to be ordered before put operations issued after the fence, + when targeting the same remote PE. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This provides weaker ordering guarantees than quiet(). + - Operations to each PE are ordered, but operations to different PEs + may still be reordered relative to each other. + - Does not guarantee completion of operations, only ordering. + - Non-blocking operations are not ordered by fence - use quiet() instead. + - Essential for ensuring correct ordering in communication patterns. + + Memory Ordering Guarantees: + - Put operations before fence() → ordered before → Put operations after fence() + - Ordering is maintained per-destination-PE basis + - Remote PEs can observe the enforced ordering + + Example: + ``` + # Ensure first put completes before second put to same PE + nvshmem.putmem_block(dst1, src1, size, target_pe) + nvshmem.fence() # Enforce ordering + nvshmem.putmem_block(dst2, src2, size, target_pe) + ``` + """ return core.extern_elementwise( "", "", @@ -256,7 +500,41 @@ def fence(_semantic=None): # type: ignore[no-untyped-def] @core.extern def quiet(_semantic=None): # type: ignore[no-untyped-def] - """Wait for completion of all outstanding put operations.""" + """ + Wait for completion of all outstanding put operations. + + This function blocks until all outstanding remote memory operations issued + by the calling PE have completed. It provides stronger guarantees than + fence() by ensuring both ordering and completion of all operations. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that waits for completion. + - Ensures all previous put operations have been delivered to their destinations. + - Provides global ordering - operations to ALL PEs are ordered. + - Required to complete non-blocking operations. + - More expensive than fence() but provides stronger guarantees. + + Memory Ordering Guarantees: + - All put operations before quiet() are completed before any operations after quiet() + - Operations are visible to all PEs as having occurred before subsequent operations + - Both blocking and non-blocking operations are completed + + Example: + ``` + # Ensure all data transfers complete before setting completion flag + nvshmem.putmem_block(data_ptr, src_ptr, data_size, target_pe) + nvshmem.quiet() # Wait for data transfer completion + nvshmem.putmem_block( + flag_ptr, flag_src_ptr, 8, target_pe + ) # Signal completion + ``` + """ return core.extern_elementwise( "", "", @@ -271,7 +549,38 @@ def quiet(_semantic=None): # type: ignore[no-untyped-def] # PE Information Operations @core.extern def my_pe(_semantic=None): # type: ignore[no-untyped-def] - """Get the PE number of the calling PE.""" + """ + Get the PE number of the calling PE. + + This function returns the unique identifier (PE number) of the current + processing element within the NVSHMEM job. PE numbers range from 0 to + nvshmem_n_pes() - 1. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: PE number of the calling PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - This is a pure function that returns the same value throughout execution. + - PE numbering starts from 0 and is contiguous. + - Each PE has a unique identifier within the NVSHMEM job. + - Can be called from both host and device code. + - Essential for implementing PE-specific logic and communication patterns. + + Example: + ``` + # Get current PE number for conditional logic + pe = nvshmem.my_pe() + if pe == 0: + # Root PE logic + pass + else: + # Non-root PE logic + pass + ``` + """ return core.extern_elementwise( "", "", @@ -283,7 +592,38 @@ def my_pe(_semantic=None): # type: ignore[no-untyped-def] @core.extern def n_pes(_semantic=None): # type: ignore[no-untyped-def] - """Get the total number of PEs.""" + """ + Get the total number of PEs in the NVSHMEM job. + + This function returns the total count of processing elements (PEs) + participating in the current NVSHMEM job. This value remains constant + throughout the execution of the program. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Total number of PEs in the job (always ≥ 1). + + Notes: + - This is a pure function that returns the same value throughout execution. + - The value is determined at NVSHMEM initialization and never changes. + - Valid PE numbers range from 0 to n_pes() - 1. + - Can be called from both host and device code. + - Essential for implementing collective operations and communication patterns. + + Example: + ``` + # Broadcast from root to all other PEs + total_pes = nvshmem.n_pes() + my_rank = nvshmem.my_pe() + + if my_rank == 0: + # Send to all other PEs + for peer in range(1, total_pes): + nvshmem.putmem_block(dst_ptr, src_ptr, size, peer) + ``` + """ return core.extern_elementwise( "", "", @@ -296,7 +636,41 @@ def n_pes(_semantic=None): # type: ignore[no-untyped-def] # Synchronization Operations @core.extern def barrier_all(_semantic=None): # type: ignore[no-untyped-def] - """Synchronize all PEs.""" + """ + Synchronize all PEs with completion guarantee. + + This function creates a barrier across all PEs in the NVSHMEM job. It ensures + that all local and remote memory updates issued before the barrier by any PE + are completed before any PE exits the barrier. This provides both + synchronization and memory consistency. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs must participate. + - Stronger guarantee than sync_all() - ensures completion of remote operations. + - Blocks until all PEs reach the barrier AND all memory operations complete. + - Must be called from kernels launched with cooperative launch. + - Provides full memory consistency across all PEs. + - More expensive than sync_all() due to completion guarantees. + + Memory Consistency Guarantees: + - All memory updates before barrier_all() are visible to all PEs + - All remote memory operations are completed before any PE continues + - Provides a global synchronization point with memory ordering + + Example: + ``` + # Ensure all PEs complete their work before proceeding + # All PEs execute this - it's a collective operation + nvshmem.barrier_all() + # At this point, all previous operations are complete on all PEs + ``` + """ return core.extern_elementwise( "", "", @@ -308,7 +682,41 @@ def barrier_all(_semantic=None): # type: ignore[no-untyped-def] @core.extern def sync_all(_semantic=None): # type: ignore[no-untyped-def] - """Synchronize all PEs (lightweight version, does not ensure completion of remote memory updates).""" + """ + Synchronize all PEs with local completion guarantee. + + This function creates a lightweight synchronization barrier across all PEs. + It ensures that all local store operations issued before the sync are + visible to other PEs, but does not guarantee completion of remote memory + operations initiated by the calling PE. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs must participate. + - Lighter weight than barrier_all() - only ensures local store visibility. + - Does not guarantee completion of remote memory updates initiated locally. + - Must be called from kernels launched with cooperative launch. + - Suitable when only synchronization (not completion) is needed. + - More efficient than barrier_all() for synchronization-only patterns. + + Memory Consistency Guarantees: + - Local store operations are visible to other PEs + - Does NOT ensure completion of outgoing remote operations + - Provides synchronization point without full completion overhead + + Example: + ``` + # Lightweight synchronization between PEs + # All PEs execute this - it's a collective operation + nvshmem.sync_all() + # Local stores are visible, but remote ops may still be in flight + ``` + """ return core.extern_elementwise( "", "", @@ -321,7 +729,45 @@ def sync_all(_semantic=None): # type: ignore[no-untyped-def] # Collective Operations (mem-based APIs - sizes in bytes) @core.extern def alltoallmem_block(team, dest, source, size_bytes, _semantic=None): # type: ignore[no-untyped-def] - """Perform alltoall operation on symmetric memory. size_bytes specifies the number of bytes to exchange per PE.""" + """ + Perform alltoall collective operation on symmetric memory. + + This function implements an all-to-all collective communication pattern where + each PE sends a portion of its data to every other PE, and receives data from + every other PE. The operation exchanges size_bytes of data between each pair of PEs. + + Args: + team (int64): Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD + (all PEs in the job). + dest (int64): Symmetric address of the destination buffer. Must be large enough + to hold size_bytes * n_pes total bytes. + source (int64): Symmetric address of the source buffer containing data to send. + Must contain size_bytes * n_pes total bytes. + size_bytes (int64): Number of bytes to exchange with each PE. Must be positive. + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Data Layout: + - Source buffer layout: [data_for_pe0, data_for_pe1, ..., data_for_pe(n-1)] + - Destination buffer layout: [data_from_pe0, data_from_pe1, ..., data_from_pe(n-1)] + - Each segment is size_bytes in length + + Notes: + - This is a collective operation - all PEs in the team must participate. + - Must be called from kernels launched with cooperative launch. + - The source and destination buffers must not overlap. + - All PEs must call with the same size_bytes value. + - Provides efficient many-to-many data exchange pattern. + + Example: + ``` + # Each PE sends 1024 bytes to every other PE + team_world = 0 + nvshmem.alltoallmem_block(team_world, dest_ptr, src_ptr, 1024) + ``` + """ return core.extern_elementwise( "", "", @@ -340,7 +786,44 @@ def alltoallmem_block(team, dest, source, size_bytes, _semantic=None): # type: @core.extern def broadcastmem_block(team, dest, source, size_bytes, pe_root, _semantic=None): # type: ignore[no-untyped-def] - """Broadcast data from a root PE to all other PEs in a team. size_bytes specifies the size in bytes.""" + """ + Broadcast data from a root PE to all other PEs in a team. + + This function implements a collective broadcast operation where the root PE + sends its data to all other PEs in the team. All PEs (including the root) + receive a copy of the data from the root PE in their destination buffer. + + Args: + team (int64): Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD + (all PEs in the job). + dest (int64): Symmetric address of the destination buffer on all PEs. + Must be large enough to hold size_bytes. + source (int64): Symmetric address of the source buffer on the root PE. + Only the root PE's source buffer is used. + size_bytes (int64): Number of bytes to broadcast. Must be positive. + pe_root (int64): PE number of the root PE that provides the source data. + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs in the team must participate. + - Must be called from kernels launched with cooperative launch. + - Only the root PE's source buffer is read; other PEs' source buffers are ignored. + - All PEs (including root) receive the data in their destination buffer. + - All PEs must call with the same team, size_bytes, and pe_root values. + - The source and destination buffers must not overlap on any PE. + - Efficient one-to-many communication pattern. + + Example: + ``` + # PE 0 broadcasts 1024 bytes to all PEs in the team + team_world = 0 + root_pe = 0 + nvshmem.broadcastmem_block(team_world, dest_ptr, src_ptr, 1024, root_pe) + ``` + """ return core.extern_elementwise( "", "", From 3a562374401113187ce2566b87e3f1d87d7c53aa Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Thu, 7 Aug 2025 18:40:16 -0700 Subject: [PATCH 0140/1424] [SymmMem] Send tensors with unerased type information to NVSHMEM Triton kernels (#159788) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756 --- test/distributed/test_nvshmem_triton.py | 509 +++++++++--------- .../_symmetric_memory/_nvshmem_triton.py | 375 +++++++------ 2 files changed, 462 insertions(+), 422 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 5a722c0bba34d..15dca00d01219 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -42,23 +42,23 @@ def requires_h100(): # Shared Triton JIT kernels @triton.jit -def nvshmem_putmem_block_kernel( - dst_ptr, - src_ptr, - size_bytes, - peer, +def nvshmem_put_kernel( + dest, + src, + nelems, + pe, ): - nvshmem.putmem_block(dst_ptr, src_ptr, size_bytes, peer) + nvshmem.put(dest, src, nelems, pe) @triton.jit -def nvshmem_getmem_block_kernel( - dst_ptr, - src_ptr, - size_bytes, - peer, +def nvshmem_get_kernel( + dest, + src, + nelems, + pe, ): - nvshmem.getmem_block(dst_ptr, src_ptr, size_bytes, peer) + nvshmem.get(dest, src, nelems, pe) @triton.jit @@ -93,11 +93,11 @@ def nvshmem_signal_op_kernel( @triton.jit def nvshmem_wait_until_kernel( - ivar_ptr, + ivar, cmp_op, cmp_val, ): - nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val) + nvshmem.wait_until(ivar, cmp_op, cmp_val) @triton.jit @@ -107,50 +107,50 @@ def nvshmem_fence_kernel(): @triton.jit def nvshmem_put_with_fence_kernel( - dst_ptr1, - dst_ptr2, - src_ptr1, - src_ptr2, - flag_ptr, - flag_src_ptr, - size_bytes, + dst1, + src1, + dst2, + src2, + flag_dst, + flag_src, + nelems, peer, ): # First put - nvshmem.putmem_block(dst_ptr1, src_ptr1, size_bytes, peer) + nvshmem.put(dst1, src1, nelems, peer) # Ensure the first put is ordered before the next. nvshmem.fence() # Second put - nvshmem.putmem_block(dst_ptr2, src_ptr2, size_bytes, peer) + nvshmem.put(dst2, src2, nelems, peer) # Order the second put before flag update. nvshmem.fence() # Write the flag (single int64) to signal completion. - nvshmem.putmem_block(flag_ptr, flag_src_ptr, 8, peer) # 8 bytes for int64 + nvshmem.put(flag_dst, flag_src, 1, peer) @triton.jit def nvshmem_put_with_quiet_kernel( - dst_ptr, - src_ptr, - flag_dst_ptr, - flag_src_ptr, - size_bytes, + dst, + src, + flag_dst, + flag_src, + nelems, peer, ): # Put data - nvshmem.putmem_block(dst_ptr, src_ptr, size_bytes, peer) + nvshmem.put(dst, src, nelems, peer) # Call quiet to ensure put is complete nvshmem.quiet() # Only after quiet, set the completion flag # This ensures the data put is complete before flag is set - nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 8, peer) # 8 bytes for int64 + nvshmem.put(flag_dst, flag_src, 1, peer) @triton.jit def nvshmem_barrier_test_kernel( - dst_ptr, - src_ptr, - size_bytes, + dst, + src, + nelems, ): # Testing barrier_all() requires coordinated operations across PEs within # the same kernel execution. Unlike other kernels that just wrap NVSHMEM @@ -162,12 +162,12 @@ def nvshmem_barrier_test_kernel( # Rank 0 broadcasts its value to all other ranks if my_pe == 0: # Write initial value - p_src = src_ptr.to(tl.pointer_type(tl.int32)) + p_src = src.to(tl.pointer_type(tl.int32)) tl.store(p_src, 42) # Put to all other ranks i = 1 while i < n_pes: - nvshmem.putmem_block(dst_ptr, src_ptr, size_bytes, i) + nvshmem.put(dst, src, nelems, i) i += 1 # Synchronize all PEs @@ -175,7 +175,7 @@ def nvshmem_barrier_test_kernel( # Non-zero ranks increment the received value if my_pe != 0: - p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) + p_dst = dst.to(tl.pointer_type(tl.int32)) received = tl.load(p_dst) tl.store(p_dst, received + 1) @@ -187,66 +187,61 @@ def nvshmem_barrier_all_kernel(): @triton.jit def nvshmem_sync_test_kernel( - dst_ptr, - src_ptr, - size_bytes, + local_data, + remote_data, + nelems, ): my_pe = nvshmem.my_pe() n_pes = nvshmem.n_pes() - # Rank 0 broadcasts its value to all other ranks - if my_pe == 0: - # Write initial value - p_src = src_ptr.to(tl.pointer_type(tl.int32)) - tl.store(p_src, 42) - # Put to all other ranks - i = 1 - while i < n_pes: - nvshmem.putmem_block(dst_ptr, src_ptr, size_bytes, i) - i += 1 + # Each PE writes a unique value to its local memory + p_local = local_data.to(tl.pointer_type(tl.int32)) + unique_value = my_pe + 100 # PE 0 writes 100, PE 1 writes 101, etc. + tl.store(p_local, unique_value) - # Synchronize all PEs (this is more lightweight than barrier_all() b/c it only ensures local store visibility - # and doesn't wait for remote ops to complete) + # sync_all() ensures local stores are visible to other PEs + # but doesn't guarantee completion of any remote operations nvshmem.sync_all() - # Non-zero ranks increment the received value - if my_pe != 0: - p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) - received = tl.load(p_dst) - tl.store(p_dst, received + 1) + # Now each PE reads from the next PE's memory to verify visibility + # PE 0 reads from PE 1, PE 1 reads from PE 2, ..., PE n-1 reads from PE 0 + next_pe = (my_pe + 1) % n_pes + nvshmem.get(remote_data, local_data, nelems, next_pe) + + # The get should now see the value that the next PE wrote locally + # because sync_all() made those local stores visible @triton.jit -def nvshmem_alltoallmem_block_kernel( +def nvshmem_alltoall_kernel( team_handle, - dest_ptr, - src_ptr, - size_bytes_per_pe, + dst, + src, + nelems_per_pe, ): - nvshmem.alltoallmem_block(team_handle, dest_ptr, src_ptr, size_bytes_per_pe) + nvshmem.alltoall(team_handle, dst, src, nelems_per_pe) @triton.jit -def nvshmem_broadcastmem_block_kernel( +def nvshmem_broadcast_kernel( team_handle, - dest_ptr, - src_ptr, - size_bytes, + dst, + src, + nelems, pe_root, ): - nvshmem.broadcastmem_block(team_handle, dest_ptr, src_ptr, size_bytes, pe_root) + nvshmem.broadcast(team_handle, dst, src, nelems, pe_root) @triton.jit def nvshmem_reduce_kernel( team_handle, - dest_ptr, - src_ptr, + dest_tensor, + source_tensor, nreduce, operation: tl.constexpr, - dtype_id: tl.constexpr, ): - nvshmem.reduce(team_handle, dest_ptr, src_ptr, nreduce, operation, dtype_id) + nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation) @instantiate_parametrized_tests @@ -278,32 +273,47 @@ def test_triton_put(self) -> None: symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - msg_size_bytes = 8 - dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize + # Configuration + nelems = 5 # number of elements to transfer + dtype = torch.int64 + val = 42 + rank # Each rank has different data - val = 5 - inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) - out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) + # Create symmetric tensors + src = symm_mem.empty(nelems, dtype=dtype, device=self.device) + dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999) + + # Fill source tensor with rank-specific pattern + for i in range(nelems): + src[i] = ( + val * 10 + i + ) # Rank 0: [420, 421, 422, 423, 424], Rank 1: [430, 431, ...] + + # Rendezvous + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) + + # Synchronize before operation + dist.barrier() peer = 1 - rank if rank == 0: - dst_ptr = out_hdl.buffer_ptrs[rank] - src_ptr = inp_hdl.buffer_ptrs[rank] - nvshmem_putmem_block_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - size_bytes=msg_size_bytes, - peer=peer, + # Rank 0 puts its data to Rank 1 + nvshmem_put_kernel[(1,)]( + dst, + src, + nelems, + peer, extern_libs=nvshmem_lib, ) + # Synchronize after operation dist.barrier() + if rank == 1: + # Verify that rank 1 received rank 0's data + expected = [420 + i for i in range(nelems)] torch.testing.assert_close( - out, val * torch.ones(numel, dtype=dtype, device=self.device) + dst, torch.tensor(expected, device=self.device, dtype=dtype) ) @skipIfRocm @@ -317,27 +327,29 @@ def test_triton_get(self) -> None: group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - msg_size_bytes = 8 + + # Configuration + numel = 8 dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize val = 7 + + # Create symmetric tensors inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_( val if rank == 0 else -1 ) out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) + symm_mem.rendezvous(inp, group=group_name) + symm_mem.rendezvous(out, group=group_name) + dist.barrier() peer = 1 - rank if rank == 1: - # Rank 1 gets data from rank 0 - dst_ptr = out_hdl.buffer_ptrs[rank] - src_ptr = inp_hdl.buffer_ptrs[rank] - nvshmem_getmem_block_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - size_bytes=msg_size_bytes, - peer=peer, + # Rank 1 gets data from rank 0 using tensor-aware API + nvshmem_get_kernel[(1,)]( + out, + inp, + numel, + peer, extern_libs=nvshmem_lib, ) if rank == 1: @@ -357,29 +369,29 @@ def test_triton_get_ring(self) -> None: symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank world_size = dist.get_world_size() - msg_size_bytes = 8 + + # Configuration + numel = 8 dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize # Each rank fills its input buffer with its own rank value inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank) out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) + symm_mem.rendezvous(inp, group=group_name) + symm_mem.rendezvous(out, group=group_name) + dist.barrier() # Ring topology: each rank gets data from the rank to its left # rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc. peer = (rank - 1) % world_size - # All ranks execute the get operation - dst_ptr = out_hdl.buffer_ptrs[rank] - src_ptr = inp_hdl.buffer_ptrs[rank] - nvshmem_getmem_block_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - size_bytes=msg_size_bytes, - peer=peer, + # All ranks execute the get operation using tensor-aware API + nvshmem_get_kernel[(1,)]( + out, + inp, + numel, + peer, extern_libs=nvshmem_lib, ) @@ -539,15 +551,14 @@ def test_triton_wait_until(self) -> None: flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_( FLAG_INITIAL_VALUE ) - flag_hdl = symm_mem.rendezvous(flag, group=group_name) + symm_mem.rendezvous(flag, group=group_name) nvshmem_barrier_all_kernel[(1,)](extern_libs=nvshmem_lib) if rank == 0: # Rank 0 (the waiter) - ivar_ptr = flag_hdl.buffer_ptrs[rank] nvshmem_wait_until_kernel[(1,)]( - ivar_ptr, + flag, cmp_op=NVSHMEM_CMP_EQ, cmp_val=FLAG_FINAL_VALUE, extern_libs=nvshmem_lib, @@ -565,15 +576,12 @@ def test_triton_wait_until(self) -> None: [FLAG_FINAL_VALUE], dtype=torch.int64, device=self.device ) - # The destination is Rank 0's flag buffer. - dst_ptr = flag_hdl.buffer_ptrs[rank] - - # Launch a kernel to put the value to Rank 0. - nvshmem_putmem_block_kernel[(1,)]( - dst_ptr, # Destination pointer on the remote PE - val_to_put.data_ptr(), # Source data pointer (local) - size_bytes=8, # Size of one int64 - peer=peer, # The target PE (Rank 0) + # Launch a kernel to put the value to Rank 0's flag tensor. + nvshmem_put_kernel[(1,)]( + flag, # Destination symmetric tensor on the remote PE + val_to_put, # Source data tensor (local) + 1, # Number of elements + peer, # The target PE (Rank 0) extern_libs=nvshmem_lib, ) @@ -658,7 +666,6 @@ def test_triton_fence(self) -> None: its arrival implies that both preceding puts have been delivered in order. """ - torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() @@ -667,9 +674,8 @@ def test_triton_fence(self) -> None: rank = self.rank peer = 1 - rank # Message configuration - msg_size_bytes = 8 dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize + numel = 8 val1 = 10 val2 = 20 @@ -679,42 +685,35 @@ def test_triton_fence(self) -> None: inp2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val2) out1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) out2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp1_hdl = symm_mem.rendezvous(inp1, group=group_name) - inp2_hdl = symm_mem.rendezvous(inp2, group=group_name) - out1_hdl = symm_mem.rendezvous(out1, group=group_name) - out2_hdl = symm_mem.rendezvous(out2, group=group_name) - - # Flag buffer resides in the signal pad of out2. - flag = out2_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) + symm_mem.rendezvous(inp1, group=group_name) + symm_mem.rendezvous(inp2, group=group_name) + symm_mem.rendezvous(out1, group=group_name) + symm_mem.rendezvous(out2, group=group_name) + + # Use regular symmetric memory tensor for flag + flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_(0) + symm_mem.rendezvous(flag, group=group_name) flag_update_val = torch.tensor( [flag_val], dtype=torch.int64, device=self.device ) NVSHMEM_CMP_EQ = 0 # compare equal if rank == 0: - dst_ptr1 = out1_hdl.buffer_ptrs[rank] - dst_ptr2 = out2_hdl.buffer_ptrs[rank] - src_ptr1 = inp1_hdl.buffer_ptrs[rank] - src_ptr2 = inp2_hdl.buffer_ptrs[rank] - flag_ptr = out2_hdl.signal_pad_ptrs[rank] - flag_src_ptr = flag_update_val.data_ptr() - - nvshmem_put_with_fence_kernel[(1, 1, 1)]( - dst_ptr1, - dst_ptr2, - src_ptr1, - src_ptr2, - flag_ptr, - flag_src_ptr, - size_bytes=msg_size_bytes, + nvshmem_put_with_fence_kernel[(1,)]( + out1, + inp1, + out2, + inp2, + flag, + flag_update_val, + nelems=numel, peer=peer, extern_libs=nvshmem_lib, ) elif rank == 1: - # Wait until flag is set by Rank 0. - ivar_ptr = out2_hdl.signal_pad_ptrs[rank] - nvshmem_wait_until_kernel[(1, 1, 1)]( - ivar_ptr, + # Wait until flag is set by Rank 0 + nvshmem_wait_until_kernel[(1,)]( + flag, cmp_op=NVSHMEM_CMP_EQ, cmp_val=flag_val, extern_libs=nvshmem_lib, @@ -737,58 +736,52 @@ def test_triton_fence(self) -> None: def test_triton_quiet(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() - # Enable NVSHMEM for Triton nvshmem_lib = nvshmem.enable_triton() group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - msg_size_bytes = 8 - dtype = torch.int8 - numel = msg_size_bytes // dtype.itemsize + peer = 1 - rank - # Data buffers + dtype = torch.int8 + numel = 8 val = 15 + flag_val = 42 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) - inp_hdl = symm_mem.rendezvous(inp, group=group_name) - out_hdl = symm_mem.rendezvous(out, group=group_name) - # Use signal pad as completion flag - flag_val = 42 - peer = 1 - rank + flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_(0) + flag_update_val = torch.tensor( + [flag_val], dtype=torch.int64, device=self.device + ) + + symm_mem.rendezvous(inp, group=group_name) + symm_mem.rendezvous(out, group=group_name) + symm_mem.rendezvous(flag, group=group_name) + NVSHMEM_CMP_EQ = 0 - if rank == 0: - # Rank 0 waits for flag from Rank 1 - ivar_ptr = out_hdl.signal_pad_ptrs[rank] - nvshmem_wait_until_kernel[(1, 1, 1)]( - ivar_ptr, + dist.barrier() + if rank == 1: + nvshmem_put_with_quiet_kernel[(1,)]( + out, + inp, + flag, + flag_update_val, + nelems=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + elif rank == 0: + nvshmem_wait_until_kernel[(1,)]( + flag, cmp_op=NVSHMEM_CMP_EQ, cmp_val=flag_val, extern_libs=nvshmem_lib, ) - # After flag is set, data should be complete due to quiet torch.testing.assert_close( out, val * torch.ones(numel, dtype=dtype, device=self.device) ) - if rank == 1: - # Rank 1 puts data and flag with quiet in between - dst_ptr = out_hdl.buffer_ptrs[rank] - src_ptr = inp_hdl.buffer_ptrs[rank] - flag_dst_ptr = out_hdl.signal_pad_ptrs[rank] - # Create a tensor for the flag value - flag_update_val = torch.tensor( - [flag_val], dtype=torch.int64, device=self.device - ) - flag_src_ptr = flag_update_val.data_ptr() - nvshmem_put_with_quiet_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - flag_dst_ptr, - flag_src_ptr, - size_bytes=msg_size_bytes, - peer=peer, - extern_libs=nvshmem_lib, - ) + dist.barrier() @skipIfRocm @requires_triton() @@ -802,30 +795,27 @@ def test_triton_barrier(self) -> None: rank = self.rank numel = 1 dtype = torch.int32 - size_bytes = numel * dtype.itemsize - # Create symmetric buffers + src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) - # Launch kernel with cooperative grid + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) + nvshmem_barrier_test_kernel[(1,)]( - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], - size_bytes=size_bytes, + dst, + src, + nelems=numel, extern_libs=nvshmem_lib, launch_cooperative_grid=True, num_ctas=1, ) - # Verify results - # Rank 0 should have 42, and then the rest should have incremented + 1 to 43 + dist.barrier() + if rank == 0: - # Rank 0 should have its original value (42) in src torch.testing.assert_close( src, torch.tensor([42], device=self.device, dtype=dtype) ) else: - # Other ranks should have received 42 and incremented to 43 torch.testing.assert_close( dst, torch.tensor([43], device=self.device, dtype=dtype) ) @@ -836,38 +826,45 @@ def test_triton_barrier(self) -> None: def test_triton_sync(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() + nvshmem_lib = nvshmem.enable_triton() group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank numel = 1 dtype = torch.int32 - size_bytes = numel * dtype.itemsize + # Create symmetric buffers - src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) - dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) + local_data = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + remote_data = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + symm_mem.rendezvous(local_data, group=group_name) + symm_mem.rendezvous(remote_data, group=group_name) + # Launch kernel with cooperative grid nvshmem_sync_test_kernel[(1,)]( - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], - size_bytes=size_bytes, + local_data, + remote_data, + nelems=numel, extern_libs=nvshmem_lib, launch_cooperative_grid=True, num_ctas=1, ) + # Verify results - if rank == 0: - # Rank 0 should have its original value (42) in src - torch.testing.assert_close( - src, torch.tensor([42], device=self.device, dtype=dtype) - ) - else: - # Other ranks should have received 42 and incremented to 43 - torch.testing.assert_close( - dst, torch.tensor([43], device=self.device, dtype=dtype) - ) + # Each PE should have written rank + 100 to its local_data + expected_local = rank + 100 + torch.testing.assert_close( + local_data, torch.tensor([expected_local], device=self.device, dtype=dtype) + ) + + # Each PE should have read (next_rank + 100) into its remote_data + # PE 0 reads from PE 1, PE 1 reads from PE 2, ..., PE n-1 reads from PE 0 + next_rank = (rank + 1) % self.world_size + expected_remote = next_rank + 100 + torch.testing.assert_close( + remote_data, + torch.tensor([expected_remote], device=self.device, dtype=dtype), + ) @skipIfRocm @requires_triton() @@ -883,7 +880,6 @@ def test_triton_alltoall(self) -> None: # Each PE will send 2 int64 elements to every other PE nelems_per_pe = 2 dtype = torch.int64 - size_bytes_per_pe = nelems_per_pe * dtype.itemsize # Source buffer: contains data for all PEs # Layout: [data_for_pe0, data_for_pe1, ...] src_size = nelems_per_pe * world_size @@ -895,17 +891,17 @@ def test_triton_alltoall(self) -> None: src[i * nelems_per_pe : (i + 1) * nelems_per_pe] = value # Destination buffer dst = symm_mem.empty(src_size, dtype=dtype, device=self.device).fill_(-1) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) # Synchronize before alltoall dist.barrier() team_handle = 0 # NVSHMEM_TEAM_WORLD handle is 0 - # Launch the kernel - nvshmem_alltoallmem_block_kernel[(1,)]( + # Launch the kernel using new tensor-aware API + nvshmem_alltoall_kernel[(1,)]( team_handle, - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], - size_bytes_per_pe=size_bytes_per_pe, + dst, + src, + nelems_per_pe, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) @@ -929,13 +925,17 @@ def test_triton_broadcast(self) -> None: group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank + # Configuration nelems = 4 # number of elements dtype = torch.int64 - size_bytes = nelems * dtype.itemsize + # Source buffer - only root will have meaningful data pe_root = 0 # PE 0 will be the root src = symm_mem.empty(nelems, dtype=dtype, device=self.device) + # Destination buffer + dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999) + if rank == pe_root: # Root fills with specific pattern for i in range(nelems): @@ -943,25 +943,28 @@ def test_triton_broadcast(self) -> None: else: # Non-root PEs have dummy data src.fill_(-1) - # Destination buffer - dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) + + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) + # Synchronize before broadcast dist.barrier() + # Execute broadcast team_handle = 0 # NVSHMEM_TEAM_WORLD - nvshmem_broadcastmem_block_kernel[(1,)]( + nvshmem_broadcast_kernel[(1,)]( team_handle, - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], - size_bytes=size_bytes, - pe_root=pe_root, + dst, + src, + nelems, + pe_root, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) + # Synchronize after broadcast dist.barrier() + # Verify results - all ranks should have the root's data expected = [100 + i for i in range(nelems)] torch.testing.assert_close( @@ -1001,8 +1004,8 @@ def test_triton_sum_reduce(self, dtype) -> None: src[i] = (rank + 1) * (i + 1) # Rank 0: [1,2,3], Rank 1: [2,4,6], etc. # Destination buffer dst = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) # Calculate expected results expected = [] for i in range(nreduce): @@ -1017,11 +1020,10 @@ def test_triton_sum_reduce(self, dtype) -> None: team_handle = 0 # NVSHMEM_TEAM_WORLD nvshmem_reduce_kernel[(1,)]( team_handle, - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], + dst, + src, nreduce, operation="sum", - dtype_id=src.dtype, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) @@ -1076,10 +1078,10 @@ def test_triton_minmax_reduce(self, dtype) -> None: # Destination buffers dst_min = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) dst_max = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) - src_min_hdl = symm_mem.rendezvous(src_min, group=group_name) - src_max_hdl = symm_mem.rendezvous(src_max, group=group_name) - dst_min_hdl = symm_mem.rendezvous(dst_min, group=group_name) - dst_max_hdl = symm_mem.rendezvous(dst_max, group=group_name) + symm_mem.rendezvous(src_min, group=group_name) + symm_mem.rendezvous(src_max, group=group_name) + symm_mem.rendezvous(dst_min, group=group_name) + symm_mem.rendezvous(dst_max, group=group_name) # Calculate expected results all_values = [] for i in range(nreduce): @@ -1097,22 +1099,20 @@ def test_triton_minmax_reduce(self, dtype) -> None: team_handle = 0 nvshmem_reduce_kernel[(1,)]( team_handle, - dst_min_hdl.buffer_ptrs[rank], - src_min_hdl.buffer_ptrs[rank], + dst_min, + src_min, nreduce, operation="min", - dtype_id=src_min.dtype, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) # Execute MAX reduction nvshmem_reduce_kernel[(1,)]( team_handle, - dst_max_hdl.buffer_ptrs[rank], - src_max_hdl.buffer_ptrs[rank], + dst_max, + src_max, nreduce, operation="max", - dtype_id=src_max.dtype, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) @@ -1167,8 +1167,8 @@ def test_triton_prod_reduce(self, dtype) -> None: src[i] = 1 if (rank // 2) % 2 == 0 else 2 # Destination buffer dst = symm_mem.empty(nreduce, dtype=dtype, device=self.device).fill_(-1) - src_hdl = symm_mem.rendezvous(src, group=group_name) - dst_hdl = symm_mem.rendezvous(dst, group=group_name) + symm_mem.rendezvous(src, group=group_name) + symm_mem.rendezvous(dst, group=group_name) # Calculate expected results vals = torch.empty(nreduce, world_size, dtype=dtype) vals[0, ::2] = 1 @@ -1186,11 +1186,10 @@ def test_triton_prod_reduce(self, dtype) -> None: team_handle = 0 # NVSHMEM_TEAM_WORLD nvshmem_reduce_kernel[(1,)]( team_handle, - dst_hdl.buffer_ptrs[rank], - src_hdl.buffer_ptrs[rank], + dst, + src, nreduce, operation="prod", - dtype_id=src.dtype, extern_libs=nvshmem_lib, launch_cooperative_grid=True, ) diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 0b6eed12b2963..c543fdffc1c76 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -1,7 +1,7 @@ import os import subprocess import sysconfig -from typing import Optional +from typing import Any, Optional from torch.utils._triton import has_triton @@ -111,106 +111,111 @@ def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] if has_triton(): + import triton + import triton.language as tl from triton.language import core - # RMA Operations (mem-based APIs - sizes in bytes) - @core.extern - def putmem_block(dst, src, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + @triton.jit # type: ignore[misc] + def put(dest, source, nelems, pe): # type: ignore[no-untyped-def] """ - Put data to remote PE using block-scoped operation. + Put tensor data from local PE to a remote PE. - This function copies a contiguous block of data from the local PE's memory - to a symmetric data object on the remote PE. The operation is performed at - thread block scope, meaning all threads in the block cooperate to perform - the transfer efficiently. + This high-level function provides a tensor-aware interface for NVSHMEM put + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. Args: - dst (int64): Symmetric address of the destination data object on the remote PE. - Must be a pointer to symmetric memory allocated via NVSHMEM. - src (int64): Local address of the source data object containing data to be copied. - Can be any valid local memory address. - size_bytes (int64): Number of bytes to transfer. Must be positive. - pe (int64): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). - _semantic: Optional semantic information for Triton compilation. - - Returns: - int32: Status code (0 for success). + dest: Destination tensor on the remote PE. Type must match source. + source: Source tensor on the local PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. - This is a blocking operation that returns after data has been copied out of the source array on the local PE. - The operation does not guarantee delivery to the destination PE. Use nvshmem_fence() for ordering or nvshmem_quiet() for completion. - - All threads in the block should call this function with the same parameters. - - The source memory remains valid for use immediately after the call returns. Example: - ```python - # Transfer 1024 bytes from local buffer to PE 1 - nvshmem.putmem_block(remote_ptr, local_ptr, 1024, 1) + ``` + # Transfer 100 elements to PE 1 + nvshmem.put(dest_tensor, src_tensor, 100, 1) ``` """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return putmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes, pe + ) + + @core.extern + def putmem_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM put""" return core.extern_elementwise( "", "", - [dst, src, size_bytes, pe], + [dest, source, size_bytes, pe], { ( - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int64"), # pe number ): ("nvshmemx_putmem_block", core.dtype("int32")) }, is_pure=False, _semantic=_semantic, ) - @core.extern - def getmem_block(dst, src, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + @triton.jit # type: ignore[misc] + def get(dest, source, nelems, pe): # type: ignore[no-untyped-def] """ - Get data from remote PE using block-scoped operation. + Get tensor data from a remote PE to local PE. - This function copies a contiguous block of data from a symmetric data object - on the remote PE to the local PE's memory. The operation is performed at - thread block scope, meaning all threads in the block cooperate to perform - the transfer efficiently. + This high-level function provides a tensor-aware interface for NVSHMEM get + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. Args: - dst (int64): Local address of the destination data object to be updated. - Can be any valid local memory address. - src (int64): Symmetric address of the source data object on the remote PE. - Must be a pointer to symmetric memory allocated via NVSHMEM. - size_bytes (int64): Number of bytes to transfer. Must be positive. - pe (int64): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). - _semantic: Optional semantic information for Triton compilation. - - Returns: - int32: Status code (0 for success). + dest: Destination tensor on the local PE. Type must match source. + source: Source tensor on the remote PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. - This is a blocking operation that returns after data has been delivered to the destination array on the local PE. - - All threads in the block should call this function with the same parameters. - The destination data is guaranteed to be available for use after the call returns. - - Provides method for copying contiguous symmetric data from different PE. Example: ``` - # Get 1024 bytes from PE 0 into local buffer - nvshmem.getmem_block(local_ptr, remote_ptr, 1024, 0) + # Get 100 elements from PE 0 + nvshmem.get(dest_tensor, src_tensor, 100, 0) ``` """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return getmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes, pe + ) + + @core.extern + def getmem_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM get""" return core.extern_elementwise( "", "", - [dst, src, size_bytes, pe], + [dest, source, size_bytes, pe], { ( - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), - core.dtype("int64"), + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int64"), # pe number ): ("nvshmemx_getmem_block", core.dtype("int32")) }, is_pure=False, @@ -288,45 +293,47 @@ def putmem_signal_block( # type: ignore[no-untyped-def] ) # Wait and Signal Operations - @core.extern - def wait_until(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] + + @triton.jit # type: ignore[misc] + def wait_until(ivar, cmp_op, cmp_val): # type: ignore[no-untyped-def] """ - Wait until a condition is met on a symmetric variable. + Wait until a tensor variable meets a specified condition. - This function blocks the calling thread until the value at the specified - symmetric memory location satisfies the given comparison condition. This - provides a mechanism for point-to-point synchronization between PEs. + This high-level function provides a tensor-aware interface for NVSHMEM wait_until + operations. It automatically handles tensor address extraction, making + the API more ergonomic and type-safe. Args: - ivar (int64): Symmetric address of the variable to monitor. Must be a - pointer to symmetric memory (typically int64/uint64). - cmp (int64): Comparison operator. Common values: - - NVSHMEM_CMP_EQ (0): Wait until ivar == cmp_val - - NVSHMEM_CMP_NE (1): Wait until ivar != cmp_val - - NVSHMEM_CMP_GT (2): Wait until ivar > cmp_val - - NVSHMEM_CMP_GE (3): Wait until ivar >= cmp_val - - NVSHMEM_CMP_LT (4): Wait until ivar < cmp_val - - NVSHMEM_CMP_LE (5): Wait until ivar <= cmp_val - cmp_val (int64): Value to compare against. - _semantic: Optional semantic information for Triton compilation. - - Returns: - int32: Status code (0 for success). + ivar_tensor: Tensor to monitor (typically int64/uint64) in symmetric memory. + cmp: Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until ivar == cmp_val + - NVSHMEM_CMP_NE (1): Wait until ivar != cmp_val + - NVSHMEM_CMP_GT (2): Wait until ivar > cmp_val + - NVSHMEM_CMP_GE (3): Wait until ivar >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until ivar < cmp_val + - NVSHMEM_CMP_LE (5): Wait until ivar <= cmp_val + cmp_val: Value to compare against. Notes: - This is a blocking operation that will wait indefinitely until the condition is satisfied. - - The variable must be in symmetric memory and accessible from other PEs. - - Updates to the variable from remote PEs will eventually become visible. - - Can be used with put operations from other PEs for synchronization. + - The tensor must be in symmetric memory and accessible from other PEs. Example: ``` - # Wait until flag becomes 1 (set by another PE) + # Wait until flag tensor becomes 1 (set by another PE) NVSHMEM_CMP_EQ = 0 - nvshmem.wait_until(flag_ptr, NVSHMEM_CMP_EQ, 1) + nvshmem.wait_until_tensor(flag_tensor, NVSHMEM_CMP_EQ, 1) ``` """ + tl.static_assert( + ivar.type.element_ty.itemsize == 8, + "wait_until expects a 64-bit type for the synchronization variable", + ) + return wait_until_extern_wrapper(ivar.to(tl.int64), cmp_op, cmp_val) + + @core.extern + def wait_until_extern_wrapper(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] return core.extern_elementwise( "", "", @@ -482,9 +489,9 @@ def fence(_semantic=None): # type: ignore[no-untyped-def] Example: ``` # Ensure first put completes before second put to same PE - nvshmem.putmem_block(dst1, src1, size, target_pe) + nvshmem.put(dst, src, nelems, target_pe) nvshmem.fence() # Enforce ordering - nvshmem.putmem_block(dst2, src2, size, target_pe) + nvshmem.put(dst2, src2, nelems, target_pe) ``` """ return core.extern_elementwise( @@ -727,47 +734,44 @@ def sync_all(_semantic=None): # type: ignore[no-untyped-def] ) # Collective Operations (mem-based APIs - sizes in bytes) - @core.extern - def alltoallmem_block(team, dest, source, size_bytes, _semantic=None): # type: ignore[no-untyped-def] + @triton.jit # type: ignore[misc] + def alltoall(team, dest, source, nelems_per_pe): # type: ignore[no-untyped-def] """ - Perform alltoall collective operation on symmetric memory. + All-to-all tensor exchange between PEs in a team. - This function implements an all-to-all collective communication pattern where - each PE sends a portion of its data to every other PE, and receives data from - every other PE. The operation exchanges size_bytes of data between each pair of PEs. + This high-level function provides a tensor-aware interface for NVSHMEM alltoall + operations. Each PE sends nelems_per_pe elements to every other PE and receives + the same amount from every other PE. Args: - team (int64): Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD - (all PEs in the job). - dest (int64): Symmetric address of the destination buffer. Must be large enough - to hold size_bytes * n_pes total bytes. - source (int64): Symmetric address of the source buffer containing data to send. - Must contain size_bytes * n_pes total bytes. - size_bytes (int64): Number of bytes to exchange with each PE. Must be positive. - _semantic: Optional semantic information for Triton compilation. - - Returns: - int32: Status code (0 for success). - - Data Layout: - - Source buffer layout: [data_for_pe0, data_for_pe1, ..., data_for_pe(n-1)] - - Destination buffer layout: [data_from_pe0, data_from_pe1, ..., data_from_pe(n-1)] - - Each segment is size_bytes in length + team: Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD. + dest: Destination tensor. Must be large enough for nelems_per_pe * n_pes elements. + source: Source tensor containing data for all PEs. Must contain nelems_per_pe * n_pes elements. + nelems_per_pe: Number of elements to exchange with each PE. Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. - This is a collective operation - all PEs in the team must participate. - - Must be called from kernels launched with cooperative launch. - - The source and destination buffers must not overlap. - - All PEs must call with the same size_bytes value. - - Provides efficient many-to-many data exchange pattern. + - Data layout: source=[data_for_pe0, data_for_pe1, ...], dest=[data_from_pe0, data_from_pe1, ...] Example: ``` - # Each PE sends 1024 bytes to every other PE - team_world = 0 - nvshmem.alltoallmem_block(team_world, dest_ptr, src_ptr, 1024) + # Each PE exchanges 10 elements with every other PE + nvshmem.alltoall(0, dest_tensor, src_tensor, 10) ``` """ + tl.static_assert(dest.type == source.type) + size_bytes_per_pe = nelems_per_pe * dest.type.element_ty.itemsize + return alltoallmem_block_extern_wrapper( + team, dest.to(tl.int64), source.to(tl.int64), size_bytes_per_pe + ) + + @core.extern # type: ignore[misc] + def alltoallmem_block_extern_wrapper( + team: Any, dest: Any, source: Any, size_bytes: Any, _semantic: Any = None + ) -> None: + """Low-level extern wrapper for NVSHMEM alltoall""" return core.extern_elementwise( "", "", @@ -784,46 +788,50 @@ def alltoallmem_block(team, dest, source, size_bytes, _semantic=None): # type: _semantic=_semantic, ) - @core.extern - def broadcastmem_block(team, dest, source, size_bytes, pe_root, _semantic=None): # type: ignore[no-untyped-def] + @triton.jit # type: ignore[misc] + def broadcast(team, dest, source, nelems, pe_root): # type: ignore[no-untyped-def] """ - Broadcast data from a root PE to all other PEs in a team. + Broadcast tensor data from a root PE to all other PEs in a team. - This function implements a collective broadcast operation where the root PE - sends its data to all other PEs in the team. All PEs (including the root) - receive a copy of the data from the root PE in their destination buffer. + This high-level function provides a tensor-aware interface for NVSHMEM broadcast + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. Args: - team (int64): Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD - (all PEs in the job). - dest (int64): Symmetric address of the destination buffer on all PEs. - Must be large enough to hold size_bytes. - source (int64): Symmetric address of the source buffer on the root PE. - Only the root PE's source buffer is used. - size_bytes (int64): Number of bytes to broadcast. Must be positive. - pe_root (int64): PE number of the root PE that provides the source data. - _semantic: Optional semantic information for Triton compilation. - - Returns: - int32: Status code (0 for success). + team: Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD. + dest: Destination tensor with type information. All PEs receive data here. + source: Source tensor on the root PE. Type must match dest. + nelems: Number of elements to broadcast. + pe_root: PE number of the root PE that provides the source data. Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. - This is a collective operation - all PEs in the team must participate. - Must be called from kernels launched with cooperative launch. - - Only the root PE's source buffer is read; other PEs' source buffers are ignored. - - All PEs (including root) receive the data in their destination buffer. - - All PEs must call with the same team, size_bytes, and pe_root values. - - The source and destination buffers must not overlap on any PE. - - Efficient one-to-many communication pattern. Example: ``` - # PE 0 broadcasts 1024 bytes to all PEs in the team - team_world = 0 - root_pe = 0 - nvshmem.broadcastmem_block(team_world, dest_ptr, src_ptr, 1024, root_pe) + # Broadcast 100 elements from PE 0 to all PEs + nvshmem.broadcast(0, dest_tensor, src_tensor, 100, 0) ``` """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return broadcastmem_block_extern_wrapper( + team, dest.to(tl.int64), source.to(tl.int64), nbytes, pe_root + ) + + @core.extern # type: ignore[misc] + def broadcastmem_block_extern_wrapper( + team: Any, + dest: Any, + source: Any, + size_bytes: Any, + pe_root: Any, + _semantic: Any = None, + ) -> None: + """Low-level extern wrapper for NVSHMEM broadcast""" return core.extern_elementwise( "", "", @@ -842,10 +850,56 @@ def broadcastmem_block(team, dest, source, size_bytes, pe_root, _semantic=None): ) # Reduction Operation + @triton.jit # type: ignore[misc] + def reduce(team, dest, source, nreduce, operation: tl.constexpr): # type: ignore[no-untyped-def] + """ + Performs a collective reduction on tensors across a team of PEs. + + This high-level function provides a tensor-aware interface for NVSHMEM + reduction operations. It automatically infers the data type from the + input tensors and calls the appropriate underlying NVSHMEM function. + + Args: + team: The team handle for the collective (0 for NVSHMEM_TEAM_WORLD). + dest: Destination tensor for the reduction results. + source: Source tensor containing data to be reduced. Must be the same type as dest. + nreduce: The number of elements in the source tensor to reduce. + operation: The reduction operation to perform ("sum", "max", "min", "prod"). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - This is a collective operation that must be called by all PEs in the team. + - Requires a cooperative grid launch. + + Example: + ``` + # Perform a sum reduction on two tensors + nvshmem.reduce(0, dest_tensor, src_tensor, 100, "sum") + ``` + """ + tl.static_assert(dest.type == source.type) + dtype = dest.type.element_ty + return reduce_extern_wrapper( + team, + dest.to(tl.int64), + source.to(tl.int64), + nreduce, + operation, + dtype, + ) + @core.extern # type: ignore[misc] - def reduce(team, dest, source, nreduce, operation: str, dtype_id, _semantic=None): # type: ignore[no-untyped-def] + def reduce_extern_wrapper( + team: Any, + dest: Any, + source: Any, + nreduce: Any, + operation: str, + dtype: Any, + _semantic: Any = None, + ) -> None: """ - Performs a collective reduction operation on symmetric data across a team of PEs. + Low-level extern wrapper for NVSHMEM reduction operations. This function provides a generic interface to NVSHMEM reduction operations, automatically selecting the appropriate NVSHMEM function based on the data type @@ -856,7 +910,7 @@ def reduce(team, dest, source, nreduce, operation: str, dtype_id, _semantic=None source (pointer): Source pointer containing data to be reduced. nreduce (int64): Number of elements to reduce. operation (str): Reduction operation ("sum", "max", "min", "prod"). - dtype_id: Data type specification - accepts torch.dtype, tl.dtype, str, or constexpr. + dtype: Data type specification - accepts torch.dtype, tl.dtype, str, or constexpr. _semantic: Optional semantic information for Triton compilation. Raises: @@ -866,7 +920,7 @@ def reduce(team, dest, source, nreduce, operation: str, dtype_id, _semantic=None Example: nvshmem.reduce(0, dest_ptr, src_ptr, 100, "sum", torch.float32) """ - # Mapping from PyTorch/Triton dtype names to NVSHMEM typenames + # Mapping from Triton dtype names to NVSHMEM typenames DTYPE_TO_NVSHMEM_MAP = { "int8": "int8", "int16": "int16", @@ -876,36 +930,23 @@ def reduce(team, dest, source, nreduce, operation: str, dtype_id, _semantic=None "uint16": "uint16", "uint32": "uint32", "uint64": "uint64", - "float16": "half", - "bfloat16": "bfloat16", - "float32": "float", - "float64": "double", + "fp16": "half", + "bf16": "bfloat16", + "fp32": "float", + "fp64": "double", } + # Triton dtype names are standardized as fp16, bf16, fp32, etc. + dtype_name = str(dtype).replace("tl.", "") + + if dtype_name not in DTYPE_TO_NVSHMEM_MAP: + raise TypeError( + f"Unsupported reduction dtype: {dtype_name}. Supported dtypes: {list(DTYPE_TO_NVSHMEM_MAP.keys())}" + ) + # Extract operation name from constexpr if needed op_name = operation.value if hasattr(operation, "value") else operation - # Normalize dtype_id to a canonical string name - # Handle different input formats: tl.dtype, torch.dtype, str, constexpr[dtype] - if hasattr(dtype_id, "name"): - # Triton language dtype (e.g., tl.float32) - dtype_name = dtype_id.name - elif isinstance(dtype_id, str): - # Already a plain string name - dtype_name = dtype_id - elif hasattr(dtype_id, "value"): - # Constexpr wrapper around a dtype - inner_value = dtype_id.value - if hasattr(inner_value, "name"): - # Triton dtype inside constexpr - dtype_name = inner_value.name - else: - # PyTorch dtype inside constexpr - dtype_name = str(inner_value).replace("torch.", "") - else: - # PyTorch dtype (e.g., torch.float32) - dtype_name = str(dtype_id).replace("torch.", "") - # Validate operation is supported supported_ops = {"sum", "max", "min", "prod"} if op_name not in supported_ops: From 178515d0ff6833c8e9221482b2a650ab31e00019 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 8 Aug 2025 01:14:36 +0800 Subject: [PATCH 0141/1424] [BE][PYFMT] remove `black`: finish `black -> ruff format` migration (#144557) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144557 Approved by: https://github.com/ezyang --- .lintrunner.toml | 2 - pyproject.toml | 3 - tools/linter/adapters/black_linter.py | 225 -------------------------- tools/linter/adapters/pip_init.py | 7 - tools/linter/adapters/pyfmt_linter.py | 61 +------ 5 files changed, 1 insertion(+), 297 deletions(-) delete mode 100644 tools/linter/adapters/black_linter.py diff --git a/.lintrunner.toml b/.lintrunner.toml index 9c46c91b5e353..3e28de5d16b94 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1452,8 +1452,6 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - '--no-black-binary', - 'black==23.12.1', 'usort==1.0.8.post1', 'isort==6.0.1', 'ruff==0.12.2', # sync with RUFF diff --git a/pyproject.toml b/pyproject.toml index c42aa782407fa..a911a2a723b14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,9 +69,6 @@ pyyaml = ["pyyaml"] # Linter tools ################################################################# -[tool.black] -line-length = 88 - [tool.isort] src_paths = ["caffe2", "torch", "torchgen", "functorch", "test"] extra_standard_library = ["typing_extensions"] diff --git a/tools/linter/adapters/black_linter.py b/tools/linter/adapters/black_linter.py deleted file mode 100644 index c22a89032cfb3..0000000000000 --- a/tools/linter/adapters/black_linter.py +++ /dev/null @@ -1,225 +0,0 @@ -from __future__ import annotations - -import argparse -import concurrent.futures -import json -import logging -import os -import subprocess -import sys -import time -from enum import Enum -from typing import BinaryIO, NamedTuple - - -IS_WINDOWS: bool = os.name == "nt" - - -class LintSeverity(str, Enum): - ERROR = "error" - WARNING = "warning" - ADVICE = "advice" - DISABLED = "disabled" - - -class LintMessage(NamedTuple): - path: str | None - line: int | None - char: int | None - code: str - severity: LintSeverity - name: str - original: str | None - replacement: str | None - description: str | None - - -def as_posix(name: str) -> str: - return name.replace("\\", "/") if IS_WINDOWS else name - - -def _run_command( - args: list[str], - *, - stdin: BinaryIO, - timeout: int, -) -> subprocess.CompletedProcess[bytes]: - logging.debug("$ %s", " ".join(args)) - start_time = time.monotonic() - try: - return subprocess.run( - args, - stdin=stdin, - capture_output=True, - shell=IS_WINDOWS, # So batch scripts are found. - timeout=timeout, - check=True, - ) - finally: - end_time = time.monotonic() - logging.debug("took %dms", (end_time - start_time) * 1000) - - -def run_command( - args: list[str], - *, - stdin: BinaryIO, - retries: int, - timeout: int, -) -> subprocess.CompletedProcess[bytes]: - remaining_retries = retries - while True: - try: - return _run_command(args, stdin=stdin, timeout=timeout) - except subprocess.TimeoutExpired as err: - if remaining_retries == 0: - raise err - remaining_retries -= 1 - logging.warning( - "(%s/%s) Retrying because command failed with: %r", - retries - remaining_retries, - retries, - err, - ) - time.sleep(1) - - -def check_file( - filename: str, - retries: int, - timeout: int, -) -> list[LintMessage]: - try: - with open(filename, "rb") as f: - original = f.read() - with open(filename, "rb") as f: - proc = run_command( - [sys.executable, "-mblack", "--stdin-filename", filename, "-"], - stdin=f, - retries=retries, - timeout=timeout, - ) - except subprocess.TimeoutExpired: - return [ - LintMessage( - path=filename, - line=None, - char=None, - code="BLACK", - severity=LintSeverity.ERROR, - name="timeout", - original=None, - replacement=None, - description=( - "black timed out while trying to process a file. " - "Please report an issue in pytorch/pytorch with the " - "label 'module: lint'" - ), - ) - ] - except (OSError, subprocess.CalledProcessError) as err: - return [ - LintMessage( - path=filename, - line=None, - char=None, - code="BLACK", - severity=LintSeverity.ADVICE, - name="command-failed", - original=None, - replacement=None, - description=( - f"Failed due to {err.__class__.__name__}:\n{err}" - if not isinstance(err, subprocess.CalledProcessError) - else ( - "COMMAND (exit code {returncode})\n" - "{command}\n\n" - "STDERR\n{stderr}\n\n" - "STDOUT\n{stdout}" - ).format( - returncode=err.returncode, - command=" ".join(as_posix(x) for x in err.cmd), - stderr=err.stderr.decode("utf-8").strip() or "(empty)", - stdout=err.stdout.decode("utf-8").strip() or "(empty)", - ) - ), - ) - ] - - replacement = proc.stdout - if original == replacement: - return [] - - return [ - LintMessage( - path=filename, - line=None, - char=None, - code="BLACK", - severity=LintSeverity.WARNING, - name="format", - original=original.decode("utf-8"), - replacement=replacement.decode("utf-8"), - description="Run `lintrunner -a` to apply this patch.", - ) - ] - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Format files with black.", - fromfile_prefix_chars="@", - ) - parser.add_argument( - "--retries", - default=3, - type=int, - help="times to retry timed out black", - ) - parser.add_argument( - "--timeout", - default=90, - type=int, - help="seconds to wait for black", - ) - parser.add_argument( - "--verbose", - action="store_true", - help="verbose logging", - ) - parser.add_argument( - "filenames", - nargs="+", - help="paths to lint", - ) - args = parser.parse_args() - - logging.basicConfig( - format="<%(threadName)s:%(levelname)s> %(message)s", - level=logging.NOTSET - if args.verbose - else logging.DEBUG - if len(args.filenames) < 1000 - else logging.INFO, - stream=sys.stderr, - ) - - with concurrent.futures.ThreadPoolExecutor( - max_workers=os.cpu_count(), - thread_name_prefix="Thread", - ) as executor: - futures = { - executor.submit(check_file, x, args.retries, args.timeout): x - for x in args.filenames - } - for future in concurrent.futures.as_completed(futures): - try: - for lint_message in future.result(): - print(json.dumps(lint_message._asdict()), flush=True) - except Exception: - logging.critical('Failed at "%s".', futures[future]) - raise - - -if __name__ == "__main__": - main() diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index 137e4637bdb44..05a7a8acf9324 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -41,11 +41,6 @@ def main() -> None: parser.add_argument( "--dry-run", help="do not install anything, just print what would be done." ) - parser.add_argument( - "--no-black-binary", - help="do not use pre-compiled binaries from pip for black.", - action="store_true", - ) args = parser.parse_args() @@ -97,8 +92,6 @@ def main() -> None: "Package {package_name} did not have a version specified. " "Please specify a version to produce a consistent linting experience." ) - if args.no_black_binary and "black" in package_name: - pip_args.append(f"--no-binary={package_name}") dry_run = args.dry_run == "1" if dry_run: diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 927325bffeb2f..ce5f8252a20f0 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -2,7 +2,6 @@ import argparse import concurrent.futures -import fnmatch import json import logging import os @@ -13,7 +12,6 @@ from pathlib import Path from typing import NamedTuple -import black import isort import usort @@ -21,43 +19,6 @@ IS_WINDOWS: bool = os.name == "nt" REPO_ROOT = Path(__file__).absolute().parents[3] -# TODO: remove this when it gets empty and remove `black` in PYFMT -USE_BLACK_FILELIST = re.compile( - "|".join( - ( - r"\A\Z", # empty string - *map( - fnmatch.translate, - [ - # ** - # .ci/** - # .github/** - # benchmarks/** - # functorch/** - # tools/** - # torchgen/** - # test/** - # test/[a-h]*/** - # test/[i-j]*/** - # test/[k-m]*/** - # test/optim/** - # test/[p-z]*/**, - # torch/** - # torch/_[a-c]*/** - # torch/_[e-h]*/** - # torch/_i*/** - # torch/_[j-z]*/** - # torch/[a-c]*/** - # torch/d*/** - # torch/[e-m]*/** - # torch/optim/** - # torch/[p-z]*/** - ], - ), - ) - ) -) - class LintSeverity(str, Enum): ERROR = "error" @@ -117,23 +78,6 @@ def run_usort(content: str, path: Path) -> str: return usort.usort_string(content, path=path, config=usort_config) -def run_black(content: str, path: Path) -> str: - black_config = black.parse_pyproject_toml(black.find_pyproject_toml((str(path),))) # type: ignore[attr-defined,arg-type] - # manually patch options that do not have a 1-to-1 match in Mode arguments - black_config["target_versions"] = { - black.TargetVersion[ver.upper()] # type: ignore[attr-defined] - for ver in black_config.pop("target_version", []) - } - black_config["string_normalization"] = not black_config.pop( - "skip_string_normalization", False - ) - black_mode = black.Mode(**black_config) - black_mode.is_pyi = path.suffix.lower() == ".pyi" - black_mode.is_ipynb = path.suffix.lower() == ".ipynb" - - return black.format_str(content, mode=black_mode) - - def run_ruff_format(content: str, path: Path) -> str: try: return subprocess.check_output( @@ -165,10 +109,7 @@ def check_file(filename: str) -> list[LintMessage]: # NB: run isort first to enforce style for blank lines replacement = run_isort(replacement, path=path) replacement = run_usort(replacement, path=path) - if USE_BLACK_FILELIST.match(path.absolute().relative_to(REPO_ROOT).as_posix()): - replacement = run_black(replacement, path=path) - else: - replacement = run_ruff_format(replacement, path=path) + replacement = run_ruff_format(replacement, path=path) if original == replacement: return [] From 556e2a73f4f0643f7c2aeb5c7dddda43388a40ce Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Fri, 8 Aug 2025 09:56:44 +0000 Subject: [PATCH 0142/1424] [Test][Easy] Use float16 dtype in test_sort_large (#159939) The test fails with: >RuntimeError: var_mean only support floating point and complex dtypes Pull Request resolved: https://github.com/pytorch/pytorch/pull/159939 Approved by: https://github.com/eqy --- test/test_sort_and_select.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 360dc058212a0..669f165529e71 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -215,7 +215,7 @@ def test_stable_sort(self, device, dtype): ) @onlyCUDA - @dtypes(torch.uint8) + @dtypes(torch.float16) @largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough def test_sort_large(self, device, dtype): t0 = torch.randperm(8192, device=device).to(dtype) From 7f4cb4a3e018a621add2a37a3a2f67b982d51001 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Fri, 8 Aug 2025 13:49:55 +0000 Subject: [PATCH 0143/1424] [MPS] coalesce for sparse tensors (#159729) MPS coalesce function for sparse tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/159729 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- aten/src/ATen/CMakeLists.txt | 8 +- aten/src/ATen/native/native_functions.yaml | 1 + .../ATen/native/sparse/mps/SparseMPSTensor.mm | 220 ++++++++++++++++++ .../native/sparse/mps/kernels/Sparse.metal | 123 ++++++++++ c10/core/Backend.h | 4 +- c10/core/Layout.h | 2 +- c10/core/TensorImpl.h | 1 + test/test_mps.py | 59 +++++ torchgen/gen.py | 9 +- 9 files changed, 416 insertions(+), 11 deletions(-) create mode 100644 aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm create mode 100644 aten/src/ATen/native/sparse/mps/kernels/Sparse.metal diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index b02638e5b6de7..547b36f10936f 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -119,6 +119,8 @@ file(GLOB_RECURSE native_mps_cpp "native/mps/*.cpp") file(GLOB_RECURSE native_mps_mm "native/mps/*.mm") file(GLOB_RECURSE native_mps_metal "native/mps/*.metal") file(GLOB_RECURSE native_mps_h "native/mps/*.h") +file(GLOB_RECURSE native_sparse_mps_mm "native/sparse/mps/*.mm") +file(GLOB_RECURSE native_mps_sparse_metal "native/sparse/mps/*.metal") file(GLOB native_sparse_cpp "native/sparse/*.cpp") file(GLOB native_quantized_cpp @@ -699,10 +701,10 @@ endif() if(USE_MPS) include(../../../cmake/Metal.cmake) - set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h}) + set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h} ${native_sparse_mps_mm}) if(CAN_COMPILE_METAL) - foreach(SHADER ${native_mps_metal}) + foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal}) cmake_path(GET SHADER STEM TGT_STEM) string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air") list(APPEND AIR_BASIC ${TGT_BASIC}) @@ -717,7 +719,7 @@ if(USE_MPS) add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp) else() file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps") - foreach(SHADER ${native_mps_metal}) + foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal}) cmake_path(GET SHADER STEM TGT_STEM) string(CONCAT SHADER_HDR_NAME "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h") metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME}) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8920864b3a719..9f3c7468a6af4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7423,6 +7423,7 @@ dispatch: SparseCPU: _coalesce_sparse_cpu SparseCUDA: _coalesce_sparse_cuda + SparseMPS: _coalesce_sparse_mps autogen: _coalesce.out - func: is_coalesced(Tensor self) -> bool diff --git a/aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm b/aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm new file mode 100644 index 0000000000000..7ccdf4077542e --- /dev/null +++ b/aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm @@ -0,0 +1,220 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +namespace at::native { + +using namespace mps; +using namespace at::sparse; + +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif + + +static Tensor flatten_indices(const Tensor& indices, IntArrayRef size) { + + TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D"); + TORCH_CHECK(static_cast(indices.size(0)) == size.size(), + "flatten_indices: indices.size(0) must equal size.size()"); + + int64_t sparse_dim = indices.size(0); + int64_t nnz = indices.size(1); + + if (nnz == 0) { + return at::empty({0}, indices.options().dtype(kLong)); + } + + std::vector strides(sparse_dim); + strides[sparse_dim - 1] = 1; + for (int64_t i = sparse_dim - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * size[i + 1]; + } + + Tensor flat_indices = at::empty({nnz}, indices.options().dtype(kLong)); + + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, indices, strides, flat_indices, sparse_dim, nnz); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + return flat_indices; +} + +static Tensor compute_output_positions(const Tensor& is_unique) { + + int64_t nnz = is_unique.size(0); + if (nnz == 0) { + return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt)); + } + + Tensor positions = at::empty({nnz}, TensorOptions().device(kMPS).dtype(kInt)); + + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("compute_output_positions_kernel"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, is_unique, positions); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + return positions; +} + +static Tensor compute_output_positions_parallel(const Tensor& is_unique) { + + int64_t nnz = is_unique.size(0); + if (nnz == 0) { + return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt)); + } + + // for small arrays, use simple kernel + // speed of the naive kernel drops off after 4096 nnz elements + if (nnz <= 4096) { + return compute_output_positions(is_unique); + } + auto stream = getCurrentMPSStream(); + Tensor positions = is_unique.to(kInt); + // Kogge-Stone parallel prefix sum + Tensor positions_cloned = positions.clone(); + + for (int64_t stride = 1; stride < nnz; stride *= 2) { + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("kogge_stone_step"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, positions, positions_cloned, stride); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + std::swap(positions, positions_cloned); + } + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("shift_right_kernel"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, positions, positions_cloned); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + return positions_cloned; +} + +static std::pair mark_unique_and_count(const Tensor& flat_indices) { + + int64_t nnz = flat_indices.size(0); + if (nnz == 0) { + return {at::empty({0}, flat_indices.options().dtype(kBool)), 0}; + } + + Tensor is_unique = at::empty({nnz}, flat_indices.options().dtype(kBool)); + Tensor count_result = at::zeros({1}, flat_indices.options().dtype(kInt)); + + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("mark_unique_positions_and_count_kernel"); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + mtl_setArgs(encoder, flat_indices, is_unique, count_result); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + int32_t num_unique = count_result.item(); + + return {is_unique, num_unique}; +} + +SparseTensor _coalesce_sparse_mps(const SparseTensor& self) { + int64_t nnz = self._nnz(); + TORCH_INTERNAL_ASSERT(!self.is_coalesced()); + if (nnz < 2) { + SparseTensor dst = self.clone(); + dst._coalesced_(true); + return dst; + } + + Tensor indices = self._indices(); + Tensor values = self._values(); + + Tensor flat_indices = flatten_indices(indices, self.sizes()); + Tensor sorted_order = flat_indices.argsort(); + Tensor flat_indices_sorted = flat_indices.index({sorted_order}); + values = values.index({sorted_order}); + indices = indices.index_select(1, sorted_order); + + auto unique_info = mark_unique_and_count(flat_indices_sorted); + Tensor is_unique = unique_info.first; + int32_t newNnz = unique_info.second; + + Tensor output_positions = compute_output_positions_parallel(is_unique); + + Tensor out_indices = at::empty({indices.size(0), newNnz}, indices.options()); + auto outValuesSize = values.sizes().vec(); + outValuesSize[0] = newNnz; + Tensor out_values = at::zeros(outValuesSize, values.options()); + + Tensor is_unique_local = is_unique; + int64_t sparse_dim = indices.size(0); + + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto pipeline = lib.getPipelineStateForFunc("coalesce_with_positions_kernel_" + scalarToMetalTypeString(values)); + auto encoder = stream->commandEncoder(); + [encoder setComputePipelineState:pipeline]; + + const uint32_t numThreads = static_cast(nnz); + const uint32_t valueSize = static_cast(values.numel() / nnz); + mtl_setArgs(encoder, + flat_indices_sorted, + indices, + values, + is_unique_local, + output_positions, + out_indices, + out_values, + numThreads, + valueSize, + sparse_dim, + newNnz); + mtl_dispatch1DJob(encoder, pipeline, nnz); + } + }); + + SparseTensor result = _sparse_coo_tensor_unsafe_symint(out_indices, out_values, self.sym_sizes())._coalesced_(true); + return result; +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal b/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal new file mode 100644 index 0000000000000..ff76b9b6b5209 --- /dev/null +++ b/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal @@ -0,0 +1,123 @@ +#include +#include +using namespace metal; + +kernel void flatten_indices_kernel( + device const int64_t* indices [[buffer(0)]], + device const int64_t* strides [[buffer(1)]], + device int64_t* flat_indices [[buffer(2)]], + constant uint& sparse_dim [[buffer(3)]], + constant uint& nnz [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + int64_t flat_idx = 0; + for (uint d = 0; d < sparse_dim; d++) { + flat_idx += indices[d * nnz + gid] * strides[d]; + } + flat_indices[gid] = flat_idx; +} + +kernel void compute_output_positions_kernel( + device const bool* is_unique [[buffer(0)]], + device int* positions [[buffer(1)]], + uint gid [[thread_position_in_grid]]) { + int pos = 0; + for (uint i = 0; i < gid; i++) { + if (is_unique[i]) + pos++; + } + positions[gid] = pos; +} + +kernel void mark_unique_positions_and_count_kernel( + device const int64_t* flat_indices [[buffer(0)]], + device bool* is_unique [[buffer(1)]], + device atomic_int* count [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + bool unique = (tid == 0) || (flat_indices[tid] != flat_indices[tid - 1]); + is_unique[tid] = unique; + + if (unique) { + atomic_fetch_add_explicit(count, 1, memory_order_relaxed); + } +} + +// Kogge-Stone parallel prefix sum step +kernel void kogge_stone_step( + device const int* input [[buffer(0)]], + device int* output [[buffer(1)]], + constant uint& stride [[buffer(2)]], + uint gid [[thread_position_in_grid]]) { + int val = input[gid]; + if (gid >= stride) { + val += input[gid - stride]; + } + output[gid] = val; +} + +// Shift right for exclusive scan +kernel void shift_right_kernel( + device const int* input [[buffer(0)]], + device int* output [[buffer(1)]], + uint gid [[thread_position_in_grid]]) { + output[gid] = (gid == 0) ? 0 : input[gid - 1]; +} + +template +kernel void coalesce_with_positions_kernel( + device const int64_t* flat_indices [[buffer(0)]], + device const int64_t* indices [[buffer(1)]], + device const T* in_values [[buffer(2)]], + device const bool* is_unique [[buffer(3)]], + device const int* output_positions [[buffer(4)]], + device int64_t* out_indices [[buffer(5)]], + device T* out_values [[buffer(6)]], + constant uint& nnz [[buffer(7)]], + constant uint& value_size [[buffer(8)]], + constant uint& sparse_dim [[buffer(9)]], + constant uint& total_unique [[buffer(10)]], + uint gid [[thread_position_in_grid]]) { + if (!is_unique[gid]) + return; + + int out_pos = output_positions[gid]; + + for (uint d = 0; d < sparse_dim; d++) { + out_indices[d * total_unique + out_pos] = indices[d * nnz + gid]; + } + + int64_t current_index = flat_indices[gid]; + uint end = gid + 1; + while (end < nnz && flat_indices[end] == current_index) { + end++; + } + + for (uint elem = 0; elem < value_size; elem++) { + T sum = 0; + for (uint j = gid; j < end; j++) { + sum += in_values[j * value_size + elem]; + } + out_values[out_pos * value_size + elem] = sum; + } +} + +#define INSTANTIATE_COALESCE_WITH_POSITIONS(DTYPE) \ + template \ + [[host_name("coalesce_with_positions_kernel_" #DTYPE)]] [[kernel]] void \ + coalesce_with_positions_kernel( \ + device const int64_t* flat_indices [[buffer(0)]], \ + device const int64_t* indices [[buffer(1)]], \ + device const DTYPE* in_values [[buffer(2)]], \ + device const bool* is_unique [[buffer(3)]], \ + device const int* output_positions [[buffer(4)]], \ + device int64_t* out_indices [[buffer(5)]], \ + device DTYPE* out_values [[buffer(6)]], \ + constant uint& nnz [[buffer(7)]], \ + constant uint& value_size [[buffer(8)]], \ + constant uint& sparse_dim [[buffer(9)]], \ + constant uint& total_unique [[buffer(10)]], \ + uint gid [[thread_position_in_grid]]); + +INSTANTIATE_COALESCE_WITH_POSITIONS(float); +INSTANTIATE_COALESCE_WITH_POSITIONS(half); +INSTANTIATE_COALESCE_WITH_POSITIONS(bfloat); +INSTANTIATE_COALESCE_WITH_POSITIONS(bool); \ No newline at end of file diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 67c9276313bba..0497d72b95703 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -237,8 +237,6 @@ inline DeviceType backendToDeviceType(Backend b) { return DeviceType::CPU; case Backend::CUDA: case Backend::SparseCUDA: - case Backend::SparseMPS: - case Backend::SparseCsrMPS: case Backend::QuantizedCUDA: case Backend::SparseCsrCUDA: return DeviceType::CUDA; @@ -276,6 +274,8 @@ inline DeviceType backendToDeviceType(Backend b) { case Backend::Meta: return DeviceType::Meta; case Backend::MPS: + case Backend::SparseMPS: + case Backend::SparseCsrMPS: return DeviceType::MPS; case Backend::HPU: return DeviceType::HPU; diff --git a/c10/core/Layout.h b/c10/core/Layout.h index 0daa129bb5a4f..0d09e0ed46f4e 100644 --- a/c10/core/Layout.h +++ b/c10/core/Layout.h @@ -33,7 +33,6 @@ inline Layout layout_from_backend(Backend backend) { case Backend::SparseCPU: case Backend::SparseCUDA: case Backend::SparseMPS: - case Backend::SparseCsrMPS: case Backend::SparseHIP: case Backend::SparseVE: case Backend::SparseXPU: @@ -43,6 +42,7 @@ inline Layout layout_from_backend(Backend backend) { return Layout::Mkldnn; case Backend::SparseCsrCPU: case Backend::SparseCsrCUDA: + case Backend::SparseCsrMPS: case Backend::SparseCsrHIP: case Backend::SparseCsrVE: case Backend::SparseCsrXPU: diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 381bc65b27fbd..fcd7b4b4b31da 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2090,6 +2090,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { constexpr auto sparse_backends = DispatchKeySet( {BackendComponent::CPUBit, BackendComponent::CUDABit, + BackendComponent::MPSBit, BackendComponent::HIPBit, BackendComponent::XPUBit}); constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse); diff --git a/test/test_mps.py b/test/test_mps.py index 975ba00cc7d8a..1deee80344404 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -12696,6 +12696,65 @@ def test_resize(self): sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0) self.assertEqual(sparse, sparse_cpu) + def test_coalesce(self): + indices = torch.tensor([[0, 0, 1, 1], [0, 0, 2, 2]], dtype=torch.int64, device="mps") + values = torch.tensor([1., 2., 3., 4.], dtype=torch.float32, device="mps") + size = (2, 3) + indices_cpu = indices.cpu() + values_cpu = values.cpu() + sparse_mps = torch.sparse_coo_tensor(indices, values, size, device="mps") + sparse_cpu = torch.sparse_coo_tensor(indices_cpu, values_cpu, size, device="cpu") + coalesced_mps = sparse_mps.coalesce() + coalesced_cpu = sparse_cpu.coalesce() + + self.assertTrue(coalesced_mps.is_coalesced()) + self.assertTrue(coalesced_cpu.is_coalesced()) + self.assertEqual(coalesced_mps._nnz(), 2) + self.assertEqual(coalesced_mps.cpu(), coalesced_cpu) + + def test_already_coalesced_tensor(self): + already_coalesced = self._get_basic_sparse_coo() + result = already_coalesced.coalesce() + self.assertTrue(result.is_coalesced()) + self.assertEqual(result._indices().cpu(), already_coalesced._indices().cpu()) + self.assertEqual(result._values().cpu(), already_coalesced._values().cpu()) + + def test_coalesce_empty_sparse_tensor(self): + empty_indices = torch.zeros((2, 0), dtype=torch.int64, device="mps") + empty_values = torch.tensor([], dtype=torch.float32, device="mps") + empty_sparse = torch.sparse_coo_tensor(empty_indices, empty_values, (3, 3), device="mps") + empty_coalesced = empty_sparse.coalesce() + self.assertTrue(empty_coalesced.is_coalesced()) + self.assertEqual(empty_coalesced._nnz(), 0) + + def test_coalesce_large_tensor(self): + size = (1000000, 1000000) + num_elements = 1000 + + # 800 unique random positions + unique_indices = torch.randint(0, size[0], (2, 800), dtype=torch.int64) + # 200 duplicates by repeating some of the first 200 indices + duplicate_indices = unique_indices[:, :200] + indices = torch.cat([unique_indices, duplicate_indices], dim=1) + # shuffle indices to mix duplicates with unique entries + perm = torch.randperm(indices.size(1)) + indices = indices[:, perm] + + values = torch.randn(num_elements, dtype=torch.float32) + indices_mps = indices.to("mps") + values_mps = values.to("mps") + sparse_mps = torch.sparse_coo_tensor(indices_mps, values_mps, size, device="mps") + sparse_cpu = torch.sparse_coo_tensor(indices, values, size, device="cpu") + + self.assertFalse(sparse_mps.is_coalesced()) + coalesced_mps = sparse_mps.coalesce() + coalesced_cpu = sparse_cpu.coalesce() + self.assertTrue(coalesced_mps.is_coalesced()) + self.assertTrue(coalesced_cpu.is_coalesced()) + self.assertEqual(coalesced_mps._nnz(), coalesced_cpu._nnz()) + self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices()) + self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values()) + # TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. # This requires mps to be properly registered in the device generic test framework which is not the diff --git a/torchgen/gen.py b/torchgen/gen.py index 7d1413827f35d..b8290d6b86844 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -2849,14 +2849,13 @@ def main() -> None: # TODO: stop generating CUDA kernels for non-CUDA builds ignore_keys = set() + MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS} if options.mps or options.update_aoti_c_shim: - functions_keys.add(DispatchKey.MPS) + functions_keys.update(MPS_KEYS) aoti_backends.add(DispatchKey.MPS) else: - ignore_keys.add(DispatchKey.MPS) - - if DispatchKey.MPS in dispatch_keys: - del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] + ignore_keys.update(MPS_KEYS) + dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS] if options.xpu or options.update_aoti_c_shim: functions_keys.add(DispatchKey.XPU) From 62bac0798100e0e06a86b7a4cee1788413e3d0ca Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 7 Aug 2025 21:58:18 -0700 Subject: [PATCH 0144/1424] [inductor][triton] support profile_scratch launcher arg (#159772) This adds support for Triton after https://github.com/triton-lang/triton/pull/7258 landed. https://github.com/triton-lang/triton/pull/7258 adds a new argument to all the Triton kernels - a profile_scratch argument, similar to global_scratch. This PR updates the static cuda launcher and the AOTI kernel callers to pass in these arguments when calling the Triton kernel. Tests: https://github.com/pytorch/pytorch/pull/159158. I also verified these test locally with triton 3.2, 3.3, and 3.4. Fixes: * static_cuda_launcher (test/repro: `python tools/dynamo/verify_dynamo.py`) * AOTI calling logic (test/repro: `TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_linalg_vander_cuda_float32`) Differential Revision: [D79825121](https://our.internmc.facebook.com/intern/diff/D79825121) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159772 Approved by: https://github.com/NikhilAPatel, https://github.com/eellison --- torch/_inductor/codegen/common.py | 4 +- torch/_inductor/codegen/cpp_wrapper_gpu.py | 43 +++++++++------ .../codegen/cuda/device_op_overrides.py | 54 +++++++++---------- .../codegen/xpu/device_op_overrides.py | 4 +- .../_inductor/runtime/static_cuda_launcher.py | 31 ++++++----- torch/_inductor/runtime/triton_heuristics.py | 19 ++++++- 6 files changed, 91 insertions(+), 64 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 471c9030f1e6c..40ebbed13ddde 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -362,8 +362,8 @@ def cpp_device_ptr(self) -> str: def tma_descriptor_helpers(self) -> str: raise NotImplementedError - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: # optionally return (scratch definition, arg name) raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 430511ce4ebf0..6bbbab8599008 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -211,12 +211,17 @@ def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): ] arg_types = [arg_type_loookup[name] for name in call_args] arg_signatures = [triton_meta["signature"][name] for name in call_args] + scratch_spaces = { + name: params[name] + for name in ["global_scratch", "profile_scratch"] + if params.get(name, None) is not None + } call_args_str = wrapper.generate_args_decl( prefix, call_args, arg_types, arg_signatures, - workspace_size=params.get("global_scratch") or 0, + scratch_spaces=scratch_spaces, ) prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") launch_kernel_args = [ @@ -454,7 +459,7 @@ def generate_args_decl( arg_types, arg_signatures, is_triton_kernel=True, - workspace_size=0, + scratch_spaces: Optional[dict[str, int]] = None, ): """ Generates any declarations of args to pass into a kernel call, and then returns the arg names. @@ -572,22 +577,26 @@ def process_args(arg, arg_type, arg_signature=None): ): process_args(arg, arg_type, arg_signature) - if ( - is_triton_kernel - and ( - global_scratch := self.device_codegen.cpp_global_scratch( - next(self.arg_var_id), - workspace=TritonScratchWorkspace( - size=workspace_size, - generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)), - ), + for scratch_name, workspace_size in (scratch_spaces or {}).items(): + if ( + is_triton_kernel + and ( + scratch := self.device_codegen.cpp_scratch( + next(self.arg_var_id), + workspace=TritonScratchWorkspace( + size=workspace_size, + generate_dtype_str=( + lambda: self.codegen_dtype(torch.uint8) + ), + ), + prefix=scratch_name, + ) ) - ) - is not None - ): - global_scratch_def, global_scratch_var = global_scratch - code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def]) - new_args.append(f"&{global_scratch_var}") + is not None + ): + scratch_def, scratch_var = scratch + code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def]) + new_args.append(f"&{scratch_var}") return ", ".join(new_args) diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 0ba0677422944..147515e0decfe 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -4,7 +4,6 @@ import torch -from ...utils import triton_version_uses_attrs_dict from ..common import ( DeviceOpOverrides, register_device_op_overrides, @@ -333,34 +332,33 @@ def cpp_kernel_type(self) -> str: def cpp_device_ptr(self) -> str: return "CUdeviceptr" - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: - if triton_version_uses_attrs_dict(): - var_name = f"global_scratch_{idx}" - if workspace.size > 0: - size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" - stride_array = f"int64_t {var_name}_stride[] = {{1}};" - device_type = "cached_torch_device_type_cuda" - device_idx = "device_idx_" - - return ( - [ - f"{size_array}", - f"{stride_array}", - f"AtenTensorHandle {var_name}_handle;", - ( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " - f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" - ), - f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", - f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", - ], - var_name, - ) - else: - return [f"CUdeviceptr {var_name} = 0;"], var_name - return None + prefix = f"{prefix}_" if prefix else "" + var_name = f"{prefix}scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" + + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 632cfd29f174f..99502ca2dd976 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -58,8 +58,8 @@ def cpp_kernel_type(self) -> str: def cpp_device_ptr(self) -> str: return "void *" - def cpp_global_scratch( - self, idx: int, workspace: TritonScratchWorkspace + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: return None diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index a52df4745f590..3290e25eeae4c 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -63,16 +63,21 @@ def __init__(self, kernel: CompiledKernel) -> None: kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared ) + def needs_scratch_arg(scratch_name: str, param_name: str) -> bool: + if hasattr(kernel.metadata, param_name): + if getattr(kernel.metadata, param_name) > 0: + raise NotImplementedError( + f"{scratch_name} scratch not yet supported" + ) + return True + return False + # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel. # Inductor never uses this field or enables it, but we still have to pass # an extra None into the set of params if its enabled - if hasattr(kernel.metadata, "global_scratch_size"): - if kernel.metadata.global_scratch_size > 0: - raise NotImplementedError("Global scratch not yet supported") - else: - self.has_global_scratch = True - else: - self.has_global_scratch = False + self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size") + # same situation for profile scratch - triton-lang/triton#7258 + self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size") self.arg_tys = self.arg_ty_from_signature(kernel.src) self.function: Optional[int] = ( @@ -214,12 +219,12 @@ def run( # thing, it should always match. # Get rid of constants before passing to cubin launcher - # Add a None if triton wants an extra parameter to the cubin - if self.has_global_scratch: - arg_tys = self.arg_tys + "O" - args = (*args, None) - else: - arg_tys = self.arg_tys + # Add a None if triton wants extra parameters for scratch spaces + arg_tys = self.arg_tys + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) assert len(args) == len(arg_tys) # TODO: can handle grid functions here or in C++, so diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index ba8de8f9829ed..8425cba55795a 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1061,6 +1061,7 @@ def save_gpu_kernel(self, stream, launcher): "def_args": launcher.def_args, "call_args": launcher.call_args, "global_scratch": launcher.global_scratch, + "profile_scratch": launcher.profile_scratch, } from torch._inductor.codecache import CudaKernelParamCache @@ -1754,9 +1755,23 @@ def make_launcher(self) -> LauncherType: launcher.def_args = def_args launcher.call_args = call_args kernel_metadata = getattr(self.kernel, "metadata", None) - launcher.global_scratch = getattr( - kernel_metadata, "global_scratch_size", None + + # for the scratch arguments: None indicates that the kernel doesn't + # take any scratch argument; otherwise a number indicates the number + # of bytes of scratch that need to be provided. + + # in AMD's Triton backend, the global scratch size is never provided + # (but for AMD it's safe to pass an extra null arg, so always include it) + global_scratch: Optional[int] = getattr( + kernel_metadata, + "global_scratch_size", + (0 if torch.version.hip else None), + ) + profile_scratch: Optional[int] = getattr( + kernel_metadata, "profile_scratch_size", None ) + launcher.global_scratch = global_scratch + launcher.profile_scratch = profile_scratch return launcher From 9fa8ce26cf638504469852cbc3e7d04579fc8674 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 6 Aug 2025 14:03:11 -0700 Subject: [PATCH 0145/1424] Working setup with runnable PyTorch on Codex. (#159968) Sample transcript: https://chatgpt.com/s/cd_68938effc1a88191ae78bc82a8cefe94 This makes use of https://github.com/pytorch/pytorch/pull/159965 to bypass doing an actual build and use nightly. Things to improve: - Once USE_NIGHTLY is in main can remove the patching - We should just keep using the latest nightly, instead of a hard coded one Signed-off-by: Edward Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/159968 Approved by: https://github.com/wdvr --- AGENTS.md | 16 ++++++++++++++++ codex_setup.sh | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100755 codex_setup.sh diff --git a/AGENTS.md b/AGENTS.md index daf0f491702ba..3d5436a02a85d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1 +1,17 @@ - This is the only AGENTS.md, there are no recursive AGENTS.md +- When you are working on a bug, first create a standalone file that + reproduces the bug and verify it fails in the expected way. Use this to + test if your changes work. Once the change is passing, find an appropriate + test file to add the test to and make sure to follow local conventions on + the test file. +- If you are running the real test suite, DO NOT run the entire test suite. + Instead run only a single test case, e.g., 'python test/test_torch.py TestTorch.test_dir' +- Do NOT run setup.py, you do not have a working build environment +- Do NOT run pre-commit, it is not setup +- To run lint, run 'lintrunner -a' (which will autoapply changes) +- Do NOT attempt to install dependencies, you do not have Internet access +- When you are ready to make a PR, do exactly these steps: + - git stash -u + - git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch + - git stash pop + - Resolve conflicts if necessary diff --git a/codex_setup.sh b/codex_setup.sh new file mode 100755 index 0000000000000..f169a7b1f6936 --- /dev/null +++ b/codex_setup.sh @@ -0,0 +1,18 @@ +set -ex +uv venv +source .venv/bin/activate +uv pip install -r requirements.txt +uv pip install numpy +lintrunner init +NIGHTLY_PATCH=$(curl -s https://github.com/pytorch/pytorch/commit/nightly.patch | head -n20) +COMMIT=$(grep -oE '[0-9a-f]{40}' <<< "$NIGHTLY_PATCH" | head -1) +COMMIT_DATE=$(echo "$NIGHTLY_PATCH" | grep '^Date:' | sed -E 's/Date: .*, ([0-9]+) ([A-Za-z]+) ([0-9]+) .*/\3 \2 \1/' | awk 'BEGIN{split("Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec", months, " "); for(i=1;i<=12;i++) month[months[i]]=sprintf("%02d",i)} {print $1 month[$2] sprintf("%02d",$3)}') +VERSION_STRING="2.9.0.dev${COMMIT_DATE}+cpu" +git rev-parse HEAD > /tmp/orig_work.txt +cp AGENTS.md /tmp +git reset --hard $COMMIT +cp /tmp/AGENTS.md . +curl https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/159965.diff | patch -p1 +USE_NIGHTLY=$VERSION_STRING python setup.py develop +git commit -asm "Agents patch" +echo "source $PWD/.venv/bin/activate" >> ~/.bashrc From b5fd7223b1bf44720dc9183bda7dfcf7aeccff02 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 8 Aug 2025 14:36:41 +0000 Subject: [PATCH 0146/1424] Improve pin_memory error message on CPU-only systems (#159994) ## Summary - clarify pin_memory error message when no accelerator backend is available ## Testing - `python repro_pin_memory.py` (fails: Need to provide pin_memory allocator to use pin memory) - `lintrunner -a` ------ https://chatgpt.com/codex/tasks/task_e_6893ba92c93483238a9bdfdd6c52812b Pull Request resolved: https://github.com/pytorch/pytorch/pull/159994 Approved by: https://github.com/albanD --- aten/src/ATen/EmptyTensor.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index 5634733325a2e..0e535ab20cd21 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -31,7 +31,9 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { return at::globalContext().getPinnedMemoryAllocator(opt_device_type); } else { TORCH_CHECK( - false, "Need to provide pin_memory allocator to use pin memory.") + false, + "pin_memory=True requires a CUDA or other accelerator backend; " + "no pinned memory allocator is available on this system.") } } From 8a37f0c90392a2c38b7c5955471fa49edcaf5cb1 Mon Sep 17 00:00:00 2001 From: zpcore Date: Fri, 8 Aug 2025 15:06:24 +0000 Subject: [PATCH 0147/1424] improve gather and scatter_add strategy (#160140) As title. This PR made a small fix on top of https://github.com/meta-pytorch/autoparallel/pull/81. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160140 Approved by: https://github.com/fmassa --- test/distributed/tensor/test_dtensor_ops.py | 1 - torch/distributed/tensor/_ops/_tensor_ops.py | 44 ++++++++++++++++++-- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index 3f724d9a85bf0..e5dcdfe11c8ce 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -160,7 +160,6 @@ def wrapped(fn): xfail("frexp"), xfail("full"), xfail("full_like"), - xfail("gather"), xfail("geometric"), xfail("geqrf"), xfail("grid_sampler_2d"), diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index 1838abdb97cab..a5a037a3c73e6 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -570,7 +570,6 @@ def replica_only_strategy(op_schema: OpSchema) -> StrategyType: aten.scatter.value, aten.scatter_.src, aten.scatter.src, - aten.scatter_add.default, ], schema_info=RuntimeSchemaInfo(1), ) @@ -597,11 +596,44 @@ def scatter_strategy(op_schema: OpSchema) -> StrategyType: return op_strategy -@register_op_strategy(aten.gather.default) +@register_op_strategy(aten.scatter_add.default, schema_info=RuntimeSchemaInfo(1)) +def scatter_add_strategy(op_schema: OpSchema) -> StrategyType: + input_strategy = op_schema.args_schema[0] + dim = op_schema.args_schema[1] + index_strategy = op_schema.args_schema[2] + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(index_strategy, OpStrategy) + assert isinstance(dim, int) + dim = normalize_dim(dim, input_strategy.ndim) + mesh = input_strategy.mesh + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim and input_shape[d] == index_shape[d]: + sharding: PlacementList = [Shard(d), Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +@register_op_strategy(aten.gather.default, schema_info=RuntimeSchemaInfo(1)) def gather_strategy(op_schema: OpSchema) -> StrategyType: mesh = op_schema.get_mesh_from_args() input_strategy = cast(OpStrategy, op_schema.args_schema[0]) dim = cast(int, op_schema.args_schema[1]) + dim = normalize_dim(dim, input_strategy.ndim) index_strategy = cast(OpStrategy, op_schema.args_schema[2]) input_shape = input_strategy.shape @@ -617,7 +649,7 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType: # input sharding, input sharded, index accepts mask partial, output follows index # this only works when the input is sharded on the gather dimension, and # index has size 1 on the gather dimension - if index_shape[dim] == 1: + if dim < len(index_shape) and index_shape[dim] == 1: index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) input_sharding: PlacementList = [ index_partial_placement, @@ -631,6 +663,12 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType: index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] single_mesh_dim_strategies.append(index_sharding) + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim: + sharding: PlacementList = [Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + return expand_to_full_mesh_op_strategy( mesh, op_schema, single_mesh_dim_strategies, input_index=1 ) From 50f23ff6f883db5021dd6bab4c146434f98dd15d Mon Sep 17 00:00:00 2001 From: gaoyvfeng <15834128411@126.com> Date: Fri, 8 Aug 2025 15:44:48 +0000 Subject: [PATCH 0148/1424] rename-HAS_CUDA-to-HAS_CUDA_AND_TRITON (#159883) Fixes #159399 "Modified torch.testing._internal.inductor_utils and test/inductor" Pull Request resolved: https://github.com/pytorch/pytorch/pull/159883 Approved by: https://github.com/janeyx99 --- .../fsdp/test_fully_shard_logging.py | 4 +-- test/dynamo/test_activation_checkpointing.py | 4 +-- test/dynamo/test_autograd_function.py | 4 +-- test/dynamo/test_backends.py | 4 +-- test/dynamo/test_base_hop.py | 4 +-- test/dynamo/test_callback.py | 4 +-- test/dynamo/test_compiler_bisector.py | 4 +-- test/dynamo/test_debug_utils.py | 4 +-- test/dynamo/test_higher_order_ops.py | 4 +-- test/dynamo/test_logging.py | 10 ++++--- test/dynamo/test_package.py | 27 ++++++++++--------- test/dynamo/test_structured_trace.py | 4 +-- test/dynamo/test_subclasses.py | 4 +-- test/functorch/test_ac.py | 4 +-- test/inductor/test_aot_inductor_custom_ops.py | 4 +-- test/inductor/test_benchmark_fusion.py | 6 ++--- test/inductor/test_ck_backend.py | 6 ++--- test/inductor/test_codecache.py | 20 +++++++------- test/inductor/test_combo_kernels.py | 4 +-- test/inductor/test_compiled_autograd.py | 25 ++++++++++------- test/inductor/test_cooperative_reductions.py | 4 +-- test/inductor/test_cuda_repro.py | 4 +-- test/inductor/test_cudacodecache.py | 4 +-- test/inductor/test_cudagraph_trees.py | 8 +++--- ...est_cudagraph_trees_expandable_segments.py | 11 +++++--- test/inductor/test_cutlass_backend.py | 8 +++--- test/inductor/test_cutlass_evt.py | 4 +-- test/inductor/test_decompose_mem_bound_mm.py | 22 +++++++-------- test/inductor/test_foreach.py | 4 +-- test/inductor/test_fp8.py | 4 +-- test/inductor/test_fused_attention.py | 4 +-- .../inductor/test_graph_transform_observer.py | 7 +++-- test/inductor/test_max_autotune.py | 4 +-- .../test_move_constructors_to_cuda.py | 4 +-- test/inductor/test_needs_exact_strides.py | 4 +-- test/inductor/test_online_softmax.py | 4 +-- test/inductor/test_pad_mm.py | 4 +-- test/inductor/test_perf.py | 6 ++--- test/inductor/test_profiler.py | 4 +-- test/inductor/test_smoke.py | 8 ++++-- .../test_torchinductor_dynamic_shapes.py | 2 +- test/inductor/test_torchinductor_opinfo.py | 4 +-- .../test_torchinductor_strided_blocks.py | 4 +-- test/inductor/test_triton_kernels.py | 4 +-- torch/testing/_internal/inductor_utils.py | 10 +++---- torch/testing/_internal/triton_utils.py | 4 +-- 46 files changed, 162 insertions(+), 138 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_logging.py b/test/distributed/_composable/fsdp/test_fully_shard_logging.py index 2ee46febfb24e..fac56ad0b8d42 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_logging.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_logging.py @@ -6,11 +6,11 @@ import torch.distributed as dist from torch._dynamo.test_case import run_tests from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.logging_utils import LoggingTestCase -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index d64334533f9b4..ea0882744c546 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -19,7 +19,7 @@ from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.two_tensor import TwoTensor from torch.utils.checkpoint import ( checkpoint, @@ -28,7 +28,7 @@ ) -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 6f460b402404f..d93a00f8ae106 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -8,10 +8,10 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils -from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda +from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_cuda -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: import triton from torch.testing._internal.triton_utils import add_kernel diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index 9d61bbf31acb1..2b927880cae31 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -16,10 +16,10 @@ onlyHPU, ) from torch.testing._internal.common_utils import skipIfHpu -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") class Seq(torch.nn.Module): diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index 18cdf78c61f27..30252d88a3782 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -13,10 +13,10 @@ ) from torch._higher_order_ops.schema import find_hop_schema from torch.testing._internal.common_utils import instantiate_parametrized_tests -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") def normalize_graph(gm): diff --git a/test/dynamo/test_callback.py b/test/dynamo/test_callback.py index 8112a2e89e957..c45fac7933c7d 100644 --- a/test/dynamo/test_callback.py +++ b/test/dynamo/test_callback.py @@ -8,7 +8,7 @@ from torch._dynamo.test_case import run_tests, TestCase from torch._guards import CompileId from torch.testing._internal.common_utils import TEST_WITH_ROCM -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class CallbackTests(TestCase): @@ -61,7 +61,7 @@ def test_counter_assertion(self) -> None: @unittest.skipIf( TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs" ) - @unittest.skipIf(not HAS_CUDA, "requires triton") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires triton") @torch._inductor.config.patch(force_disable_caches=True) def test_triggers(self) -> None: torch._dynamo.reset() diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index a5a350c0d1ad1..cce1b7bc9183f 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -11,12 +11,12 @@ from torch._inductor.compiler_bisector import CompilerBisector from torch._inductor.test_case import TestCase from torch.library import _scoped_library, Library -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON aten = torch.ops.aten -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") f32 = torch.float32 i64 = torch.int64 diff --git a/test/dynamo/test_debug_utils.py b/test/dynamo/test_debug_utils.py index ea39f6fbd9e1e..1315fa8d9c51a 100644 --- a/test/dynamo/test_debug_utils.py +++ b/test/dynamo/test_debug_utils.py @@ -10,10 +10,10 @@ from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string from torch._dynamo.test_case import TestCase from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") f32 = torch.float32 i64 = torch.int64 diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index b9c1ff3a61fe9..441a10aeba43f 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -38,11 +38,11 @@ xfailIfTorchDynamo, ) from torch.testing._internal.hop_db import hop_db -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") def count_ops(gm, args, freq, op): diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 99d992a899dbc..bcea00cdc98f1 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -26,7 +26,10 @@ TEST_XPU, xfailIf, ) -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU_AND_TRITON +from torch.testing._internal.inductor_utils import ( + HAS_CUDA_AND_TRITON, + HAS_XPU_AND_TRITON, +) from torch.testing._internal.logging_utils import ( LoggingTestCase, make_logging_test, @@ -34,10 +37,11 @@ ) -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_gpu = unittest.skipUnless( - HAS_CUDA or HAS_XPU_AND_TRITON, "requires cuda or xpu with triton" + HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON, "requires cuda or xpu with triton" ) + requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 5739f45504a6d..fdd01135ea2ff 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -24,7 +24,10 @@ skipIfRocm, skipIfXpu, ) -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU_AND_TRITON +from torch.testing._internal.inductor_utils import ( + HAS_CUDA_AND_TRITON, + HAS_XPU_AND_TRITON, +) def compute_loss_helper(x): @@ -94,7 +97,7 @@ def forward(self, x): @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_basic_fn(self, backend, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -138,7 +141,7 @@ def fn(x): @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_lazy_backward(self, backend, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -185,7 +188,7 @@ def fn(x): @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_graph_break_bomb(self, backend, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -249,7 +252,7 @@ def guard_filter_fn(guards): @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda", "xpu")) def test_dynamic_shape(self, backend, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -368,7 +371,7 @@ def guard_filter_fn(guards): @parametrize("device", ("cpu", "cuda", "xpu")) def test_dynamo_cache_manual_load(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -405,7 +408,7 @@ def fn2(x): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_serialize(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -441,7 +444,7 @@ def fn2(x): @skipIfXpu @skipIfRocm def test_automatic_dynamo_autotune_cache(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -474,7 +477,7 @@ def fn(x, y): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_recompiles(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -507,7 +510,7 @@ def fn(x): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_graph_breaks(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -553,7 +556,7 @@ def guard_filter_fn(guards): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_lazy_backward(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") @@ -582,7 +585,7 @@ def fn(x): @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_call_function_from_resume(self, device): - if device == "cuda" and not HAS_CUDA: + if device == "cuda" and not HAS_CUDA_AND_TRITON: raise unittest.SkipTest("Requires CUDA/Triton") if device == "xpu" and not HAS_XPU_AND_TRITON: raise unittest.SkipTest("Requires XPU/Triton") diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 77ef75d125367..ece491d764ddf 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -22,7 +22,7 @@ from torch._logging._internal import TorchLogsFormatter from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_utils import find_free_port -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON if torch.distributed.is_available(): @@ -31,7 +31,7 @@ HAS_TLPARSE = shutil.which("tlparse") is not None requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 17a01f745d405..ef4158b4a65b6 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -31,7 +31,7 @@ parametrize, subtest, ) -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.two_tensor import TwoTensor from torch.utils._python_dispatch import return_and_correct_aliasing @@ -145,7 +145,7 @@ def mk_subclass_dense_subclass_dense(): VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()} -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") compile_full_eager = torch.compile(backend="eager", fullgraph=True) diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py index 430d4a3d56ddd..fde84b6683edf 100644 --- a/test/functorch/test_ac.py +++ b/test/functorch/test_ac.py @@ -6,7 +6,7 @@ import torch import torch._functorch.config as config from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint from torch.utils.flop_counter import FlopCounterMode, register_flop_formula @@ -405,5 +405,5 @@ def call(): if __name__ == "__main__": # I'm using the cuda memory allocator to verify memory allocations - if HAS_CUDA and not TEST_WITH_ROCM: + if HAS_CUDA_AND_TRITON and not TEST_WITH_ROCM: run_tests() diff --git a/test/inductor/test_aot_inductor_custom_ops.py b/test/inductor/test_aot_inductor_custom_ops.py index aa3c589b45467..0b4f508477ac4 100644 --- a/test/inductor/test_aot_inductor_custom_ops.py +++ b/test/inductor/test_aot_inductor_custom_ops.py @@ -24,7 +24,7 @@ skipIfXpu, ) from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test -from torch.testing._internal.triton_utils import HAS_CUDA +from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON from torch.utils._python_dispatch import TorchDispatchMode @@ -556,5 +556,5 @@ class AOTInductorTestABICompatibleCuda(AOTICustomOpTestCase): from torch._inductor.test_case import run_tests # cpp_extension N/A in fbcode - if HAS_CUDA or sys.platform == "darwin": + if HAS_CUDA_AND_TRITON or sys.platform == "darwin": run_tests(needs="filelock") diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index b3afba7d6843f..8a61cc051c20b 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -13,7 +13,7 @@ from torch.testing._internal.inductor_utils import ( get_func_call, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, IS_BIG_GPU, ) @@ -197,7 +197,7 @@ def f(x): self.common(f, (x,)) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: class BenchmarkFusionCudaTest(TestCase): common = check_model_cuda @@ -347,5 +347,5 @@ class BenchmarkFusionCpuTest(TestCase): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index 7c50ee1dbd1f6..f73a47e45a57a 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -22,11 +22,11 @@ _quantize_rowwise, _quantize_tensorwise, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, ) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") log = logging.getLogger(__name__) @@ -464,5 +464,5 @@ def compiled_bmm(x, w): from torch._inductor.utils import is_big_gpu # Set env to make it work in CI. - if HAS_CUDA and HAS_CPU and is_big_gpu(): + if HAS_CUDA_AND_TRITON and HAS_CPU and is_big_gpu(): run_tests() diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 996e81032a05d..8e53725dd159c 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -59,7 +59,7 @@ ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, - HAS_CUDA, + HAS_CUDA_AND_TRITON, HAS_GPU, HAS_MULTIGPU, HAS_TRITON, @@ -872,7 +872,7 @@ def fn(x): @torch._functorch.config.patch({"enable_autograd_cache": False}) @config.patch("fx_graph_remote_cache", False) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") def test_no_arguments_tensor_device_guards(self): """ Usually, when there are example inputs, the device index of the inputs @@ -902,7 +902,7 @@ def f(): @torch._functorch.config.patch({"enable_autograd_cache": False}) @config.patch("fx_graph_remote_cache", False) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") def test_tensor_device_guards_cpu_tensor(self): """ CPU tensor arguments should still cache hit @@ -2574,7 +2574,7 @@ def test_get_hash_for_files(self): class TestCudaCompileCommand(TestCase): - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") def test_cuda_compile_command(self): cmd_no_extra_args: str = cuda_compile_command( ["abc.cu", "def.cu"], "output", "so" @@ -2619,7 +2619,7 @@ def reset(self): torch._dynamo.reset() clear_caches() - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") @unittest.skipIf(not SM80OrLater, "Requires SM80+") @unittest.skipIf( TEST_WITH_ROCM, "Requires static cuda launcher, which does not support ROCM" @@ -2670,7 +2670,7 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2711,7 +2711,7 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2772,7 +2772,7 @@ def f(a, b, c, d, e, f): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") @requires_triton() - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2836,7 +2836,7 @@ def fn(x, y): class TestRemoteAOTAutogradCache(TestCase): - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": True}) @@ -2875,7 +2875,7 @@ def f(a, b): for k in global_stats.fx_graph.cache.keys(): self.assertRegex(k, r"pt2:fx-graph-v1::[0-9a-z]{52}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": True}) diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index a054464bf6689..480094dfb7481 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -10,7 +10,7 @@ instantiate_parametrized_tests, TestCase, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda @@ -558,5 +558,5 @@ def fn(x, y, z): if __name__ == "__main__": from torch._dynamo.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index b3d98a970cf65..c99ad7f2c95a9 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -47,7 +47,12 @@ skipIfWindows, ) from torch.testing._internal.hop_db import hop_db -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_CUDA_AND_TRITON, + HAS_GPU, +) from torch.testing._internal.logging_utils import logs_to_string from torch.utils._python_dispatch import TorchDispatchMode @@ -2989,7 +2994,7 @@ def backward(ctx, grad): b = MyFunc.apply(a) b.sum().backward() - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") def test_cudagraphs_cpu_division(self): from torch._dynamo.testing import reduce_to_scalar_loss @@ -3029,7 +3034,7 @@ def test_cudagraphs_cpu_graph(self): self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") def test_cudagraphs_sdpa(self): query = torch.rand( 32, 8, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True @@ -3051,7 +3056,7 @@ def test_cudagraphs_sdpa(self): 2 if inductor_config.cpp_wrapper else 0, ) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): class MyFn(torch.autograd.Function): @staticmethod @@ -3082,7 +3087,7 @@ def backward(ctx, gO): self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) @scoped_load_inline - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { @@ -3710,7 +3715,7 @@ def inner_compiler(gm_, example_inputs_): self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node)) self.assertTrue(isinstance(view_nodes[1].args[1][0], torch.fx.Node)) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") def test_flex_attention(self): def _squared(score, b, h, m, n): """Joint graph needed for correctness""" @@ -3878,7 +3883,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check), ) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") def test_cpu_offloading(self): def fn(): def pack(x): @@ -5046,7 +5051,7 @@ def wrap_test_class(orig_cls): dct[name] = unittest.expectedFailure elif name.startswith("test_"): backend = lookup_backend(name) - if not HAS_CUDA and backend == "inductor": + if not HAS_CUDA_AND_TRITON and backend == "inductor": continue ctxs = [ compiled_autograd._enable( @@ -5283,7 +5288,7 @@ def wrap_test_class(orig_cls): skipped_tests = set() -if not HAS_CUDA: +if not HAS_CUDA_AND_TRITON: # Found Tesla M60 which is too old to be supported by the triton GPU compiler skipped_tests.add("test_type_conversions") @@ -5309,7 +5314,7 @@ def wrap_test_class(orig_cls): test_higher_order_ops.ActivationCheckpointingTests ) -if torch.distributed.is_available() and HAS_CUDA: +if torch.distributed.is_available() and HAS_CUDA_AND_TRITON: test_dtensor = load_test_module("distributed/tensor/test_dtensor_compile") TestDTensorCompileWithCompiledAutograd = wrap_test_class( test_dtensor.TestDTensorCompile diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index fc296b12a9d70..0b8f60dc0d269 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -18,7 +18,7 @@ instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class TestingHeuristics(InductorChoices): @@ -381,5 +381,5 @@ def fn(x, y): if __name__ == "__main__": from torch._dynamo.test_case import run_tests - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index bb59b626bef14..6037bd4d794cd 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -2216,7 +2216,7 @@ def forward(self, x): if __name__ == "__main__": from torch._inductor.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_CUDA + from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON - if HAS_CUDA and not TEST_WITH_ASAN: + if HAS_CUDA_AND_TRITON and not TEST_WITH_ASAN: run_tests(needs="filelock") diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 36f73b2004763..7a132ac2a0468 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -10,10 +10,10 @@ from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import fresh_cache -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") _SOURCE_CODE = r""" diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 688c4d87230cf..4a7f9e6e92e03 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -55,11 +55,11 @@ importlib.import_module("functorch") importlib.import_module("filelock") -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON aten = torch.ops.aten -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_multigpu = functools.partial( unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" ) @@ -124,7 +124,7 @@ def tearDown(self): torch._dynamo.reset() -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: def get_all_cudagraph_segments(): segments = torch.cuda.memory_snapshot() @@ -4057,5 +4057,5 @@ def fn(x, y): sys.exit(0) raise unittest.SkipTest("cuda graph test is skipped") - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_cudagraph_trees_expandable_segments.py b/test/inductor/test_cudagraph_trees_expandable_segments.py index 04f2ad96fdc0b..65597316091d4 100644 --- a/test/inductor/test_cudagraph_trees_expandable_segments.py +++ b/test/inductor/test_cudagraph_trees_expandable_segments.py @@ -8,13 +8,13 @@ import torch from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: try: from .test_cudagraph_trees import CudaGraphTreeTests except ImportError: @@ -32,7 +32,12 @@ sys.path.remove(str(REPO_ROOT)) if __name__ == "__main__": - if torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS and HAS_CUDA: + if ( + torch.cuda.is_available() + and not IS_JETSON + and not IS_WINDOWS + and HAS_CUDA_AND_TRITON + ): get_disabled_tests(".") torch.cuda.memory._set_allocator_settings("expandable_segments:True") diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index ea0fa87382145..c29dff73f9a1e 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -58,12 +58,12 @@ _quantize_rowwise, _quantize_tensorwise, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, ) torch.set_float32_matmul_precision("high") -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") @@ -158,7 +158,7 @@ def select_no_algorithm(*args, **kwargs): @instantiate_parametrized_tests class TestCutlassBackend(TestCase): def setUp(self): - if not HAS_CUDA: + if not HAS_CUDA_AND_TRITON: self.skipTest("CUDA is not available") if torch.version.hip: self.skipTest("CUTLASS backend is not supported on HIP") @@ -2313,5 +2313,5 @@ def test_config_number_post_filtering(self) -> None: from torch._inductor.utils import is_big_gpu # Set env to make it work in CI. - if HAS_CUDA and HAS_CPU and is_big_gpu(): + if HAS_CUDA_AND_TRITON and HAS_CPU and is_big_gpu(): run_tests() diff --git a/test/inductor/test_cutlass_evt.py b/test/inductor/test_cutlass_evt.py index eb468c3910209..9c2b9a624a202 100644 --- a/test/inductor/test_cutlass_evt.py +++ b/test/inductor/test_cutlass_evt.py @@ -15,7 +15,7 @@ from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.utils import OrderedSet from torch.testing._internal.common_cuda import SM90OrLater -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON if try_import_cutlass(): @@ -571,5 +571,5 @@ def test_evt_codegen(self): if __name__ == "__main__": from torch._dynamo.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index 8be6e23475925..919d97f987f64 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -15,7 +15,7 @@ parametrize, TEST_XPU, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_gpu @@ -117,7 +117,7 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -128,7 +128,7 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 3 if should_decompose and HAS_CUDA else 0 + expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -177,7 +177,7 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -224,7 +224,7 @@ def test_decompose_linear_mixed_precision( self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -269,7 +269,7 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -281,7 +281,7 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -331,7 +331,7 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -343,7 +343,7 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -367,7 +367,7 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose and HAS_CUDA else 0 + expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -381,7 +381,7 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_gradients(module, traced) expected_val = 0 - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: expected_val = 1 if has_bias else 2 self.assertEqual( diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index 8eb113f183299..f9cedf81f85b0 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -14,7 +14,7 @@ IS_FBCODE, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda from torch.utils._pytree import tree_flatten @@ -1109,5 +1109,5 @@ def ref_fn(xs): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 50044b2c1943a..11d320315cdcd 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -22,7 +22,7 @@ _quantize_tensorwise, _to_fp8_saturated, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, ) from torch.utils._triton import has_triton_tma_device @@ -766,5 +766,5 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): if __name__ == "__main__": - if HAS_CUDA or HAS_CPU: + if HAS_CUDA_AND_TRITON or HAS_CPU: run_tests() diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 19757d8942071..25e96fa9f1e9f 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -18,7 +18,7 @@ from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, HAS_XPU_AND_TRITON, ) @@ -1119,7 +1119,7 @@ def dot_prod_attention( ) -if HAS_XPU_AND_TRITON or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION): +if HAS_XPU_AND_TRITON or (HAS_CUDA_AND_TRITON and PLATFORM_SUPPORTS_FUSED_ATTENTION): class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): device = GPU_TYPE diff --git a/test/inductor/test_graph_transform_observer.py b/test/inductor/test_graph_transform_observer.py index 1def72ae9e273..2bd0b6ef43f11 100644 --- a/test/inductor/test_graph_transform_observer.py +++ b/test/inductor/test_graph_transform_observer.py @@ -11,7 +11,7 @@ from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON try: @@ -28,7 +28,10 @@ class TestGraphTransformObserver(TestCase): def test_sdpa_rewriter(self): if not ( - HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT + HAS_CUDA_AND_TRITON + and PLATFORM_SUPPORTS_FUSED_ATTENTION + and HAS_PYDOT + and HAS_DOT ): return diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 8917c7a6ed360..93165fa2dcec8 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -68,13 +68,13 @@ get_kernel_launch, GPU_TYPE, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, HAS_GPU, ) torch.set_float32_matmul_precision("high") -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") diff --git a/test/inductor/test_move_constructors_to_cuda.py b/test/inductor/test_move_constructors_to_cuda.py index 3c3b8708c630f..b174c79f1ebd0 100644 --- a/test/inductor/test_move_constructors_to_cuda.py +++ b/test/inductor/test_move_constructors_to_cuda.py @@ -9,7 +9,7 @@ from torch.testing import FileCheck from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON requires_multigpu = functools.partial( @@ -112,5 +112,5 @@ def foo(x): if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_needs_exact_strides.py b/test/inductor/test_needs_exact_strides.py index ae80abe7c440c..2d636db3f88f1 100644 --- a/test/inductor/test_needs_exact_strides.py +++ b/test/inductor/test_needs_exact_strides.py @@ -13,7 +13,7 @@ IS_LINUX, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class TestNeedsExactStrides(InductorTestCase): @@ -98,5 +98,5 @@ def f(x, other): instantiate_parametrized_tests(TestNeedsExactStrides) if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index 798d86b0dd617..1e94ff1f49877 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -14,7 +14,7 @@ IS_LINUX, parametrize, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" @@ -297,5 +297,5 @@ def f(x, mask): instantiate_parametrized_tests(TestOnlineSoftmax) if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index bcd1519c59350..d04bed2a90329 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -16,7 +16,7 @@ from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_utils import skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class PadMMTest(TestCase): @@ -541,5 +541,5 @@ def fn(x, y): if __name__ == "__main__": - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 0ca54257250f6..30a273ba17e31 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -28,13 +28,13 @@ # performance for that setting. # # Defines all the kernels for tests -from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda +from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_cuda # set so that metrics appear torch._logging.set_logs(inductor_metrics=True) -if HAS_CUDA: +if HAS_CUDA_AND_TRITON: import triton # @manual import triton.language as tl # @manual @@ -1292,5 +1292,5 @@ def f(a, b): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index 3d54c378de4a2..f22f0374813b0 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -12,7 +12,7 @@ from torch._inductor import config from torch.profiler import ProfilerActivity from torch.testing._internal.common_utils import TemporaryFileName -from torch.testing._internal.inductor_utils import HAS_CUDA, IS_BIG_GPU +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON, IS_BIG_GPU from torch.torch_version import TorchVersion from torch.utils._triton import has_triton @@ -313,5 +313,5 @@ def fn(x, y): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: run_tests() diff --git a/test/inductor/test_smoke.py b/test/inductor/test_smoke.py index 895e8ba16ab0d..2a247fddbe76e 100644 --- a/test/inductor/test_smoke.py +++ b/test/inductor/test_smoke.py @@ -6,7 +6,11 @@ import torch._logging from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CUDA_AND_TRITON, + HAS_GPU, +) class MLP(torch.nn.Module): @@ -62,5 +66,5 @@ def test_compile_invalid_options(self): from torch._inductor.test_case import run_tests if IS_LINUX and HAS_GPU: - if (not HAS_CUDA) or torch.cuda.get_device_properties(0).major <= 5: + if (not HAS_CUDA_AND_TRITON) or torch.cuda.get_device_properties(0).major <= 5: run_tests() diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index a2d5ff9be6c23..8b6d625a54471 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -147,7 +147,7 @@ class TestInductorDynamic(TestCase): compile_fn = partial(torch.compile, dynamic=True) def setUp(self): - # HAS_CUDA also checks compute capability to skip tests + # HAS_CUDA_AND_TRITON also checks compute capability to skip tests # on older devices if not HAS_GPU: self.skipTest("Triton not available") diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 2a0e4c63fb682..e8d6ce38d5af6 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -46,7 +46,7 @@ from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, - HAS_CUDA, + HAS_CUDA_AND_TRITON, has_triton, HAS_XPU_AND_TRITON, maybe_skip_size_asserts, @@ -1126,7 +1126,7 @@ def tearDown(self): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently - @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") + @skipCUDAIf(not HAS_CUDA_AND_TRITON, "Skipped! Triton not found") @skipXPUIf( not HAS_XPU_AND_TRITON, "Skipped! Supported XPU compiler and Triton not found" ) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 67d197f0750d0..c203ea661fbe7 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -26,7 +26,7 @@ ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, - HAS_CUDA, + HAS_CUDA_AND_TRITON, HAS_GPU, requires_gpu, skip_windows_ci, @@ -1349,7 +1349,7 @@ class TritonBlockPointerTestGPU(BlockDescriptorTestBase): @unittest.skipIf( - not (HAS_CUDA and torch.cuda.get_device_capability()[0] >= 9), + not (HAS_CUDA_AND_TRITON and torch.cuda.get_device_capability()[0] >= 9), "Requires Triton CUDA backend and CUDA compute capability >= 9.0", ) @config.patch({"triton.use_tensor_descriptor": True, "assume_aligned_inputs": True}) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 03ba4dc712702..87529c23dd7ad 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -33,7 +33,7 @@ ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, - HAS_CUDA, + HAS_CUDA_AND_TRITON, HAS_GPU, HAS_XPU_AND_TRITON, ) @@ -52,7 +52,7 @@ import triton from triton import language as tl - if HAS_CUDA: + if HAS_CUDA_AND_TRITON: try: from triton.language.extra.libdevice import ( # @manual fast_dividef, diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 7ce065c64317c..f1cf62aa64bd1 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -69,13 +69,13 @@ def test_cpu(): TRITON_HAS_CPU = False -HAS_CUDA = torch.cuda.is_available() and HAS_TRITON +HAS_CUDA_AND_TRITON = torch.cuda.is_available() and HAS_TRITON HAS_XPU_AND_TRITON = torch.xpu.is_available() and HAS_TRITON HAS_MPS = torch.mps.is_available() -HAS_GPU = HAS_CUDA or HAS_XPU_AND_TRITON +HAS_GPU = HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON GPU_TYPE = get_gpu_type() @@ -163,16 +163,16 @@ def inner(fn): skipCPUIf = functools.partial(skipDeviceIf, device="cpu") IS_A100 = LazyVal( - lambda: HAS_CUDA + lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 166912 ) IS_H100 = LazyVal( - lambda: HAS_CUDA + lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 232448 ) -IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu()) +IS_BIG_GPU = LazyVal(lambda: HAS_CUDA_AND_TRITON and is_big_gpu()) def dummy_graph() -> GraphLowering: """ diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 69b260d2833b5..922bde7cc4b58 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -2,11 +2,11 @@ import unittest -from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON, HAS_GPU from torch.utils._triton import has_triton -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu") if has_triton(): From 231c72240d80091f099c95e326d3600cba866eee Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 8 Aug 2025 16:03:49 +0000 Subject: [PATCH 0149/1424] CMake build: preserve PYTHONPATH (#160144) Fixes #160092 I'm very new to CMake, so let me know if there's a fancier way to do this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160144 Approved by: https://github.com/malfet Co-authored-by: Xuehai Pan --- torch/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 8d761068d1e62..1632147f0220e 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -265,7 +265,7 @@ add_custom_command( OUTPUT "${TORCH_SRC_DIR}/utils/data/datapipes/datapipe.pyi" COMMAND - ${CMAKE_COMMAND} -E env PYTHONPATH="${TORCH_ROOT}" + ${CMAKE_COMMAND} -E env --modify PYTHONPATH=path_list_prepend:"${TORCH_ROOT}" -- "${Python_EXECUTABLE}" ${TORCH_SRC_DIR}/utils/data/datapipes/gen_pyi.py DEPENDS "${TORCH_SRC_DIR}/utils/data/datapipes/datapipe.pyi.in" From a4f69a5da08eace1c1e6469dec6a18aa842da73b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 7 Aug 2025 22:14:38 -0700 Subject: [PATCH 0150/1424] [dynamo][guards] Remove guards on stdlib modules (#159913) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159913 Approved by: https://github.com/StrongerXi --- torch/_dynamo/guards.py | 4 ++++ torch/_dynamo/source.py | 23 +++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 16 ++++++++++++++++ torch/_dynamo/variables/functions.py | 8 ++++++++ 4 files changed, 51 insertions(+) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 5ffa6d06d7c4e..a32b8d686dac7 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -70,6 +70,7 @@ is_from_flatten_script_object_source, is_from_local_source, is_from_optimizer_source, + is_from_skip_guard_source, is_from_unspecialized_builtin_nn_module_source, TensorProperty, TensorPropertySource, @@ -4124,4 +4125,7 @@ def install_guard(*guards: Guard, skip: int = 0) -> None: add = TracingContext.get().guards_context.dynamo_guards.add for guard in guards: assert isinstance(guard, Guard) + + if is_from_skip_guard_source(guard.originating_source): + continue add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 3cb36a63d27ad..6897ddd9b24c7 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -402,6 +402,18 @@ def is_ephemeral(self) -> bool: return True +@dataclasses.dataclass(frozen=True) +class SkipGuardSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen") -> None: + self.base.reconstruct(codegen) + + def guard_source(self) -> GuardSource: + return self.base.guard_source() + + def name(self) -> str: + return self.base.name() + + class TensorProperty(enum.Enum): SIZE = 0 STRIDE = 1 @@ -1151,3 +1163,14 @@ def is_from_defaults(source: Source) -> bool: if isinstance(source, ChainedSource): return is_from_defaults(source.base) return False + + +@functools.lru_cache +def is_from_skip_guard_source(source: Source) -> bool: + if isinstance(source, SkipGuardSource): + return True + + if isinstance(source, ChainedSource): + return is_from_skip_guard_source(source.base) + + return False diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 546d1bc84f25e..8e5a1ef80393c 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -108,6 +108,7 @@ GlobalWeakRefSource, LocalCellSource, LocalSource, + SkipGuardSource, Source, ) from .trace_rules import is_builtin_constant, is_forbidden @@ -443,6 +444,15 @@ def impl(self: "InstructionTranslator", inst: Instruction): return impl +def is_stdlib(mod): + if sys.version_info < (3, 10): + # For < 3.10, no easy way to identify a stdlib module name. + return False + if not isinstance(mod, types.ModuleType): + return False + return mod.__name__.split(".")[0] in sys.stdlib_module_names + + def _detect_and_normalize_assert_statement( self: "InstructionTranslatorBase", truth_fn: typing.Callable[[object], bool], @@ -4100,6 +4110,12 @@ def get_globals_source_and_value(self, name): # Dont use lazy vt because we will do a setattr afterwards fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) global_source = DictGetItemSource(globals_source, name) # type: ignore[assignment] + + if is_stdlib(fglobals_value): + # Users don't inplace mutate a stdlib attribute (like inspect, + # collections), skip guards that originate from the stdlib modules. + global_source = SkipGuardSource(global_source) # type: ignore[assignment] + return fglobals_value, fglobals_vt, global_source def _load_global(self, inst): diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 0da182c022b99..be92c4eb491bc 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -62,6 +62,7 @@ ConstantSource, DefaultsSource, GetItemSource, + SkipGuardSource, ) from ..utils import ( check_constant_args, @@ -303,6 +304,13 @@ def _create_nested_fn( def fn_var_getattr(tx, fn, source, name): source = source and AttrSource(source, name) + + if source and name == "__annotations__": + # We get a large number of silly guards from annotations from inspect + # module. Changing annotations is rare, and it impacting the extracted + # graph is even rarer. So skip guards. + source = SkipGuardSource(source) + try: subobj = inspect.getattr_static(fn, name) except AttributeError: From 86eb65f7f06016bcd5d7951dc9d74bc3993a827a Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 31 Jul 2025 12:22:20 -0500 Subject: [PATCH 0151/1424] [MPS] Move max_pool2d to Metal for `stride != 1` (#157876) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR updates `max_pool2d` to use a Metal kernel instead of the old MPS graph impl. However, when the `stride` argument is 1 in all dimensions, the old implementation gives significantly better performance, so we fall back to it in that case. Below is a performance comparison of `max_pool2d` before and after this PR, obtained from this script: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/2f02f2bf7ad8e1b80d8eb728612b179d48fe92d7/max_pool_mps/perf.py
Click to expand case | before PR | after PR | speedup |   | case info -- | -- | -- | -- | -- | -- 0 | 0.014264 | 0.004473 | 3.188911245 |   | (3, 2, 2), {'kernel_size': 2, 'return_indices': True} 1 | 0.010752 | 0.00421 | 2.55391924 |   | (3, 2, 2), {'kernel_size': 2, 'return_indices': False} 2 | 0.020777 | 0.006123 | 3.393271272 |   | (3, 10, 10), {'kernel_size': 5, 'return_indices': True} 3 | 0.011065 | 0.005759 | 1.921340511 |   | (3, 10, 10), {'kernel_size': 5, 'return_indices': False} 4 | 0.01452 | 0.007829 | 1.854642994 |   | (3, 100, 100), {'kernel_size': 5, 'return_indices': True} 5 | 0.009258 | 0.007075 | 1.308551237 |   | (3, 100, 100), {'kernel_size': 5, 'return_indices': False} 6 | 0.188137 | 0.168688 | 1.115295694 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 0, 'return_indices': True} 7 | 0.161362 | 0.154746 | 1.042753932 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 0, 'return_indices': False} 8 | 0.182883 | 0.16945 | 1.079274122 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 1, 'return_indices': True} 9 | 0.156875 | 0.163346 | 0.9603847049 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 1, 'return_indices': False} 10 | 0.193433 | 0.167396 | 1.155541351 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 2, 'return_indices': True} 11 | 0.158967 | 0.151246 | 1.051049284 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 2, 'return_indices': False} 12 | 0.931071 | 0.932883 | 0.9980576342 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 0, 'return_indices': True} 13 | 0.324496 | 0.3252 | 0.9978351784 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 0, 'return_indices': False} 14 | 0.944071 | 0.936246 | 1.008357846 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 1, 'return_indices': True} 15 | 0.322171 | 0.314854 | 1.023239343 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 1, 'return_indices': False} 16 | 0.894158 | 0.886408 | 1.008743152 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 2, 'return_indices': True} 17 | 0.309338 | 0.304146 | 1.017070749 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 2, 'return_indices': False} 18 | 0.606 | 0.260546 | 2.325884873 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 0, 'return_indices': True} 19 | 0.30445 | 0.231054 | 1.317657344 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 0, 'return_indices': False} 20 | 0.474708 | 0.261925 | 1.812381407 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 1, 'return_indices': True} 21 | 0.23175 | 0.231883 | 0.9994264349 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 1, 'return_indices': False} 22 | 0.434475 | 0.266246 | 1.631855502 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 2, 'return_indices': True} 23 | 0.236942 | 0.231792 | 1.022218196 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 2, 'return_indices': False} 24 | 0.202396 | 0.174888 | 1.157289237 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 0, 'return_indices': True} 25 | 0.160679 | 0.158246 | 1.015374796 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 0, 'return_indices': False} 26 | 0.200354 | 0.184133 | 1.088093932 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 1, 'return_indices': True} 27 | 0.160779 | 0.160679 | 1.000622359 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 1, 'return_indices': False} 28 | 0.199175 | 0.178625 | 1.115045486 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 2, 'return_indices': True} 29 | 0.159458 | 0.160883 | 0.9911426316 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 2, 'return_indices': False} 30 | 0.199021 | 0.165329 | 1.203787599 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 0, 'return_indices': True} 31 | 0.156337 | 0.158213 | 0.9881425673 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 0, 'return_indices': False} 32 | 0.180146 | 0.174483 | 1.032455884 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 1, 'return_indices': True} 33 | 0.156988 | 0.158167 | 0.9925458534 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 1, 'return_indices': False} 34 | 0.182133 | 0.176521 | 1.031792251 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 2, 'return_indices': True} 35 | 0.169042 | 0.156483 | 1.080257919 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 2, 'return_indices': False} 36 | 1.767821 | 1.766254 | 1.000887188 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 0, 'return_indices': True} 37 | 1.059346 | 1.058775 | 1.000539302 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 0, 'return_indices': False} 38 | 1.85755 | 1.859429 | 0.9989894747 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 1, 'return_indices': True} 39 | 1.100417 | 1.097683 | 1.002490701 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 1, 'return_indices': False} 40 | 1.843167 | 1.847558 | 0.9976233493 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 2, 'return_indices': True} 41 | 1.090142 | 1.093163 | 0.9972364597 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 2, 'return_indices': False} 42 | 0.480867 | 0.251733 | 1.910226311 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 0, 'return_indices': True} 43 | 0.319246 | 0.236479 | 1.349997251 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 0, 'return_indices': False} 44 | 0.49315 | 0.256408 | 1.923301925 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 1, 'return_indices': True} 45 | 0.316746 | 0.227854 | 1.390127011 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 1, 'return_indices': False} 46 | 0.4912 | 0.257762 | 1.905633879 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 2, 'return_indices': True} 47 | 0.324771 | 0.229371 | 1.41592006 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 2, 'return_indices': False} 48 | 0.152904 | 0.095079 | 1.608178462 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 0, 'return_indices': True} 49 | 0.102963 | 0.089217 | 1.154073775 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 0, 'return_indices': False} 50 | 0.155158 | 0.095429 | 1.625899884 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 1, 'return_indices': True} 51 | 0.104338 | 0.089979 | 1.15958168 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 1, 'return_indices': False} 52 | 0.153121 | 0.096429 | 1.587914424 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 2, 'return_indices': True} 53 | 0.103642 | 0.090254 | 1.148336916 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 2, 'return_indices': False} 54 | 0.191071 | 0.165125 | 1.157129447 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 0, 'return_indices': True} 55 | 0.153971 | 0.149021 | 1.033216795 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 0, 'return_indices': False} 56 | 0.193192 | 0.166892 | 1.157586942 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 1, 'return_indices': True} 57 | 0.156617 | 0.15215 | 1.029359185 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 1, 'return_indices': False} 58 | 0.178033 | 0.167308 | 1.06410333 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 2, 'return_indices': True} 59 | 0.157425 | 0.164404 | 0.9575496947 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 2, 'return_indices': False} 60 | 1.757638 | 1.750896 | 1.0038506 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 0, 'return_indices': True} 61 | 1.048471 | 1.047967 | 1.000480931 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 0, 'return_indices': False} 62 | 1.790708 | 1.789767 | 1.000525767 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 1, 'return_indices': True} 63 | 1.054575 | 1.054796 | 0.9997904808 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 1, 'return_indices': False} 64 | 1.785837 | 1.784192 | 1.000921986 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 2, 'return_indices': True} 65 | 1.054713 | 1.054492 | 1.00020958 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 2, 'return_indices': False} 66 | 0.478267 | 0.261017 | 1.832321266 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 0, 'return_indices': True} 67 | 0.32005 | 0.226654 | 1.412064204 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 0, 'return_indices': False} 68 | 0.484008 | 0.254721 | 1.900149575 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 1, 'return_indices': True} 69 | 0.321 | 0.218842 | 1.466811672 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 1, 'return_indices': False} 70 | 0.482087 | 0.248771 | 1.937874591 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 2, 'return_indices': True} 71 | 0.316558 | 0.230533 | 1.373156988 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 2, 'return_indices': False} 72 | 0.137842 | 0.085088 | 1.619993419 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 0, 'return_indices': True} 73 | 0.100671 | 0.0769 | 1.309115735 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 0, 'return_indices': False} 74 | 0.148321 | 0.086967 | 1.705485989 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 1, 'return_indices': True} 75 | 0.101392 | 0.075454 | 1.343759112 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 1, 'return_indices': False} 76 | 0.150208 | 0.083742 | 1.793699697 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 2, 'return_indices': True} 77 | 0.099587 | 0.075825 | 1.313379492 |   | (3, 1000, 1000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 2, 'return_indices': False} 78 | 0.622546 | 0.602729 | 1.03287879 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 0, 'return_indices': True} 79 | 0.531696 | 0.5067 | 1.049330965 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 0, 'return_indices': False} 80 | 0.626646 | 0.617038 | 1.015571164 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 1, 'return_indices': True} 81 | 0.530354 | 0.525367 | 1.009492412 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 1, 'return_indices': False} 82 | 0.633933 | 0.577775 | 1.097197006 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 2, 'return_indices': True} 83 | 0.533067 | 0.526954 | 1.011600633 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': None, 'padding': 2, 'return_indices': False} 84 | 3.372867 | 3.386412 | 0.9960001914 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 0, 'return_indices': True} 85 | 1.155975 | 1.156604 | 0.9994561665 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 0, 'return_indices': False} 86 | 3.401921 | 3.39755 | 1.001286515 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 1, 'return_indices': True} 87 | 1.202829 | 1.192538 | 1.008629494 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 1, 'return_indices': False} 88 | 3.23675 | 3.220238 | 1.005127571 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 2, 'return_indices': True} 89 | 1.077067 | 1.085613 | 0.9921279498 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 1, 'padding': 2, 'return_indices': False} 90 | 1.572925 | 0.925625 | 1.699311276 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 0, 'return_indices': True} 91 | 0.791204 | 0.793454 | 0.9971642969 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 0, 'return_indices': False} 92 | 1.572742 | 0.922729 | 1.704446268 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 1, 'return_indices': True} 93 | 0.784292 | 0.788871 | 0.9941955022 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 1, 'return_indices': False} 94 | 1.526546 | 0.925708 | 1.649057802 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 2, 'return_indices': True} 95 | 0.769321 | 0.787675 | 0.9766985114 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 2, 'padding': 2, 'return_indices': False} 96 | 0.736033 | 0.612808 | 1.201082558 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 0, 'return_indices': True} 97 | 0.574625 | 0.530925 | 1.082309177 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 0, 'return_indices': False} 98 | 0.722021 | 0.614488 | 1.174996094 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 1, 'return_indices': True} 99 | 0.563171 | 0.533721 | 1.055178642 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 1, 'return_indices': False} 100 | 0.735725 | 0.613992 | 1.198264798 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 2, 'return_indices': True} 101 | 0.583487 | 0.532513 | 1.095723485 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 1, 'stride': 4, 'padding': 2, 'return_indices': False} 102 | 0.656383 | 0.575313 | 1.140914598 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 0, 'return_indices': True} 103 | 0.559796 | 0.509079 | 1.099625009 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 0, 'return_indices': False} 104 | 0.662046 | 0.572362 | 1.156691045 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 1, 'return_indices': True} 105 | 0.552633 | 0.508671 | 1.086425214 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 1, 'return_indices': False} 106 | 0.634108 | 0.574629 | 1.103508525 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 2, 'return_indices': True} 107 | 0.534013 | 0.510996 | 1.045043405 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': None, 'padding': 2, 'return_indices': False} 108 | 7.056642 | 7.066717 | 0.9985743026 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 0, 'return_indices': True} 109 | 4.144275 | 4.142658 | 1.000390329 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 0, 'return_indices': False} 110 | 7.172683 | 7.189867 | 0.9976099697 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 1, 'return_indices': True} 111 | 4.162538 | 4.158875 | 1.000880767 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 1, 'return_indices': False} 112 | 7.194233 | 7.181837 | 1.001726021 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 2, 'return_indices': True} 113 | 4.294083 | 4.196062 | 1.023360236 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 1, 'padding': 2, 'return_indices': False} 114 | 1.875692 | 0.891071 | 2.104986022 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 0, 'return_indices': True} 115 | 1.097479 | 0.781175 | 1.404907991 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 0, 'return_indices': False} 116 | 1.8883 | 0.89015 | 2.121327866 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 1, 'return_indices': True} 117 | 1.101329 | 0.778542 | 1.414604479 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 1, 'return_indices': False} 118 | 1.872833 | 0.893654 | 2.095702587 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 2, 'return_indices': True} 119 | 1.096712 | 0.784579 | 1.397835017 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 2, 'padding': 2, 'return_indices': False} 120 | 0.513029 | 0.374417 | 1.370207549 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 0, 'return_indices': True} 121 | 0.349546 | 0.305763 | 1.143192603 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 0, 'return_indices': False} 122 | 0.518929 | 0.377487 | 1.374693698 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 1, 'return_indices': True} 123 | 0.364662 | 0.3145 | 1.159497615 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 1, 'return_indices': False} 124 | 0.521275 | 0.375242 | 1.389170189 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 2, 'return_indices': True} 125 | 0.367488 | 0.308354 | 1.191773092 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 2, 'stride': 4, 'padding': 2, 'return_indices': False} 126 | 0.652342 | 0.569308 | 1.145850752 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 0, 'return_indices': True} 127 | 0.555696 | 0.506892 | 1.096280865 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 0, 'return_indices': False} 128 | 0.654333 | 0.570367 | 1.147213987 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 1, 'return_indices': True} 129 | 0.548925 | 0.505825 | 1.085207335 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 1, 'return_indices': False} 130 | 0.655908 | 0.571904 | 1.146884792 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 2, 'return_indices': True} 131 | 0.560808 | 0.508238 | 1.103435792 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': None, 'padding': 2, 'return_indices': False} 132 | 6.949462 | 6.949112 | 1.000050366 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 0, 'return_indices': True} 133 | 4.072913 | 4.065013 | 1.001943413 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 0, 'return_indices': False} 134 | 7.200896 | 7.197792 | 1.000431243 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 1, 'return_indices': True} 135 | 4.291367 | 4.218538 | 1.017264038 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 1, 'return_indices': False} 136 | 7.1823 | 7.306933 | 0.9829431856 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 2, 'return_indices': True} 137 | 4.151175 | 4.149592 | 1.000381483 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 1, 'padding': 2, 'return_indices': False} 138 | 1.781279 | 0.884288 | 2.014365229 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 0, 'return_indices': True} 139 | 1.050804 | 0.774362 | 1.356993241 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 0, 'return_indices': False} 140 | 1.860758 | 0.884637 | 2.103414169 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 1, 'return_indices': True} 141 | 1.099908 | 0.775887 | 1.417613647 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 1, 'return_indices': False} 142 | 1.857387 | 0.885738 | 2.096993693 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 2, 'return_indices': True} 143 | 1.105279 | 0.77365 | 1.428655077 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 2, 'padding': 2, 'return_indices': False} 144 | 0.489408 | 0.269583 | 1.815426047 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 0, 'return_indices': True} 145 | 0.322525 | 0.236979 | 1.360985573 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 0, 'return_indices': False} 146 | 0.515475 | 0.265813 | 1.93923924 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 1, 'return_indices': True} 147 | 0.315525 | 0.228146 | 1.382995976 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 1, 'return_indices': False} 148 | 0.503438 | 0.277204 | 1.816128194 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 2, 'return_indices': True} 149 | 0.335421 | 0.228275 | 1.469372467 |   | (3, 2000, 2000), {'kernel_size': 5, 'dilation': 4, 'stride': 4, 'padding': 2, 'return_indices': False} 150 | 5.72495 | 4.909554 | 1.166083518 |   | (10, 10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': None, 'return_indices': True} 151 | 4.45215 | 4.251333 | 1.047236243 |   | (10, 10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': None, 'return_indices': False} 152 | 29.953021 | 29.879879 | 1.002447868 |   | (10, 10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': 1, 'return_indices': True} 153 | 9.854683 | 9.839517 | 1.001541336 |   | (10, 10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': 1, 'return_indices': False} 154 | 6.178033 | 5.697375 | 1.084364817 |   | (10, 10, 1000, 1000), {'kernel_size': 100, 'padding': 50, 'return_indices': True} 155 | 6.280317 | 5.712525 | 1.099394226 |   | (10, 10, 1000, 1000), {'kernel_size': 100, 'padding': 50, 'return_indices': False} 156 | 10.256062 | 11.336527 | 0.9046917103 |   | (10, 10, 1000, 1000), {'kernel_size': 250, 'padding': 50, 'return_indices': True} 157 | 9.469546 | 11.33705 | 0.8352742556 |   | (10, 10, 1000, 1000), {'kernel_size': 250, 'padding': 50, 'return_indices': False} 158 | 0.119087 | 0.0797 | 1.494190715 |   | (10, 10, 100, 100), {'kernel_size': 2, 'return_indices': True} 159 | 0.098713 | 0.047173 | 2.092574142 |   | (10, 10, 100, 100), {'kernel_size': 2, 'return_indices': False} 160 | 0.960812 | 0.675762 | 1.421820108 |   | (10, 10, 300, 300), {'kernel_size': 2, 'return_indices': True} 161 | 0.536546 | 0.485958 | 1.104099531 |   | (10, 10, 300, 300), {'kernel_size': 2, 'return_indices': False} 162 | 2.555225 | 1.791567 | 1.426251432 |   | (10, 10, 500, 500), {'kernel_size': 2, 'return_indices': True} 163 | 1.419087 | 1.305137 | 1.087308842 |   | (10, 10, 500, 500), {'kernel_size': 2, 'return_indices': False} 164 | 5.182008 | 3.48085 | 1.488719135 |   | (10, 10, 700, 700), {'kernel_size': 2, 'return_indices': True} 165 | 2.831779 | 2.498537 | 1.133374851 |   | (10, 10, 700, 700), {'kernel_size': 2, 'return_indices': False} 166 | 8.546038 | 5.7783 | 1.478988284 |   | (10, 10, 900, 900), {'kernel_size': 2, 'return_indices': True} 167 | 4.731004 | 4.161975 | 1.136720908 |   | (10, 10, 900, 900), {'kernel_size': 2, 'return_indices': False} 168 | 0.084754 | 0.07435 | 1.139932751 |   | (10, 10, 100, 100), {'kernel_size': 2, 'return_indices': True} 169 | 0.057933 | 0.043096 | 1.344277891 |   | (10, 10, 100, 100), {'kernel_size': 2, 'return_indices': False} 170 | 2.568592 | 1.802117 | 1.425319222 |   | (10, 10, 500, 500), {'kernel_size': 2, 'return_indices': True} 171 | 1.433054 | 1.307342 | 1.096158465 |   | (10, 10, 500, 500), {'kernel_size': 2, 'return_indices': False} 172 | 10.3213 | 7.111604 | 1.451332217 |   | (10, 10, 1000, 1000), {'kernel_size': 2, 'return_indices': True} 173 | 5.680525 | 5.168129 | 1.099145358 |   | (10, 10, 1000, 1000), {'kernel_size': 2, 'return_indices': False} 174 | 1.02255 | 1.01375 | 1.008680641 |   | (10, 1000, 1000), {'kernel_size': 2, 'padding': 1, 'stride': 1, 'return_indices': False} 175 | 3.074233 | 3.094383 | 0.993488201 |   | (10, 1000, 1000), {'kernel_size': 2, 'padding': 1, 'stride': 1, 'return_indices': True} 176 | 1.016812 | 1.030575 | 0.9866453194 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': 1, 'return_indices': False} 177 | 3.053658 | 3.089504 | 0.9883974903 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': 1, 'return_indices': True} 178 | 1.025863 | 1.032088 | 0.9939685376 |   | (10, 1000, 1000), {'kernel_size': 8, 'padding': 1, 'stride': 1, 'return_indices': False} 179 | 3.798942 | 3.799213 | 0.9999286694 |   | (10, 1000, 1000), {'kernel_size': 8, 'padding': 1, 'stride': 1, 'return_indices': True} 180 | 4.492979 | 4.493421 | 0.999901634 |   | (10, 1000, 1000), {'kernel_size': 16, 'padding': 1, 'stride': 1, 'return_indices': False} 181 | 51.543363 | 51.266204 | 1.005406271 |   | (10, 1000, 1000), {'kernel_size': 16, 'padding': 1, 'stride': 1, 'return_indices': True} 182 | 1.018008 | 1.001587 | 1.016394981 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 0, 'stride': (1, 1), 'return_indices': False} 183 | 3.035404 | 3.003113 | 1.010752509 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 0, 'stride': (1, 1), 'return_indices': True} 184 | 0.610421 | 0.56 | 1.0900375 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 0, 'stride': (1, 4), 'return_indices': False} 185 | 1.138983 | 0.757296 | 1.504012962 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 0, 'stride': (1, 4), 'return_indices': True} 186 | 0.641558 | 0.557808 | 1.150141267 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 0, 'stride': (4, 1), 'return_indices': False} 187 | 1.181475 | 0.754725 | 1.565437742 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 0, 'stride': (4, 1), 'return_indices': True} 188 | 1.03045 | 1.026904 | 1.003453098 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': (1, 1), 'return_indices': False} 189 | 3.041421 | 3.0263 | 1.00499653 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': (1, 1), 'return_indices': True} 190 | 0.609929 | 0.572304 | 1.065743032 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': (1, 4), 'return_indices': False} 191 | 1.146875 | 0.756446 | 1.516135983 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': (1, 4), 'return_indices': True} 192 | 0.645187 | 0.561708 | 1.148616363 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': (4, 1), 'return_indices': False} 193 | 1.181721 | 0.758054 | 1.558887625 |   | (10, 1000, 1000), {'kernel_size': 4, 'padding': 1, 'stride': (4, 1), 'return_indices': True} 194 | 0.927654 | 0.925946 | 1.0018446 |   | (10, 1000, 1000), {'kernel_size': 1, 'return_indices': False} 195 | 2.749983 | 2.740354 | 1.00351378 |   | (10, 1000, 1000), {'kernel_size': 1, 'return_indices': True}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157876 Approved by: https://github.com/malfet --- .../src/ATen/native/mps/kernels/Pooling.metal | 102 +++++++++++-- .../src/ATen/native/mps/operations/Pooling.mm | 142 ++++++++++++------ 2 files changed, 186 insertions(+), 58 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/Pooling.metal b/aten/src/ATen/native/mps/kernels/Pooling.metal index 4eec3ed4d1b6e..45a8d680afcd0 100644 --- a/aten/src/ATen/native/mps/kernels/Pooling.metal +++ b/aten/src/ATen/native/mps/kernels/Pooling.metal @@ -88,6 +88,53 @@ void max_pool_3d_input_iter( } } +template +void max_pool_2d_input_iter( + constant T* input, + device T* output, + device int64_t* indices, + constant int32_t* input_sizes, + constant int32_t* input_strides, + thread int32_t (&pooling_dim_indices)[3], + constant int32_t* kernel_size, + constant int32_t* stride, + constant int32_t* padding, + constant int32_t* dilation) { + auto bounds0 = get_input_iter_bounds<0>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + auto bounds1 = get_input_iter_bounds<1>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + + auto d0 = dilation[0]; + auto d1 = dilation[1]; + + T max_value = input + [input_strides[0] * bounds0.start + input_strides[1] * bounds1.start]; + auto max_index = bounds0.start * input_sizes[1] + bounds1.start; + + for (auto i0 = bounds0.start; i0 < bounds0.end; i0 += d0) { + auto offset0 = input_strides[0] * i0; + + for (auto i1 = bounds1.start; i1 < bounds1.end; i1 += d1) { + auto offset1 = input_strides[1] * i1; + + auto input_value = input[offset0 + offset1]; + bool is_greater = input_value > max_value; + + max_value = is_greater ? input_value : max_value; + + if (return_indices) { + auto input_index = i0 * input_sizes[1] + i1; + max_index = is_greater ? input_index : max_index; + } + } + } + *output = max_value; + if (return_indices) { + *indices = max_index; + } +} + struct PoolOffsets { int32_t output; int32_t indices; @@ -212,7 +259,7 @@ kernel void max_pool( PoolOffsets offsets = find_pool_offsets( output_sizes, output_strides, - indices_strides, + return_indices ? indices_strides : nullptr, input_strides, pooling_dim_indices, dims, @@ -224,18 +271,47 @@ kernel void max_pool( indices += offsets.indices; input += offsets.input_leading; - max_pool_3d_input_iter( - input, - output, - indices, - input_sizes + leading_dims, - input_strides + leading_dims, - pooling_dim_indices, - kernel_size, - stride, - padding, - dilation, - return_indices); + switch (pooling_dims) { + case 2: + if (return_indices) { + return max_pool_2d_input_iter( + input, + output, + indices, + input_sizes + leading_dims, + input_strides + leading_dims, + pooling_dim_indices, + kernel_size, + stride, + padding, + dilation); + } else { + return max_pool_2d_input_iter( + input, + output, + indices, + input_sizes + leading_dims, + input_strides + leading_dims, + pooling_dim_indices, + kernel_size, + stride, + padding, + dilation); + } + case 3: + return max_pool_3d_input_iter( + input, + output, + indices, + input_sizes + leading_dims, + input_strides + leading_dims, + pooling_dim_indices, + kernel_size, + stride, + padding, + dilation, + return_indices); + } } // Finds the element in the grad input which corresponds to the index into the diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index b2bc870844a88..6ae3122cf3d19 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -297,13 +297,13 @@ static PoolSizes process_pool_sizes(const Tensor& input, pooling_dims, " ints"); - TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3, + TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == pooling_dims, op_name, ": stride must either be omitted, a single int, or a tuple of ", pooling_dims, " ints"); - TORCH_CHECK(padding.size() == 1 || padding.size() == 3, + TORCH_CHECK(padding.size() == 1 || padding.size() == pooling_dims, op_name, ": padding must either be a single int, or a tuple of ", pooling_dims, @@ -333,6 +333,22 @@ static PoolSizes process_pool_sizes(const Tensor& input, ": pad should be at most half of effective kernel size"); } + if (pooling_dims == 2) { + const auto memory_format = input.suggest_memory_format(); + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + if (memory_format == at::MemoryFormat::ChannelsLast) { + // Expect tensor in NHWC format and allow 0-dim only for N. + TORCH_CHECK((dims == 4 && valid_dims && input.size(3) != 0), + "Expected 4D (batch mode) tensor expected for input with channels_last layout" + " with optional 0 dim batch size for input, but got: ", + input.sizes()); + } else { + TORCH_CHECK((dims == 3 && input.size(0) != 0 && valid_dims) || (dims == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:", + input.sizes()); + } + } + for (const auto dim : c10::irange(static_cast(leading_dims == 2), dims)) { TORCH_CHECK(input.size(dim) > 0, op_name, ": Expected input's non-batch dimensions to have positive length"); } @@ -786,6 +802,16 @@ static void avg_pool_backward_out_mps_template(const Tensor& grad_input, } // namespace mps +// TODO: The MPS graph impl can sometimes give significantly better performance +// than the Metal impl for cases where the stride is 1 in all dimensions. There +// may be a code path in the graph kernel that specifically optimizes for that +// case. We should look into implementing a specialized case in Metal so we can +// avoid using the graph impl. +static bool use_graph_for_max_pool2d(IntArrayRef kernel_size, IntArrayRef stride_) { + IntArrayRef stride = stride_.empty() ? kernel_size : stride_; + return (stride[0] == 1) && (stride.size() == 1 || stride[1] == 1); +} + Tensor mps_max_pool2d(const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -793,24 +819,37 @@ Tensor mps_max_pool2d(const Tensor& input, IntArrayRef dilation, bool ceil_mode) { Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous); - mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { - MPSGraph* mpsGraph = cachedGraph.graph(); - return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil]; - }; - mps::pool2d_template(input, - output, - std::nullopt, - std::nullopt, - kernel_size, - stride, - padding, - dilation, - ceil_mode, - false, - std::nullopt, - pooling_op_block, - "max_pool2d"); - + bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); + if (use_graph) { + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil]; + }; + mps::pool2d_template(input, + output, + std::nullopt, + std::nullopt, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + std::nullopt, + pooling_op_block, + "max_pool2d"); + } else { + mps::max_pool_with_indices_out_mps_template(output, + std::nullopt, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + /*pooling_dims=*/2, + "max_pool2d"); + } return output; } @@ -855,32 +894,45 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& output, const Tensor& indices) { - auto indices_memory_format = indices.suggest_memory_format(); - - mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { - MPSGraph* mpsGraph = cachedGraph.graph(); - NSArray* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor - descriptor:desc - name:nil]; - cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long); - return poolOutputs[0]; - }; - mps::pool2d_template(input, - output, - indices, - std::nullopt, - kernel_size, - stride, - padding, - dilation, - ceil_mode, - false, - std::nullopt, - pooling_op_block, - "max_pool2d_indices"); + bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); + if (use_graph) { + auto indices_memory_format = indices.suggest_memory_format(); + + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + NSArray* poolOutputs = + [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil]; + cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long); + return poolOutputs[0]; + }; + mps::pool2d_template(input, + output, + indices, + std::nullopt, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + std::nullopt, + pooling_op_block, + "max_pool2d_indices"); + if (indices_memory_format == MemoryFormat::ChannelsLast) { + const_cast(indices) = indices.to(MemoryFormat::ChannelsLast); + } - if (indices_memory_format == MemoryFormat::ChannelsLast) { - const_cast(indices) = indices.to(MemoryFormat::ChannelsLast); + } else { + mps::max_pool_with_indices_out_mps_template(output, + indices, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + /*pooling_dims=*/2, + "max_pool2d"); } } From c5ec5458a547f7a774468ea0eb2258d3de596492 Mon Sep 17 00:00:00 2001 From: albanD Date: Fri, 8 Aug 2025 17:19:12 +0000 Subject: [PATCH 0152/1424] Don't build nccl when distributed is disabled (#160086) Because distributed doesn't build on recent compilers, I have to disable distributed, but this makes it still fail as nccl is still built Pull Request resolved: https://github.com/pytorch/pytorch/pull/160086 Approved by: https://github.com/Skylion007, https://github.com/janeyx99 --- CMakeLists.txt | 4 ++-- tools/build_pytorch_libs.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c6b662fd69c3a..558bdf2be3ee3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -260,8 +260,9 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) +option(USE_DISTRIBUTED "Use distributed" ON) cmake_dependent_option(USE_NCCL "Use NCCL" ON - "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) + "USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_XCCL "Use XCCL" ON "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) @@ -322,7 +323,6 @@ set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN}) cmake_dependent_option(USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF "USE_MKLDNN" OFF) option(USE_STATIC_MKL "Prefer to link with MKL statically (Unix only)" OFF) -option(USE_DISTRIBUTED "Use distributed" ON) cmake_dependent_option( USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 457b224354fb2..9d43de80f1298 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -88,7 +88,8 @@ def build_pytorch( ) -> None: my_env = _create_build_env() if ( - not check_negative_env_flag("USE_CUDA") + not check_negative_env_flag("USE_DISTRIBUTED") + and not check_negative_env_flag("USE_CUDA") and not check_negative_env_flag("USE_NCCL") and not check_env_flag("USE_SYSTEM_NCCL") ): From d7114f05b10de8e6de81ffc567d63944c3117d51 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 8 Aug 2025 15:17:56 +0000 Subject: [PATCH 0153/1424] Add DeviceAllocator as the base device allocator (#138222) # Motivation In line with [RFC] [A device-agnostic Python device memory related API design for stream-based accelerators](https://github.com/pytorch/pytorch/issues/134978), some memory-related APIs are widely used in popular repositories, such as HuggingFace [so many if-else conditional code](https://github.com/search?q=repo%3Ahuggingface%2Faccelerate%20torch.cuda.empty_cache&type=code). We would like to introduce a generic API set under torch.accelerator namespace to generalize these user cases.
Device-specific memory APIs torch.xxx.foo Device-agnostic memory APIs torch.accelerator.foo
```python torch.xxx.empty_cache ``` ```python torch.accelerator.empty_cache ```
```python torch.xxx.reset_peak_memory_stats ``` ```python torch.accelerator.reset_peak_memory_stats ```
```python torch.xxx.reset_accumulated_memory_stats ``` ```python torch.accelerator.reset_accumulated_memory_stats ```
```python torch.xxx.memory_stats ``` ```python torch.accelerator.memory_stats ```
```python torch.xxx.memory_allocated ``` ```python torch.accelerator.memory_allocated ```
```python torch.xxx.max_memory_allocated ``` ```python torch.accelerator.max_memory_allocated ```
```python torch.xxx.memory_reserved ``` ```python torch.accelerator.memory_reserved ```
```python torch.xxx.max_memory_reserved ``` ```python torch.accelerator.max_memory_reserved ```
# Solution This design follows a similar pattern to `HostAllocator`. We're introducing a base class `DeviceAllocator`, from which `CUDAAllocator` and `XPUAllocator` will inherit. This allows us to provide a unified call path like: `torch.accelerator.empty_cache()` -> `GetDeviceAllocator(allocator)->empty_cache()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138222 Approved by: https://github.com/albanD, https://github.com/Camyll --- aten/src/ATen/cuda/CUDAGraph.cpp | 1 - aten/src/ATen/cuda/CUDAGraph.h | 1 + c10/core/CachingDeviceAllocator.cpp | 10 ++++++ c10/core/CachingDeviceAllocator.h | 53 +++++++++++++++++++++++++++++ c10/cuda/CUDACachingAllocator.cpp | 11 ++++++ c10/cuda/CUDACachingAllocator.h | 19 ++++++----- c10/cuda/CUDAGraphsC10Utils.h | 6 ---- c10/xpu/XPUCachingAllocator.cpp | 19 +++++++---- 8 files changed, 98 insertions(+), 22 deletions(-) create mode 100644 c10/core/CachingDeviceAllocator.cpp diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 7fba7c4c7424c..2800e505a9b76 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index c8cae16b624fe..4f2aa31dd1c35 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/c10/core/CachingDeviceAllocator.cpp b/c10/core/CachingDeviceAllocator.cpp new file mode 100644 index 0000000000000..582efd59cf1b1 --- /dev/null +++ b/c10/core/CachingDeviceAllocator.cpp @@ -0,0 +1,10 @@ +#include + +namespace c10 { + +// Ensures proper DLL export of this pure virtual base class on Windows, +// since it's mainly used in other DLLs outside c10.dll. +DeviceAllocator::DeviceAllocator() = default; +DeviceAllocator::~DeviceAllocator() = default; + +} // namespace c10 diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index b23490de693a8..0bec03ae417fa 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace c10::CachingDeviceAllocator { @@ -59,3 +60,55 @@ struct DeviceStats { }; } // namespace c10::CachingDeviceAllocator + +namespace c10 { + +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by Graph mode capture_begin. +// second is set if the instance is created by Graph mode graph_pool_handle. +using MempoolId_t = std::pair; + +struct C10_API DeviceAllocator : public c10::Allocator { + DeviceAllocator(); + ~DeviceAllocator() override; + + // Returns true if the allocator has been properly initialized and is ready + // for use + virtual bool initialized() = 0; + + // Releases all cached device memory from the specified memory pool back to + // the system + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; + + // Associates a memory allocation with a stream to establish dependency + // tracking. Prevents memory reuse until all operations on the specified + // stream complete + virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0; + + // Retrieves comprehensive memory statistics for the specified device, + // including allocation patterns, usage metrics + virtual CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) = 0; + + // Resets cumulative allocation statistics for the specified device to zero + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + + // Resets peak memory usage statistics for the specified device + virtual void resetPeakStats(c10::DeviceIndex device) = 0; +}; + +// This function is used to get the DeviceAllocator for a specific device type +// and keep backward compatibility with c10::GetAllocator. +C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) { + TORCH_CHECK( + t != DeviceType::CPU, + "getDeviceAllocator is not supported for CPU device type."); + auto* allocator = c10::GetAllocator(t); + auto* device_allocator = dynamic_cast(allocator); + TORCH_INTERNAL_ASSERT( + device_allocator, "Allocator for ", t, " is not a DeviceAllocator."); + return device_allocator; +} + +} // namespace c10 diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index c2a46ac9f3f74..59b62dcac07f0 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -4118,7 +4118,18 @@ struct BackendStaticInitializer { BackendStaticInitializer() { auto r = parseEnvForBackend(); +// Register this HIP allocator as the CUDA allocator to allow it to work +// with both c10::GetAllocator(kCUDA) and c10::getDeviceAllocator(kCUDA) +// APIs. We don't perform this masquerading inside +// HIPAllocatorMasqueradingAsCUDA because it needs to happen during static +// initialization, and doing so there may introduce static initialization +// order (SIOF) issues. +#define HIP_MASQUERADING_AS_CUDA \ + "cud" \ + "a" + at::SetAllocator(c10::Device(HIP_MASQUERADING_AS_CUDA).type(), r, 0); allocator.store(r); +#undef HIP_MASQUERADING_AS_CUDA } }; diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 956411fe22827..75a2d4c8e481b 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -202,25 +202,24 @@ struct ShareableHandle { std::string handle; }; -class CUDAAllocator : public Allocator { +class CUDAAllocator : public DeviceAllocator { public: virtual void* raw_alloc(size_t nbytes) = 0; virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; virtual void raw_delete(void* ptr) = 0; virtual void init(int device_count) = 0; - virtual bool initialized() = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; - virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; - virtual void recordStream(const DataPtr&, CUDAStream stream) = 0; - virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device) = 0; - virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; - virtual void resetPeakStats(c10::DeviceIndex device) = 0; + // Keep for BC only + virtual void recordStream(const DataPtr& ptr, CUDAStream stream) = 0; + void recordStream(const DataPtr& ptr, c10::Stream stream) override { + CUDAStream cuda_stream = CUDAStream(stream); + recordStream(ptr, cuda_stream); + } virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, @@ -525,6 +524,10 @@ inline void enablePeerAccess( namespace c10::cuda { +// Keep BC only +using c10::CaptureId_t; +using c10::MempoolId_t; + // MemPool represents a pool of memory in a caching allocator. Currently, // it's just the ID of the pool object maintained in the CUDACachingAllocator. // diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index eb29ca8bc9f02..936875fd71d5c 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -9,12 +9,6 @@ namespace c10::cuda { -using CaptureId_t = unsigned long long; - -// first is set if the instance is created by CUDAGraph::capture_begin. -// second is set if the instance is created by at::cuda::graph_pool_handle. -using MempoolId_t = std::pair; - // RAII guard for "cudaStreamCaptureMode", a thread-local value // that controls the error-checking strictness of a capture. struct C10_CUDA_API CUDAStreamCaptureModeGuard { diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index afae32d92a4b4..04ab3cabcbc2b 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -539,7 +539,7 @@ class DeviceCachingAllocator { static void local_raw_delete(void* ptr); -class XPUAllocator : public Allocator { +class XPUAllocator : public DeviceAllocator { private: std::mutex mutex; ska::flat_hash_map allocated_blocks; @@ -575,6 +575,10 @@ class XPUAllocator : public Allocator { } } + bool initialized() override { + return !device_allocators.empty(); + } + void malloc( void** devPtr, DeviceIndex device, @@ -609,13 +613,13 @@ class XPUAllocator : public Allocator { } } - void emptyCache() { + void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override { for (auto& da : device_allocators) { da->emptyCache(); } } - void recordStream(const DataPtr& ptr, XPUStream stream) { + void recordStream(const DataPtr& ptr, c10::Stream stream) override { if (!ptr.get()) { return; } @@ -625,7 +629,8 @@ class XPUAllocator : public Allocator { Block* block = get_allocated_block(ptr.get()); TORCH_CHECK(block, "No allocated block can be found."); - device_allocators[block->device]->recordStream(block, stream); + c10::xpu::XPUStream xpu_stream{stream}; + device_allocators[block->device]->recordStream(block, xpu_stream); } DataPtr allocate(size_t size) override { @@ -678,17 +683,17 @@ class XPUAllocator : public Allocator { ": did you call init?"); } - DeviceStats getDeviceStats(DeviceIndex device) { + DeviceStats getDeviceStats(DeviceIndex device) override { assertValidDevice(device); return device_allocators[device]->getStats(); } - void resetPeakStats(DeviceIndex device) { + void resetPeakStats(DeviceIndex device) override { assertValidDevice(device); device_allocators[device]->resetPeakStats(); } - void resetAccumulatedStats(DeviceIndex device) { + void resetAccumulatedStats(DeviceIndex device) override { assertValidDevice(device); device_allocators[device]->resetAccumulatedStats(); } From 84f7e88aef091822f1feb1e71833571738db18fd Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 8 Aug 2025 15:17:57 +0000 Subject: [PATCH 0154/1424] Add unified memory APIs for torch.accelerator (#152932) # Motivation The following API will be put under torch.accelerator - empty_cache - max_memory_allocated - max_memory_reserved - memory_allocated - memory_reserved - memory_stats - reset_accumulated_memory_stats - reset_peak_memory_stats Pull Request resolved: https://github.com/pytorch/pytorch/pull/152932 Approved by: https://github.com/albanD ghstack dependencies: #138222 --- aten/src/ATen/DeviceAccelerator.h | 22 ++++ docs/source/accelerator.md | 23 ++++ torch/_C/__init__.pyi.in | 5 + torch/accelerator/__init__.py | 18 +++ torch/accelerator/memory.py | 201 ++++++++++++++++++++++++++++++ torch/csrc/DeviceAccelerator.cpp | 64 ++++++++++ torch/cuda/memory.py | 4 +- 7 files changed, 335 insertions(+), 2 deletions(-) create mode 100644 torch/accelerator/memory.py diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f37e492c861fe..f23b35047fcc8 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -72,6 +73,27 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +TORCH_API inline void emptyCache() { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->emptyCache(); +} + +TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); +} + +TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); +} + +TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetPeakStats(device_index); +} + } // namespace at::accelerator namespace at { diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index c6f2fb1080400..ce593a9acf518 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -25,3 +25,26 @@ synchronize device_index ``` + +```{eval-rst} +.. automodule:: torch.accelerator.memory +``` +```{eval-rst} +.. currentmodule:: torch.accelerator.memory +``` + +## Memory management +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + max_memory_allocated + max_memory_reserved + memory_allocated + memory_reserved + memory_stats + reset_accumulated_memory_stats + reset_peak_memory_stats +``` diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9e03c7dba8305..fb7e9c5ce56e0 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2435,6 +2435,11 @@ def _accelerator_synchronizeDevice(device_index: _int) -> None: ... def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... def _accelerator_setAllocatorSettings(env: str) -> None: ... +def _accelerator_isAllocatorInitialized() -> _bool: ... +def _accelerator_emptyCache() -> None: ... +def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... +def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... +def _accelerator_resetPeakStats(device_index: _int) -> None: ... # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index e9e48f1cf3061..4d1a78df1f74c 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -8,6 +8,16 @@ import torch from ._utils import _device_t, _get_device_index +from .memory import ( + empty_cache, + max_memory_allocated, + max_memory_reserved, + memory_allocated, + memory_reserved, + memory_stats, + reset_accumulated_memory_stats, + reset_peak_memory_stats, +) __all__ = [ @@ -15,9 +25,17 @@ "current_device_idx", # deprecated "current_device_index", "current_stream", + "empty_cache", "device_count", "device_index", "is_available", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", "set_device_idx", # deprecated "set_device_index", "set_stream", diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py new file mode 100644 index 0000000000000..d34a11a3a02e5 --- /dev/null +++ b/torch/accelerator/memory.py @@ -0,0 +1,201 @@ +from collections import OrderedDict +from typing import Any + +import torch + +from ._utils import _device_t, _get_device_index + + +__all__ = [ + "empty_cache", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", +] + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other application. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return + torch._C._accelerator_emptyCache() + + +def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: + r"""Return a dictionary of accelerator device memory allocator statistics for a given device index. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from device memory allocation. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of June 2025, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of June 2025, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed device memory allocation calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of device memory allocation calls. + - ``"num_device_free"``: number of device memory free calls. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return OrderedDict() + device_index = _get_device_index(device_index, optional=True) + stats = torch._C._accelerator_getDeviceStats(device_index) + flat_stats = [] + + def flatten(prefix: str, value: Any) -> None: + if isinstance(value, dict): + for k, v in value.items(): + nested_prefix = f"{prefix}.{k}" if prefix else k + flatten(nested_prefix, v) + else: + flat_stats.append((prefix, value)) + + flatten("", stats) + flat_stats.sort() + return OrderedDict(flat_stats) + + +def memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory occupied by tensors + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory occupied by tensors + in bytes for a given device index. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory managed by the caching allocator + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory managed by the caching allocator + in bytes for a given device index. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.peak", 0) + + +def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetAccumulatedStats(device_index) + + +def reset_peak_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "peak" stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetPeakStats(device_index) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 3a97c0794684f..59cb8047467c9 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -77,6 +77,70 @@ void initModule(PyObject* module) { m.def("_accelerator_setAllocatorSettings", [](std::string env) { c10::CachingAllocator::setAllocatorSettings(env); }); + + m.def("_accelerator_isAllocatorInitialized", []() { + const auto device_type = at::accelerator::getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->initialized(); + }); + + m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); }); + + m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) { + using c10::CachingAllocator::Stat; + using c10::CachingAllocator::StatArray; + using c10::CachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + + const auto stats = at::accelerator::getDeviceStats(device_index); + const auto stat_to_dict = [](const Stat& stat) -> py::dict { + py::dict dict; + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto stat_array_to_dict = [=](const StatArray& stats) -> py::dict { + const std::array(StatType::NUM_TYPES)> + kStatTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(kStatTypeNames.size())) { + dict[kStatTypeNames[i]] = stat_to_dict(stats[i]); + } + return dict; + }; + + py::dict result; + result["num_alloc_retries"] = stats.num_alloc_retries; + result["num_ooms"] = stats.num_ooms; + result["max_split_size"] = stats.max_split_size; + result["num_sync_all_streams"] = stats.num_sync_all_streams; + result["num_device_alloc"] = stats.num_device_alloc; + result["num_device_free"] = stats.num_device_free; + result["allocated_bytes"] = stat_array_to_dict(stats.allocated_bytes); + result["reserved_bytes"] = stat_array_to_dict(stats.reserved_bytes); + result["active_bytes"] = stat_array_to_dict(stats.active_bytes); + result["requested_bytes"] = stat_array_to_dict(stats.requested_bytes); + result["allocation"] = stat_array_to_dict(stats.allocation); + result["segment"] = stat_array_to_dict(stats.segment); + result["active"] = stat_array_to_dict(stats.active); + result["inactive_split"] = stat_array_to_dict(stats.inactive_split); + result["inactive_split_bytes"] = + stat_array_to_dict(stats.inactive_split_bytes); + result["oversize_allocations"] = stat_to_dict(stats.oversize_allocations); + result["oversize_segments"] = stat_to_dict(stats.oversize_segments); + return result; + }); + + m.def( + "_accelerator_resetAccumulatedStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetAccumulatedStats(device_index); + }); + + m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetPeakStats(device_index); + }); } } // namespace torch::accelerator diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 63e59096162fb..1bd6f9edc0319 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -255,9 +255,9 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: - ``all``: combined statistics across all memory pools. - ``large_pool``: statistics for the large allocation pool - (as of October 2019, for size >= 1MB allocations). + (as of June 2025, for size >= 1MB allocations). - ``small_pool``: statistics for the small allocation pool - (as of October 2019, for size < 1MB allocations). + (as of June 2025, for size < 1MB allocations). Metric type: From da1f608ca33f3062535d0a4866d95db19e72fcbd Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 8 Aug 2025 15:17:59 +0000 Subject: [PATCH 0155/1424] Add UT for torch.accelerator memory-related API (#155200) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155200 Approved by: https://github.com/albanD ghstack dependencies: #138222, #152932 --- test/test_accelerator.py | 78 ++++++++++++++++++++++++++++++++++++++++ test/test_cuda.py | 36 +++++++++++++++++++ test/test_xpu.py | 37 +++++++++++++++++++ 3 files changed, 151 insertions(+) diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 0ea224d704cb8..21731bd275b60 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -1,5 +1,6 @@ # Owner(s): ["module: tests"] +import gc import sys import unittest @@ -156,6 +157,83 @@ def test_generic_event_behavior(self): ): event1.elapsed_time(event2) + @unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!") + def test_memory_stats(self): + # Ensure that device allocator is initialized + acc = torch.accelerator.current_accelerator() + tmp = torch.randn(100, device=acc) + del tmp + gc.collect() + self.assertTrue(torch._C._accelerator_isAllocatorInitialized()) + torch.accelerator.empty_cache() + + pool_type = ["all", "small_pool", "large_pool"] + metric_type = ["peak", "current", "allocated", "freed"] + stats_type = [ + "allocated_bytes", + "reserved_bytes", + "active_bytes", + "requested_bytes", + ] + mem_stats = torch.accelerator.memory_stats() + expected_stats = [ + f"{st}.{pt}.{mt}" + for st in stats_type + for pt in pool_type + for mt in metric_type + ] + missing_stats = [stat for stat in expected_stats if stat not in mem_stats] + self.assertEqual( + len(missing_stats), + 0, + f"Missing expected memory statistics: {missing_stats}", + ) + + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertGreaterEqual(prev_allocated, 0) + self.assertGreaterEqual(prev_reserved, 0) + self.assertGreater(prev_max_allocated, 0) + self.assertGreater(prev_max_reserved, 0) + tmp = torch.ones(256, device=acc) + self.assertGreater(torch.accelerator.memory_allocated(), prev_allocated) + self.assertGreaterEqual(torch.accelerator.memory_reserved(), prev_reserved) + del tmp + gc.collect() + torch.accelerator.empty_cache() + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.memory_allocated(), prev_allocated) + self.assertEqual(torch.accelerator.memory_reserved(), prev_reserved) + torch.accelerator.reset_accumulated_memory_stats() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device=acc) + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + if __name__ == "__main__": run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index f2f3304069f1b..9755835853eed 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -373,6 +373,42 @@ def test_memory_allocation(self): torch.cuda.caching_allocator_delete(mem) self.assertEqual(torch.cuda.memory_allocated(), prev) + def test_memory_stats(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertEqual(prev_allocated, prev_max_allocated) + self.assertEqual(prev_reserved, prev_max_reserved) + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device="cuda") + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + def test_check_error(self): # Assert this call doesn't raise. torch.cuda.check_error(0) diff --git a/test/test_xpu.py b/test/test_xpu.py index cd5275418c440..beb5a53a4a6b3 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,5 +1,6 @@ # Owner(s): ["module: intel"] +import gc import re import subprocess import sys @@ -520,6 +521,42 @@ def test_device_memory_allocated(self): ) del a + def test_memory_stats(self): + gc.collect() + torch.xpu.empty_cache() + torch.xpu.reset_peak_memory_stats() + torch.xpu.reset_accumulated_memory_stats() + prev_allocated = torch.accelerator.memory_allocated() + prev_reserved = torch.accelerator.memory_reserved() + prev_max_allocated = torch.accelerator.max_memory_allocated() + prev_max_reserved = torch.accelerator.max_memory_reserved() + self.assertEqual(prev_allocated, prev_max_allocated) + self.assertEqual(prev_reserved, prev_max_reserved) + # Activate 1kB memory + prev_active_current = torch.accelerator.memory_stats()[ + "active_bytes.all.current" + ] + tmp = torch.randn(256, device="xpu") + # Detect if the current active memory is 1kB + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + 1024 + prev_active_current, + ) + self.assertEqual(torch.accelerator.memory_stats()["active_bytes.all.freed"], 0) + del tmp + gc.collect() + torch.accelerator.empty_cache() + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.current"], + prev_active_current, + ) + self.assertEqual( + torch.accelerator.memory_stats()["active_bytes.all.freed"], 1024 + ) + torch.accelerator.reset_peak_memory_stats() + self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) + self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + @skipXPUIf( int(torch.version.xpu) < 20250000, "Test requires SYCL compiler version 2025.0.0 or newer.", From 5f5f508aa836a46dfe88857fb223049616b94e93 Mon Sep 17 00:00:00 2001 From: Andres Lugo <108368282+alugorey@users.noreply.github.com> Date: Fri, 8 Aug 2025 18:40:17 +0000 Subject: [PATCH 0156/1424] [ROCm] Ck backend UX refactor (#152951) Refactors how the enablement/disablement of CK Gemms and SDPA works. - Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms. - USE_ROCM_CK_GEMM is set to True by default on Linux - Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA. - USE_ROCM_CK_SDPA is set to False by default - (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release) - Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it. - the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default Pull Request resolved: https://github.com/pytorch/pytorch/pull/152951 Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus Co-authored-by: Jeff Daily Co-authored-by: Jithun Nair Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> --- CMakeLists.txt | 2 + aten/src/ATen/CMakeLists.txt | 108 ++++++++++-------- aten/src/ATen/Context.cpp | 88 +++++++++----- aten/src/ATen/Context.h | 9 +- aten/src/ATen/cuda/CUDABlas.cpp | 10 +- aten/src/ATen/cuda/detail/CUDAHooks.cpp | 21 ++++ aten/src/ATen/cuda/detail/CUDAHooks.h | 2 + aten/src/ATen/detail/CUDAHooksInterface.h | 8 ++ aten/src/ATen/native/hip/ck_gemm.h | 3 +- aten/src/ATen/native/hip/ck_gemm_bfloat16.hip | 4 +- aten/src/ATen/native/hip/ck_gemm_float.hip | 2 + aten/src/ATen/native/hip/ck_gemm_half.hip | 2 + .../native/transformers/cuda/attention.cu | 2 +- .../transformers/cuda/attention_backward.cu | 2 +- .../hip/flash_attn/ck/me_bwd_ck.hip | 4 +- .../hip/flash_attn/ck/me_ck_api.h | 4 +- .../hip/flash_attn/ck/me_fwd_ck.hip | 4 +- .../transformers/hip/flash_attn/flash_api.h | 15 ++- caffe2/CMakeLists.txt | 4 +- cmake/Dependencies.cmake | 3 + cmake/Summary.cmake | 7 +- docs/source/notes/hip.rst | 27 +++++ setup.py | 6 + 23 files changed, 232 insertions(+), 105 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 558bdf2be3ee3..16fec0c80028c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,6 +240,8 @@ cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF) +cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF) +option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF) option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF) cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF) cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 547b36f10936f..5f4997357f826 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -180,26 +180,27 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") # if USE_FLASH_ATTENTION is set, ensure CK instances get generated if(USE_FLASH_ATTENTION) - if(DEFINED ENV{USE_CK_FLASH_ATTENTION}) - set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION}) - if(USE_CK_FLASH_ATTENTION STREQUAL "1") - if(DEFINED ENV{PYTORCH_ROCM_ARCH}) - list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) - if(NUM_ARCHS GREATER 1) - message(WARNING "Building CK for multiple archs can increase build time considerably! - Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") - endif() - endif() - message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") - message(STATUS "Generating CK kernel instances...") - add_subdirectory(native/transformers/hip/flash_attn/ck) - file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") - list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) - # FAv3 Generation - add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) - file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") - list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) + if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1") + message(STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead") + caffe2_update_option(USE_ROCM_CK_SDPA ON) + endif() + if(USE_ROCM_CK_SDPA) + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) + if(NUM_ARCHS GREATER 1) + message(WARNING "Building CK for multiple archs can increase build time considerably! + Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") endif() + endif() + message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled") + message(STATUS "Generating CK kernel instances...") + add_subdirectory(native/transformers/hip/flash_attn/ck) + file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + # FAv3 Generation + add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) + file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") @@ -418,40 +419,42 @@ if(USE_CUDA) endif() if(USE_ROCM) - # NOTE: The PyTorch build does not actually add_subdirectory - # third_party/composable_kernel or use it as a CMake library. What is used - # is header only, so this should be ok, except that the CMake build generates - # a ck/config.h. We just do that part here. Without this, the ck.h from the - # ROCM SDK may get accidentally used instead. - function(_pytorch_rocm_generate_ck_conf) - set(CK_ENABLE_INT8 "ON") - set(CK_ENABLE_FP16 "ON") - set(CK_ENABLE_FP32 "ON") - set(CK_ENABLE_FP64 "ON") - set(CK_ENABLE_BF16 "ON") - set(CK_ENABLE_FP8 "ON") - set(CK_ENABLE_BF8 "ON") - set(CK_USE_XDL "ON") - set(CK_USE_WMMA "ON") - configure_file( - "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in" - "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h" - ) - endfunction() - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include) - _pytorch_rocm_generate_ck_conf() + if((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM) + # NOTE: The PyTorch build does not actually add_subdirectory + # third_party/composable_kernel or use it as a CMake library. What is used + # is header only, so this should be ok, except that the CMake build generates + # a ck/config.h. We just do that part here. Without this, the ck.h from the + # ROCM SDK may get accidentally used instead. + function(_pytorch_rocm_generate_ck_conf) + set(CK_ENABLE_INT8 "ON") + set(CK_ENABLE_FP16 "ON") + set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_FP64 "ON") + set(CK_ENABLE_BF16 "ON") + set(CK_ENABLE_FP8 "ON") + set(CK_ENABLE_BF8 "ON") + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") + configure_file( + "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in" + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h" + ) + endfunction() + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include) + _pytorch_rocm_generate_ck_conf() + endif() # Next two lines are needed because TunableOp uses third-party/fmt list(APPEND ATen_HIP_INCLUDE $) list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only) -if(USE_FLASH_ATTENTION) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck) -endif() + if(USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck) + endif() list(APPEND ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} @@ -461,12 +464,17 @@ endif() ${native_quantized_hip_hip} ${native_transformers_hip_hip} ${native_transformers_src_hip_hip} ) - if(WIN32) # Windows doesn't support Composable Kernels + if(NOT USE_ROCM_CK_GEMM) file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") file(GLOB native_hip_ck "native/hip/ck*.hip") exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" ${native_hip_bgemm} ${native_hip_ck}) endif() + if(WIN32) # Windows doesn't support Composable Kernels and Triton + exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" + ${native_transformers_hip_hip} ${native_transformers_hip_cpp}) + endif() + # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources) list(APPEND all_hip_cpp ${native_nested_hip_cpp} diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 2b89a46ed9af8..30c2235131fb6 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -480,6 +480,9 @@ at::BlasBackend Context::blasPreferredBackend() { // call site for blasPreferredBackend(), we set it to an actual value. if (blas_preferred_backend == at::BlasBackend::Default) { blas_preferred_backend = at::BlasBackend::Cublas; + // This logic sits in the getter because it needs to validate + // values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT + // which initialize the backend without calling the setter #ifdef USE_ROCM // AMD Instinct targets prefer hipblaslt static const bool hipblaslt_preferred = []() { @@ -509,6 +512,10 @@ at::BlasBackend Context::blasPreferredBackend() { // hipblaslt support for all archs is not as complete as hipblas if (blas_preferred_backend == at::BlasBackend::Cublaslt) { static const bool hipblaslt_unsupported = []() { + if(!hasCuBLASLt()) + { + return true; + } static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 @@ -534,6 +541,24 @@ at::BlasBackend Context::blasPreferredBackend() { return blas_preferred_backend; } +bool Context::ckSupported() { +#ifdef USE_ROCM + static const std::vector supported_archs = { + "gfx90a", "gfx942", "gfx950" + }; + for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) { + if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) { + TORCH_WARN_ONCE( + "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); + return false; + } + } + return true; +#else + return false; +#endif +} + void Context::setBlasPreferredBackend(at::BlasBackend b) { #ifdef _MSC_VER TORCH_WARN_ONCE( @@ -543,8 +568,14 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #else TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(), "Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt."); - TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(), - "Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm."); +#ifdef USE_ROCM + static const bool ckSupportedFlag = ckSupported(); + static const bool hasCKGEMMFlag = hasCKGEMM(); + TORCH_CHECK((b != at::BlasBackend::Ck) || (ckSupportedFlag && hasCKGEMMFlag), + "Cannot set preferred blas backend to CK since following conditions are not true: ", + "architecture supported for CK: ", ckSupportedFlag, + ", PyTorch built with CK GEMM support: ", hasCKGEMMFlag); +#endif if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) { TORCH_WARN_ONCE( "torch.backends.cuda.preferred_blas_library is an experimental feature. " @@ -556,35 +587,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #endif } -at::ROCmFABackend Context::getROCmFAPreferredBackend() const { +at::ROCmFABackend Context::getROCmFAPreferredBackend() { +#ifdef USE_ROCM + // Set potential "Default" value so we don't have to interpret at call sites. + // We use aotriton backend as the default, for now. + if(rocm_fa_preferred_backend == at::ROCmFABackend::Default) { + rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton; + } else if (rocm_fa_preferred_backend == at::ROCmFABackend::Ck) { + // This logic sits in the getter because it needs to validate + // values set via env vars such as TORCH_ROCM_FA_PREFER_CK + // which initialize the backend without calling the setter + // Perform validity checking + static const bool hasCKSDPAFlag = hasCKSDPA(); + static const bool ckSupportedFlag = ckSupported(); + if(!(hasCKSDPAFlag && ckSupportedFlag)){ + TORCH_WARN_ONCE( + "Cannot set preferred SDPA backend to CK since following conditions are not true: ", + "architecture supported for CK: ", ckSupportedFlag, + ", PyTorch built with CK SDPA support: ", hasCKSDPAFlag); + rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton; + } + } +#endif + return rocm_fa_preferred_backend; } void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { - - // TODO: add plumbing for hasCK for validity checking - TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(), - "Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm."); #ifdef USE_ROCM - if(b == at::ROCmFABackend::Ck) { - static const bool ck_unsupported = []() { - static const std::vector archs = { - "gfx90a", "gfx942" - }; - for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(archs, index)) { - TORCH_WARN_ONCE( - "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); - return true; - } - } - return false; - }(); - if(!ck_unsupported) rocm_fa_preferred_backend = b; - } - else { - rocm_fa_preferred_backend = b; - } + static const bool hasCKSDPAFlag = hasCKSDPA(); + static const bool ckSupportedFlag = ckSupported(); + TORCH_CHECK((b != at::ROCmFABackend::Ck) || (hasCKSDPAFlag && ckSupportedFlag), + "Cannot set preferred SDPA backend to CK since following conditions are not true: ", + "architecture supported for CK: ", ckSupportedFlag, + ", PyTorch built with CK SDPA support: ", hasCKSDPAFlag); #endif rocm_fa_preferred_backend = b; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 945076f3f0124..2cc12a38a0b6e 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -132,6 +132,7 @@ class TORCH_API Context { static bool hasKleidiAI(); static bool hasLAPACK(); static bool hasMKLDNN(); + static bool ckSupported(); static bool hasMAGMA() { return detail::getCUDAHooks().hasMAGMA(); } @@ -162,6 +163,12 @@ class TORCH_API Context { static bool hasROCM() { return detail::getCUDAHooks().hasROCM(); } + static bool hasCKSDPA() { + return detail::getCUDAHooks().hasCKSDPA(); + } + static bool hasCKGEMM() { + return detail::getCUDAHooks().hasCKGEMM(); + } static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } @@ -252,7 +259,7 @@ class TORCH_API Context { at::BlasBackend blasPreferredBackend(); void setBlasPreferredBackend(at::BlasBackend); - at::ROCmFABackend getROCmFAPreferredBackend() const; + at::ROCmFABackend getROCmFAPreferredBackend(); void setROCmFAPreferredBackend(at::ROCmFABackend); // Note [Enabling Deterministic Operations] diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index cf403365b2df2..0dbae4aeed5b7 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -832,7 +832,7 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::bgemm_internal_ck(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } @@ -1273,7 +1273,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(double)); #endif } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(double)); } @@ -1289,7 +1289,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100 gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); @@ -1341,7 +1341,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::Half)); } @@ -1357,7 +1357,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::BFloat16)); } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 247fdb2537cb4..3dedf3fd64c72 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -207,6 +207,27 @@ bool CUDAHooks::hasCuBLASLt() const { #endif } + +bool CUDAHooks::hasCKSDPA() const { +#if !defined(USE_ROCM) + return false; +#elif defined(USE_ROCM) && defined(USE_ROCM_CK_SDPA) + return true; +#else + return false; +#endif +} + +bool CUDAHooks::hasCKGEMM() const { +#if !defined(USE_ROCM) + return false; +#elif defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) + return true; +#else + return false; +#endif +} + bool CUDAHooks::hasROCM() const { // Currently, this is same as `compiledWithMIOpen`. // But in future if there are ROCm builds without MIOpen, diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index b0dac7a71e809..2780369a37b71 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -31,6 +31,8 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCuSOLVER() const override; bool hasCuBLASLt() const override; bool hasROCM() const override; + bool hasCKSDPA() const override; + bool hasCKGEMM() const override; const at::cuda::NVRTC& nvrtc() const override; DeviceIndex current_device() const override; bool isBuilt() const override {return true;} diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index f99e03d156c9b..00573e3cf701b 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -118,6 +118,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { return false; } + virtual bool hasCKSDPA() const { + return false; + } + + virtual bool hasCKGEMM() const { + return false; + } + virtual const at::cuda::NVRTC& nvrtc() const { TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); } diff --git a/aten/src/ATen/native/hip/ck_gemm.h b/aten/src/ATen/native/hip/ck_gemm.h index 176cbabd5e01c..0d42cad56fcda 100644 --- a/aten/src/ATen/native/hip/ck_gemm.h +++ b/aten/src/ATen/native/hip/ck_gemm.h @@ -10,6 +10,7 @@ inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); } +#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); template <> @@ -18,7 +19,7 @@ template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); - +#endif } // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip index 79cb14be41031..7561cede386fb 100644 --- a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip +++ b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip @@ -1,6 +1,7 @@ #undef __HIP_NO_HALF_CONVERSIONS__ - #include + +#if defined(USE_ROCM_CK_GEMM) #include #include @@ -781,3 +782,4 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { } } // namespace at::native +#endif // USE_ROCM_CK_GEMM diff --git a/aten/src/ATen/native/hip/ck_gemm_float.hip b/aten/src/ATen/native/hip/ck_gemm_float.hip index b8301a47981c6..c4fea6088d3f0 100644 --- a/aten/src/ATen/native/hip/ck_gemm_float.hip +++ b/aten/src/ATen/native/hip/ck_gemm_float.hip @@ -1,6 +1,7 @@ #undef __HIP_NO_HALF_CONVERSIONS__ #include +#if defined(USE_ROCM_CK_GEMM) #include #include @@ -484,3 +485,4 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)) { } } // namespace at::native +#endif // USE_ROCM_CK_GEMM diff --git a/aten/src/ATen/native/hip/ck_gemm_half.hip b/aten/src/ATen/native/hip/ck_gemm_half.hip index 552f0de845418..ebe044c389721 100644 --- a/aten/src/ATen/native/hip/ck_gemm_half.hip +++ b/aten/src/ATen/native/hip/ck_gemm_half.hip @@ -1,6 +1,7 @@ #undef __HIP_NO_HALF_CONVERSIONS__ #include +#if defined(USE_ROCM_CK_GEMM) #include #include @@ -606,3 +607,4 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)) { } } // namespace at::native +#endif // USE_ROCM_CK_GEMM diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 80049aa9a832f..48899d4ce12fb 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1346,7 +1346,7 @@ std::tuple _efficient_ if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) std::optional out(res); std::optional seqused_k = std::nullopt; std::optional alibi_slopes = std::nullopt; diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 3888df64ad80b..c760ffe451053 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -431,7 +431,7 @@ _efficient_attention_backward( // ROCM Implementation if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float(); // Store grad_bias in optional std::optional opt_grad_bias = grad_bias; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip index 601ffd2d07525..59669afb93d2f 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip @@ -1,7 +1,7 @@ #include #include -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) namespace pytorch_flash { std::tuple< at::Tensor, // dQ @@ -117,4 +117,4 @@ mem_eff_backward_ck( } } // namespace pytorch_flash -#endif // USE_CK_FLASH_ATTENTION +#endif // USE_ROCM_CK_SDPA diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h index 6fd46467bc076..e92006ef6315c 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h @@ -3,7 +3,7 @@ #include -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) namespace pytorch_flash { std::tuple< @@ -64,4 +64,4 @@ mem_eff_backward_ck( const at::Tensor philox_offset); } // namespace pytorch_flash -#endif // USE_CK_FLASH_ATTENTION +#endif // USE_ROCM_CK_SDPA diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip index fac77821a56c1..d15c5105d0b46 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip @@ -1,7 +1,7 @@ #include #include -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) namespace pytorch_flash { std::tuple< at::Tensor, // output @@ -93,4 +93,4 @@ mem_eff_forward_ck( } } // namespace pytorch_flash -#endif // USE_CK_FLASH_ATTENTION +#endif // USE_ROCM_CK_SDPA diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index 17298aae9485d..f6f2240d4f091 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -147,7 +147,7 @@ std::tuple mha_varlen_bwd_aot( const at::Tensor& philox_seed, const at::Tensor& philox_offset); -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) // CK implementation TORCH_API std::tuple< @@ -295,7 +295,7 @@ mha_fwd( const float softcap, const bool return_softmax, std::optional gen_) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { const int non_null_window_left = window_size_left.value_or(-1); @@ -368,7 +368,7 @@ mha_varlen_fwd( const float softcap, const bool return_softmax, std::optional gen_) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional dummy_attn_bias = std::nullopt; @@ -441,9 +441,10 @@ inline std::tuple mha_bwd( const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { + +#if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { -#if defined(USE_CK_FLASH_ATTENTION) std::optional non_null_dbias = std::nullopt; const int non_null_window_left = window_size_left.value_or(-1); const int non_null_window_right = window_size_right.value_or(-1); @@ -474,10 +475,8 @@ inline std::tuple mha_bwd( philox_offset); // for FA return [dQ, dV, dK, dSoftmax] return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); -#else - TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); -#endif } +#endif return mha_bwd_aot( dout, q, @@ -530,7 +529,7 @@ inline std::tuple mha_varlen_bwd const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { -#if defined(USE_CK_FLASH_ATTENTION) +#if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional non_null_dbias = std::nullopt; diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 706b191e318e2..c346cedbcf519 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1446,8 +1446,8 @@ if(USE_ROCM) if(USE_MEM_EFF_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION) endif() - if(USE_CK_FLASH_ATTENTION) - target_compile_definitions(torch_hip PRIVATE USE_CK_FLASH_ATTENTION) + if(USE_ROCM_CK_SDPA) + target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_SDPA) endif() endif() diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index b7f545027b02d..8836b66bc0360 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1045,6 +1045,9 @@ if(USE_ROCM) if(HIPBLASLT_VEC_EXT) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT) endif() + if(USE_ROCM_CK_GEMM) + list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK_GEMM) + endif() list(APPEND HIP_HIPCC_FLAGS --offload-compress) if(WIN32) add_definitions(-DROCM_ON_WINDOWS) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 3c2ec74f14d17..24cfaa7f217d7 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -127,10 +127,11 @@ function(caffe2_print_configuration_summary) endif() message(STATUS " USE_ROCM : ${USE_ROCM}") if(${USE_ROCM}) - message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") - message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") - message(STATUS " USE_CK_FLASH_ATTENTION : ${USE_CK_FLASH_ATTENTION}") + message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") + message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") + message(STATUS " USE_ROCM_CK_SDPA : ${USE_ROCM_CK_SDPA}") + message(STATUS " USE_ROCM_CK_GEMM : ${USE_ROCM_CK_GEMM}") endif() message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") diff --git a/docs/source/notes/hip.rst b/docs/source/notes/hip.rst index a34535d67fc99..7ee596b53f9cc 100644 --- a/docs/source/notes/hip.rst +++ b/docs/source/notes/hip.rst @@ -179,3 +179,30 @@ by recompiling the PyTorch from source. Please add below line as an argument to cmake command parameters:: -DROCM_FORCE_ENABLE_GPU_ASSERTS:BOOL=ON + +Enabling/Disabling ROCm Composable Kernel +----------------------------------------- + +Enabling composable_kernel (CK) for both SDPA and GEMMs is a two-part process. First the user must have built +pytorch while setting the corresponding environment variable to '1' + +SDPA: +``USE_ROCM_CK_SDPA=1`` + +GEMMs: +``USE_ROCM_CK_GEMM=1`` + +Second, the user must explicitly request that CK be used as the backend library via the corresponding python +call + +SDPA: +``setROCmFAPreferredBackend('')`` + +GEMMs: +``setBlasPreferredBackend('')`` + +To enable CK in either scenario, simply pass 'ck' to those functions. + +In order to set the backend to CK, the user MUST have built with the correct environment variable. If not, +PyTorch will print a warning and use the "default" backend. For GEMMs, this will route to hipblas and +for SDPA it routes to aotriton. diff --git a/setup.py b/setup.py index e30896a2fdf4e..ad00317da0866 100644 --- a/setup.py +++ b/setup.py @@ -156,6 +156,12 @@ # USE_ROCM_KERNEL_ASSERT=1 # Enable kernel assert in ROCm platform # +# USE_ROCM_CK_GEMM=1 +# Enable building CK GEMM backend in ROCm platform +# +# USE_ROCM_CK_SDPA=1 +# Enable building CK SDPA backend in ROCm platform +# # Environment variables we respect (these environment variables are # conventional and are often understood/set by other software.) # From 72009ec6bebca7714f99c18449183787f202af4d Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Thu, 7 Aug 2025 13:08:12 -0700 Subject: [PATCH 0157/1424] [replicate][be] improved readability and cleaned up remaining DDP code (#160133) **Summary** As much of ReplicateState functionality is copied from FSDPState, I fixed any remaining comments that incorrectly used FSDP instead of Replicate. In addition, instead of labeling modules FSDPModule or FSDPLinear, I have changed it so that is now uses Replicate____. Finally, I have removed some leftover code from the DDP implementation. I have included test cases to verify correctness. **Test Case** 1. pytest test/distributed/_composable/test_replicate_with_fsdp.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/160133 Approved by: https://github.com/mori360 ghstack dependencies: #160128 --- .../_composable/replicate_with_fsdp.py | 36 +++++-------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/torch/distributed/_composable/replicate_with_fsdp.py b/torch/distributed/_composable/replicate_with_fsdp.py index b49d240e4d75e..219501a0a7086 100644 --- a/torch/distributed/_composable/replicate_with_fsdp.py +++ b/torch/distributed/_composable/replicate_with_fsdp.py @@ -43,7 +43,7 @@ from torch.distributed.tensor import Shard -cls_to_fsdp_cls: dict[type, type] = {} +cls_to_replicate_cls: dict[type, type] = {} _ROOT_MODULE_PREFIX = "" @@ -51,10 +51,10 @@ class _ReplicateStateContext: - """This has state shared across FSDP states.""" + """This has state shared across Replicate states.""" def __init__(self) -> None: - # All FSDP states in the root state's module tree + # All Replicate states in the root state's module tree self.all_states: list[_ReplicateState] = [] # Iteration's forward root runs the once-per-forward logic; this root # may not be the overall root set by lazy initialization in cases where @@ -173,7 +173,7 @@ def replicate_impl( offload_policy: OffloadPolicy = OffloadPolicy(), ignored_params: Optional[set[nn.Parameter]] = None, ): - torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard") + torch._C._log_api_usage_once("torch.distributed._composable.replicate_with_fsdp") if isinstance(module, (nn.ModuleList, nn.ModuleDict)): raise ValueError( f"replicate does not support containers that do not implement forward: {module}" @@ -224,11 +224,11 @@ def replicate_impl( # Place Replicate leftmost for highest priority in the method resolution order for module in modules: cls = module.__class__ - new_cls = cls_to_fsdp_cls.get(cls, None) + new_cls = cls_to_replicate_cls.get(cls, None) if not new_cls: dct = {"__deepcopy__": _unimplemented_deepcopy} - new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) - cls_to_fsdp_cls[cls] = new_cls + new_cls = type(f"Replicate{cls.__name__}", (FSDPModule, cls), dct) + cls_to_replicate_cls[cls] = new_cls module.__class__ = new_cls return arg_module @@ -262,27 +262,7 @@ def replicate( ) device_mesh = kwargs.pop("device_mesh", None) - if device_mesh is not None: - from torch.distributed.device_mesh import _mesh_resources - - root_mesh = _mesh_resources.get_root_mesh(device_mesh) - # if a root mesh is not the same as device_mesh, - # meaning the device_mesh is sliced out from the root mesh. - if root_mesh != device_mesh: - # TODO: This is a temporary work around to enable DDP + TP. - # We should do the logic in DDP so that the 2D implementation is - # sound and the state_dict works out of the box. - # - # This won't conflict with what is done in DDP class as the module - # replicate is going to pass is NOT the original module. - from torch.distributed.tensor.parallel.ddp import ( - _localize_dtensor, - _reconstruct_dtensor, - ) - - module.register_forward_pre_hook(_reconstruct_dtensor) - module.register_forward_hook(_localize_dtensor) - else: + if device_mesh is None: device_mesh = replicate_mesh() module = replicate_impl(module, mesh=device_mesh, **kwargs) From c86040a8e68f754b90a84099187d3624954c7f36 Mon Sep 17 00:00:00 2001 From: James Dong Date: Fri, 8 Aug 2025 19:45:26 +0000 Subject: [PATCH 0158/1424] [torch.export] Fix test_export_api_with_dynamic_shapes (#160164) Summary: Update test KJT's dynamic_shapes to match the newly exported fields. Test Plan: ``` buck test 'fbcode//mode/opt' fbcode//caffe2/test:test_export -- --exact 'caffe2/test:test_export - test_export_api_with_dynamic_shapes_cpp_runtime_nonstrict (caffe2.test.export.test_nativert.NativeRTTestExport)' File changed: fbcode//caffe2/test/export/test_export.py Buck UI: https://www.internalfb.com/buck2/8247eaf8-eaf9-4876-95cb-7b4263d15ef2 Test UI: https://www.internalfb.com/intern/testinfra/testrun/2533275093345198 Network: Up: 100KiB Down: 0B (reSessionID-72a2579f-df3f-4262-9aa3-de0db9687 Executing actions. Remaining 0/2 Command: test. Time elapsed: 2:20.5s Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Rollback Plan: Reviewed By: malaybag Differential Revision: D79862872 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160164 Approved by: https://github.com/angelayi, https://github.com/ezyang --- test/export/test_export.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index c67657bfe3155..848373aef6841 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6349,7 +6349,9 @@ def forward(self, kjt) -> torch.Tensor: efoo = torch.export.export( foo, inputs, - dynamic_shapes={"kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}]}, + dynamic_shapes={ + "kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}, None, None] + }, ) self.assertEqual( [out.shape for out in efoo.module()(*inputs)], From 2ee22e435131369a7e4f8cc4732579acc29a941b Mon Sep 17 00:00:00 2001 From: Jovian Anthony Jaison <38627145+jovianjaison@users.noreply.github.com> Date: Fri, 8 Aug 2025 19:53:41 +0000 Subject: [PATCH 0159/1424] [pytorch][dynamo_compile] Log stack_trace to dynamo_compile (#159655) This change logs the stack trace of the code being compiled by Dynamo, improving visibility into what is compiled. It adds a stack_trace field to compilation metrics. This helps with debugging and analysis of Dynamo compilation behavior. Ref [D79287964](https://www.internalfb.com/diff/D79287964) Test Plan: $ python -m test_utils Internal: ref [D79372519](https://www.internalfb.com/diff/D79372519) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159655 Approved by: https://github.com/c00w --- test/dynamo/test_utils.py | 29 ++++++++++++++++++++++ torch/_dynamo/convert_frame.py | 44 +++++++++++++++++++--------------- torch/_dynamo/utils.py | 1 + 3 files changed, 55 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index d4206575d7b08..f77a8e6ac7f18 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -246,6 +246,32 @@ def add(x, y): utils.reset_frame_count() torch._logging._internal.structured_logging_overhead.clear() + @dynamo_config.patch({"log_compilation_metrics": True}) + @inductor_config.patch({"force_disable_caches": True}) + def test_stack_trace(self): + self.warmup() + + compilation_events = [] + with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event: + self.run_forward_backward() + compilation_events = [arg[0][0] for arg in log_event.call_args_list] + stack_trace_list = [] + for e in compilation_events: + stack_trace_list.append(e.stack_trace) + + self.assertGreater(len(stack_trace_list), 0) + result = "\n".join( + item + for sublist in stack_trace_list + if sublist + for item in (sublist if isinstance(sublist, list) else [sublist]) + ) + self.assertIn( + "test_stack_trace", + result, + "Log file does not contain the expected string: 'test_stack_trace'", + ) + @dynamo_config.patch( { "log_compilation_metrics": True, @@ -396,6 +422,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): e.cuda_version = None e.triton_version = None e.python_version = None + e.stack_trace = None # First event is for the forward. Formatting makes reading diffs # much easier. @@ -479,6 +506,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': 0, 'specialize_float': False, + 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, @@ -652,6 +680,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': None, 'specialize_float': None, + 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index bba4d9c980869..fb27c29935439 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -225,30 +225,35 @@ def fx_forward_from_src_skip_result( return result -def log_dynamo_start(code: CodeType, skip: int = 0) -> None: +def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]: convert_frame_intern = structured.intern_string(__file__) + # Extract and filter the stack + stack = list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] # Initialize the ChromiumEventLogger on start torch._logging.trace_structured( "dynamo_start", - lambda: { - "stack": list( - itertools.takewhile( - lambda f: f["filename"] != convert_frame_intern, - structured.from_traceback( - CapturedTraceback.extract(skip=4 + skip).summary() - ), - ) - ) - + [ - { - "line": code.co_firstlineno, - "name": code.co_name, - "filename": structured.intern_string(code.co_filename), - } - ] - }, + lambda: {"stack": stack}, ) + stack_strings = [ + f"Line: {frame['line']}, Name: {frame['name']}, Filename: {frame['filename']}" + for frame in stack + ] + return stack_strings + def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: """ @@ -1160,7 +1165,7 @@ def format_func_info(code: CodeType) -> str: # # 2 extra here # torch/_logging/_internal.py:1064 in trace_structured # torch/_dynamo/convert_frame.py:780 in - log_dynamo_start(code, skip) + stack_trace = log_dynamo_start(code, skip) start_time_ns = time.time_ns() fail_type: Optional[str] = None fail_reason: Optional[str] = None @@ -1300,6 +1305,7 @@ def format_func_info(code: CodeType) -> str: "dynamo_compile_time_before_restart_us": to_int_us( dynamo_time_before_restart ), + "stack_trace": stack_trace, } # TODO: replace with CompileEventLogger.compilation_metrics # There are some columns here not in PT2 Compile Events diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 588f1ddb99a19..c6707fe12fbd0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1288,6 +1288,7 @@ class CompilationMetrics: compliant_custom_ops: Optional[set[str]] = None restart_reasons: Optional[set[str]] = None dynamo_time_before_restart_s: Optional[float] = None + stack_trace: Optional[list[str]] = None # Sometimes, we will finish analyzing a frame but conclude we don't want # to install any guarded code. True means we actually decided to install # a compiled frame From 1febab2a89302464f6c7d69cfbef7a24c421ea65 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Fri, 8 Aug 2025 20:13:30 +0000 Subject: [PATCH 0160/1424] Do not treat ReinterpretView as a realized node (#159920) Summary: Do not treat ReinterpretView as a realized node Function [gather_origins](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L888](https://l.facebook.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fblob%2Fmain%2Ftorch%2F_inductor%2Futils.py%23L888&h=AT2PYr83thTo6VUjPs26Y8QAN6Sid16rvDMHtxO-Bp9FDwHr4J5PObtH3IhNTL-LPSRVC9WVJAcmwUToVWJIrDWb84i0j61QE55ySYAkGbuigqcNc7xczlirHhbiC9vMqiz91VwWdl4Pe2yKN7VIjjCiFUqw) calls is_realized_node to decide if a FX node should be included in the origins of a IR node. ReinterpretView is considered a realized node, so it is not included in the origins. It leads to an incomplete graph. For example: ``` @torchdynamo.optimize("inductor") def fn(input_data, weight): normalized_input = input_data * weight.unsqueeze(0) return normalized_input input_data = torch.randn(4272, 192, requires_grad=True).to(device) weight = torch.randn(192, requires_grad=True).to(device) fn(input_data, weight) ``` The original FX graph returned in [get_kernel_metadata](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L723](https://l.facebook.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fblob%2Fmain%2Ftorch%2F_inductor%2Futils.py%23L723&h=AT2PYr83thTo6VUjPs26Y8QAN6Sid16rvDMHtxO-Bp9FDwHr4J5PObtH3IhNTL-LPSRVC9WVJAcmwUToVWJIrDWb84i0j61QE55ySYAkGbuigqcNc7xczlirHhbiC9vMqiz91VwWdl4Pe2yKN7VIjjCiFUqw) is the following: %primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2] %primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1] %mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {}) return %mul The unsqueeze op is missing. With this DIFF, the new FX graph is the following: %primals_2 : Tensor "f32[4272, 192][192, 1]cuda:0" = PlaceHolder[target=primals_2] %primals_1 : Tensor "f32[192][1]cuda:0" = PlaceHolder[target=primals_1] %unsqueeze : Tensor "f32[1, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%primals_1, 0), kwargs = {}) %mul : Tensor "f32[4272, 192][192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %unsqueeze), kwargs = {}) return %mul Pull Request resolved: https://github.com/pytorch/pytorch/pull/159920 Approved by: https://github.com/mlazos --- torch/_inductor/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 026f5f14fe74f..f21905e16e9d7 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -895,7 +895,15 @@ def is_unrealized_node(n: IRNode) -> bool: return is_unrealized_node(n.data) if isinstance(n, ir.StorageBox): return is_unrealized_node(n.data) - return isinstance(n, ir.IRNode) and not ir.IRNode.is_realized_node(n) + return isinstance(n, ir.IRNode) and not isinstance( + n, + ( + ir.ComputedBuffer, + ir.InputsKernel, + ir.InputBuffer, + ir.TemplateBuffer, + ), + ) # kwargs and args may include a container of node, for example torch.cat([t1, t2]) # flatten them before search the unrealized nodes From 2247aa6d1d43e256255f5c74a781c3190a4387b6 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Thu, 7 Aug 2025 14:37:50 -0700 Subject: [PATCH 0161/1424] Documents tuning NVLink performance on H100/H200 (#159792) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159792 Approved by: https://github.com/ngimel --- docs/source/notes/cuda.rst | 124 +++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 5210eb4ad1495..8ad4c87a71395 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -896,6 +896,130 @@ APIs can be used for debugging purposes: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html#memory-allocator +Tuning NVLink Performance with Custom Memory Allocator on H100/H200 GPUs +------------------------------------------------------------------------ +In rare cases, performance of NVLink on H100/H200 GPUs can be influenced by the physical memory +layout of data, creating an opportunity for developers to tune their applications for optimal +throughput. + +An example of how physical memory layout of data affects performance is when communication +kernels issue unbalanced NVLink read/write operations. In the following figure, we can see +that each warp accesses memory addresses with a consistent strided pattern in each single wave. +We can have a more balanced load by tuning the stride size in the workload or we can implement +a custom CUDA allocator. + +.. code:: + + _______________________________ _______________________________ _______________________________ + | Warp 0 Reading | No-reading | | Warp 1 Reading | No-reading | ... Warp N Reading | No-reading | + _______________________________ _______________________________ _______________________________ + <-----------------------------> + Stride size + +Such an allocator can maintain contiguous virtual memory addresses for the kernel while strategically +arranging the mapping to physical memory addresses (e.g., through shuffling). This technique allows +developers to explore different physical access patterns to find the most efficient one, unlocking +higher performance without modifying the kernel's logic. A practical implementation of such an allocator +can be achieved using PyTorch’s custom allocator support as mentioned before, where the malloc and free +functions are: + +.. code:: C++ + + // assuming a system with 8 GPUs + struct CustomAllocInfo { + void** devPtr; // This will be the usable virtual memory address + CUdeviceptr dptr; + size_t totalSize; // Total size of the allocated memory + size_t padded_size; + int device_id; + std::vector handles; // Handles to physical memory allocations + }; + + // loop over pages + cudaError_t customCudaMalloc(CustomAllocInfo* info) { + if (!info) return cudaErrorInvalidValue; + + CUdeviceptr dptr; + + // Handles to redundant physical memory allocations which help truncate stride pattern in physical memory + std::vector handles_redundant; + + size_t granularity = 0; + CUmemAllocationProp prop = {}; + + int currentDev = info->device_id; + size_t totalSize = info->totalSize; + + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = currentDev; + cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); + size_t padded_size = ROUND_UP(totalSize, granularity); + + info->padded_size = padded_size; + + // loop over pages + size_t iter_granularity = granularity * 64; // 64 * granularity with shift_size = 2 works + uint32_t iteration_count = (totalSize + iter_granularity - 1) / iter_granularity; + + cuMemAddressReserve(&dptr, padded_size, 0ULL, 0ULL, 0ULL); + + const int shift_size = 2; + for (size_t i = 0; i < iteration_count; i+=shift_size) { + + CUmemGenericAllocationHandle allocHandle[shift_size]; + for (int shift = 0; (shift < shift_size)&&(i+shift < iteration_count); shift++){ + CHECK_CUDA(cuMemCreate(&allocHandle[shift], iter_granularity, &prop, 0)); + info->handles.push_back(allocHandle[shift]); + } + + for (int shift = 0; (shift < shift_size)&&(i+shift < iteration_count); shift++){ + + // mapping makes the shift (shift -> (shift+1)%shift_size ) + CHECK_CUDA(cuMemMap(dptr + (i+shift) * iter_granularity, iter_granularity, 0, allocHandle[(shift+1)%shift_size], 0)); + + setupMultiGPUAccess(dptr + (i+shift) * iter_granularity, iter_granularity, {0, 1, 2, 3, 4, 5, 6, 7}); // Enable access for all 8 GPUs + } + + // std::cout << "Here we allocate one redundant page (2MB)..." << std::endl; + // this is an extra optimization on top of the swizzling. It helps "break" + // the physical access pattern even more. It can be left out if workload is already + // performing at SOL with just swizzling. + CUmemGenericAllocationHandle allocHandle_redundant; + CHECK_CUDA(cuMemCreate(&allocHandle_redundant, granularity, &prop, 0)); + handles_redundant.push_back(allocHandle_redundant); + } + + *info->devPtr = (void*)dptr; + info->dptr = dptr; + + // Release each redundant allocation + for (auto handle : handles_redundant) { + // std::cout << "Here we release one redundant page (2MB)..." << std::endl; + CHECK_CUDA(cuMemRelease(handle)); + } + + return cudaSuccess; + } + + void customCudaFree(CustomAllocInfo* info) { + if (!info) return; + + // CHECK_CUDA(cudaSetDevice(info->device_id)); + + CHECK_CUDA(cuMemUnmap(info->dptr, info->padded_size)); + + // Unmap and release each allocation + for (auto handle : info->handles) { + CHECK_CUDA(cuMemRelease(handle)); + } + + // Unreserve the virtual address space + // CHECK_CUDA(cuMemAddressFree((CUdeviceptr)*info->devPtr, info->padded_size)); + CHECK_CUDA(cuMemAddressFree(info->dptr, info->padded_size)); + } + + cuBLAS workspaces ----------------- From 28ccc9e7247798980fe00a11bcd64a8016b5f227 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 7 Aug 2025 20:41:22 -0700 Subject: [PATCH 0162/1424] [MPS] Extend `index_put` to complex types (#160159) And delete confusing supported types check. Move all pseudo atomic (but eventually consistent) ops to `c10/metal/atomic.h` header Fixes https://github.com/pytorch/pytorch/issues/160034 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160159 Approved by: https://github.com/manuelcandales, https://github.com/dcci, https://github.com/Skylion007 --- .../ATen/native/mps/kernels/Indexing.metal | 27 ++-------- .../ATen/native/mps/operations/Indexing.mm | 22 ++------- c10/metal/atomic.h | 49 +++++++++++++++++++ torch/testing/_internal/common_mps.py | 2 + 4 files changed, 58 insertions(+), 42 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/Indexing.metal b/aten/src/ATen/native/mps/kernels/Indexing.metal index 7503d8b2b1c8b..048b2e5ae7c9a 100644 --- a/aten/src/ATen/native/mps/kernels/Indexing.metal +++ b/aten/src/ATen/native/mps/kernels/Indexing.metal @@ -5,29 +5,6 @@ using namespace metal; using namespace c10::metal; -namespace c10 { -namespace metal { -// There are no atomic 64-bit add in Metal yet, but this implements a consistent -// add I.e. if multiple threads are modify the same 64-bit value, results stored -// at the address will eventually be equal to its original value plus sum of all -// operands -template <> -struct AtomicType { - using type = ::metal::atomic; - static inline void atomic_add(device type* data, long offset, long value) { - const auto value_bits = as_type(value); - const uint low = static_cast(value_bits); - uint high = static_cast(value_bits >> 32); - auto ptr = data + (offset << 1); - auto old_low = atomic_fetch_add_explicit(ptr, low, memory_order_relaxed); - high += (old_low + low < old_low) ? 1 : 0; - atomic_fetch_add_explicit(ptr + 1, high, memory_order_relaxed); - } -}; - -} // namespace metal -} // namespace c10 - struct IndexAB { constant int64_t* indexArray; }; @@ -234,13 +211,15 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial); REGISTER_INDEX_OP(put_accumulate, float, float); REGISTER_INDEX_OP(put_accumulate, half, half); +REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat); REGISTER_INDEX_OP(put_accumulate, long, long); REGISTER_INDEX_OP(put_accumulate, int, int); REGISTER_INDEX_OP(put_accumulate, short, short); REGISTER_INDEX_OP(put_accumulate, char, char); REGISTER_INDEX_OP(put_accumulate, uchar, uchar); REGISTER_INDEX_OP(put_accumulate, bool, bool); -REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat); +REGISTER_INDEX_OP(put_accumulate, float2, float2); +REGISTER_INDEX_OP(put_accumulate, half2, half2); template kernel void kernel_index_offsets( diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index 66ae1114f841d..a73866dc4357b 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -108,26 +108,12 @@ static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, - const std::string& op, - bool accumulate) { - using namespace mps; - + const std::string& op) { const auto num_indices = index_size.size(); TORCH_CHECK(num_indices <= 16, "Current limit allows up to 16 indices to be used in MPS indexing kernels"); AT_ASSERT(num_indices == index_stride.size()); AT_ASSERT(static_cast(num_indices) == iter.ntensors() - 2); - const Tensor& inputTensor = iter.tensor(1); - const auto scalar_type = inputTensor.scalar_type(); - - if (accumulate) { - // No atomic support for the complex dtypes - TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type)); - } else { - TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) || - scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf, - getMPSTypeString(inputTensor) + std::string(" not supported for index.Tensor_out")); - } } static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, const Tensor& mask) { @@ -158,7 +144,7 @@ static void dispatch_index_kernel(TensorIteratorBase& iter, IntArrayRef index_stride, const std::string& kernel_name, const bool serial = false) { - validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false); + validateInputData(iter, index_size, index_stride, "index.Tensor_out"); if (iter.numel() == 0) return; if (!iter.can_use_32bit_indexing()) { @@ -200,7 +186,7 @@ static void dispatch_index_kernel(TensorIteratorBase& iter, } static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) { - validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false); + validateInputData(iter, index_size, index_stride, "index.Tensor_out"); dispatch_index_kernel( iter, index_size, index_stride, fmt::format("index_select_{}", getBitSizeString(iter.tensor_base(0)))); } @@ -210,7 +196,7 @@ static void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_stride, bool accumulate) { @autoreleasepool { - validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate); + validateInputData(iter, index_size, index_stride, "index_put_impl"); if (accumulate) { dispatch_index_kernel(iter, index_size, diff --git a/c10/metal/atomic.h b/c10/metal/atomic.h index 6dcd9a706ba74..d0cbc03916989 100644 --- a/c10/metal/atomic.h +++ b/c10/metal/atomic.h @@ -124,5 +124,54 @@ struct AtomicType { } }; +// ComplexHalf atomic op +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, half2 value) { + auto ptr = data + offset; + auto old = + ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed); + while (!::metal::atomic_compare_exchange_weak_explicit( + ptr, + &old, + as_type(as_type(old) + value), + ::metal::memory_order_relaxed, + ::metal::memory_order_relaxed)) + ; + } +}; + +// There are no atomic 64-bit add in Metal yet, but templates below implements a +// consistent add I.e. if multiple threads are modify the same 64-bit value, +// results stored at the address will eventually be equal to its original value +// plus sum of all operands +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, long value) { + const auto value_bits = as_type(value); + const uint low = static_cast(value_bits); + uint high = static_cast(value_bits >> 32); + auto ptr = data + (offset << 1); + auto old_low = + atomic_fetch_add_explicit(ptr, low, ::metal::memory_order_relaxed); + high += (old_low + low < old_low) ? 1 : 0; + atomic_fetch_add_explicit(ptr + 1, high, ::metal::memory_order_relaxed); + } +}; + +// ComplexFloat atomic op, which again is not really atomic, but eventually +// consistent +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, float2 value) { + auto ptr = data + (offset << 1); + atomic_fetch_add_explicit(ptr + 0, value.x, ::metal::memory_order_relaxed); + atomic_fetch_add_explicit(ptr + 1, value.y, ::metal::memory_order_relaxed); + } +}; + } // namespace metal } // namespace c10 diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 58afc631d21bb..fbfa5e2c9f9fb 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -25,6 +25,7 @@ def mps_ops_modifier( "__rsub__", "__getitem__", "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", "abs", "add", "alias_copy", @@ -75,6 +76,7 @@ def mps_ops_modifier( "imag", "index_copy", "index_select", + "index_put", "isfinite", "isinf", "isreal", From 206c1eef6571f906c2792d899a09136b3fce9673 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 8 Aug 2025 22:04:22 +0000 Subject: [PATCH 0163/1424] Revert "[pytorch][dynamo_compile] Log stack_trace to dynamo_compile (#159655)" This reverts commit 2ee22e435131369a7e4f8cc4732579acc29a941b. Reverted https://github.com/pytorch/pytorch/pull/159655 on behalf of https://github.com/clee2000 due to broke dynamo/test_utils.py::TestDynamoTimed::test_dynamo_timed [GH job link](https://github.com/pytorch/pytorch/actions/runs/16839294394/job/47711078667) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/2ee22e435131369a7e4f8cc4732579acc29a941b). Probably a landrace since it did run on the PR ([comment](https://github.com/pytorch/pytorch/pull/159655#issuecomment-3169400889)) --- test/dynamo/test_utils.py | 29 ---------------------- torch/_dynamo/convert_frame.py | 44 +++++++++++++++------------------- torch/_dynamo/utils.py | 1 - 3 files changed, 19 insertions(+), 55 deletions(-) diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index f77a8e6ac7f18..d4206575d7b08 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -246,32 +246,6 @@ def add(x, y): utils.reset_frame_count() torch._logging._internal.structured_logging_overhead.clear() - @dynamo_config.patch({"log_compilation_metrics": True}) - @inductor_config.patch({"force_disable_caches": True}) - def test_stack_trace(self): - self.warmup() - - compilation_events = [] - with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event: - self.run_forward_backward() - compilation_events = [arg[0][0] for arg in log_event.call_args_list] - stack_trace_list = [] - for e in compilation_events: - stack_trace_list.append(e.stack_trace) - - self.assertGreater(len(stack_trace_list), 0) - result = "\n".join( - item - for sublist in stack_trace_list - if sublist - for item in (sublist if isinstance(sublist, list) else [sublist]) - ) - self.assertIn( - "test_stack_trace", - result, - "Log file does not contain the expected string: 'test_stack_trace'", - ) - @dynamo_config.patch( { "log_compilation_metrics": True, @@ -422,7 +396,6 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): e.cuda_version = None e.triton_version = None e.python_version = None - e.stack_trace = None # First event is for the forward. Formatting makes reading diffs # much easier. @@ -506,7 +479,6 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': 0, 'specialize_float': False, - 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, @@ -680,7 +652,6 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': None, 'specialize_float': None, - 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index fb27c29935439..bba4d9c980869 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -225,35 +225,30 @@ def fx_forward_from_src_skip_result( return result -def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]: +def log_dynamo_start(code: CodeType, skip: int = 0) -> None: convert_frame_intern = structured.intern_string(__file__) - # Extract and filter the stack - stack = list( - itertools.takewhile( - lambda f: f["filename"] != convert_frame_intern, - structured.from_traceback( - CapturedTraceback.extract(skip=4 + skip).summary() - ), - ) - ) + [ - { - "line": code.co_firstlineno, - "name": code.co_name, - "filename": structured.intern_string(code.co_filename), - } - ] # Initialize the ChromiumEventLogger on start torch._logging.trace_structured( "dynamo_start", - lambda: {"stack": stack}, + lambda: { + "stack": list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] + }, ) - stack_strings = [ - f"Line: {frame['line']}, Name: {frame['name']}, Filename: {frame['filename']}" - for frame in stack - ] - return stack_strings - def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: """ @@ -1165,7 +1160,7 @@ def format_func_info(code: CodeType) -> str: # # 2 extra here # torch/_logging/_internal.py:1064 in trace_structured # torch/_dynamo/convert_frame.py:780 in - stack_trace = log_dynamo_start(code, skip) + log_dynamo_start(code, skip) start_time_ns = time.time_ns() fail_type: Optional[str] = None fail_reason: Optional[str] = None @@ -1305,7 +1300,6 @@ def format_func_info(code: CodeType) -> str: "dynamo_compile_time_before_restart_us": to_int_us( dynamo_time_before_restart ), - "stack_trace": stack_trace, } # TODO: replace with CompileEventLogger.compilation_metrics # There are some columns here not in PT2 Compile Events diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index c6707fe12fbd0..588f1ddb99a19 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1288,7 +1288,6 @@ class CompilationMetrics: compliant_custom_ops: Optional[set[str]] = None restart_reasons: Optional[set[str]] = None dynamo_time_before_restart_s: Optional[float] = None - stack_trace: Optional[list[str]] = None # Sometimes, we will finish analyzing a frame but conclude we don't want # to install any guarded code. True means we actually decided to install # a compiled frame From 334ecbd4ffe11858cae7d23d1190ddb4777c2513 Mon Sep 17 00:00:00 2001 From: Robert Hardwick Date: Fri, 8 Aug 2025 14:38:08 +0000 Subject: [PATCH 0164/1424] Add torchao to install_inductor_benchmark_deps cleanup stage (#160191) It looks like `torcho` was missed from the cleanup during torchbench setup. Fixes #160188 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160191 Approved by: https://github.com/huydhn --- .ci/docker/common/install_inductor_benchmark_deps.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index bda3aa6009564..c2601adb67e32 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -48,4 +48,4 @@ install_huggingface install_timm # Clean up -conda_run pip uninstall -y torch torchvision torchaudio triton +conda_run pip uninstall -y torch torchvision torchaudio triton torchao From 1128f4c2a822cbe34a9d966306af15097179ffe1 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Fri, 8 Aug 2025 22:22:48 +0000 Subject: [PATCH 0165/1424] [cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for `sm90`, `sm100` (#149282) cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/149282 Approved by: https://github.com/drisspg Co-authored-by: Aaron Gokaslan --- aten/src/ATen/native/cudnn/MHA.cpp | 1064 +++++++++++------ aten/src/ATen/native/cudnn/MHA.h | 27 + aten/src/ATen/native/native_functions.yaml | 6 + .../cuda/NestedTensorTransformerFunctions.cpp | 57 + .../native/transformers/cuda/attention.cu | 10 - .../transformers/cuda/attention_backward.cu | 192 ++- .../native/transformers/cuda/sdp_utils.cpp | 72 +- ...asDecompTest.test_has_decomposition.expect | 1 + test/inductor/test_cuda_repro.py | 8 +- test/test_nestedtensor.py | 9 +- test/test_transformers.py | 21 +- tools/autograd/derivatives.yaml | 4 + 12 files changed, 1025 insertions(+), 446 deletions(-) diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 48119a6a3b4c3..a482c9041c906 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -2,9 +2,13 @@ #include #include -#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ - (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) +#if AT_CUDNN_ENABLED() +#include +#endif +#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ + (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) || \ + (defined(CUDNN_FRONTEND_VERSION) && CUDNN_FRONTEND_VERSION < 10100) namespace at { namespace native { @@ -84,6 +88,37 @@ void run_cudnn_SDP_bprop( false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); } +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset) { + TORCH_CHECK( + false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); +} + } // namespace native } // namespace at @@ -95,7 +130,6 @@ void run_cudnn_SDP_bprop( #include #include -#include #include #include @@ -111,40 +145,6 @@ namespace native { #include namespace fe = cudnn_frontend; -using graph_and_tensors = std::tuple< - std::shared_ptr, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::optional>, // Bias - std::shared_ptr, // Attn_scale, - // TODO(eqy): additional options - // std::shared_ptr, // SEQ_LEN_Q, - // std::shared_ptr, // SEQ_LEN_KV, - std::shared_ptr, // Seed, - std::shared_ptr, // Offset, - // std::shared_ptr, // Dropout_mask, - // std::shared_ptr, // Dropout_scale - std::shared_ptr, // O - std::shared_ptr // Stats - >; - -using graph_and_tensors_backward = std::tuple< - std::shared_ptr, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::optional>, // Bias, - std::shared_ptr, // Attn_scale, - std::shared_ptr, // Seed, - std::shared_ptr, // Offset, - std::shared_ptr, // O, - std::shared_ptr, // dO, - std::shared_ptr, // stats, - std::shared_ptr, // dQ, - std::shared_ptr, // dK,, - std::shared_ptr // dV, - >; #define MAX_MHA_DIM 4 @@ -298,11 +298,45 @@ struct MHAGraphCache { // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to // be thread safe across all engines see Limitations in // https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html -thread_local MHAGraphCache mhagraphcache; -thread_local MHAGraphCache - mhagraphbackwardcache; +// We also leak the caches to workaround potential teardown race issues. + +auto& getMHAGraphCache_() { + thread_local auto& instance = + *new MHAGraphCache, MHACacheKeyWrapper>; + return instance; +} + +auto& getMHAGraphBackwardCache_() { + thread_local auto& instance = + *new MHAGraphCache, MHACacheKeyWrapper>; + return instance; +} namespace { + +enum UIDS { + Q, + K, + V, + O, + BIAS, + SCALE, + SEED, + OFFSET, + LSE, + DO, + DQ, + DK, + DV, + SEQ_LEN_Q, + SEQ_LEN_KV, + RAG_Q_OFF, + RAG_K_OFF, + RAG_V_OFF, + RAG_O_OFF, + RAG_LSE_OFF +}; + // analogous to the same function in Descriptors.h for cuDNN Convolutions... auto fixSizeOneDimStrideSDPA( const IntArrayRef sizes, @@ -320,9 +354,10 @@ auto fixSizeOneDimStrideSDPA( } return strides; } + } // namespace -auto build_graph_and_tensors( +auto build_graph( int64_t b, int64_t h, int64_t s_q, @@ -355,46 +390,55 @@ auto build_graph_and_tensors( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA") .set_is_inference(return_softmaxstats == false) .set_causal_mask(is_causal) - .set_attn_scale(attn_scale) - .set_dropout(dropout_probability, seed, offset); - auto Q = mha_graph->tensor( + .set_attn_scale(attn_scale); + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + scaled_dot_product_flash_attention_options.set_dropout( + dropout_probability, seed, offset); + } + auto Q_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(Q) .set_name("Q") .set_dim(q.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec()))); - auto K = mha_graph->tensor( + auto K_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(K) .set_name("K") .set_dim(k.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec()))); - auto V = mha_graph->tensor( + auto V_ = mha_graph->tensor( fe::graph::Tensor_attributes() + .set_uid(V) .set_name("V") .set_dim(v.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec()))); @@ -402,17 +446,20 @@ auto build_graph_and_tensors( if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto [O, Stats] = - mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); - O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); + auto [O_, Stats] = + mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); + O_->set_uid(O); + O_->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); if (Stats) { + Stats->set_uid(LSE); Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); } @@ -423,20 +470,10 @@ auto build_graph_and_tensors( AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(seed), - std::move(offset), - std::move(O), - std::move(Stats)); + return mha_graph; } -auto build_graph_and_tensors_nestedtensor( +auto build_graph_nestedtensor( int64_t b, int64_t h_q, int64_t h_k, @@ -473,28 +510,22 @@ auto build_graph_and_tensors_nestedtensor( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_KV = + auto SEQ_LEN_Q_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_Q) + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_KV) .set_name("Seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -506,41 +537,66 @@ auto build_graph_and_tensors_nestedtensor( .set_is_inference(return_softmaxstats == false) .set_causal_mask(is_causal) .set_attn_scale(attn_scale) - .set_dropout(dropout_probability, seed, offset) - .set_seq_len_q(SEQ_LEN_Q) - .set_seq_len_kv(SEQ_LEN_KV) + .set_seq_len_q(SEQ_LEN_Q_) + .set_seq_len_kv(SEQ_LEN_KV_) .set_padding_mask(true); + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + scaled_dot_product_flash_attention_options.set_dropout( + dropout_probability, seed, offset); + } // We hardcode BSHD to cuDNN even though the underlying layout is THD auto q_strides = q.strides(); auto k_strides = k.strides(); auto v_strides = v.strides(); + // NB: cuDNN API shape is transposed constexpr int strideidx0 = 1; constexpr int strideidx1 = 0; constexpr int strideidx2 = 2; - auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]})); - auto K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]})); - auto V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]})); + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); std::optional> bias; if (attn_bias.has_value()) { TORCH_CHECK( @@ -548,44 +604,48 @@ auto build_graph_and_tensors_nestedtensor( "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto RAG_Q_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_K_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_V_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_O_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("cum_seq_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - // auto RAG_STATS_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("cum_seq_stats") - // .set_dim({b + 1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - auto RAG_STATS_OFF = nullptr; - Q->set_ragged_offset(RAG_Q_OFF); - K->set_ragged_offset(RAG_K_OFF); - V->set_ragged_offset(RAG_V_OFF); - auto [O, Stats] = - mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); + auto RAG_Q_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_Q_OFF) + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_K_OFF) + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_V_OFF) + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_O_OFF) + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + Q_->set_ragged_offset(RAG_Q_OFF_); + K_->set_ragged_offset(RAG_K_OFF_); + V_->set_ragged_offset(RAG_V_OFF_); + auto [O_, Stats] = + mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); auto o_strides = o.strides(); - O->set_output(true) + O_->set_output(true) + .set_uid(O) .set_dim({b, h_q, s_q, d_v}) .set_stride( {INT_MAX, @@ -593,16 +653,20 @@ auto build_graph_and_tensors_nestedtensor( o_strides[strideidx1], o_strides[strideidx2]}); - O->set_ragged_offset(RAG_O_OFF); + O_->set_ragged_offset(RAG_O_OFF_); if (Stats) { - TORCH_CHECK( - false, - "cuDNN SDPA Nested Tensor does not yet handle backwards/logsumexp computation"); - // TODO(eqy): fix when stats (backward) support is added + auto RAG_STATS_OFF = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_LSE_OFF) + .set_name("cum_seq_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); Stats->set_output(true) + .set_uid(LSE) .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h_q, s_q, 1}) - .set_stride({h_q * s_q * d_v, d_v, s_q * d_v, 1}); + .set_stride({h_q * s_q, 1, h_q, 1}); Stats->set_ragged_offset(RAG_STATS_OFF); } AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); @@ -611,27 +675,10 @@ auto build_graph_and_tensors_nestedtensor( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(seed), - std::move(offset), - std::move(O), - std::move(Stats), - std::move(RAG_Q_OFF), - std::move(RAG_K_OFF), - std::move(RAG_V_OFF), - std::move(RAG_O_OFF), - std::move(RAG_STATS_OFF), - std::move(SEQ_LEN_Q), - std::move(SEQ_LEN_KV)); + return mha_graph; } -auto build_graph_and_tensors_backward( +auto build_graph_backward( int64_t b, int64_t h, int64_t s_q, @@ -667,6 +714,7 @@ auto build_graph_and_tensors_backward( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -676,87 +724,327 @@ auto build_graph_and_tensors_backward( .set_name("CUDNN_SDPA_BACKWARD") .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim(q.sizes().vec()) - .set_stride(q.strides().vec())); - auto K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim(k.sizes().vec()) - .set_stride(k.strides().vec())); - auto V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim(v.sizes().vec()) - .set_stride(v.strides().vec())); + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(q.strides().vec())); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(k.strides().vec())); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(v.strides().vec())); std::optional> bias; if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); sdpa_backward_options.set_bias(bias.value()); } - auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - - auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") + if (dropout_probability != 0.0f) { + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type( - dropoutoffset.dtype() == kInt + dropoutseed.dtype() == kInt ? fe::DataType_t::INT32 : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, seed, offset); + } - auto O = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim(o.sizes().vec()) - .set_stride(o.strides().vec())); - auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(O) + .set_name("O") + .set_dim(o.sizes().vec()) + .set_stride(o.strides().vec())); + auto Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(LSE) .set_name("Stats") .set_dim(softmaxstats.sizes().vec()) .set_stride(softmaxstats.strides().vec()) .set_data_type(fe::DataType_t::FLOAT)); - auto DO = mha_graph->tensor(fe::graph::Tensor_attributes() + auto Do = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(DO) .set_name("DO") .set_dim(dO.sizes().vec()) .set_stride(dO.strides().vec())); + auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( + Q_, K_, V_, O_, Do, Stats, sdpa_backward_options); + Dq->set_uid(DQ); + Dq->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); + Dk->set_uid(DK); + Dk->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); + Dv->set_uid(DV); + Dv->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); + AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); + AT_CUDNN_FRONTEND_CHECK( + mha_graph->create_execution_plans({fe::HeurMode_t::A})); + AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); + return mha_graph; +} + +auto build_graph_backward_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset, + cudnnHandle_t& handle) { + auto dtype = fe::DataType_t::HALF; + if (q.scalar_type() == kBFloat16) { + dtype = fe::DataType_t::BFLOAT16; + } + auto mha_graph = std::make_shared(); + // We're baking in float accumulation and scale types + // in theory the graph may support other types, but they + // have not been tested + mha_graph->set_io_data_type(dtype) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + auto attn_scale = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SCALE) + .set_name("Attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + auto SEQ_LEN_Q_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_Q) + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEQ_LEN_KV) + .set_name("Seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() + .set_name("CUDNN_SDPA_NESTEDTENSOR_BACKWARD") + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale) + .set_seq_len_q(SEQ_LEN_Q_) + .set_seq_len_kv(SEQ_LEN_KV_) + .set_padding_mask(true); if (dropout_probability != 0.0f) { - sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset); + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(SEED) + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(OFFSET) + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, seed, offset); } - auto [DQ, DK, DV] = - mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options); - DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); - DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); - DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); + auto q_strides = q.strides(); + auto k_strides = k.strides(); + auto v_strides = v.strides(); + // NB: cuDNN API shape is transposed + constexpr int strideidx0 = 1; + constexpr int strideidx1 = 0; + constexpr int strideidx2 = 2; + auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(Q) + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(K) + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(V) + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); + auto o_strides = o.strides(); + auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(O) + .set_name("O") + .set_dim({b, h_q, s_q, d_v}) + .set_stride( + {INT_MAX, + o_strides[strideidx0], + o_strides[strideidx1], + o_strides[strideidx2]})); + + std::optional> bias; + if (attn_bias.has_value()) { + TORCH_CHECK( + false, + "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + bias = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(BIAS) + .set_name("bias") + .set_dim(attn_bias.value().sizes().vec()) + .set_stride(attn_bias.value().strides().vec())); + sdpa_backward_options.set_bias(bias.value()); + } + auto RAG_Q_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_Q_OFF) + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_K_OFF) + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_V_OFF) + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_O_OFF) + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_STATS_OFF_ = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(RAG_LSE_OFF) + .set_name("cum_seq_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + O_->set_ragged_offset(RAG_O_OFF_); + Q_->set_ragged_offset(RAG_Q_OFF_); + K_->set_ragged_offset(RAG_K_OFF_); + V_->set_ragged_offset(RAG_V_OFF_); + auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_uid(LSE) + .set_name("stats") + .set_dim({b, h_q, s_q, 1}) + .set_stride({s_q * h_q, 1, h_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + STATS->set_ragged_offset(RAG_STATS_OFF_); + auto do_strides = dO.strides(); + auto DO_ = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_ragged_offset(RAG_O_OFF_) + .set_uid(DO) + .set_name("DO") + .set_dim({b, h_q, s_q, d_v}) + .set_stride( + {INT_MAX, + do_strides[strideidx0], + do_strides[strideidx1], + do_strides[strideidx2]})); + auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( + Q_, K_, V_, O_, DO_, STATS, sdpa_backward_options); + Dq->set_output(true) + .set_uid(DQ) + .set_ragged_offset(RAG_Q_OFF_) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]}); + Dk->set_output(true) + .set_uid(DK) + .set_ragged_offset(RAG_K_OFF_) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]}); + Dv->set_output(true) + .set_uid(DV) + .set_ragged_offset(RAG_V_OFF_) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]}); + AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); AT_CUDNN_FRONTEND_CHECK( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return std::make_tuple( - std::move(mha_graph), - std::move(Q), - std::move(K), - std::move(V), - std::move(bias), - std::move(attn_scale), - std::move(Seed), - std::move(Offset), - std::move(O), - std::move(DO), - std::move(STATS), - std::move(DQ), - std::move(DK), - std::move(DV)); + return mha_graph; } void run_cudnn_SDP_fprop( @@ -817,12 +1105,12 @@ void run_cudnn_SDP_fprop( dropout_probability, is_causal, return_softmaxstats); - auto graph_and_tensors_ptr = mhagraphcache.find(key); - graph_and_tensors graph_and_tensors_values; - if (graph_and_tensors_ptr) { - graph_and_tensors_values = *graph_and_tensors_ptr; + auto graph_ptr = getMHAGraphCache_().find(key); + std::shared_ptr mha_graph; + if (graph_ptr) { + mha_graph = *graph_ptr; } else { - graph_and_tensors_values = build_graph_and_tensors( + mha_graph = build_graph( b, h, s_q, @@ -843,29 +1131,28 @@ void run_cudnn_SDP_fprop( _dropoutoffset, handle); } - auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] = - graph_and_tensors_values; - std::unordered_map, void*> - variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {attn_scale, &scaling_factor}, - {seed, _dropoutseed.data_ptr()}, - {offset, _dropoutoffset.data_ptr()}, - {O, o.data_ptr()}}; + std::unordered_map variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {SCALE, &scaling_factor}, + {O, o.data_ptr()}}; if (return_softmaxstats) { - variant_pack[Stats] = softmaxstats.data_ptr(); + variant_pack[LSE] = softmaxstats.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[bias.value()] = attn_bias.value().data_ptr(); + variant_pack[BIAS] = attn_bias.value().data_ptr(); + } + if (dropout_probability != 0.0f) { + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - mhagraphcache.update(key, graph_and_tensors_values); + getMHAGraphCache_().update(key, mha_graph); } void run_cudnn_SDP_fprop_nestedtensor( @@ -904,72 +1191,55 @@ void run_cudnn_SDP_fprop_nestedtensor( if (return_softmaxstats && !softmaxstats.defined()) { softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat)); } - auto - [mha_graph, - Q, - K, - V, - bias, - attn_scale, - seed, - offset, - O, - Stats, - RAG_Q_OFF, - RAG_K_OFF, - RAG_V_OFF, - RAG_O_OFF, - RAG_STATS_OFF, - SEQ_LEN_Q, - SEQ_LEN_KV] = - build_graph_and_tensors_nestedtensor( - b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - scaling_factor, - return_softmaxstats, - is_causal, - dropout_probability, - cum_seqlen_q, - cum_seqlen_kv, - q, - k, - v, - attn_bias, - softmaxstats, - o, - dropoutseed, - dropoutoffset, - handle); + auto mha_graph = build_graph_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + return_softmaxstats, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + softmaxstats, + o, + dropoutseed, + dropoutoffset, + handle); auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); - auto rag_k_off = cum_seqlen_kv.mul(h_k * d_qk); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); auto rag_stats_off = cum_seqlen_q.mul(h_q); - std::unordered_map, void*> - variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {attn_scale, &scaling_factor}, - {seed, dropoutseed.data_ptr()}, - {offset, dropoutoffset.data_ptr()}, - {O, o.data_ptr()}, - {RAG_Q_OFF, rag_q_off.data_ptr()}, - {RAG_O_OFF, rag_q_off.data_ptr()}, - {RAG_K_OFF, rag_k_off.data_ptr()}, - {RAG_V_OFF, rag_v_off.data_ptr()}, - {SEQ_LEN_Q, seqlen_q.data_ptr()}, - {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + std::unordered_map variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {SCALE, &scaling_factor}, + {O, o.data_ptr()}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; if (return_softmaxstats) { - variant_pack[Stats] = softmaxstats.data_ptr(); - variant_pack[RAG_STATS_OFF] = cum_seqlen_q.data_ptr(); + variant_pack[LSE] = softmaxstats.data_ptr(); + variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr(); + } + if (dropout_probability != 0.0f) { + variant_pack[SEED] = dropoutseed.data_ptr(); + variant_pack[OFFSET] = dropoutoffset.data_ptr(); } if (attn_bias.has_value()) { TORCH_CHECK("bias not supported with nestedtensor"); @@ -1053,12 +1323,12 @@ void run_cudnn_SDP_bprop( dropout_probability, is_causal, true); - auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); - graph_and_tensors_backward graph_and_tensors_backward_values; - if (graph_and_tensors_backward_ptr) { - graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr; + auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key); + std::shared_ptr mha_graph; + if (graph_backward_ptr) { + mha_graph = *graph_backward_ptr; } else { - graph_and_tensors_backward_values = build_graph_and_tensors_backward( + mha_graph = build_graph_backward( b, h, s_q, @@ -1082,49 +1352,153 @@ void run_cudnn_SDP_bprop( _dropoutoffset, handle); } - auto - [mha_graph, - Q, - K, - V, - bias, - attn_scale, - Seed, - Offset, - O, - Do, - Stats, - Dq, - Dk, - Dv] = graph_and_tensors_backward_values; - std::unordered_map, void*> - variant_pack = {// inputs - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {O, o.data_ptr()}, - {Do, dO_.data_ptr()}, - {Stats, softmaxstats.data_ptr()}, - // outputs - {Dq, dQ.data_ptr()}, - {Dk, dK.data_ptr()}, - {Dv, dV.data_ptr()}, - // pass by value - {attn_scale, &scaling_factor}}; + std::unordered_map variant_pack = { + // inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {DO, dO_.data_ptr()}, + {LSE, softmaxstats.data_ptr()}, + // outputs + {DQ, dQ.data_ptr()}, + {DK, dK.data_ptr()}, + {DV, dV.data_ptr()}, + {SCALE, &scaling_factor}}; if (dropout_probability != 0.0f) { - variant_pack[Seed] = _dropoutseed.data_ptr(); - variant_pack[Offset] = _dropoutoffset.data_ptr(); + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[bias.value()] = attn_bias.value().data_ptr(); + variant_pack[BIAS] = attn_bias.value().data_ptr(); + } + auto workspace_size = mha_graph->get_workspace_size(); + auto workspace_ptr = + c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); + TORCH_CHECK(!workspace_size || workspace_ptr.get()); + TORCH_CHECK( + mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); + getMHAGraphBackwardCache_().update(key, mha_graph); +} + +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset) { + // do nothing if we got 0-element tensors + if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() || + !softmaxstats.numel()) { + return; } + + Tensor dO_ = dO; + const auto innermost_dO_stride = dO.strides()[dO.strides().size() - 1]; + if (innermost_dO_stride != 1) { + permute_to_matching_layout(o, dO_); + } + + auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); + auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); + auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); + auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); + auto rag_stats_off = cum_seqlen_q.mul(h_q); + + auto dprops = at::cuda::getCurrentDeviceProperties(); + auto _dropoutseed = dropoutseed; + auto _dropoutoffset = dropoutoffset; + // cuDNN dropout bug requires these to be in int64 + if (dprops->major == 10 && dprops->minor == 0) { + _dropoutseed = dropoutseed.to(kLong); + _dropoutoffset = dropoutoffset.to(kLong); + } + + cudnnHandle_t handle = getCudnnHandle(); + + auto mha_graph = build_graph_backward_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + o, + dO_, + softmaxstats, + dQ, + dK, + dV, + dropoutseed, + dropoutoffset, + handle); + + std::unordered_map variant_pack = { + // inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {DO, dO_.data_ptr()}, + {LSE, softmaxstats.data_ptr()}, + // outputs + {DQ, dQ.data_ptr()}, + {DK, dK.data_ptr()}, + {DV, dV.data_ptr()}, + {SCALE, &scaling_factor}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {RAG_LSE_OFF, rag_stats_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + if (dropout_probability != 0.0f) { + variant_pack[SEED] = _dropoutseed.data_ptr(); + variant_pack[OFFSET] = _dropoutoffset.data_ptr(); + } + TORCH_CHECK( + !attn_bias.has_value(), + "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); + auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK(!workspace_size || workspace_ptr.get()); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - mhagraphbackwardcache.update(key, graph_and_tensors_backward_values); } } // namespace native diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 045e8cf6dee9d..620abc1aa0a8e 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -70,4 +70,31 @@ void run_cudnn_SDP_bprop( const Tensor& dropoutseed, const Tensor& dropoutoffset); +void run_cudnn_SDP_bprop_nestedtensor( + int64_t b, + int64_t h_q, + int64_t h_k, + int64_t h_v, + int64_t s_q, + int64_t s_kv, + int64_t d_qk, + int64_t d_v, + float scaling_factor, + bool is_causal, + float dropout_probability, + const Tensor& cum_seqlen_q, + const Tensor& cum_seqlen_kv, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const std::optional& attn_bias, + const Tensor& o, + const Tensor& dO, + const Tensor& softmaxstats, + Tensor& dQ, + Tensor& dK, + Tensor& dV, + const Tensor& dropoutseed, + const Tensor& dropoutoffset); + } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9f3c7468a6af4..e7492f4c379af 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15013,6 +15013,7 @@ - func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _scaled_dot_product_cudnn_attention_backward_cuda + NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda tags: nondeterministic_seeded - func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) @@ -15045,6 +15046,11 @@ CUDA: _cudnn_attention_forward tags: nondeterministic_seeded +- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _cudnn_attention_backward + tags: nondeterministic_seeded + - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 5b7476453407e..96c6ab8310f80 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -349,6 +349,63 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda( return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } +std::tuple _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda( + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + const Tensor& philox_seed, + const Tensor& philox_offset, + const Tensor& attn_bias, + const Tensor& cum_seq_q, + const Tensor& cum_seq_k, + const int64_t max_q, + const int64_t max_k, + double dropout_p, + bool is_causal, + std::optional scale) { + if (!grad_out.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + auto [ + grad_out_buffer_reshaped, + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + output_buffer_reshaped] = + preprocessing::sdpa_nested_preprocessing_backward( + grad_out, + query, + key, + value, + out, + cum_seq_q, + cum_seq_k, + max_q, + max_k); + + auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped, + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + output_buffer_reshaped, + logsumexp, + philox_seed, + philox_offset, + attn_bias, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + scale); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); +} + + std::tuple _scaled_dot_product_flash_attention_backward_nested( const at::Tensor& grad_out_, const at::Tensor& query, diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 48899d4ce12fb..1a3e2825d4fa8 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -849,16 +849,6 @@ std::tuple #include +#include +#include #include #include #include @@ -184,7 +186,7 @@ std::tuple _flash_attention_backward( return std::make_tuple(Tensor(), Tensor(), Tensor()); } -std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( +std::tuple _cudnn_attention_backward( const Tensor& grad_out, const Tensor& query, const Tensor& key, @@ -211,57 +213,117 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ } } - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_qk = query.size(3); - const int64_t head_dim_v = value.size(3); + const bool is_nested = cum_seq_q.defined(); const int64_t max_seqlen_batch_q = query.size(2); const int64_t max_seqlen_batch_k = key.size(2); - // This is needed because SaveVariable automatically converts - // std::optional to undefined tensor - std::optional attn_bias_; - if (attn_bias.defined()) { - attn_bias_ = attn_bias; - } - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else { - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + if (!is_nested) { + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); + + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + } } - } - const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); - auto dq = at::empty_like(query); - auto dk = at::empty_like(key); - auto dv = at::empty_like(value); - run_cudnn_SDP_bprop(batch_size /*int64_t b*/, - num_heads /*int64_t h*/, - max_q/*int64_t s_q*/, - max_k/*int64_t s_kv*/, - head_dim_qk /*int64_t d_qk*/, - head_dim_v /*int64_t d_v*/, - softmax_scale /*float scaling_factor*/, - is_causal /*bool is_causal*/, - dropout_p /*float dropout_probability*/, - query /*const Tensor& q*/, - key /*const Tensor& k*/, - value /*const Tensor& v*/, - attn_bias_ /*const std::optional& attn_bias*/, - out /*const Tensor& o*/, - grad_out/*const Tensor& dO*/, - logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, - dq/*Tensor& dQ*/, - dk/*Tensor& dK*/, - dv/*Tensor& dV*/, - philox_seed/*Tensor& dropoutseed*/, - philox_offset/*Tensor& dropoutoffset*/); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + run_cudnn_SDP_bprop(batch_size /*int64_t b*/, + num_heads /*int64_t h*/, + max_q/*int64_t s_q*/, + max_k/*int64_t s_kv*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, + softmax_scale /*float scaling_factor*/, + is_causal /*bool is_causal*/, + dropout_p /*float dropout_probability*/, + query /*const Tensor& q*/, + key /*const Tensor& k*/, + value /*const Tensor& v*/, + attn_bias_ /*const std::optional& attn_bias*/, + out /*const Tensor& o*/, + grad_out/*const Tensor& dO*/, + logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, + dq/*Tensor& dQ*/, + dk/*Tensor& dK*/, + dv/*Tensor& dV*/, + philox_seed/*Tensor& dropoutseed*/, + philox_offset/*Tensor& dropoutoffset*/); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + } else { + // BHSD ... + const int64_t batch_size = cum_seq_q.size(0) - 1; + const int64_t num_heads_q = query.size(-2); + const int64_t num_heads_k = key.size(-2); + const int64_t num_heads_v = value.size(-2); + const int64_t head_dim_qk = query.size(-1); + const int64_t head_dim_v = value.size(-1); + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + } + } + + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + run_cudnn_SDP_bprop_nestedtensor( + batch_size, + num_heads_q, + num_heads_k, + num_heads_v, + max_seqlen_batch_q, + max_seqlen_batch_k, + head_dim_qk, + head_dim_v, + softmax_scale, + is_causal, + dropout_p, + cum_seq_q, + cum_seq_k, + query, + key, + value, + attn_bias_, + out, + grad_out, + logsumexp, + dq, + dk, + dv, + philox_seed, + philox_offset); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); + } } std::tuple @@ -1063,4 +1125,40 @@ std::tuple _scaled_dot_product_e } } +std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + const Tensor& philox_seed, + const Tensor& philox_offset, + const Tensor& attn_bias, + const Tensor& cum_seq_q, + const Tensor& cum_seq_k, + const int64_t max_q, + const int64_t max_k, + double dropout_p, + bool is_causal, + std::optional scale) { + return at::_cudnn_attention_backward( + grad_out, + query, + key, + value, + out, + logsumexp, + philox_seed, + philox_offset, + attn_bias, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + scale); +} + } // namespace at::native diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 4b198f4d6d2de..4b85b2d28753a 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -57,21 +57,28 @@ namespace sdp { namespace { +// tracks whether we've set the default priority order once, to avoid setting +// it redundantly or overwriting a user-specified priority order +// when the priority order context manager is used before the default priority +// order is initialized the following happens: +// (1) the current priority order is queried +// (2) priority_order() is called, which initializes it to the default as init_ is false +// (3) the user-specified priority order is set +// (3.1) we are in the priority context... +// (3.2) we exit the priority context... +// (4) the previous priority order (default) is restored +bool priority_order_init_ = false; + // TODO(eqy): more benchmarking to determine whether this should include sm86/89 // Needs to be kept in-sync with test_fused_chocie in test_transformers.py bool check_prefer_cudnn_attention() { - // TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0 - // see context: https://github.com/pytorch/pytorch/issues/138340 - // return false; -#if defined(CUDNN_VERSION) - -#if CUDNN_VERSION > 90000 + static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true; + if (!prefer_cudnn) { + return false; + } +#if (defined(CUDNN_VERSION) && (CUDNN_VERSION > 90000)) auto dprops = at::cuda::getCurrentDeviceProperties(); - return dprops->major >= 9; -#else - return false; -#endif - + return dprops->major >= 9 && !dprops->minor; #else return false; #endif @@ -79,6 +86,16 @@ bool check_prefer_cudnn_attention() { // flash_attention V2 is universally faster than efficient_attention and Math std::array priority_order(sdp_params const& params) { + if (!priority_order_init_) { + priority_order_init_ = true; + if (check_prefer_cudnn_attention()) { + const std::vector cudnn_order = {static_cast(at::SDPBackend::cudnn_attention), + static_cast(at::SDPBackend::flash_attention), + static_cast(at::SDPBackend::efficient_attention), + static_cast(at::SDPBackend::math)}; + at::globalContext().setSDPPriorityOrder(cudnn_order); + } + } return at::globalContext().sDPPriorityOrder(); } @@ -414,12 +431,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } auto head_dim_limit = 128; - if (cudnn_version >= 90501) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (dprops->major == 9 && !dprops->minor) { - head_dim_limit = 256; - } - } + // TODO(eqy): add head dim >= 256 cases once support is finalized if (d_qk > head_dim_limit || d_v > head_dim_limit) { if (debug) { TORCH_WARN("head_dim should be no more than ", head_dim_limit); @@ -453,9 +465,15 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } } - if (s_q == 1 || s_k == 1) { + if (s_k == 1) { + if (debug) { + TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1."); + } + return false; + } + if (s_q == 1 && params.dropout != 0.0) { if (debug) { - TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1."); + TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout."); } return false; } @@ -563,9 +581,9 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) { const auto dprop = at::cuda::getCurrentDeviceProperties(); // Check that the input is nested - if (dprop->major != 9 && has_for_nested_inputs(params)) { + if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) { if (debug) { - TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0."); + TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0."); } return false; } @@ -589,7 +607,7 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { // sdp kernels if (!at::globalContext().userEnabledCuDNNSDP()) { if (debug) { - TORCH_WARN("CuDNN attention has been runtime disabled."); + TORCH_WARN("cuDNN attention has been runtime disabled."); } return false; } @@ -620,7 +638,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { #endif #if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000 if (debug) { - TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)"); + TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)"); } return false; #endif @@ -630,10 +648,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { c10::array_of( check_runtime_disabled_cudnn, check_for_nested_inputs, - check_nonzero_sequence_lengths_dense, check_all_tensors_on_device, check_tensor_shapes, - check_cudnn_tensor_shapes, check_cudnn_deterministic, check_dtypes_low_precision, check_attn_mask_shape, @@ -646,8 +662,10 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { } constexpr auto dense_constraints = c10::array_of( + check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense, - check_batch_size_and_num_heads_dense + check_batch_size_and_num_heads_dense, + check_cudnn_tensor_shapes ); if (has_only_dense_inputs(params)) { @@ -864,7 +882,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) { sdp::can_use_mem_efficient_attention(kernel_params, print_debug); TORCH_WARN("Flash attention kernel not used because:"); sdp::can_use_flash_attention(kernel_params, print_debug); - TORCH_WARN("CuDNN attention kernel not used because:"); + TORCH_WARN("cuDNN attention kernel not used because:"); sdp::can_use_cudnn_attention(kernel_params, print_debug); TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return SDPBackend::error; diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 5c5795f45ce25..c650b102bf1a7 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -75,6 +75,7 @@ aten::_ctc_loss.out aten::_ctc_loss_backward aten::_ctc_loss_backward.Tensor aten::_ctc_loss_backward.out +aten::_cudnn_attention_backward aten::_cudnn_attention_forward aten::_cudnn_ctc_loss aten::_cudnn_ctc_loss.Tensor diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 6037bd4d794cd..00511c572239e 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -26,6 +26,7 @@ run_fw_bw_and_get_code, ) from torch.fx.experimental.proxy_tensor import make_fx +from torch.nn.attention import sdpa_kernel, SDPBackend from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -177,9 +178,10 @@ def test_effn_attn_bias_padding_misaligned(self): inputs = [q, k, v, mask] def f(q, k, v, mask): - return F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0 - ) + with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): + return F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0 + ) f_compiled = torch.compile(f) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index a0c018c45d80f..f4473aacfb8bf 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -6760,11 +6760,10 @@ def check_forward_backward(skip_backward=False): and check_cudnn and (dtype == torch.float16 or dtype == torch.bfloat16) ): - with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.CUDNN_ATTENTION - ): - check_forward_backward() + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.CUDNN_ATTENTION + ): + check_forward_backward() @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") diff --git a/test/test_transformers.py b/test/test_transformers.py index 89db8d798c266..05a21569aeaca 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -49,7 +49,6 @@ PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, - SM90OrLater, tf32_on_and_off, tf32_enabled, ) @@ -2657,6 +2656,7 @@ def test_cudnn_attention_gqa(self, device): @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") + @unittest.expectedFailure # cuDNN currently doesn't support this on SM100+/fails graph validation def test_cudnn_attention_d256_heuristic(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) @@ -2667,7 +2667,7 @@ def test_cudnn_attention_d256_heuristic(self, device): v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) - with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH], set_priority=True): + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION], set_priority=True): actual = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) actual.backward(torch.randn_like(actual)) @@ -2705,7 +2705,7 @@ def test_fused_attention_different_dk_dv(self, device): @skipIfRocm # No cuDNN Attention - @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") + @unittest.skipIf(True, "broken as of cuDNN 9.10") def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 b, h = 1, 2 @@ -2720,7 +2720,6 @@ def test_cudnn_attention_fail_d128(self, device): ISSM90 = device_cap == (9, 0) ISSM100 = device_cap == (10, 0) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): - # SM90/100 support d <= 256 as of cuDNN 9.5.1+ if (ISSM90 or ISSM100) and torch.backends.cudnn.version() >= 90501: torch.nn.functional.scaled_dot_product_attention(q, k, v) else: @@ -3156,15 +3155,19 @@ def test_fused_sdp_choice(self, device, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + device_capability = None + if "cuda" in str(device): + device_capability = torch.cuda.get_device_capability() + prefer_cudnn = "TORCH_CUDNN_SDPA_PREFERRED" in os.environ + prefer_cudnn = prefer_cudnn and device_capability and (device_capability == (9, 0) or device_capability == (10, 0)) + # TODO we are currently disabling this by default, lets assert that this returns # FlashAttention, we need to change when we make remove opt-in for cudnn - if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) - with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) + if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and prefer_cudnn: + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) elif PLATFORM_SUPPORTS_FLASH_ATTENTION: self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) - elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows + elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a778c1a85da09..c050c6cbdc4c3 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2904,6 +2904,10 @@ output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) +- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) + - name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) From 9e07673deb212c87b1c6fea23799a97474c476ed Mon Sep 17 00:00:00 2001 From: Kanya-Mo <167922169+Kanya-Mo@users.noreply.github.com> Date: Fri, 8 Aug 2025 22:36:42 +0000 Subject: [PATCH 0166/1424] Fix test_fsdp_ep.py due to _MeshEnv API change (#158695) #132339 changed parent/child mesh related APIs from _MeshEnv. UT TestFSDPWithEP.test_e2e still uses old APIs and will fail: ``` File "/home/kanya/pytorch/test/distributed/checkpoint/e2e/test_fsdp_ep.py", line 77, in test_e2e mesh_fsdp_ep = _mesh_resources.create_child_mesh(mesh_fsdp_tp, ("dp",)) AttributeError: '_MeshEnv' object has no attribute 'create_child_mesh' To execute this test, run the following from the base repo dir: python test/distributed/checkpoint/e2e/test_fsdp_ep.py TestFSDPWithEP.test_e2e This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0. Did you mean: 'create_sub_mesh'? ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158695 Approved by: https://github.com/Skylion007, https://github.com/nWEIdia --- test/distributed/checkpoint/e2e/test_fsdp_ep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/distributed/checkpoint/e2e/test_fsdp_ep.py b/test/distributed/checkpoint/e2e/test_fsdp_ep.py index 7489317035b99..51d4b3e995372 100644 --- a/test/distributed/checkpoint/e2e/test_fsdp_ep.py +++ b/test/distributed/checkpoint/e2e/test_fsdp_ep.py @@ -73,8 +73,8 @@ def test_e2e(self): self.device_type, (2, 4), mesh_dim_names=("dp", "tp") ) # TODO: we are using an internal API atm. Change to a public API once it is ready. - mesh_fsdp_ep = _mesh_resources.create_child_mesh(mesh_fsdp_tp, ("dp",)) - del _mesh_resources.child_to_parent_mapping[mesh_fsdp_ep] + mesh_fsdp_ep = _mesh_resources.create_sub_mesh(mesh_fsdp_tp, ("dp",), [(0,)]) + del _mesh_resources.child_to_root_mapping[mesh_fsdp_ep] mesh_fsdp = init_device_mesh(self.device_type, (8,)) for i, l in enumerate(model.second.ep_layers): From 4e2ddb5db67617f9f5309c8bba0c17adc84cadbc Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Fri, 8 Aug 2025 22:56:01 +0000 Subject: [PATCH 0167/1424] [Inductor][CUTLASS] Copy cutlass_mock_imports directory (#159724) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pip wheels of PyTorch nightly and 2.8 release candidates do not contain `cutlass_mock_imports`. This is the path to the source code: ``` root@8120d02fd9c5:$ tree ./torch/_inductor/codegen/cuda/cutlass_lib_extensions/ ./torch/_inductor/codegen/cuda/cutlass_lib_extensions/ ├── cutlass_mock_imports │   ├── cuda │   │   ├── __init__.py │   │   ├── cuda.py │   │   └── cudart.py │   ├── pydot │   │   └── __init__.py │   └── scipy │   ├── __init__.py │   └── special.py ├── evt_extensions.py └── gemm_operation_extensions.py 5 directories, 8 files ``` And this what installed wheel has: ``` root@8120d02fd9c5:$ tree /usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/ /usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/ ├── __init__.py ├── evt_extensions.py └── gemm_operation_extensions.py 1 directory, 3 files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159724 Approved by: https://github.com/henrylhtsang --- test/inductor/test_cutlass_backend.py | 13 +++++++++++++ .../cutlass_mock_imports/__init__.py | 0 2 files changed, 13 insertions(+) create mode 100644 torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index c29dff73f9a1e..5889adb120ffa 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -200,6 +200,19 @@ def run_evt_test(self, model, op, shape, num_fusions=1): ) torch.testing.assert_close(result, ref_result) + def test_check_paths(self): + cutlass_mock_imports_path = os.path.join( + os.path.dirname(torch.__file__), + "_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports", + ) + cutlass_mock_cuda_path = os.path.join(cutlass_mock_imports_path, "cuda") + cutlass_mock_pydot_path = os.path.join(cutlass_mock_imports_path, "pydot") + cutlass_mock_scipy_path = os.path.join(cutlass_mock_imports_path, "scipy") + self.assertTrue(os.path.exists(cutlass_mock_imports_path)) + self.assertTrue(os.path.exists(cutlass_mock_cuda_path)) + self.assertTrue(os.path.exists(cutlass_mock_pydot_path)) + self.assertTrue(os.path.exists(cutlass_mock_scipy_path)) + @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_threshold(self): diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d From 566c6d52ef1411c8262d7b9cf85e2044fdfbe1a3 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Fri, 8 Aug 2025 23:09:30 +0000 Subject: [PATCH 0168/1424] [ONNX] Fix the export of the model having none as output (#160200) Fixes #160150 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160200 Approved by: https://github.com/justinchuby Co-authored-by: Justin Chu --- test/onnx/exporter/test_api.py | 12 ++++++++++++ torch/onnx/_internal/exporter/_core.py | 6 ++++++ torch/onnx/_internal/exporter/_testing.py | 3 +++ 3 files changed, 21 insertions(+) diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 9a8a171b5fe29..593cc524ebe7e 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -600,6 +600,18 @@ def test_torchscript_exporter_raises_deprecation_warning(self): SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False ) + def test_model_output_can_be_none(self): + class ModelWithNoneOutput(torch.nn.Module): + def forward(self, x): + return x + 1, None + + onnx_program = torch.onnx.export( + ModelWithNoneOutput(), + (torch.randn(1, 1, 2),), + dynamo=True, + ) + onnx_testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index a4e3eea2e1d28..85aa513c6d023 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -726,6 +726,12 @@ def _handle_output_node( # node.args[0] can be a tuple with more than one elements. This happens when, # for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs for output in node.args[0]: # type: ignore[index,union-attr] + if output is None: + logger.warning( + "Output node %s has None output. The output is ignored in the exported graph. Please ensure the graph output order is expected", + node.name, + ) + continue output_value_name = output.name # type: ignore[union-attr] assert isinstance(output_value_name, str), ( f"Bug: Expected {output_value_name!r} to be a string" diff --git a/torch/onnx/_internal/exporter/_testing.py b/torch/onnx/_internal/exporter/_testing.py index 58f18d0cc923c..c34c2f1a38c3d 100644 --- a/torch/onnx/_internal/exporter/_testing.py +++ b/torch/onnx/_internal/exporter/_testing.py @@ -71,6 +71,9 @@ class names like "TorchExportNonStrictStrategy". # ONNX outputs are always real, so we need to convert torch complex outputs to real representations torch_outputs_adapted = [] for output in torch_outputs: + # ONNX graph does not support None outputs, so we skip them + if output is None: + continue if not isinstance(output, torch.Tensor): torch_outputs_adapted.append(torch.tensor(output)) elif torch.is_complex(output): From 731ee31f7b6ba19307daab323f6196172b71aaf8 Mon Sep 17 00:00:00 2001 From: "Yanan Cao (PyTorch)" Date: Fri, 8 Aug 2025 23:14:13 +0000 Subject: [PATCH 0169/1424] [TorchScript, PT2] Add torch._check compatibility support (#159988) Summary: Add support for torch._check() in TorchScript jit.script frontend. * It will be special cased to behave like torch._assert, turned into an if + raise exception. Test Plan: Unit tests Rollback Plan: Differential Revision: D79744604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159988 Approved by: https://github.com/davidberard98 --- test/jit/test_builtins.py | 158 ++++++++++++++++++ torch/csrc/jit/frontend/sugared_value.cpp | 82 ++++++++- torch/csrc/jit/frontend/sugared_value.h | 18 +- .../csrc/jit/python/python_sugared_value.cpp | 2 + torch/fx/passes/runtime_assert.py | 7 +- 5 files changed, 260 insertions(+), 7 deletions(-) diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index b84bc96519cbc..781080f5deb60 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -131,6 +131,164 @@ def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6}) self.assertEqual(py_out, jit_out) + def test_torch_check(self): + """Test torch._check functionality with flexible argument handling""" + + def test_check_basic(x): + torch._check(x.sum().item() > -1000) + return x + + def test_check_with_message(x): + torch._check(x.sum().item() > -1000, "Tensor sum must be reasonable") + return x + + def test_check_with_kwarg_message(x): + torch._check( + x.sum().item() > -1000, message="Tensor sum must be reasonable" + ) + return x + + def test_check_cond_kwarg(x): + torch._check(cond=x.sum().item() > -1000) + return x + + def test_check_both_kwargs(x): + torch._check(cond=x.sum().item() > -1000, message="Both as kwargs") + return x + + def test_check_kwargs_reversed(x): + torch._check(message="Reversed order", cond=x.sum().item() > -1000) + return x + + def test_check_in_loop(x): + sizes = torch.jit.annotate(List[int], x.tolist()) + for s in sizes: + torch._check(s > -100) + return x + + test_tensor = torch.tensor([1, 2, 3]) + + # Test all variations + self.checkScript(test_check_basic, (test_tensor,)) + self.checkScript(test_check_with_message, (test_tensor,)) + self.checkScript(test_check_with_kwarg_message, (test_tensor,)) + self.checkScript(test_check_cond_kwarg, (test_tensor,)) + self.checkScript(test_check_both_kwargs, (test_tensor,)) + self.checkScript(test_check_kwargs_reversed, (test_tensor,)) + self.checkScript(test_check_in_loop, (test_tensor,)) + + # Test that the compiled functions work correctly + scripted_basic = torch.jit.script(test_check_basic) + scripted_with_message = torch.jit.script(test_check_with_message) + scripted_with_kwarg = torch.jit.script(test_check_with_kwarg_message) + scripted_cond_kwarg = torch.jit.script(test_check_cond_kwarg) + scripted_both_kwargs = torch.jit.script(test_check_both_kwargs) + scripted_kwargs_reversed = torch.jit.script(test_check_kwargs_reversed) + scripted_in_loop = torch.jit.script(test_check_in_loop) + + # These should all succeed without throwing + result1 = scripted_basic(test_tensor) + result2 = scripted_with_message(test_tensor) + result3 = scripted_with_kwarg(test_tensor) + result4 = scripted_cond_kwarg(test_tensor) + result5 = scripted_both_kwargs(test_tensor) + result6 = scripted_kwargs_reversed(test_tensor) + result7 = scripted_in_loop(test_tensor) + + # Results should be the same as input + for result in [result1, result2, result3, result4, result5, result6, result7]: + self.assertEqual(result, test_tensor) + + # Check that the message constants are present in the graphs + FileCheck().check("Tensor sum must be reasonable").run( + scripted_with_message.graph + ) + FileCheck().check("Tensor sum must be reasonable").run( + scripted_with_kwarg.graph + ) + FileCheck().check("Both as kwargs").run(scripted_both_kwargs.graph) + FileCheck().check("Reversed order").run(scripted_kwargs_reversed.graph) + + # Verify the graphs contain some computation (not just empty) + basic_graph_str = str(scripted_basic.graph) + self.assertTrue( + len(basic_graph_str) > 100, "Basic graph should contain some computation" + ) + + # Verify the loop case contains a loop + FileCheck().check("prim::Loop").run(scripted_in_loop.graph) + + for scripted_func in [ + scripted_basic, + scripted_with_message, + scripted_with_kwarg, + scripted_cond_kwarg, + scripted_both_kwargs, + scripted_kwargs_reversed, + ]: + FileCheck().check("prim::If").check("prim::RaiseException").run( + scripted_func.graph + ) + + def test_torch_check_invalid_args(self): + """Test torch._check with invalid arguments""" + + # Test too many arguments + with self.assertRaisesRegex( + RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" + ): + + @torch.jit.script + def too_many_args(x): + torch._check(True, "msg", "extra") + return x + + # Test invalid keyword argument + with self.assertRaisesRegex(RuntimeError, "unexpected keyword argument"): + + @torch.jit.script + def invalid_kwarg(x): + torch._check(True, invalid_arg="msg") + return x + + # Test duplicate cond argument (positional + keyword) + with self.assertRaisesRegex( + RuntimeError, "multiple values for argument 'cond'" + ): + + @torch.jit.script + def duplicate_cond(x): + torch._check(True, cond=False) + return x + + # Test missing required cond argument + with self.assertRaisesRegex(RuntimeError, "missing required argument 'cond'"): + + @torch.jit.script + def missing_cond(x): + torch._check(message="msg only") + return x + + # Test no arguments at all + with self.assertRaisesRegex( + RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" + ): + + @torch.jit.script + def no_args(x): + torch._check() + return x + + # Test too many total arguments (positional + keyword) + with self.assertRaisesRegex( + RuntimeError, "torch._check\\(\\) expects 1 or 2 arguments" + ): + + @torch.jit.script + def too_many_total_args(x): + torch._check(True, "msg", cond=False) + return x + class TestTensorBuiltins(JitTestCase): def test_tensor_properties(self): diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 5f1a3e798bf93..0e9f0c9c2178c 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -359,8 +359,8 @@ void SimpleValue::setAttr( throw( ErrorReport(loc) << "Assignment to attribute '" << field - << "' cannot be of a type that contains class " - << "'" << classType->repr_str() << "'.\n" + << "' cannot be of a type that contains class " << "'" + << classType->repr_str() << "'.\n" << "Classes that recursively contain instances of themselves" << " are not yet supported"); } @@ -826,4 +826,82 @@ SugaredValuePtr SugaredEnumClass::iter( return enum_values_list_constant; } +std::shared_ptr TorchCheckValue::call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) { + if (args.size() + kwargs.size() < 1 || args.size() + kwargs.size() > 2) { + throw( + ErrorReport(loc) << "torch._check() expects 1 or 2 arguments, got " + << (args.size() + kwargs.size())); + } + + NamedValue* cond_arg = nullptr; + NamedValue* message_arg = nullptr; + bool found_cond_kwarg = false; + bool found_message_kwarg = false; + + for (const auto& kwarg : kwargs) { + if (kwarg.name() == "cond") { + if (found_cond_kwarg) { + throw( + ErrorReport(loc) + << "torch._check() got multiple values for argument 'cond'"); + } + cond_arg = const_cast(&kwarg); + found_cond_kwarg = true; + } else if (kwarg.name() == "message") { + if (found_message_kwarg) { + throw( + ErrorReport(loc) + << "torch._check() got multiple values for argument 'message'"); + } + message_arg = const_cast(&kwarg); + found_message_kwarg = true; + } else { + throw( + ErrorReport(loc) << "torch._check() got unexpected keyword argument '" + << kwarg.name() << "'"); + } + } + + if (args.size() >= 1) { + if (found_cond_kwarg) { + throw( + ErrorReport(loc) + << "torch._check() got multiple values for argument 'cond'"); + } + cond_arg = const_cast(&args[0]); + } + + if (args.size() >= 2) { + if (found_message_kwarg) { + throw( + ErrorReport(loc) + << "torch._check() got multiple values for argument 'message'"); + } + message_arg = const_cast(&args[1]); + } + + if (!cond_arg) { + throw( + ErrorReport(loc) << "torch._check() missing required argument 'cond'"); + } + + std::vector assert_args; + assert_args.push_back(*cond_arg); + + if (message_arg) { + assert_args.push_back(*message_arg); + } else { + Value* default_msg = insertConstant(*m.graph(), std::string(""), loc); + assert_args.emplace_back(loc, "message", default_msg); + } + + emitBuiltinCall(loc, *m.graph(), Symbol::aten("_assert"), assert_args, {}); + return std::make_shared(); +} + } // namespace torch::jit diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index d88e77b16cd1b..59ddea774d5d1 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -136,8 +136,7 @@ struct TORCH_API SugaredValue // Value * virtual Value* len(const SourceRange& loc, GraphFunction& m) { throw( - ErrorReport(loc) << "'" << kind() << "'" - << " object is not iterable"); + ErrorReport(loc) << "'" << kind() << "'" << " object is not iterable"); } // expression for ith element for iterable value @@ -858,4 +857,19 @@ struct TORCH_API SliceValue : public SugaredValue { Value* step_; }; +struct TORCH_API TorchCheckValue : public SugaredValue { + explicit TorchCheckValue() = default; + + std::string kind() const override { + return "torch._check sugared value"; + } + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override; +}; + } // namespace torch::jit diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index b9db0be814e45..8b16e089aa50e 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1222,6 +1222,8 @@ std::shared_ptr toSugaredValue( } else if ( obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) { return SpecialFormValue::create(prim::isinstance); + } else if (obj.ptr() == py::module::import("torch").attr("_check").ptr()) { + return std::make_shared(); #ifdef USE_RPC // RPC module is only available when build flag "USE_DISTRIBUTED" is on. } else if ( diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index bb71a25971da7..19e101a5c120a 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -337,12 +337,13 @@ def match_symbol(symint, cb): torch._check, torch.ops.aten._assert_scalar.default, ): + cond = node.args[0] if node.args else node.kwargs.get("cond") if ( - node.args[0] == True # noqa: E712 - or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy + cond == True # noqa: E712 + or (assert_expr := _get_sym_val(cond)) in expr_to_proxy and assert_expr in added_asserts ): - arg = node.args[0] + arg = cond gm.graph.erase_node(node) if isinstance(arg, fx.Node) and not arg.users: gm.graph.erase_node(arg) From 8c41cb800ae0411f02ea5da34bd5ccc3790633b0 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 8 Aug 2025 15:11:14 -0700 Subject: [PATCH 0170/1424] [MPS][BE] Combine all pre-MacOS14 xfail lists (#160228) It does not matter whether it started to fail after 13.1 or 13.3, fact that it still fails on latest MacOS Pull Request resolved: https://github.com/pytorch/pytorch/pull/160228 Approved by: https://github.com/dcci --- torch/testing/_internal/common_mps.py | 151 ++++++-------------------- 1 file changed, 32 insertions(+), 119 deletions(-) diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index fbfa5e2c9f9fb..2aefcce61b73c 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -286,85 +286,6 @@ def mps_ops_modifier( "where", "byte", } - # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 - MACOS_BEFORE_13_3_XFAILLIST = { - # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ - "cdist": [torch.float32], - # CPU Error: cpu not giving nan for x/0.0 - "atan2": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], - # test blow pass on macOS 12 as it falls back to cpu - # Argsort case using duplicate indices (undefined behaviour): - # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') - # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') - # Elements from index 30 and 5133 are both equal. - # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. - "argsort": [torch.float16, torch.int8, torch.uint8, torch.bool], - # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. - # The values of the sorted tensor match the CPU, - # but in case of the returned indices this results in undefined behaviour. - "sort": [torch.int8, torch.uint8, torch.bool, torch.float16], - # Unsupported dtypes - "cumsum": [torch.int64], - "cumprod": [torch.int64], - "cumulative_trapezoid": [torch.int64], - "masked.cumsum": [torch.int64], - "masked.cumprod": [torch.int64], - "linalg.vander": [torch.int64], - # Fail with `Expected 1.0 but got nan.` for empty tensors - # Caused by sample input at index 23: SampleInput( - # input=Tensor[size=(), device="mps:0", dtype=torch.float32], - # args=(0), - # kwargs={'mask': 'Tensor[size=(), device="mps:0", dtype=torch.bool]'}, - # broadcasts_input=False, name='') - "masked.softmin": [torch.float32, torch.float16], - "masked.softmax": [torch.float32, torch.float16], - "masked.log_softmax": [torch.float32, torch.float16], - } - - MACOS_AFTER_13_1_XFAILLIST = { - # before macOS 13.2 it falls back to cpu and pass the forward pass - "grid_sampler_2d": [ - torch.float32, - torch.float16, - torch.bfloat16, - ], # Unsupported Border padding mode - } - - MACOS_13_3_XFAILLIST = { - # Failure due to precision issue for fp16 - # on both cpu and mps there are test cases that might produce inf result - # 'nn.functional.pairwise_distance': [torch.float16], - # test blow pass on macOS 12 as it falls back to cpu - # Argsort case using duplicate indices (undefined behaviour): - # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') - # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') - # Elements from index 30 and 5133 are both equal. - # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. - "argsort": [ - torch.float16, - torch.int8, - torch.uint8, - torch.bool, - torch.bfloat16, - ], - # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. - # The values of the sorted tensor match the CPU, - # but in case of the returned indices this results in undefined behaviour. - "sort": [ - torch.int8, - torch.uint8, - torch.bool, - torch.float16, - torch.bfloat16, - ], - } MACOS_BEFORE_14_4_XFAILLIST = { # These ops work fine in 14.4 but fail in 14.2 or 13.x @@ -497,7 +418,6 @@ def mps_ops_modifier( torch.float16, ], # Unsupported dtypes - "dot": [torch.int64] if MACOS_VERSION < 14.0 else [], "histc": [torch.float16, torch.bfloat16], "index_add": [torch.int64], # GEMM on MPS is not supported for integral types @@ -519,8 +439,6 @@ def mps_ops_modifier( "addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - "matmul": [torch.int64] if MACOS_VERSION < 14.0 else [], - "__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [], # returned output on CPU is float64 "bincount": [ torch.int16, @@ -625,6 +543,38 @@ def mps_ops_modifier( "linalg.matrix_rank": None, # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")` "arange": [torch.uint8], + # before macOS 13.2 it falls back to cpu and pass the forward pass + "grid_sampler_2d": [ + torch.float32, + torch.float16, + torch.bfloat16, + ], # Unsupported Border padding mode + # Failure due to precision issue for fp16 + # on both cpu and mps there are test cases that might produce inf result + # 'nn.functional.pairwise_distance': [torch.float16], + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + "argsort": [ + torch.float16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + ], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, + # but in case of the returned indices this results in undefined behaviour. + "sort": [ + torch.int8, + torch.uint8, + torch.bool, + torch.float16, + torch.bfloat16, + ], } EMPTY_OPS_SKIPLIST = { @@ -692,43 +642,6 @@ def addDecorator(op: OpInfo, d: DecorateInfo) -> None: ), ) - if ( - key in MACOS_BEFORE_13_3_XFAILLIST - and key not in xfail_exclusion - and (torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, - dtypes=MACOS_BEFORE_13_3_XFAILLIST[key], - ), - ) - - if ( - key in MACOS_AFTER_13_1_XFAILLIST - and key not in xfail_exclusion - and torch.backends.mps.is_macos13_or_newer(2) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, dtypes=MACOS_AFTER_13_1_XFAILLIST[key] - ), - ) - - if ( - key in MACOS_13_3_XFAILLIST - and key not in xfail_exclusion - and (MACOS_VERSION >= 13.3) - ): - addDecorator( - op, - DecorateInfo( - unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST[key] - ), - ) - # If ops is not supported for complex types, expect it to fail if key not in SUPPORTED_COMPLEX_OPS: addDecorator( From 9b803cdbe298009f08340c1aaccb25aafbca95d8 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 8 Aug 2025 21:30:05 +0000 Subject: [PATCH 0171/1424] [BE] Remove more optim entries from docs coverage ignore list (#160194) This PR does privatize ReduceLRSchedulerOnPlateau.is_better -> ReduceLRSchedulerOnPlateau._is_better because that API was never meant to be public. A GitHub search for it also reveals that the API is not commonly used much. https://github.com/search?q=.is_better%28&type=code&p=2 If you do use this API and you rely on it for some reason, please file an issue. In the meantime, you can access it through `_is_better(...)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160194 Approved by: https://github.com/albanD, https://github.com/Skylion007 --- docs/source/conf.py | 31 ------------------------------- torch/optim/lr_scheduler.py | 13 +++++++++++-- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 07a44318ff726..4f47652e88d2d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1793,12 +1793,6 @@ # torch.optim.optimizer "register_optimizer_step_post_hook", "register_optimizer_step_pre_hook", - # torch.optim.swa_utils - "get_ema_avg_fn", - "get_ema_multi_avg_fn", - "get_swa_avg_fn", - "get_swa_multi_avg_fn", - "update_bn", # torch.overrides "enable_reentrant_dispatch", # torch.package.analyze.find_first_use_of_broken_modules @@ -2909,31 +2903,6 @@ # torch.onnx.verification "OnnxBackend", "OnnxTestCaseRepro", - # torch.optim.adamax - "Adamax", - # torch.optim.adamw - "AdamW", - # torch.optim.asgd - "ASGD", - # torch.optim.lbfgs - "LBFGS", - # torch.optim.lr_scheduler - "ChainedScheduler", - "ConstantLR", - "CosineAnnealingLR", - "CosineAnnealingWarmRestarts", - "CyclicLR", - "ExponentialLR", - "LRScheduler", - "LambdaLR", - "LinearLR", - "MultiStepLR", - "MultiplicativeLR", - "OneCycleLR", - "PolynomialLR", - "ReduceLROnPlateau", - "SequentialLR", - "StepLR", # torch.optim.optimizer "Optimizer", # torch.overrides diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 6f9f6f1a3cf0c..58ad582bebb91 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1344,7 +1344,7 @@ def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[over warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) self.last_epoch = epoch - if self.is_better(current, self.best): + if self._is_better(current, self.best): self.best = current self.num_bad_epochs = 0 else: @@ -1386,7 +1386,7 @@ def _reduce_lr(self, epoch): def in_cooldown(self): # noqa: D102 return self.cooldown_counter > 0 - def is_better(self, a, best): # noqa: D102 + def _is_better(self, a, best): # noqa: D102 if self.mode == "min" and self.threshold_mode == "rel": rel_epsilon = 1.0 - self.threshold return a < best * rel_epsilon @@ -1686,6 +1686,15 @@ def get_lr(self) -> list[float]: @override def state_dict(self) -> dict[str, Any]: # noqa: D102 + """Return the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + """ state = super().state_dict() # We are dropping the `_scale_fn_ref` attribute because it is a # `weakref.WeakMethod` and can't be pickled. From e96c7c4bb0f6aeae2ab3b6f040f7d67edbec199a Mon Sep 17 00:00:00 2001 From: Ankita George Date: Fri, 8 Aug 2025 11:17:49 -0700 Subject: [PATCH 0172/1424] [dcp][hf] Improve HF consolidation algorithm (#158648) Before we had a bunch of if-else cases based on sharding strategy to decide how to save the tensor with different logic for different strategies. This can be consolidated into one function that uses an algorithm to handle all cases by finding the max possible contiguous bytes that can be written Differential Revision: [D78489438](https://our.internmc.facebook.com/intern/diff/D78489438/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158648 Approved by: https://github.com/saumishr --- .../test_consolidate_hf_safetensors.py | 71 +++ .../checkpoint/_consolidate_hf_safetensors.py | 470 +++++------------- 2 files changed, 191 insertions(+), 350 deletions(-) diff --git a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py index ba07c62728d71..ad74c34c4e2ef 100644 --- a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py +++ b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py @@ -8,6 +8,7 @@ import torch.distributed.checkpoint as dist_cp from torch import distributed as dist from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + _calculate_max_contiguous_elements, consolidate_safetensors_files, ) from torch.distributed.checkpoint._hf_utils import _metadata_fn @@ -153,6 +154,76 @@ def test_consolidate_to_two_files(self): ) dist.barrier() + def test_calculate_max_contiguous_elements_validations(self) -> None: + """Test validation logic in _calculate_max_contiguous_elements function.""" + + # Test empty lists validation + with self.assertRaisesRegex(ValueError, "Input lists cannot be empty"): + _calculate_max_contiguous_elements([], [2, 3], [4, 5]) + + # Test mismatched list lengths validation + with self.assertRaisesRegex( + ValueError, "All input lists must have the same length" + ): + _calculate_max_contiguous_elements([1], [2, 3], [4, 5]) + + # Test indices out of bounds validation + with self.assertRaisesRegex( + ValueError, "Index .* at dimension .* is out of bounds for sub-tensor shape" + ): + _calculate_max_contiguous_elements( + [2, 1], [2, 3], [4, 5] + ) # indices[0] >= sub_tensor_shape[0] + + # Test sub-tensor dimensions exceeding tensor dimensions validation + with self.assertRaisesRegex( + ValueError, + "Sub-tensor dimension .* at position .* exceeds tensor dimension", + ): + _calculate_max_contiguous_elements( + [1, 2], [2, 6], [4, 5] + ) # sub_tensor_shape[1] > tensor_shape[1] + + def test_calculate_max_contiguous_elements_valid_cases(self) -> None: + """Test valid cases for _calculate_max_contiguous_elements function.""" + + # Test 1D case - simple remaining elements + result = _calculate_max_contiguous_elements([2], [5], [10]) + self.assertEqual(result, 3) # 5 - 2 = 3 elements remaining + + # Test 2D case - at start of row, can write complete rows + result = _calculate_max_contiguous_elements([1, 0], [3, 4], [6, 4]) + self.assertEqual(result, 8) # 2 rows * 4 columns = 8 elements + + # Test 2D case - middle of row, only remaining in current row + result = _calculate_max_contiguous_elements([1, 2], [3, 4], [6, 8]) + self.assertEqual(result, 2) # 4 - 2 = 2 elements remaining in row + + # Test 3D case - at start of 2D slice, can write complete slices + result = _calculate_max_contiguous_elements([1, 0, 0], [3, 2, 4], [5, 2, 4]) + self.assertEqual(result, 16) # 2 slices * 2 rows * 4 columns = 16 elements + + # Test edge case - at last position + result = _calculate_max_contiguous_elements([2, 3], [3, 4], [6, 8]) + self.assertEqual(result, 1) # Only 1 element remaining + + # Test case where sub-tensor spans full width + result = _calculate_max_contiguous_elements([0, 0], [2, 5], [4, 5]) + self.assertEqual(result, 10) # 2 rows * 5 columns = 10 elements + + # Test column-wise sharded case - sub-tensor doesn't span full width + # Even at start of row, can only write width of one row due to column sharding + result = _calculate_max_contiguous_elements([1, 0], [3, 2], [4, 8]) + self.assertEqual( + result, 2 + ) # Only 2 elements (width of sub-tensor) can be written contiguously + + # Test another column-wise sharded case - middle of tensor + result = _calculate_max_contiguous_elements([0, 0], [2, 3], [6, 10]) + self.assertEqual( + result, 3 + ) # Only 3 elements (width of sub-tensor) can be written contiguously + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index dc988e999c4ed..8577180e9f893 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -358,189 +358,6 @@ def _write_data( raise -def _write_row_wise_tensor( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytearray, - element_size: int, - full_tensor_strides: list[int], - sub_tensor_strides: list[int], - sub_tensor_offsets: list[int], - sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: - """ - Writes a row-wise sharded tensor to the output file. - - This is an optimized path for tensors that are sharded along the first dimension, - with all other dimensions being complete. This allows writing entire rows at once. - - Args: - fs: Filesystem interface for file operations - sub_tensor_bytes: Byte array containing the sub-tensor data - element_size: The size of each element in bytes - full_tensor_strides: Strides of the full tensor - sub_tensor_strides: Strides of the sub-tensor - sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor - sub_tensor_shape: The shape of the sub-tensor - output_file_path: The path to the file where the full tensor is stored - output_start_byte: The starting byte of the full tensor in the file - """ - # Open the output file in read+binary mode to allow seeking and writing - with fs.open(output_file_path, "r+b") as out_f: - # Calculate the number of elements in each row - elements_per_row = full_tensor_strides[ - 0 - ] # This is the stride of the first dimension - - # For each row in the sub-tensor - for row_idx in range(sub_tensor_shape[0]): - # Calculate the row index in the full tensor - full_row_idx = sub_tensor_offsets[0] + row_idx - - # Calculate the position in the full tensor - full_pos = full_row_idx * full_tensor_strides[0] - full_byte_offset = output_start_byte + full_pos * element_size - - # Calculate the position in the sub-tensor - sub_pos = row_idx * sub_tensor_strides[0] - sub_byte_offset = sub_pos * element_size - - # Extract the row data from the sub-tensor - row_size = elements_per_row * element_size - row_data = sub_tensor_bytes[sub_byte_offset : sub_byte_offset + row_size] - - # Seek to the correct position in the output file and write the data - out_f.seek(full_byte_offset) - out_f.write(row_data) - - -def _write_column_wise_tensor( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytearray, - element_size: int, - tensor_shape: list[int], - sub_tensor_offsets: list[int], - sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: - """ - Writes a column-wise sharded 2D tensor to the output file. - - This is an optimized path for 2D tensors that are sharded along the second dimension, - with the first dimension being complete. This requires writing column by column. - - Args: - fs: Filesystem interface for file operations - sub_tensor_bytes: Byte array containing the sub-tensor data - element_size: The size of each element in bytes - tensor_shape: The shape of the overall tensor - sub_tensor_strides: Strides of the sub-tensor - sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor - sub_tensor_shape: The shape of the sub-tensor - output_file_path: The path to the file where the full tensor is stored - output_start_byte: The starting byte of the full tensor in the file - """ - # Open the output file in read+binary mode to allow seeking and writing - with fs.open(output_file_path, "r+b") as out_f: - # For each column in the sub-tensor - for col_idx in range(sub_tensor_shape[1]): - # Calculate the column index in the full tensor - full_col_idx = sub_tensor_offsets[1] + col_idx - - # For each row in the column - for row_idx in range(sub_tensor_shape[0]): - # Calculate the position in the full tensor - full_pos = row_idx * tensor_shape[1] + full_col_idx - full_byte_offset = output_start_byte + full_pos * element_size - - # Calculate the position in the sub-tensor - sub_pos = row_idx * sub_tensor_shape[1] + col_idx - sub_byte_offset = sub_pos * element_size - - # Extract the element data from the sub-tensor - element_data = sub_tensor_bytes[ - sub_byte_offset : sub_byte_offset + element_size - ] - - # Seek to the correct position in the output file and write the data - out_f.seek(full_byte_offset) - out_f.write(element_data) - - -def _write_element_by_element( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytearray, - element_size: int, - tensor_shape: list[int], - full_tensor_strides: list[int], - sub_tensor_strides: list[int], - sub_tensor_offsets: list[int], - sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: - """ - Writes a sub-tensor to the output file using a general element-by-element approach. - - This is a general approach that works for any sharding pattern, but is less efficient - than the specialized approaches for row-wise or column-wise sharding. - - Args: - fs: Filesystem interface for file operations - sub_tensor_bytes: Byte array containing the sub-tensor data - element_size: The size of each element in bytes - tensor_shape: The shape of the overall tensor - full_tensor_strides: Strides of the full tensor - sub_tensor_strides: Strides of the sub-tensor - sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor - sub_tensor_shape: The shape of the sub-tensor - output_file_path: The path to the file where the full tensor is stored - output_start_byte: The starting byte of the full tensor in the file - """ - # Open the output file in read+binary mode to allow seeking and writing - with fs.open(output_file_path, "r+b") as out_f: - # Create a list to hold the current indices for each dimension - indices = [0] * len(tensor_shape) - - # Calculate the total number of elements in the sub-tensor - total_elements = 1 - for dim_size in sub_tensor_shape: - total_elements *= dim_size - - # Process each element in the sub-tensor - for element_idx in range(total_elements): - # Calculate the indices for this element in the sub-tensor - sub_idx = element_idx - for dim in range(len(sub_tensor_shape) - 1, -1, -1): - indices[dim] = sub_idx % sub_tensor_shape[dim] - sub_idx //= sub_tensor_shape[dim] - - # Calculate the position of this element in the sub-tensor's byte array - sub_pos = 0 - for dim in range(len(sub_tensor_shape)): - sub_pos += indices[dim] * sub_tensor_strides[dim] - sub_byte_offset = sub_pos * element_size - - # Calculate the position of this element in the full tensor - full_pos = 0 - for dim in range(len(tensor_shape)): - # The global index is the local index plus the offset for this dimension - global_idx = indices[dim] + sub_tensor_offsets[dim] - full_pos += global_idx * full_tensor_strides[dim] - full_byte_offset = output_start_byte + full_pos * element_size - - # Extract the element data from the sub-tensor - element_data = sub_tensor_bytes[ - sub_byte_offset : sub_byte_offset + element_size - ] - - # Seek to the correct position in the output file and write the data - out_f.seek(full_byte_offset) - out_f.write(element_data) - - def _write_sub_tensor_to_file_optimized( fs: fsspec.AbstractFileSystem, sub_tensor_bytes: bytes, @@ -552,12 +369,14 @@ def _write_sub_tensor_to_file_optimized( output_start_byte: int, ) -> None: """ - Optimized version of _write_sub_tensor_to_file with enhanced sharding pattern detection. + Optimized version that writes the maximum number of contiguous bytes possible. - Uses advanced pattern detection to optimize common sharding patterns: - - Row-wise sharding with memory-efficient bulk copying - - Contiguous chunk detection for direct memory operations - - General fallback for arbitrary patterns + Uses a unified algorithm that calculates the maximum contiguous bytes that can be + written in each iteration and continues until the entire subtensor is written. + Handles all sharding patterns efficiently: + - Full sub-tensor at once for row-wise sharding + - Row-by-row for column-wise sharding + - Optimized chunks for other patterns Args: fs: Filesystem interface for file operations @@ -573,184 +392,135 @@ def _write_sub_tensor_to_file_optimized( if not tensor_shape or not sub_tensor_shape: return - # Enhanced row-wise sharding detection - if len(tensor_shape) >= 2 and len(sub_tensor_shape) >= 2: - # Check if this is a row-wise chunk (all dims except first are complete) - is_row_wise = all( - sub_tensor_shape[i] == tensor_shape[i] and sub_tensor_offsets[i] == 0 - for i in range(1, len(tensor_shape)) - ) + # Calculate tensor strides for efficient indexing + tensor_strides = [1] + for i in range(len(tensor_shape) - 1, 0, -1): + tensor_strides.insert(0, tensor_strides[0] * tensor_shape[i]) - if is_row_wise: - # Optimized row-wise copy using bulk memory operations - _write_row_wise_tensor_optimized( - fs, - sub_tensor_bytes, - element_size, - tensor_shape, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) - return - - # Fall back to the original implementation for complex patterns - _write_sub_tensor_to_file( - fs, - bytearray(sub_tensor_bytes), - element_size, - tensor_shape, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) + sub_tensor_strides = [1] + for i in range(len(sub_tensor_shape) - 1, 0, -1): + sub_tensor_strides.insert(0, sub_tensor_strides[0] * sub_tensor_shape[i]) + total_elements = math.prod(sub_tensor_shape) -def _write_row_wise_tensor_optimized( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytes, - element_size: int, - tensor_shape: list[int], - sub_tensor_offsets: list[int], - sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: - """ - Optimized row-wise tensor writing using bulk memory operations. - - This function an optimization strategy: - - Direct memory copy for contiguous rows - - Minimal file seeking operations - - Bulk data transfer instead of element-by-element - """ with fs.open(output_file_path, "r+b") as out_f: - # Optimized row-wise copy - elements_per_row = math.prod(tensor_shape[1:]) - bytes_per_row = elements_per_row * element_size + elements_written = 0 + + while elements_written < total_elements: + # Convert linear index to multi-dimensional indices + temp_idx = elements_written + indices = [] + for dim_size in reversed(sub_tensor_shape): + indices.append(temp_idx % dim_size) + temp_idx //= dim_size + indices.reverse() + + # Calculate maximum contiguous elements we can write from this position + max_contiguous = _calculate_max_contiguous_elements( + indices, sub_tensor_shape, tensor_shape + ) + + # Calculate source position in bytes + src_pos = sum( + idx * stride for idx, stride in zip(indices, sub_tensor_strides) + ) + src_byte_offset = src_pos * element_size - start_row = sub_tensor_offsets[0] - num_rows = sub_tensor_shape[0] + # Calculate destination position in bytes + dest_indices = [ + idx + offset for idx, offset in zip(indices, sub_tensor_offsets) + ] + dest_pos = sum( + idx * stride for idx, stride in zip(dest_indices, tensor_strides) + ) + dest_byte_offset = output_start_byte + dest_pos * element_size - # Calculate byte positions - tensor_start_byte = output_start_byte + start_row * bytes_per_row - chunk_size_bytes = num_rows * bytes_per_row + # Write the contiguous chunk + bytes_to_write = max_contiguous * element_size + out_f.seek(dest_byte_offset) + chunk_data = sub_tensor_bytes[ + src_byte_offset : src_byte_offset + bytes_to_write + ] + out_f.write(chunk_data) - # Direct memory copy for contiguous rows - out_f.seek(tensor_start_byte) - out_f.write(sub_tensor_bytes[:chunk_size_bytes]) + elements_written += max_contiguous -def _write_sub_tensor_to_file( - fs: fsspec.AbstractFileSystem, - sub_tensor_bytes: bytearray, - element_size: int, - tensor_shape: list[int], - sub_tensor_offsets: list[int], +def _calculate_max_contiguous_elements( + indices: list[int], sub_tensor_shape: list[int], - output_file_path: str, - output_start_byte: int, -) -> None: + tensor_shape: list[int], +) -> int: """ - Original implementation - writes a sub-tensor from a byte array into a file representing the full tensor at specified offsets. + Calculate the maximum number of contiguous elements that can be written from current position. - This function handles the complex task of placing a tensor shard (sub-tensor) at the correct - position within the consolidated tensor file. It works by calculating the exact byte offsets - for each slice of data and writing them to the appropriate positions. This implementation - supports tensors of any dimensionality with optimized paths for common sharding patterns: - - Row-wise sharding (optimized path) - - Column-wise sharding for 2D tensors (optimized path) - - Any other arbitrary sharding pattern (general element-by-element approach) + This determines the largest chunk by checking how elements are laid out in memory + and finding natural boundaries where contiguity breaks. Args: - fs: Filesystem interface for file operations - sub_tensor_bytes: Byte array containing the sub-tensor data - element_size: The size of each element in bytes - tensor_shape: The shape of the overall tensor (list) - sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor (list) - sub_tensor_shape: The shape of the sub-tensor (list) - output_file_path: The path to the file where the full tensor is stored - output_start_byte: The starting byte of the full tensor in the file + indices: Current position indices in the sub-tensor + sub_tensor_shape: Shape of the sub-tensor being written + tensor_shape: Shape of the full tensor + + Raises: + ValueError: If input lists are empty, have mismatched lengths, or contain invalid values """ - # Handle the case of empty tensors - if not tensor_shape or not sub_tensor_shape: - return + # Validate input lists are not empty + if not indices or not sub_tensor_shape or not tensor_shape: + raise ValueError("Input lists cannot be empty") - # Calculate strides for the full tensor (row-major order, C-style) - # Stride is the number of elements to skip to move to the next element in that dimension - full_tensor_strides = [1] * len(tensor_shape) - for i in range(len(tensor_shape) - 2, -1, -1): - full_tensor_strides[i] = full_tensor_strides[i + 1] * tensor_shape[i + 1] - - # Calculate strides for the sub-tensor (row-major order, C-style) - sub_tensor_strides = [1] * len(sub_tensor_shape) - for i in range(len(sub_tensor_shape) - 2, -1, -1): - sub_tensor_strides[i] = sub_tensor_strides[i + 1] * sub_tensor_shape[i + 1] - - # Check if this is a row-wise sharded tensor - # Row-wise sharding is detected when the last dimension is complete - # and only the first dimension is partial - is_row_wise = False - if len(tensor_shape) >= 2: - # Check if all dimensions except the first are complete - all_other_dims_complete = True - for i in range(1, len(tensor_shape)): - if sub_tensor_shape[i] != tensor_shape[i]: - all_other_dims_complete = False - break - - # Row-wise sharding: first dimension is partial, all others are complete - is_row_wise = all_other_dims_complete and sub_tensor_shape[0] < tensor_shape[0] - - # Check if this is a column-wise sharded 2D tensor - # Column-wise sharding is detected when the first dimension is complete - # and the second dimension is partial (only for 2D tensors) - is_column_wise = False - if len(tensor_shape) == 2: - is_column_wise = ( - sub_tensor_shape[0] == tensor_shape[0] - and sub_tensor_shape[1] < tensor_shape[1] + # Validate all lists have the same length (same number of dimensions) + if not (len(indices) == len(sub_tensor_shape) == len(tensor_shape)): + raise ValueError( + f"All input lists must have the same length. Got indices: {len(indices)}, " + f"sub_tensor_shape: {len(sub_tensor_shape)}, tensor_shape: {len(tensor_shape)}" ) - # Call the appropriate function based on the sharding pattern - if is_row_wise: - _write_row_wise_tensor( - fs, - sub_tensor_bytes, - element_size, - full_tensor_strides, - sub_tensor_strides, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) - elif is_column_wise: - _write_column_wise_tensor( - fs, - sub_tensor_bytes, - element_size, - tensor_shape, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) - else: - _write_element_by_element( - fs, - sub_tensor_bytes, - element_size, - tensor_shape, - full_tensor_strides, - sub_tensor_strides, - sub_tensor_offsets, - sub_tensor_shape, - output_file_path, - output_start_byte, - ) + # Validate indices are within bounds of sub_tensor_shape + for i, (idx, sub_dim) in enumerate(zip(indices, sub_tensor_shape)): + if idx >= sub_dim: + raise ValueError( + f"Index {idx} at dimension {i} is out of bounds for sub-tensor shape {sub_tensor_shape}" + ) + + # Validate sub_tensor dimensions don't exceed tensor dimensions + for i, (sub_dim, tensor_dim) in enumerate(zip(sub_tensor_shape, tensor_shape)): + if sub_dim > tensor_dim: + raise ValueError( + f"Sub-tensor dimension {sub_dim} at position {i} exceeds tensor dimension {tensor_dim}" + ) + + # Start with elements remaining in the last dimension + max_contiguous = sub_tensor_shape[-1] - indices[-1] + + # Check if we can extend across multiple dimensions + # We can write across dimension boundaries if we're writing complete "rows" + # and the layout in destination tensor maintains contiguity + + # For 2D case: check if we can write multiple complete rows + if len(sub_tensor_shape) >= 2: + # If we're at the start of a row and can write complete rows + if indices[-1] == 0: # At start of last dimension (column) + rows_remaining = sub_tensor_shape[-2] - indices[-2] # Rows left to write + + # Check if writing complete rows maintains contiguity in destination + # This is true for row-wise sharding or when sub-tensor spans full width + if sub_tensor_shape[-1] == tensor_shape[-1]: # Full width + max_contiguous = rows_remaining * sub_tensor_shape[-1] + + # For higher dimensions, check if we can extend further + if len(sub_tensor_shape) >= 3 and indices[-2] == 0: + # Check if we can write complete 2D slices + remaining_in_dim = sub_tensor_shape[-3] - indices[-3] + if ( + sub_tensor_shape[-1] == tensor_shape[-1] + and sub_tensor_shape[-2] == tensor_shape[-2] + ): + max_contiguous = ( + remaining_in_dim * sub_tensor_shape[-2] * sub_tensor_shape[-1] + ) + + return max_contiguous def _write_overall_metadata_file( @@ -846,7 +616,7 @@ def consolidate_safetensors_files( for fqn, index in fqn_to_index_mapping.items(): # Generate names like "model-00001-of-00005.safetensors" file_name = _gen_file_name(index, max(fqn_to_index_mapping.values())) - output_path = f"{local_output_dir}/{file_name}" + output_path = os.path.join(local_output_dir, file_name) if output_path not in output_files_data: output_files_data[output_path] = _OutputFileData( @@ -857,7 +627,7 @@ def consolidate_safetensors_files( else: # If no mapping is provided, create a single output file file_name = _gen_file_name(1, 1) - output_path = f"{local_output_dir}/{file_name}" + output_path = os.path.join(local_output_dir, file_name) output_files_data[output_path] = _OutputFileData() # Find all safetensors files in the input directory From 11a3565f1872bbad9c253a127e8d4ce7a1b40ec8 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Sat, 9 Aug 2025 01:04:21 +0000 Subject: [PATCH 0173/1424] [Torch Native] Add test for packaging weight (#158750) Add test that require weights to be packaged for torch native For now, we need `package_weights_in_so=True` for compile standalone. The constants are in a `.o` file and will be added as a source to the CMakeLists.txt of the model. After we added weight deduping, we should be able to let this config be False. ``` python test/inductor/test_aot_inductor_package.py -k test_compile_with_exporter_weights ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158750 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor_package.py | 51 +++++++++++++++++++++- torch/export/experimental/__init__.py | 3 +- torch/export/experimental/_utils.py | 14 ++++-- 3 files changed, 62 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 2809f5533bd9c..46152103836a4 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -157,6 +157,7 @@ def cmake_compile_and_run(self, base_dir): check=True, ) subprocess.run(["make"], cwd=build_path, check=True) + result = subprocess.run( ["./build/main"], cwd=base_dir, @@ -502,16 +503,62 @@ def default(*args, **kwargs): if self.device == GPU_TYPE: self.assertEqual( result.stdout, - "output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2 0 0 0\n" + "output_tensor1\n 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2\n 0 0 0\n" " 0 0 0\n 0 0 0\n[ CUDAFloatType{3,3} ]\n", ) else: self.assertEqual( result.stdout, - "output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2 0 0 0\n" + "output_tensor1\n 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2\n 0 0 0\n" " 0 0 0\n 0 0 0\n[ CPUFloatType{3,3} ]\n", ) + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) + @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") + @skipIfRocm # doesn't support multi-arch binary + @skipIfXpu # doesn't support multi-arch binary + @torch._inductor.config.patch("test_configs.use_libtorch", True) + def test_compile_with_exporter_weights(self): + self.check_package_cpp_only() + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.fc1(x) + return x + + def default(*args, **kwargs): + return None + + example_inputs = (torch.ones(3, 3).to(self.device),) + + package = _ExportPackage() + m1 = Model().to(self.device) + exporter1 = package._exporter("Model", m1)._define_overload("default", default) + exporter1(*example_inputs) + expected_res = m1(*example_inputs) + + package_example_inputs = True + with ( + tempfile.TemporaryDirectory() as tmp_dir, + ): + package._compiled_and_package( + tmp_dir + "/package.pt2", True, package_example_inputs + ) + + # Test compiling generated files + self.cmake_compile_and_run(tmp_dir) + tensor_model = torch.load( + tmp_dir + "/output_tensor1.pt", weights_only=False + ) + true_res = next(iter(tensor_model.parameters())) + self.assertEqual(expected_res, true_res) + def test_metadata(self): class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 372eb3a29533d..1c87bb29bfe96 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -360,7 +360,8 @@ def _compiled_and_package( "aot_inductor.package": True, "aot_inductor.package_cpp_only": True, "always_keep_tensor_constants": True, - "aot_inductor.package_constants_in_so": False, + # we'll change this back to False once we enable weight deduping for standalone mode + "aot_inductor.package_constants_in_so": standalone, "aot_inductor.compile_standalone": standalone, } aoti_files_map = {} diff --git a/torch/export/experimental/_utils.py b/torch/export/experimental/_utils.py index 910c45c2ceb9d..67bda0c34ce4f 100644 --- a/torch/export/experimental/_utils.py +++ b/torch/export/experimental/_utils.py @@ -1,9 +1,11 @@ +import logging import typing from torch._inductor.utils import IndentedBuffer __all__ = [] # type: ignore[var-annotated] +logger = logging.getLogger(__name__) def _get_main_cpp_file( @@ -125,8 +127,10 @@ def _get_main_cpp_file( [ f"auto constants_map{i + 1} = std::make_shared();", f"auto constants_array{i + 1} = std::make_shared>();", - f"auto model{i + 1} = AOTInductorModel{model_name}::Create(", - f" constants_map{i + 1}, constants_array{i + 1}, device_str,", + f"auto model{i + 1} = std::make_unique(", + f" std::move(constants_map{i + 1}),", + f" std::move(constants_array{i + 1}),", + " device_str,", f' "{package_name}/data/aotinductor/{model_name}/");', f"model{i + 1}->load_constants();", ] @@ -154,7 +158,10 @@ def _get_main_cpp_file( ib.writeline("\n// Validate outputs") for i in range(len(model_names)): ib.writeline( - f"""std::cout << "output_tensor{i + 1}" << output_tensor{i + 1} << std::endl;""" + f"""std::cout << "output_tensor{i + 1}\\n" << output_tensor{i + 1} << std::endl;""" + ) + ib.writeline( + f"""torch::save(output_tensor{i + 1}, "output_tensor{i + 1}.pt");""" ) ib.writeline("return 0;") @@ -205,6 +212,7 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str model_libs = " ".join(model_names) ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})") + if cuda: ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})") From 10e3514c962b58cbbee994257872a626ff76d51b Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 9 Aug 2025 02:21:22 +0000 Subject: [PATCH 0174/1424] Remove tensorexpr tests (#158928) The tests are not maintained. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158928 Approved by: https://github.com/albanD, https://github.com/malfet --- .ci/pytorch/build.sh | 4 - .ci/pytorch/test.sh | 10 - aten/src/ATen/test/thread_init_test.cpp | 11 +- caffe2/CMakeLists.txt | 4 - test/cpp/tensorexpr/CMakeLists.txt | 83 - test/cpp/tensorexpr/README.md | 55 - test/cpp/tensorexpr/gtest_assert_float_eq.h | 119 - test/cpp/tensorexpr/padded_buffer.cpp | 37 - test/cpp/tensorexpr/padded_buffer.h | 242 - test/cpp/tensorexpr/test_approx.cpp | 96 - test/cpp/tensorexpr/test_aten.cpp | 1068 --- test/cpp/tensorexpr/test_base.h | 89 - test/cpp/tensorexpr/test_boundsinference.cpp | 1019 --- test/cpp/tensorexpr/test_conv.cpp | 234 - test/cpp/tensorexpr/test_cpp_codegen.cpp | 259 - test/cpp/tensorexpr/test_cuda.cpp | 2344 ------ test/cpp/tensorexpr/test_dynamic_shapes.cpp | 701 -- test/cpp/tensorexpr/test_expr.cpp | 836 -- test/cpp/tensorexpr/test_external_calls.cpp | 1061 --- test/cpp/tensorexpr/test_graph_opt.cpp | 319 - test/cpp/tensorexpr/test_ir_printer.cpp | 98 - test/cpp/tensorexpr/test_ir_verifier.cpp | 191 - test/cpp/tensorexpr/test_kernel.cpp | 2133 ----- test/cpp/tensorexpr/test_llvm.cpp | 1799 ----- test/cpp/tensorexpr/test_loopnest.cpp | 6894 ----------------- test/cpp/tensorexpr/test_memdependency.cpp | 3252 -------- test/cpp/tensorexpr/test_memplanning.cpp | 708 -- test/cpp/tensorexpr/test_ops.cpp | 78 - test/cpp/tensorexpr/test_quantization.cpp | 452 -- test/cpp/tensorexpr/test_reductions.cpp | 1928 ----- test/cpp/tensorexpr/test_registerizer.cpp | 3702 --------- test/cpp/tensorexpr/test_simplify.cpp | 5680 -------------- test/cpp/tensorexpr/test_te_fuser_pass.cpp | 402 - test/cpp/tensorexpr/test_type.cpp | 202 - .../tensorexpr/test_type_specializations.cpp | 75 - test/cpp/tensorexpr/test_utils.h | 78 - test/cpp/tensorexpr/tutorial.cpp | 542 -- test/test_jit_fuser_te.py | 5 +- torch/csrc/jit/runtime/static/ops.cpp | 2 +- 39 files changed, 10 insertions(+), 36802 deletions(-) delete mode 100644 test/cpp/tensorexpr/CMakeLists.txt delete mode 100644 test/cpp/tensorexpr/README.md delete mode 100644 test/cpp/tensorexpr/gtest_assert_float_eq.h delete mode 100644 test/cpp/tensorexpr/padded_buffer.cpp delete mode 100644 test/cpp/tensorexpr/padded_buffer.h delete mode 100644 test/cpp/tensorexpr/test_approx.cpp delete mode 100644 test/cpp/tensorexpr/test_aten.cpp delete mode 100644 test/cpp/tensorexpr/test_base.h delete mode 100644 test/cpp/tensorexpr/test_boundsinference.cpp delete mode 100644 test/cpp/tensorexpr/test_conv.cpp delete mode 100644 test/cpp/tensorexpr/test_cpp_codegen.cpp delete mode 100644 test/cpp/tensorexpr/test_cuda.cpp delete mode 100644 test/cpp/tensorexpr/test_dynamic_shapes.cpp delete mode 100644 test/cpp/tensorexpr/test_expr.cpp delete mode 100644 test/cpp/tensorexpr/test_external_calls.cpp delete mode 100644 test/cpp/tensorexpr/test_graph_opt.cpp delete mode 100644 test/cpp/tensorexpr/test_ir_printer.cpp delete mode 100644 test/cpp/tensorexpr/test_ir_verifier.cpp delete mode 100644 test/cpp/tensorexpr/test_kernel.cpp delete mode 100644 test/cpp/tensorexpr/test_llvm.cpp delete mode 100644 test/cpp/tensorexpr/test_loopnest.cpp delete mode 100644 test/cpp/tensorexpr/test_memdependency.cpp delete mode 100644 test/cpp/tensorexpr/test_memplanning.cpp delete mode 100644 test/cpp/tensorexpr/test_ops.cpp delete mode 100644 test/cpp/tensorexpr/test_quantization.cpp delete mode 100644 test/cpp/tensorexpr/test_reductions.cpp delete mode 100644 test/cpp/tensorexpr/test_registerizer.cpp delete mode 100644 test/cpp/tensorexpr/test_simplify.cpp delete mode 100644 test/cpp/tensorexpr/test_te_fuser_pass.cpp delete mode 100644 test/cpp/tensorexpr/test_type.cpp delete mode 100644 test/cpp/tensorexpr/test_type_specializations.cpp delete mode 100644 test/cpp/tensorexpr/test_utils.h delete mode 100644 test/cpp/tensorexpr/tutorial.cpp diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index c7d2cb93a64b9..65f97389324a5 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -50,9 +50,6 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then export ATEN_THREADING=NATIVE fi -# Enable LLVM dependency for TensorExpr testing -export USE_LLVM=/opt/llvm -export LLVM_DIR=/opt/llvm/lib/cmake/llvm if ! which conda; then # In ROCm CIs, we are doing cross compilation on build machines with @@ -192,7 +189,6 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then export USE_ASAN=1 export REL_WITH_DEB_INFO=1 export UBSAN_FLAGS="-fno-sanitize-recover=all" - unset USE_LLVM fi if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 84d40a2e458a1..473a125475c4e 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1051,20 +1051,10 @@ test_libtorch_api() { mkdir -p $TEST_REPORTS_DIR OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml - "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml else # Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest" - # On s390x, pytorch is built without llvm. - # Even if it would be built with llvm, llvm currently doesn't support used features on s390x and - # test fails with errors like: - # JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer - # unknown file: Failure - # C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) } - if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then - python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr - fi fi # quantization is not fully supported on s390x yet diff --git a/aten/src/ATen/test/thread_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp index 7ad7a18e9c660..60dd52d1dffcb 100644 --- a/aten/src/ATen/test/thread_init_test.cpp +++ b/aten/src/ATen/test/thread_init_test.cpp @@ -1,7 +1,8 @@ +#include + #include #include #include -#include #include @@ -9,7 +10,7 @@ // numbers of threads set and also whether the scheduler // will throw an exception when multiple threads call // their first parallel construct. -void test(int given_num_threads) { +static void test(int given_num_threads) { auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat)); ASSERT_TRUE(given_num_threads >= 0); ASSERT_EQ(at::get_num_threads(), given_num_threads); @@ -19,7 +20,7 @@ void test(int given_num_threads) { } } -int main() { +TEST(ThreadInitTest, ThreadInit) { at::init_num_threads(); at::set_num_threads(4); @@ -32,13 +33,11 @@ int main() { #if !AT_PARALLEL_NATIVE at::set_num_threads(5); - ASSERT_TRUE(at::get_num_threads() == 5); + ASSERT_EQ(at::get_num_threads(), 5); #endif // test inter-op settings at::set_num_interop_threads(5); ASSERT_EQ(at::get_num_interop_threads(), 5); ASSERT_ANY_THROW(at::set_num_interop_threads(6)); - - return 0; } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index c346cedbcf519..96ed0c3b918e7 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1345,10 +1345,6 @@ if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) - add_subdirectory( - ${TORCH_ROOT}/test/cpp/tensorexpr - ${CMAKE_BINARY_DIR}/test_tensorexpr - ) if(USE_DISTRIBUTED) add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) if(NOT WIN32) diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt deleted file mode 100644 index 8fe6ffd525e98..0000000000000 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ /dev/null @@ -1,83 +0,0 @@ -set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr) - -set(TENSOREXPR_TEST_SRCS - ${TENSOREXPR_TEST_ROOT}/test_approx.cpp - ${TENSOREXPR_TEST_ROOT}/test_aten.cpp - ${TENSOREXPR_TEST_ROOT}/test_boundsinference.cpp - ${TENSOREXPR_TEST_ROOT}/test_conv.cpp - ${TENSOREXPR_TEST_ROOT}/test_cpp_codegen.cpp - ${TENSOREXPR_TEST_ROOT}/test_dynamic_shapes.cpp - ${TENSOREXPR_TEST_ROOT}/test_expr.cpp - ${TENSOREXPR_TEST_ROOT}/test_external_calls.cpp - ${TENSOREXPR_TEST_ROOT}/test_graph_opt.cpp - ${TENSOREXPR_TEST_ROOT}/test_ir_printer.cpp - ${TENSOREXPR_TEST_ROOT}/test_ir_verifier.cpp - ${TENSOREXPR_TEST_ROOT}/test_kernel.cpp - ${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp - ${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp - ${TENSOREXPR_TEST_ROOT}/test_ops.cpp - ${TENSOREXPR_TEST_ROOT}/test_quantization.cpp - ${TENSOREXPR_TEST_ROOT}/test_memplanning.cpp - ${TENSOREXPR_TEST_ROOT}/test_reductions.cpp - ${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp - ${TENSOREXPR_TEST_ROOT}/test_simplify.cpp - ${TENSOREXPR_TEST_ROOT}/test_te_fuser_pass.cpp - ${TENSOREXPR_TEST_ROOT}/test_type.cpp - ${TENSOREXPR_TEST_ROOT}/test_type_specializations.cpp -) - -if(USE_CUDA) - list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp) -endif() - -if(USE_LLVM AND LLVM_FOUND) - list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp) -endif() - -add_executable(test_tensorexpr - ${TORCH_ROOT}/test/cpp/common/main.cpp - ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp - ${TENSOREXPR_TEST_SRCS}) - -target_link_libraries(test_tensorexpr PRIVATE torch gtest_main) -target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) -target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST) - -add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp) -target_link_libraries(tutorial_tensorexpr PRIVATE torch) -target_include_directories(tutorial_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) - -# The test case depends on the xnnpack header which in turn depends on the -# pthreadpool header. For some build environment we need add the dependency -# explicitly. -if(USE_PTHREADPOOL) - target_link_libraries(test_tensorexpr PRIVATE pthreadpool_interface) -endif() -if(USE_CUDA) - target_compile_definitions(test_tensorexpr PRIVATE USE_CUDA) - target_compile_definitions(tutorial_tensorexpr PRIVATE USE_CUDA) -elseif(USE_ROCM) - target_link_libraries(test_tensorexpr PRIVATE - hiprtc::hiprtc - hip::amdhip64 - ${TORCH_CUDA_LIBRARIES}) - target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) - - target_link_libraries(tutorial_tensorexpr PRIVATE - hiprtc::hiprtc - hip::amdhip64 - ${TORCH_CUDA_LIBRARIES}) - target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM) -endif() - -if(INSTALL_TEST) - set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") - install(TARGETS test_tensorexpr DESTINATION bin) - set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") - install(TARGETS tutorial_tensorexpr DESTINATION bin) - # Install PDB files for MSVC builds - if(MSVC AND BUILD_SHARED_LIBS) - install(FILES $ DESTINATION bin OPTIONAL) - install(FILES $ DESTINATION bin OPTIONAL) - endif() -endif() diff --git a/test/cpp/tensorexpr/README.md b/test/cpp/tensorexpr/README.md deleted file mode 100644 index f86a50a65e804..0000000000000 --- a/test/cpp/tensorexpr/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# TensorExpr C++ Tests - -## How to add a new test -First, create a new test file. Test files should have be placed in this -directory, with a name that starts with `test_`, like `test_foo.cpp`. - -Here is an example test file you can copy-paste. -```cpp -#include - -// Tests go in torch::jit -namespace torch { -namespace jit { - -// 1. Test cases are void() functions. -// 2. They start with the prefix `test` -void testCaseOne() { - // ... -} - -void testCaseTwo() { - // ... -} -} -} -``` - -Then, register your test in `tests.h`: -```cpp -// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests -#define TH_FORALL_TESTS(_) \ - _(ADFormulas) \ - _(Attributes) \ - ... - _(CaseOne) // note that the `test` prefix is omitted. - _(CaseTwo) -``` - -We glob all the test files together in `CMakeLists.txt` so that you don't -have to edit it every time you add a test. Unfortunately, this means that in -order to get the build to pick up your new test file, you need to re-run -cmake: -```bash -CMAKE_FRESH=1 python setup.py build -``` - -## How do I run the tests? -The following commands assume you are in PyTorch root. - - ```bash - # (re)build the test binary - ninja build/bin/test_tensorexpr - # run - build/bin/test_tensorexpr --gtest_filter='glob_style_filter*' - ``` diff --git a/test/cpp/tensorexpr/gtest_assert_float_eq.h b/test/cpp/tensorexpr/gtest_assert_float_eq.h deleted file mode 100644 index f85264a8f5d3c..0000000000000 --- a/test/cpp/tensorexpr/gtest_assert_float_eq.h +++ /dev/null @@ -1,119 +0,0 @@ -#pragma once - -#include -// Copyright 2005, Google Inc. -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// -// The Google C++ Testing and Mocking Framework (Google Test) -// -// This header file declares functions and macros used internally by -// Google Test. They are subject to change without notice. - -using Bits = uint32_t; - -// this avoids the "dereferencing type-punned pointer -// will break strict-aliasing rules" error -union Float { - float float_; - Bits bits_; -}; - -// # of bits in a number. -static const size_t kBitCount = 8 * sizeof(Bits); -// The mask for the sign bit. -static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); - -// GOOGLETEST_CM0001 DO NOT DELETE - -// Converts an integer from the sign-and-magnitude representation to -// the biased representation. More precisely, let N be 2 to the -// power of (kBitCount - 1), an integer x is represented by the -// unsigned number x + N. -// -// For instance, -// -// -N + 1 (the most negative number representable using -// sign-and-magnitude) is represented by 1; -// 0 is represented by N; and -// N - 1 (the biggest number representable using -// sign-and-magnitude) is represented by 2N - 1. -// -// Read http://en.wikipedia.org/wiki/Signed_number_representations -// for more details on signed number representations. -static Bits SignAndMagnitudeToBiased(const Bits& sam) { - if (kSignBitMask & sam) { - // sam represents a negative number. - return ~sam + 1; - } else { - // sam represents a positive number. - return kSignBitMask | sam; - } -} - -// Given two numbers in the sign-and-magnitude representation, -// returns the distance between them as an unsigned number. -static Bits DistanceBetweenSignAndMagnitudeNumbers( - const Bits& sam1, - const Bits& sam2) { - const Bits biased1 = SignAndMagnitudeToBiased(sam1); - const Bits biased2 = SignAndMagnitudeToBiased(sam2); - return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); -} - -// How many ULP's (Units in the Last Place) we want to tolerate when -// comparing two numbers. The larger the value, the more error we -// allow. A 0 value means that two numbers must be exactly the same -// to be considered equal. -// -// The maximum error of a single floating-point operation is 0.5 -// units in the last place. On Intel CPU's, all floating-point -// calculations are done with 80-bit precision, while double has 64 -// bits. Therefore, 4 should be enough for ordinary use. -// -// See the following article for more details on ULP: -// http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ -static const size_t kMaxUlps = 4; - -// Returns true if and only if this number is at most kMaxUlps ULP's away -// from rhs. In particular, this function: -// -// - returns false if either number is (or both are) NAN. -// - treats really large numbers as almost equal to infinity. -// - thinks +0.0 and -0.0 are 0 DLP's apart. -inline bool AlmostEquals(float lhs, float rhs) { - // The IEEE standard says that any comparison operation involving - // a NAN must return false. - if (std::isnan(lhs) || std::isnan(rhs)) - return false; - - Float l = {lhs}; - Float r = {rhs}; - - return DistanceBetweenSignAndMagnitudeNumbers(l.bits_, r.bits_) <= kMaxUlps; -} diff --git a/test/cpp/tensorexpr/padded_buffer.cpp b/test/cpp/tensorexpr/padded_buffer.cpp deleted file mode 100644 index 424d82c77453c..0000000000000 --- a/test/cpp/tensorexpr/padded_buffer.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "test/cpp/tensorexpr/padded_buffer.h" - -#include -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -int PaddedBufferBase::Index(const std::vector& indices) const { - TORCH_DCHECK_EQ(dims_.size(), indices.size()); - int total_index = 0; - for (const auto i : c10::irange(dims_.size())) { - total_index += indices[i] * strides_[i]; - } - return total_index; -} - -PaddedBufferBase::PaddedBufferBase( - const std::vector& dims, - // NOLINTNEXTLINE(modernize-pass-by-value) - const std::string& name) - : dims_(dims), name_(name), strides_(dims.size()) { - for (int i = (int)dims.size() - 1; i >= 0; --i) { - if (i == (int)dims.size() - 1) { - strides_[i] = 1; - } else { - strides_[i] = strides_[i + 1] * dims[i + 1]; - } - } - total_size_ = strides_[0] * dims[0]; -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/padded_buffer.h b/test/cpp/tensorexpr/padded_buffer.h deleted file mode 100644 index b3e5227ae7e62..0000000000000 --- a/test/cpp/tensorexpr/padded_buffer.h +++ /dev/null @@ -1,242 +0,0 @@ -#pragma once - -#include -#include - -#include -#include "torch/csrc/jit/tensorexpr/eval.h" - -namespace torch { -namespace jit { -namespace tensorexpr { - -template -struct DefaultPaddedValue; - -template <> -struct DefaultPaddedValue { - static const int kValue = static_cast(0xDEADBEEF); -}; - -template <> -struct DefaultPaddedValue { - static const int8_t kValue = static_cast(0xBE); -}; - -template <> -struct DefaultPaddedValue { - static const uint8_t kValue = static_cast(0xBE); -}; - -template <> -struct DefaultPaddedValue { - static const int16_t kValue = static_cast(0xBEEF); -}; - -template <> -struct DefaultPaddedValue { - static const int64_t kValue = static_cast(0xDEADBEEF); -}; - -template <> -struct DefaultPaddedValue { - static constexpr float kValue = 0.1357; -}; - -template <> -struct DefaultPaddedValue { - // at::Half ctor isn't constexpr, so just fill it with bits. - static constexpr uint16_t kValue = 1357; -}; - -template <> -struct DefaultPaddedValue { - static constexpr double kValue = 0.1357; -}; - -// A concrete base to be used in PaddedBase. -class PaddedBufferBase { - public: - const std::string& name() const { - return name_; - } - - int size() const { - return total_size_; - } - - int raw_size() const { - return total_size_ + 2 * kPaddingSize; - } - - virtual ~PaddedBufferBase() {} - - protected: - explicit PaddedBufferBase( - const std::vector& dims, - const std::string& name); - int Index(const std::vector& indices) const; - - std::vector dims_; - std::string name_; - std::vector strides_; - int total_size_; // total number of useful element, does not include the - // paddings - static constexpr int kPaddingSize = 64; -}; - -// A padded buffer with wartermarks for testing. -// The buffer carries padded watermarks on both sides to catch potential -// out-of-bounds writes. For read-only data that are not supposed to change, it -// can also make a backup and be compared later. -template -class PaddedBuffer : public PaddedBufferBase { - public: - PaddedBuffer(int d0, const std::string& name = "") - : PaddedBuffer(std::vector({d0}), name) {} - PaddedBuffer(int d0, int d1, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1}), name) {} - PaddedBuffer(int d0, int d1, int d2, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1, d2}), name) {} - PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1, d2, d3}), name) {} - PaddedBuffer(const std::vector& dims, const std::string& name = "") - : PaddedBufferBase(dims, name) { - data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue); - } - PaddedBuffer(const PaddedBuffer& other, const std::string& name) - : PaddedBuffer(other) { - this->name_ = name; - } - - T* data() { - return data_.data() + kPaddingSize; - } - const T* data() const { - return const_cast(this)->data(); - } - T* raw_data() { - return data_.data(); - } - const T* raw_data() const { - return const_cast(this)->raw_data(); - } - T& operator()(int i0) { - // There is a bit performance impact with forming a vector here. But this - // data structure is for testing only, and not performance critical. - return this->operator()(std::vector({i0})); - } - const T& operator()(int i0) const { - return const_cast(this)->operator()(i0); - } - T& operator()(int i0, int i1) { - return this->operator()(std::vector({i0, i1})); - } - const T& operator()(int i0, int i1) const { - return const_cast(this)->operator()(i0, i1); - } - T& operator()(int i0, int i1, int i2) { - return this->operator()(std::vector({i0, i1, i2})); - } - const T& operator()(int i0, int i1, int i2) const { - return const_cast(this)->operator()(i0, i1, i2); - } - T& operator()(int i0, int i1, int i2, int i3) { - return this->operator()(std::vector({i0, i1, i2, i3})); - } - const T& operator()(int i0, int i1, int i2, int i3) const { - return const_cast(this)->operator()(i0, i1, i2, i3); - } - T& operator()(const std::vector& indices) { - return data_[kPaddingSize + Index(indices)]; - } - const T& operator()(const std::vector& indices) const { - return const_cast(this)->operator()(indices); - } - - template - friend void ExpectAllNear( - const PaddedBuffer& v1, - const PaddedBuffer& v2, - float abs_error); - template - friend void ExpectAllEqual( - const PaddedBuffer& v1, - const PaddedBuffer& v2); - void Backup() { - backup_data_ = data_; - } - - // Verify the watermarks in the paddings are intact. - void ValidateWatermark() const { - for (const auto i : c10::irange(kPaddingSize)) { - ASSERT_EQ(data_[i], kPaddingValue); - ASSERT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue); - } - } - - void CheckBackup() const { - ValidateWatermark(); - DCHECK(backup_data_.size() == data_.size()) - << "Please make sure you have call Backup() before calling CheckBackup()"; - for (const auto i : c10::irange(total_size_)) { - ASSERT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]); - } - } - - private: - std::vector data_; - std::vector backup_data_; - T kPaddingValue = DefaultPaddedValue::kValue; -}; - -template -inline CodeGen::CallArg::CallArg(const PaddedBuffer& buffer) - : data_(const_cast(buffer.data())) {} - -template -std::string CompareErrorMsg( - const PaddedBuffer& v1, - const PaddedBuffer& v2, - int index) { - std::ostringstream oss; - oss << "index: " << index << ", v1: (" << v1.name() << ", " << v1(index) - << ")" - << ", v2: (" << v2.name() << ", " << v2(index) << ")"; - return oss.str(); -} - -template -void ExpectAllEqual(const PaddedBuffer& f1, const PaddedBuffer& f2) { - const std::vector& v1 = f1.data_; - const std::vector& v2 = f2.data_; - const int kPaddingSize = f1.kPaddingSize; - const int total_size = f1.total_size_; - ASSERT_EQ(v1.size(), v2.size()); - f1.ValidateWatermark(); - f2.ValidateWatermark(); - for (const auto i : c10::irange(total_size)) { - ASSERT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]); - } -} - -template -void ExpectAllNear( - const PaddedBuffer& f1, - const PaddedBuffer& f2, - float abs_error) { - const std::vector& v1 = f1.data_; - const std::vector& v2 = f2.data_; - const int kPaddingSize = f1.kPaddingSize; - const int total_size = f1.total_size_; - ASSERT_EQ(v1.size(), v2.size()); - f1.ValidateWatermark(); - f2.ValidateWatermark(); - for (const auto i : c10::irange(total_size)) { - ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error); - } -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_approx.cpp b/test/cpp/tensorexpr/test_approx.cpp deleted file mode 100644 index e1a576aecf526..0000000000000 --- a/test/cpp/tensorexpr/test_approx.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#ifdef TORCH_ENABLE_LLVM - -#include -#include -#include -#include -#include -#include -#include - -using namespace torch::indexing; -namespace te = torch::jit::tensorexpr; - -static void vectorize(te::LoopNest* ln, te::Tensor target, int width) { - auto loops = ln->getLoopStmtsFor(target); - te::ForPtr inner, tail; - ln->splitWithTail(loops[0], width, &inner, &tail); - ASSERT_TRUE(te::LoopNest::vectorize(inner)); -} - -std::string diffs(const at::Tensor& a, const at::Tensor& b) { - auto diff = torch::abs(a.flatten() - b.flatten()); - auto count_diffs = torch::sum(diff > 0.f); - auto greatest_diff_index = torch::argmax(diff); - std::stringstream ss; - ss << "Found " << count_diffs << " unequal element(s). " - << "The greatest difference was " << diff.index({greatest_diff_index}) - << " at index " << greatest_diff_index; - return ss.str(); -} - -TEST(Approx, log_vml) { - te::VarHandle N("N", te::kInt); - te::BufHandle A("A", {N}, te::kFloat); - te::Tensor B = te::Compute( - "B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); }); - - te::LoopNest ln({B}); - ln.prepareForCodegen(); - vectorize(&ln, B, 8); - te::StmtPtr s = ln.root_stmt(); - s = te::IRSimplifier::simplify(s); - te::LLVMCodeGen cg(s, {A, B, N}); - - auto eps = std::numeric_limits::epsilon(); - auto test = [&](const at::Tensor& A_t) { - at::Tensor B_ref = at::log(A_t); - at::Tensor B_t = at::empty_like(A_t); - auto ap = A_t.data_ptr(); - auto bp = B_t.data_ptr(); - cg.call({ap, bp, A_t.numel()}); - // Results should be bit-identical. - ASSERT_TRUE(torch::allclose( - B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true)) - << "Input[:8]\n" - << A_t.index({Slice(0, 8)}) << "\n" - << "Test[:8]\n" - << B_t.index({Slice(0, 8)}) << "\n" - << "Ref[:8]\n" - << B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref); - }; - - // Generate every single-precision FP value in [1.0, 2.0). - at::Tensor A_t = torch::arange(1.0f, 2.0f, eps); - ASSERT_EQ(A_t.numel(), 1 << 23); - - test(A_t); - - test(A_t * 2.0f); - test(A_t * 0.5f); - - test(A_t * 4.0f); - test(A_t * 0.25f); - - test(A_t * powf(2.0f, 16)); - test(A_t * powf(2.0f, -16)); - - test(A_t * powf(2.0f, 126)); - test(A_t * powf(2.0f, -126)); - - test(torch::full({32}, INFINITY)); - test(torch::full({32}, NAN)); - - auto min = std::numeric_limits::min(); - auto denorm_min = std::numeric_limits::denorm_min(); - - // Denormals aren't bit precise, because sleef isn't bit-precise either. - A_t = torch::arange(0.0f, min, denorm_min); - ASSERT_EQ(A_t.numel(), 1 << 23); - auto B_ref = at::log(A_t); - auto B_t = at::empty_like(B_ref); - cg.call({A_t.data_ptr(), B_t.data_ptr(), A_t.numel()}); - ASSERT_TRUE(torch::allclose(B_t, B_ref)); -} - -#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp deleted file mode 100644 index 34ce2bd069d55..0000000000000 --- a/test/cpp/tensorexpr/test_aten.cpp +++ /dev/null @@ -1,1068 +0,0 @@ -#include -#include -#include - -#include - -#include -#include -#include "test/cpp/tensorexpr/padded_buffer.h" -#include "test/cpp/tensorexpr/test_base.h" -#include "torch/csrc/jit/tensorexpr/ir_printer.h" - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(ATen, _cast_Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Cast::make(kFloat, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), static_cast(i)); - } -} - -TEST(ATen, negInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Sub::make(0, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), -static_cast(i)); - } -} - -TEST(ATen, negFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Sub::make(0, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), -i); - } -} - -TEST(ATen, addInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); - } -} - -TEST(ATen, addFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); - } -} - -TEST(ATen, subInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); - } -} - -TEST(ATen, subFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); - } -} - -TEST(ATen, lerp) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i))); - } -} - -TEST(ATen, addcmulInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - ExprHandle load_d = d_buf.load(index); - StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - PaddedBuffer e_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - d_v(i) = 5 * i + 3; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); - ir_eval(a_v, b_v, c_v, d_v, e_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), 5 * i + 3); - ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); - } -} - -TEST(ATen, addcmulFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - ExprHandle load_d = d_buf.load(index); - StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - PaddedBuffer e_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - d_v(i) = 5 * i + 3; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); - ir_eval(a_v, b_v, c_v, d_v, e_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), 5 * i + 3); - ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); - } -} - -TEST(ATen, mulInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a * load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); - } -} - -TEST(ATen, mulFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a * load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); - } -} - -TEST(ATen, divInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a / load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = 2 * i + 1; - b_v(i) = i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), 2 * i + 1); - ASSERT_EQ(b_v(i), i + 1); - ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); - } -} - -TEST(ATen, divFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a / load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = 2 * i + 1; - b_v(i) = i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), 2 * i + 1); - ASSERT_EQ(b_v(i), i + 1); - ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); - } -} - -TEST(ATen, maxInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i))); - } -} - -TEST(ATen, maxFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i))); - } -} - -TEST(ATen, minInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i))); - } -} - -TEST(ATen, minFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i))); - } -} - -void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 1.0f / i); - } -} - -TEST(ATen, reluInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i - 64; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i - 64); - ASSERT_EQ(b_v(i), std::max(a_v(i), 0)); - } -} - -TEST(ATen, reluFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store( - {index}, Max::make(load_a, 0, false) // relu does not propagate nans - ); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i - 64; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i - 64); - ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0)); - } -} - -TEST(ATen, logFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log(a_v(i))); - } -} - -TEST(ATen, fastLogFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::log(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_FLOAT_EQ(test, ref); - } - } -} - -TEST(ATen, fastTanhFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::tanh(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_NEAR(test, ref, 1e-6); - } - } -} - -TEST(ATen, fastSigmoidFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - at::Tensor t = at::ones({1}) * a_v(i); - float ref = at::sigmoid(t).item().to(); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_NEAR(test, ref, 1e-6); - } - } -} - -TEST(ATen, log10Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log10(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log10(a_v(i))); - } -} - -TEST(ATen, log2Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log2(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log2(a_v(i))); - } -} - -TEST(ATen, expFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, exp(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::exp(a_v(i))); - } -} - -TEST(ATen, erfFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, erf(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::erf(a_v(i))); - } -} - -TEST(ATen, cosFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, cos(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::cos(a_v(i))); - } -} - -TEST(ATen, eqInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, geInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGE))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, gtInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 6); - std::vector b_buffer(N, 3); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGT))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, leInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLE))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, ltInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLT))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 0); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_base.h b/test/cpp/tensorexpr/test_base.h deleted file mode 100644 index 68b96fe6c90f7..0000000000000 --- a/test/cpp/tensorexpr/test_base.h +++ /dev/null @@ -1,89 +0,0 @@ -#pragma once - -#if defined(USE_GTEST) -#include -#include -#else -#include -#include "c10/util/Exception.h" -#include "test/cpp/tensorexpr/gtest_assert_float_eq.h" -#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__) -#define ASSERT_FLOAT_EQ(x, y, ...) \ - TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__) -#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__) -#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__) -#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__) -#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__) -#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__) - -#define ASSERT_NEAR(x, y, a, ...) \ - TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__) - -#define ASSERT_TRUE TORCH_INTERNAL_ASSERT -#define ASSERT_FALSE(x) ASSERT_TRUE(!(x)) -#define ASSERT_THROWS_WITH(statement, substring) \ - try { \ - (void)statement; \ - ASSERT_TRUE(false); \ - } catch (const std::exception& e) { \ - ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ - } -#define ASSERT_ANY_THROW(statement) \ - { \ - bool threw = false; \ - try { \ - (void)statement; \ - } catch (const std::exception& e) { \ - threw = true; \ - } \ - ASSERT_TRUE(threw); \ - } - -#endif // defined(USE_GTEST) -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -template -void ExpectAllNear( - const std::vector& v1, - const std::vector& v2, - V threshold, - const std::string& name = "") { - ASSERT_EQ(v1.size(), v2.size()); - for (size_t i = 0; i < v1.size(); i++) { - ASSERT_NEAR(v1[i], v2[i], threshold); - } -} - -template -void ExpectAllNear( - const std::vector& vec, - const U& val, - V threshold, - const std::string& name = "") { - for (size_t i = 0; i < vec.size(); i++) { - ASSERT_NEAR(vec[i], val, threshold); - } -} - -template -static void assertAllEqual(const std::vector& vec, const T& val) { - for (auto const& elt : vec) { - ASSERT_EQ(elt, val); - } -} - -template -static void assertAllEqual(const std::vector& v1, const std::vector& v2) { - ASSERT_EQ(v1.size(), v2.size()); - for (size_t i = 0; i < v1.size(); ++i) { - ASSERT_EQ(v1[i], v2[i]); - } -} -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp deleted file mode 100644 index 2605842d6e74d..0000000000000 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ /dev/null @@ -1,1019 +0,0 @@ -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -static void verifyConstBounds( - const TensorAccessBoundsInfo& access_info, - const std::vector>& ref) { - size_t ndim = ref.size(); - ASSERT_EQ(access_info.start.size(), ndim); - ASSERT_EQ(access_info.stop.size(), ndim); - for (const auto i : c10::irange(ndim)) { - if (ref[i].first >= 0) { // Negative values are used to skip the check - ASSERT_TRUE(access_info.start[i]->isConstant()); - int start_i = immediateAs(access_info.start[i]); - ASSERT_EQ(start_i, ref[i].first); - } - if (ref[i].second >= 0) { - ASSERT_TRUE(access_info.stop[i]->isConstant()); - int stop_i = immediateAs(access_info.stop[i]); - ASSERT_EQ(stop_i, ref[i].second); - } - } -} - -TEST(BoundsInference, _1) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}} - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 99}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); -} - -TEST(BoundsInference, _2) { - // Verify that bounds inference works for the following example: - // for i in 0..n: - // b[i] = a[i] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}} - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}}); -} - -TEST(BoundsInference, _3) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] * a[i+10] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}} - ExprHandle n(100); - BufHandle a("a", {n + 10}, kFloat); - Tensor b = Compute( - "b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 109}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); -} - -TEST(BoundsInference, _4) { - // Verify that bounds inference works for the following example: - // - // for y in 0..200: - // for x in 0..320: - // b[y,x] = x*y - // for y in 0..200: - // for x in 0..320: - // c[y,x] = a[y,x] * b[y,x] - ExprHandle W(320); - ExprHandle H(200); - BufHandle a("a", {H, W}, kFloat); - Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return x * y; - }); - Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return a.load(y, x) * b.load(y, x); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - StmtPtr body = l.getLoopBodyFor(c); - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 199}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}}); - } - { - // Infer bounds on the inner loop body's scope - auto bounds_info = inferBounds(body); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); - } -} - -TEST(BoundsInference, _5) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] - // - // ==> split ==> - // - // for i_outer in 0..100/16: - // for i_inner in 0..16: - // b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner] - // for i_tail in 0..100%16: - // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16]; - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - ForPtr inner; - ForPtr tail; - std::vector loops = l.getLoopStmtsFor(b); - LoopNest::splitWithTail(loops[0], 16, &inner, &tail); - ForPtr outer = loops[0]; - - { - // Verify inferred bounds for the outer loop - auto bounds_info = inferBounds(outer); - ASSERT_EQ(bounds_info.size(), 2); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 95}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}}); - } - { - // Verify inferred bounds for the tail loop - auto bounds_info = inferBounds(tail); - ASSERT_EQ(bounds_info.size(), 2); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{96, 99}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}}); - } -} - -TEST(BoundsInference, _6) { - // Verify that bounds inference works for the following example: - // - // for y in 0..200: - // for x in 0..320: - // b[y,x] = x*y - // for y in 0..20: - // for x in 0..32: - // c[y,x] = a[y+100,x+100] * b[y*2,x*5] - ExprHandle W(320); - ExprHandle H(200); - ExprHandle CW(32); - ExprHandle CH(20); - BufHandle a("a", {H, W}, kFloat); - Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return x * y; - }); - Tensor c = - Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) { - return a.load(y + 100, x + 100) * b.load(y * 2, x * 5); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - StmtPtr body = l.getLoopBodyFor(c); - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{100, 119}, {100, 131}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {100, 131}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}}); - } - { - // Infer bounds on the inner loop body's scope - auto bounds_info = inferBounds(body); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); - } -} - -TEST(BoundsInference, Adjacent) { - ExprHandle H(6); - BufHandle a("a", {20}, kFloat); - Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); }); - Tensor c = - Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); }); - LoopNest l({b, c}); - std::vector loops = NodeFinder::find(l.root_stmt()); - - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 2); - - // reads from a[0:5], writes to b[0:5] - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 5}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 2); - - // reads from a[0+6:5+6], writes to c[0:5] - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{6, 11}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); - } - { - // Infer bounds on the high level program. - auto bounds_info = inferBounds(l.root_stmt()); - ASSERT_EQ(bounds_info.size(), 3); - - // Should be union of above 2 bounds, but this time the bounds of A can be - // merged. - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 11}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); - } -} - -TEST(BoundsInference, MultipleTopLoopLoad) { - BufHandle a("a", {100}, kFloat); - Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); }); - Tensor c = - Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); }); - Tensor d = - Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); }); - LoopNest l({b, c, d}); - - auto bounds_info = inferBounds(l.root_stmt()); - - ASSERT_EQ(bounds_info.size(), 4); - - // a only read. - { - auto bounds = bounds_info[a.node()]; - ASSERT_EQ(bounds.size(), 1); - // One dimension. - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); - // Bounds: - // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b). - // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 = - // 96 + 2 - 1 (d). - verifyConstBounds(bound, {{0, 97}}); - } - - // b, c, d only written. - { - auto bounds = bounds_info[b.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for b. - verifyConstBounds(bound, {{0, 63}}); - } - { - auto bounds = bounds_info[c.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for c. - verifyConstBounds(bound, {{0, 31}}); - } - { - auto bounds = bounds_info[d.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for d. - verifyConstBounds(bound, {{0, 95}}); - } -} - -TEST(BoundsInference, MultipleTopLoopStore) { - BufHandle a("a", {100}, kFloat); - BufHandle b("b", {100}, kFloat); - BufHandle c("c", {100}, kFloat); - BufHandle d("d", {100}, kFloat); - VarHandle x("x", kInt); - - // Same as above but the offsets are on the Store now. - // Can't do this through ComputeAPI without transforms we don't have yet. - StmtPtr stmt = Block::make( - {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))), - For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))), - For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))}); - - auto bounds_info = inferBounds(stmt); - - ASSERT_EQ(bounds_info.size(), 4); - - // a only read. - { - auto bounds = bounds_info[a.node()]; - ASSERT_EQ(bounds.size(), 1); - // One dimension. - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); - // Bounds: there are no offsets, so this is just the max loop bounds. - verifyConstBounds(bound, {{0, 95}}); - } - - // b, c, d only written. - { - auto bounds = bounds_info[b.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the b loop. - // b loop has no offset, so just the loop extents. - verifyConstBounds(bound, {{0, 63}}); - } - { - auto bounds = bounds_info[c.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the c loop. - // Offset is 10, extent is 32-1. - verifyConstBounds(bound, {{10, 41}}); - } - { - auto bounds = bounds_info[d.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the d loop. - // Offset is 2, extent is 96-1. - verifyConstBounds(bound, {{2, 97}}); - } -} - -TEST(BoundsInference, CacheReads) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 3); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}); - auto bounds_info_before = inferBounds(l.root_stmt()); - - StmtPtr j_loop = l.getLoopStmtsFor(B)[1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - - auto bounds_info_after = inferBounds(l.root_stmt()); - - // CacheAccesses should not change existing bounds, but add a new one for the - // cache. - for (auto& pair : bounds_info_after) { - auto beforeIt = bounds_info_before.find(pair.first); - if (beforeIt != bounds_info_before.end()) { - // Same number of TensorAccessBoundInfos. - ASSERT_EQ(pair.second.size(), beforeIt->second.size()); - - for (const auto i : c10::irange(pair.second.size())) { - TensorAccessBoundsInfo& after = pair.second[i]; - TensorAccessBoundsInfo& before = beforeIt->second[i]; - // Same number of dimensions. - ASSERT_EQ(before.start.size(), after.start.size()); - - // Bounds are equal. - for (const auto j : c10::irange(before.start.size())) { - ASSERT_TRUE(exprEquals(before.start[j], after.start[j])); - ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j])); - } - } - } else { - // This should be the cache. - ASSERT_EQ(pair.first->name_hint(), "A_local"); - // Should have both a load and a store. - ASSERT_EQ(pair.second.size(), 2); - TensorAccessBoundsInfo& first = pair.second[0]; - TensorAccessBoundsInfo& second = pair.second[1]; - - ASSERT_NE(first.kind, second.kind); - // 2 dimensions. - ASSERT_EQ(first.start.size(), second.start.size()); - ASSERT_EQ(first.start.size(), 2); - - // bounds for load and store are equal. - for (const auto j : c10::irange(first.start.size())) { - ASSERT_TRUE(exprEquals(first.start[j], second.start[j])); - ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j])); - } - } - } -} - -TEST(BoundsInference, Flattened) { - Tensor b = Compute( - "b", - {3, 4, 5}, - [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) { - return x * y + z; - }); - - LoopNest l({b}); - // Flatten indices. - l.prepareForCodegen(); - auto bounds_info = inferBounds(l.root_stmt()); - - // There's only one buffer. - ASSERT_EQ(bounds_info.size(), 1); - auto& TABI = bounds_info[b.buf()][0]; - ASSERT_EQ(TABI.kind, TensorAccessKind::kStore); - // Flattened bounds should have a single dimension. - ASSERT_EQ(TABI.start.size(), 1); - ASSERT_EQ(TABI.stop.size(), 1); - - // Bounds should be 0 -> (3*4*5)-1 - ASSERT_TRUE(exprEquals(TABI.start[0], alloc(0))); - ASSERT_TRUE(exprEquals(TABI.stop[0], alloc(3 * 4 * 5 - 1))); -} - -TEST(BoundsInference, GetPotentialHazards) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* - * A[0] = B[0]; - * B[0] = 3; WAR on B - * A[0] = B[0]; WAW on A, RAW on B - * C[0] = 5; - */ - - StorePtr store1 = Store::make(a, {0}, Load::make(b, {0})); - StorePtr store2 = Store::make(b, {0}, 3); - StorePtr store3 = Store::make(a, {0}, Load::make(b, {0})); - StorePtr store4 = Store::make(c, {0}, 5); - StmtPtr stmt = Block::make({store1, store2, store3, store4}); - - MemDependencyChecker analyzer; - stmt->accept(&analyzer); - - ASSERT_EQ( - HazardKind::WriteAfterRead, - getPotentialHazards(analyzer, store1, store2)); - - ASSERT_EQ( - HazardKind::ReadAfterWrite, - getPotentialHazards(analyzer, store2, store3)); - - ASSERT_EQ( - HazardKind::WriteAfterWrite, - getPotentialHazards(analyzer, store1, store3)); - - // Fourth store has no dependencies - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store1, store4)); - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store2, store4)); - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store3, store4)); - } -} - -TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return (i + 1) * (j + 1); - }); - - LoopNest l({A, B}); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; - ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; - - // No dependencies between loops. - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, loopRootA, loopRootB)); -} - -TEST(BoundsInference, GetPotentialHazardsLoopCall) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j) + 5; - }); - - LoopNest l({A, B}); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; - ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; - - ASSERT_EQ( - HazardKind::ReadAfterWrite, - getPotentialHazards(analyzer, loopRootA, loopRootB)); -} - -TEST(BoundsInference, GetPotentialHazardsLoopSplit) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - - LoopNest l({A}); - ForPtr inner, tail; - - // Splitting with tail by something offset creates a tail which also writes to - // A. - ForPtr outer = l.getLoopStmtsFor(A)[0]; - // `outer` loop get transformed to the outer loop after splitting. - LoopNest::splitWithTail(outer, 5, &inner, &tail); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ASSERT_EQ( - HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k-1] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 10, 100, Store::make(a_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // B[k] = A[k]; - // } - BufHandle a_buf("A", {200}, kInt); - BufHandle b_buf("B", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(b_buf, {k}, Load::make(a_buf, {k}))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k+100] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // A[m+1,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // A[m+20,n+100] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // B[m,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) { - // Input IR: - // for (const auto j : c10::irange(100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(100)) { - // B[k] = 20 * A[99-k]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) { - // Input IR: - // for (const auto k : c10::irange(100)) { - // B[k] = 20 * A[99-k]; - // } - // for (const auto j : c10::irange(100)) { - // A[j] = 10 * j; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto par = Block::make({forK, forJ}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapWithLoads) { - // Input IR: - // for (const auto k : c10::irange(10, 100)) { - // B[k] = 20 * A[99-k]; - // } - // for (const auto j : c10::irange(10, 100)) { - // C[j] = 10 * A[j]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - BufHandle c_buf("C", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 10, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make( - j, - 10, - 100, - Store::make(c_buf, {j}, Mul::make(10, Load::make(a_buf, {j})))); - auto par = Block::make({forK, forJ}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, IsOverlapping) { - // Input IR: - // for (const auto i : c10::irange(100)) { - // A[i] = i * 10; // storeA1 - // B[i] = A[99-i] * 20; // loadA1 - // C[i] = A[i + 100] * 10; // loadA2 - // A[i + 50] = i * 50; // storeA2 - // A[i + 150] = i * 150; // storeA3 - // } - BufHandle a_buf("A", {300}, kInt); - BufHandle b_buf("B", {100}, kInt); - BufHandle c_buf("C", {100}, kInt); - VarHandle i("i", kInt); - auto storeA1 = Store::make(a_buf, {i}, i * 10); - auto loadA1 = Load::make(a_buf, {ExprHandle(99) - i}); - auto storeB = Store::make(b_buf, {i}, Mul::make(loadA1, 20)); - auto loadA2 = Load::make(a_buf, {i + 100}); - auto storeC = Store::make(c_buf, {i}, Mul::make(loadA2, 10)); - auto storeA2 = Store::make(a_buf, {i + 50}, i * 50); - auto storeA3 = Store::make(a_buf, {i + 150}, i * 150); - auto forI = For::make( - i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3})); - tensorexpr::analysis::MemDependencyChecker analyzer; - forI->accept(&analyzer); - ASSERT_TRUE(isOverlapping(analyzer, storeA1, to(loadA1.node()))); - ASSERT_FALSE(isOverlapping(analyzer, storeA1, to(loadA2.node()))); - ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2)); - ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3)); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp deleted file mode 100644 index e72303873a6cf..0000000000000 --- a/test/cpp/tensorexpr/test_conv.cpp +++ /dev/null @@ -1,234 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -namespace te = torch::jit::tensorexpr; -namespace F = torch::nn::functional; - -#ifdef TORCH_ENABLE_LLVM - -// Generate test data with few bits of precision, to minimize error -// accumulation from floating-point reordering. -static at::Tensor genTestData(c10::IntArrayRef args) { - return at::trunc(at::randn(args) * 256.0f) / 256.0f; -} - -TEST(Conv, DepthwiseConv2D) { - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - te::BufHandle input("input", {N, C, H, W}, te::kFloat); - te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); - te::BufHandle bias("bias", {K}, te::kFloat); - te::Tensor output = - te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output}); - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto bt = genTestData({K}); - auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - cg.call( - {it.data_ptr(), - wt.data_ptr(), - bt.data_ptr(), - ot.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -TEST(Conv, DepthwiseConv2DNoBias) { - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - te::BufHandle input("input", {N, C, H, W}, te::kFloat); - te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); - te::Tensor output = - te::conv2d_depthwise(input, weight, kStride, kPad, kGroups); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output}); - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto ref = - at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - cg.call({it.data_ptr(), wt.data_ptr(), ot.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -TEST(Conv, DepthwiseConv2DDynamicShapes) { - te::VarHandle N_var("N", te::kInt); - te::VarHandle C_var("C", te::kInt); - te::VarHandle H_var("H", te::kInt); - te::VarHandle W_var("W", te::kInt); - te::VarHandle K_var("K", te::kInt); - te::VarHandle CperG_var("CperG", te::kInt); - te::VarHandle R_var("R", te::kInt); - te::VarHandle S_var("S", te::kInt); - te::VarHandle kPad_var("kPad", te::kInt); - te::VarHandle kStride_var("kStride", te::kInt); - te::VarHandle kGroups_var("kGroups", te::kInt); - - te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat); - te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat); - te::Tensor output = te::conv2d_depthwise( - input, - weight, - N_var, - C_var, - H_var, - W_var, - K_var, - CperG_var, - R_var, - S_var, - kStride_var, - kPad_var, - kGroups_var); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - std::vector buffer_args = { - input, - weight, - N_var, - C_var, - H_var, - W_var, - K_var, - CperG_var, - R_var, - S_var, - kPad_var, - kStride_var, - kGroups_var, - output}; - te::LLVMCodeGen cg(loop.root_stmt(), buffer_args); - - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto ref = - at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - std::vector call_args = { - it.data_ptr(), - wt.data_ptr(), - N, - C, - H, - W, - K, - CperG, - R, - S, - kPad, - kStride, - kGroups, - ot.data_ptr()}; - cg.call(call_args); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -#endif - -TEST(Conv, Conv2D) { - // Input dimensions. - constexpr int N = 1; - constexpr int C = 3; - constexpr int H = 11; - constexpr int W = 11; - - // Filter dimensions. - constexpr int K = 8; - constexpr int R = 3; - constexpr int S = 3; - - // Output dims. - constexpr int OH = H - R + 1; - constexpr int OW = W - S + 1; - - // Compute reference result. - at::Tensor input = torch::randn({N, C, H, W}); - at::Tensor filter = torch::randn({K, C, R, S}); - at::Tensor ref = F::conv2d(input, filter); - - // Double check the output size is as expected. - ASSERT_EQ(ref.size(0), N); - ASSERT_EQ(ref.size(1), K); - ASSERT_EQ(ref.size(2), OH); - ASSERT_EQ(ref.size(3), OW); - - te::BufHandle inputB("input", {N, C, H, W}, te::kFloat); - te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat); - - te::Tensor conv = te::Reduce( - "conv", - {N, K, OH, OW}, - te::Sum(), - // FIXME: We have to use a `std::vector` parameter here and then unpack - // it, because we don't have an overload allowing for an arbitrary number - // of ExprHandle/VarHandle parameters. - [&](const std::vector& v) { - auto const& n = v[0]; - auto const& k = v[1]; - auto const& oh = v[2]; - auto const& ow = v[3]; - auto const& c = v[4]; - auto const& r = v[5]; - auto const& s = v[6]; - // FIXME: We have to use `call` and construct a `std::vector` here - // because the `operator()` overload is only specialized for a small - // number of arguments. - return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s); - }, - // FIXME: If you forget one of the reduction dims, you get a segfault. - // Could that be caught by a verifier? - {C, R, S}); - - // FIXME: It'd be nice to have a single header that pulls in things like - // LoopNest, IRSimplifier, etc. - te::LoopNest loop({conv}); - loop.prepareForCodegen(); - te::StmtPtr s = loop.root_stmt(); - s = te::IRSimplifier::simplify(s); - - at::Tensor result = at::empty_like(ref); - te::SimpleIREvaluator cg(s, {inputB, filterB, conv}); - cg.call( - {input.data_ptr(), - filter.data_ptr(), - result.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3)); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_cpp_codegen.cpp b/test/cpp/tensorexpr/test_cpp_codegen.cpp deleted file mode 100644 index ed7679053637c..0000000000000 --- a/test/cpp/tensorexpr/test_cpp_codegen.cpp +++ /dev/null @@ -1,259 +0,0 @@ -#include - -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -#define STR_CHECK(node, expected) \ - std::stringstream ss; \ - CppPrinter printer(&ss); \ - printer.visit(node); \ - ASSERT_EQ(ss.str(), expected) - -#define FILE_CHECK(node, pattern) \ - std::stringstream ss; \ - CppPrinter printer(&ss); \ - printer.visit(node); \ - torch::jit::testing::FileCheck().run(pattern, ss.str()) - -TEST(CppPrinter, IntImm) { - auto i = alloc(10); - STR_CHECK(i, "10"); -} - -TEST(CppPrinter, FloatImm) { - auto f = alloc(10); - STR_CHECK(f, "10.f"); -} - -TEST(CppPrinter, FloatImm1) { - auto f = alloc(10); - STR_CHECK(f, "10.f"); -} - -TEST(CppPrinter, DoubleImm) { - auto d = alloc(10); - STR_CHECK(d, "10.0"); -} - -TEST(CppPrinter, DoubleImm1) { - auto d = alloc(10.1); - STR_CHECK(d, "10.1"); -} - -TEST(CppPrinter, HalfImm) { - auto h = alloc(10); - STR_CHECK(h, "10"); -} - -TEST(CppPrinter, Add) { - auto add = alloc(alloc(1), alloc(2)); - STR_CHECK(add, "1 + 2"); -} - -TEST(CppPrinter, AddExpr1) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc(alloc(2), alloc(3))); - STR_CHECK(add, "(0 + 1) + (2 - 3)"); -} - -TEST(CppPrinter, AddExpr2) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc(alloc(2), alloc(3))); - STR_CHECK(add, "0 * 1 + (2 - 3)"); -} - -TEST(CppPrinter, AddExpr3) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc
(alloc(2), alloc(3))); - STR_CHECK(add, "(0 + 1) + 2 / 3"); -} - -TEST(CppPrinter, Mod) { - auto mod = alloc(alloc(1), alloc(2)); - STR_CHECK(mod, "1 % 2"); -} - -TEST(CppPrinter, ModFloat) { - auto mod = alloc(alloc(1), alloc(2)); - STR_CHECK(mod, "std::fmod(1.f, 2.f)"); -} - -TEST(CppPrinter, Max) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "std::max(1, 2)"); -} - -TEST(CppPrinter, MaxFloat) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "std::max(1.f, 2.f)"); -} - -TEST(CppPrinter, MaxHalf) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "(1 < 2) ? 2 : 1"); -} - -TEST(CppPrinter, And) { - auto v = alloc(alloc(1), alloc(2)); - STR_CHECK(v, "1 & 2"); -} - -TEST(CppPrinter, CompareSelect) { - auto cs = alloc( - alloc(1), - alloc(2), - alloc(1), - alloc(2), - CompareSelectOperation::kLE); - STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)"); -} - -TEST(CppPrinter, IfThenElse) { - auto cond = alloc(alloc(1), alloc(2)); - auto true_value = alloc(alloc(0), alloc(1)); - auto false_value = alloc(alloc(2), alloc(3)); - auto v = alloc(cond, true_value, false_value); - STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)"); -} - -TEST(CppPrinter, AllocateFree) { - BufHandle buf("x", {2, 3}, kInt); - AllocatePtr alloc = Allocate::make(buf); - FreePtr free = Free::make(buf); - BlockPtr block = Block::make({alloc, free}); - - const std::string pattern = R"( - # CHECK: { - # CHECK: int* x = static_cast(malloc(24)); - # CHECK: free(x); - # CHECK: } - )"; - FILE_CHECK(block, pattern); -} - -TEST(CppPrinter, LoadStore) { - BufHandle a("A", {2, 3}, kInt); - BufHandle b("B", {3, 4}, kInt); - auto store = b.store({2, 2}, a.load(1, 1)); - STR_CHECK( - store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n"); -} - -TEST(CppPrinter, Var) { - auto var = alloc("x", kInt); - STR_CHECK(var, "x"); -} - -TEST(CppPrinter, Cast) { - auto cast = alloc(kFloat, alloc(1)); - STR_CHECK(cast, "static_cast(1)"); -} - -TEST(CppPrinter, BitCast) { - auto cast = alloc(kInt, alloc(20)); - STR_CHECK(cast, "std::bitcast(20.f)"); -} - -TEST(CppPrinter, Let) { - auto var = alloc("x", kFloat); - auto val = alloc(2); - auto let = alloc(var, val); - STR_CHECK(let, "float x = 2.f;\n"); -} - -TEST(CppPrinter, For) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); - const std::string pattern = R"( - # CHECK: for (int i = 0; i < 1024; i++) { - # CHECK: C[i] = (A[i]) + (B[i]); - # CHECK: } - )"; - FILE_CHECK(f, pattern); -} - -TEST(CppPrinter, Cond) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = - Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - const std::string pattern = R"( - # CHECK: if (((X[0] < 10) ? 1 : 0)) { - # CHECK: X[0] = (X[0]) + 1; - # CHECK: } else { - # CHECK: X[0] = (X[0]) - 1; - # CHECK: } - )"; - FILE_CHECK(cond, pattern); -} - -TEST(CppPrinter, Intrinsics) { - const std::unordered_set> unsupported_ops{ - kRand, kSigmoid}; - for (const auto i : c10::irange(static_cast(kMaxIntrinsicsOp))) { - IntrinsicsOp op = static_cast(i); - if (unsupported_ops.count(op)) { - continue; - } - - if (Intrinsics::OpArgCount(op) == 1) { - auto v = alloc(op, alloc(2.0f)); - STR_CHECK(v, "std::" + v->func_name() + "(2.f)"); - } else { - auto v = - alloc(op, alloc(1.0f), alloc(2.0f)); - STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)"); - } - } -} - -TEST(CppPrinter, ExternalCall) { - std::vector dims{alloc(2), alloc(2)}; - auto output = alloc("out", dims, kFloat); - auto buf_arg1 = alloc("a", dims, kFloat); - auto buf_arg2 = alloc("b", dims, kFloat); - auto scalar_arg = alloc(alloc(1), alloc(2)); - std::vector buf_args{buf_arg1, buf_arg2}; - std::vector scalar_args{scalar_arg}; - auto call = - alloc(output, "nnc_aten_matmul", buf_args, scalar_args); - const std::string pattern = R"( - # CHECK: { - # CHECK: void* buf_ptrs[]{out, a, b}; - # CHECK: int64_t buf_ranks[]{2, 2, 2}; - # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2}; - # CHECK: int8_t buf_dtypes[]{6, 6, 6}; - # CHECK: int64_t extra_args[]{1 + 2}; - # CHECK: nnc_aten_matmul( - # CHECK: 3, - # CHECK: buf_ptrs, - # CHECK: buf_ranks, - # CHECK: buf_dims, - # CHECK: buf_dtypes, - # CHECK: 1, - # CHECK: extra_args); - # CHECK: } - )"; - FILE_CHECK(call, pattern); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp deleted file mode 100644 index 2e1e84e758db3..0000000000000 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ /dev/null @@ -1,2344 +0,0 @@ -#ifdef USE_CUDA - -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; -using namespace torch::jit::tensorexpr; - -template -static void testCudaTestVectorAdd01_impl() { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Dtype dtype = ToDtype(); - BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); - BufHandle b_buf("b", {num_iter, block_count, block_size}, dtype); - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); - const int N = block_count * block_size * num_iter; - PaddedBuffer a_v(N); - PaddedBuffer b_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = ctype(i); - b_v(i) = ctype(i * 3 + 7); - c_ref(i) = a_v(i) + b_v(i); - } - - // TODO: move gpu support into PaddedBuffer - ctype* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(ctype))); - ctype* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(ctype))); - ctype* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(ctype))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -float sigmoid(float x) { - return 1.0f / (1.0f + expf(-0.0f - x)); -} - -TEST(Cuda, Sigmoid_CUDA) { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Dtype dtype = ToDtype(); - BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return sigmoid(sigmoid(a_buf.load(n, b_id, t_id))); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf); - const int N = block_count * block_size * num_iter; - PaddedBuffer a_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = float(i); - c_ref(i) = sigmoid(sigmoid(a_v(i))); - } - - // TODO: move gpu support into PaddedBuffer - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, TestVectorAdd01_CUDA) { - // floating types. - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - - // integer types. - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); -} - -static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) { - BufHandle a_buf("a", {N}, kFloat); - BufHandle b_buf("b", {N}, kFloat); - Tensor c = Compute("c", {N}, [&](const VarHandle& n) { - return a_buf.load(n) + b_buf.load(n); - }); - LoopNest l({c}); - ForPtr n_inner; - std::vector loops = l.getLoopStmtsFor(c); - l.splitWithMask(loops[0], block_size, &n_inner); - loops[0]->set_gpu_block_index(0); - n_inner->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); - PaddedBuffer a_v(N); - PaddedBuffer b_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = i; - b_v(i) = i * 3 + 7; - c_ref(i) = a_v(i) + b_v(i); - } - - // TODO: move gpu support into PaddedBuffer - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, TestVectorAdd02_CUDA) { - testCudaTestVectorAdd02_impl(1024, 128); - testCudaTestVectorAdd02_impl(1030, 128); -} - -TEST(Cuda, HalfCast_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor b = Compute("b", {4}, [&](const VarHandle& i) { - return Cast::make(kFloat, a.load(i)); - }); - - LoopNest l({b}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b}); - - std::vector aData(4, 2.0f); - std::vector bData(4, 0.0f); - at::Half* aDev = nullptr; - float* bDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = bData.size() * sizeof(bData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(bData.data(), bDev, bSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(bData, 2.0f); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); -} - -TEST(Cuda, DynamicShape2D_CUDA) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, c, m, n}); - - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - float* aDev = nullptr; - float* bDev = nullptr; - float* cDev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); - C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); - C10_CUDA_CHECK(cudaMalloc(&cDev, cData.size() * sizeof(cData[0]))); - C10_CUDA_CHECK(cudaMemcpy( - aDev, - aData.data(), - aData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - bDev, - bData.data(), - bData.size() * sizeof(bData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - cDev, - cData.data(), - cData.size() * sizeof(cData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, cDev, M, N}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy( - cData.data(), - cDev, - cData.size() * sizeof(cData[0]), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(cDev)); - }; - testWithSize(32, 32); - testWithSize(1, 16); - testWithSize(27, 13); -} - -TEST(Cuda, TestRand01_CUDA) { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return Intrinsics::make(IntrinsicsOp::kRand, kFloat); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c); - const int N = block_count * block_size * num_iter; - PaddedBuffer c_v(N); - - // TODO: move gpu support into PaddedBuffer - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - float sum1 = 0; - float sum2 = 0; - float sum3 = 0; - for (const auto i : c10::irange(N)) { - float v = c_v.data()[i]; - sum1 += v; - sum2 += v * v; - sum3 += v * v * v; - ASSERT_TRUE(v >= 0 && v < 1); - } - sum1 /= N; - sum2 /= N; - sum3 /= N; - float sum1_mean = 1.f / 2; - float sum2_mean = 1.f / 3; - float sum3_mean = 1.f / 4; - - ASSERT_NEAR(sum1, sum1_mean, 2e-2); - ASSERT_NEAR(sum2, sum2_mean, 2e-2); - ASSERT_NEAR(sum3, sum3_mean, 2e-2); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, DynamicShapeSplit_CUDA) { - constexpr int64_t N = 4096; - VarHandle n("n", kLong); - BufHandle a("a", {n}, kFloat); - Tensor b = - Compute("b", {n}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); - LoopNest l({b}); - ForPtr inner; - std::vector loops = l.getLoopStmtsFor(b); - l.splitWithMask(loops[0], 1024, &inner); - loops[0]->set_gpu_block_index(0); - inner->set_gpu_thread_index(0); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, n}); - - std::vector aData(N, 1.0f); - std::vector bData(N, 1.0f); - float* aDev = nullptr; - float* bDev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); - C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); - C10_CUDA_CHECK(cudaMemcpy( - aDev, - aData.data(), - aData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - bDev, - bData.data(), - bData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, N}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy( - bData.data(), - bDev, - bData.size() * sizeof(aData[0]), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(bData, std::vector(N, 2.0f), 1e-7); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); -} - -TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { - const static int N = 1024; - BufHandle data_buf("data", {N}, kFloat); - BufHandle output_buf("output", {1}, kFloat); - - // The test adds the following code for trivial reduction: - // for (const auto bidx : c10::irange(1)) { // blockIdx.x - // for (const auto tidx : c10::irange(1)) { // threadIdx.x - // output[0] = 0.f; - // for (const auto i1 : c10::irange(1024)) { - // output[0] = output[0] + data[i1]; - // } - // } - // } - - StorePtr init_store = output_buf.store({0}, 0.f); - VarHandle i1("i1", kInt); - ExprHandle load_data = Load::make(data_buf, {i1}); - ExprHandle load_output = Load::make(output_buf, {0}); - ExprHandle add_value = load_output + load_data; - StorePtr store_output = output_buf.store({0}, add_value); - ForPtr for_output = For::make(i1, 0, N, store_output); - StmtPtr reduce_block = Block::make({init_store, for_output}); - VarHandle thread_idx("tidx", kInt); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr thread_idx_loop = - For::make(thread_idx, 0, 1, reduce_block, thread_idx_options); - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, thread_idx_loop, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, data_buf, output_buf); - PaddedBuffer data_v(N); - PaddedBuffer output_v(1, "output_v"); - PaddedBuffer output_ref(1, "output_ref"); - - output_ref(0) = 0; - for (const auto i : c10::irange(N)) { - data_v(i) = i; - output_ref(0) += data_v(i); - } - - float* data_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&data_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - data_dev, data_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - float* output_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&output_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(data_dev, output_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - output_v.data(), output_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(output_v, output_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(data_dev)); - C10_CUDA_CHECK(cudaFree(output_dev)); -} - -TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { - const static int N = 1024; - - // This test does the following reduction: - // clang-format off - // for b in 0..1 // block-idx - // for t in 0..1024: // thread-idx - // if t < 1: - // b[0] = 0 - // // implied sync_threads - // for t in 0..1024: // thread-idx - // b[0] = b[0] + a[t] // implied atomic - // clang-format on - - BufHandle a_buf("a", {N}, kFloat); - BufHandle b_buf("b", {1}, kFloat); - - StorePtr init_store = b_buf.store({0}, 0.f); - VarHandle t("t", kInt); - VarHandle b("b", kInt); - - // for t in 0..1024: // thread-idx - // if t < 1: - // b[0] = 0 - ExprHandle cond_t_lt_1 = - CompareSelect::make(t, 1, CompareSelectOperation::kLT); - CondPtr masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr for_init = For::make(t, 0, N, masked_init_b, thread_idx_options); - - // for t in 0..1024: // thread-idx - // b[0] = b[0] + a[t] // implied atomic - ExprHandle load_a = Load::make(a_buf, {t}); - ExprHandle load_b = Load::make(b_buf, {0}); - ExprHandle add_value = load_b + load_a; - StorePtr store_b = b_buf.store({0}, add_value); - ForPtr for_b = For::make(t, 0, N, store_b, thread_idx_options); - - StmtPtr reduce_block = Block::make({for_init, for_b}); - - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, reduce_block, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); - PaddedBuffer a_v(N); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(N)) { - a_v(i) = i; - b_ref(0) += a_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, NoThreadIdxWrite_1_CUDA) { - // This test does the following reduction: - // - // for k in 0..1: // block-idx - // a[0] = 0 - // for n in 0..2: - // a[0] = a[0] + n - // for m in 0..1024: // thread-idx - // b[m] = m - // a[1] = 1 - // for l in 0..2: - // a[1] = a[1] + n - // - // note that the statements not covered by thread-idx are supposed to be - // covered by its own thread-idx - - const static int N = 1024; - BufHandle a_buf("a", {2}, kFloat); - BufHandle b_buf("b", {N}, kFloat); - - VarHandle k("k", kInt); - VarHandle l("l", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - // a[0] = 0 - // for n in 0..2: - // a[0] = a[0] + n - StorePtr store_a0_0 = a_buf.store({0}, 0.f); - ExprHandle load_a0 = Load::make(a_buf, {0}); - ExprHandle v1 = load_a0 + n; - StorePtr store_a0_v1 = a_buf.store({0}, v1); - ForPtr loop_a_0 = For::make(n, 0, 2, store_a0_v1); - - // for m in 0..1024: // thread-idx - // b[m] = m - StorePtr store_bm_m = b_buf.store({m}, m + 0.f); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); - - // a[1] = 1 - // for l in 0..2: - // a[1] = a[1] + l - StorePtr store_a1_1 = a_buf.store({1}, 1.f); - ExprHandle load_a1 = a_buf.load(1); - ExprHandle v2 = load_a1 + l; - StorePtr store_a1_v2 = a_buf.store({1}, v2); - ForPtr loop_a_1 = For::make(l, 0, 2, store_a1_v2); - - StmtPtr reduce_block = - Block::make({store_a0_0, loop_a_0, loop_b_1, store_a1_1, loop_a_1}); - - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, reduce_block, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); - PaddedBuffer a_v(2); - PaddedBuffer b_v(N, "b_v"); - PaddedBuffer a_ref(2, "a_ref"); - PaddedBuffer b_ref(N, "b_ref"); - - a_ref(0) = 0; - for (const auto i : c10::irange(2)) { - a_ref(0) += i; - } - a_ref(1) = a_ref(0) + 1; - for (const auto i : c10::irange(N)) { - b_ref(i) = i; - } - - // TODO: add check of the generated code. - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, 2 * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(a_v.data(), a_dev, 2 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(a_v, a_ref, 1e-5); - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, SharedMemReduce_1_CUDA) { - // FIXME: this test is flaky in CI. - // This test does the following: - // for k in 0..1: // block-idx - // alloc(c, 64) - // for n in 0..64: // thread-idx - // c(n) = 0 - // for m in 0..128: - // for n in 0..64: // thread_idx - // c(n) = c(n) + a(k, m, n) - // b(k) = 0 - // for n in 0..64: // thread_idx - // b(k) = b(k) + c(n) - // free(c) - - const int M = 128; - const int N = 64; - const int kTotalSize = M * N; - LoopOptions thread_idx_opt; - thread_idx_opt.set_gpu_thread_index(0); - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - BufHandle a("a", {1, M, N}, kFloat); - BufHandle b("b", {1}, kFloat); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - std::vector block; - std::vector dims; - dims.push_back(ExprHandle(N).node()); - BufHandle c{alloc("c", dims, kFloat)}; - { - // alloc(c, 64); - AllocatePtr alloc = Allocate::make(c); - block.push_back(alloc); - } - - { - // for n in 0..64: // thread-idx - // c(n) = 0 - StorePtr store_cn_0 = Store::make(c, {n}, 0.f); - ForPtr loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt); - block.push_back(loop_n1); - } - - { - // for m in 0..128: - // for n in 0..64: // thread_idx - // c(n) = c(n) + a(k, m, n) - ExprHandle load_cn = Load::make(kFloat, c, {n}); - ExprHandle a_kmn = Load::make(a, {k * (M * N) + m * N + n}); - ExprHandle v_add = load_cn + a_kmn; - StorePtr store_cn_v = Store::make(c, {n}, v_add); - ForPtr loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); - ForPtr loop_m1 = For::make(m, 0, M, loop_n2); - block.push_back(loop_m1); - } - - { - // b(k) = 0 - // for n in 0..64: // thread_idx - // b(k) = b(k) + c(n) - StorePtr store_bk_0 = b.store({k}, 0.f); - block.push_back(store_bk_0); - ExprHandle load_bk = b.load(k); - ExprHandle load_cn = Load::make(kFloat, c, {n}); - ExprHandle v_add = load_bk + load_cn; - StorePtr store_bk = b.store({k}, v_add); - ForPtr loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); - block.push_back(loop_n3); - } - - { - // free(c) - FreePtr free_stmt = Free::make(c); - block.push_back(free_stmt); - } - - BlockPtr reduce_body = Block::make(block); - ForPtr loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt); - - // TODO: check the generated code for correctness. - CudaCodeGen cuda_cg(loop_k1, a, b); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is not masked, but the d write is. - const std::string& verification_pattern = - R"IR( -# CHECK: c_1 = 0 -# CHECK: for (int m = 0; m < 128 -# CHECK: c_1 = c_1 + -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<1 -# CHECK: b[blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: atomicAdd(&b[blockIdx.x], c_1) -)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, LocalMemReduce_1_CUDA) { - // This test does the following: - // for k in 0..1: // block-idx - // b(k) = 0 - // for n in 0..64: // thread-idx - // alloc(c, 1) - // c(0) = 0 - // for m in 0..128: - // c(0) = c(0) + a(k, m, n) - // b(k) = b(k) + c(0) - // free(c) - - const int M = 128; - const int N = 64; - const int kTotalSize = M * N; - LoopOptions thread_idx_opt; - thread_idx_opt.set_gpu_thread_index(0); - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - BufHandle a("a", {1, M, N}, kFloat); - BufHandle b("b", {1}, kFloat); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle c{ - alloc("c", std::vector({alloc(1)}), kFloat)}; - std::vector block_k; - { - // b(k) = 0 - StorePtr store_bk_0 = b.store({k}, 0.f); - block_k.push_back(store_bk_0); - } - std::vector block_n; - { - // alloc(c, 1); - AllocatePtr alloc = Allocate::make(c); - block_n.push_back(alloc); - } - { - // c(0) = 0 - StorePtr store_c0_0 = Store::make(c, {0}, 0.f); - block_n.push_back(store_c0_0); - } - { - // for m in 0..128: - // c(0) = c(0) + a(k, m, n) - ExprHandle load_c0 = Load::make(kFloat, c, {0}); - ExprHandle a_kmn = a.load(k * (M * N) + m * N + n); - ExprHandle v_add = load_c0 + a_kmn; - StorePtr store_c0_v = Store::make(c, {0}, v_add); - ForPtr loop_m = For::make(m, 0, M, store_c0_v); - block_n.push_back(loop_m); - } - { - // b(k) = b(k) + c(0) - ExprHandle load_bk = b.load(k); - ExprHandle load_c0 = Load::make(kFloat, c, {0}); - ExprHandle v_add = load_bk + load_c0; - StorePtr store_bk = b.store({k}, v_add); - block_n.push_back(store_bk); - } - { - // free(c) - FreePtr free_stmt = Free::make(c); - block_n.push_back(free_stmt); - } - { - BlockPtr block_n_stmt = Block::make(block_n); - ForPtr for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt); - block_k.push_back(for_n); - } - BlockPtr block_k_stmt = Block::make(block_k); - ForPtr loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt); - - CudaCodeGen cuda_cg(loop_k, a, b); - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, HalfSupport_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor b = Compute("b", {4}, [&](const VarHandle& i) { - return Cast::make(half, ExprHandle(2.0f) * a.load(i)); - }); - - Tensor c = Compute("c", {4}, [&](const VarHandle& i) { - return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i)); - }); - - Tensor d = Compute("d", {4}, [&](const VarHandle& i) { - return Cast::make(half, c.load(i)); - }); - - LoopNest l({b, c, d}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, c, d}); - - std::vector aData(4, 2.0f); - std::vector cData(4, 0.0f); - std::vector dData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* bDev = nullptr; - at::Half* cDev = nullptr; - at::Half* dDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = aData.size() * sizeof(aData[0]); - auto cSize = cData.size() * sizeof(float); - auto dSize = dData.size() * sizeof(dData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMalloc(&cDev, cSize)); - C10_CUDA_CHECK(cudaMalloc(&dDev, dSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(cDev, cData.data(), cSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(dDev, dData.data(), dSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, cDev, dDev}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(cData.data(), cDev, cSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(dData.data(), dDev, dSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(cData, 46.0f); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(cDev)); - C10_CUDA_CHECK(cudaFree(dDev)); -} - -TEST(Cuda, HalfPropagation_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { - return Max::make(a.load(i), ExprHandle(alloc(0)), true); - }); - - LoopNest l({relu}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, relu}); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the types used by the Max are Float. - const std::string& verification_pattern = - R"IR( -# CHECK: for ( -# CHECK: float v = float(a[i]); -# CHECK: relu[i] = half(Max(v, 0.f -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector aData(4, 2.0f); - std::vector reluData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* reluDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto reluSize = reluData.size() * sizeof(reluData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, reluDev}); - C10_CUDA_CHECK( - cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(aData, reluData); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(reluDev)); -} - -TEST(Cuda, UnusedHalfArgument_CUDA) { - BufHandle a("a", {4}, kFloat); - auto half = ToDtype(); - BufHandle b("b", {4}, half); - Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { - return Max::make(a.load(i), ExprHandle(alloc(0)), true); - }); - - LoopNest l({relu}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, relu}); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the types used by the Max are Float. - const std::string& verification_pattern = - R"IR( -# CHECK: for ( -# CHECK: float v = a[i]; -# CHECK: relu[i] = Max(v, 0.f -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // Sanity Cbeck; - std::vector aData(4, 2.0f); - std::vector bData(4, 2.0f); - std::vector reluData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* bDev = nullptr; - at::Half* reluDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = bData.size() * sizeof(bData[0]); - auto reluSize = reluData.size() * sizeof(reluData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, reluDev}); - C10_CUDA_CHECK( - cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(aData, reluData); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(reluDev)); -} - -TEST(Cuda, PrioritizeDependents_CUDA) { - BufHandle a("a", {10}, kFloat); - BufHandle b("b", {12}, kFloat); - BufHandle c("c", {12}, kFloat); - - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - /* - * for (const auto i : c10::irange(12)) { - * c[i] = (i < 10 ? a[i] + b[i] : b[i]); - * } - */ - ExprHandle load_a = a.load({i}); - ExprHandle load_b = b.load({i}); - ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT); - ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b); - - ForPtr loop = - For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt); - - CudaCodeGen cuda_cg(loop, a, b, c); - - PaddedBuffer a_v(10, "a_v"); - PaddedBuffer b_v(12, "b_v"); - PaddedBuffer c_v(12, "c_v"); - PaddedBuffer c_ref(12, "c_ref"); - - for (const auto i : c10::irange(10)) { - a_v(i) = i * 100; - b_v(i) = i; - c_v(i) = 0; - } - - for (const auto i : c10::irange(10, 12)) { - b_v(i) = i; - c_v(i) = 0; - } - - float* a_dev = nullptr; - float* b_dev = nullptr; - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, 10 * sizeof(float))); - C10_CUDA_CHECK(cudaMalloc(&b_dev, 12 * sizeof(float))); - C10_CUDA_CHECK(cudaMalloc(&c_dev, 12 * sizeof(float))); - - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), 10 * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), 12 * sizeof(float), cudaMemcpyHostToDevice)); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, 12 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - for (const auto i : c10::irange(12)) { - if (i < 10) { - c_ref(i) = i + i * 100; - } else { - c_ref(i) = i; - } - } - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -/// Tests the case where there are two loops which have different extents bound -/// to the same block dimension. We must mask the smaller extent loop body. -TEST(Cuda, MaskBlockDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is not masked, but the d write is. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if (blockIdx -# CHECK: c[blockIdx.x] = -# CHECK: if (blockIdx.x<50 -# CHECK: d[blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(1))); - - // Sanity check that the kernel works. - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case with two loops, which have different extents that are bound -/// to the same thread dimension. This is the same as the above - the smaller -/// rank write should be masked. But this time we also need to syncthreads. -TEST(Cuda, MaskThreadDim_CUDA) { - int A_SIZE = 50; - int B_SIZE = 100; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i / 2) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is masked, but the d write is not. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x<50 -# CHECK: c[threadIdx.x] = -# CHECK: __syncthreads(); -# CHECK-NOT: if (threadIdx.x -# CHECK: d[threadIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i / 2) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where there are two loops, and each is bound to a different -/// block dimension. In this case all writes should be masked since they occur -/// in distinct dimensions. -// Note: this is an extremely dumb pattern which we should never see, but is a -// useful edge case to make sure we've got things covered. -TEST(Cuda, MaskMultiBlockDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(1); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Write to c should be masked against y, write to d against x. - const std::string& verification_pattern = - R"IR( -# CHECK: if (blockIdx.y<1 -# CHECK: c[blockIdx.x] = -# CHECK: if (blockIdx.x<1 -# CHECK: d[blockIdx.y] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(blockExtents[1], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where both the blockDim and threadDim are bound to different -/// loops. In this instance both stores should be masked since they are -/// distinct. -// Note: this is an extremely dumb pattern which we should never see, but is a -// useful edge case to make sure we've got things covered. -TEST(Cuda, MaskBlockAndThreadDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x<1 -# CHECK: c[blockIdx.x] = -# CHECK: } -# CHECK: if (blockIdx.x<1 -# CHECK: d[threadIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where the loopnest has two loops of depth two: each with the -/// outer loop bound to blockDim.x and the inner loop bound to threadDim.x. In -/// this case all writes with a rank smaller than the max should be masked. -TEST(Cuda, MaskMultiDim_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: C[threadIdx.x + 100 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case where loop extents are symbolic and not known at compile time. -// In this case both stores must be masked against the extent of the other loop, -// in case it is larger. -TEST(Cuda, MaskMultiDimSymbolic_CUDA) { - VarHandle OUTER_SIZE("OUTER_SIZE", kLong); - VarHandle A_SIZE("A_SIZE", kLong); - VarHandle B_SIZE("B_SIZE", kLong); - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, OUTER_SIZE, A_SIZE, B_SIZE, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Since we don't know which is bigger (A_SIZE or B_SIZE) we must mask both. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x(A_SIZE.node(), B_SIZE.node(), true))); - - int64_t OUTER_EXTENT = 10; - int64_t A_EXTENT = 100; - int64_t B_EXTENT = 50; - - PaddedBuffer a_v(OUTER_EXTENT, A_EXTENT); - PaddedBuffer b_v(OUTER_EXTENT, B_EXTENT); - PaddedBuffer c_v(OUTER_EXTENT, A_EXTENT); - PaddedBuffer d_v(OUTER_EXTENT, B_EXTENT); - - PaddedBuffer c_ref(OUTER_EXTENT, A_EXTENT); - PaddedBuffer d_ref(OUTER_EXTENT, B_EXTENT); - - for (const auto o : c10::irange(OUTER_EXTENT)) { - for (const auto i : c10::irange(A_EXTENT)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_EXTENT)) { - for (const auto i : c10::irange(B_EXTENT)) { - b_v(o, i) = (float)(B_EXTENT - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, OUTER_EXTENT, A_EXTENT, B_EXTENT, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case where two loops are fused at a common parent loop, which is -// bound to the block dimension. Internally the inner loops have different -// extents but are bound to the same thread dimension. The smaller loop should -// be masked. -TEST(Cuda, MaskCompoundInnerLoop_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); - - // Can't build this using Compute and transforms yet. - LoopOptions blockBound; - blockBound.set_gpu_block_index(0); - LoopOptions threadBound; - threadBound.set_gpu_thread_index(0); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - - StmtPtr stmt = For::make( - i, - 0, - OUTER_SIZE, - Block::make( - {For::make( - j, - 0, - A_SIZE, - c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), - threadBound), - For::make( - k, - 0, - B_SIZE, - d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), - threadBound)}), - blockBound); - - stmt = FlattenIndexes(stmt); - stmt = IRSimplifier::simplify(stmt); - - CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: c[threadIdx.x + 100 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: d[threadIdx.x + 50 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev, d_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loops fused into a common parent, which is not bound -// to any block or thread dimension - however it's two inner loops are bound to -// the first thread dimensions. This should work just like the MaskThreadDim -// test where the bigger loop is unmasked but the smaller is masked. -TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); - - // Can't build this using Compute and transforms yet. - LoopOptions blockBound; - blockBound.set_gpu_block_index(0); - LoopOptions threadBound; - threadBound.set_gpu_thread_index(0); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - - StmtPtr stmt = For::make( - i, - 0, - OUTER_SIZE, - Block::make( - {For::make( - j, - 0, - A_SIZE, - c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), - threadBound), - For::make( - k, - 0, - B_SIZE, - d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), - threadBound)})); - - stmt = FlattenIndexes(stmt); - stmt = IRSimplifier::simplify(stmt); - - CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The other loop remains the D write is masked. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 10 -# CHECK-NOT: if ( -# CHECK: c[threadIdx.x + 100 * i] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: d[threadIdx.x + 50 * i] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev, d_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loop nests, each of which bound to the same block -// size, but with internal loops bound to different thread rank (ie x and y). In -// this case both bodies must be masked against the other dimension being > 0. -// Note: this is a bit degenerate no one would actually write this for perf. -TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 30; - int B_SIZE = 15; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(1); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Both stores masked against the other thread dim < 1. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.y<1 -# CHECK: C[threadIdx.x + 30 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<1 -# CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loop nests, each bound to both Block and Thread but -// the second loop is smaller in both cases - the second store must be masked -// for both the block and thread dimension. -TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { - int OUTER_A_SIZE = 10; - int OUTER_B_SIZE = 5; - int A_SIZE = 30; - int B_SIZE = 15; - BufHandle a_buf("a", {OUTER_A_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_B_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_A_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_B_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked twice, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: C[threadIdx.x + 30 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (blockIdx.x<5 -# CHECK: if (threadIdx.x<15 -# CHECK: D[threadIdx.x + 15 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_A_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_B_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_A_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_B_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_A_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_B_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_A_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_B_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -} // namespace jit -} // namespace torch - -#endif diff --git a/test/cpp/tensorexpr/test_dynamic_shapes.cpp b/test/cpp/tensorexpr/test_dynamic_shapes.cpp deleted file mode 100644 index 07b9872fb8325..0000000000000 --- a/test/cpp/tensorexpr/test_dynamic_shapes.cpp +++ /dev/null @@ -1,701 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -TEST(DynamicShapes, SimpleGraph) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Tensor, - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - return (%4))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto x_type = TensorType::create(at::rand({10, 5})); - std::vector x_sym_dims( - {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); - auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); - graph->inputs().at(0)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-2), SS(-3)), - // %SS_2 : int, - // %SS_3 : int): - // %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x) - // %4 : Float(SS(-2), SS(-3)) = aten::erf(%3) - // return (%4) - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - std::vector symbolic_shape_inputs = c10::fmap( - x_sym_dims, - [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::erf(at::tanh(a)); - - std::vector stack = fmap(std::vector({a})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::erf(at::tanh(a)); - - std::vector stack = fmap(std::vector({a})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWith2InputsSameDims) { -#ifdef TORCH_ENABLE_LLVM - // The two inputs in this graph must have the same dims. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Tensor, - %y : Tensor, - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({10, 5})); - std::vector x_sym_dims( - {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); - auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-4), SS(-5)), - // %y : Float(SS(-4), SS(-5)), - // %SS_2 : int, - // %SS_3 : int): - // %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x) - // %5 : Float(SS(-4), SS(-5)) = aten::erf(%4) - // %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y) - // return (%6) - - std::vector symbolic_shape_inputs = c10::fmap( - x_sym_dims, - [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWith2InputsAndBroadcast) { -#ifdef TORCH_ENABLE_LLVM - // The second input to the graph has a dim of size 1 which should be - // broadcasted in the at::mul op. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(10, 5, requires_grad=0, device=cpu), - %y : Float(1, 5, requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({10, 5})); - auto y_type = TensorType::create(at::rand({1, 5})); - auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes( - std::vector({x_dim0_sym, x_dim1_sym})); - auto y_sym_type = y_type->withSymbolicShapes(std::vector( - {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(y_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-6), SS(-7)), - // %y : Float(1, SS(-7)), - // %SS_2 : int, - // %SS_3 : int): - // %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x) - // %5 : Float(SS(-6), SS(-7)) = aten::erf(%4) - // %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y) - // return (%6) - - std::vector symbolic_shape_inputs( - {x_dim0_sym.value(), x_dim1_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) { -#ifdef TORCH_ENABLE_LLVM - // The second input to the graph has a dim of size 1 which should be - // broadcasted in the at::mul op. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(1, 5, requires_grad=0, device=cpu), - %y : Float(1, 5, requires_grad=0, device=cpu), - %SS_2 : int): - %4 : Tensor = aten::tanh(%x) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({1, 5})); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes(std::vector( - {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(1, SS(-2)), - // %y : Float(1, SS(-2)), - // %SS_2 : int): - // %3 : Float(1, SS(-2)) = aten::tanh(%x) - // %4 : Float(1, SS(-2)) = aten::mul(%3, %y) - // return (%4) - - std::vector symbolic_shape_inputs({x_dim1_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::tanh(a), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::tanh(a), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithSymbolicStrides) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %SS_3 : int, - %SS_2 : int): - %15 : int = prim::Constant[value=1]() - %21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15) - %22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0) - return (%22))IR"; - parseIR(graph_string, &*graph); - - std::vector input_desc = { - torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE}; - std::vector output_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = output_desc; - std::vector symbolic_shape_inputs = {-3, -2}; - TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - { - auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(x0, x1, 1), x0); - - std::vector inputs = {x0, x1}; - std::vector stack = at::fmap(inputs); - stack.push_back(32); - stack.push_back(10); - k.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - { - auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto out = - at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(x0, x1, 1), x0); - - std::vector inputs = {out, x0, x1}; - std::vector stack = at::fmap(inputs); - stack.push_back(32); - stack.push_back(10); - k.runWithAllocatedOutputs(stack); - - ASSERT_TRUE(at::allclose(out, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithCatAndBroadcast) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(10, 5, requires_grad=0, device=cpu), - %y : Float(4, 5, requires_grad=0, device=cpu), - %z : Float(1, 1, requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int): - %11 : int = prim::Constant[value=0]() - %3 : Tensor = aten::tanh(%x) - %out1 : Tensor = aten::erf(%3) - %out2 : Tensor = aten::relu(%y) - %10 : Tensor[] = prim::ListConstruct(%out1, %out2) - %25 : Tensor = aten::cat(%10, %11) - %28 : Tensor = aten::hardswish(%25) - %29 : Tensor = aten::mul(%28, %z) - return (%29))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto z_inp = graph->inputs()[2]; - auto x_type = TensorType::create(at::rand({10, 5})); - auto y_type = TensorType::create(at::rand({4, 5})); - auto z_type = TensorType::create(at::rand({1, 1})); - auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes( - std::vector({x_dim0_sym, x_dim1_sym})); - auto y_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto y_sym_type = y_type->withSymbolicShapes( - std::vector({y_dim0_sym, x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(y_sym_type); - auto cat_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto cat_out_type = x_type->withSymbolicShapes( - std::vector({cat_dim0_sym, x_dim1_sym})); - auto nodeIt = graph->nodes().begin(); - ++nodeIt; - nodeIt->output()->setType(x_sym_type); // aten::tanh - ++nodeIt; - nodeIt->output()->setType(x_sym_type); // aten::erf - ++nodeIt; - nodeIt->output()->setType(y_sym_type); // aten::relu - ++nodeIt; - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::cat - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::hardswish - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::mul - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-2), SS(-3)), - // %y : Float(SS(-4), SS(-3)), - // %z : Float(1, 1), - // %SS_2 : int, - // %SS_3 : int, - // %SS_4 : int, - // %SS_5 : int): - // %7 : int = prim::Constant[value=0]() - // %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x) - // %9 : Float(SS(-2), SS(-3)) = aten::erf(%8) - // %10 : Float(SS(-4), SS(-3)) = aten::relu(%y) - // %11 : Tensor[] = prim::ListConstruct(%9, %10) - // %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7) - // %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12) - // %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z) - // return (%14) - - std::vector symbolic_shape_inputs( - {x_dim0_sym.value(), - x_dim1_sym.value(), - y_dim0_sym.value(), - cat_dim0_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[z_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul( - at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c); - - std::vector stack = fmap(std::vector({a, b, c})); - stack.push_back(10); - stack.push_back(5); - stack.push_back(4); - stack.push_back(14); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -#endif -} - -TEST(DynamicShapes, GraphFromModel) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu), - %2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu), - %input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu), - %4 : Float(SS(-7), requires_grad=0, device=cpu), - %5 : Float(SS(-7), requires_grad=0, device=cpu), - %SS_10 : int, - %SS_9 : int, - %SS_8 : int, - %SS_7 : int, - %SS_6 : int, - %SS_5 : int, - %SS_4 : int, - %SS_3 : int, - %SS_2 : int): - %15 : int = prim::Constant[value=1]() - %16 : bool = prim::Constant[value=0]() - %17 : int = prim::Constant[value=6]() - %18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16) - %19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2) - %20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15) - %21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15) - %22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4) - return (%22))IR"; - parseIR(graph_string, &*graph); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->inputs().at(3)] = input_desc; - symbolic_strides[graph->inputs().at(4)] = input_desc; - symbolic_strides[graph->inputs().at(5)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - std::vector symbolic_shape_inputs = { - -10, -9, -8, -7, -6, -5, -4, -3, -2}; - TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - int64_t i2 = 10; - int64_t i3 = 32; - int64_t i4 = 19; - int64_t i5 = 71; - int64_t i6 = 139; - int64_t i7 = 261; - int64_t i8 = 261; - int64_t i9 = 261; - int64_t i10 = 261; - auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong)); - auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4); - - { - std::vector inputs = {x0, x1, x2, x3, x4, x5}; - std::vector stack = at::fmap(inputs); - stack.emplace_back(i10); - stack.emplace_back(i9); - stack.emplace_back(i8); - stack.emplace_back(i7); - stack.emplace_back(i6); - stack.emplace_back(i5); - stack.emplace_back(i4); - stack.emplace_back(i3); - stack.emplace_back(i2); - k.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - { - auto out = - at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - std::vector inputs = {out, x0, x1, x2, x3, x4, x5}; - std::vector stack = at::fmap(inputs); - stack.emplace_back(i10); - stack.emplace_back(i9); - stack.emplace_back(i8); - stack.emplace_back(i7); - stack.emplace_back(i6); - stack.emplace_back(i5); - stack.emplace_back(i4); - stack.emplace_back(i3); - stack.emplace_back(i2); - k.runWithAllocatedOutputs(stack); - - ASSERT_TRUE(at::allclose(out, ref)); - } -#endif -} - -TEST(DynamicShapes, MultiThreadedExecution) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_template = R"IR( - graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), - %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), - %SS_2 : int, - %SS_3 : int): - %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x) - %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3) - %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y) - return (%5))IR"; - for (bool use_cuda : {false, true}) { - if (!torch::cuda::is_available() && use_cuda) { - continue; - } - auto device = use_cuda ? at::kCUDA : at::kCPU; - at::jit::TemplateEnv env; - env.s("device", use_cuda ? "cuda:0" : "cpu"); - const auto graph_string = format(graph_template, env); - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - auto run_kernel = [&](int dim1, int dim2) { - auto a = - at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); - auto b = - at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); - - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - // Run the kernel in parallel to ensure that the run() method calls in - // TensorExprKernel are not changing any state. - constexpr size_t kNumThreads = 4; - std::vector threads; - for (size_t id = 0; id < kNumThreads; ++id) { - threads.emplace_back(run_kernel, id + 5, id + 20); - } - for (auto& t : threads) { - t.join(); - } - } -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp deleted file mode 100644 index eb2d6296b2299..0000000000000 --- a/test/cpp/tensorexpr/test_expr.cpp +++ /dev/null @@ -1,836 +0,0 @@ -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -using SimpleIRExprEval = ExprEval; - -TEST(Expr, BasicValueTest) { - ExprHandle a = IntImm::make(2), b = IntImm::make(3); - ExprHandle c = Add::make(a, b); - SimpleIRExprEval eval(c); - ASSERT_EQ(eval.value(), 5); -} - -TEST(Expr, BasicValueTest02) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - SimpleIRExprEval eval(f); - ASSERT_EQ(eval.value(), -4.0f); -} - -TEST(Expr, IsChannelsLastContiguous) { - std::vector vars = { - VarHandle("var1", kLong), - VarHandle("var2", kLong), - VarHandle("var3", kLong), - VarHandle("var4", kLong), - VarHandle("var5", kLong)}; - - // { - // key: ndims, - // value: [ - // ... - // [dim_2, dim_1, ..., dim_n] - // ] - // } - using shapGenInfo = std::unordered_map>>; - - // { - // size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n], - // strides: [ - // ... - // [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z] - // ] - // } - using shapeInfo = - std::pair, std::vector>>; - - std::vector dims = {3, 4, 5}; - - std::unordered_map> dims_expr_vec_conf = { - {3, std::vector(vars.begin(), vars.begin() + 2)}, - {4, std::vector(vars.begin(), vars.begin() + 3)}, - {5, std::vector(vars.begin(), vars.begin() + 4)}, - }; - - shapGenInfo channels_last_cont_shape_conf = { - {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}}; - shapGenInfo channels_last_non_cont_shape_conf = { - {3, {{2, 1, 0}, {1, 0, 2}}}, - {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}}, - {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}}; - - shapGenInfo cont_shape_conf = { - {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}}; - - auto shape_gen_fn = [dims_expr_vec_conf]( - int ndims, shapGenInfo shape_gen_info) -> shapeInfo { - auto dims_expr_vec = dims_expr_vec_conf.at(ndims); - std::vector> strides_expr_vec; - for (size_t i = 0; i < strides_expr_vec.size(); i++) { - strides_expr_vec[i].resize(ndims); - } - - auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) { - if (indicator % 2 == 0) { - return a * b; - } else { - return b * a; - } - }; - - auto stride_order_vec = shape_gen_info.at(ndims); - for (size_t i = 0; i < strides_expr_vec.size(); i++) { - auto stride_order = stride_order_vec[i]; - - strides_expr_vec[i][stride_order[0]] = 1; - for (size_t j = 1; j < stride_order.size(); j++) { - auto cur_dim_idx = stride_order[j]; - auto adjacent_dim_idx = stride_order[j - 1]; - - strides_expr_vec[i][cur_dim_idx] = stride_gen_fn( - i, - dims_expr_vec[adjacent_dim_idx], - strides_expr_vec[i][adjacent_dim_idx]); - } - } - - return {dims_expr_vec, strides_expr_vec}; - }; - - auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool { - if (ndims == 3) { - return buf_handle.is_channels_last_1d_contiguous(); - } else if (ndims == 4) { - return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast); - } else { - return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d); - } - }; - - // channels-last contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true); - } - } - - // channels-last non-contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false); - } - } - - // contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(buf_handle.is_contiguous(), true); - } - } - - // non-contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(buf_handle.is_contiguous(), false); - } - } -} - -TEST(Expr, LetTest01) { - VarHandle x("x", kFloat); - ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, LetTest02) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = - ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - eval.bindVar(y, ExprHandle(6.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4 * 6)); -} - -TEST(Expr, LetStmtTest01) { - BufHandle a_buf("a", {1}, kFloat); - BufHandle b_buf("b", {1}, kFloat); - - ExprHandle load_a = a_buf.load(0); - VarHandle var = VarHandle("v", kFloat); - StmtPtr let_store = Let::make(var, load_a); - StmtPtr store_b = b_buf.store({0}, var); - BlockPtr block = Block::make({let_store, store_b}); - - SimpleIREvaluator eval(block, {a_buf, b_buf}); - - PaddedBuffer a_v(1); - PaddedBuffer b_v(1); - PaddedBuffer b_ref(1); - - a_v(0) = 23; - b_ref(0) = a_v(0); - eval(a_v, b_v); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(Expr, IntTest) { - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, FloatTest) { - VarHandle x("x", kFloat); - ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, ByteTest) { - VarHandle x("x", kByte); - ExprHandle body = ExprHandle((uint8_t)2) + - (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((uint8_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, CharTest) { - VarHandle x("x", kChar); - ExprHandle body = ExprHandle((int8_t)2) + - (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int8_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, ShortTest) { - VarHandle x("x", kShort); - ExprHandle body = ExprHandle((int16_t)2) + - (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int16_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, LongTest) { - VarHandle x("x", kLong); - ExprHandle body = ExprHandle((int64_t)2) + - (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int64_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, HalfTest) { - VarHandle x("x", kHalf); - ExprHandle body = ExprHandle((at::Half)2) + - (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((at::Half)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, DoubleTest) { - VarHandle x("x", kDouble); - ExprHandle body = ExprHandle((double)2) + - (x * ExprHandle((double)3) + ExprHandle((double)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((double)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, VectorAdd01) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a_buf("A", {kTotalSize}, kFloat); - BufHandle b_buf("B", {kTotalSize}, kFloat); - BufHandle c_buf("C", {kTotalSize}, kFloat); - - /* - Build the following: - for (const auto index : c10::irange(kVectorCount)) { - store(c_buf, ramp(index * 8, 1, 8), - load(a_buf, ramp(index * 8, 1, 8) + - load(b_buf, ramp(index * 8, 1, 8)))) - } - */ - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = - a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); - ExprHandle load_b = - b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); - ExprHandle value = load_a + load_b; - StmtPtr store_c = - c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value); - StmtPtr stmt = For::make(index, 0, kVectorCount, store_c); - - ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize)); - ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize)); - ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize)); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer c_ref(kTotalSize); - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i * i; - b_v(i) = i * i * 4; - c_ref(i) = a_v(i) + b_v(i); - } - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(Expr, CompareSelectEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 1); -} - -TEST(Expr, CompareSelectDtypes) { - // LHS and RHS expressions should have the same dtype, but this dtype could - // differ from the dtype of the return values (but dtypes of true and false - // return values should be the same). - // This test constructs a CompareSelect expression where the input dtype is - // different from the output dtype and verifies that it works correctly: - // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2 - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0.0f); - std::vector c_ref(N, 3.14f); - - VarHandle i("i", kInt); - // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f - // A and B are int, C is float. - auto select_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), - b.load(i), - FloatImm::make(3.14f), - FloatImm::make(2.78f), - CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(select_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - assertAllEqual(b_buffer, 1); - ExpectAllNear(c_buffer, c_ref, 1e-7); -} - -TEST(Expr, IntrinsicsDtypes) { - constexpr int N = 256; - BufHandle a("A", {N}, kDouble); - BufHandle b("B", {N}, kDouble); - std::vector a_buffer(N, -10.0); - std::vector b_buffer(N, 0.0); - std::vector b_ref(N, 10.0); - - VarHandle i("i", kInt); - auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i)))); - - SimpleIREvaluator ir_eval(abs_expr, {a, b}); - ir_eval(a_buffer, b_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - - assertAllEqual(a_buffer, -10.0); - ExpectAllNear(b_buffer, b_ref, 1e-7); -} - -TEST(Expr, Substitute01) { - VarPtr x = alloc("x", kFloat); - VarPtr y = alloc("y", kFloat); - ExprPtr e = - alloc(alloc(x, alloc(1.0f)), alloc(x, y)); - - VarPtr z = alloc("z", kFloat); - ExprPtr e2 = Substitute(e, {{x, alloc(z, alloc(5.0f))}}); - ExprPtr e2_ref = alloc( - alloc(alloc(z, alloc(5.0f)), alloc(1.0f)), - alloc(alloc(z, alloc(5.0f)), y)); - std::ostringstream oss; - oss << *e2; - std::string e2_str = oss.str(); - - oss.str(""); - oss << *e2_ref; - std::string e2_ref_str = oss.str(); - ASSERT_EQ(e2_str, e2_ref_str); -} - -TEST(Expr, Math01) { - ExprHandle v = sin(ExprHandle(1.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "sin(1.f)"); - - SimpleIRExprEval eval(v); - float v_ref = std::sin(1.0f); - float res = eval.value(); - ASSERT_NEAR(res, v_ref, 1e-6); -} - -TEST(Expr, UnaryMath01) { - struct TestConfig { - std::function func; - std::function ref_func; - }; - - std::vector test_configs = { - {[](const ExprHandle& v) { return sin(v); }, - [](float v) { return std::sin(v); }}, - {[](const ExprHandle& v) { return sin(v); }, - [](float v) { return std::sin(v); }}, - {[](const ExprHandle& v) { return tan(v); }, - [](float v) { return std::tan(v); }}, - {[](const ExprHandle& v) { return asin(v); }, - [](float v) { return std::asin(v); }}, - {[](const ExprHandle& v) { return acos(v); }, - [](float v) { return std::acos(v); }}, - {[](const ExprHandle& v) { return atan(v); }, - [](float v) { return std::atan(v); }}, - {[](const ExprHandle& v) { return sinh(v); }, - [](float v) { return std::sinh(v); }}, - {[](const ExprHandle& v) { return cosh(v); }, - [](float v) { return std::cosh(v); }}, - {[](const ExprHandle& v) { return tanh(v); }, - [](float v) { return std::tanh(v); }}, - {[](const ExprHandle& v) { return exp(v); }, - [](float v) { return std::exp(v); }}, - {[](const ExprHandle& v) { return tensorexpr::abs(v); }, - [](float v) { return std::fabs(v); }}, - {[](const ExprHandle& v) { return log(v); }, - [](float v) { return std::log(v); }}, - {[](const ExprHandle& v) { return log2(v); }, - [](float v) { return std::log2(v); }}, - {[](const ExprHandle& v) { return log10(v); }, - [](float v) { return std::log10(v); }}, - {[](const ExprHandle& v) { return erf(v); }, - [](float v) { return std::erf(v); }}, - {[](const ExprHandle& v) { return sqrt(v); }, - [](float v) { return std::sqrt(v); }}, - {[](const ExprHandle& v) { return rsqrt(v); }, - [](float v) { return 1.0f / std::sqrt(v); }}, - {[](const ExprHandle& v) { return ceil(v); }, - [](float v) { return std::ceil(v); }}, - {[](const ExprHandle& v) { return floor(v); }, - [](float v) { return std::floor(v); }}, - {[](const ExprHandle& v) { return round(v); }, - [](float v) { return std::round(v); }}, - {[](const ExprHandle& v) { return trunc(v); }, - [](float v) { return std::trunc(v); }}, - }; - - for (const TestConfig& test_config : test_configs) { - const float input_v = 0.8765f; - ExprHandle v = test_config.func(ExprHandle(input_v)); - float v_ref = test_config.ref_func(input_v); - SimpleIRExprEval eval(v); - ASSERT_NEAR(eval.value(), v_ref, 1e-6); - } - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - for (float input_v : {std::nan("1"), 0., .5}) { - ExprHandle v = FloatImm::make(input_v); - SimpleIRExprEval eval(Intrinsics::make(kIsNan, v)); - ASSERT_NEAR(eval.value(), std::isnan(input_v), 0); - } -} - -TEST(Expr, BinaryMath01) { - struct TestConfig { - std::function func; - std::function ref_func; - }; - - std::vector test_configs = { - {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); }, - [](float v1, float v2) { return std::pow(v1, v2); }}, - {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); }, - [](float v1, float v2) { return std::fmod(v1, v2); }}, - }; - - for (const TestConfig& test_config : test_configs) { - const float v1 = 0.8765f; - float v2 = 1.2345f; - ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2)); - float v_ref = test_config.ref_func(v1, v2); - SimpleIRExprEval eval(v_expr); - ASSERT_NEAR(eval.value(), v_ref, 1e-6); - } -} - -TEST(Expr, LogicalOps01) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.69f); - ExprHandle f1 = (a > b) && (c > d); - ExprHandle f2 = (a > b) && (c < d); - ExprHandle f3 = (a < b) && (c > d); - ExprHandle f4 = (a < b) && (c < d); - ExprHandle f5 = (a < b) || (c > d); - ExprHandle f6 = (a < b) || (c < d); - ExprHandle f7 = (a > b) || (c < d); - ExprHandle f8 = (a > b) || (c > d); - - SimpleIRExprEval eval1(f1); - SimpleIRExprEval eval2(f2); - SimpleIRExprEval eval3(f3); - SimpleIRExprEval eval4(f4); - SimpleIRExprEval eval5(f5); - SimpleIRExprEval eval6(f6); - SimpleIRExprEval eval7(f7); - SimpleIRExprEval eval8(f8); - ASSERT_EQ(eval1.value(), 1); - ASSERT_EQ(eval2.value(), 0); - ASSERT_EQ(eval3.value(), 0); - ASSERT_EQ(eval4.value(), 0); - ASSERT_EQ(eval5.value(), 1); - ASSERT_EQ(eval6.value(), 0); - ASSERT_EQ(eval7.value(), 1); - ASSERT_EQ(eval8.value(), 1); -} - -TEST(Expr, LogicalOps02) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.72f); - - ExprHandle f1 = (a > b) || (c > d); - ExprHandle f2 = (a > b) && (c <= d); - ExprHandle f3 = (a > b) && (c > d); - ExprHandle ff1 = f1 && f2; - ExprHandle ff2 = f2 || f3; - - SimpleIRExprEval eval1(ff1); - SimpleIRExprEval eval2(ff2); - ASSERT_EQ(eval1.value(), 1); - ASSERT_EQ(eval2.value(), 1); -} - -TEST(Expr, LogicalOps03) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.69f); - - // Bool types - ExprHandle bool_f1 = (a > b) && BoolImm::make(true); - ExprHandle bool_f2 = (c <= d) || BoolImm::make(true); - - // Int types - ExprHandle int_f1 = (a > b) && IntImm::make(1); - ExprHandle int_f2 = (c <= d) || IntImm::make(1); - - // Short types - ExprHandle short_f1 = (a > b) && ShortImm::make(1); - ExprHandle short_f2 = (c <= d) || ShortImm::make(1); - - // Long types - ExprHandle long_f1 = (a > b) && LongImm::make(1); - ExprHandle long_f2 = (c <= d) || LongImm::make(1); - - // Char types - ExprHandle char_f1 = (a > b) && CharImm::make(1); - ExprHandle char_f2 = (c <= d) || CharImm::make(1); - - // Byte types - ExprHandle byte_f1 = (a > b) && ByteImm::make(1); - ExprHandle byte_f2 = (c <= d) || ByteImm::make(1); - - SimpleIRExprEval eval1(bool_f1); - SimpleIRExprEval eval2(bool_f2); - SimpleIRExprEval eval3(int_f1); - SimpleIRExprEval eval4(int_f2); - SimpleIRExprEval eval5(short_f1); - SimpleIRExprEval eval6(short_f2); - SimpleIRExprEval eval7(long_f1); - SimpleIRExprEval eval8(long_f2); - SimpleIRExprEval eval9(char_f1); - SimpleIRExprEval eval10(char_f2); - SimpleIRExprEval eval11(byte_f1); - SimpleIRExprEval eval12(byte_f2); - - ASSERT_EQ(eval1.value(), true); - ASSERT_EQ(eval2.value(), true); - ASSERT_EQ(eval3.value(), 1); - ASSERT_EQ(eval4.value(), 1); - ASSERT_EQ(eval5.value(), 1); - ASSERT_EQ(eval6.value(), 1); - ASSERT_EQ(eval7.value(), 1); - ASSERT_EQ(eval8.value(), 1); - ASSERT_EQ(eval9.value(), 1); - ASSERT_EQ(eval10.value(), 1); - ASSERT_EQ(eval11.value(), 1); - ASSERT_EQ(eval12.value(), 1); -} - -TEST(Expr, BitwiseOps) { - ExprHandle a(59); - ExprHandle b(11); - ExprHandle c(101); - ExprHandle d(2); - ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; - - SimpleIRExprEval eval(f); - ASSERT_EQ(eval.value(), 11); -} - -TEST(Expr, DynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(Expr, OutOfBounds) { - ExprHandle N(10); - ExprHandle start(0); - ExprHandle stop(15); - VarHandle i("i", kInt); - - BufHandle X("X", {N}, kInt); - - auto body = Store::make(X, {i}, i); - auto stmt = For::make(i, start, stop, body); - - PaddedBuffer data(20); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); -} - -TEST(Expr, OutOfBounds2d) { - std::vector> size_options = {{10, 15}, {15, 10}}; - for (auto sizes : size_options) { - ExprHandle N(sizes.first); - ExprHandle M(sizes.second); - ExprHandle start(0); - ExprHandle stopInner(15); - ExprHandle stopOuter(15); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - BufHandle X("X", {N, M}, kInt); - - auto body = Store::make(X, {i, j}, i); - auto inner = For::make(j, start, stopInner, body); - auto stmt = For::make(i, start, stopOuter, inner); - - PaddedBuffer data(400); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); - } -} - -TEST(Expr, OutOfBounds2dFlattenedIndex) { - ExprHandle buf_size(149); - ExprHandle start(0); - ExprHandle stopInner(15); - ExprHandle stopOuter(10); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - BufHandle X("X", {buf_size}, kInt); - - auto idx = Add::make(Mul::make(i, stopInner), j); - auto body = Store::make(X, {idx}, i); - auto inner = For::make(j, start, stopInner, body); - auto stmt = For::make(i, start, stopOuter, inner); - - PaddedBuffer data(400); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); -} - -void testCond01() { - const int N = 16; - PaddedBuffer a_v(N); - BufHandle a_buf("a", {N}, kFloat); - VarHandle index = VarHandle("index", kInt); - StmtPtr assign_x2 = a_buf.store({index}, cast(index) * 2); - StmtPtr assign_x3 = a_buf.store({index}, cast(index) * 3); - ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); - StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3); - StmtPtr for_stmt = For::make(index, 0, N, assign); - SimpleIREvaluator(for_stmt, {a_buf})(a_v); - - PaddedBuffer a_ref(N); - for (const auto i : c10::irange(N)) { - if (i % 2 == 0) { - a_ref(i) = i * 2; - } else { - a_ref(i) = i * 3; - } - } - ExpectAllNear(a_v, a_ref, 1e-5); -} - -void testIfThenElse01() { - ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 1.0f); -} - -void testIfThenElse02() { - ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 2.0f); -} - -void testIfThenElse03() { - ExprHandle v = - ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 2.0f); -} - -void testStmtClone() { - const int N = 16; - - BufHandle a_buf("a", {N}, kInt); - VarHandle index = VarHandle("index", kInt); - StmtPtr body = a_buf.store({index}, 5); - StmtPtr loop = For::make(index, 0, N, body); - - StmtPtr cloned_loop = Stmt::clone(loop); - std::vector orig_loop_results(N); - std::vector cloned_loop_results(N); - SimpleIREvaluator(loop, {a_buf})(orig_loop_results); - SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results); - - assertAllEqual(orig_loop_results, 5); - assertAllEqual(cloned_loop_results, 5); - - // Let's add another assign to the body in the cloned loop and verify that the - // original statement hasn't changed while the cloned one has. - StmtPtr body_addition = a_buf.store({index}, 33); - BlockPtr cloned_body = static_to(static_to(cloned_loop)->body()); - cloned_body->append_stmt(body_addition); - - std::vector orig_loop_results_after_mutation(N); - std::vector cloned_loop_results_after_mutation(N); - SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation); - SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation); - - assertAllEqual(orig_loop_results_after_mutation, 5); - assertAllEqual(cloned_loop_results_after_mutation, 33); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_external_calls.cpp b/test/cpp/tensorexpr/test_external_calls.cpp deleted file mode 100644 index 49f43d16b499d..0000000000000 --- a/test/cpp/tensorexpr/test_external_calls.cpp +++ /dev/null @@ -1,1061 +0,0 @@ -#include - -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -TEST(ExternalCall, Conv1d_float) { - BufHandle Input("Input", {1, 100, 115}, kFloat); - BufHandle Weight("Weight", {100, 1, 7}, kFloat); - BufHandle Bias("Bias", {100}, kFloat); - BufHandle ResultBuf("Result", {1, 100, 115}, kFloat); - int64_t stride = 1; - int64_t pad = 3; - int64_t dilation = 1; - int64_t groups = 100; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv1d", - {Input, Weight, Bias}, - {stride, pad, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 100, 115}, options) * 5.f; - at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f; - at::Tensor bias = at::ones({100}, options) * 11.f; - at::Tensor ref = - at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 100 * 115, 5.f); - std::vector weight_buf(100 * 1 * 7, 6.f); - std::vector bias_buf(100, 11.f); - std::vector result_buf(1 * 100 * 115, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv1d_int) { - // A similar test, but now using kInt tensors - BufHandle Input("Input", {1, 100, 115}, kInt); - BufHandle Weight("Weight", {100, 1, 7}, kInt); - BufHandle Bias("Bias", {100}, kInt); - BufHandle ResultBuf("Result", {1, 100, 115}, kInt); - int64_t stride = 1; - int64_t pad = 3; - int64_t dilation = 1; - int64_t groups = 100; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv1d", - {Input, Weight, Bias}, - {stride, pad, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kInt) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 100, 115}, options) * 5; - at::Tensor weight = at::ones({100, 1, 7}, options) * 6; - at::Tensor bias = at::ones({100}, options) * 11; - at::Tensor ref = - at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 100 * 115, 5); - std::vector weight_buf(100 * 1 * 7, 6); - std::vector bias_buf(100, 11); - std::vector result_buf(1 * 100 * 115, -1); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv1d_nobias_noargs) { - BufHandle Input("Input", {1, 1, 115}, kFloat); - BufHandle Weight("Weight", {10, 1, 7}, kFloat); - BufHandle ResultBuf("Result", {1, 10, 109}, kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 1, 115}, options) * 5.f; - at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f; - at::Tensor ref = at::conv1d(input, weight); - - at::Tensor nnc_result; - std::vector input_buf(1 * 1 * 115, 5.f); - std::vector weight_buf(10 * 1 * 7, 6.f); - std::vector result_buf(1 * 10 * 109, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); - - llvm_codegen.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); - - ir_eval.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_float) { - BufHandle Input("Input", {1, 3, 224, 224}, kFloat); - BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat); - BufHandle Bias("Bias", {16}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv2d", - {Input, Weight, Bias}, - {stride, stride, pad, pad, dilation, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f; - at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f; - at::Tensor bias = at::ones({16}, options) * 11.f; - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 3 * 224 * 224, 5.f); - std::vector weight_buf(16 * 3 * 3 * 3, 6.f); - std::vector bias_buf(16, 11.f); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_int) { - // A similar test, but now using kInt tensors - - BufHandle Input("Input", {1, 3, 224, 224}, kInt); - BufHandle Weight("Weight", {16, 3, 3, 3}, kInt); - BufHandle Bias("Bias", {16}, kInt); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv2d", - {Input, Weight, Bias}, - {stride, stride, pad, pad, dilation, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kInt) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5; - at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6; - at::Tensor bias = at::ones({16}, options) * 11; - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 3 * 224 * 224, 5); - std::vector weight_buf(16 * 3 * 3 * 3, 6); - std::vector bias_buf(16, 11); - std::vector result_buf(1 * 16 * 112 * 112, -1); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_nobias_noargs) { - BufHandle Input("Input", {1, 16, 112, 112}, kFloat); - BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f; - at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; - at::Tensor ref = at::conv2d(input, weight); - - at::Tensor nnc_result; - std::vector input_buf(1 * 16 * 112 * 112, 5.f); - std::vector weight_buf(16 * 16 * 1 * 1, 6.f); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); - - llvm_codegen.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); - - ir_eval.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Addmm_float) { - BufHandle Input("Input", {100, 300}, kFloat); - BufHandle Mat1("Mat1", {100, 200}, kFloat); - BufHandle Mat2("Mat2", {200, 300}, kFloat); - BufHandle ResultBuf("Result", {100, 300}, kFloat); - int64_t beta = 2; - int64_t alpha = 2; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({100, 300}, options) * 5.f; - at::Tensor mat1 = at::ones({100, 200}, options) * 6.f; - at::Tensor mat2 = at::ones({200, 300}, options) * 11.f; - at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha); - - at::Tensor nnc_result; - std::vector input_buf(100 * 300, 5.f); - std::vector mat1_buf(100 * 200, 6.f); - std::vector mat2_buf(200 * 300, 11.f); - std::vector result_buf(100 * 300, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result}); - - llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result}); - - ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Embedding) { - BufHandle Weight("Weight", {256, 100}, kFloat); - BufHandle Indices("Indices", {1, 115}, kLong); - BufHandle ResultBuf("Result", {1, 115, 100}, kFloat); - int64_t padding_idx = -1; - bool scale_grad_by_freq = false; - bool sparse = false; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_embedding", - {Weight, Indices}, - {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - - at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f; - at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6; - at::Tensor ref = - at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); - - at::Tensor nnc_result; - std::vector weight_buf(256 * 100, 5.f); - std::vector indices_buf(1 * 115, 6); - std::vector result_buf(1 * 115 * 100, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result}); - - llvm_codegen.call({weight_buf, indices_buf, result_buf}); - nnc_result = at::from_blob( - result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result}); - - ir_eval.call({weight_buf, indices_buf, result_buf}); - nnc_result = at::from_blob( - result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, MaxReduction) { - BufHandle Input("Input", {1, 115, 152}, kFloat); - BufHandle ResultBuf("Result", {1, 152}, kFloat); - int64_t dim = 1; - bool keep_dim = false; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - - at::Tensor input = at::ones({1, 115, 152}, options) * 5.f; - at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim)); - - at::Tensor nnc_result; - std::vector input_buf(1 * 115 * 152, 5.f); - std::vector result_buf(1 * 152, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result}); - - llvm_codegen.call({input_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result}); - - ir_eval.call({input_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -#ifdef USE_XNNPACK - -TEST(ExternalCall, Prepacked_Linear_float) { - using namespace at::native::xnnpack; - - BufHandle Input("Input", {100, 200}, kFloat); - BufHandle ResultBuf("Result", {100, 300}, kFloat); - - // Calculate reference result using at::linear. - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = - at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200}); - at::Tensor weight = - at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200}); - at::Tensor bias = at::linspace(-10.0, 10.0, 300, options); - at::Tensor ref = at::linear(input, weight, bias); - - // Create prepacked xnnpack context object. - auto linear_clamp_prepack_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("prepacked::linear_clamp_prepack", "") - .typed( - at::Tensor, - std::optional, - const std::optional&, - const std::optional&)>(); - auto prepacked = linear_clamp_prepack_op.call( - weight, bias, std::optional(), std::optional()); - - BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_prepacked_linear_clamp_run", - {Input, DummyPrepacked}, - {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - at::Tensor nnc_result; - std::vector input_buf( - input.data_ptr(), input.data_ptr() + 100 * 200); - std::vector result_buf(100 * 300, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); - - llvm_codegen.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); - - ir_eval.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Prepacked_Conv2d_float) { - using namespace at::native::xnnpack; - - BufHandle Input("Input", {1, 3, 224, 224}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - // Calculate reference result using at::conv2d. - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options) - .resize_({1, 3, 224, 224}); - at::Tensor weight = - at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3}); - at::Tensor bias = at::linspace(-10.0, 10.0, 16, options); - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - // Create prepacked xnnpack context object. - auto conv2d_clamp_prepack_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "") - .typed( - at::Tensor, - std::optional, - std::vector, - std::vector, - std::vector, - int64_t, - const std::optional&, - const std::optional&)>(); - auto prepacked = conv2d_clamp_prepack_op.call( - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups, - std::optional(), - std::optional()); - - BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_prepacked_conv2d_clamp_run", - {Input, DummyPrepacked}, - {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - at::Tensor nnc_result; - std::vector input_buf( - input.data_ptr(), input.data_ptr() + 1 * 3 * 224 * 224); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); - - llvm_codegen.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); - - ir_eval.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); -} - -#endif // USE_XNNPACK - -TEST(ExternalCall, BinaryFloat) { - using TensorFunc = std::function; - using Test = std::tuple< - std::vector, - std::vector, - std::vector, - TensorFunc, - std::string>; - std::vector tests = {}; - tests.push_back( - Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"}); - tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"}); - tests.push_back(Test{ - {100, 200}, - {200, 300}, - {100, 300}, - [&](const at::Tensor& a, const at::Tensor& b) { return at::mm(a, b); }, - "nnc_aten_mm"}); - for (auto curTest : tests) { - auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest; - auto toExprHandleVec = [](std::vector v) { - auto intV = std::vector(v.begin(), v.end()); - return std::vector(intV.begin(), intV.end()); - }; - BufHandle A("A", toExprHandleVec(aShape), kFloat); - BufHandle B("B", toExprHandleVec(bShape), kFloat); - BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, externCallName, {A, B}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; - at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f; - at::Tensor ref = torchFunc(a, b); - - auto prod = [](std::vector v) { - // NOLINTNEXTLINE(modernize-use-transparent-functors) - return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); - }; - - at::Tensor nnc_result; - std::vector a_buf(prod(aShape), 5.f); - std::vector b_buf(prod(bShape), 6.f); - std::vector result_buf(prod(resShape), -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result}); - - llvm_codegen.call({a_buf, b_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result}); - ir_eval.call({a_buf, b_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); - } -} - -TEST(ExternalCall, UnaryFloat) { - using TensorFunc = std::function; - auto toExprHandleVec = [](std::vector v) { - auto intV = std::vector(v.begin(), v.end()); - return std::vector(intV.begin(), intV.end()); - }; - using Test = std::tuple< - std::vector, - std::vector, - TensorFunc, - std::string, - std::vector>; - std::vector tests = {}; - tests.push_back(Test{ - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - {1, 64, 8, 9}, - {1, 64, 5, 7}, - [](at::Tensor x) { return at::adaptive_avg_pool2d(x, {5, 7}); }, - "nnc_aten_adaptive_avg_pool2d", - toExprHandleVec({5, 7})}); - tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - {100, 200}, - {100}, - [](at::Tensor x) { return at::mean(x, {1}); }, - "nnc_aten_mean", - toExprHandleVec({1, /*keepdim=*/0})}); - for (auto curTest : tests) { - auto [aShape, resShape, torchFunc, externCallName, externCallArgs] = - curTest; - BufHandle A("A", toExprHandleVec(aShape), kFloat); - BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs)); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; - at::Tensor ref = torchFunc(a); - - auto prod = [](std::vector v) { - // NOLINTNEXTLINE(modernize-use-transparent-functors) - return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); - }; - - at::Tensor nnc_result; - std::vector a_buf(prod(aShape), 5.f); - std::vector result_buf(prod(resShape), -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result}); - - llvm_codegen.call({a_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result}); - ir_eval.call({a_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); - } -} - -TEST(ExternalCall, ComputeInterop) { - // This test verifies that Tensors using external calls can be used by and can - // use Tensors built with Compute API. - - BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat); - BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat); - - Tensor Input = Compute( - "Input", - {1, 16, 32, 32}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { return FloatImm::make(5.0f); }); - Tensor Weight = Compute( - "Weight", - {16, 16, 1, 1}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { return FloatImm::make(6.0f); }); - - Tensor ConvResult = Tensor( - ConvResultBuf.node(), - ExternalCall::make( - ConvResultBuf, - "nnc_aten_conv2d", - {BufHandle(Input.buf()), BufHandle(Weight.buf())}, - {})); - Tensor MatmulResult = Tensor( - MatmulResultBuf.node(), - ExternalCall::make( - MatmulResultBuf, - "nnc_aten_matmul", - {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())}, - {})); - Tensor Result = Compute( - "Result", - {1, 16, 32, 32}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { - return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w); - }); - - LoopNest l({Input, Weight, ConvResult, MatmulResult, Result}); - - // Inlining should not inline anything here since all Bufs are either defined - // or used in ExternalCalls - we run it just for testing - l.inlineIntermediateBufs(true); - - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f; - at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; - at::Tensor t = at::conv2d(input, weight); - at::Tensor t2 = at::matmul(t, t); - at::Tensor ref = t + t2; - - at::Tensor nnc_result; - std::vector input_buf(1 * 16 * 32 * 32, 5.f); - std::vector weight_buf(16 * 16 * 1 * 1, 6.f); - std::vector conv_result_buf(1 * 16 * 32 * 32, -1.f); - std::vector matmul_result_buf(1 * 16 * 32 * 32, -1.f); - std::vector result_buf(1 * 16 * 32 * 32, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen( - l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); - - llvm_codegen.call( - {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval( - l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); - - ir_eval.call( - {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Inlining) { - // This test verifies that Tensors using external calls can be used by and - // can use Tensors built with Compute API. - - BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat); - - Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return FloatImm::make(5.0f); - }); - Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return FloatImm::make(4.0f); - }); - Tensor MatmulResult = Tensor( - MatmulResultBuf.node(), - ExternalCall::make( - MatmulResultBuf, - "nnc_aten_matmul", - {BufHandle(A.buf()), BufHandle(B.buf())}, - {})); - Tensor Result = - Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return MatmulResult.load(i, j) + FloatImm::make(3.0f); - }); - - StmtPtr root_stmt = alloc(std::vector( - {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()})); - LoopNest l(root_stmt, {Result.buf()}); - - // Inlining should not inline anything here since all Bufs are either - // defined or used in ExternalCalls - l.inlineIntermediateBufs(false); - - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones({8, 8}, options) * 5.f; - at::Tensor b = at::ones({8, 8}, options) * 4.f; - at::Tensor t = at::matmul(a, b); - at::Tensor ref = t + 3.f; - - at::Tensor nnc_result; - std::vector result_buf(8 * 8); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Result}); - - llvm_codegen.call({result_buf}); - nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Result}); - - ir_eval.call({result_buf}); - nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, JitCustomFusionOp) { - const char* custom_op_schema_literal = - "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor"; - const char* external_func_name = "nnc_add_mul"; - - auto add_mul_lowering_func = - [external_func_name]( - const std::vector& inputs, - const std::vector& output_shape, - const std::vector& output_strides, - const std::optional& output_type, - at::Device device) { - auto output_dtype = Dtype(*output_type); - torch::jit::tensorexpr::BufHandle result_buf( - "nnc_add_mul_res_buf", output_shape, output_dtype); - const torch::jit::tensorexpr::BufHandle& a = - std::get(inputs[0]); - const torch::jit::tensorexpr::BufHandle& b = - std::get(inputs[1]); - const torch::jit::tensorexpr::BufHandle& c = - std::get(inputs[1]); - torch::jit::tensorexpr::StmtPtr s = - torch::jit::tensorexpr::ExternalCall::make( - result_buf, external_func_name, {a, b, c}, {}); - return Tensor(result_buf.node(), s); - }; - - auto add_mul_external_func = [](int64_t bufs_num, - void** buf_data, - int64_t* buf_ranks, - int64_t* buf_dims, - int64_t* buf_strides, - int8_t* buf_dtypes, - int64_t args_num, - int64_t* extra_args) {}; - - torch::jit::RegisterOperators reg({Operator( - custom_op_schema_literal, - [](const Node* node) -> Operation { - return [](Stack& _stack) { - auto a = std::move(peek(_stack, 0, 3)).toTensor(); - auto b = std::move(peek(_stack, 1, 3)).toTensor(); - auto c = std::move(peek(_stack, 2, 3)).toTensor(); - drop(_stack, 3); - auto result = (a + b) * c; - pack(_stack, std::move(result)); - return 0; - }; - }, - c10::AliasAnalysisKind::FROM_SCHEMA)}); - - auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet(); - custom_operator_set.insert({custom_op_schema_literal}); - - auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry(); - te_lowering_registry.insert( - parseSchema(custom_op_schema_literal), add_mul_lowering_func); - - auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry(); - te_nnc_func_registry[external_func_name] = add_mul_external_func; - - std::string graph_string = R"IR( - graph(%a : Float(10, 20, strides=[20, 1], device=cpu), - %b : Float(10, 20, strides=[20, 1], device=cpu), - %c : Float(10, 20, strides=[20, 1], device=cpu)): - %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c) - return (%res))IR"; - - auto graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::string shape_compute_python_string = R"PY( - def computOutput(a: List[int], b: List[int], c: List[int]): - expandedSizes: List[int] = [] - dimsA = len(a) - dimsB = len(b) - dimsC = len(c) - ndim = max(dimsA, dimsB, dimsC) - for i in range(ndim): - offset = ndim - 1 - i - dimA = dimsA - 1 - offset - dimB = dimsB - 1 - offset - dimC = dimsC - 1 - offset - sizeA = a[dimA] if (dimA >= 0) else 1 - sizeB = b[dimB] if (dimB >= 0) else 1 - sizeC = a[dimC] if (dimC >= 0) else 1 - - if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1: - # TODO: only assertion error is bound in C++ compilation right now - raise AssertionError( - "The size of tensor a {} must match the size of tensor b (" - "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i) - ) - - expandedSizes.append(max(sizeA, sizeB, sizeC)) - - return expandedSizes - )PY"; - auto cu_ptr = torch::jit::compile(shape_compute_python_string); - torch::jit::GraphFunction* gf = - (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput"); - ASSERT_TRUE(gf); - -#ifdef TORCH_ENABLE_LLVM - auto static_graph_case = graph->copy(); - FuseTensorExprs(static_graph_case, 1); - torch::jit::testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("nnc_custom::add_mul") - ->run(*static_graph_case); - - auto dynamic_graph_case = graph->copy(); - auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal); - ASSERT_TRUE(custom_op); - torch::jit::RegisterShapeComputeGraphForSchema( - custom_op->schema(), gf->graph()); - FuseTensorExprs(dynamic_graph_case, 1, false, true); - torch::jit::testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("nnc_custom::add_mul") - ->run(*dynamic_graph_case); -#else - torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph); -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_graph_opt.cpp b/test/cpp/tensorexpr/test_graph_opt.cpp deleted file mode 100644 index aed73d09d14d5..0000000000000 --- a/test/cpp/tensorexpr/test_graph_opt.cpp +++ /dev/null @@ -1,319 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -class GraphOpt : public ::testing::Test { - public: - void SetUp() override { - old_cat_wo_conditionals_ = getCatWoConditionals(); - getCatWoConditionals() = true; - } - - void TearDown() override { - getCatWoConditionals() = old_cat_wo_conditionals_; - } - - private: - bool old_cat_wo_conditionals_; -}; - -TEST_F(GraphOpt, OptimizeCat) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::log` op must be moved to the inputs of `aten::cat`. - testing::FileCheck() - .check("aten::log") - ->check("aten::log") - ->check("aten::log") - ->check("aten::cat") - ->check_not("aten::log") - ->run(*kernel.graph()); - - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::log(at::cat({x, y, z}, 0)); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCat2) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) - %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::log` and `aten::tanh` ops must be moved to the inputs of - // `aten::cat`. - testing::FileCheck() - .check("aten::log") - ->check("aten::log") - ->check("aten::log") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check_not("aten::log") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::tanh(at::log(at::cat({x, y, z}, 0))); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCat3) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%a : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) - %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::tanh` op must be moved to the inputs of `aten::cat`. - // But the `aten::mul` op must not be moved since it is not a single-tensor - // op (it has 2 tensor inputs). - testing::FileCheck() - .check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check("aten::mul") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto a = at::rand({60}, at::kFloat); - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::tanh(at::cat({x, y, z}, 0)) * a; - - std::vector inputs = {a, x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Int(10, strides=[1], device=cpu), - %y : Int(20, strides=[1], device=cpu), - %z : Int(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::tanh` op must be moved to the inputs of `aten::cat`. - // The scalar type of the inputs to `cat` should now be `Float` since they - // are the result of `tanh` which does the type promotion. - testing::FileCheck() - .check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto x = at::randint(std::numeric_limits::max(), {10}, at::kInt); - auto y = at::randint(std::numeric_limits::max(), {20}, at::kInt); - auto z = at::randint(std::numeric_limits::max(), {30}, at::kInt); - auto ref = at::tanh(at::cat({x, y, z}, 0)); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Double(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation should have happened because the `aten::cat` op performs - // type promotion. This case is currently not handled. - testing::FileCheck() - .check("aten::cat") - ->check("aten::log") - ->check_not("aten::cat") - ->check_not("aten::log") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation is expected since the consumers of cat are not - // single-tensor element-wise ops. - testing::FileCheck() - .check("aten::cat") - ->check("aten::mul") - ->check_not("aten::cat") - ->check_not("aten::mul") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(60, strides=[1], device=cpu), - %1 : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %one : int = prim::Constant[value=1]() - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) - %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation is expected since the consumers of cat are not - // single-tensor element-wise ops. - testing::FileCheck() - .check("aten::cat") - ->check("aten::mul") - ->check("aten::add") - ->check_not("aten::cat") - ->check_not("aten::mul") - ->check_not("aten::add") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, AOTGraphPrepPasses) { - const auto graph_string = R"IR( - graph(%x, %y, %z, %i : int): - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - return (%xyz_list, %i))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - removeGraphOutput(g, 1); - replaceListOutputWithTuple(g); - LowerAllTuples(g); - - testing::FileCheck().check("return (%x, %y, %z)")->run(*g); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp deleted file mode 100644 index 4d2f8c6e906ee..0000000000000 --- a/test/cpp/tensorexpr/test_ir_printer.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include - -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -#include -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(IRPrinter, BasicValueTest) { - ExprHandle a = IntImm::make(2), b = IntImm::make(3); - ExprHandle c = Add::make(a, b); - - std::stringstream ss; - ss << c; - ASSERT_EQ(ss.str(), "2 + 3"); -} - -TEST(IRPrinter, BasicValueTest02) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - - std::stringstream ss; - ss << f; - ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)"); -} - -TEST(IRPrinter, BasicValueTest03) { - ExprHandle a(3.402823466385289e+38f); - ExprHandle b(-3.402823466385289e+38f); - std::stringstream ss; - ss << a << ", " << b; - ASSERT_EQ(ss.str(), "3.402823466385289e+38f, -3.402823466385289e+38f"); -} - -TEST(IRPrinter, CastTest) { - VarHandle x("x", kHalf); - VarHandle y("y", kFloat); - ExprHandle body = ExprHandle(2.f) + - (Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y); - - std::stringstream ss; - ss << body; - ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)"); -} - -TEST(IRPrinter, FunctionName) { - int M = 4; - int N = 20; - - Tensor producer = Compute( - "producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return m * n; - }); - - Tensor chunk_0 = Compute( - "chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { - return producer.load(m, n); - }); - - Tensor chunk_1 = Compute( - "chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { - return producer.load(m, n + ExprHandle(N / 2)); - }); - - Tensor consumer = Compute( - "consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) { - return i * chunk_1.load(i, j); - }); - - LoopNest l({chunk_0, chunk_1, consumer}); - auto body = LoopNest::sanitizeNames(l.root_stmt()); - - std::stringstream ss; - ss << *body; - - const std::string& verification_pattern = - R"IR( - # CHECK: for (int i_2 - # CHECK: for (int j_2 - # CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, ss.str()); -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_verifier.cpp b/test/cpp/tensorexpr/test_ir_verifier.cpp deleted file mode 100644 index 886213ea9c760..0000000000000 --- a/test/cpp/tensorexpr/test_ir_verifier.cpp +++ /dev/null @@ -1,191 +0,0 @@ -#include - -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -#include -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(IRVerifier, BitwiseOps) { - VarPtr X = alloc("x", kInt); - VarPtr Y = alloc("y", kFloat); - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, CompareSelect) { - ExprPtr X = alloc(1); - ExprPtr Y = alloc(3.14f); - { - auto a = alloc(X, X, X, Y, kEQ); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y, X, X, kEQ); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Ramp) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kFloat); - { - auto a = alloc(I, J, 4); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Load) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - BufPtr B = alloc( - "b", - std::vector({alloc(10), alloc(20)}), - kFloat); - { - // Indices with different int dtypes (kInt, kLong) are ok - auto a = alloc(B, std::vector({I, J})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_NO_THROW(verify(a)); - } - { - // Float index - auto a = alloc(B, std::vector({K, K})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Multilanes are only allowed in flattened indices - auto multilane_index = alloc(I, alloc(1), 4); - auto a = alloc(B, std::vector({I, multilane_index})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, IfThenElse) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - { - // Condition must be integral - auto a = alloc(K, I, I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Dtypes of true and false exprs must match - auto a = alloc(I, I, J); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Can't have multiple lanes in condition expr - auto a = alloc(alloc(I, 4), I, I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, For) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kInt); - StmtPtr body = alloc(std::vector({})); - { - // Can't have nullptr as a Var - auto a = alloc(nullptr, I, J, body); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Block) { - VarPtr I = alloc("i", kInt); - BufPtr B = alloc("B", std::vector({alloc(10)}), kInt); - { - StmtPtr store = alloc(B, std::vector({I}), I); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - StmtPtr block1 = alloc(std::vector({store})); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - StmtPtr block2 = alloc(std::vector({store})); - // Stmt can't have multiple parents, thus inserting it into several blocks - // is illegal - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(block2)); - } -} - -TEST(IRVerifier, Store) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - BufPtr B = alloc( - "b", - std::vector({alloc(10), alloc(20)}), - kFloat); - { - // Indices with different int dtypes (kInt, kLong) are ok - auto a = alloc(B, std::vector({I, J}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_NO_THROW(verify(a)); - } - { - // Float index - auto a = alloc(B, std::vector({K, K}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Multilanes are only allowed in flattened indices - auto multilane_index = alloc(I, alloc(1), 4); - auto a = alloc(B, std::vector({I, multilane_index}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Value and buf dtypes mismatch - auto a = alloc(B, std::vector({I}), I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp deleted file mode 100644 index dc67928b111a0..0000000000000 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ /dev/null @@ -1,2133 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -class Kernel : public ::testing::Test { - public: - void SetUp() override { - getTEMustUseLLVMOnCPU() = false; - } -}; - -TEST_F(Kernel, ParallelExternalCallBuf) { - const auto graph_string = R"IR( - graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu), - %1 : Float(1000, 5000, strides=[5000, 1], device=cpu), - %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)): - %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1) - %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2) - return (%4))IR"; - auto graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); -#ifdef TORCH_ENABLE_LLVM - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR"; - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -#endif -} - -TEST_F(Kernel, InliningIntermediates) { - // here, each mul has only one use, so it should be completely inlined - { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %one : int = prim::Constant[value=1]() - %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one) - return (%5))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); - } - { - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=${device}), - %1 : Float(5, 3, strides=[3, 1], device=${device})): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %one : int = prim::Constant[value=1]() - %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one) - %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one) - %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0) - return (%4, %5))IR"; - for (bool use_cuda : {false, true}) { - if (!torch::cuda::is_available() && use_cuda) { - continue; - } - - at::jit::TemplateEnv env; - env.s("device", use_cuda ? "cuda:0" : "cpu"); - const auto graph_string = format(graph_template, env); - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - // aten_mul only has one use, inlined completely - torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); - - // aten_sub should be removed by the CUDA backend by metavar rewriting - // and by the CPU backend by horizontal fusion. - torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str()); - } - } -} - -TEST_F(Kernel, PreAllocIntermediateBufs) { - const auto graph_string = R"IR( -graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu), - %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)): - %2 : int = prim::Constant[value=1]() - %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12 - %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15 - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::matmul(a, b) + a; - TensorExprKernel k(graph, {}, {}, true); - - std::vector inputs = {a, b}; - auto stmt = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *stmt; - - // Check whether the intermediate buffer has been added to constants - auto constants = k.getConstantDescriptors(); - ASSERT_EQ(constants.size(), 1); - - // Check the IR we produced - torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str()); - torch::jit::testing::FileCheck().check_not("Free")->run(oss.str()); - - // Check correctness - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, _1) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, _2) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, _3) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[12, 2], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2), Slice(None, None, 2)}); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, Huge) { - const auto graph_string = R"IR( - graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)): - %1 : int = prim::Constant[value=0]() - %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1) - %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - std::ostringstream oss; - oss << *k.getCodeGenStmt(); - // The 4000000000 iterations loop will be split into 500000000 x 8 and the - // outer loop will be parallel. If LLVM is not present, it will not be split, - // and to cover both of these cases we're looking for 00000000ll; in the - // output. - const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST_F(Kernel, ParallelStrided) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), - %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)): - %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat)) - .index( - {Slice(None, None, 2), - Slice(None, None, 2), - Slice(None, None, 2)}); - auto ref = a * (a * b); - auto o = at::zeros_like(ref); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, DISABLED_Shape_Inference) { - // disabled: doesn't do stride propagation, and isn't being used currently - - // Test TensorExpr shape inference capabilities: it should only require shapes - // for the inputs - { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[12, 2], device=cpu)): - %2 : Tensor = aten::mul(%0, %1) - %3 : Tensor = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2), Slice(None, None, 2)}); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - const auto graph_string = R"IR( - graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), - %1 : Float(8, 8, strides=[8, 1], device=cpu)): - %2 : Tensor = aten::mul(%0, %1) - %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2) - %r : Tensor = aten::mul(%3, %4) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat)); - auto t = torch::chunk(a * b, 2, 1); - auto ref = t[0] * t[1]; - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - TORCH_CHECK_EQ(o.sizes()[0], 8); - TORCH_CHECK_EQ(o.sizes()[1], 4); - for (size_t i = 0; i < 8 * 4; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that shape inference handles aten::unsqueeze - - const auto graph_string = R"IR( - graph(%a : Float(4, 2, strides=[2, 1], device=cpu), - %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu), - %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)): - %one : int = prim::Constant[value=1]() - %minus_one : int = prim::Constant[value=-1]() - %three : int = prim::Constant[value=3]() - %minus_four : int = prim::Constant[value=-4]() - %a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2] - %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1] - %b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1] - %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2] - %ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1] - %abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2] - return (%abc))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) * - at::unsqueeze(c, -4); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_mul)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that shape inference handles aten::cat - - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({a, b, c}, 1); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that we throw an error when input list for aten::cat is empty - - const auto graph_string = R"IR( - graph(): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct() - %r : Tensor = aten::cat(%inputs, %dim) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - auto compile = [&]() { - TensorExprKernel k(graph); - k.getCodeGenStmt(); - }; - ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat"); - } - { - // Test that we throw an error when 'dim' passed to aten::cat is invalid - - const auto ir_dim_99 = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=99]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) - return (%r))IR"; - const auto ir_dim_minus_6 = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=-6]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) - return (%r))IR"; - - auto compile = [](const std::string& graph_string) { - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - k.getCodeGenStmt(); - }; - ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index"); - ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index"); - } -} - -TEST_F(Kernel, CatInputTypesPromotion) { - { - // Test that we properly promote input types for aten::cat - - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble)); - auto ref = at::cat({a, b, c}, 1); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); - } - } -} - -TEST_F(Kernel, ToDType) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): - %1 : NoneType = prim::Constant() - %2 : bool = prim::Constant[value=0]() - %3 : int = prim::Constant[value=6]() - %4 : int = prim::Constant[value=15]() - %5 : int = prim::Constant[value=5]() - %6 : bool = prim::Constant[value=1]() - %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1) - %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4) - %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6) - %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1) - %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1) - %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1) - return (%k.3))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_to -# CHECK-NEXT: } -# CHECK-NEXT: })IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16)); - auto ref = - at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat)); - - std::vector inputs = {a}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); -#endif -} - -TEST_F(Kernel, CatAndInlineWithAConstantDim) { - const auto graph_string = R"IR( - graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu), - %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)): - %2 : bool = prim::Constant[value=0]() - %3 : int = prim::Constant[value=1]() - %4 : Tensor[] = prim::ListConstruct(%0, %1) - %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3) - %6 : Tensor[] = prim::ListConstruct(%5) - %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3) - %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2) - return (%8, %7))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - - auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::_cast_Float(at::cat({a, b}, 1), 0); - - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, CatWithEmptyInputs) { - bool curr_cat_wo_conditionals = getCatWoConditionals(); - for (auto cat_wo_conditionals : {true, false}) { - getCatWoConditionals() = cat_wo_conditionals; - const auto graph_string = R"IR( - graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu), - %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)): - %3 : int = prim::Constant[value=0]() - %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0) - %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1) - %10 : Tensor[] = prim::ListConstruct(%6, %7) - %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3) - return (%11))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - - auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0); - - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } - getCatWoConditionals() = curr_cat_wo_conditionals; -} - -TEST_F(Kernel, CatWoConditionals) { - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) - return (%r))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK: for -# CHECK: for -# CHECK: aten_cat -# CHECK: for -# CHECK: for -# CHECK: aten_cat -# CHECK: for -# CHECK: for -# CHECK: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({a, b, c}, 1); - - std::vector inputs = {a, b, c}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - getCatWoConditionals() = old_cat_wo_conditionals; -} - -TEST_F(Kernel, OptimizeConditionals) { - bool old_cat_wo_conditionals = getCatWoConditionals(); - bool old_opt_conditionals = getOptConditionals(); - getCatWoConditionals() = false; - getOptConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(5, 3, strides=[3, 1], device=cpu), - %b : Float(5, 7, strides=[7, 1], device=cpu), - %c : Float(5, 9, strides=[9, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim) - %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r) - return (%t))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_relu -# CHECK: for -# CHECK-NEXT: aten_relu -# CHECK: for -# CHECK-NEXT: aten_relu -# CHECK-NOT: Allocate -# CHECK-NOT: Free)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat)); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::relu(at::cat({a, b, c}, 1)); - - std::vector inputs = {a, b, c}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - getOptConditionals() = old_opt_conditionals; - getCatWoConditionals() = old_cat_wo_conditionals; -} - -namespace { - -std::string dtypeConstant(ScalarType scalar_type) { - if (scalar_type == ScalarType::Undefined) { - return "None = prim::Constant()"; - } else { - at::jit::TemplateEnv env_dtype; - env_dtype.d("scalar_type", static_cast(scalar_type)); - return format("int = prim::Constant[value=${scalar_type}]()", env_dtype); - } -} - -at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { - int64_t numel = std::accumulate( - sizes.begin(), - sizes.end(), - 1, - // NOLINTNEXTLINE(modernize-use-transparent-functors) - std::multiplies()); - std::vector values(numel); - std::iota(values.begin(), values.end(), 0); - auto a = at::tensor(values, options); - return a.reshape(sizes); -} - -} // namespace - -TEST_F(Kernel, SumAllAxes) { - // Test lowering of sum on all axes. - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : ${dtype} - %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1) - return (%2))IR"; - auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { - at::jit::TemplateEnv env; - env.s("dtype", dtypeConstant(scalar_type)); - if (scalar_type == ScalarType::Undefined) { - env.s("out_dtype", "Float"); - } else { - env.s("out_dtype", "Double"); - } - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto o = at::empty({}, TensorOptions(kCPU)); - std::optional dtype; - if (scalar_type != ScalarType::Undefined) { - dtype = static_cast(scalar_type); - } - auto ref = a.sum(/*dtype=*/dtype); - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } -} - -std::string li_to_str(at::ArrayRef li) { - std::stringstream out; - bool first = true; - for (auto elem : li) { - if (!first) { - out << ", "; - } - out << elem; - first = false; - } - return out.str(); -} - -TEST_F(Kernel, SumOneAxis) { - // Test lowering of sum on one axis. - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : int[] = prim::Constant[value=[${dim}]]() - %2 : bool = prim::Constant[value=${keepdim}]() - %3 : ${dtype} - %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3) - return (%4))IR"; - auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - for (int dim = -a.dim(); dim < a.dim(); ++dim) { - for (bool keepdim : {false, true}) { - for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { - at::jit::TemplateEnv env; - env.d("dim", dim); - env.d("keepdim", keepdim); - env.s("dtype", dtypeConstant(scalar_type)); - std::optional dtype; - if (scalar_type != ScalarType::Undefined) { - dtype = static_cast(scalar_type); - } - auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype); - if (scalar_type == ScalarType::Undefined) { - env.s("out_dtype", "Float"); - } else { - env.s("out_dtype", "Double"); - } - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - const auto graph_string = format(graph_template, env); - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto o = at::empty({}, TensorOptions(kCPU)); - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t -# CHECK-NEXT: sum -# CHECK-NEXT: for (int64_t -# CHECK-NEXT: sum)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); - } - } - } -} - -TEST_F(Kernel, SumMultipleAxes) { - // Test lowering of sum on multiple axes. - const auto graph_template = R"IR( - graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)): - %1 : int = prim::Constant[value=${dim1}]() - %2 : int = prim::Constant[value=${dim2}]() - %3 : int[] = prim::ListConstruct(%1, %2) - %4 : bool = prim::Constant[value=${keepdim}]() - %5 : ${dtype} - %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5) - return (%6))IR"; - auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - // Only iterate over positive values of axes to keep the running time - // reasonable, since the number of pairs is quadratic. - for (const auto dim1 : c10::irange(a.dim())) { - for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) { - for (bool keepdim : {false, true}) { - at::jit::TemplateEnv env; - env.d("dim1", dim1); - env.d("dim2", dim2); - env.d("keepdim", keepdim); - env.s("dtype", dtypeConstant(ScalarType::Undefined)); - auto o = at::empty({}, TensorOptions(kCPU)); - auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); - - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: sum)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } - } - } -} - -// This test and the following ones testing Softmax only tests with dim set -// to one of the valid input dimensions. It does not test with dim=None -// because that is supposed to be deprecated. -TEST_F(Kernel, Softmax2D) { - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %dt_float : int = prim::Constant[value=7]() - %dt_none : NoneType = prim::Constant() - %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt}) - return (%4))IR"; - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 5 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 - # CHECK-NEXT: aten_softmax)IR"; - - for (bool empty_dtype : {false, true}) { - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - auto other_dim = (softmax_dim + 1) % a.dim(); - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - env.s("dt", empty_dtype ? "dt_none" : "dt_float"); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("other_dim", other_dim); - ver_env.d("other_dim_size", a.sizes()[other_dim]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = - format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, - // oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } - } -} - -TEST_F(Kernel, Softmax3D) { - const auto graph_template = R"IR( - graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %2 : int = prim::Constant[value=7]() - %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) - return (%3))IR"; - - auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} - # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 3 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 - # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 - # CHECK-NEXT: aten_softmax)IR"; - - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - std::vector other_dims; - for (const auto i : c10::irange(a.dim())) { - if (i != softmax_dim) { - other_dims.push_back(i); - } - } - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("dim1", other_dims[0]); - ver_env.d("dim1_size", a.sizes()[other_dims[0]]); - ver_env.d("dim2", other_dims[1]); - ver_env.d("dim2_size", a.sizes()[other_dims[1]]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } -} - -TEST_F(Kernel, Softmax4D) { - const auto graph_template = R"IR( - graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %2 : int = prim::Constant[value=7]() - %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) - return (%3))IR"; - - auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} - # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} - # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} - # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 2 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 - # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 - # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 - # CHECK-NEXT: aten_softmax)IR"; - - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - std::vector other_dims; - for (const auto i : c10::irange(a.dim())) { - if (i != softmax_dim) { - other_dims.push_back(i); - } - } - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("dim1", other_dims[0]); - ver_env.d("dim1_size", a.sizes()[other_dims[0]]); - ver_env.d("dim2", other_dims[1]); - ver_env.d("dim2_size", a.sizes()[other_dims[1]]); - ver_env.d("dim3", other_dims[2]); - ver_env.d("dim3_size", a.sizes()[other_dims[2]]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } -} - -TEST_F(Kernel, SignTest) { - const auto graph_template = R"IR( - graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)): - %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0) - return (%2))IR"; - - auto run_test = [](const std::string& graph_string, const at::Tensor& input) { - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - - std::vector inputs = {input}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = at::sign(input); - ASSERT_TRUE(at::allclose(o, ref)); - }; - auto common_options = at::TensorOptions() - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - int default_input_size = 100; - for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) { - at::Tensor corner_case_inputs; - at::jit::TemplateEnv env; - auto options = common_options; - switch (scalar_type) { - case ScalarType::Float: { - env.s("dtype", "Float"); - options = options.dtype(at::kFloat); - std::vector input_float = { - 0.0f, - -0.0f, - std::numeric_limits::infinity(), - -std::numeric_limits::infinity(), - std::nanf("1"), - -std::nanf("1")}; - corner_case_inputs = at::from_blob( - input_float.data(), - {static_cast(input_float.size())}, - options); - auto rand_input = at::rand({default_input_size}, options); - auto input = at::cat({rand_input, corner_case_inputs}); - env.d("size", at::numel(input)); - const auto graph_string = format(graph_template, env); - run_test(graph_string, input); - break; - } - case ScalarType::Double: { - env.s("dtype", "Double"); - options = options.dtype(at::kDouble); - std::vector input_double = { - 0.0, - -0.0, - std::numeric_limits::infinity(), - -std::numeric_limits::infinity(), - std::nan("1"), - -std::nan("1")}; - corner_case_inputs = at::from_blob( - input_double.data(), - {static_cast(input_double.size())}, - options); - auto rand_input = at::rand({default_input_size}, options); - auto input = at::cat({rand_input, corner_case_inputs}); - env.d("size", at::numel(input)); - const auto graph_string = format(graph_template, env); - run_test(graph_string, input); - break; - } - default: - throw unsupported_dtype(); - } - } -} - -TEST_F(Kernel, InlineProducerIntoReduction) { - // Inline producer (mul) into reduction (sum). - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1) - %3 : int = prim::Constant[value=7]() - %4 : Double(device=cpu) = aten::sum(%2, %3) - return (%4))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - // Check the IR we produced. - // We should have only one loop in the end. - const std::string& verification_pattern = - R"IR( - # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 - # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 - # CHECK-NEXT: sum - # CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = (a * b).sum(at::kDouble); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, InlineReductionIntoConsumer) { - // Inline producer (mul %2) into reduction (sum %4) but DO NOT - // inline the reduction into consumer (mul %4). - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : int = prim::Constant[value=6]() - %4 : Float(device=cpu) = aten::sum(%2, %3) - %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4) - return (%5))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - // Check the IR we produced. - // We should have two loops in the end. - const std::string& verification_pattern = - R"IR( - # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 - # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 - # CHECK-NEXT: sum - # CHECK: for (int64_t i_2 = 0ll; i_2 < 5 - # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3 - # CHECK-NEXT: aten_mul - # CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = (a * b).sum(at::kFloat) * (a * b); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, SanitizeNames_CUDA) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0), - %1 : Float(5, 3, strides=[3, 1], device=cuda:0)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%4))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - graph->inputs().at(0)->setDebugName("aten::add:"); - graph->inputs().at(1)->setDebugName("aten::add_"); - TensorExprKernel k(graph); - auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto ref = a * (a * b); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, SanitizeConstants_CUDA) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)): - %none : NoneType = prim::Constant() - %size : int = prim::Constant[value=16]() - %sizes : int[] = prim::ListConstruct(%size, %size) - %30 : Device = prim::Constant[value="cuda"]() - %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none) - %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we insert a call to - // aten::ones and then const-prop it - ConstantPropagation(graph); - - // We set the name of the constant to include special characters that are - // not allowed. This should be fixed by the sanitizer in TensorExprKernel. - graph->nodes().front()->output()->setDebugName("illegal.name"); - - // Check if we have a constant node with illegal name in the graph. - auto const_node = graph->nodes().front(); - ASSERT_EQ(const_node->kind(), prim::Constant); - ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, ConstantTensors) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %size : int = prim::Constant[value=16]() - %sizes : int[] = prim::ListConstruct(%size, %size) - %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we insert a call to - // aten::ones and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, ConstantTensorsNonContiguous) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %dtype : int = prim::Constant[value=6]() - %c0 : int = prim::Constant[value=0]() - %c256 : int = prim::Constant[value=256]() - %c16 : int = prim::Constant[value=16]() - %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) - %sizes : int[] = prim::ListConstruct(%c16, %c16) - %y_t : Tensor = aten::view(%y_flat, %sizes) - %y : Tensor = aten::t(%y_t) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we generate several aten - // calls to produce non-contiguous constant tensor and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat)) - .view({16, 16}) - .t(); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, RunFast) { -#ifdef TORCH_ENABLE_LLVM - // TODO: Implement call_raw in IREval and remove the ifdef - - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - - k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()}); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, RunWithAllocatedOutputs) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - - std::vector args = {o, a, b}; - std::vector stack = fmap(args); - k.runWithAllocatedOutputs(stack); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, CodegenInspection) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %dtype : int = prim::Constant[value=6]() - %c0 : int = prim::Constant[value=0]() - %c256 : int = prim::Constant[value=256]() - %c16 : int = prim::Constant[value=16]() - %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) - %sizes : int[] = prim::ListConstruct(%c16, %c16) - %y_t : Tensor = aten::view(%y_flat, %sizes) - %y : Tensor = aten::t(%y_t) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we generate several aten - // calls to produce non-contiguous constant tensor and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - // Check that we could retrieve generated assembly - auto asm_str = k.getCodeText("asm"); - const std::string& asm_verification_pattern = - R"ASM( - # CHECK: .text - # CHECK: retq)ASM"; - torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str); - - // Check that we could retrieve info about codegen parameters - auto constants = k.getConstantDescriptors(); - auto buf_args = k.getBufferArgs(); - // Expected buf args: [input0, output0, constant0] - ASSERT_EQ(buf_args.size(), 3); - ASSERT_EQ(constants.size(), 1); - ASSERT_TRUE( - !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar()); -#endif -} - -Tensor lowerNanToNum( - const std::vector& inputs, - const std::vector& outputShape, - const std::vector& outputStrides, - const std::optional& outputType, - at::Device device) { - auto input_buf = std::get(inputs[0]); - auto e = Compute( - "custom_nan_to_num", - outputShape, - outputStrides, - [&](const std::vector& axes) { - std::vector indices(axes.begin(), axes.end()); - auto load = input_buf.load(indices); - return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load); - }); - return e; -} - -TEST_F(Kernel, CustomLowering) { - const auto graph_string = R"IR( - graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): - %none : NoneType = prim::Constant() - %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) - return (%y) -)IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - std::unordered_map lowerings = { - {aten::nan_to_num, lowerNanToNum}}; - TensorExprKernel k(graph, lowerings); - - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - - // Check that our custom lowering is actually used - torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str()); - torch::jit::testing::FileCheck().check("isnan")->run(oss.str()); -} - -TEST_F(Kernel, Vectorize) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(100, 16, strides=[16, 1], device=cpu), - %1 : Float(100, 16, strides=[16, 1], device=cpu)): - %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1) - %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 100 * 16; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first. -TEST_F(Kernel, DISABLED_FlattenVectorize) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(100, 3, strides=[3, 1], device=cpu), - %1 : Float(100, 3, strides=[3, 1], device=cpu)): - %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 100 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, Strided1dWithinBounds) { - auto ir = R"IR( - graph(%0 : Float(3, strides=[1], device=cpu), - %1 : Float(3, strides=[2], device=cpu)): - %2 : int = prim::Constant[value=1]() - %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2) - return (%3))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - - auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2)}); - auto expect = a + b; - - std::vector inputs = {a, b}; - - std::vector stack = fmap(inputs); - k.run(stack); - - auto output = stack[0].toTensor(); - - for (size_t i = 0; i < 3; ++i) { - TORCH_CHECK_EQ( - ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]); - } -} - -TEST_F(Kernel, InputAsOutput) { - const auto graph_string = R"IR( - graph(%x : Float(5, 3, strides=[3, 1], device=cpu), - %y : Float(5, 3, strides=[1, 5], device=cpu)): - return (%x, %y))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - TensorExprKernel k(graph); - std::vector inputs = {x, y}; - - std::vector stack = fmap(inputs); - k.run(stack); - CHECK(at::allclose(x, stack[0].toTensor())); - CHECK(at::allclose(y, stack[1].toTensor())); -} - -TEST_F(Kernel, ScalarOut) { - auto ir = R"IR( -graph(%x : int, %y : int): - %z : int = aten::mul(%x, %y) - %r : int = aten::mul(%z, %x) - return (%r, %z))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - - // Verify the generated IR. We expect to see a scalar variable (Let) followed - // by a store to a 0-dim buffer. - const std::string& verification_pattern = R"IR( -# CHECK: int64_t -# CHECK-NEXT: [0ll] = -# CHECK-NEXT: int64_t -# CHECK-NEXT: [0ll] = -)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - int64_t x = 2, y = 3, r = 0, z = 0; - - // Verify that TEK::runFast works correctly with scalar outputs - std::vector inputs = {&x, &y}; - std::vector outputs = {&r, &z}; - k.runFast(inputs, outputs); - TORCH_CHECK_EQ(z, x * y); - TORCH_CHECK_EQ(r, z * x); - - // Verify that TEK::run works correctly with scalar outputs - std::vector stack = {x, y}; - k.run(stack); - TORCH_CHECK_EQ(stack[0], x * y * x); - TORCH_CHECK_EQ(stack[1], x * y); -} - -TEST_F(Kernel, ScalarTensorOut) { - auto ir = R"IR( -graph(%x : int, - %xt : Long(3, strides=[1], device=cpu), - %y : int, - %yt : Long(3, strides=[1], device=cpu)): - %z : int = aten::mul(%x, %y) - %r : int = aten::mul(%z, %x) - %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y) - %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt) - return (%r, %rt, %z, %zt))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - int64_t x = 2, y = 3, r = 0, z = 0; - auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2; - auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3; - auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); - auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); - - // Verify that TEK::runFast works correctly with mixed scalar and tensor - // inputs/outputs - std::vector inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; - std::vector outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; - k.runFast(inputs, outputs); - TORCH_CHECK_EQ(z, x * y); - TORCH_CHECK_EQ(r, z * x); - ASSERT_TRUE(at::equal(zt, xt * yt)); - ASSERT_TRUE(at::equal(rt, zt * xt)); - - // Verify that TEK::run works correctly with mixed scalar and tensor - // inputs/outputs - std::vector stack = {x, xt, y, yt}; - k.run(stack); - TORCH_CHECK_EQ(stack[0], x * y * x); - ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt)); - TORCH_CHECK_EQ(stack[2], x * y); - ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt)); -} - -TEST_F(Kernel, FuseLoopsWithVariableBounds) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu), - %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim1, int dim2) { - auto a = - at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto c = - at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b, c}, 1); - - std::vector stack = - fmap(std::vector({a, b, c})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -TEST_F(Kernel, FuseLoopsWithVariableConcatDim) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3, -4, -5}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim1, int dim2, int dim3) { - auto a = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto c = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b, c}, 1); - - std::vector stack = - fmap(std::vector({a, b, c})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - stack.emplace_back(dim3); - stack.emplace_back(3 * dim3); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20, 15); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int, - %SS_6 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3, -4, -5, -6}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t j -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) { - auto a = - at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b}, 1); - - std::vector stack = fmap(std::vector({a, b})); - stack.emplace_back(dim2); - stack.emplace_back(dim3); - stack.emplace_back(dim4); - stack.emplace_back(dim5); - stack.emplace_back(dim4 + dim5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20, 15, 8); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp deleted file mode 100644 index f6ffc84f62c09..0000000000000 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ /dev/null @@ -1,1799 +0,0 @@ -#ifdef TORCH_ENABLE_LLVM -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -using LLVMExprEval = ExprEval; - -// Typed tests, can't use gtest params here due to the way we instantiate tests. -#define TEST_LLVM_SCALAR_TYPES(_) \ - _(uint8_t, Byte, 24) \ - _(int8_t, Char, -20) \ - _(int16_t, Short, 3332) \ - _(int, Int, 123456) \ - _(int64_t, Long, 2631563121321) \ - _(float, Float, 0.122) \ - _(double, Double, 0.21312) \ - _(at::Half, Half, 0.128f) - -#define IMM_TEST(Type, Name, Val) \ - TEST(LLVM, Name##ImmTest) { \ - auto a = Name##Imm::make(Val); \ - LLVMExprEval cg(a); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(IMM_TEST) -#undef IMM_TEST - -#define ADD_TEST(Type, Name, Val) \ - TEST(LLVM, Name##AddTest) { \ - auto a = Name##Imm::make(Val); \ - auto b = Name##Imm::make(Val * 2); \ - auto c = Add::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val * 3, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val * 3); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(ADD_TEST) -#undef ADD_TEST - -#define SUB_TEST(Type, Name, Val) \ - TEST(LLVM, Name##SubTest) { \ - auto a = Name##Imm::make(Val * 2); \ - auto b = Name##Imm::make(Val); \ - auto c = Sub::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(SUB_TEST) -#undef SUB_TEST - -#define MUL_TEST(Type, Name, Val) \ - TEST(LLVM, Name##MulTest) { \ - auto a = Name##Imm::make(Val); \ - auto b = Name##Imm::make((Type)4); \ - auto c = Mul::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val * 4, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val * 4); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(MUL_TEST) -#undef MUL_TEST - -#define DIV_TEST(Type, Name, Val) \ - TEST(LLVM, Name##DivTest) { \ - auto a = Name##Imm::make((Type)6); \ - auto b = Name##Imm::make((Type)3); \ - auto c = Div::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), 2, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), 2); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(DIV_TEST) -#undef DIV_TEST - -TEST(LLVM, IntToFloatCastTest) { - auto a = IntImm::make(2); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b, {}); - ASSERT_EQ(cg.value(), 2.0); -} - -TEST(LLVM, FloatToIntCastTest) { - auto a = FloatImm::make(2.0); - auto b = Cast::make(kInt, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, IntToLongCastTest) { - auto a = IntImm::make(12345); - auto b = Cast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 12345); -} - -TEST(LLVM, ByteToCharCastTest) { - auto a = ByteImm::make(250); - auto b = Cast::make(kChar, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), (int8_t)250); -} - -TEST(LLVM, HalfToLongCastTest) { - auto a = HalfImm::make(2.0); - auto b = Cast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, ByteToDoubleCastTest) { - auto a = ByteImm::make(2); - auto b = Cast::make(kDouble, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, FloatToByteCastTest) { - auto a = FloatImm::make(254.0); - auto b = Cast::make(kByte, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 254); -} - -TEST(LLVM, FloatToCharCastTest) { - auto a = FloatImm::make(-2.0); - auto b = Cast::make(kChar, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), -2); -} - -TEST(LLVM, ByteToFloatCastTest) { - auto a = ByteImm::make(254); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 254.0); -} - -TEST(LLVM, CharToFloatCastTest) { - auto a = CharImm::make(-2); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), -2.0); -} - -TEST(LLVM, BitCast) { - /* constexpr int16_t ref16 = 1337; */ - constexpr int32_t ref32 = 1337; - constexpr int64_t ref64 = 1337; - constexpr float reff32 = 1337.0f; - constexpr double reff64 = 1337.0f; - - // this is broken - /*{ - at::Half k_; - at::Half* k = &k_; - *reinterpret_cast(k) = ref16; - auto a = HalfImm::make(k); - auto b = BitCast::make(kShort, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref16); - }*/ - - { - float k = raw_bitcast(ref32); - auto a = FloatImm::make(k); - auto b = BitCast::make(kInt, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref32); - } - - { - double k = raw_bitcast(ref64); - auto a = DoubleImm::make(k); - auto b = BitCast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref64); - } - - { - int64_t k = raw_bitcast(reff64); - auto a = LongImm::make(k); - auto b = BitCast::make(kDouble, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), reff64); - } - - { - int32_t k = raw_bitcast(reff32); - auto a = IntImm::make(k); - auto b = BitCast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), reff32); - } -} - -TEST(LLVM, fastLogFloat) { - const int kTotalSize = 128 * 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - LLVMCodeGen ir_eval(stmt, {a_buf, b_buf}); - ir_eval.call({a_v, b_v}); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::log(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_FLOAT_EQ(test, ref); - } - } -} - -TEST(LLVM, LetTest01) { - BufHandle a("A", {1}, kFloat); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kFloat); - auto block = Block::make({ - Let::make(x, 3.f), - a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))), - }); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f); -} - -TEST(LLVM, LetTest02) { - BufHandle a("A", {1}, kFloat); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - auto block = Block::make( - {Let::make(x, 3.f), - Let::make(y, 6.f), - a.store( - {IntImm::make(0)}, - ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))}); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f); -} - -TEST(LLVM, LetTestMultitype) { - BufHandle a("A", {1}, kDouble); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kByte); - VarHandle y("y", kHalf); - auto block = Block::make( - {Let::make(x, 3), - Let::make(y, 6.f), - a.store( - {0}, - Cast::make( - kDouble, - ExprHandle(2.f) + - (x * ExprHandle(3.f) + y * ExprHandle(4.f))))}); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f); -} - -TEST(LLVM, BufferTest) { - BufHandle a("A", {32}, kFloat); - std::vector v(5); - std::vector args({v.data()}); - auto rv = IntImm::make(0); - LLVMExprEval cg(rv, {a}); - ASSERT_EQ(cg.value(args), 0); -} - -TEST(LLVM, BlockTest) { - BufHandle a("A", {32}, kInt); - std::vector v = {1, 2}; - std::vector args({v.data()}); - - auto block = Block::make({ - a.store({0}, 3), - a.store({1}, 4), - a.store({0}, 4), - }); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 4); - ASSERT_EQ(v[1], 4); -} - -TEST(LLVM, LoadStoreTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - - auto store = b.store({0}, a.load(0)); - LLVMCodeGen cg(store, {a, b}); - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 42); - ASSERT_EQ(b_buffer[0], 42); -} - -TEST(LLVM, IfThenElseTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - std::vector c_buffer = {1}; - - auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0)); - LLVMCodeGen cg(store, {a, b, c}); - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 42); - ASSERT_EQ(b_buffer[0], 42); -} - -// if (x < 10) x = x + 1 -TEST(LLVM, CondNoFalseBlockTest) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); - - for (int32_t x_value : {0, 10, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(cond, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - ASSERT_EQ(x_buffer[0], x_value + 1); - } else { - ASSERT_EQ(x_buffer[0], x_value); - } - } -} - -// if (x < 10) { -// x = x + 1; -// } else { -// x = x - 1; -// } -TEST(LLVM, CondTest) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = - Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - auto block = Block::make({ - cond, - x.store({0}, x.load(0) * 2), - }); - - for (int32_t x_value : {0, 10, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(block, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - ASSERT_EQ(x_buffer[0], (x_value + 1) * 2); - } else { - ASSERT_EQ(x_buffer[0], (x_value - 1) * 2); - } - } -} - -// if (x < 10) { -// if (x > 5) { -// x = x + 1; -// } else { -// x = x - 1; -// } -// } else { -// if (x <= 15) { -// x = x + 2; -// } else { -// x = x - 2; -// } -// } -TEST(LLVM, CondNestedTest) { - BufHandle x("X", {1}, kInt); - auto true_cmp = - CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); - auto true_cond = Cond::make( - true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - auto false_cmp = - CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE); - auto false_cond = Cond::make( - false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2)); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = Cond::make(cmp, true_cond, false_cond); - - for (int32_t x_value : {0, 8, 15, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(cond, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - if (x_value > 5) { - ASSERT_EQ(x_buffer[0], x_value + 1); - } else { - ASSERT_EQ(x_buffer[0], x_value - 1); - } - } else { - if (x_value <= 15) { - ASSERT_EQ(x_buffer[0], x_value + 2); - } else { - ASSERT_EQ(x_buffer[0], x_value - 2); - } - } - } -} - -TEST(LLVM, DirectVectorization) { - constexpr int M = 3; - constexpr int N = 64; - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {M, N}, kFloat); - BufHandle c("c", {M, N}, kFloat); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - StmtPtr s = For::make( - m, - 0, - M, - Store::make( - c, - {Ramp::make(m * 64, 1, 64)}, - Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) * - Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)}))); - LLVMCodeGen cg(s, {a, b, c}); -} - -TEST(LLVM, VecLoadStoreTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {1, 1, 1, 1}; - std::vector b_buffer = {2, 2, 2, 2}; - - auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)})); - LLVMCodeGen cg(store, {a, b}); - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 1); - ASSERT_EQ(a_buffer[1], 1); - ASSERT_EQ(a_buffer[2], 1); - ASSERT_EQ(a_buffer[3], 1); - ASSERT_EQ(b_buffer[0], 1); - ASSERT_EQ(b_buffer[1], 1); - ASSERT_EQ(b_buffer[2], 1); - ASSERT_EQ(b_buffer[3], 1); -} - -#define FLOAT_INTRINSICS_TEST(Name, Lanes) \ - TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \ - BufHandle a("A", {1}, kFloat); \ - BufHandle b("B", {1}, kFloat); \ - float val = 0.5f; \ - std::vector a_buffer(Lanes, val); \ - std::vector b_buffer(Lanes, val); \ - auto store = b.store( \ - {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ - LLVMCodeGen cg(store, {a, b}); \ - std::vector args({a_buffer.data(), b_buffer.data()}); \ - ASSERT_EQ(cg.value(args), 0); \ - for (const auto i : c10::irange(Lanes)) { \ - ASSERT_FLOAT_EQ(a_buffer[i], val); \ - } \ - } // namespace jit -FLOAT_INTRINSICS_TEST(erf, 4) -FLOAT_INTRINSICS_TEST(erfc, 4) -FLOAT_INTRINSICS_TEST(acos, 4) -FLOAT_INTRINSICS_TEST(asin, 4) -FLOAT_INTRINSICS_TEST(atan, 4) -FLOAT_INTRINSICS_TEST(cosh, 4) -FLOAT_INTRINSICS_TEST(sinh, 4) -FLOAT_INTRINSICS_TEST(tanh, 4) -FLOAT_INTRINSICS_TEST(expm1, 4) -FLOAT_INTRINSICS_TEST(lgamma, 4) -FLOAT_INTRINSICS_TEST(erf, 8) -FLOAT_INTRINSICS_TEST(erfc, 8) -FLOAT_INTRINSICS_TEST(acos, 8) -FLOAT_INTRINSICS_TEST(asin, 8) -FLOAT_INTRINSICS_TEST(atan, 8) -FLOAT_INTRINSICS_TEST(cosh, 8) -FLOAT_INTRINSICS_TEST(sinh, 8) -FLOAT_INTRINSICS_TEST(tanh, 8) -FLOAT_INTRINSICS_TEST(expm1, 8) -FLOAT_INTRINSICS_TEST(lgamma, 8) -#undef FLOAT_INTRINSICS_TEST - -#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \ - TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \ - BufHandle a("A", {1}, kDouble); \ - BufHandle b("B", {1}, kDouble); \ - float val = 0.5f; \ - std::vector a_buffer(Lanes, val); \ - std::vector b_buffer(Lanes, val); \ - auto store = b.store( \ - {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ - LLVMCodeGen cg(store, {a, b}); \ - std::vector args({a_buffer.data(), b_buffer.data()}); \ - ASSERT_EQ(cg.value(args), 0); \ - for (const auto i : c10::irange(Lanes)) { \ - ASSERT_FLOAT_EQ(a_buffer[i], val); \ - } \ - } // namespace jit -DOUBLE_INTRINSICS_TEST(erf, 2) -DOUBLE_INTRINSICS_TEST(erfc, 2) -DOUBLE_INTRINSICS_TEST(acos, 2) -DOUBLE_INTRINSICS_TEST(asin, 2) -DOUBLE_INTRINSICS_TEST(atan, 2) -DOUBLE_INTRINSICS_TEST(cosh, 2) -DOUBLE_INTRINSICS_TEST(sinh, 2) -DOUBLE_INTRINSICS_TEST(tanh, 2) -DOUBLE_INTRINSICS_TEST(expm1, 2) -DOUBLE_INTRINSICS_TEST(lgamma, 2) -DOUBLE_INTRINSICS_TEST(erf, 4) -DOUBLE_INTRINSICS_TEST(erfc, 4) -DOUBLE_INTRINSICS_TEST(acos, 4) -DOUBLE_INTRINSICS_TEST(asin, 4) -DOUBLE_INTRINSICS_TEST(atan, 4) -DOUBLE_INTRINSICS_TEST(cosh, 4) -DOUBLE_INTRINSICS_TEST(sinh, 4) -DOUBLE_INTRINSICS_TEST(tanh, 4) -DOUBLE_INTRINSICS_TEST(expm1, 4) -DOUBLE_INTRINSICS_TEST(lgamma, 4) -#undef DOUBLE_INTRINSICS_TEST - -TEST(LLVM, VectorizerLoadStoreTest) { - BufHandle a("A", {1}, kInt); - - Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); - - ASSERT_TRUE(to(to(s)->front()) == nullptr); - - LLVMCodeGen cg(s, {a, c_buf}); - - std::vector a_vec(4, 21); - std::vector c_vec(4, 0); - std::vector args({a_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 21); -} - -TEST(LLVM, VectorizeBitCast) { - BufHandle a("A", {128}, kInt); - - Tensor c = Compute("c", {128}, [&](const VarHandle& i) { - return bitcast(a.load(i)); - }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); - ASSERT_TRUE(to(to(s)->front()) == nullptr); - - LLVMCodeGen cg(s, {a, c_buf}); - - std::vector a_vec(128); - std::vector c_vec(128); - for (const auto i : c10::irange(128)) { - a_vec[i] = raw_bitcast(1337.f); - } - std::vector args({a_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 1337.f); -} - -TEST(LLVM, MemcpyTest) { - constexpr int N = 32; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - std::vector a_buffer(N, 42); - std::vector b_buffer(N, 0); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, b.store({i}, a.load(i))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, 42); - assertAllEqual(b_buffer, 42); -} - -TEST(LLVM, BzeroTest) { - constexpr int N = 32; - BufHandle b("B", {N}, kInt); - std::vector b_buffer(N, 11); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, b.store({i}, 0)); - - LLVMCodeGen cg(expr, {b}); - - std::vector args({b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(b_buffer, 0); -} - -TEST(LLVM, ElemwiseAdd) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 42); -} - -TEST(LLVM, ElemwiseAddFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 42.0f); -} - -TEST(LLVM, ElemwiseLog10Float) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - std::vector a_buffer(N, 10.0f); - std::vector b_buffer(N, 2.0f); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N / 4, - b.store( - {Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)})))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, 10.0f); - assertAllEqual(b_buffer, 1.0f); -} - -TEST(LLVM, ElemwiseLog1pFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - std::vector a_buffer(N, expf(3.0f) - 1); - std::vector b_buffer(N, 42.0f); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N / 4, - b.store( - {Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)})))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, expf(3.0f) - 1); - ExpectAllNear(b_buffer, 3.0f, 1e-5f); -} - -TEST(LLVM, ElemwiseMaxInt) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 41); -} - -TEST(LLVM, ElemwiseMinInt) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 1); -} - -TEST(LLVM, ElemwiseMaxFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 41.0f); -} - -TEST(LLVM, ElemwiseMaxNaNFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, NAN); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(b_buffer, 1.0f); - for (auto const& elt : c_buffer) { - ASSERT_TRUE(std::isnan(elt)); - } -} - -TEST(LLVM, ElemwiseMinFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 1.0f); -} - -TEST(LLVM, ElemwiseMinNaNFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, NAN); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(b_buffer, 1.0f); - for (auto const& elt : c_buffer) { - ASSERT_TRUE(std::isnan(elt)); - } -} - -TEST(LLVM, ElemwiseMod) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 23); - std::vector c_buffer(N, 18); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i)))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 23); - assertAllEqual(c_buffer, 18); -} - -TEST(LLVM, CompareSelectIntEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - for (int i = 0; i < N / 2; i++) { - b_buffer[i] = 0; - c_ref[i] = 0; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectFloatEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1.0f); - std::vector b_buffer(N, 1.0f); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 1); -} - -TEST(LLVM, CompareSelectByteGT) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 0); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 0); - - for (int i = 0; i < N / 2; i++) { - a_buffer[i] = 128; - c_ref[i] = 1; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGT))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(0)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteGE) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 0); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGE))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(0)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteLT) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 128); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - for (int i = 0; i < N / 2; i++) { - a_buffer[i] = 128; - c_ref[i] = 0; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLT))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(128)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteLE) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 128); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLE))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(128)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, StoreFloat) { - BufHandle result("result", {1}, kFloat); - std::vector result_buffer = {0.0f}; - auto expr = result.store({0}, FloatImm::make(3.14f)); - LLVMCodeGen cg(expr, {result}); - std::vector args({result_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(result_buffer[0], 3.14f); -} - -TEST(LLVM, SimpleMath01) { - const int N = 1024; - Tensor tensor = Compute( - "f", {N}, [](const VarHandle& i) { return cast(i * i + 1); }); - LoopNest l({tensor}); - StmtPtr stmt = l.root_stmt(); - BufHandle f_buf(tensor.buf()); - LLVMCodeGen cg(stmt, {f_buf}); - - PaddedBuffer f_v(N, "f_v"); - std::vector args({f_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer f_ref(N, "f_ref"); - for (const auto i : c10::irange(N)) { - f_ref(i) = i * i + 1; - } - ExpectAllNear(f_v, f_ref, 1e-5); -} - -TEST(LLVM, ComputeMul) { - const int N = 1024; - BufHandle a("a", {N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute( - "c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - - LLVMCodeGen cg(s, {a, b, c_buf}); - - std::vector a_vec(N, 21.0f); - std::vector b_vec(N, 2.0f); - std::vector c_vec(N, 0.0f); - std::vector args({a_vec.data(), b_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 42.0f); -} - -TEST(LLVM, BroadcastAdd) { - const int M = 32; - const int N = 1024; - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(j); - }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - LLVMCodeGen cg(s, {a, b, c_buf}); - - std::vector av(M * N); - std::iota(av.begin(), av.end(), 0); - std::vector bv(N); - std::iota(bv.begin(), bv.end(), 0); - std::vector cv(M * N, 0); - std::vector args({av.data(), bv.data(), cv.data()}); - ASSERT_EQ(cg.value(args), 0); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]); - } - } -} - -TEST(LLVM, BitwiseOps) { - auto a = IntImm::make(59); - auto b = IntImm::make(11); - auto c = IntImm::make(101); - auto d = IntImm::make(2); - - ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; - LLVMExprEval cg(f); - - ASSERT_EQ(cg.value(), 11); -} - -TEST(LLVM, ArithmeticRightShift) { - auto a = CharImm::make(-4); - auto b = CharImm::make(1); - ExprHandle f = a >> b; - LLVMExprEval cg(f); - ASSERT_EQ(cg.value(), -2); -} - -TEST(LLVM, LogicalRightShift) { - auto a = ByteImm::make(0xfc); - auto b = ByteImm::make(1); - ExprHandle f = a >> b; - LLVMExprEval cg(f); - ASSERT_EQ(cg.value(), 0x7e); -} - -TEST(LLVM, DynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - LLVMCodeGen cg(s, {a, b, c, n}); - std::vector args({aData.data(), bData.data(), cData.data(), &size}); - cg.value(args); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, BindDynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - LLVMCodeGen cg(s, {a, b, c, n}); - cg.call({aData, bData, cData, size}); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, TensorDynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - Tensor c = Compute( - "c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); }); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - LLVMCodeGen cg(s, {a, b, c, n}); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - cg.call({aData, bData, cData, size}); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, DynamicShape2D) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - LLVMCodeGen cg(s, {a, b, c, m, n}); - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - cg.call({aData, bData, cData, M, N}); - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - }; - testWithSize(1, 8); - testWithSize(16, 32); - testWithSize(37, 11); -} - -TEST(LLVM, EmptyStmt) { - StmtPtr s = alloc(std::vector({})); - - LLVMCodeGen cg(s, {}); - cg.call({}); - // Just don't crash. -} - -TEST(LLVM, EliminatedStmt) { - BufHandle a("a", {1}, kFloat); - - Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; }); - - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - LLVMCodeGen cg(s, {a, c}); - std::vector aData(1, 1.0f); - std::vector cData(0, 0.0f); - cg.call({aData, cData}); -} - -TEST(LLVM, SimpleReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loop({b}); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(LLVM, RFactorReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loop({b}); - - std::vector loops = loop.getLoopStmtsFor(b); - ForPtr loop_m = loops.at(1); - ForPtr loop_n = loops.at(2); - loop.reorderAxis(loop_m, loop_n); - - loops = loop.getLoopStmtsFor(b); - loop_m = loops.at(2); - loop_n = loops.at(1); - auto b_body = loop.getAllWritesToBuf(b.buf())[1]; - ASSERT_TRUE(loop.rfactor(b_body, loop_n)); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(LLVM, RFactorVectorizedReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loopnest({b}); - std::vector loops = loopnest.getLoopStmtsFor(b); - // Reorder n and m loops - loopnest.reorderAxis(loops.at(1), loops.at(2)); - auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1); - auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf()); - ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3); - ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1])); - auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]); - - // Vectorize initializer of rfac_buf - ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0])); - // Vectorize producer of rfac_buf - ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1])); - loopnest.simplify(); - - loopnest.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt()); - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -template -static void testSimpleParallel() { - // Compute a simple operation, and try all loop-axis combination to be - // parallel or sequential. - const int M = 4; - const int N = 6; - Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) { - return cast(m + n); - }); - LoopNest loop_nest({f}); - auto const& loops = loop_nest.getLoopStmtsFor(f); - ForPtr m = loops[0]; - ForPtr n = loops[1]; - if (outer) { - m->set_parallel(); - } - if (inner) { - n->set_parallel(); - } - loop_nest.prepareForCodegen(); - StmtPtr stmt = loop_nest.root_stmt(); - LLVMCodeGen cg(stmt, {f}); - - PaddedBuffer f_v(M, N, "f_v"); - std::vector args({f_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer f_ref(M, N, "f_ref"); - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - f_ref(m, n) = m + n; - } - } - ExpectAllNear(f_v, f_ref, 1e-5); -} - -TEST(LLVM, SimpleParallelSS) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelSP) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelPS) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelPP) { - testSimpleParallel(); -} - -TEST(LLVM, CompositeParallel) { - int loop_count = 6; - int test_count = 1 << loop_count; - // Compute a composite operation, and try all loop-axis combination to be - // parallel or sequential. - for (const auto test_cfg : c10::irange(test_count)) { - int M = 5; - int N = 7; - Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; }); - Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; }); - Tensor t3 = - Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) { - return t1.load(m) * t2.load(n); - }); - Tensor t4 = - Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) { - return t3.load(m, n) + m + n; - }); - LoopNest loop_nest({t4}, {t1, t2, t3, t4}); - std::vector loop_list; - { - auto const& loops = loop_nest.getLoopStmtsFor(t1); - loop_list.push_back(loops[0]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t2); - loop_list.push_back(loops[0]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t3); - loop_list.push_back(loops[0]); - loop_list.push_back(loops[1]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t4); - loop_list.push_back(loops[0]); - loop_list.push_back(loops[1]); - } - ASSERT_EQ(loop_list.size(), loop_count); - for (const auto i : c10::irange(loop_count)) { - if (test_cfg & (1 << i)) { - loop_list[i]->set_parallel(); - } - } - loop_nest.prepareForCodegen(); - StmtPtr stmt = loop_nest.root_stmt(); - LLVMCodeGen cg(stmt, {t4}); - - PaddedBuffer t4_v(M, N, "t4_v"); - std::vector args({t4_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer t4_ref(M, N, "t4_ref"); - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - t4_ref(m, n) = (m + 1) * (n + 2) + m + n; - } - } - ExpectAllNear(t4_v, t4_ref, 1e-5); - } -} - -TEST(LLVM, VectorizedGEMM) { - int M = 32; - int N = 32; - int K = 48; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - LoopNest loop({CT}); - - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr m = loops[0]; - loop.splitWithMask(m, 16); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr n = loops[2]; - loop.splitWithMask(n, 16); - } - // mo, mi, no, ni, k -> - // mo, no, mi, ni, k - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[1]; - ForPtr no = loops[2]; - loop.reorderAxis(mi, no); - } - // mo, no, mi, ni, k -> - // mo, no, mi, k, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr ni = loops[3]; - ForPtr k = loops[4]; - loop.reorderAxis(ni, k); - } - // mo, no, mi, k, ni -> - // mo, no, k, mi, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[2]; - ForPtr k = loops[3]; - loop.reorderAxis(mi, k); - } - { - auto loops = NodeFinder::find(loop.root_stmt()); - ASSERT_TRUE(LoopNest::vectorize(loops[3])); - ASSERT_TRUE(LoopNest::vectorize(loops.back())); - } - - loop.prepareForCodegen(); - - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - LLVMCodeGen cg(s, {AP, BP, CT}); - - PaddedBuffer a_v(M, K, "a_v"); - PaddedBuffer b_v(K, N, "b_v"); - PaddedBuffer c_v(M, N, "c_v"); - PaddedBuffer c_ref(M, N, "c_ref"); - - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - c_ref(m, n) = 0.f; - for (const auto k : c10::irange(K)) { - c_ref(m, n) += a_v(m, k) * b_v(k, n); - } - } - } - - cg.call({a_v, b_v, c_v}); - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LLVM, CallRaw) { - const int M = 32; - VarHandle N("N", kInt); - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(j); - }); - - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - int32_t N_value = 1024; - std::vector av(M * N_value); - std::iota(av.begin(), av.end(), 0); - std::vector bv(N_value); - std::iota(bv.begin(), bv.end(), 0); - std::vector cv(M * N_value, 0); - std::vector args({av.data(), bv.data(), cv.data(), &N_value}); - - LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N}); - cg.call_raw(args); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N_value)) { - ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); - } - } - - SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N}); - eval.call_raw(args); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N_value)) { - ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); - } - } -} - -TEST(LLVM, CustomTarget) { - constexpr int M = 16; - BufHandle a("a", {M}, kFloat); - BufHandle b("b", {M}, kFloat); - BufHandle c("c", {M}, kFloat); - Tensor d = Compute("d", {M}, [&](const VarHandle& m) { - return a.load(m) * b.load(m) + c.load(m); - }); - LoopNest nest({d}); - nest.prepareForCodegen(); - auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d}) - .triple("i686-elf") - .cpu("i386") - .build(); - std::ostringstream ss; - ss << cg->getCodeText("asm"); - torch::jit::testing::FileCheck() - .check("fadds") - ->check("fmuls") - ->check_not("vfmadd") - ->run(ss.str()); -} - -TEST(LLVM, CodeGenKernelFuncName) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - auto store = b.store({0}, a.load(0)); - - { - LLVMCodeGen cg(store, {a, b}); - // Check that the kernel function name used by LLVMCodeGen - // is not empty. - ASSERT_NE(cg.kernel_func_name(), ""); - } - - { - LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func"); - // Check that the kernel function name used by LLVMCodeGen - // is the one that was given above. - ASSERT_EQ(cg.kernel_func_name(), "new_func"); - } -} - -} // namespace jit -} // namespace torch - -#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp deleted file mode 100644 index a8bda8814dbae..0000000000000 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ /dev/null @@ -1,6894 +0,0 @@ -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -void checkIR(StmtPtr s, const std::string& pattern) { - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(pattern, oss.str()); -} - -void checkExprIR(ExprPtr e, const std::string& pattern) { - std::string prefixed_pattern = "# CHECK: " + pattern + "\n"; - std::ostringstream oss; - oss << *e << "\n"; - torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str()); -} - -void checkExprIR(const ExprHandle& e, const std::string& pattern) { - checkExprIR(e.node(), pattern); -} - -TEST(LoopNest, ExprSimple01) { - Tensor tensor = - Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - LoopNest::splitWithTail(loops[0], 2); - LoopNest::splitWithTail(loops[0], 2); -} - -TEST(LoopNest, ExprLower01) { - Tensor tensor = - Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 20); - ASSERT_LT(oss.str().size(), 200); -} - -TEST(LoopNest, ExprSimple02) { - auto func = [](const ExprHandle& x, const ExprHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }; - Tensor tensor = Compute("f", {26, 5}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - LoopNest::splitWithTail(loops[0], 4); - - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 200); - ASSERT_LT(oss.str().size(), 600); - - { - // Compare to a reference loop structure structure. - VarHandle x_outer("i_outer", kInt); - VarHandle x_inner("i_inner", kInt); - VarHandle y("i", kInt); - VarHandle x_tail("i_tail", kInt); - BufHandle f("f", {26, 5}, kFloat); - ExprHandle x_1 = x_outer * 4 + x_inner; - ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; - ForPtr stmt1 = For::make( - x_outer, - 0, - x_outer_end, - For::make( - x_inner, - 0, - 4, - For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y))))); - ExprHandle x_2 = x_tail + x_outer_end * 4; - ForPtr stmt2 = For::make( - x_tail, - 0, - (ExprHandle(26) - 0) % 4, - For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y)))); - StmtPtr stmt = Block::make({stmt1, stmt2}); - - std::ostringstream oss_ref; - oss_ref << *stmt; - ASSERT_EQ(oss.str(), oss_ref.str()); - } - - { - PaddedBuffer f_v(26, 5, "f_v"); - PaddedBuffer f_ref(26, 5, "f_res"); - - stmt = FlattenIndexes(stmt); - SimpleIREvaluator ir_eval(stmt, {tensor}); - ir_eval(f_v); - - for (int x = 0; x < 26; x++) { - for (int y = 0; y < 5; y++) { - f_ref(x, y) = 1 + x * x + y * y; - } - } - - ExpectAllNear(f_v, f_ref, 1e-5); - } -} - -BlockPtr getSimplifiedBody(const LoopNest& l) { - StmtPtr stmt = l.root_stmt(); - StmtPtr simplified = IRSimplifier::simplify(stmt); - return to(simplified); -} - -void assertForRange(ForPtr f, int expected_start, int expected_stop) { - ASSERT_NE(f, nullptr); - IntImmPtr start = to(f->start()); - ASSERT_NE(start, nullptr); - ASSERT_EQ(start->value(), expected_start); - IntImmPtr stop = to(f->stop()); - ASSERT_NE(stop, nullptr); - ASSERT_EQ(stop->value(), expected_stop); -} - -void assertForRanges( - BlockPtr body, - const std::vector>& start_stops) { - ASSERT_EQ(body->nstmts(), start_stops.size()); - - auto it = body->begin(); - for (size_t i = 0; i < start_stops.size(); i++, it++) { - ForPtr loop = to(*it); - assertForRange(loop, start_stops[i].first, start_stops[i].second); - } -} - -TEST(LoopNest, ExprSliceHeadWithLoopOptions) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::sliceHead(loops[0], 2, &head, &tail); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {0, 8}}); - - ASSERT_TRUE(tail->loop_options().is_gpu_block_index()); - ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - ASSERT_TRUE(head->loop_options().isDefault()); -} - -TEST(LoopNest, ExprSliceTailWithLoopOptions) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 4, &head, &tail); - - ForPtr tail_head; - ForPtr tail_tail; - tail->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}}); - - ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index()); - ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - ASSERT_TRUE(head->loop_options().isDefault()); - ASSERT_TRUE(tail_tail->loop_options().isDefault()); -} - -TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 10, &head, &tail); - - ASSERT_EQ(head, loops[0]); - ASSERT_EQ(tail, nullptr); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 100, &head, &tail); - - ASSERT_EQ(head, loops[0]); - ASSERT_EQ(tail, nullptr); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceHead) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 4, &head, &tail); - - ASSERT_NE(head, nullptr); - ASSERT_NE(head, loops[0]); - ASSERT_NE(tail, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 4}, {4, 10}}); -} - -TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceTail(loops[0], 4, &head, &tail); - // head: [0, 6) - // tail: [6, 10) - - LoopNest::sliceHead(tail, 2); - // tail_head: [6, 8) - // tail_tail: [8, 10) - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); -} - -TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 10, &head, &tail); - - ASSERT_EQ(head, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 100, &head, &tail); - - ASSERT_EQ(head, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceTail) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 4, &head, &tail); - - ASSERT_NE(head, nullptr); - ASSERT_EQ(head, loops[0]); - ASSERT_NE(tail, nullptr); - ASSERT_NE(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {6, 10}}); -} - -TEST(LoopNest, ExprSplitAndSlice) { - // 0: splitWithTail - // 1: sliceTail on inner loop - // 2: sliceHead on outer loop - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {100}, func); - LoopNest l({tensor}); - - ForPtr inner; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // outer: [0, 4) - // inner: [0, 21) - // tail: [84, 100) - LoopNest::splitWithTail(loops[0], 21, &inner, &tail); - LoopNest::sliceTail(inner, 2); - LoopNest::sliceHead(loops[0], 2); - - // for (int x_outer = 0; x_outer < 2; x_outer++) { - // for (int x_inner = 0; x_inner < 19; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // for (int x_inner = 19; x_inner < 21; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // } - // for (int x_outer = 2; x_outer < 4; x_outer++) { - // for (int x_inner = 0; x_inner < 19; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // for (int x_inner = 19; x_inner < 21; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // } - // for (int x_tail = 0; x_tail < 16; x_tail++) { - // f[x_tail + 84] = 1.f + float(x_tail + 84); - // } - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}}); - - auto biter = body->begin(); - - ForPtr loop = to(*biter++); - assertForRanges(loop->body(), {{0, 19}, {19, 21}}); - - loop = to(*biter); - assertForRanges(loop->body(), {{0, 19}, {19, 21}}); -} - -TEST(LoopNest, ExprSliceAndNormalize) { - // 0: sliceHead - // 1: normalize tail - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceHead(loops[0], 2, &head, &tail); - // head: [0, 2) - // tail: [2, 10) - - LoopNest::normalize(tail); - // normalized_tail: [0, 8) - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {0, 8}}); -} - -template -T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) { - ExprEval eval(expr, {var}); - return eval.value(value); -} - -TEST(LoopNest, ExprSliceWithVariableDimension) { - auto testWithDimension = - [](int dimension, - const std::vector>& expected_for_ranges) { - VarHandle dim("dim", kInt); - Tensor tensor = - Compute("f", {dim}, [](const ExprHandle& x) { return x; }); - LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceHead(loops[0], 2, &head, &tail); - - LoopNest::sliceTail(tail, 2); - - BlockPtr body = getSimplifiedBody(l); - ASSERT_EQ(expected_for_ranges.size(), 3); - auto it = body->begin(); - for (auto& start_stop : expected_for_ranges) { - ForPtr loop = to(*it++); - int start = evalExpr(ExprHandle(loop->start()), dim, dimension); - int stop = evalExpr(ExprHandle(loop->stop()), dim, dimension); - ASSERT_EQ(start, start_stop.first); - ASSERT_EQ(stop, start_stop.second); - } - }; - - testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}}); - testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}}); - testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}}); - testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}}); - testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}}); - testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}}); -} - -TEST(LoopNest, ExprSplitWithTail) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {199}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - LoopNest::splitWithTail(loops[0], 17); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - LoopNest::splitWithTail(loops[0], 7); - - StmtPtr stmt = l.root_stmt(); - StmtPtr simplified = IRSimplifier::simplify(stmt); - BlockPtr body = to(simplified); - ASSERT_EQ(body->nstmts(), 3); - auto biter = body->begin(); - - // Verify that the split loops are ordered correctly. - ForPtr loop = to(*biter++); - assertForRange(loop, 0, 7); - - loop = to(*biter++); - assertForRange(loop, 0, 4); - - loop = to(*biter); - assertForRange(loop, 0, 12); -} - -TEST(LoopNest, ExprSplitWithTailNone) { - auto func = [](const ExprHandle& x, const ExprHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }; - Tensor tensor = Compute("f", {24, 5}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithTail(loops[0], 4); - - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 200); - ASSERT_LT(oss.str().size(), 600); - - { - // Compare to a reference loop structure structure. - VarHandle x_outer("i_outer", kInt); - VarHandle x_inner("i_inner", kInt); - VarHandle y("i", kInt); - VarHandle x_tail("i_tail", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - BufHandle f("f", {24, 5}, kFloat); - ExprHandle x_1 = x_outer * 4 + x_inner; - ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; - StmtPtr stmt = alloc(std::vector({For::make( - x_outer, - 0, - x_outer_end, - For::make( - x_inner, - 0, - 4, - For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))})); - - std::ostringstream oss_ref; - oss_ref << *stmt; - ASSERT_EQ(oss.str(), oss_ref.str()); - } - - { - PaddedBuffer f_v(24, 5, "f_v"); - PaddedBuffer f_ref(24, 5, "f_res"); - - SimpleIREvaluator ir_eval(stmt, {tensor}); - ir_eval(f_v); - - for (int x = 0; x < 24; x++) { - for (int y = 0; y < 5; y++) { - f_ref(x, y) = 1 + x * x + y * y; - } - } - - ExpectAllNear(f_v, f_ref, 1e-5); - } -} - -TEST(LoopNest, ExprSplitWithMask01) { - const int M = 26; - const int N = 5; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithMask(loops[1], 4); - - StmtPtr stmt = l.root_stmt(); - - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -// Tests the case where we split a loop cleanly multiple times, we should not -// insert any masks. -TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { - const int M = 64; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - LoopNest::splitWithMask(loops[0], 4); - - StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt()); - - // Two splits mean 3 loops, but should need no masks in this case. - checkIR(stmt1, R"IR( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: f[)IR"); -} - -TEST(LoopNest, getLoopAt) { - // Input IR: - // for (int i = 0; i < 100; i++) { - // for (int j = 0; j < 100; j++) { - // A[i, j] = sin(i * j); - // for (int k1 = 0; k1 < 200; k1++) { - // B[i, j, k1] = (A[i, j]) / (k1 + 1); - // } - // for (int k2 = 0; k2 < 300; k2++) { - // C[i, j, k2] = (A[i, j]) * (k2 + 1); - // } - // } - // } - BufPtr A = alloc( - "A", - std::vector({alloc(100), alloc(100)}), - kInt); - BufPtr B = alloc( - "B", - std::vector( - {alloc(100), alloc(100), alloc(200)}), - kInt); - BufPtr C = alloc( - "C", - std::vector( - {alloc(100), alloc(100), alloc(300)}), - kInt); - BufHandle a_buf(A); - BufHandle b_buf(B); - BufHandle c_buf(C); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k1("k1", kInt); - VarHandle k2("k2", kInt); - auto store1 = Store::make(a_buf, {i, j}, sin(i * j)); - auto store2 = Store::make( - b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1))); - auto store3 = Store::make( - c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1))); - auto for_k2 = For::make(k2, 0, 300, Block::make({store3})); - auto for_k1 = For::make(k1, 0, 200, Block::make({store2})); - auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2})); - auto for_i = For::make(i, 0, 100, for_j); - LoopNest l(Block::make({for_i}), {B, C}); - auto ret_k2 = l.getLoopAt(for_i, {0, 2}); - TORCH_CHECK(ret_k2 == for_k2); - - std::ostringstream oss; - oss << *ret_k2; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int k2 -# CHECK-NEXT: C[i, j, k2] = - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, TileSimple) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 64, N = 64; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - l.tile(loops[0], loops[1], 4, 8); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: f[ -# CHECK-NOT: for (int i_tail -# CHECK-NOT: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, TileWithTails) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 64, N = 64; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - l.tile(loops[0], loops[1], 5, 9); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: f[ -# CHECK: for (int i_inner -# CHECK: f[ -# CHECK: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, TileInMiddle) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 8, N = 8, L = 8, K = 8; - BufHandle a_buf("a", {M, N, L, K}, kFloat); - BufHandle b_buf("b", {M, N, L, K}, kFloat); - Tensor tensor = Compute( - "f", - {M, N, L, K}, - [&](const ExprHandle& m, - const ExprHandle& n, - const ExprHandle& l, - const ExprHandle& k) { - return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f; - }); - - LoopNest nest({tensor}); - std::vector loops = - nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - nest.tile(loops[1], loops[2], 3, 3); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: for (int i_1 -# CHECK: f[ -# CHECK: for (int i_tail_1 -# CHECK: for (int i_inner_1 -# CHECK: for (int i_1 -# CHECK: f[ -# CHECK: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, L, K, "a"); - PaddedBuffer b_v(M, N, L, K, "b"); - PaddedBuffer c_v(M, N, L, K, "c"); - PaddedBuffer c_ref(M, N, L, K, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int l = 0; l < L; l++) { - for (int k = 0; k < K; k++) { - a_v(m, n, l, k) = 2 * (m + l); - b_v(m, n, l, k) = 3 * (n + k); - c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f; - } - } - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, SplitWithTailWithLoopOptions) { - const int M = 21; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - ForPtr inner, tail; - - LoopNest l({tensor}); - auto loops = NodeFinder::find(l.root_stmt()); - ASSERT_GT(loops.size(), 0); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::splitWithTail(loops[0], 4, &inner, &tail); - ASSERT_NE(inner, nullptr); - ASSERT_NE(tail, nullptr); - ForPtr outer = loops[0]; - - // Outer loop carries loop axis bindings. - ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); - ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - // Inner loop has none. - ASSERT_TRUE(inner->loop_options().isDefault()); - - // Tail loop has none. - ASSERT_TRUE(tail->loop_options().isDefault()); -} - -TEST(LoopNest, SplitWithMaskWithLoopOptions) { - const int M = 21; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - ForPtr inner; - - LoopNest l({tensor}); - auto loops = NodeFinder::find(l.root_stmt()); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::splitWithMask(loops[0], 4, &inner); - ForPtr outer = loops[0]; - - // Outer loop carries loop axis bindings. - ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); - ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - // Inner loop has none. - ASSERT_TRUE(inner->loop_options().isDefault()); -} - -TEST(LoopNest, ScheduleBroadcastAddBuffer) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - Tensor c = Compute( - "broadcast_add", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - LoopNest l({c}); - StmtPtr stmt = l.root_stmt(); - - PaddedBuffer a_v(M, N, "a_v"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 7 * m * n; - } - } - a_v.Backup(); - - PaddedBuffer b_v(N, K, "b_v"); - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - b_v(n, k) = 11 * n * k; - } - } - b_v.Backup(); - - PaddedBuffer c_v(M, N, K, "c_buf"); - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c}); - ir_eval(a_v, b_v, c_v); - - a_v.CheckBackup(); - b_v.CheckBackup(); - PaddedBuffer c_ref(M, N, K, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - c_ref(m, n, k) = 7 * m * n + 11 * n * k; - } - } - } - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, ScheduleFunctionCall01) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - Tensor c = Compute( - "broadcast_add", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 100); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N, K); - PaddedBuffer d_v(M, N, K); - PaddedBuffer d_ref(M, N, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - for (int k = 0; k < K; k++) { - d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1; - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, d}); - eval(a_v, b_v, d_v); - - ExpectAllNear(d_v, d_ref, 1e-5); -} - -TEST(LoopNest, ScheduleInlineSimple) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer y_1(M, N, K); - PaddedBuffer y_2(M, N, K); - - eval1(a_v, b_v, c_v, d_v, y_1); - eval2(a_v, b_v, c_v, d_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -static std::string remove_space(const std::string& str) { - std::string str_new = str; - str_new.erase( - remove_if(str_new.begin(), str_new.end(), isspace), str_new.end()); - return str_new; -} - -void InlineFunc01Helper(const std::vector& inline_order) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - Tensor z = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + y.load(m, n, k); - }); - - LoopNest l({z}, {x, y, z}); - for (const std::string& order : inline_order) { - if (order == "x") { - l.computeInline(x.buf()); - } else if (order == "y") { - l.computeInline(y.buf()); - } else { - throw std::runtime_error("Invalid order: " + order); - } - } - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - - std::ostringstream oss; - oss << *stmt; - std::string str1 = remove_space(oss.str()); - - { - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer z_v(M, N, K); - PaddedBuffer z_ref(M, N, K); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); - eval(a_v, b_v, c_v, d_v, z_v); - ExpectAllNear(z_v, z_ref, 1e-5); - } - - if (inline_order.size() == 2) { - Tensor z2 = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k) + - (c_buf.load(m, n) * d_buf.load(m, k) + - a_buf.load(m, n) * b_buf.load(n, k)); - }); - LoopNest l2({z2}); - l2.prepareForCodegen(); - StmtPtr stmt2 = l2.root_stmt(); - - std::ostringstream oss2; - oss2 << *stmt2; - std::string str2 = remove_space(oss2.str()); - - ASSERT_EQ(str1, str2); - ASSERT_GT(str1.size(), 100); - } -} - -TEST(LoopNest, ScheduleInlineFunc01) { - InlineFunc01Helper({"x", "y"}); - InlineFunc01Helper({"y", "x"}); - InlineFunc01Helper({"x"}); - InlineFunc01Helper({"y"}); - InlineFunc01Helper({}); -} - -// Make sure we cache random vars if we should. -TEST(LoopNest, ScheduleInlineRandom) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Mod::make(Intrinsics::make(kRand, kInt), 5); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + x.load(m, n, k); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: int x = rand(); -# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); -} - -// Make sure we don't cache random vars that are not being inlined. -TEST(LoopNest, ScheduleInlineRandomUnrelated) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return m * n * k; - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + Intrinsics::make(kRand, kInt) + - Intrinsics::make(kRand, kInt); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR"); -} - -// Make sure we generate the right number of random values == the dimensionality -// of the production tensor. -TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute("x", {M}, [&](const VarHandle& m) { - return Mod::make(Intrinsics::make(kRand, kInt), 5); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m) + x.load(m); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: int x = rand(); -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); -} - -// Make sure we don't screw up intrinsics thinking they're rand. -TEST(LoopNest, ScheduleInlineIntrinsics) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kSqrt, x.load(m, n, k)); - }); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M, N, K); - PaddedBuffer y_2(M, N, K); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -// Make sure we can handle rand and non-rand intrinsics. -TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kRand, kFloat); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kSqrt, x.load(m, n, k)); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: float x = rand(); -# CHECK: y[i, i_1, i_2] = sqrt(x);)IR"); -} - -// Split a Compute then inline it into another compute. -TEST(LoopNest, ScheduleSplitAThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Split a Compute then inline another Compute into it. -TEST(LoopNest, ScheduleSplitBThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); - LoopNest::splitWithMask(loops[0], 3); - l.computeInline(a.buf()); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - - std::vector output(6, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Split a Compute twice then inline it. -TEST(LoopNest, ScheduleSplitTwiceThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - ForPtr i_inner; - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4, &i_inner); - LoopNest::splitWithMask(i_inner, 2); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Inline a Compute, then split. -TEST(LoopNest, ScheduleInlineThenSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - l.computeInline(a.buf()); - - std::vector loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.back(), 3); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(6, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Split a Compute, inline it, then split the result. -TEST(LoopNest, ScheduleSplitInlineThenSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - auto loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.back(), 2); - l.computeInline(a.buf()); - - loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.front(), 2); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(16, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 16; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Oversplit a loop that is simplified out after inlining. -TEST(LoopNest, ScheduleSplitInlineSimplify) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { - return ExprHandle(4) * i - ExprHandle(2) * i; - }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Inline a Compute with two consumers. -TEST(LoopNest, ScheduleInlineThreeMixedOnce) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(a.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Inline Compute A into B, then inline B into C. -TEST(LoopNest, ScheduleInlineThreeMixedTwice) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(a.buf()); - l.computeInline(b.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Inline a Compute that is both a producer and consumer. -TEST(LoopNest, ScheduleInlineThreeMixedInner) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(b.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Split 3 Computes, then inline the first two into the last. -TEST(LoopNest, ScheduleInlineThreeMixedSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); - LoopNest::splitWithMask(loops[0], 3); - loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::splitWithMask(loops[0], 2); - - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Check that inlining works for output tensors too -TEST(LoopNest, ScheduleInlineOutputTensors) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return m * n * k; - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + m; - }); - - LoopNest l1({x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: x[i, i_1, i_2] = (i * i_1) * i_2; -# CHECK: for (int i_3 = 0; i_3 < 4; i_3++) -# CHECK: for (int i_4 = 0; i_4 < 5; i_4++) -# CHECK: for (int i_5 = 0; i_5 < 6; i_5++) -# CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR"); -} - -TEST(LoopNest, ScheduleInlineWithCompoundIndices) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[i*2,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[0, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {static_cast(0), j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - // Inlining should fail since the producer has compound expr as index. - ASSERT_FALSE(l.computeInline(a_buf.node())); - - // The input statement must remain as is. - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t i = 0; - # CHECK-NEXT: A[ - # CHECK: for (int64_t j = 0; - # CHECK-NEXT: B[)IR"); -} - -TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[0ll,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[(int64_t)0, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make( - a_buf, - {static_cast(0), i}, - Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {0, j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - ASSERT_TRUE(l.computeInline(a_buf.node())); - - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t j = 0; j < 100; j++) { - # CHECK: B[0ll, j] = j * 500ll + j * 100ll; - # CHECK: })IR"); -} - -TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[(int64_t)0,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[0ll, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make(a_buf, {0, i}, Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {static_cast(0), j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - ASSERT_TRUE(l.computeInline(a_buf.node())); - - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t j = 0; j < 100; j++) { - # CHECK: B[0ll, j] = j * 500ll + j * 100ll; - # CHECK: })IR"); -} - -TEST(LoopNest, ScheduleFuserStyle) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - - Tensor b = - Compute("f", {kTotalSize}, [&](const std::vector& axes) { - return a_buf.load(axes[0]) + 11.0f; - }); - - Tensor c = - Compute("g", {kTotalSize}, [&](const std::vector& axes) { - return b.load(axes[0]) + 1.0f; - }); - - LoopNest l({b, c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - std::vector a_data(kTotalSize, 7.0f); - std::vector b_data(kTotalSize, 0.0f); - std::vector c_data(kTotalSize, 0.0f); - SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data); - - for (int i = 0; i < kTotalSize; i++) { - ASSERT_EQ(b_data[i], 18.0f); - ASSERT_EQ(c_data[i], 19.0f); - } -} - -TEST(LoopNest, ScheduleFuserThreeArg) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat); - - Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) { - return a.load(i) + b.load(i); - }); - Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) { - return e.load(i) + c.load(i); - }); - Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) { - return f.load(i) + d.load(i); - }); - - LoopNest l({g}, {e, f, g}); - l.computeInline(l.getLoopBodyFor(e)); - l.computeInline(l.getLoopBodyFor(f)); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - std::vector a_data(kTotalSize, 1.0f); - std::vector b_data(kTotalSize, 2.0f); - std::vector c_data(kTotalSize, 3.0f); - std::vector d_data(kTotalSize, 4.0f); - std::vector g_data(kTotalSize, 0.0f); - SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data); - - for (int i = 0; i < kTotalSize; i++) { - ASSERT_EQ(g_data[i], 10.0f); - } -} - -TEST(LoopNest, ScheduleDynamicShape2D) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - SimpleIREvaluator cg(s, {a, b, c, m, n}); - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - cg.call({aData, bData, cData, M, N}); - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - }; - testWithSize(1, 8); - testWithSize(16, 32); - testWithSize(37, 11); -} - -TEST(LoopNest, LoopNestComputeAt_1) { - // Verify that compute_at works on the following example: - // - // for (int i_a = 0; i_a < N; i_a++) { - // A[i_a] = i_a * i_a - // } - // for (int i_b = 0; i_b < N; i_b++) { - // B[i_b] = A[i_b] - // } - // - // After the transformation the i_b loop should have an allocation for a temp - // buffer and that buffer should be used in computation of B. No use of A - // should be in that loop after the transformation. Also, computation of A - // should not be inlined into B. Instead, it should be computed into the temp, - // and the temp should be used in B. - VarHandle N("N", kInt); - Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; }); - Tensor B = - Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); }); - LoopNest l({B}, {A, B}); - std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {B, N}); - StmtPtr s = cg.stmt(); - - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1] -# CHECK: for (int i = 0; i < N; i++) -# CHECK: temp[ -# CHECK-NOT: A[ -# CHECK: B[i_1] = temp[0] -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector b_data(100, 0); - cg.call({b_data, 100}); - - std::vector b_ref(100, 0); - for (int i = 0; i < 100; i++) { - b_ref[i] = i * i; - } - assertAllEqual(b_data, b_ref); -} - -TEST(LoopNest, LoopNestComputeAt_2) { - // Verify that compute_at works on the following example: - // - // for (int py = 0; py < H+1; py++) { - // for (int px = 0; px < W+1; px++) { - // p[py, px] = py*px - // } - // } - // for (int cy = 0; cy < H; cy++) { - // for (int cx = 0; cx < W; cx++) { - // c[py, px] = p[cy,cx] + p[cy+1,cx] + - // p[cy,cx+1] + p[cy+1,cx+1] - // } - // } - - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - Tensor p = Compute( - "prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) { - return px * py; - }); - Tensor c = - Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) + - p.load(y + 1, x + 1); - }); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); - } - } - LoopNest orig_loopnest({c}, {p, c}); - - { - // First let's try to compute P at axis cy (the outer loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) -# CHECK: for -# CHECK: for -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) -# CHECK-NOT: prod[ -# CHECK: cons[ -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute P at axis cx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) -# CHECK: for -# CHECK: for -# CHECK-NOT: prod[ -# CHECK: cons[ -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } -} - -TEST(LoopNest, LoopNestComputeAt_3) { - // Verify that compute_at works on the following example: - // - // A(x,y) = x*y - // B(x,y) = A(x, y) - // C(x,y) = B(x+1, y) - // D(x,y) = A(x, y+1) + C(x, y) - // - // i.e. when 'A' comes to 'D' directly and indirectly through 'C'. - - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - Tensor A = Compute( - "A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) { - return ax * ay; - }); - Tensor B = Compute( - "B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) { - return A.load(by, bx); - }); - Tensor C = - Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) { - return B.load(cy, cx + 1); - }); - Tensor D = - Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) { - return A.load(dy + 1, dx) + C.load(dy, dx); - }); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = (y + 1) * x + y * (x + 1); - } - } - - LoopNest orig_loopnest({D}, {A, B, C, D}); - { - // First let's try to compute A at axis dy (the outer loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1, W] -# CHECK: for (int i = 0; i < H + 1; i++) -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) -# CHECK: A[ -# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) -# CHECK: B[ -# CHECK: for (int i_4 = 0; i_4 < H; i_4++) -# CHECK: for (int i_5 = 0; i_5 < W; i_5++) -# CHECK: C[ -# CHECK: for (int i_6 = 0; i_6 < H; i_6++) -# CHECK: for (int i_7 = 0; i_7 < W; i_7++) -# CHECK-NOT: A[)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute A at axis dx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1, 1] -# CHECK: for (int i = 0; i < H + 1; i++) -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) -# CHECK: A[ -# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) -# CHECK: B[ -# CHECK: for (int i_4 = 0; i_4 < H; i_4++) -# CHECK: for (int i_5 = 0; i_5 < W; i_5++) -# CHECK: C[ -# CHECK: for (int i_6 = 0; i_6 < H; i_6++) -# CHECK: for (int i_7 = 0; i_7 < W; i_7++) -# CHECK-NOT: A[)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } -} - -using Axis = const VarHandle&; - -TEST(LoopNest, Reduce2dComputeAt) { - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - - Tensor p = Compute( - "prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; }); - Tensor c = Reduce( - "cons", - {H, W}, - Sum(), - [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); }, - {2, 2}); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); - } - } - LoopNest orig_loopnest({c}, {p, c}); - checkIR(orig_loopnest.root_stmt(), R"IR( -# CHECK: for (int i = 0; i < H + 1; i++) { -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) { -# CHECK: prod[i, i_1] = i_1 * i; -# CHECK: } -# CHECK: } -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) { -# CHECK: cons[i_2, i_3] = int(0); -# CHECK: for (int i_4 = 0; i_4 < 2; i_4++) { -# CHECK: for (int i_5 = 0; i_5 < 2; i_5++) { -# CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5}); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - - { - // First let's try to compute P at axis cy (the outer loop) - LoopNest l(orig_loopnest); - auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); - // FIXME: Calling simplify here breaks the IR: - // MALFORMED INPUT: could not find base node in Load - temp[...] - // l.simplify(); - l.eliminateDeadStores(); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] -# CHECK: for (int i = 0; i < H; i++) { -# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) { -# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0); -# CHECK: } -# CHECK: } -# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0); -# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: Free(temp); -)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute P at axis cx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); - l.simplify(); - l.eliminateDeadStores(); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] -# CHECK: for (int i = 0; i < H; i++) { -# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { -# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) { -# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1); -# CHECK: } -# CHECK: } -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0; -# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: Free(temp); -)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - assertAllEqual(c_data, c_ref); - } -} - -TEST(LoopNest, DISABLED_Conv1d_NH) { - // Lots of stuff is broken here. The computeAt swaps the axes for some odd - // reason. Even without that, the index flattener fails due to "dimensions - // mismatch in flatten index". - - int N = 4; - int H = 256; - int R = 3; - int Pad = 1; - BufHandle IP("input", {H}, kFloat); - - Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) { - auto cond = CompareSelect::make(h, Pad, 1, 0, kLT); - cond = CompareSelect::make(h, H + Pad, 1, cond, kGE); - return ifThenElse(cond, 0.f, IP.load(n, h - Pad)); - }); - Tensor B = Reduce( - "B", - {N, H}, - Sum(), - [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); }, - {R}); - LoopNest l({B}); - checkIR(l.root_stmt(), R"IR( -# CHECK: for (int np = 0; np < 4; np++) { -# CHECK: for (int hp = 0; hp < 258; hp++) { -# CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]); -# CHECK: } -# CHECK: } -# CHECK: for (int n = 0; n < 4; n++) { -# CHECK: for (int h = 0; h < 256; h++) { -# CHECK: B[n, h] = float(0); -# CHECK: for (int r = 0; r < 3; r++) { -# CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r}); -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - // FIXME: The current IR is totally broken. The body of the inlined loop is: - - // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0), - // 0.f, input[idx1 + 0, (idx0 + n) - 1]); - - // Which seems to mix up the axes. The CHECK below is my best guess at what - // the input "should" look like - - checkIR(l.root_stmt(), R"IR( -# CHECK: for (int n = 0; n < 4; n++) { -# CHECK: for (int idx0 = 0; idx0 < 1; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < 258; idx1++) { - temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]); -# CHECK: } -# CHECK: } -# CHECK: for (int h = 0; h < 256; h++) { -# CHECK: B[n, h] = float(0); -# CHECK: for (int r = 0; r < 3; r++) { -# CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r}); -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - - l.simplify(); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - SimpleIREvaluator cg(s, {IP, B}); - // auto At = at::ones({N, H}, at::kFloat); - auto At = at::arange(N * H, at::kFloat).reshape({N, H}); - auto Rt = at::conv1d( - At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3); - auto Bt = at::empty_like(Rt); - cg.call({At.data_ptr(), Bt.data_ptr()}); - ASSERT_TRUE(at::allclose(Rt, Bt)); -} - -class LoopOrderHelper : public IRVisitor { - std::stringstream ordering; - - public: - std::string getOrder(StmtPtr s) { - ordering.str(""); - s->accept(this); - return ordering.str(); - } - - void visit(const ForPtr& v) final { - ordering << v->var()->name_hint() << ","; - IRVisitor::visit(v); - } -}; - -TEST(LoopNest, LoopNestReorderAxis1) { - Tensor tensor = - Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector stmt1_output(6, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - ASSERT_NE(stmt1, stmt2); - LoopOrderHelper loopOrderHelper; - std::string order1 = loopOrderHelper.getOrder(stmt1); - std::string order2 = loopOrderHelper.getOrder(stmt2); - - ASSERT_EQ(order1, "j,i,"); - ASSERT_EQ(order2, "i,j,"); - - std::vector stmt2_output(6, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg.call({stmt2_output}); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } - - // Reorder them back. - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - StmtPtr stmt3 = l.root_stmt(); - - std::string order3 = loopOrderHelper.getOrder(stmt3); - ASSERT_EQ(order3, order1); - - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt3; - - // Should be identical to the unreordered statement. - ASSERT_EQ(oss1.str(), oss2.str()); -} - -TEST(LoopNest, LoopNestReorderPartialAxes) { - Tensor tensor = Compute( - "f", - {2, 3, 4}, - [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,"); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,"); - - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } - - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[1], loops[2]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,"); - - StmtPtr stmt3 = Stmt::clone(l.root_stmt()); - - std::vector stmt3_output(24, 0); - SimpleIREvaluator cg3(stmt3, {tensor}); - cg3.call({stmt3_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt3_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderInternalAxis) { - Tensor tensor = Compute( - "f", - {1, 2, 3, 4}, - [](const VarHandle& w, - const VarHandle& x, - const VarHandle& y, - const VarHandle& z) { - return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,"); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[2], loops[1]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,"); - - StmtPtr stmt2 = l.root_stmt(); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderEnclosingAxis) { - Tensor tensor = Compute( - "f", - {1, 2, 3, 4}, - [](const VarHandle& w, - const VarHandle& x, - const VarHandle& y, - const VarHandle& z) { - return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[3]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,"); - - StmtPtr stmt2 = l.root_stmt(); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderSameAxis) { - Tensor tensor = - Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt1 = Stmt::clone(l.root_stmt()); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[1], loops[1]); - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::ostringstream oss, oss2; - oss << *stmt1; - oss2 << *stmt2; - ASSERT_EQ(oss.str(), oss2.str()); -} - -TEST(LoopNest, LoopNestReorderExtraStatements) { - /* We're going for a structure like this: - * for i in ... - * Stmt 1 - * for j in ... - * Stmt 2 - * for k in ... - * Stmt 3 - * Stmt 4 - */ - - Tensor tensor = Compute( - "f", - {2, 3, 4}, - [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - BufHandle extra("res", {6, 3}, kFloat); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - VarHandle i = VarHandle(loops[0]->var()); - - StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f); - StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f); - // stmt 3 is the Function body. - StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f); - - loops[0]->body()->prepend_stmt(store_1); - loops[1]->body()->prepend_stmt(store_2); - loops[1]->body()->append_stmt(store_3); - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector extra1(6, 0); - std::vector res1(24, 0); - SimpleIREvaluator cg(stmt1, {tensor, extra}); - cg.call({res1, extra1}); - - /* Then we reorder loop y and z, we want it to look like: - * - * for i in ... - * Stmt 1 - * for j in ... - * Stmt 2 - * for j_1 in ... - * for k in ... - * Stmt 3 - * for j_2 in ... - * Stmt 4 - * - * We need extra loops because we don't have dependency info about stmt 3 - * and 4. - * - */ - - LoopNest::reorderAxis(loops[1], loops[2]); - StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - // Check the IR we produced - checkIR(stmt2, R"IR( -# CHECK: for -# CHECK: res[i, 0] = 1 -# CHECK: for -# CHECK: res[i, 1] = 2 -# CHECK: for -# CHECK: for -# CHECK: f[ -# CHECK: for -# CHECK: res[i, 2] = 4 -)IR"); - - std::vector extra2(6, 0); - std::vector res2(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor, extra}); - cg2.call({res2, extra2}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(res1[i], res2[i]); - } - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(extra1[i], extra2[i]); - } - - /* Now reorder x and the y above stmt 3: - * - * - * for x in ... - * Stmt 1 - * for y in ... - * Stmt 2 - * - * for y in ... - * for z in ... - * for x in ... - * Stmt 3 - * - * for x in ... - * for y in ... - * Stmt 4 - * - * - */ - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[2]); - StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - // Check the IR we produced - checkIR(stmt3, R"IR( -# CHECK: for -# CHECK: res[i, 0] = 1 -# CHECK: for -# CHECK: res[i, 1] = 2 -# CHECK: for -# CHECK: for -# CHECK: for -# CHECK: f[ -# CHECK: for -# CHECK: for -# CHECK: res[i_2, 2] = 4 -)IR"); - - std::vector extra3(6, 0); - std::vector res3(24, 0); - SimpleIREvaluator cg3(stmt3, {tensor, extra}); - cg3.call({res3, extra3}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(res1[i], res3[i]); - } - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(extra1[i], extra3[i]); - } -} - -void LoopNestReorderTestHelper( - bool prepend, - bool append, - int index1, - int index2) { - Tensor c = Compute( - "5d", {2, 3, 2, 3, 2}, [](const std::vector&) { return -1; }); - LoopNest l({c}); - - BufHandle extra("extra", {5}, kInt); - - auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - int j = 0; - for (auto l : loops) { - // Add an increment at each layer of the loop which counts the number of - // times the loop executes. - LoadPtr load = - alloc(extra.node(), std::vector({alloc(j)})); - AddPtr add = alloc(load, alloc(1)); - StmtPtr store = alloc( - extra.node(), std::vector({alloc(j)}), add); - if (prepend) { - l->body()->prepend_stmt(store); - } - if (append) { - l->body()->append_stmt(Stmt::clone(store)); - } - - j++; - } - - StmtPtr stmt1 = Stmt::clone(l.root_stmt()); - - std::vector extra1(5, 0); - std::vector res1(2 * 3 * 2 * 3 * 2, 0); - SimpleIREvaluator cg(stmt1, {c, extra}); - cg.call({res1, extra1}); - - std::vector loopExtents = {2, 3, 2, 3, 2}; - - int expected_loops = 0; - if (prepend) { - expected_loops++; - } - if (append) { - expected_loops++; - } - for (int i = 0; i < 5; ++i) { - expected_loops *= loopExtents[i]; - ASSERT_EQ(extra1[i], expected_loops); - } - - loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::reorderAxis(loops[index1], loops[index2]); - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::ostringstream oss, oss2; - oss << *stmt1; - oss2 << *stmt2; - ASSERT_NE(oss.str(), oss2.str()); - - std::vector extra2(5, 0); - std::vector res2(2 * 3 * 2 * 3 * 2, 0); - SimpleIREvaluator cg2(stmt2, {c, extra}); - cg2.call({res2, extra2}); - - expected_loops = 0; - if (prepend) { - expected_loops++; - } - if (append) { - expected_loops++; - } - - for (int i = 0; i < 5; ++i) { - expected_loops *= loopExtents[i]; - ASSERT_EQ(extra2[i], expected_loops); - } - - for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) { - ASSERT_EQ(res2[i], res1[i]); - } -} - -TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(true, false, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(false, true, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderLongStringFull) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(true, true, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderInternalLoopNest) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - Tensor z = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + y.load(m, n, k); - }); - - LoopNest l({z}, {x, y, z}); - ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2]; - ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0]; - LoopNest::reorderAxis(a, b); - - l.prepareForCodegen(); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - - // Check the IR we produced has the 3 nests in the right order, but k and m - // swapped in the middle. - checkIR(stmt, R"IR( -# CHECK: < 4 -# CHECK: < 5 -# CHECK: < 6 -# CHECK: < 6 -# CHECK: < 5 -# CHECK: < 4 -# CHECK: < 4 -# CHECK: < 5 -# CHECK: < 6)IR"); - - { - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer z_v(M, N, K); - PaddedBuffer z_ref(M, N, K); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); - eval(a_v, b_v, c_v, d_v, z_v); - ExpectAllNear(z_v, z_ref, 1e-5); - } -} - -TEST(LoopNest, OuterLoopVectorization) { - Tensor tensor = - Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - - ASSERT_TRUE( - LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0])); - - StmtPtr root_stmt = l.root_stmt(); - BlockPtr outer_block = to(root_stmt); - ASSERT_NE(outer_block, nullptr); - while (BlockPtr inner_block = to(outer_block->front())) { - outer_block = inner_block; - } - - // Verify that we have only a single loop level remaining after - // vectorization. - ASSERT_EQ(outer_block->nstmts(), 1); - ForPtr for_loop = to(outer_block->front()); - ASSERT_NE(for_loop, nullptr); - BlockPtr for_body = for_loop->body(); - ASSERT_EQ(for_body->nstmts(), 1); - ASSERT_EQ(to(for_body->front()), nullptr); -} - -TEST(LoopNest, VectorizeLoopNotNormalized) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 1; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 1, 5, for_body); - auto outer_for = For::make(i, 0, 10, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - ASSERT_TRUE(LoopNest::vectorize(inner_for)); - ASSERT_EQ(outer_for->body()->nstmts(), 1); - ASSERT_EQ(to(outer_for->body()->front()), nullptr); -} - -namespace { - -std::string constantUpperBoundLoopIR(int upper_bound_val) { - ExprHandle upper_bound(upper_bound_val); - Tensor A = - Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - std::ostringstream oss; - oss << *unrolled; - return oss.str(); -} - -} // namespace - -TEST(LoopNest, Unroll) { - const std::string actual = constantUpperBoundLoopIR(3); - const std::string& verification_pattern = - R"IR( -# CHECK: A[0] = 0; -# CHECK: A[1] = 2; -# CHECK: A[2] = 4)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, actual); -} - -TEST(LoopNest, UnrollOuter) { - ExprHandle outer_bound(3); - ExprHandle inner_bound(4); - Tensor A = Compute( - "A", - {outer_bound, inner_bound}, - [&](const VarHandle& x, const VarHandle& y) { return x + y; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - checkIR(unrolled, R"IR( -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[0, i] = i; -# CHECK: } -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[1, i] = i + 1; -# CHECK: } -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[2, i] = i + 2; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollInner) { - ExprHandle outer_bound(3); - ExprHandle inner_bound(4); - Tensor A = Compute( - "A", - {outer_bound, inner_bound}, - [&](const VarHandle& x, const VarHandle& y) { return x + y; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll( - static_to(loops[0]->body()->stmts().front()), &unrolled); - checkIR(loops[0], R"IR( -# CHECK: for (int i = 0; i < 3; i++) { -# CHECK: A[i, 0] = i; -# CHECK: A[i, 1] = i + 1; -# CHECK: A[i, 2] = i + 2; -# CHECK: A[i, 3] = i + 3; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollMultipleStatements) { - const int kTotalSize = 3; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle x("x", kInt); - auto f = For::make( - x, - 0, - kTotalSize, - Block::make( - {Store::make(a_buf, {x}, x * 2), - Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); - auto parent_block = Block::make({f}); - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(f, &unrolled); - checkIR(unrolled, R"IR( -# CHECK: A[0] = 0; -# CHECK: B[0] = A[0]; -# CHECK: A[1] = 2; -# CHECK: B[1] = A[1]; -# CHECK: A[2] = 4 -# CHECK: B[2] = A[2];)IR"); -} - -TEST(LoopNest, UnrollNonLiteralConstantBounds) { - // Input IR: - // for (int i = 2 - 1; i < 12 / 3; i++) { - // for (int j = 0; j < 4; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {3, 4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 4, for_body); - auto outer_for = For::make( - i, - IntImm::make(2) - IntImm::make(1), - IntImm::make(12) / IntImm::make(3), - inner_for); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto b = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - checkIR(unrolled, R"IR( -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[1, j] = j; -# CHECK: } -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[2, j] = 2 * j; -# CHECK: } -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[3, j] = 3 * j; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollNonConstantBounds) { - // Input IR: - // for (int i = 0; i < M; i++) { - // for (int j = 0; j < N; j++) { - // A[i, j] = i * j; - // } - // } - VarHandle M("M", kInt); - VarHandle N("N", kInt); - BufHandle a_buf("A", {M, N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, N, for_body); - auto outer_for = For::make(i, 0, M, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - LoopNest::unroll(inner_for, 8); - l.simplify(); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) { - # CHECK: A[i, 8 * j_outer] = - # CHECK: A[i, 8 * j_outer + 1] = - # CHECK: A[i, 2 * (4 * j_outer + 1)] = - # CHECK: A[i, 8 * j_outer + 3] = - # CHECK: A[i, 4 * (2 * j_outer + 1)] = - # CHECK: A[i, 8 * j_outer + 5] = - # CHECK: A[i, 8 * j_outer + 6] = - # CHECK: A[i, 8 * j_outer + 7] = - # CHECK: } - # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) { - # CHECK: A[i, 8 * (N / 8) + j_tail] = - # CHECK: } - # CHECK: } - )IR"); -} - -TEST(LoopNest, UnrollByFactorsLessThan2) { - // Input IR: - // for (int i = 0; i < M; i++) { - // for (int j = 0; j < N; j++) { - // A[i, j] = i * j; - // } - // } - VarHandle M("M", kInt); - VarHandle N("N", kInt); - BufHandle a_buf("A", {M, N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, N, for_body); - auto outer_for = For::make(i, 0, M, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - // Unrolling by factor = 1 should do nothing. - LoopNest::unroll(inner_for, 1); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); - - // Unrolling by factor = 0 should do nothing. - LoopNest::unroll(inner_for, 0); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); - - // Unrolling by negative factor should do nothing. - LoopNest::unroll(inner_for, -2); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); -} - -TEST(LoopNest, UnrollByFactorEqualToIters) { - // Input IR: - // for (int i = 0; i < 5; i++) { - // A[i] = i * i; - // } - BufHandle a_buf("A", {5}, kInt); - VarHandle i("i", kInt); - auto for_body = Block::make({Store::make(a_buf, {i}, i * i)}); - auto for_loop = For::make(i, 0, 5, for_body); - auto block = Block::make({for_loop}); - LoopNest l(block, {a_buf.node()}); - - LoopNest::unroll(for_loop, 5); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++) - # CHECK: A[5 * i_outer] - # CHECK: A[5 * i_outer + 1] - # CHECK: A[5 * i_outer + 2] - # CHECK: A[5 * i_outer + 3] - # CHECK: A[5 * i_outer + 4] - )IR"); -} - -TEST(LoopNest, UnrollEmpty) { - const std::string actual = constantUpperBoundLoopIR(0); - const std::string& verification_pattern = R"IR( -# CHECK-NOT: A[ - )IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, actual); -} - -TEST(LoopNest, NoUnroll) { - VarHandle upper_bound("N", kInt); - Tensor A = - Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - ASSERT_THROWS_WITH( - LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop"); -} - -TEST(LoopNest, UnrollWithLet) { - const int kTotalSize = 3; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle e("e", kInt); - VarHandle x("x", kInt); - auto f = For::make( - x, - 0, - kTotalSize, - Block::make( - {Let::make(e, 7), - Store::make(a_buf, {x}, e), - Store::make(b_buf, {x}, e + 1)})); - auto parent_block = Block::make({f}); - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(f, &unrolled); - std::ostringstream oss; - oss << *unrolled; - const std::string& verification_pattern = - R"IR( -# CHECK: int e = 7; -# CHECK: A[0] = e; -# CHECK: B[0] = e + 1; -# CHECK: A[1] = e; -# CHECK: B[1] = e + 1; -# CHECK: A[2] = e; -# CHECK: B[2] = e + 1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector a_v(kTotalSize, 0); - std::vector b_v(kTotalSize, 0); - SimpleIREvaluator eval(unrolled, {a_buf, b_buf}); - eval(a_v, b_v); - for (int i = 0; i < kTotalSize; ++i) { - ASSERT_EQ(a_v[i], 7); - ASSERT_EQ(b_v[i], 8); - } -} - -TEST(LoopNest, IsNormalized) { - // Input IR: - // for (int i = 50; i < 100; i++) { - // A[i] = B[i]; - // } - BufHandle a_buf("A", {ExprHandle(100)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto for_stmt = - For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i}))); - Block::make({for_stmt}); - ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); - - for_stmt->set_start(alloc(0)); - ASSERT_TRUE(LoopNest::isNormalized(for_stmt)); - - VarHandle N("N", kInt); - for_stmt->set_start(N.node()); - ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); -} - -TEST(LoopNest, NormalizeStartPositive) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - const int kTotalSize = 50; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, 50, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 50; x++) { - # CHECK: A[x + 50] = B[x + 50]; - # CHECK: B[x + 50] = 2 * (x + 50); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartNegative) { - // Input IR: - // for (int x = -50; x < 100; x++) { - // A[x + 50] = B[x + 50]; - // B[x + 50] = x * 2; - // } - const int kTotalSize = 150; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})), - Store::make(b_buf, {x + 50}, x * 2)}); - auto for_stmt = For::make(x, -50, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 150; x++) { - # CHECK: A[x] = B[x]; - # CHECK: B[x] = 2 * (x - 50); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartZero) { - // Input IR: - // for (int x = 0; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - // Should not be modified. - - const int kTotalSize = 100; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, 0, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 100; x++) { - # CHECK: A[x] = B[x]; - # CHECK: B[x] = 2 * x; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartVariable) { - // Input IR: - // for (int x = y; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - - const int kTotalSize = 100; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, y, 100, for_body); - auto parent_block = Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 100 - y; x++) { - # CHECK: A[x + y] = B[x + y]; - # CHECK: B[x + y] = 2 * (x + y); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeOnNestedOuterLoop) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // for (int y = 10; y < 100; y++) { - // A[x] = A[x] + B[y] + y * 2; - // } - // } - - BufHandle a_buf("A", {ExprHandle(50)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto inner_for_body = Store::make( - a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); - auto inner_for = For::make(y, 10, 100, inner_for_body); - auto for_stmt = For::make(x, 50, 100, inner_for); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 50; x++) { - # CHECK: for (int y = 10; y < 100; y++) { - # CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeOnNestedInnerLoop) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // for (int y = 10; y < 100; y++) { - // A[x] = A[x] + B[y] + y * 2; - // } - // } - - BufHandle a_buf("A", {ExprHandle(50)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto inner_for_body = Store::make( - a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); - auto inner_for = For::make(y, 10, 100, inner_for_body); - auto for_stmt = For::make(x, 50, 100, inner_for); - Block::make({for_stmt}); - - LoopNest::normalize(inner_for); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 50; x < 100; x++) { - # CHECK: for (int y = 0; y < 90; y++) { - # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeAndSplitWithTail) { - // Create a dummy tensor to construct LoopNest. - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - // Input IR: - // for (int x = 5; x < 10; x++) { - // A[x] = x * 2; - // } - const int kTotalSize = 5; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2)); - auto parent_block = Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - ForPtr x_inner; - ForPtr x_tail; - LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); - - auto x_outer_result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss_outer; - oss_outer << *x_outer_result; - const std::string& expected_outer_ir = - R"IR( - # CHECK: { - # CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); - - auto x_tail_result = IRSimplifier::simplify(x_tail); - std::ostringstream oss_tail; - oss_tail << *x_tail_result; - const std::string& expected_tail_ir = - R"IR( - # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) { - # CHECK: A[x_tail + 5] = 2 * (x_tail + 5); - )IR"; - torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); -} - -TEST(LoopNest, NotNormalizeAndSplitWithTail) { - // Create a dummy tensor to construct LoopNest. - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - // Input IR: - // for (int x = 5; x < 15; x++) { - // A[x] = x * 2; - // } - const int kTotalSize = 10; - BufHandle a_buf("A", {kTotalSize}, kInt); - VarHandle x("x", kInt); - auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2)); - auto parent_block = Block::make({for_stmt}); - - ForPtr x_inner; - ForPtr x_tail; - LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail); - - auto x_outer_result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss_outer; - oss_outer << *x_outer_result; - const std::string& expected_outer_ir = - R"IR( - # CHECK: { - # CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); - - auto x_tail_result = IRSimplifier::simplify(x_tail); - std::ostringstream oss_tail; - oss_tail << *x_tail_result; - const std::string& expected_tail_ir = - R"IR( - # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) { - # CHECK: A[x_tail + 13] = 2 * (x_tail + 13); - )IR"; - torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); -} - -TEST(LoopNest, FlattenSimpleLoopNest2D) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 5, for_body); - auto outer_for = For::make(i, 0, 10, inner_for); - auto parent_block = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { - # CHECK: A[i_flat / 5, i_flat % 5] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenSimpleLoopNest3D) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // for (int k = 0; k < 7; k++) { - // A[i,j,k] = i + j * k; - // } - // } - // } - BufHandle a_buf("A", {10, 5, 7}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)}); - auto for1 = For::make(k, 0, 7, for_body); - auto for2 = For::make(j, 0, 5, for1); - auto for3 = For::make(i, 0, 10, for2); - auto parent_block = Block::make({for3}); - - std::vector loops = {for3, for2, for1}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) { - # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5, 7); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5, 7); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenLoopNestAfterNormalize) { - // Input IR: - // for (int i = 2; i < 10; i++) { - // for (int j = 3; j < 15; j++) { - // A[i - 2,j - 3] = i * j; - // } - // } - BufHandle a_buf("A", {8, 12}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)}); - auto inner_for = For::make(j, 3, 15, for_body); - auto outer_for = For::make(i, 2, 10, inner_for); - auto parent_block = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) { - # CHECK: A[i_flat / 12, i_flat % 12] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(8, 12); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(8, 12); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { - // Input IR: - // for (int i = 0; i < 15-5; i++) { - // for (int j = 0; j < 20/4; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = - For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body); - auto outer_for = - For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto b = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - checkIR(result, R"IR( - # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { - # CHECK: A[i_flat / 5, i_flat % 5] = - )IR"); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenImperfectLoopNest) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // A[i, i] = 0; - // for (int j = 0; j < 15; j++) { - // A[i,j] = i * j; - // } - // } - // Do not flatten. - - BufHandle a_buf("A", {10, 15}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 15, for_body); - auto outer_for = For::make( - i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for})); - auto par = Block::make({outer_for}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenReductionLoopNest) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // S[i] = 0; - // for (int j = 0; j < 15; j++) { - // S[i] = S[i] + A[i,j]; - // } - // } - // Do not flatten. - - BufHandle a_buf("A", {10, 15}, kInt); - BufHandle s_buf("S", {10}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make( - s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))}); - auto inner_for = For::make(j, 0, 15, for_body); - auto outer_for = - For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for})); - auto par = Block::make({outer_for}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenReductionLoopNestFromTensor) { - const int M = 3; - const int N = 7; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle b("b", {m, n}, kFloat); - Tensor c = Reduce("sum", {M}, Sum(), b, {N}); - LoopNest loop({c}); - HashProvider hasher; - auto hash_before = hasher.hash(loop.root_stmt()); - - auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1]; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(loop.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenIncorrectLoopsAsInput) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - // for (int x = 0; x < 10; x++) { - // for (int y = 0; y < 5; y++) { - // A[x,y] = A[x,y] + x + y; - // } - // } - // Flatten({For_i, For_y}) => should not succeed - - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - auto par = Block::make({outer_for1, outer_for2}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for1, inner_for2}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, DetectInlineRankMismatch) { - const int kTotalSize = 8; - - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - Tensor a = Compute( - "a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); }); - Tensor reshape = Compute( - "reshape", - {kTotalSize / 2, 2}, - [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); }); - LoopNest l({reshape}, {a, reshape}); - ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a))); -} - -TEST(LoopNest, CacheReadsSimple) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 3); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - // just this once: verify the whole thing. - checkIR(result, R"IR( -#CHECK: Allocate(A); // dtype=int, dims=[64, 64] -#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10] -#CHECK: for (int i -#CHECK: for (int j -#CHECK: A[ -#CHECK: } -#CHECK: } -#CHECK: for (int i_1 -#CHECK: for (int j_1 -#CHECK: A_local[j_1] = A[ -#CHECK: } -#CHECK: for (int j_2 -#CHECK: B[j_2 + 10 * i_1] = A_local[j_2]; -#CHECK: } -#CHECK: } -#CHECK: for (int i_2 -#CHECK: for (int j_3 -#CHECK: C[ -#CHECK: } -#CHECK: } -#CHECK: Free(A_local); -#CHECK: Free(A); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 3); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsOuter) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0]; - LoopNest::cacheAccesses(A.buf(), "A_local", i_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] -#CHECK: A_local[j_1 + 11 * i_1] = -#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsInternal) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] -#CHECK: A_local[k + 11 * j_1] = -#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsInner) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - // note im changing the offset of the first arg of the first call to A. - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 34, j + 40) + A.load(i + 30, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr body = l.getLoopBodyFor(B); - LoopNest::cacheAccesses(A.buf(), "A_local", body); - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] -#CHECK: A_local[l + 2 * k] = -#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheWritesSimple) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", a_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64] -#CHECK: for (int j = 0; j < 64 -#CHECK: A_local[j] = i * j; -#CHECK: for (int j_1 = 0; j_1 < 64 -#CHECK: A[j_1 + 64 * i] = A_local[ -#CHECK: Free(A_local); -#CHECK-NOT: A_local - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, DeadStoreElimination) { - VarHandle y("y", kInt); - VarHandle x("x_tail", kInt); - BufHandle f("f", {26, 5}, kInt); - BufHandle g("g", {26, 5}, kInt); - ExprHandle x_outer_end = 5; - ExprHandle x_2 = x + x_outer_end * 4; - ForPtr stmt1 = For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - Block::make({ - Store::make(f, {x_2, y}, (x_2 + y)), - Store::make(g, {x_2, y}, (x_2 * y)), - }))); - StmtPtr stmt = Block::make({stmt1}); - - // Will eliminate if not used by an output. - LoopNest loop(Stmt::clone(stmt), {f.node()}); - loop.eliminateDeadStores(); - - checkIR(loop.root_stmt(), R"IR( -#CHECK: f[x_tail + 5 * 4, y] -#CHECK-NOT: g[x_tail + 5 * 4, y] - )IR"); - - // But won't eliminate if used by different outputs. - LoopNest loop2(stmt, {f.node(), g.node()}); - loop2.eliminateDeadStores(); - - checkIR(loop2.root_stmt(), R"IR( -#CHECK: f[x_tail + 5 * 4, y] -#CHECK: g[x_tail + 5 * 4, y] - )IR"); -} - -TEST(LoopNest, DeadStoreEliminationWithIntermediates) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - BufHandle f("f", {26 * 5}, kInt); - BufHandle g("g", {26 * 5}, kInt); - BufHandle h("h", {26, 5}, kInt); - ExprHandle x_outer_end = 5; - ExprHandle x_2 = x + x_outer_end * 4; - ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); - ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); - ForPtr stmt3 = For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - Block::make({ - Store::make(h, {x, y}, Load::make(f, {x * y})), - }))); - StmtPtr stmt = Block::make({stmt1, stmt2, stmt3}); - - // Will eliminate the write to g, but not f since it used by the producer of - // h. - LoopNest loop(Stmt::clone(stmt), {h.node()}); - loop.eliminateDeadStores(); - - checkIR(loop.root_stmt(), R"IR( - #CHECK: f[x] = x; - #CHECK-NOT: g[z] = - #CHECK: h[x, y] = f[x * y]; - )IR"); - - // Sanity check won't eliminate if g is an output. - LoopNest loop2(stmt, {h.node(), g.node()}); - loop2.eliminateDeadStores(); - - checkIR(loop2.root_stmt(), R"IR( - #CHECK: f[x] = x; - #CHECK: g[z] = z + 1; - #CHECK: h[x, y] = f[x * y]; - )IR"); -} - -TEST(LoopNest, CompoundTensorSimple) { - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - BlockPtr body = Block::make({outer_for1, outer_for2}); - - Tensor A = Tensor(a_buf.node(), body); - - LoopNest l({A}); - l.prepareForCodegen(); - - std::vector a_data(50, 0); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg(s, {A}); - - std::vector a_ref(50, 0); - - for (int i = 0; i < 10; ++i) { - for (int j = 0; j < 5; ++j) { - a_ref[i * 5 + j] = (i * j) + i + j; - } - } - cg.call({a_data}); - - assertAllEqual(a_data, a_ref); -} - -TEST(LoopNest, InlineConstantIndex) { - const int N = 10; - BufHandle x_buf("a", {1, N, 1}, kFloat); - Tensor y = Compute( - "f", - {1, N, 1}, - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { - return x_buf.load(m, n, o); - }); - Tensor z = Compute( - "f", - {1, N, 1}, - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { - return y.load(m, n, o); - }); - - LoopNest l({z}, {y, z}); - l.simplify(); - ASSERT_TRUE(l.computeInline(y.buf())); -} - -TEST(LoopNest, CompoundTensorUsed) { - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - BlockPtr body = Block::make({outer_for1, outer_for2}); - - Tensor A = Tensor(a_buf.node(), body); - Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j + 1) + A.load(i, j + 2); - }); - - LoopNest l({B}, {A, B}); - ASSERT_FALSE(l.computeInline(A.buf())); - l.prepareForCodegen(); - - std::vector a_data(50, 0); - std::vector b_data(50, 0); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg(s, {B}); - - std::vector b_ref(50, 0); - - auto AT = [](int i, int j) { return i * j + i + j; }; - for (int i = 0; i < 10; ++i) { - for (int j = 0; j < 3; ++j) { - b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); - } - } - cg.call({b_data}); - - assertAllEqual(b_data, b_ref); -} - -TEST(LoopNest, InlineFromLoad) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store_a = For::make(i, 0, N, Store::make(a, {i}, i)); - auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j}))); - LoopNest l(Block::make({store_a, store_b}), {b.node()}); - - l.computeInline(a.node()); - - // Check that A[j] is replaced with j after inlining - std::ostringstream oss; - oss << *l.root_stmt(); - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: for (int j -# CHECK-NOT: B[j] = A[j] -# CHECK-NEXT: B[j] = j -)IR", - oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsSimple) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {15}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 15 -# CHECK-NEXT: A[i + 5] = C[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsNestedConditions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK: for (int i = 0; i < 10 -# CHECK-NEXT: A[i + 10] = D[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsMultipleStores) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // } - // for (int j = 0; j < 100; j++) { - // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto storeA = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, storeA); - auto storeB = Store::make( - b_buf, - {j}, - IfThenElse::make( - CompareSelect::make(j, 30, kLT), - Load::make(c_buf, {j}), - Load::make(d_buf, {j}))); - auto forJ = For::make(j, 0, 100, storeB); - auto par = Block::make({forI, forJ}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 15 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK: for (int j = 0; j < 30 -# CHECK-NEXT: B[j] = C[j] -# CHECK: for (int j = 0; j < 70 -# CHECK-NEXT: B[j + 30] = D[j + 30] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { - // Input IR: - // for (int i = 0; i < 50; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) - // } - // Only the first conditional, in the write to A, will be optimized. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {100}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto storeA = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - auto storeB = Store::make( - b_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 30, kLT), - Load::make(c_buf, {i}), - Load::make(d_buf, {i}))); - auto forI = For::make(i, 0, 50, Block::make({storeA, storeB})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK-NEXT: B[i] = C[i] -# CHECK: for (int i = 0; i < 45 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) - // } - // } - // Currently, this case where the condition variable `i` is not the - // inner-most loop variable, is not optimized. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store)); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10]) - // } - // No optimization should be done here because one of the conditions use '>'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - VarHandle N("N", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, N, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsInvalidCondition) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10]) - // } - // No optimization should be done here because one of the conditions use '>'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kGT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsInvalidCondition2) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(10 colReduce(int M, int N) { - BufHandle a("a", {M, N}, kFloat); - Tensor t = Reduce( - "b", - {N}, - Sum(), - [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); }, - {M}); - return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))}; -} - -static StmtPtr splitTailReorder(Tensor b) { - constexpr int kVectorWidth = 8; - LoopNest nest({b}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; - nest.splitWithTail(loops[0], kVectorWidth); - // Now the loopnests will look like: - // - // for (int i_outer = 0; ... - // for (int i_inner = 0; ... - // b[i_outer * 8 + i_inner] = float(0); - // for (int j = 0; ... - // b[i_outer * 8 + i_inner] = ReduceOp(...); - // - // for (int i_tail = 0; ... - // b[i_tail + ((100 - 0) / 8) * 8] = float(0); - // for (int j = 0; ... - // b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...); - // - // Since there are 4 writes to b, we will get 4 loopnests from the - // call to `getAllLoopNestsWritingToBuf` below. - // - // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)" - // Loopnest #2: {i_outer, i_inner, j}; - // We will have to reorder i_inner and j. - auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf()); - LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]); - nest.prepareForCodegen(); - return nest.root_stmt(); -} - -static StmtPtr splitMaskReorder(Tensor b) { - constexpr int kVectorWidth = 8; - LoopNest nest({b}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; - nest.splitWithMask(loops[0], kVectorWidth); - loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; - LoopNest::reorderAxis(loops[1], loops[2]); - nest.prepareForCodegen(); - return nest.root_stmt(); -} - -static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) { - int M = immediateAs(p.dim(0)); - int N = immediateAs(p.dim(1)); - PaddedBuffer a(M, N); - PaddedBuffer b(N); - PaddedBuffer ref(N); - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a(i, j) = 1.0f; - } - } - for (int i = 0; i < N; i++) { - b(i) = 0.0f; - } - for (int i = 0; i < N; i++) { - ref(i) = 76.0f; - } - SimpleIREvaluator(s, {p, t}).call({a, b}); - ExpectAllNear(b, ref, 1e-5); -} - -TEST(LoopNest, ColReduceSplitTailEvenReorder) { - constexpr int M = 76, N = 128; - auto p = colReduce(M, N); - StmtPtr s = splitTailReorder(p.second); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i_outer -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int j -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitTailUnevenReorder) { - constexpr int M = 76, N = 100; - auto p = colReduce(M, N); - StmtPtr s = splitTailReorder(p.second); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i_outer -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int j -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int i_tail -# CHECK-NEXT: b[ -# CHECK-NEXT: for (int j -# CHECK-NEXT: b[ - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitMaskEvenReorder) { - constexpr int M = 76, N = 128; - auto p = colReduce(M, N); - StmtPtr s = splitMaskReorder(p.second); - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { - constexpr int M = 76, N = 100; - auto p = colReduce(M, N); - StmtPtr s = splitMaskReorder(p.second); - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ReorderAxisWithMultipleConds) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // if i > 5 { - // if i < 10 { - // for (int j = 0; j < 100; j++) { - // A[i] = i * j; - // } - // } - // } - // } - BufHandle a_buf("A", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j))); - auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr); - auto outer_cond = - Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr); - auto forI = For::make(i, 0, 20, outer_cond); - StmtPtr par = Block::make({forI}); - LoopNest l(par, {a_buf.node()}); - LoopNest::reorderAxis(forI, forJ); - ASSERT_EQ(par, l.root_stmt()); - par = IRSimplifier::simplify(par); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: for (int i -# CHECK-NEXT: if (i>5 -# CHECK-NEXT: if (i<10 -# CHECK-NEXT: A[i] = i * j -# CHECK-NOT: for ( - )IR"; - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, VectorizeUse) { - constexpr int N = 8; - BufHandle a("a", {N}, kFloat); - Tensor b = - Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); - Tensor c = - Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; }); - LoopNest nest({c}, {b, c}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; - ASSERT_TRUE(LoopNest::vectorize(loops[0])); - loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0]; - ASSERT_TRUE(LoopNest::vectorize(loops[0])); - nest.prepareForCodegen(); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - StmtPtr s = nest.root_stmt(); - std::ostringstream oss; - oss << *nest.root_stmt(); - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: c[Ramp -)IR", - oss.str()); -} - -const char* int64Loop = R"IR( -# CHECK: for (int64_t i = 0ll; i < 12ll; i++) { -# CHECK: b[i] = (a[i]) + 1ll; -# CHECK: } -)IR"; - -TEST(LoopNest, Int64Direct) { - constexpr int64_t N = 12; - BufHandle a("a", {N}, kLong); - BufHandle b("b", {N}, kLong); - VarHandle n("i", kLong); - StmtPtr s = For::make( - n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(int64Loop, oss.str()); -} - -TEST(LoopNest, Int64Compute) { - constexpr int64_t N = 12; - BufHandle a("a", {N}, kLong); - Tensor b = Compute("b", {N}, [&](const VarHandle& n) { - return a.load(n) + LongImm::make(1l); - }); - LoopNest nest({b}); - nest.prepareForCodegen(); - nest.simplify(); - std::ostringstream oss; - oss << *nest.root_stmt(); - torch::jit::testing::FileCheck().run(int64Loop, oss.str()); -} - -TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB}); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI, {forJ}); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopWithoutAnyPivot) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopOverInnerLoops) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { - // Input IR: - // for (int m = 0; m < 50; m++) { - // for (int i = 0; i < 20; i++) { - // A[m,i] = 0; - // for (int j = 0; j < 100; j++) { - // A[m,i] = A[m,i] + i * j; - // } - // B[m,i] = A[m,i]; - // for (int k = 0; k < 50; k++) { - // B[m,i] = B[m,i] + i * k; - // } - // } - // } - BufHandle a_buf("A", {100, 100}, kInt); - BufHandle b_buf("B", {100, 100}, kInt); - VarHandle m("m", kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m, i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, - {m, i}, - Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, - {m, i}, - Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - - { - // Check the case of distributing loop and its parents over all the - // statements in the loop. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: A[m, i] = 0 -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m, i] = -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: B[m, i] = A[m, i] -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m, i] = -# CHECK-NOT: for ( - )IR"; - - auto newForI = to(Stmt::clone(forI)); - auto forM = For::make(m, 0, 50, newForI); - auto par = Block::make({forM}); - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto newLoops = LoopNest::distributeLoopAndParents(newForI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(newLoops.front(), forM); - } - - { - // Check the case of distributing loop and its parents over all the inner - // loops. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: A[m, i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m, i] = -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: B[m, i] = A[m, i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m, i] = -# CHECK-NOT: for ( - )IR"; - - auto newForI = to(Stmt::clone(forI)); - auto forM = For::make(m, 0, 50, newForI); - auto par = Block::make({forM}); - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(newLoops.front(), forM); - } -} - -TEST(LoopNest, fuseLoopsSimple) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsMultiple) { - // Input IR: - // for (int i = 0; i < 100; i++) { - // A[i+100] = 20 + i; - // } - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forI = - For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forI, forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i + 100] = -# CHECK-NEXT: A[i] = -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsNested) { - // Input IR: - // for (int m = 0; m < 20; m++) { - // A[m] = 0; - // for (int j = 0; j < 100; j++) { - // A[m] = A[m] + m * j; - // } - // } - // for (int n = 0; n < 20; n++) { - // B[n] = A[n]; - // for (int k = 0; k < 50; k++) { - // B[n] = B[n] + n * k; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); - auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); - auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); - auto forN = For::make(n, 0, 20, Block::make({initB, forK})); - auto par = Block::make({forM, forN}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: A[m] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m] = -# CHECK: B[m] = A[m] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forM); -} - -TEST(LoopNest, fuseLoopsNested2D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // B[m,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto forI = For::make( - i, - 0, - 20, - For::make( - j, - 0, - 100, - Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); - auto forM = For::make( - m, - 0, - 20, - For::make( - n, - 0, - 50, - Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int n -# CHECK-NEXT: B[i, n] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsNested2DInner) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // for (int n = 0; n < 100; n++) { - // B[i,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle n("n", kInt); - auto forJ = For::make( - j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); - auto forN = For::make( - n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); - auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); - - std::ostringstream oss; - oss << *forI; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK-NEXT: B[i, j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsDifferentStopBounds) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 50; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsDifferentStartBounds) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 50; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsNotContiguous) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // B[0] = 0; - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto initB = Store::make(b_buf, {0}, 0); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, initB, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithDifferentParents) { - // Input IR: - // for (int i = 0; i < 50; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j; - // } - // } - // B[0] = 0; - // for (int k = 50; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {50, 100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j))); - auto forI = For::make(i, 0, 50, forJ); - auto initB = Store::make(b_buf, {0}, 0); - auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI, initB, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithVariableBounds) { - // Input IR: - // for (int j = 0; j < N; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithExprBounds) { - // Input IR: - // for (int j = 0; j < M + N; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < M + N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { - // Input IR: - // for (int j = M; j < N * 2; j++) { - // A[j] = 10 * j; - // } - // for (int k = M; k < N + N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k+100] = 30 * k - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k))); - auto par = Block::make({forJ, forK}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: A[j + 100] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // A[m+20,n+100] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int n -# CHECK-NEXT: A[i + 20, n + 100] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithReductions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0 - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + B[i,j]; - // } - // } - // for (int m = 0; m < 20; m++) { - // C[m] = A[m]; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - BufHandle c_buf("C", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto sumA = Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j}))); - auto forJ = For::make(j, 0, 100, sumA); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ})); - auto forM = - For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = (A[i]) + -# CHECK-NOT: for ( -# CHECK: C[i] = A[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWith2DReductions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 50; j++) { - // A[i,j] = 0 - // for (int k = 0; k < 100; k++) { - // A[i,j] = A[i,j] + B[i,j,k]; - // } - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 40; n++) { - // C[m,n] = A[m,n]; - // } - // } - BufHandle a_buf("A", {20, 50}, kInt); - BufHandle b_buf("B", {20, 50, 100}, kInt); - BufHandle c_buf("C", {20, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto initA = Store::make(a_buf, {i, j}, 0); - auto sumA = Store::make( - a_buf, - {i, j}, - Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k}))); - auto forK = For::make(k, 0, 100, sumA); - auto forJ = For::make(j, 0, 50, Block::make({initA, forK})); - auto forI = For::make(i, 0, 20, forJ); - auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK-NEXT: for (int k -# CHECK-NEXT: A[i, j] = (A[i, j]) + -# CHECK: for (int n -# CHECK-NEXT: C[i, n] = A[i, n] -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithComplexIndices) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,j*20+j+2] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[m,n*20+n+2]; - // } - // } - BufHandle a_buf("A", {20, 400}, kInt); - BufHandle b_buf("B", {20, 400}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = - Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j -# CHECK: for (int n -# CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2] -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,i*20+j] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[m,m*20+n]; // Both indices of A use m - // } - // } - BufHandle a_buf("A", {20, 500}, kInt); - BufHandle b_buf("B", {20, 500}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithTranspose) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,j] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[n,m]; // Transpose - // } - // } - BufHandle a_buf("A", {20, 20}, kInt); - BufHandle b_buf("B", {20, 20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, j}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies1) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k-1] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies2) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k+50] = 20 * k; - // } - BufHandle a_buf("A", {150}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies3) { - // Input IR: - // for (int m = 0; m < 20; m++) { - // A[m] = 0; - // for (int j = 0; j < 100; j++) { - // A[m] = A[m] + m * j; - // } - // } - // for (int n = 0; n < 20; n++) { - // B[n] = A[n+1]; - // for (int k = 0; k < 50; k++) { - // B[n] = B[n] + n * k; - // } - // } - BufHandle a_buf("A", {25, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); - auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); - auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); - auto forN = For::make(n, 0, 20, Block::make({initB, forK})); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forM, forN}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies4) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // A[m+1,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {30, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto forI = For::make( - i, - 0, - 20, - For::make( - j, - 0, - 100, - Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); - auto forM = For::make( - m, - 0, - 20, - For::make( - n, - 0, - 50, - Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies5) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // for (int n = 0; n < 100; n++) { - // A[i,n+1] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle n("n", kInt); - auto forJ = For::make( - j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); - auto forN = For::make( - n, - 0, - 100, - Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100)))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies6) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * A[99-k]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies7) { - // Input IR: - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * A[99-k]; - // } - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forK, forJ}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); -} - -TEST(LoopNest, areLoopsPerfectlyNested) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Specifying the loops in any other order fails. - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK})); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ})); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI})); - - // Adding a statement to forK body should be OK. - auto init = Store::make(a_buf, {i, j}, 0); - forK->body()->insert_stmt_before(init, store); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Adding a statement in forJ body should fail this test. - forK->body()->remove_stmt(init); - forJ->body()->insert_stmt_before(init, forK); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Similarly, adding a statement in forI body should fail this test. - forJ->body()->remove_stmt(init); - forI->body()->insert_stmt_before(init, forJ); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); -} - -TEST(LoopNest, reorderNestedLoops2D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store = Store::make(a_buf, {i, j}, Mul::make(i, j)); - auto forJ = For::make(j, 0, 30, store); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ}, {1, 0}); - - ASSERT_EQ(reordered[0], forJ); - ASSERT_EQ(reordered[1], forI); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI})); - ASSERT_EQ(forJ->get_parent(), par); - ASSERT_EQ(store->get_parent(), forI->body()); -} - -TEST(LoopNest, reorderNestedLoops3D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1}); - - ASSERT_EQ(reordered[0], forK); - ASSERT_EQ(reordered[1], forI); - ASSERT_EQ(reordered[2], forJ); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ})); - ASSERT_EQ(forK->get_parent(), par); - ASSERT_EQ(store->get_parent(), forJ->body()); -} - -TEST(LoopNest, reorderNestedLoops4D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // for (int l = 0; l < 50; l++) { - // A[i,j,k,l] = i * j * k * l * 500; - // } - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle l("l", kInt); - auto store = Store::make( - a_buf, - {i, j, k, l}, - Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500)); - auto forL = For::make(l, 0, 50, store); - auto forK = For::make(k, 0, 40, forL); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1}); - - ASSERT_EQ(reordered[0], forK); - ASSERT_EQ(reordered[1], forI); - ASSERT_EQ(reordered[2], forL); - ASSERT_EQ(reordered[3], forJ); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ})); - ASSERT_EQ(forK->get_parent(), par); - ASSERT_EQ(store->get_parent(), forJ->body()); -} - -TEST(LoopNest, reorderTrivialPermutation) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2}); - - ASSERT_EQ(reordered[0], forI); - ASSERT_EQ(reordered[1], forJ); - ASSERT_EQ(reordered[2], forK); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - ASSERT_EQ(forI->get_parent(), par); - ASSERT_EQ(store->get_parent(), forK->body()); -} - -TEST(LoopNest, reorderInvalidPermutations) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}), - "invalid permutation size"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 2}), - "invalid permutation size"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}), - "invalid permutation for reorder"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}), - "invalid permutation for reorder"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}), - "invalid permutation for reorder"); -} - -TEST(LoopNest, reorderInvalidLoopNest) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // A[i,j] = 0 - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - - // Specifying the loops in incorrect order fails. - ASSERT_THROWS_WITH( - LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); - - // Adding a statement to forJ loop fails. - auto init = Store::make(a_buf, {i}, 0); - forJ->body()->insert_stmt_before(init, forK); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); - - // Moving that statement to forI loop also fails. - forJ->body()->remove_stmt(init); - forI->body()->insert_stmt_before(init, forJ); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); -} - -TEST(LoopNest, compressBufferSimple) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferMultipleDims) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // B[i,j] = A[i,j] + A[i,j] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store1 = Store::make(aBuf, {i, j}, sin(i * j)); - auto store2 = Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j}))); - auto forJ = For::make(j, 0, 200, Block::make({store1, store2})); - auto forI = For::make(i, 0, 100, forJ); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, 0] = -# CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); -} - -TEST(LoopNest, compressBufferMultipleDims2) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // for (int k = 0; k < 300; ++k) { - // A[i,j,k] = sin(i*j*k) - // } - // for (int k = 0; k < 299; ++j) { - // B[i,j,k] = A[i,j,k] + A[i,j,k+1] - // } - // } - // } - BufHandle aBuf("A", {100, 200, 300}, kInt); - BufHandle bBuf("B", {100, 200, 300}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k)); - auto forK1 = For::make(k, 0, 300, store1); - auto store2 = Store::make( - bBuf, - {i, j, k}, - Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1}))); - auto forK2 = For::make(k, 0, 299, store2); - auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2})); - auto forI = For::make(i, 0, 100, forJ); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: for (int k -# CHECK-NEXT: A[0, 0, k] = -# CHECK: for (int k -# CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 3); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300); -} - -TEST(LoopNest, compressBufferDifferentOrderIndices) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[j, i] = sin(i*j) - // } - // for (int j = 0; j < 99; ++j) { - // B[i, j] = A[j, i] + A[j+1, 0] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 99, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[j, 0] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); -} - -TEST(LoopNest, compressBufferVariableBounds) { - // Input IR: - // for (int i = 0; i < M; ++i) { - // for (int j = 0; j < N; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int j = 0; j < N-1; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - N - 1, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferNoCommonParentLoops) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // } - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - auto forI1 = For::make(i, 0, 100, forJ1); - auto forI2 = For::make(i, 0, 100, forJ2); - auto par = Block::make({forI1, forI2}); - LoopNest::compressBuffer(aBuf.node(), par); - - // There should be no change in the buffer or code. - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferIndicesMixed) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i + j, j] = sin(i*j) - // } - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i + j, j] + A[i + j, j+1] - // } - // } - BufHandle aBuf("A", {300, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make( - Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - // There should be no change in the buffer or code. - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i + j, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressMultipleBuffers) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int k = 0; k < 199; ++k) { - // B[i,k] = A[i,k] + A[i, k+1] - // } - // for (int m = 0; m < 50; ++m) { - // C[i,m] = B[i,m] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - BufHandle cBuf("C", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forK = For::make( - k, - 0, - 199, - Store::make( - bBuf, - {i, k}, - Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1})))); - auto forM = - For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m}))); - auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM})); - auto par = Block::make({forI}); - - // This should compress all buffers A, B, and C as follows: - // A[100, 200] -> A[1, 200] - // B[100, 200] -> B[1, 200] - // C[100, 200] -> C[1, 1] - LoopNest::compressAllBuffers(par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int k -# CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1]) -# CHECK: for (int m -# CHECK-NEXT: C[0, 0] = B[0, m] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); - ASSERT_EQ(bBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200); - ASSERT_EQ(cBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1); -} - -TEST(LoopNest, sanitizeNames) { - std::vector dim_args; - // Let's pick names that would overlap with default index names if not - // sanitized properly: - dim_args.emplace_back(ExprHandle(alloc("i", kInt))); - dim_args.emplace_back(ExprHandle(alloc("N:2", kInt))); - // Now let's create a many dimensions so that we had to use the same letter - // for different loops - for (int i = 0; i < 10; i++) { - dim_args.emplace_back(ExprHandle(alloc("N", kInt))); - } - - // Now create two Computes with conflicting after sanitization names: - Tensor X = Compute("$X:!", dim_args, [&](const std::vector& v) { - return v[0] + v[1] + v[9] + 1; - }); - Tensor Y = Reduce( - "%X\"+", - {}, - Sum(), - [&](const std::vector& v) { return X.load(v); }, - dim_args); - - // Finally, let's verify what we got after sanitization: - LoopNest l({X, Y}); - StmtPtr s = l.root_stmt(); - LoopNest::sanitizeNames(s); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < i_1; i++) { -# CHECK-NEXT: for (int j = 0; j < N_2_1; j++) { -# CHECK-NEXT: for (int k = 0; k < N_9; k++) { -# CHECK-NEXT: for (int l = 0; l < N_8; l++) { -# CHECK-NEXT: for (int m = 0; m < N_7; m++) { -# CHECK-NEXT: for (int n = 0; n < N_6; n++) { -# CHECK-NEXT: for (int o = 0; o < N_5; o++) { -# CHECK-NEXT: for (int p = 0; p < N_4; p++) { -# CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) { -# CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) { -# CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) { -# CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) { -# CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1; -# CHECK: v_X___1 = int(0); -# CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) { -# CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) { -# CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) { -# CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) { -# CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) { -# CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) { -# CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) { -# CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) { -# CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) { -# CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) { -# CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) { -# CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) { -# CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1}); - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp deleted file mode 100644 index 5db84eab1f509..0000000000000 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ /dev/null @@ -1,3252 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -// Test helper function used to determine if two regions of a buffer have an -// overlap. No Overlap & partial overlap is obvious. Contains means A is -// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal -// ranges are ContainedOrEqual. -TEST(MemDependency, BoundOverlap) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - // Sanity check 3 overlap cases. - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1))); - - // Partial overlap works in either order. - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10))); - - // Total Overlap works when one bound encloses the other, and returns which. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16))); - - // Total overlap works when the bounds are an identical range, returns - // ContainedOrEqual. - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15))); - - // Total overlap when only one end of the bound matches. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15))); - - // No overlap when a < b. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130))); - - // No overlap when a > b. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120))); - - // No overlap when adjacent. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1))); - - // Partial overlap when middle bounds match. - ASSERT_EQ( - OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4))); - ASSERT_EQ( - OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2))); - - // Total overlap when one bound is single length over one end of the other. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15))); -} - -TEST(MemDependency, BoundComparison) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ)); - - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE)); - - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT)); - - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE)); - - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT)); - - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE)); -} - -TEST(MemDependency, BoundOverlapSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - VarHandle w("w", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - // Sanity check cases where the start and end is symbolic but the diff is - // constant. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x))); - ASSERT_EQ( - OverlapKind::PartialOverlap, - boundOverlap(CB(x, x + 3), CB(x + 2, x + 5))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1))); - - // We can't infer the sign of y, so cannot tell whether adding y is larger or - // smaller than y/2. - ASSERT_EQ( - OverlapKind::PartialOverlap, - boundOverlap(CB(x, x + y), CB(x, x + y / 2))); - - // No information about this bound, have to take the most conservative option: - // there may be an overlap. - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w))); - - // Math on opaque terms works. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - boundOverlap(CB(x + w, y - z), CB(x + w, y - z))); - // Even requiring simplification. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - boundOverlap(CB(x - w - w, y), CB(x - w * 2, y))); -} - -// Tests the helper function for overlap of multi dimensional indices bounds. -// This uses boundOverlap on each dimension and return the "lowest" kind of -// overlap. -TEST(MemDependency, BoundOverlapMultiDim) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - // Sanity check one dimensional cases. - ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)})); - ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)})); - ASSERT_EQ( - OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)})); - - // Total overlap in 3 dims. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)})); - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)})); - - // Total overlap in 2 dims, no overlap in another. - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); - - // Total overlap in 2 dims, partial overlap in another. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); - // This case is most important, so verify the overlap in any dim. (dim 2) - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)})); - // Dim 1. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)})); - // Total overlap in 1 dim, partial in 2. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)})); - // Total overlap, partial overlap, no overlap. - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)})); - - // Total overlap (B) in 2 dims, total overlap (A) in another. - ASSERT_EQ( - OverlapKind::Contains, - overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)})); - - // Total overlap (A) in 2 dims, total overlap (B) in another. - ASSERT_EQ( - OverlapKind::Contains, - overlaps( - {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)})); - - // Total (B), No Overlap, Total (A). - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)})); -} - -// Test the helper we use to subtract bounds: returns the regions(s) of A which -// remain after removing the region of B. -TEST(MemDependency, BoundSubtract) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - // One element subtract. - ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0); - ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0); - - // No Overlap. - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)})); - - // one side overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)})); - ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)})); - ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)})); - - // both sides overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {})); - - // internal overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)})); -} - -TEST(MemDependency, BoundSubtractSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - VarHandle w("w", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - // One element subtract. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {})); - - // Subtract constant range low. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)})); - // Subtract constant range high. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)})); - // Subtract constant range total overlap. - ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {})); - // Subtract constant range internal. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)), - {CB(x, x + 2), CB(x + 8, x + 10)})); - - // Size is inferable but not constant, only works with a single var. - ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)})); - - // Size is not inferable. - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)})); -} - -// Tests the helper function that does subtraction, but for multi dimensional -// indices bounds. -TEST(MemDependency, BoundSubtractMultiDim) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](std::vector x, std::vector y) { - if (x.size() != y.size()) { - return false; - } - for (auto i = 0U; i < x.size(); ++i) { - if (!indexBoundsEquals(x[i], y[i])) { - return false; - } - } - return true; - }; - - // sanity check one dimension. - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {})); - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {})); - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}})); - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}})); - - // Multi dim total overlap. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); - - // Multi dim one way partial in dim 1. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), - {{CB(4, 9), CB(0, 2)}})); - - // Multi dim one way partial in dim 2. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), - {{CB(0, 9), CB(11, 20)}})); - - // Partial overlap in 2 dims. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}), - {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}})); - - // Partial overlap in 3 dims. - ASSERT_TRUE( - EQ(subtractIndicesBounds( - {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}), - {{CB(0, 1), CB(0, 5), CB(0, 5)}, - {CB(2, 5), CB(0, 1), CB(0, 5)}, - {CB(2, 5), CB(2, 5), CB(0, 1)}})); -} - -// Tests the multi dimensional subtraction code for bounds that cannot be fully -// materialized. -TEST(MemDependency, BoundSubtractMultiDimSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](std::vector x, std::vector y) { - if (x.size() != y.size()) { - return false; - } - for (auto i = 0U; i < x.size(); ++i) { - if (!indexBoundsEquals(x[i], y[i])) { - return false; - } - } - return true; - }; - - // Cannot determine overlaps. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}})); - - // Various total Overlaps. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {})); - - // one-way overlap in first dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}), - {{CB(x - 4, x), CB(0, y)}})); - // second dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}), - {{CB(0, x), CB(0, 4)}})); - - // Internal overlap in first dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}), - {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}})); - // second dim. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}), - {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}})); - - // Overlap in both dimensions. - ASSERT_TRUE( - EQ(subtractIndicesBounds( - {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}), - { - {CB(0, 4), CB(0, y)}, - {CB(x - 4, x), CB(0, y)}, - {CB(0, x), CB(0, 9)}, - {CB(0, x), CB(y - 9, y)}, - })); -} - -// Simple check that the analyzer does anything at all... -TEST(MemDependency, MemDependencyCheckerSimple) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, bStore}); - - stmt->accept(&analyzer); - - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); - // sanity check, but anything that depends directly must depend indirectly. - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore)); -} - -// Check that there is a difference between direct and indirect dependence. -TEST(MemDependency, MemDependencyCheckerMultiStmt) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * B[0] = A[0]; - * C[0] = B[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, bStore, cStore}); - - stmt->accept(&analyzer); - - // C depends on A indirectly. - ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore)); - ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore)); - - // C depends on B directly, which depends on A directly. - ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore)); - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - - // Dependency goes top to bottom only. - ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore)); -} - -// Verify that we do filter writes that are totally overlapped by later writes. -TEST(MemDependency, MemDependencyCheckerOverlap) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * A[0] = 6; - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr a2Store = Store::make(a, {0}, 6); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, a2Store, bStore}); - - stmt->accept(&analyzer); - - // B store depends on second A store but not first since it is completely - // overlapped. - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store)); - ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore)); - - // No dependency between either A store. - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store)); - ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore)); -} - -// Verify that bounds match loop iterations, and that dependencies progress -// across loop scopes. -TEST(MemDependency, MemDependencyCheckerLoop) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * for (int x = 0; x < 10; ++x) { - * A[x] = x; - * } - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {x}, x); - StmtPtr loop = For::make(x, 0, 10, aStore); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1)); - - StmtPtr stmt = Block::make({loop, bStore}); - - stmt->accept(&analyzer); - - // Same A->B dependency. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop but does not depend on any loop iteration. - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop)); - - auto aStoreAccess = analyzer.accessFor(aStore); - ASSERT_NE(aStoreAccess, nullptr); - - // It should have bounds covering the range of x: 0 <= x < 10. - ASSERT_TRUE(indexBoundsEquals( - aStoreAccess->bounds(), {Bound(alloc(0), alloc(9))})); -} - -// Reductions should promote dependencies as well. -TEST(MemDependency, MemDependencyCheckerLoopReduce) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * A[0] = 0; - * for (int x = 0; x < 10; ++x) { - * A[0] = A[x] + 1; - * } - * B[0] = A[0]; - */ - - StorePtr aInit = Store::make(a, {0}, 0); - ExprHandle reduce = Sum()(a, 1, {x}, {x}); - StorePtr aReduce = Store::make(a, {0}, reduce); - StmtPtr loop = For::make(x, 0, 10, aReduce); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - - StmtPtr stmt = Block::make({aInit, loop, bStore}); - - stmt->accept(&analyzer); - - // B -> A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); - - // B depends indirectly on the initializer of A, since the reduction depends - // on it. - ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); - - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop and depends on other iterations. - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); - - // The loop contents depend on the initializer too. - ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); - - // Find loads within the reduction: - auto reduceLoads = NodeFinder::find(reduce.node()); - // Pull out the access for the load inside the loop. - for (auto load : reduceLoads) { - auto loopLoad = analyzer.accessFor(load); - // It should have 10 element long bounds. - ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); - } -} - -// Lowering a reduction doesn't affect dependency analysis. -TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * A[0] = 0; - * for (int x = 0; x < 10; ++x) { - * A[0] = A[x] + 1; - * } - * B[0] = A[0]; - */ - - StorePtr aInit = Store::make(a, {0}, 0); - ExprHandle aLoad = Load::make(a, {x}); - StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); - StmtPtr loop = For::make(x, 0, 10, aReduce); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - - StmtPtr stmt = Block::make({aInit, loop, bStore}); - - stmt->accept(&analyzer); - - // B -> A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); - - // B depends indirectly on the initializer of A, since the reduction depends - // on it. - ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); - - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop and depends on other iterations. - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); - - // The loop contents depend on the initializer too. - ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); - - // Pull out the access for the store inside the loop. - auto loopLoad = analyzer.accessFor(aLoad.node()); - // It should have 10 element long bounds. - ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); -} - -// Can determine dependencies of outputs, through to inputs. -TEST(MemDependency, MemDependencyCheckerInputsOutputs) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - // initialize analyzer with inputs and outputs. - analysis::MemDependencyChecker analyzer({a}, {b}); - - // Here's a Relu. - /* - * for (int x = 0; x < 10; ++x) { - * B[x] = Max(A[x], 0); - * } - */ - - ExprHandle aLoad = Load::make(a, {x}); - StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true)); - StmtPtr loop = For::make(x, 0, 10, bStore); - - StmtPtr stmt = Block::make({loop}); - - stmt->accept(&analyzer); - - // Output depends indirectly on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - // aLoad depends directly on the input A. - ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node())); - // bStore therefore depends directly on the input A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node())); - // The output depends directly on the store. - ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); - - // Check AccessInfo based overloads. - auto input = analyzer.input(a.node()); - auto output = analyzer.output(b.node()); - - // Output depends indirectly on input. - ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); - // Not directly. - ASSERT_FALSE(analyzer.dependsDirectly(output, input)); - // Not in reverse order. - ASSERT_FALSE(analyzer.dependsIndirectly(input, output)); - - // output -> bStore -> bLoad -> input. - auto storeAccess = analyzer.accessFor(bStore); - auto loadAccess = analyzer.accessFor(aLoad.node()); - - ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess)); - ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input)); -} - -// Can tell if an output does not depend on an input. -TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - // initialize analyzer with inputs and outputs. - analysis::MemDependencyChecker analyzer({a}, {b}); - - // Here's a dumb Relu. - /* - * for (int x = 0; x < 10; ++x) { - * B[x] = Max(x, 0); - * } - */ - - StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true)); - StmtPtr loop = For::make(x, 0, 10, bStore); - - StmtPtr stmt = Block::make({loop}); - - stmt->accept(&analyzer); - - // Output does not depend indirectly on input. - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node())); - - // The output still depends directly on the store. - ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); - - // Check AccessInfo based overloads. - auto input = analyzer.input(a.node()); - auto output = analyzer.output(b.node()); - - // Output does not depend indirectly on input. - ASSERT_FALSE(analyzer.dependsIndirectly(output, input)); -} - -// Verify different loop extents produce accesses with different bounds, and -// that later accesses find dependencies that overlap their entire bound range. -TEST(MemDependency, MemDependencyCheckerLoopBounds) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - using namespace analysis; - - MemDependencyChecker analyzer({a}, {c}); - - // This enables using the execution order of the loops to determine if some - // loops are self dependent or not. - analyzer.allowLoopExecutionOrderAnalysis(); - - /* - * for (int x = 1; x < 10; ++x) { - * B[x] = A[x]; - * } - * for (int x = 1; x < 9; ++x) { - * B[x] = B[x] * 2; - * } - * for (int x = 3; x < 4; ++x) { - * C[x] = A[x]; - * } - * for (int x = 0; x < 10; ++x) { - * C[x] = B[x]; - * } - */ - - std::vector stmts( - {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))), - For::make( - x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))), - For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))}); - - StmtPtr stmt = Block::make(stmts); - - stmt->accept(&analyzer); - - auto input = analyzer.input(a.node()); - auto output = analyzer.output(c.node()); - - // sanity check Output -> Input. - ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); - - // Check the For loop dependencies: - - // Last write to C depends on both writes to B since they contain the last - // write to at least one element. - ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1])); - ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0])); - - // The last write to C does not depend on the other write to C. - ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2])); - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - /* 0. Input: A[(0, 9)] - dependents: 1 5 - * 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2 - * 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7 - * 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4 - * 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7 - * 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6 - * 6. Store: C[(3, 3)] - depends on: 5 - * 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8 - * 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9 - * 9. Output: C[(0, 9)] - depends on: 8 - */ - - // Now let's look at the bounds of each access. - // There are 9 accesses in this Stmt, so this is exhaustive, we won't do this - // much. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 10); - VarPtr aVar = a.node()->base_handle(); - VarPtr bVar = b.node()->base_handle(); - VarPtr cVar = c.node()->base_handle(); - - // The first access is the input A. - ASSERT_EQ(history[0]->type(), AccessType::Input); - ASSERT_EQ(history[0]->var(), aVar); - // It has the bounds of the producing Input. - ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); - // sanity check the input we retrieved earlier matches. - ASSERT_EQ(history[0], input); - - // The second access is the load of A in the first loop. - ASSERT_EQ(history[1]->type(), AccessType::Load); - ASSERT_EQ(history[1]->var(), aVar); - // It has the bounds of the loop, i.e. start == 1. - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)})); - // It reads from A, so it should have a dependency on the last write to this - // range - with is the input. - ASSERT_EQ(history[1]->dependencies().size(), 1); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - // The third access is the store into B in the first loop. - ASSERT_EQ(history[2]->type(), AccessType::Store); - ASSERT_EQ(history[2]->var(), bVar); - // It also has the bounds of the loop, i.e. start == 1. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); - // The previous load is in its RHS, so it depends on it. - ASSERT_EQ(history[2]->dependencies().size(), 1); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The third access is the load from B in the second loop. - ASSERT_EQ(history[3]->type(), AccessType::Load); - ASSERT_EQ(history[3]->var(), bVar); - // It has the bounds of the second loop, i.e. >= 1 < 9. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)})); - // It reads from B in a smaller range, so should depend on the previous - // store. - ASSERT_EQ(history[3]->dependencies().size(), 1); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The fourth: the store to B in the second loop. - ASSERT_EQ(history[4]->type(), AccessType::Store); - ASSERT_EQ(history[4]->var(), bVar); - // It also has the bounds of the second loop. - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)})); - // The previous load is in its RHS, so it depends on it as before. - ASSERT_EQ(history[4]->dependencies().size(), 1); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // The fifth access is the load is from the 3rd loop, and skips previous B - // accesses. - ASSERT_EQ(history[5]->type(), AccessType::Load); - ASSERT_EQ(history[5]->var(), aVar); - // It has the bounds of the third loop: >= 3 < 4. - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)})); - // It depends on the last thing to write to A, which is the A input. - ASSERT_EQ(history[5]->dependencies().size(), 1); - ASSERT_TRUE(history[5]->hasDependency(history[0])); - - // Sixth: the store into the output C. - ASSERT_EQ(history[6]->type(), AccessType::Store); - ASSERT_EQ(history[6]->var(), cVar); - // It also has the bounds of the third loop. - ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)})); - // The previous load is in its RHS, so it depends on it as always. - ASSERT_EQ(history[6]->dependencies().size(), 1); - ASSERT_TRUE(history[6]->hasDependency(history[5])); - - // The seventh access is the load of B in the fourth loop. - ASSERT_EQ(history[7]->type(), AccessType::Load); - ASSERT_EQ(history[7]->var(), bVar); - // It has the bounds of the final loop, >= 0 < 10 - ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); - // The bounds of this read are larger than the bounds of the previous write, - // so it depends on both previous Stores to B. - ASSERT_EQ(history[7]->dependencies().size(), 2); - ASSERT_TRUE(history[7]->hasDependency(history[2])); - ASSERT_TRUE(history[7]->hasDependency(history[4])); - - // Eight: the final store into the output C. - ASSERT_EQ(history[8]->type(), AccessType::Store); - ASSERT_EQ(history[8]->var(), cVar); - // It also has the bounds of the final loop. - ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); - // The previous load is in its RHS, so it depends on it as always. - ASSERT_EQ(history[8]->dependencies().size(), 1); - ASSERT_TRUE(history[8]->hasDependency(history[7])); - - // The last access represents the output Buf. - ASSERT_EQ(history[9]->type(), AccessType::Output); - ASSERT_EQ(history[9]->var(), cVar); - // It has the bounds of the output Buf. - ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)})); - // sanity check the input we retrieved earlier matches. - ASSERT_EQ(history[9], output); - // It depends on the last write to C only. - ASSERT_EQ(history[9]->dependencies().size(), 1); - ASSERT_TRUE(history[9]->hasDependency(history[8])); -} - -// Verify that we can still infer bounds when the loop var is offset. -TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer({a}, {b}); - - // This enables using the execution order of the loops to determine if some - // loops are self dependent or not. - analyzer.allowLoopExecutionOrderAnalysis(); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * for (int x = 0; x < 9; x++) { - * A[x] = A[x + 1]; - * } - * for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - * for (int x = 0; x < 10; x++) { - * A[x] = A[9 - x]; - * } - * for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - */ - - StmtPtr stmt = Block::make( - {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), - For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))), - For::make( - x, - 0, - 9, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))), - For::make( - x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))), - For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - // Sanity check output depends on Input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - /* 0. Input: A[(0, 9)] - dependents: 1 - * 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2 - * 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3 - * 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4 - * 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7 - * 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6 - * 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7 - * 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8 - * 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9 - * 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10 - * 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11 - * 11. Output: B[(0, 9)] - depends on: 10 - */ - - // Now let's look at the bounds of each access. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 12); - VarPtr aVar = a.node()->base_handle(); - VarPtr bVar = b.node()->base_handle(); - - // The first access is the input A. - ASSERT_EQ(history[0]->type(), AccessType::Input); - ASSERT_EQ(history[0]->var(), aVar); - // It has the bounds of the producing Input. - ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); - - // The second access is the load A[x-1]. - ASSERT_EQ(history[1]->type(), AccessType::Load); - ASSERT_EQ(history[1]->var(), aVar); - // It has the bounds of the loop modified by the offset of each index, in - // this case -1. - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); - // It depends on the input, but also the store in the same loop, since - // different iterations of the loop depend on each other. - ASSERT_EQ(history[1]->dependencies().size(), 2); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - ASSERT_TRUE(history[1]->hasDependency(history[2])); - - // The third access is the Store to A[x] in the first loop. - ASSERT_EQ(history[2]->type(), AccessType::Store); - ASSERT_EQ(history[2]->var(), aVar); - // It has no offset on x, so should have the same bounds as the loop. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); - - // The fourth access is the load A[x+1] in the second loop. - ASSERT_EQ(history[3]->type(), AccessType::Load); - ASSERT_EQ(history[3]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9) modified by the offset of each - // index, in this case 1. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)})); - // This load totally overlaps the previous write to A, so it depends only on - // it and not the input. - ASSERT_EQ(history[3]->dependencies().size(), 1); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The fifth access is the store to A[x] in the second loop. - ASSERT_EQ(history[4]->type(), AccessType::Store); - ASSERT_EQ(history[4]->var(), aVar); - // It has no offset on x, so should have the same bounds as the loop. - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)})); - - // The sixth access is the load to A[8 - x] in the third loop. - ASSERT_EQ(history[5]->type(), AccessType::Load); - ASSERT_EQ(history[5]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9) modified by the offset of each - // index, in this case 8 - x. - // This access has a negative stride, which will be normalized. - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)})); - // This load totally overlaps the most recent write to A, so it depends only - // on it and not the input or the first write to A. - ASSERT_EQ(history[5]->dependencies().size(), 1); - ASSERT_TRUE(history[5]->hasDependency(history[4])); - - // The seventh access is the store to A[9 - x] in the third loop. - ASSERT_EQ(history[6]->type(), AccessType::Store); - ASSERT_EQ(history[6]->var(), aVar); - // This store has a negative stride on it's indices, but is normalized - // internally. - ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)})); - - // The eighth access is the load A[9-x] in the second loop. - ASSERT_EQ(history[7]->type(), AccessType::Load); - ASSERT_EQ(history[7]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x, - // which essentially traverses the loop backwards. - ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); - // This Load has three write dependencies: - ASSERT_EQ(history[7]->dependencies().size(), 3); - // * The previous store (#6) for elements 1-9 - ASSERT_TRUE(history[7]->hasDependency(history[6])); - // * An earlier store (#4) covering element 0 - ASSERT_TRUE(history[7]->hasDependency(history[4])); - // * A future store inside this loop, since this loop modifies the buffer - // in a non distinct way (due to the load and store having different access - // strides). - ASSERT_TRUE(history[7]->hasDependency(history[8])); - - // The ninth access is the store to A[x] in the fourth loop. - ASSERT_EQ(history[8]->type(), AccessType::Store); - ASSERT_EQ(history[8]->var(), aVar); - // This store has a negative stride on it's indices, but is normalized - // internally. - ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); - - // The tenth and 11th accesses are the copy from A[x] to B[x]. - ASSERT_EQ(history[9]->type(), AccessType::Load); - ASSERT_EQ(history[9]->var(), aVar); - ASSERT_EQ(history[10]->type(), AccessType::Store); - ASSERT_EQ(history[10]->var(), bVar); - - // The last access represents the output Buf. - ASSERT_EQ(history[11]->type(), AccessType::Output); - ASSERT_EQ(history[11]->var(), bVar); - // It has the bounds of the output Buf. - ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)})); - // It depends on the last write to B only. - ASSERT_EQ(history[11]->dependencies().size(), 1); - ASSERT_TRUE(history[11]->hasDependency(history[10])); - - // ok that's enough of that. -} - -// Check many different cases of loop self dependency - when a load within a -// loop is dependent on a Store later in the same loop but in different -// iteration. This is affected by whether or not we can trust the execution -// order of the loop. -TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - using namespace analysis; - - // This check assumes that the Stmt has a single Store with a single Load on - // the RHS. - auto isSelfDependent = - [](const std::vector>& history) -> bool { - return history.front()->hasDependency(history.back()); - }; - - { - /* for (int y = 0; y < 10; y++) { - * A[y] = (A[y]) + 1; - * } */ - - // Not self dependent since all loop iterations use a different y. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - y, - 0, - 10, - Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))})); - - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int y = 0; y < 10; y++) { - * A[y + 1] = (A[y + 1]) + 1; - * } - */ - - // Not self dependent due to different y (with offset). - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - y, - 0, - 10, - Block::make( - {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))})); - - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - // Is self dependent since all loops use a common constant element of A. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[0] = (B[0]) + x; - * } - */ - - // Is not self dependent because there is no store to the buffer that is - // read. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))})); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[y] = (A[y]) + x; - * } - */ - - // Is self dependent since all loops use a common symbolic element of A. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 1]; - * } - */ - - // In this case it depends if we are considering execution order. - - MemDependencyChecker analyzer; - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - // With analysis of order disabled, this is self dependent since the read - // from X+1 and the write to X+1 could be in reverse order. - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 1]; - * } - */ - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - // If order analysis is enabled, this is not dependent since the read for - // each element occurs before the write to that element. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - MemDependencyChecker analyzer; - - StmtPtr stmt = - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - // In this case, even with order analysis the Load is dependent on the - // Store, since the write to X occurs before the read from X. - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - */ - - // Still works if the execution order is reversed, so long as the read - // comes before the write. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); - stmt->accept(&analyzer); - - // However here was can determine the A store is earlier in the order than - // the load. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[8 - x] = A[9 - x]; - * } - */ - - // But not if it doesn't. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - */ - - // And not if we're not relying on execution order. - - MemDependencyChecker analyzer; - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 3; x < 10; x++) { - * A[x - 2] = A[x - 1]; - * } - */ - - // Forward order but negative indices. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - // However here was can determine the A store is earlier in the order than - // the load. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2]; - * } - */ - - // With an access stride. - - MemDependencyChecker analyzer; - // Execution order doesn't matter since the read and the write are totally - // distinct. - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 1]; - * } - */ - - // Here we can use the common stride of the accesses to determine they are - // distinct. - // Note, this is the only place (loop self dependency) we use this stride - // to avoid unnecessary dependence. - - MemDependencyChecker analyzer; - // Execution order doesn't matter since the read and the write are totally - // distinct. - - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 - 1]; - * } - */ - - // same if the read is behind the write so long as they are distinct. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 2]; - * } - */ - - // But not if the offset is in the stride. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 - 2]; - * } - */ - - // Works with negative offsets too. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 7]; - * } - */ - - // Detects accesses are distinct when offset is large but not a multiple - // of stride. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 4]; - * } - */ - - // Works with offsets which are multiples of the stride. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 6] = A[x * 6 + 5]; - * } - */ - - // detects accesses are distinct with large strides when the offset is - // within. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6]; - * } - */ - - // detects accesses are overlapping when stride is different but a - // multiple. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 4] = A[x * 2]; - * } - */ - - // still works when the read axis is the smaller stride. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6 + 1]; - * } - */ - - // detects accesses are distinct when stride is different but a multiple - // and there is an offset. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6 + 4]; - * } - */ - - // The smaller stride determines whether there is overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2 + 3] = A[x * 6]; - * } - */ - - // The smaller stride determines whether there is overlap, not the larger. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 3 + 1]; - * } - */ - - // If they have strides with no common multiple > 1, they overlap. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 10]; - * } - */ - - // If the offset is greater than the size of the loop, they can't overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[9 - x]; - * } - */ - - // If they have different execution orders they may overlap. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[19 - x * 2]; - * } - */ - - // Or they may not, depending on their start offset and strides. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x / 2] = A[x / 2]; - * } - */ - - // If the stride is not monotonic, they overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x / 2] = A[x / 2] + 1; - * } - */ - - // If the stride is not monotonic, they overlap - even with an offset. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x % 2] = A[x % 2]; - * } - */ - - // Mod too... - - analysis::MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = y; x < z; x++) { - * A[x] = A[x + 1]; - * } - */ - - // Still works with symbolic loop extents. - - { - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - StmtPtr stmt = - For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - } -} - -// Verify that a strided access still works. -// TODO: actually this only works because of the size of the ranges, revisit -// this test after strided overlap is implemented. -TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { - BufHandle a("A", {20}, kInt); - BufHandle b("B", {20}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - MemDependencyChecker analyzer({a.node()}, {b.node()}); - StmtPtr stmt = Block::make( - {For::make( - x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), - For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))) - - }); - stmt->accept(&analyzer); - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 2 dependencies... the store in each loop. - auto outputAccess = analyzer.output(b.node()); - ASSERT_EQ(outputAccess->dependencies().size(), 2); -} - -/* TODO(nickg) - this test will fail due to the lack of stride math in Bound -TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { - BufHandle a("A", {20}, kInt); - BufHandle b("B", {20}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - analysis::MemDependencyChecker analyzer({a.node()}, {c.node()}); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), - For::make( - x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))) - - }); - stmt->accept(&analyzer); - - std::cout << *stmt << "\n"; - for (auto& wi : analyzer.getHistory()) { - wi->print(); - } - } -}*/ - -// analysis on Stmts using Cond. -TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * C[0] = (B[0]) + 1; - * } else { - * C[0] = (B[1]) + 1; - * } - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)), - Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 3); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * C[x] = B[x]; - * } - * } else { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))), - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 3); - - // TODO(nickg): actually since the true and false branch cover the total - // range of the first store this should have 2 dependencies, but we don't - // do that yet. - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Only has true branch. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))), - nullptr)}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * } else { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Only has false branch. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - nullptr, - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (C[0]<5 ? 1 : 0) { - * C[0] = 5; - * } - */ - - // Cond's Condition depends on a previous access. - - MemDependencyChecker analyzer({a}, {c}); - StorePtr initStore = Store::make(c, {x}, Load::make(a, {x})); - ExprHandle conditionalLoad = Load::make(c, {0}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, initStore), - Cond::make( - CompareSelect::make( - conditionalLoad, 5, CompareSelectOperation::kLT), - Store::make(c, {0}, 5), - nullptr)}); - - stmt->accept(&analyzer); - - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - - ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore)); - ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node())); - } -} - -// Stmts using IfThenElse. -TEST(MemDependency, MemDependencyCheckerIfThenElse) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1; - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Add::make(Load::make(b, {0}), 1), - Add::make(Load::make(b, {1}), 1))); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - ifStore}); - - stmt->accept(&analyzer); - - // Output C should have 2 dependencies, each of the two stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // Now we need to check the Store containing the IfThenElse. - auto ifStoreAccess = analyzer.accessFor(ifStore); - - // It should have 2 dependencies. - ASSERT_EQ(ifStoreAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * C[0] = (y < 5 ? (B[0]) + 1 : 42; - */ - - // If the load appears in only one side of an IfThenElse the output may be - // dependent on it. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Add::make(Load::make(b, {0}), 1), - 42)); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - ifStore}); - - stmt->accept(&analyzer); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = (x < 5 ? B[x] : A[x]; - * } - */ - - // In this case C is dependent on both A and B. - - // TODO: in cases like this it would be possible to split the range of B - // into two bounds, one dependent on A and one dependent on B. We'd need to - // examine conditions relative to previously encountered loop variables. I'm - // uncertain if this would be helpful. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Load::make(b, {x}), - Load::make(a, {x}))); - StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)}); - - stmt->accept(&analyzer); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } -} - -// Cutting a loop with single elem writes -TEST(MemDependency, MemDependencyCheckerCutLoop) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - * B[5] = 100; - */ - - // Cutting a loop with single element writes. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))), - Store::make(b, {5}, 100)}); - - stmt->accept(&analyzer); - - // Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 2 dependencies. - auto outputAccess = analyzer.output(b.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - } - - { - /* for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - * for (int x = 4; x < 7; x++) { - * B[x] = B[x] + 3; - * } - * B[5] = 100; - * B[6] = 101; - * B[7] = 102; - */ - - // Cutting a loop with a smaller loop but then totally overlap that second - // loop with one element writes. - - MemDependencyChecker analyzer({a}, {b}); - ForPtr firstLoop = - For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))); - StorePtr secondStore = - Store::make(b, {x}, Add::make(Load::make(b, {x}), 1)); - ForPtr secondLoop = For::make(x, 4, 7, secondStore); - - StmtPtr stmt = Block::make( - {firstLoop, - secondLoop, - Store::make(b, {4}, 100), - Store::make(b, {5}, 101), - Store::make(b, {6}, 102)}); - - stmt->accept(&analyzer); - - // Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 4 dependencies. - auto outputAccess = analyzer.output(b.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 4); - - // Second loop depends on first loop. - ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop)); - - // Output does not depend on second loop or store. - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop)); - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore)); - } -} - -// Dynamic shapes (load in indices). -TEST(MemDependency, MemDependencyCheckerDynamicShapes) { - BufHandle a("A", {100}, kInt); - BufHandle b("B", {100}, kInt); - BufHandle c("C", {100}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - { - /* for (int x = 0; x < B[0]; x++) { - * C[x] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 - * 1. Input: A[(0, 99)] - dependents: 3 - * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4 - * 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4 - * 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Output dependent on A input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - // Also dependent on B input to determine the size of the region written. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The accesses in the loop depend on the load in the stop condition. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // Make a load from B to compare against. - ExprHandle loadFromB = Load::make(b, {0}); - - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)})); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)})); - } - - { - /* for (int x = B[0]; x < B[1]; x++) { - * C[x] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, - Load::make(b, {0}), - Load::make(b, {1}), - Store::make(c, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 3 - * 1. Input: A[(0, 99)] - dependents: 4 - * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5 - * 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5 - * 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5 - * 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6 - * 6. Output: C[(0, 99)] - depends on: 5 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 7); - - // The accesses in the loop depend on the load in the start condition. - ASSERT_TRUE(history[5]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[2])); - - // also the stop condition. - ASSERT_TRUE(history[5]->hasDependency(history[3])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // Make loads from B to compare against. - ExprHandle loadFromB0 = Load::make(b, {0}); - ExprHandle loadFromB1 = Load::make(b, {1}); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[B[x]]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 - * 1. Input: A[(0, 99)] - dependents: 3 - * 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4 - * 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4 - * 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads, the load of A depends on the load of B. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The loads in the indices depend on the relevant input buffer. - ASSERT_TRUE(history[3]->hasDependency(history[1])); - ASSERT_TRUE(history[2]->hasDependency(history[0])); - - // The load from B has the loop bounds. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - - // The load from A has bounds B[0] to B[9]. - ExprHandle loadFromB0 = Load::make(b, {0}); - ExprHandle loadFromB9 = Load::make(b, {9}); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[B[x]] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 3 - * 1. Input: A[(0, 99)] - dependents: 2 - * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4 - * 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4 - * 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads, neither load is dependent. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - ASSERT_FALSE(history[3]->hasDependency(history[2])); - ASSERT_FALSE(history[2]->hasDependency(history[3])); - - // The loads each depend on their relevant input. (but accesses are in a - // different order than the last case). - ASSERT_TRUE(history[3]->hasDependency(history[0])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The load from B has the loop bounds. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)})); - - // And so does the load from A. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[B[A[x]]] = x; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 3 - * 1. Input: A[(0, 99)] - dependents: 2 - * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4 - * 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4 - * 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // The outer load depends on the inner. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The loads each depend on their relevant input. (but accesses are in a - // different order than the last case). - ASSERT_TRUE(history[3]->hasDependency(history[0])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The load from A has the loop bounds. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - // The load from B as bounds A[0] to A[9]. - ExprHandle loadFromA0 = Load::make(a, {0}); - ExprHandle loadFromA9 = Load::make(a, {9}); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)})); - - // The store has bounds of B[A[0]] to B[A[9]]. - ExprHandle loadFromBA0 = Load::make(b, {loadFromA0}); - ExprHandle loadFromBA9 = Load::make(b, {loadFromA9}); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)})); - } -} - -// Verify multi dimensional bounds work. -TEST(MemDependency, MemDependencyCheckerMultiDim) { - int M = 10, N = 9, K = 12; - BufHandle a("A", {M, N, K}, kInt); - BufHandle b("B", {M, N, K}, kInt); - BufHandle c("C", {M, K}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 9; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, y, z] = A[x, y, z]; - * } - * } - * } - */ - // Full range. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - N, - For::make( - z, - 0, - K, - Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 5; x++) { - * for (int y = 0; y < 5; y++) { - * for (int z = 0; z < 5; z++) { - * B[x, y, z] = A[x, y, z]; - * } - * } - * } - */ - // Partial range. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - For::make( - z, - 0, - 5, - Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 12; y++) { - * B[x, 0, y] = A[x, 0, y]; - * } - * } - */ - - // Partial loops. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - N, - For::make( - y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 100; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, 0, z] = (A[x, 0, z]) + (C[x, z]); - * } - * } - * } - */ - - // Loops that don't correspond to an index, bufs with different - // dimensionality. - - MemDependencyChecker analyzer({a, c}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - 100, - For::make( - z, - 0, - K, - Store::make( - b, - {x, 0, z}, - Add::make( - Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on both inputs. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node())); - - // 6 accesses: 2 inputs, 2 loads, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // Simple chain from input to output over the A buf. - // history[0] is the C input, history[3] is the load from C. - ASSERT_TRUE(history[5]->hasDependency(history[4])); - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - // The store also depends on the load from the C input. - ASSERT_TRUE(history[4]->hasDependency(history[3])); - ASSERT_TRUE(history[3]->hasDependency(history[0])); - - // A Buf accesses. - ASSERT_TRUE( - EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); - - // C buf access. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 9; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]); - * } - * } - * } - */ - // Multi-dim reductions. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - N, - For::make( - z, - 0, - K, - Store::make( - b, - {x, 0, 0}, - Add::make( - Load::make(b, {x, y, z}), - Load::make(a, {x, y, z}))))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, 2 loads, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 5); - - // Simple chain from input to output. - ASSERT_TRUE(history[4]->hasDependency(history[3])); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[3]->hasDependency(history[1])); - ASSERT_TRUE(history[2]->hasDependency(history[0])); - - // The load from B depends on the store to B. - ASSERT_TRUE(history[1]->hasDependency(history[3])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)})); - } -} - -// Various tests using the external Compute/Reduce API. -TEST(MemDependency, MemDependencyCheckerComputeAPI) { - using namespace analysis; - - /* for (int m = 0; m < 4; m++) { - * for (int n = 0; n < 5; n++) { - * for (int k = 0; k < 6; k++) { - * broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]); - * } - * } - * } - * for (int m_1 = 0; m_1 < 4; m_1++) { - * for (int n_1 = 0; n_1 < 5; n_1++) { - * for (int k_1 = 0; k_1 < 6; k_1++) { - * d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1); - * } - * } - * } - */ - - // Can determine if 2 loops created by Compute are dependent. - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - - MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); - - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); - - // Second loop depends on first loop. - auto c_loop = l.getLoopStmtsFor(c)[0]; - auto d_loop = l.getLoopStmtsFor(d)[0]; - ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); -} - -TEST(MemDependency, MemDependencyCheckerComputeInline) { - using namespace analysis; - - /* for (int m = 0; m < 4; m++) { - * for (int n = 0; n < 5; n++) { - * for (int k = 0; k < 6; k++) { - * d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1); - * } - * } - * } - */ - - // Check inlining affects the number of accesses returned. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - l.computeInline(c.buf()); - - MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); - - // broadcast_add tensor should not appear in trace at all. - for (auto& wi : analyzer.getHistory()) { - ASSERT_NE(wi->var(), c.buf()->base_handle()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeSplit) { - using namespace analysis; - // Split an axis, so the number of loops != the number of dimensions. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - - LoopNest l({c}); - - MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); - l.root_stmt()->accept(&analyzer_before); - - l.splitWithTail(l.getLoopStmtsFor(c)[0], 2); - - MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - stmt->accept(&analyzer_after); - - // Splitting should not change accesses at all. - auto history_before = analyzer_before.getHistory(); - auto history_after = analyzer_after.getHistory(); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - ASSERT_EQ( - history_before[i]->bounds().size(), history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - ASSERT_EQ( - history_before[i]->dependencies().size(), - history_after[i]->dependencies().size()); - ASSERT_EQ( - history_before[i]->dependents().size(), - history_after[i]->dependents().size()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeReorder) { - using namespace analysis; - // Reorder an axis, so the loop order doesn't match the indexing order. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - - LoopNest l({c}); - - MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); - l.root_stmt()->accept(&analyzer_before); - - auto loops = l.getLoopStmtsFor(c); - l.reorderAxis(loops[0], loops[1]); - - MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - stmt->accept(&analyzer_after); - - // Reordering should not change accesses at all. - auto history_before = analyzer_before.getHistory(); - auto history_after = analyzer_after.getHistory(); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - ASSERT_EQ( - history_before[i]->bounds().size(), history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - ASSERT_EQ( - history_before[i]->dependencies().size(), - history_after[i]->dependencies().size()); - ASSERT_EQ( - history_before[i]->dependents().size(), - history_after[i]->dependents().size()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeReduce) { - using namespace analysis; - /* for (int l2 = 0; l2 < 2; l2++) { - * for (int n1 = 0; n1 < 3; n1++) { - * for (int m1 = 0; m1 < 6; m1++) { - * scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]); - * } - * } - * } - * for (int l1 = 0; l1 < 2; l1++) { - * sum[l1] = float(0); - * for (int n1_1 = 0; n1_1 < 3; n1_1++) { - * for (int m1_1 = 0; m1_1 < 6; m1_1++) { - * sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)), - * out_args={l1}, reduce_args={n1, m1}); - * } - * } - * } - */ - - // Can determine dependencies of a Reduction. - - BufHandle a("a", {2, 3, 6}, kFloat); - BufHandle b("b", {2, 3, 6}, kFloat); - - Tensor c = Compute( - "scale", - {2, 3, 6}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6}); - LoopNest l({d}, {c, d}); - - MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()}); - - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node())); - - // Second loop depends on first loop. - auto c_loop = l.getLoopStmtsFor(c)[0]; - auto d_loop = l.getLoopStmtsFor(d)[0]; - ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); - - // Reduction depends on both inputs. - auto reduces = NodeFinder::find(l.root_stmt()); - ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node())); -} - -TEST(MemDependency, MemDependencyCheckerComputeGEMM) { - int M = 1024; - int N = 1024; - int K = 2048; - using namespace analysis; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - LoopNest loop({CT}); - - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr m = loops[0]; - loop.splitWithMask(m, 4); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr n = loops[2]; - loop.splitWithMask(n, 16); - } - // mo, mi, no, ni, k -> - // mo, no, mi, ni, k - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[1]; - ForPtr no = loops[2]; - loop.reorderAxis(mi, no); - } - // mo, no, mi, ni, k -> - // mo, no, mi, k, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr ni = loops[3]; - ForPtr k = loops[4]; - loop.reorderAxis(ni, k); - } - // mo, no, mi, k, ni -> - // mo, no, k, mi, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[2]; - ForPtr k = loops[3]; - loop.reorderAxis(mi, k); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - loop.cacheAccesses(CT.buf(), "C_regs", loops[2]); - } - - MemDependencyChecker analyzer_unlowered( - loop.getInputBufs(), loop.getOutputBufs()); - - MemDependencyChecker analyzer_lowered( - loop.getInputBufs(), loop.getOutputBufs()); - - // Test both unlowered and lowered form. - { - StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt()); - stmt->accept(&analyzer_unlowered); - - // Outputs depend on inputs. - ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node())); - ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node())); - - // The last write to gemm should cover the total bound of the output. - std::shared_ptr outputAccess = - analyzer_unlowered.output(CT.buf()); - // A single dependency. - ASSERT_EQ(outputAccess->dependencies().size(), 1); - - // dependencies is a set with 1 element, so can just deref begin(). - std::shared_ptr gemmStore = - outputAccess->dependencies().begin()->second; - // Check its a store. - ASSERT_EQ(gemmStore->type(), AccessType::Store); - - ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds())); - - // Likewise the first read from each input cover the entire range of the - // input. - auto aInput = analyzer_unlowered.input(AP.node()); - auto bInput = analyzer_unlowered.input(BP.node()); - - // A single dependent each. - ASSERT_EQ(aInput->dependents().size(), 1); - ASSERT_EQ(bInput->dependents().size(), 1); - - // They're both loads. - std::shared_ptr aLoad = aInput->dependents().begin()->second; - std::shared_ptr bLoad = bInput->dependents().begin()->second; - ASSERT_EQ(aLoad->type(), AccessType::Load); - ASSERT_EQ(bLoad->type(), AccessType::Load); - - ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds())); - ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds())); - } - - loop.prepareForCodegen(); - SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT}); - - // now check lowered dependency graph. - { - StmtPtr stmt = IRSimplifier::simplify(cg.stmt()); - stmt->accept(&analyzer_lowered); - - // Lowering will change the dimensionality of all bounds due to index - // flattening and will insert Allocates and Frees. - - auto history_before = analyzer_unlowered.getHistory(); - auto history_after = analyzer_lowered.getHistory(); - - ASSERT_EQ(history_before.size() + 2, history_after.size()); - - // Filter out the alloc/free; - auto isAllocFree = [](const auto& info) { - return info->type() == AccessType::Alloc || - info->type() == AccessType::Free; - }; - history_after.erase( - std::remove_if(history_after.begin(), history_after.end(), isAllocFree), - history_after.end()); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - - if (history_before[i]->dependencies().size() != - history_after[i]->dependencies().size()) { - // Must depend on an Alloc. - ASSERT_TRUE(std::any_of( - history_after[i]->dependencies().begin(), - history_after[i]->dependencies().end(), - [](const auto& pair) { - return pair.second->type() == AccessType::Alloc; - })); - - ASSERT_EQ( - history_before[i]->dependencies().size() + 1, - history_after[i]->dependencies().size()); - } - - if (history_before[i]->dependents().size() != - history_after[i]->dependents().size()) { - // Must depend on an Free. - ASSERT_TRUE(std::any_of( - history_after[i]->dependents().begin(), - history_after[i]->dependents().end(), - [](const auto& pair) { - return pair.second->type() == AccessType::Free; - })); - - ASSERT_EQ( - history_before[i]->dependents().size() + 1, - history_after[i]->dependents().size()); - } - - // Inputs and outputs are not flattened, only accesses. - if (history_before[i]->type() == AccessType::Input || - history_before[i]->type() == AccessType::Output) { - ASSERT_EQ( - history_before[i]->bounds().size(), - history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - } else { - ASSERT_EQ(history_after[i]->bounds().size(), 1); - ExprPtr flat_bounds = alloc(1); - - for (auto& b : history_before[i]->bounds()) { - flat_bounds = - alloc(flat_bounds, alloc(b.end, alloc(1))); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start)); - } - - flat_bounds = IRSimplifier::simplify(flat_bounds); - ExprPtr after_bounds = IRSimplifier::simplify( - alloc(history_after[i]->bounds()[0].end, alloc(1))); - ASSERT_TRUE(exprEquals(flat_bounds, after_bounds)); - } - } - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_memplanning.cpp b/test/cpp/tensorexpr/test_memplanning.cpp deleted file mode 100644 index f5ee8747650fc..0000000000000 --- a/test/cpp/tensorexpr/test_memplanning.cpp +++ /dev/null @@ -1,708 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -extern void checkIR(StmtPtr s, const std::string& pattern); - -TEST(BufLiveRange, SingleRangeLine) { - VarHandle i("i", kInt), j("j", kInt); - BufHandle a("a", {32}, kFloat); - BufHandle b("b", {32, 32}, kFloat); - - // Construct Stmt: - // { - // for (int i = 0; i < 32; i++) { - // a[i] = 0; - // for (int j = 0; j < 32; j++) { - // a[i] = (a[i]) + (b[i, j]); - // } - // } - // } - - StorePtr aInit = Store::make(a, {i}, 0); - ExprHandle reduce = a.load({i}) + b.load({i, j}); - StorePtr aReduce = Store::make(a, {i}, reduce); - StmtPtr loop = - For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)})); - - StmtPtr stmt = Block::make({loop}); - - auto range = BufLiveRange::liveRange(stmt, a.node()); - ASSERT_TRUE(std::get<0>(range) == 0); - ASSERT_TRUE(std::get<1>(range) == 0); -} - -TEST(BufLiveRange, MulRangeLine) { - VarHandle i("i", kInt); - BufHandle a("a", {32}, kFloat); - BufHandle b("b", {32}, kFloat); - - // Construct Stmt: - // { - // for (int i = 0; i < 32; i++) { - // if (i<10 ? 1 : 0) { - // a[i] = i + i; - // b[i] = i * i; - // } - // } - // for (int i = 0; i < 32; i++) { - // if (i>10 ? 1 : 0) { - // a[i] = i * i; - // b[i] = i + i; - // } - // } - // } - - StorePtr aStore_1 = Store::make(a, {i}, i + i); - StorePtr bStore_1 = Store::make(b, {i}, i * i); - StmtPtr loop_1 = For::make( - i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL)); - - StorePtr aStore_2 = Store::make(a, {i}, i * i); - StorePtr bStore_2 = Store::make(b, {i}, i + i); - StmtPtr loop_2 = For::make( - i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL)); - - StmtPtr stmt = Block::make({loop_1, loop_2}); - - auto range_a = BufLiveRange::liveRange(stmt, a.node()); - ASSERT_TRUE(std::get<0>(range_a) == 0); - ASSERT_TRUE(std::get<1>(range_a) == 1); - - auto range_b = BufLiveRange::liveRange(stmt, b.node()); - ASSERT_TRUE(std::get<0>(range_b) == 0); - ASSERT_TRUE(std::get<1>(range_b) == 1); -} - -TEST(MemPlanning, MemReuseWithTypeCast) { - int M = 4; - int N = 4; - int K = 4; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return CompareSelect::make( - CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n)); - }); - Tensor FT = - Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n); - }); - StmtPtr stmt = - tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are - // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E' - // with typecasting. - //{ - // for (int i = 0; i < 4; i++) { - // for (int i_1 = 0; i_1 < 4; i_1++) { - // gemm[i, i_1] = float(0); - // for (int i_2 = 0; i_2 < 4; i_2++) { - // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, - // i_1]), reduce_args={i_2}); - // } - // } - // } - // for (int i_3 = 0; i_3 < 4; i_3++) { - // for (int i_4 = 0; i_4 < 4; i_4++) { - // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]); - // } - // } - // for (int i_5 = 0; i_5 < 4; i_5++) { - // for (int i_6 = 0; i_6 < 4; i_6++) { - // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6])); - // } - // } - // for (int i_7 = 0; i_7 < 4; i_7++) { - // for (int i_8 = 0; i_8 < 4; i_8++) { - // F[i_7, i_8] = E[i_7, i_8]; - // } - // } - //} - - LoopNest l(stmt, {FT.buf()}); - l.prepareForCodegen(); - SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] -# CHECK: Alias(E,gemm); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - PaddedBuffer a_v(M, K, "a"); - PaddedBuffer b_v(K, N, "b"); - PaddedBuffer o1(M, N, "e_before"); - PaddedBuffer o2(M, N, "e_after"); - - for (const auto m : c10::irange(M)) { - for (const auto k : c10::irange(K)) { - a_v(m, k) = at::randn({1}).item().to(); - } - } - - for (const auto k : c10::irange(K)) { - for (const auto n : c10::irange(N)) { - b_v(k, n) = at::randn({1}).item().to(); - } - } - - cg.call({a_v, b_v, o1}); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg_llvm.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] -# CHECK: Alias(E,gemm); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - cg_llvm.call({a_v, b_v, o2}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(o1, o2, 1e-5); -#endif -} - -TEST(MemPlanning, NoMemReuseForLargerType) { - int M = 4; - int N = 4; - int K = 4; - - BufHandle AP("A", {M, K}, kShort); - BufHandle BP("B", {K, N}, kShort); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - auto zero = Cast::make(CT.buf()->dtype(), 0); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n)); - }); - Tensor FT = - Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n); - }); - StmtPtr stmt = - tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are - // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for - // 'E'. - //{ - // for (int i = 0; i < 4; i++) { - // for (int i_1 = 0; i_1 < 4; i_1++) { - // gemm[i, i_1] = int16_t(0); - // for (int i_2 = 0; i_2 < 4; i_2++) { - // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, - // i_1]), reduce_args={i_2}); - // } - // } - // } - // for (int i_3 = 0; i_3 < 4; i_3++) { - // for (int i_4 = 0; i_4 < 4; i_4++) { - // relu[i_3, i_4] = (gemm[i_3, i_4]) a_v(M, K, "a"); - PaddedBuffer b_v(K, N, "b"); - PaddedBuffer o1(M, N, "e_before"); - PaddedBuffer o2(M, N, "e_after"); - - for (const auto m : c10::irange(M)) { - for (const auto k : c10::irange(K)) { - a_v(m, k) = at::randn({1}).item().to(); - } - } - - for (const auto k : c10::irange(K)) { - for (const auto n : c10::irange(N)) { - b_v(k, n) = at::randn({1}).item().to(); - } - } - - cg.call({a_v, b_v, o1}); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg_llvm.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] -# CHECK: Allocate(E); // dtype=float, dims=[4, 4] -# CHECK: Free(E); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - cg_llvm.call({a_v, b_v, o2}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(o1, o2, 1e-5); -#endif -} - -TEST(MemPlanning, SameBufSizeMemReuse) { - int M = 1024; - int N = 1024; - int K = 2048; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - auto zero = Cast::make(CT.buf()->dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm' - // for 'add'. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - Tensor GT = - Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return FT.load(m, n) - ET.load(m, n); - }); - - auto stmt = - Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same - // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul' - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - Tensor GT = - Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return FT.load(m, n) - 1; - }); - Tensor HT = - Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return GT.load(m, n) / 2; - }); - - auto stmt = Block::make( - {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and - // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for - // 'mul', and reuse 'gemm' for 'sub'. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = Compute( - "add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) { - return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2); - }); - Tensor FT = Compute( - "mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) { - return ET.load(fm, fn) * ET.load(fm, fn); - }); - auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of - // buffer 'gemm' is smaller. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1]) -#include -#include -#include -#include -#include - -using namespace torch::jit::tensorexpr; - -using Tensors = std::vector; -using Args = std::vector; -std::unique_ptr compile( - const Args& inputs, - const Tensors& outputs) { - LoopNest nest({outputs}); - nest.prepareForCodegen(); - nest.simplify(); - auto join = inputs; - join.insert(join.end(), outputs.begin(), outputs.end()); - return std::make_unique(nest.root_stmt(), join); -} - -TEST(Ops, Sum) { - constexpr int M = 8; - constexpr int N = 16; - std::vector testDims = {{0}, {1}, {0, 1}}; - std::vector> outputShapes = {{N}, {M}, {}}; - for (unsigned idx = 0; idx < testDims.size(); idx++) { - const auto& dims = testDims[idx]; - const auto& outShape = outputShapes[idx]; - - BufHandle a("a", {M, N}, kFloat); - std::vector outStrides = - c10::fmap(make_contiguous_strides(outShape)); - Tensor b = computeSum( - {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); - auto cg = compile({a}, {b}); - - auto at = at::arange(M * N, at::kFloat).view({M, N}); - auto ref = at::sum(at, dims); - auto bt = at::empty_like(ref); - - cg->call({at.data_ptr(), bt.data_ptr()}); - - ASSERT_TRUE(at::allclose(bt, ref)); - } -} - -TEST(Ops, ChannelsLastSum) { - constexpr int A = 2; - constexpr int B = 3; - constexpr int C = 4; - constexpr int D = 5; - constexpr int E = 6; - std::vector testDims = {{0}, {1}, {0, 1}}; - - std::vector> outputShapes = { - {B, C, D, E}, {A, C, D, E}, {C, D, E}}; - for (unsigned idx = 0; idx < testDims.size(); idx++) { - const auto& dims = testDims[idx]; - const auto& outShape = outputShapes[idx]; - - BufHandle a("a", {A, B, C, D, E}, kFloat); - std::vector outStrides = - c10::fmap(make_channels_last_strides(outShape)); - Tensor b = computeSum( - {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); - auto cg = compile({a}, {b}); - - auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E}); - auto ref = at::sum(at, dims); - auto bt = at::empty_like(ref); - - cg->call({at.data_ptr(), bt.data_ptr()}); - - ASSERT_TRUE(at::allclose(bt, ref)); - } -} diff --git a/test/cpp/tensorexpr/test_quantization.cpp b/test/cpp/tensorexpr/test_quantization.cpp deleted file mode 100644 index af6b539ff33e9..0000000000000 --- a/test/cpp/tensorexpr/test_quantization.cpp +++ /dev/null @@ -1,452 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir.h" - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; -using SimpleIRExprEval = ExprEval; -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -class Quantization : public ::testing::Test { - public: - void SetUp() override { - getTEMustUseLLVMOnCPU() = false; - } -}; - -TEST_F(Quantization, QuantDequantInt8) { - const auto graph_string = R"IR( - graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=12]() - %3 : int = prim::Constant[value=13]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantDequantUInt8) { - const auto graph_string = R"IR( - graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %3 : int = prim::Constant[value=122]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantDequantUInt8_NLC) { - const auto graph_string = R"IR( - graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)): - %2 : int = prim::Constant[value=13]() - %3 : int = prim::Constant[value=122]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(1, 2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - x.unsafeGetTensorImpl()->set_sizes_and_strides( - std::initializer_list{1, 2, 2}, {4, 1, 2}); - auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_add( - at::Tensor x1, - at::Tensor x2, - double scale, - int64_t zero) { - const auto qadd_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::add", "") - .typed(); - return qadd_op.call(x1, x2, scale, zero); -} - -TEST_F(Quantization, QuantAddDequantInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=12]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8); - auto qa = quantized_add(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantAddDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); - auto qa = quantized_add(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantSigmoidDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %qa : QUInt8(2, 2) = aten::sigmoid(%q1) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto qs = at::sigmoid(q1); - auto y_expected = at::dequantize(qs); - - TensorExprKernel k(graph); - std::vector inputs = {x1}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "qs:\n" << qs << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_mul( - at::Tensor x1, - at::Tensor x2, - double scale, - int64_t zero) { - const auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::mul", "") - .typed(); - return op.call(x1, x2, scale, zero); -} - -TEST_F(Quantization, QuantMulDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); - auto qa = quantized_mul(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %4 : NoneType = prim::Constant() - %3 : int[] = prim::Constant[value=[6, 6]]() - %qz : int = prim::Constant[value=13]() - %qs : float = prim::Constant[value=0.1]() - %q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2) - %qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4) - %6 : Float(1, 1, 6, 6) = aten::dequantize(%qu) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); - auto qu = at::upsample_nearest2d(q, {6, 6}); - auto y_expected = at::dequantize(qu); - - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "q:\n" << q << std::endl; - std::cout << "qu:\n" << qu << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, UpsampleNearst2d) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): - %4 : NoneType = prim::Constant() - %3 : int[] = prim::Constant[value=[4, 4]]() - %u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4) - return (%u))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y_expected = at::upsample_nearest2d(x, {4, 4}); - - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_cat( - c10::List const& xs, - int64_t dim, - double scale, - int64_t zero) { - const auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::cat", "") - .typed const&, - int64_t, - std::optional, - std::optional)>(); - return op.redispatch( - DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero); -} - -TEST_F(Quantization, QuantCatDequantUInt8) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): - %qdt : int = prim::Constant[value=13]() - %qxz : int = prim::Constant[value=13]() - %qxs : float = prim::Constant[value=0.1]() - %qyz : int = prim::Constant[value=16]() - %qys : float = prim::Constant[value=0.15]() - %qzz : int = prim::Constant[value=19]() - %qzs : float = prim::Constant[value=0.2]() - %qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt) - %qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt) - %qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt) - %catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz) - %catd : int = prim::Constant[value=0]() - %qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz) - %cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat) - return (%cat))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); - auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8); - auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8); - auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13); - auto expected = at::dequantize(qcat); - - TensorExprKernel k(graph); - std::vector inputs = {x, y, z}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto result = stack[0].toTensor(); - bool check = at::allclose(expected, result); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y:\n" << y << std::endl; - std::cout << "z:\n" << z << std::endl; - std::cout << "qx:\n" << qx << std::endl; - std::cout << "qy:\n" << qy << std::endl; - std::cout << "qz:\n" << qz << std::endl; - std::cout << "qcat:\n" << qcat << std::endl; - std::cout << "expected:\n" << expected << std::endl; - std::cout << "result:\n" << result << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp deleted file mode 100644 index fb83ab85b71ed..0000000000000 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ /dev/null @@ -1,1928 +0,0 @@ -#include - -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(Reductions, ReduceSum0D_1) { - const int M = 10; - - BufHandle b("b", {M}, kFloat); - std::vector in(M); - for (const auto j : c10::irange(M)) { - in[j] = j; - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], in[i]); - } -} - -TEST(Reductions, ReduceSum0D_2) { - BufHandle b("b", {}, kFloat); - std::vector in(1); - in[0] = 77.7; - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], in[0]); -} - -// Sum an array to a single value. -TEST(Reductions, ReduceSum1D) { - BufHandle b("b", {10}, kFloat); - std::vector in(10); - for (const auto j : c10::irange(10)) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {10}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 45); -} -// Sum a 2D tensor to a 1D tensor with dynamic shapes. -TEST(Reductions, ReduceSum2D) { - const int M = 3; - const int N = 7; - - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle b("b", {m, n}, kFloat); - std::vector in(M * N); - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - in[i * N + j] = j; - } - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, n, m}); - - cg.call({in, out, 5, 7}); - - float expected = 0; - for (const auto i : c10::irange(N)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], expected); - } -} - -// Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to -// check our work. -TEST(Reductions, ReduceSum3D) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m}); - - std::vector bData(2 * 3 * M, 0); - std::vector cData(2 * 3, 6.0f); - std::vector dData(2, 1.0f); - std::vector eData(2, 1.0f); - - for (int i = 0; i < 2 * 3; ++i) { - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j; - } - } - - cg.call({bData, cData, M}); - float expected = 0; - for (const auto i : c10::irange(M)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - - for (int i = 0; i < 2 * 3; ++i) { - ASSERT_EQ(cData[i], expected); - } - - Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m}); - LoopNest loop2({d}); - loop2.prepareForCodegen(); - StmtPtr s2 = loop2.root_stmt(); - s2 = IRSimplifier::simplify(s2); - - SimpleIREvaluator cg2(s2, {b, d, m}); - cg2.call({bData, dData, M}); - - // We're combining an additional dimension of 3, so the sum is 3x. - expected = expected * 3; - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(dData[i], expected); - } - - // This is the same as just reducing the original result across that axis. - BufHandle c_buf(c.buf()); - Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3}); - LoopNest loop3({e}); - loop3.prepareForCodegen(); - StmtPtr s3 = loop3.root_stmt(); - s3 = IRSimplifier::simplify(s3); - - SimpleIREvaluator cg3(s3, {c, e}); - cg3.call({cData, eData}); - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(eData[i], expected); - } -} - -// Sum a large (10 D) Tensor 5 dimensions in. -TEST(Reductions, ReduceSum10D) { - BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat); - const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3; - BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat); - const int OutputSize = 2 * 3 * 2 * 3 * 2; - - std::vector in(InputSize, 1.f); - std::vector out(OutputSize, -1.f); - - Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, c}); - - cg.call({in, out}); - - // NOLINTNEXTLINE(bugprone-integer-division) - float expected = InputSize / OutputSize; - for (const auto i : c10::irange(OutputSize)) { - ASSERT_EQ(out[i], expected); - } -} - -// Reduce via Mul rather than Add using a custom Reducer. -TEST(Reductions, ReduceProduct) { - const int M = 4; - const int N = 4; - - BufHandle b("b", {M, N}, kFloat); - std::vector in(M * N); - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - in[i * N + j] = 2 + j; - } - } - - std::vector out(M, -1.f); - - Reducer product( - ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; }); - - Tensor c = Reduce("product", {M}, product, b, {N}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - - float expected = 1; - for (const auto i : c10::irange(N)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected *= 2 + i; - } - - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], expected); - } -} - -// Maximum reductions. -TEST(Reductions, ReduceMax) { - BufHandle in_("b", {10}, kFloat); - - std::vector in(10); - std::vector out(1, -1.f); - for (const auto j : c10::irange(10)) { - in[j] = j; - } - - Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10}); - - LoopNest loop({dm1}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - SimpleIREvaluator cg(s, {in_, dm1}); - - cg.call({in, out}); - - ASSERT_EQ(out[0], 9); - - BufHandle in2_("b", {2, 5}, kFloat); - std::vector out2(2, -1.f); - - Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5}); - - LoopNest loop2({m2d}); - loop2.prepareForCodegen(); - s = loop2.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg2(s, {in2_, m2d}); - cg2.call({in, out2}); - - ASSERT_EQ(out2[0], 4); - ASSERT_EQ(out2[1], 9); -} - -// Minimum reduction, with custom initialization. -TEST(Reductions, ReduceMinCustomInitializer) { - VarHandle minInit("minInit", kFloat); - BufHandle in_("b", {10}, kFloat); - - std::vector in(10); - std::vector out(1, -1.f); - for (const auto j : c10::irange(10)) { - in[j] = 10 + j; - } - - Tensor min = Reduce( - "min", - {}, - Minimum(ExprHandle(minInit)), - [&](ParameterList& v) { return in_.load(v); }, - {10}); - - LoopNest loop({min}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, min, minInit}); - - // Works normally (note that out data starts lower than the correct - // minimum). - cg.call({in, out, std::numeric_limits::max()}); - ASSERT_EQ(out[0], 10); - - // With an initializer lower than the min, that's the min. - cg.call({in, out, 5.f}); - ASSERT_EQ(out[0], 5); -} - -// Example implementation of Any/All. -// TODO: this is very awkward without logical And/Or operators. -TEST(Reductions, ReduceAnyAll) { - VarHandle searchValue("searchValue", kInt); - BufHandle b("b", {4, 10}, kInt); - - Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) { - return CompareSelect::make(a, 1, 1, b, kEQ); - }); - - Tensor any = Reduce( - "anyEqual", - {4}, - anyEqSV, - [&](const auto& i, const auto& j) { - return CompareSelect::make(b.load(i, j), searchValue, kEQ); - }, - {10}); - - LoopNest loop({any}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, any, searchValue}); - - std::vector in(40, 0); - std::vector out(4, 0); - - // input has 0-39 in 4 rows. - for (const auto i : c10::irange(40)) { - in[i] = i; - } - cg.call({in, out, 1}); - - // only the first row has 1 - ASSERT_EQ(out[0], 1); - ASSERT_EQ(out[1], 0); - ASSERT_EQ(out[2], 0); - ASSERT_EQ(out[3], 0); - - cg.call({in, out, 15}); - - // 15 in the 3rd row - ASSERT_EQ(out[0], 0); - ASSERT_EQ(out[1], 1); - ASSERT_EQ(out[2], 0); - ASSERT_EQ(out[3], 0); - - Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) { - return CompareSelect::make(a, 0, 0, b, kEQ); - }); - - Tensor allGreaterThan = Reduce( - "allGreaterThan", - {4}, - allGTSV, - [&](const auto& i, const auto& j) { - return CompareSelect::make(b.load(i, j), searchValue, kGT); - }, - {10}); - - LoopNest loop2({allGreaterThan}); - loop2.prepareForCodegen(); - s = loop2.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue}); - - cg2.call({in, out, 11}); - - // 11 is in row 2. - ASSERT_EQ(out[0], 0); - ASSERT_EQ(out[1], 0); - ASSERT_EQ(out[2], 1); - ASSERT_EQ(out[3], 1); - - cg2.call({in, out, -3}); - - // All are positive. - ASSERT_EQ(out[0], 1); - ASSERT_EQ(out[1], 1); - ASSERT_EQ(out[2], 1); - ASSERT_EQ(out[3], 1); -} - -TEST(Reductions, ReduceMatmul2D) { - BufHandle tA("tA", {3, 2}, kFloat); - BufHandle tB("tB", {2, 3}, kFloat); - - std::vector tA_(6); - std::vector tB_(6); - - std::vector out(9, -1.f); - for (const auto i : c10::irange(3)) { - for (const auto j : c10::irange(2)) { - tA_[i * 2 + j] = i * 2 + j; - tB_[j * 3 + i] = i * 2 + j; - } - } - - Tensor mm = Reduce( - "mm", - {3, 3}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return tA.load(m, k) * tB.load(k, n); - }, - {2}); - - LoopNest loop({mm}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {tA, tB, mm}); - cg.call({tA_, tB_, out}); - - std::vector expected( - {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f}); - - for (const auto i : c10::irange(9)) { - ASSERT_EQ(out[i], expected[i]); - } -} - -TEST(Reductions, ReduceRfactorLike) { - BufHandle in("in", {10, 10}, kFloat); - std::vector in_(100); - for (const auto i : c10::irange(100)) { - in_[i] = i; - } - std::vector in_rf_(10, -2.f); - std::vector out(1, -1.f); - - Tensor l1 = Reduce("l1", {10}, Sum(), in, {10}); - BufHandle in_rf(l1.buf()); - - Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10}); - - LoopNest loop({l1, l2}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, l1, l2}); - cg.call({in_, in_rf_, out}); - - ASSERT_EQ(out[0], 99 * 50); -} - -TEST(Reductions, ReduceAsProducer) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle a("a", {2, 3}, kFloat); - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); - Tensor d = - Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) { - return c.load(l, n) * a.load(l, n); - }); - LoopNest loop({d}, {c, d}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {a, b, d, m}); - - std::vector aData(2 * 3, 0); - std::vector bData(2 * 3 * M, 0); - std::vector dData(2 * 3, 6.0f); - - for (int i = 0; i < 2 * 3; ++i) { - aData[i] = 6 - i; - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j; - } - } - - cg.call({aData, bData, dData, M}); - float expected = 0; - for (const auto i : c10::irange(M)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - for (int i = 0; i < 2 * 3; ++i) { - ASSERT_EQ(dData[i], expected * (6 - i)); - } -} - -TEST(Reductions, ReduceAsConsumer) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle a("a", {2, 3, m}, kFloat); - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Compute( - "scale", - {2, 3, m}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {2}, Sum(), c, {3, m}); - LoopNest loop({d}, {c, d}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {a, b, d, m}); - - std::vector aData(2 * 3 * M, 0); - std::vector bData(2 * 3 * M, 0); - std::vector dData(2, 6.0f); - - for (int i = 0; i < 2 * 3; ++i) { - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j + 1; - aData[i * M + j] = 6 - i; - } - } - - cg.call({aData, bData, dData, M}); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float expected[2] = {0, 0}; - for (const auto i : c10::irange(2)) { - for (const auto j : c10::irange(3)) { - for (const auto k : c10::irange(M)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected[i] += (k + 1) * (6 - (i * 3 + j)); - } - } - } - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(dData[i], expected[i]); - } -} - -TEST(Reductions, SplitReduceAxis) { - BufHandle in("in", {16, 8}, kFloat); - - std::vector in_(16 * 8); - for (const auto i : c10::irange(16)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out(16, -1.f); - - Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); - LoopNest l({tensor}); - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::splitWithTail(loops[1], 2); - - l.prepareForCodegen(); - - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, tensor}); - cg.call({in_, out}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out[i], i * 8); - } -} - -TEST(Reductions, SplitNonReduceAxis) { - BufHandle in("in", {16, 8}, kFloat); - - std::vector in_(16 * 8); - for (const auto i : c10::irange(16)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out(16, -1.f); - Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); - LoopNest l({tensor}); - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::splitWithTail(loops[0], 2); - LoopNest::splitWithTail(loops[0], 2); - - l.prepareForCodegen(); - - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, tensor}); - cg.call({in_, out}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out[i], i * 8); - } -} - -TEST(Reductions, ReorderedReductionInitializer) { - /* From the quip: - for k in 0..1: // blockIdx - for m in 0..128: - for n in 0..64: // threadIdx - SumOp(c(k, n), 0, a(k, m, n), {m}) - */ - - BufHandle in("in", {1, 12, 6}, kFloat); - std::vector in_(12 * 6, 1.f); - - Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6}); - LoopNest l_({tensor_}); - - l_.prepareForCodegen(); - StmtPtr s_ = Stmt::clone(l_.root_stmt()); - s_ = IRSimplifier::simplify(s_); - - Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6}); - LoopNest l({tensor}); - - auto loops = l.getLoopStmtsFor(tensor); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - LoopNest::reorderAxis(loops[1], loops[2]); - - StmtPtr s = l.root_stmt(); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - s = IRSimplifier::simplify(s); - - l.prepareForCodegen(); - - s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - std::vector out1(16, -1.f); - SimpleIREvaluator cg(s_, {in, tensor_}); - cg.call({in_, out1}); - - std::vector out2(16, -1.f); - SimpleIREvaluator cg2(s, {in, tensor}); - cg2.call({in_, out2}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out1[i], out2[i]); - } -} - -TEST(Reductions, ReduceRfactor) { - const int M = 10; - const int N = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle b("b", {m, n}, kFloat); - std::vector in(M * N); - for (int j = 0; j < M * N; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 2); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n}); - - cg.call({in, out, M, N}); - ASSERT_EQ(out[0], 4950); -} - -TEST(Reductions, Reduce3DRfactorInner) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("b", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_FALSE(loop.rfactor(c_body, loops.at(2))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 1); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, Reduce3DRfactorOuter) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("b", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 2); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n, k}); - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReduceRepeatedInternalRfactor) { - BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat); - const int InputSize = 2 * 3 * 4 * 5 * 6; - - std::vector in(InputSize, 1.f); - std::vector out(1, -1.f); - std::vector ref(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6}); - LoopNest orig_loop({c}); - - // Try rfactoring N outer loops - for (const auto rfac_number : c10::irange(1, 5)) { - LoopNest refloop(orig_loop); - LoopNest loop(orig_loop); - refloop.prepareForCodegen(); - SimpleIREvaluator ref_cg( - IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); - ref_cg.call({in, ref}); - - BufPtr tmp_buf = c.buf(); - - for (const auto idx : c10::irange(rfac_number)) { - auto reduce = loop.getAllWritesToBuf(tmp_buf)[1]; - ASSERT_TRUE(loop.rfactor( - reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf)); - } - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, c}); - cg.call({in, out}); - - ASSERT_EQ(ref[0], out[0]); - } -} - -// Split a reduction axis with a tail loop. -TEST(Reductions, ReduceSplitTail) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 8); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis cleanly so there is no tail loop. -TEST(Reductions, ReduceSplitNoTail) { - const int M = 10; - const int N = 10; - const int K = 10; - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 5); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with only a tail loop (the split loop will be size 0 -// and eliminated out). -TEST(Reductions, ReduceOverSplitTail) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 16); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with a mask. -TEST(Reductions, ReduceSplitMask) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 8); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis cleanly not requiring a mask. -TEST(Reductions, ReduceSplitNoMask) { - const int M = 10; - const int N = 10; - const int K = 10; - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 5); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with all logic in the mask. -TEST(Reductions, ReduceOverSplitMask) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 16); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Test an rfactor when there are two ReduceOps in the graph due to a -// splitWithTail. -TEST(Reductions, ReduceSplitRfactor) { - const int M = 2; - const int N = 10; - const int K = 10; - const int SPLIT_FACTOR = 4; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (const auto m : c10::irange(M)) { - for (int j = 0; j < N * K; ++j) { - in[m * N * K + j] = j; - } - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[2], SPLIT_FACTOR); - - auto c_body = loop.getAllWritesToBuf(c.buf())[2]; - auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); - LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]); - all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); - ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1])); - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = loop.root_stmt(); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - for ([[maybe_unused]] const auto i : c10::irange(M)) { - ASSERT_EQ(out[0], 4950); - } -} - -// Test an rfactor which ends up being eliminated since the total loop size is -// smaller than the split factor. -TEST(Reductions, ReduceOverSplitRfactor) { - const int N = 10; - const int K = 10; - const int SPLIT_FACTOR = 16; - - BufHandle b("b", {N, K}, kFloat); - std::vector in(N * K); - for (int j = 0; j < N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - ForPtr i, t; - LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); - LoopNest::reorderAxis(loops[0], i); - - auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0])); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]); - - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = loop.root_stmt(); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the IR to verify the rfactored reduce is eliminated. - // TODO: The alloc free should be eliminated here since it is size 0. - /* - const std::string& verification_pattern = - R"IR( -# CHECK: Allocate(tmp_buf); // dtype=float, dims=[0] -# CHECK: sum[0] = 0.f; -# CHECK: for (int n = 0; n < 10; n++) { -# CHECK: for (int k_tail = 0; k_tail < 10; k_tail++) { -# CHECK: sum[0] = (sum[0]) + (b[k_tail + 10 * n]); -# CHECK: } -# CHECK: } -# CHECK: Free(tmp_buf);)IR"; - */ - // TODO: rfactor output is not consistent yet, will fix (@nickg). - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Reductions, ReduceInlineReduction) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K}); - Tensor y = Compute( - "y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); }); - - PaddedBuffer a_v(M); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - a_v(i) = i * i; - } - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - b_v(i, j, k) = j * j * k; - } - } - } - - LoopNest l1({y}, {x, y}); - // Cannot inline a reduction computation - ASSERT_FALSE(l1.computeInline(x.buf())); -} - -TEST(Reductions, ReduceInlineConsumer) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M, N, K}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n, k) + b_buf.load(m, n, k); - }); - Tensor y = Reduce("y", {M}, Sum(), x, {N, K}); - - PaddedBuffer a_v(M, N, K); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - a_v(i, j, k) = i * i + k; - b_v(i, j, k) = j * j + k; - } - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M); - PaddedBuffer y_2(M); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -TEST(Reductions, ReduceInlineReducerInternal) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M, N, K}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n, k) + b_buf.load(m, n, k); - }); - - Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) { - return Add::make(ExprHandle(1.f), Min::make(a, b, false)); - }); - Tensor y = Reduce("y", {M}, minimum, x, {N, K}); - - PaddedBuffer a_v(M, N, K); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - a_v(i, j, k) = i * i + k; - b_v(i, j, k) = j * j + k; - } - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M); - PaddedBuffer y_2(M); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -TEST(Reductions, ReductionCacheAccessesOperatorAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before( - LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[0]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[4] -#CHECK: for (int i_2 -#CHECK: d_local[i_2] = 0.f -#CHECK: for (int -#CHECK: for (int -#CHECK: d_local[i_2] = (d_local[i_2]) + (scale[ -#CHECK: } -#CHECK: } -#CHECK: } -#CHECK: for (int i_3 -#CHECK: sum[i_3] = d_local[i_3] -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[1] -#CHECK: sum[i_1] = 0 -#CHECK: d_local[0] = sum[i_1] -#CHECK: for (int j_1 -#CHECK: for (int k_1 -#CHECK: d_local[0] = (d_local[0]) + (scale[ -#CHECK: } -#CHECK: } -#CHECK: sum[i_1] = d_local[0] -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[2]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[1] -#CHECK: sum[i_1] = 0 -#CHECK: for (int -#CHECK: d_local[0] = 0 -#CHECK: for (int -#CHECK: d_local[0] = (d_local[0]) + (scale[ -#CHECK: } -#CHECK: sum[i_1] = (sum[i_1]) + (d_local[0]) -#CHECK: } -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheBodyAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; - l.cacheAccesses(c.buf(), "scale_local", d_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12] -#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) { -#CHECK: for (int k_1 = 0; k_1 < 12; k_1++) { -#CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1]; -#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]); -#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]); -#CHECK: Free(scale_local); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); - - StmtPtr e_loop = l.getLoopStmtsFor(e)[1]; - l.cacheAccesses(d.buf(), "sum_local", e_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Alias(sum_local,scale); -#CHECK: sum[i_1] = (sum[i_1]) + (scale[ -#CHECK: for (int j_2 = 0; j_2 < 4 -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; -#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionSplitCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - ForPtr inner; - - // Split outer reduction axis. - LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner); - - // Split reduction consumer. - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); - - l.cacheAccesses(d.buf(), "sum_local", inner); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - // reduction changes but cache does not. - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Alias(sum_local,scale); -#CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]); -#CHECK: for (int i_2 = 0; i_2 < 6 -#CHECK: for (int j_2 = 0; j_2 < 4 -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; -#CHECK: for (int j_3 = 0; j_3 < 4 -#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionReorderCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - ForPtr inner; - - // reorder outer reduction axes. - auto loops = l.getLoopStmtsFor(d); - LoopNest::reorderAxis(loops[0], loops[1]); - - // Split reduction consumer. - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); - - l.cacheAccesses(d.buf(), "sum_local", inner); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - // neither reduction body not cache changes. - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]); -#CHECK: for (int i_3 = 0; i_3 < 6; -#CHECK: for (int j_2 = 0; j_2 < 4; -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3]; -#CHECK: for (int j_3 = 0; j_3 < 4; -#CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionRfactorCacheTempOuter) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("B", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::reorderAxis(loops.at(0), loops.at(1)); - loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - BufPtr rfac_buf; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); - loop.distributeLoop(loops.at(0)); - - auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); - - all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]); - loop.simplify(); - loop.prepareForCodegen(); - StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] -#CHECK: Allocate(tmp); // dtype=float, dims=[n] -#CHECK: for (int i_1 = 0; i_1 < m -#CHECK: for (int j = 0; j < n -#CHECK: tmp[j] = 0 -#CHECK: } -#CHECK: for (int j_1 = 0; j_1 < n -#CHECK: for (int k -#CHECK: tmp[j_1] = (tmp[j_1]) + (B[ -#CHECK: } -#CHECK: } -#CHECK: for (int j_2 = 0; j_2 < n -#CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]); -#CHECK: } -#CHECK: Free(tmp); -#CHECK-NOT: tmp - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReductionRfactorCacheTempInner) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("B", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - - LoopNest::reorderAxis(loops.at(0), loops.at(1)); - loops = loop.getLoopStmtsFor(c); - BufPtr rfac_buf; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); - loop.distributeLoop(loops.at(0)); - auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); - - all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]); - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] -#CHECK: Allocate(tmp); // dtype=float, dims=[1] -#CHECK: for (int i_1 = 0; i_1 < m -#CHECK: for (int j = 0; j < n -#CHECK: tmp[0] = 0 -#CHECK: for (int k -#CHECK: tmp[0] = (tmp[0]) + (B[ -#CHECK: } -#CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]); -#CHECK: Free(tmp); -#CHECK-NOT: tmp - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReductionVectorize) { - std::vector in_(8 * 8); - for (const auto i : c10::irange(8)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out_before(8, -1.f); - std::vector out_after(8, -1.f); - - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); - LoopNest l_before({tensor}); - LoopNest l(l_before); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); - cg_before.call({in_, out_before}); - - ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0])); - - StmtPtr s = l.root_stmt(); - s = LoopNest::sanitizeNames(IRSimplifier::simplify(s)); - - std::ostringstream oss; - oss << *s; - const std::string& expected_ir = - R"IR( -#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8); -#CHECK: for (int i = 0; i < 8; i++) { -#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i}); -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - // Vectorizing should not change result. - l.prepareForCodegen(); - s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg_after(s, {in, tensor}); - cg_after.call({in_, out_after}); - for (const auto i : c10::irange(8)) { - ASSERT_EQ(out_before[i], out_after[i]); - } -} - -TEST(Reductions, ReductionVectorizeInner) { - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); - LoopNest l({tensor}); - - ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); -} - -TEST(Reductions, ReductionVectorizeRfactor) { - std::vector in_(8 * 8); - for (const auto i : c10::irange(8)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out_before(1, -1.f); - std::vector out_after(1, -1.f); - - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8}); - - LoopNest l_before({tensor}); - LoopNest l(l_before); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); - cg_before.call({in_, out_before}); - - ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); - - // But if we rfactor this so it's not a reduce axis we can vectorize that - // loop. - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::reorderAxis(loops[0], loops[1]); - loops = l.getLoopStmtsFor(tensor); - auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1]; - BufPtr rfac_buf = nullptr; - ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf)); - - LoopNest::distributeLoop(loops.at(0)); - auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf); - - ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0])); - l.simplify(); - - StmtPtr s = LoopNest::sanitizeNames(l.root_stmt()); - - std::ostringstream oss; - oss << *s; - const std::string& expected_ir = - R"IR( -#CHECK: sum = 0.f; -#CHECK: for (int i = 0; i < 8; i++) { -#CHECK: sum_rfac[i] = 0.f; -#CHECK: } -#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) { -#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1}); -#CHECK: } -#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) { -#CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2}); -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - // Vectorizing should not change result. - l.prepareForCodegen(); - s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg_after(s, {in, tensor}); - cg_after.call({in_, out_after}); - - ASSERT_EQ(out_before[0], out_after[0]); -} - -TEST(Reductions, InitFunction) { - constexpr int M = 32; - constexpr int N = 16; - BufHandle A("A", {M, N}, kFloat); - BufHandle B("B", {N}, kFloat); - Tensor C = Reduce( - "C", - {N}, - Sum(), - [&](const std::vector& v) { return B.load(v[0]); }, - [&](const std::vector& v) { return A.load(v[1], v[0]); }, - {M}); - LoopNest nest({C}); - nest.prepareForCodegen(); - StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt())); - std::ostringstream oss; - oss << *s << "\n"; - const std::string& expected_ir = - R"IR( -#CHECK: for (int i = 0; i < 16; i++) { -#CHECK: C[i] = B[i]; -#CHECK: for (int j = 0; j < 32; j++) { -#CHECK: C[i] = (C[i]) + (A[i + 16 * j]); -#CHECK: } -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp deleted file mode 100644 index 6cbd04264c321..0000000000000 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ /dev/null @@ -1,3702 +0,0 @@ -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" -#include "torch/csrc/jit/tensorexpr/registerizer.h" - -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -// Can replace a simple scalar access with a local variable. -TEST(Registerizer, RegisterizerSimple) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't do replacement of a loop access. -TEST(Registerizer, RegisterizerLoop) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - // No change. - stmt = registerize(stmt); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: A[0] = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: A[x] = -# CHECK-NOT: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't replace even if the load is a fixed scalar, since the store could -// invalidate it. -TEST(Registerizer, RegisterizerLoopFixedLoad) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[0]) + x; - * } - */ - - // No change. - stmt = registerize(stmt); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[0]) + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: A[0] = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: A[x] = -# CHECK-NOT: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// We can registerize accesses that occur entirely within inner scopes, even if -// they depend on the loop var. -TEST(Registerizer, RegisterizerLoopInternal) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * A[x] = (A[x]) + x; - * } - */ - - stmt = registerize(stmt); - - // TODO: the order of terms in addition changes and in general depends on - // some hash value. This results in unpredictable swaps of the operands from - // random changes, which is not great. Ideally, we should ensure some - // specific order (ideally, the original one). - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * A_1 = x + A_1; - * A_1 = x + A_1; - * A[x] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: int A_1 = A[x]; -# CHECK: A_1 = A_1 + x; -# CHECK: A_1 = A_1 + x; -# CHECK: A[x] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An access can be overlapped by another read in the same Expr. In this case -// B[z] and B[y] overlap and prevent registerization of both accesses. -TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (B[y]) + (B[z]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeated) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})) - - }); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = x + (A[1]); - * A[0] = x + (A[1]); - * } - * for (int x = 0; x < 10; x++) { - * A[0] = x + (A[1]); - * A[0] = x + (A[1]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[1]; - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * A_2 = A_1 + x; - * } - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * A_2 = A_1 + x; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[1]; -# CHECK: int A_2 = A[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: A_2 = A_1 + x; -# CHECK: A_2 = A_1 + x; -# CHECK: } -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: A_2 = A_1 + x; -# CHECK: A_2 = A_1 + x; -# CHECK: } -# CHECK-NOT: A[1] -# CHECK: A[0] = A_2; -# CHECK-NOT: A[1] -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})) - - }); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = IRSimplifier::simplify(Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), - Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), - Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})) - - })); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Will registerize multiple accesses of different items of the same buffer. -TEST(Registerizer, RegisterizerMultiVar) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), - }); - - /* - * A[0] = 0; - * A[1] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * A[1] = (A[1]) - x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * int A_2 = 0; - * for (int x = 0; x < 10; x++) { - * A_2 = x + A_2; - * A_1 = A_1 - x; - * } - * A[1] = A_2; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: int A_2 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A_2 = -# CHECK: A[1] = A_2 -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Will registerize the valid accesses while skipping invalid replacements. -TEST(Registerizer, RegisterizerVariableLoad) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle x2("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make(x, 0, 10, Store::make(b, {x}, x)), - For::make( - x2, - 0, - 10, - Block::make({Store::make( - a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x; - * } - * for (int x_1 = 0; x_1 < 10; x_1++) { - * A[0] = (A[0]) + (B[x_1]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x; - * } - * for (int x_1 = 0; x_1 < 10; x_1++) { - * A_1 = A_1 + (B[x_1]); - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: B[x] = x -# CHECK: for (int x_1 = 0; x_1 < 10; x_1++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize variable accesses so long as the variable does not change. -TEST(Registerizer, RegisterizerSymbolicIndices) { - VarHandle i("i", kInt); - VarHandle N("N", kInt); - BufHandle a("A", {N}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {i}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))}); - - /* - * A[i] = 0; - * for (int x = 0; x < 10; x++) { - * A[i] = (A[i]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[i] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[i] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize accesses dependent on multiple loop vars. -TEST(Registerizer, RegisterizerMultiLoop) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - Block::make({Store::make( - a, - {0}, - Mul::make(Add::make(Load::make(a, {0}), x), y))})))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * A[0] = x * y + (A[0]) * y; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * A_1 = x * y + y * A_1; - * } - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: for (int y = 0; y < 10; y++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize correctly if scalars already exist in the program. -TEST(Registerizer, RegisterizerRepeated) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), - }); - - // Registerize manually to make sure we only replace a single target. - { - registerizer::RegisterizerAnalysis analysis; - stmt->accept(&analysis); - auto candidates = analysis.getCandidates(); - ASSERT_EQ(candidates.size(), 2); - - candidates.pop_back(); - registerizer::RegisterizerReplacer replacer(candidates); - stmt = stmt->accept_mutator(&replacer); - } - - // Re-analyze and replace the second target. - { - registerizer::RegisterizerAnalysis analysis; - stmt->accept(&analysis); - auto candidates = analysis.getCandidates(); - ASSERT_EQ(candidates.size(), 1); - - registerizer::RegisterizerReplacer replacer(candidates); - stmt = stmt->accept_mutator(&replacer); - } - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: int A_1_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A_1_1 = -# CHECK: A[1] = A_1_1; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize the load of A. -TEST(Registerizer, RegisterizerNoLoads) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = x + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + 1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize the load of A but not the store of B. -TEST(Registerizer, RegisterizerNoRepeatedStores) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - // TODO: its unnecessary to reorder the initializer of A[0], but it's not - // actually worse so lets not worry for now. - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: B[x] = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't registerize if there are multiple accesses which may overlap. -TEST(Registerizer, RegisterizerMultiVarOverlap) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})), - }); - stmt = IRSimplifier::simplify(stmt); - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerAllocs) { - BufHandle a("A", {2}, kInt); - BufHandle c("C", {1}, kInt); - VarHandle x("x", kInt); - - BufHandle b("B", {Load::make(c, {0})}, kInt); - - StmtPtr stmt = Block::make( - {Allocate::make(b), - Store::make(a, {0}, Load::make(c, {0})), - Store::make(b, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)), - Store::make(a, {0}, Load::make(c, {0}))})), - Free::make(b)}); - - /* - * Allocate(B, int, {C[0]}); - * A[0] = C[0]; - * B[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[0] = (B[0]) + x; - * A[0] = C[0]; - * } - * Free(B); - */ - - stmt = registerize(stmt); - - /* - * int C_1 = C[0]; - * Allocate(B, int, {C_}); - * int A_1 = C_1; - * int B_1 = 0; - * for (int x = 0; x < 10; x++) { - * B_1 = B_1 + x; - * A_1 = C_1; - * } - * B[0] = B_1; - * A[0] = A_1; - * Free(B); - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int C_1 = C[0]; -# CHECK: Allocate(B -# CHECK: int A_1 = C_1; -# CHECK: int B_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: B_1 = -# CHECK: A_1 = C_ -# CHECK: B[0] = B_1; -# CHECK: A[0] = A_1; -# CHECK: Free(B)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNoInitializer) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNoInitializerLoopVar) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoadThenStore) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {0}, Load::make(b, {0}))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * B[0] = (A[0]) + x; - * A[0] = B[0]; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * int B_1 = B[0]; - * for (int x = 0; x < 10; x++) { - * B_1 = x + A_1; - * A_1 = B_1; - * } - * B[0] = B_1; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: int B_1 = B[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: B[ -# CHECK: B_1 = -# CHECK-NOT: A[ -# CHECK: A_1 = B_ -# CHECK: B[0] = B_ -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerParallelized) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - LoopOptions loopOpts; - loopOpts.set_gpu_block_index(0); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}), - loopOpts)}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - ASSERT_THROWS_WITH( - registerize(stmt), - "Registerization must occur after parallelism flattening"); -} - -// Should be able to registerize this since the scalar would exist before the -// branch. -TEST(Registerizer, RegisterizerConditionAfter) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * C[x] = A_1; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Should be able to registerize this since the scalar exists in the same form -// after the branch and there is no overlap. -TEST(Registerizer, RegisterizerConditionBefore) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x}))}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * A[x] = B[x]; - * C[x] = A[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_ 1 = A[x]; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A_1 = B[x]; - * C[x] = A_1; - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Should be able to registerize this as the combination of the two above rules. -TEST(Registerizer, RegisterizerConditionInside) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * A[x] = C[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * C[x] = A_1; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * B[x] = A_1; - * A_1 = C[x]; - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: B[x] = A_1; -# CHECK: A_1 = C[x]; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An example where an access is cut by an overlapping access inside a -// condition, and both sides are large enough to be registerized but cannot be -// because there is no safe place to put the initializer or finalizer. -TEST(Registerizer, RegisterizerConditionInsideOverlap1) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({ - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {0}, 3), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - }), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * A[x] = C[x]; - */ - - // The A[0] store overlaps, A[x] cutting the region that can be registerized - // into two groups. - // Each group has 2 loads and 2 stores however, so we could registerize it, - // but the first group would need to be finalized inside the condition block, - // the second would need to be initialized inside the condition block. There's - // no safe place to put these that's visible to the other uses in the group - // and so neither registerization is possible. - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Same as the above, but the access group before the condition (and after the -// condition) are large enough to be registerized without needing the access -// from the loop. Registerization occurs but does not include any accesses in -// the condition, and the first group must be finalized before the Cond, the -// second initialized after it. -TEST(Registerizer, RegisterizerConditionInsideOverlap2) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(a, {x}, Load::make(b, {x + 1})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({ - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {0}, 3), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - }), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(b, {x + 1}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * A[x] = B[x + 1]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * B[x + 1] = A[x]; - * A[x] = C[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; // A_1 initializer - * A_1 = B[x + 1]; // - * C[x] = A_1; // - * A[x] = A_1; // A_1 finalizer - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * int A_2 = A[x]; // A_2 initializer - * B[x] = A_2; // - * B[x + 1] = A_2; // - * A_2 = C[x]; // - * A[x] = A_2; // A_2 finalizer - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: A_1 = B[x + 1]; -# CHECK: C[x] = A_1; -# CHECK: A[x] = A_1; -# CHECK: if ( -# CHECK-NOT: A_1 = A_1 + 1; -# CHECK: A[x] = (A[x] -# CHECK: A[0] = -# CHECK: A[x] = (A[x] -# CHECK: } -# CHECK: int A_2 = A[x]; -# CHECK: B[x] = A_2; -# CHECK: B[x + 1] = A_2; -# CHECK: A_2 = C[x]; -# CHECK: A[x] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// When accesses are within conditional blocks they are not visible to the wider -// program, because we don't know if the branch would be taken and if it isn't -// the accesses in it don't need to be valid (think size checks on the index). -// In this case the accesses cannot be registerized. -TEST(Registerizer, RegisterizerConditionHidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * if (x>5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// But... if the same access is found in a non conditional scope, that means -// that that access is valid in the higher scope (or at least if its not it's -// the user's fault). It "unhides" the conditional accesses, allowing -// registerization to occur. -TEST(Registerizer, RegisterizerConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * A[x] = (A[x]) + 1; <-- this is doing the unhiding. - * if (x>5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A_1 = A_1 + 1; - * if (x>5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if (x<5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x>5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a load that occurs in the condition of a Cond. -TEST(Registerizer, RegisterizerCondCondition) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), - nullptr)}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if ((A[x])<5 ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * int C_1 = A_1; - * if (A_1<5 ? 1 : 0) { - * C_1 = C_1 + 1; - * } - * C[x] = C_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: int C_1 = A_1; -# CHECK: if (A_1<5 -# CHECK: C_1 = C_1 + 1; -# CHECK: C[x] = C_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Appearing in the condition of a Cond makes it visible to the enclosing scope, -// and so we can registerize internal usages. -TEST(Registerizer, RegisterizerCondConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))}); - - /* - * if ((A[x])<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } else { - * A[x] = (A[x]) + 10; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * if (A_1<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } else { - * A_1 = A_1 + 10; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if (A_1<5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } else { -# CHECK: A_1 = A_1 + 10; -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Conditional hiding also works for IfThenElse exprs. -TEST(Registerizer, RegisterizerIfThenElseHidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - {Store::make( - b, - {y}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - Store::make( - b, - {y + 1}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2)))}); - - /* - * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Conditional unhiding also works for IfThenElse exprs. -TEST(Registerizer, RegisterizerIfThenElseUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make({ - Store::make(a, {x}, 0), - Store::make( - b, - {y}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - Store::make( - b, - {y + 1}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - }); - - /* - * A[x] = 0; - * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); -# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Nested IfThenElse exprs can't promote to higher level scopes. -TEST(Registerizer, RegisterizerIfThenElseNested) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - BufHandle d("D", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - IfThenElse::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Load::make(d, {x}), - Load::make(b, {x})), - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kEQ), - Load::make(c, {x}), - Load::make(d, {x}))))}); - - /* - * A[x] = IfThenElse(x<3 ? 1 : 0, - * IfThenElse(x==2 ? 1 : 0, D[x], B[x]), - * IfThenElse(x==5 ? 1 : 0, C[x], D[x])); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Cannot registerize an access completely contained within an IfThenElse -// branch, since it is not a Stmt and cannot hold variable definitions. We need -// to check that we don't promote the initializer/finalizer to the enclosing -// Block. -TEST(Registerizer, RegisterizerIfThenElseInternal) { - // Making these floats so they don't get simplified to a single access. - BufHandle a("A", {5}, kFloat); - BufHandle b("B", {5}, kFloat); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Add::make(Load::make(b, {x}), Load::make(b, {x})), - Load::make(b, {x})))}); - - /* - * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // If this was a Cond instead of an IfThenElse then we could registerize the - // two accesses to B[x] in the True branch. - - // Actually lets verify that. - - stmt = Block::make({Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))), - Store::make(a, {x}, Load::make(b, {x})))}); - - /* - * if (x<3 ? 1 : 0) { - * A[x] = (B[x]) + (B[x]); - * } else { - * A[x] = B[x]; - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<3 ? 1 : 0) { - * float B_1 = B[x]; - * A[x] = B_1 + B_1; - * } else { - * A[x] = B[x]; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK-NOT: float -# CHECK: if (x<3 -# CHECK: float B_1 = -# CHECK: A[x] = B_1 + B_1 -# CHECK: } else { -# CHECK: A[x] = B[x] -# CHECK: } -# CHECK-NOT: A[x] -# CHECK-NOT: B[x])IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a load that occurs in the condition of an IfThenElse; -TEST(Registerizer, RegisterizerIfThenElseCondition) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(a, {x})), - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Load::make(b, {0}), - Load::make(c, {0})))}); - - /* - * A[x] = A[x]; <---- just here so there are enough accesses to combine. - * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * A_1 = A_1; - * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Appearing in the condition of a Cond makes it visible to the enclosing scope, -// and so we can registerize internal usages. -TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - b, - {x}, - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x}), 10)))}); - - /* - * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10); - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Cannot promote accesses internal to IfThenElse branches even if the enclosing -// scope if conditional. -TEST(Registerizer, RegisterizerConditionBranchOnly) { - BufHandle a("A", {5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({ - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), x), - Add::make(Load::make(a, {x - 5}), x))), - Store::make( - a, - {x - 5}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), x), - Add::make(Load::make(a, {x - 5}), x)))), - }))}); - stmt = IRSimplifier::simplify(stmt); - - std::ostringstream before; - before << *stmt; - - /* for (int x = 0; x < 10; x++) { - * if (x<5 ? 1 : 0) { - * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } else { - * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } - * } - */ - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// We can registerize an IfThenElse that appears in the condition branch of a -// Cond. This is a weird but valid thing to do. -TEST(Registerizer, RegisterizerCondIfThenElse) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make( - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(b, {x})), - x, - CompareSelectOperation::kEQ), - Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), - nullptr)}); - - /* - * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - // access to A can be registerized, but not B or C - - /* - * int A_1 = A[x]; - * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x] -# CHECK: C[x] = (C[x]) + 1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a conditional access in the RHS of a store unhidden by it's -// LHS, and hoist it out of a loop. -TEST(Registerizer, RegisterizerIfThenElseLoop) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = For::make( - y, - 0, - 10, - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(b, {y})))); - - /* - * for (int y = 0; y < 10; y++) { - * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * for (int y = 0; y < 10; y++) { - * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: for ( -# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Cannot registerize if the RHS overlaps the access creating visibility. -TEST(Registerizer, RegisterizerIfThenElseLoopCut) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make({For::make( - y, - 0, - 10, - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(a, {y}))))}); - - /* - * for (int y = 0; y < 10; y++) { - * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Simple case where an access is cut by an overlapping access later in the -// program, we can registerize up until the overlap. -TEST(Registerizer, RegisterizerPartialAfter) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})), - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x] = A[x - 1]; -# CHECK: } -# CHECK-NOT: A)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// We can registerize an access which overlaps a previous access, the -// initializer must be inserted after the previous access. -TEST(Registerizer, RegisterizerPartialBefore) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), - Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: for ( -# CHECK: A[x] = A[x - 1]; -# CHECK: } -# CHECK: int A_1 = 0; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// The combination of the previous two tests, an access is cut by an overlapping -// access in both directions. -TEST(Registerizer, RegisterizerPartialInside) { - BufHandle a("A", {1}, kInt); - VarHandle x1("x1", kInt); - VarHandle x2("x2", kInt); - VarHandle x3("x3", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 2), - For::make( - x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))), - For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))), - For::make( - x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))}); - - /* - * A[0] = 2; - * for (int x1 = 0; x1 < 10; x1++) { - * A[0] = (A[0]) + x1; - * } - * for (int x2 = 1; x2 < 10; x2++) { - * A[x2] = A[x2 - 1]; - * } - * for (int x3 = 0; x3 < 10; x3++) { - * A[0] = (A[0]) + x3; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 2; - * for (int x1 = 0; x1 < 10; x1++) { - * A_1 = A_1 + x1; - * } - * A[0] = A_1; - * for (int x2 = 1; x2 < 10; x2++) { - * A[x2] = A[x2 - 1]; - * } - * int A_2 = A[0]; - * for (int x3 = 0; x3 < 10; x3++) { - * A_2 = A_2 + x3; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 2; -# CHECK: for ( -# CHECK: A_1 = A_1 + x1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x2] = -# CHECK: } -# CHECK: int A_2 = A[0]; -# CHECK: for ( -# CHECK: A_2 = A_2 + x3; -# CHECK: } -# CHECK: A[0] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An element could be registerized program wide but is cut by a conditional -// access, we should break this into two scalars and write back to the buffer -// before the condition. -TEST(Registerizer, RegisterizerPartialCondition) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 2), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Load::make(a, {x - 1})), - nullptr), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))}); - - /* - * A[0] = 2; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - * if (x<5 ? 1 : 0) { - * A[x] = A[x - 1]; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 2; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - * if (x<5 ? 1 : 0) { - * A[x] = A[x - 1]; - * } - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + x; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 2; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: if ( -# CHECK: A[x] = -# CHECK: } -# CHECK: int A_2 = A[0]; -# CHECK: for ( -# CHECK: A_2 = A_2 + x; -# CHECK: } -# CHECK: A[0] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Tests case where an access is cut by an internal conditional access which -// itself is registerized. -TEST(Registerizer, RegisterizerPartialConditionInternalCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 1), - Store::make(a, {0}, 3), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), - nullptr), - Store::make(a, {0}, 4), - Store::make(a, {0}, 6)}); - - /* - * A[0] = 1; - * A[0] = 3; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * A[x] = 3; - * } - * A[0] = 4; - * A[0] = 6; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 1; - * A_1 = 3; - * A[0] = A_1; - * if (x<5 ? 1 : 0) { - * int A_2 = 1; - * A_2 = 3; - * A[x] = A_2; - * } - * int A_3 = 4; - * A_3 = 6; - * A[0] = A_3; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 1; -# CHECK: A_1 = 3 -# CHECK: A[0] = A_1; -# CHECK: if ( -# CHECK: int A_2 = 1; -# CHECK: A_2 = 3; -# CHECK: A[x] = A_2; -# CHECK: } -# CHECK: int A_3 = 4; -# CHECK: A_3 = 6; -# CHECK: A[0] = A_3;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// First statement in condition closes outer access, but can be registerized -// with later statements. -TEST(Registerizer, RegisterizerPartialConditionInternalStart) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 1), - Store::make(a, {0}, 3), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), - nullptr), - Store::make(a, {x}, 4), - Store::make(a, {x}, 6)}); - - /* - * A[0] = 1; - * A[0] = 3; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * A[x] = 3; - * } - * A[x] = 4; - * A[x] = 6; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 1; - * A_1 = 3; - * A[0] = A_1; - * int A_2 = A[x]; <--- must read from the input here. - * if (x<5 ? 1 : 0) { - * A_2 = 1; - * A_2 = 3; - * } - * A_2 = 4; - * A_2 = 6; - * A[x] = A_2; - */ - - // TODO: I suppose we could refactor with a conditional initializer? - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 1; -# CHECK: A_1 = 3 -# CHECK: A[0] = A_1; -# CHECK: int A_2 = A[x]; -# CHECK: if ( -# CHECK: A_2 = 1; -# CHECK: A_2 = 3; -# CHECK: } -# CHECK: A_2 = 4; -# CHECK: A_2 = 6; -# CHECK: A[x] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An access cuts two open overlaps and creates four scalar variables. -TEST(Registerizer, RegisterizerPartialOverlapsTwo) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {1}, Load::make(a, {0})), - Store::make(a, {0}, Load::make(a, {1})), - Store::make(a, {0}, Load::make(a, {1})), - For::make(x, 1, 10, Store::make(a, {x}, x)), - Store::make(a, {1}, Load::make(a, {0})), - Store::make(a, {0}, Load::make(a, {1})), - Store::make(a, {0}, Load::make(a, {1}))}); - - /* - * A[1] = A[0]; - * A[0] = A[1]; - * A[0] = A[1]; - * for (int x = 1; x < 10; x++) { - * A[x] = x; - * } - * A[1] = A[0]; - * A[0] = A[1]; - * A[0] = A[1]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * int A_2 = A_1; - * A_1 = A_2; - * A_1 = A_2; - * A[1] = A_2; - * A[0] = A_1; - * for (int x = 1; x < 10; x++) { - * A[x] = x; - * } - * int A_3 = A[0]; - * int A_4 = A_3; - * A_3 = A_4; - * A_3 = A_4; - * A[1] = A_4; - * A[0] = A_3; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: int A_2 = A_1; -# CHECK: A_1 = A_2; -# CHECK: A_1 = A_2; -# CHECK: A[1] = A_2; -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x] = x; -# CHECK: } -# CHECK: int A_3 = A[0]; -# CHECK: int A_4 = A_3; -# CHECK: A_3 = A_4; -# CHECK: A_3 = A_4; -# CHECK: A[1] = A_4; -# CHECK: A[0] = A_3;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Nested blocks will automatically be flattened and do not provent -// registerization of enclosed accesses. -TEST(Registerizer, RegisterizerNestedBlocks) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})}); - - /* - * A[0] = (A[0]) + 1; - * { - * A[0] = (A[0]) + 2; - * } - * { - * A[0] = (A[0]) + 3; - * { - * A[0] = (A[0]) + 4; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * A_1 = A_1 + 2; - * A_1 = A_1 + 3; - * A_1 = A_1 + 4; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: A_1 = A_1 + 2; -# CHECK: A_1 = A_1 + 3; -# CHECK: A_1 = A_1 + 4; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// The access can be registerized internally to a condition, but must ensure -// that both initializer and finalizer are within the same condition. -TEST(Registerizer, RegisterizerNestedConditions) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * if (x==2 ? 1 : 0) { - * - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * if (x==2 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x==2 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// If an access exists outside the scope of the condition then we can lift -// nested conditional usages into the same scalar. -TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {1}, 1), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * A[0] = (A[0]) + 1; - * if (x<5 ? 1 : 0) { - * A[1] = 1; - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * if (x<5 ? 1 : 0) { - * A[1] = 1; - * if (x==2 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x<5 -# CHECK: A[1] = 1; -# CHECK: if (x==2 -# CHECK: A_1 = A_1 + 1; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * if (x<5 ? 1 : 0) { - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - stmt = registerize(stmt); -} - -TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - stmt = registerize(stmt); -} - -// If an access is cut by another access internal to a condition block, it still -// cuts the access. -TEST(Registerizer, RegisterizerNestedConditionsCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {x}, 1), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * A[0] = (A[0]) + 1; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * if (x==2 ? 1 : 0) { - * - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}))}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * for (int x = 0; x < 10; x++) { - * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering. - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Three loops and four element regions, three of which should be registerized -// at different levels of the IR. -TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {4}, 0), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kGT), - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kGT), - Block::make({ - Cond::make( - CompareSelect::make(x, 4, CompareSelectOperation::kGT), - Block::make({ - Store::make( - a, {1}, Add::make(Load::make(a, {1}), 1)), - Store::make( - a, {2}, Add::make(Load::make(a, {2}), 1)), - Store::make( - a, {3}, Add::make(Load::make(a, {3}), 1)), - Store::make( - a, {4}, Add::make(Load::make(a, {4}), 1)), - Store::make( - a, {1}, Add::make(Load::make(a, {1}), 1)), - }), - nullptr), - Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)), - }), - nullptr), - nullptr)}); - - /* - * A[4] = 0; - * if (x>2 ? 1 : 0) { - * if (x>3 ? 1 : 0) { - * if (x>4 ? 1 : 0) { - * A[1] = (A[1]) + 1; - * A[2] = (A[2]) + 1; - * A[3] = (A[3]) + 1; - * A[4] = (A[4]) + 1; - * A[1] = (A[1]) + 1; - * } - * A[2] = (A[2]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * if (x>2 ? 1 : 0) { - * if (x>3 ? 1 : 0) { - * int A_3 = A[2]; - * if (x>4 ? 1 : 0) { - * int A_2 = A[1]; - * A_2 = A_2 + 1; - * A_3 = A_3 + 1; - * A[3] = (A[3]) + 1; - * A_1 = A_1 + 1; - * A_2 = A_2 + 1; - * A[1] = A_2; - * } - * A_3 = A_3 + 1; - * A[2] = A_3; - * } - * } - * A[4] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: if (x>2 ? 1 : 0) { -# CHECK: if (x>3 ? 1 : 0) { -# CHECK: int A_3 = A[2]; -# CHECK: if (x>4 ? 1 : 0) { -# CHECK: int A_2 = A[1]; -# CHECK: A_2 = A_2 + 1; -# CHECK: A_3 = A_3 + 1; -# CHECK: A[3] = (A[3]) + 1; -# CHECK: A_1 = A_1 + 1; -# CHECK: A_2 = A_2 + 1; -# CHECK: A[1] = A_2; -# CHECK: } -# CHECK: A_3 = A_3 + 1; -# CHECK: A[2] = A_3; -# CHECK: } -# CHECK: } -# CHECK: A[4] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can replace a simple scalar access with a local variable even when that -// variable is an outer loop var. -TEST(Registerizer, RegisterizerNestedLoopSimple) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({For::make( - y, - 0, - 10, - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))}); - - /* - * for (int y = 0; y < 10; y++) { - * for (int x = 0; x < 10; x++) { - * A[y] = (A[y]) + x; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * for (int y = 0; y < 10; y++) { - * int A_1 = A[y]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[y] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int y -# CHECK: int A_1 = A[y]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[y] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Test the positive case of the hiddenAccess split, where an internal -// conditional access can be hoisted up through a loop to match an existing -// access in a higher scope and the two can be registerized. -TEST(Registerizer, RegisterizerHiddenAccessYes) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kEQ), - For::make( - y, - 0, - 10, - Store::make( - a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A[0] = (A[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A_1 = A_1 + 1; - * } - * } - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: int A_1 = 0; -# CHECK: for (int x -# CHECK: B[x] = 0; -# CHECK: if (x==3 -# CHECK: for (int y -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Test the negative case of the hiddenAccess split, where the hoisted access is -// never unhidden at a higher scope and registerization occurs at the lower -// scope. -TEST(Registerizer, RegisterizerHiddenAccessNo) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kEQ), - For::make( - y, - 0, - 10, - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A[0] = (A[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * int A_1 = A[0]; - * for (int y = 0; y < 10; y++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * } - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: for (int x -# CHECK: B[x] = 0; -# CHECK: if (x==3 -# CHECK: int A_1 = A[0]; -# CHECK: for (int y -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: } -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// In this case the conditional access must be hoisted by two loops, there are -// two accesses here one is unhidden and the other isn't. A[0] can be -// registerized but B[0] cannot. -TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - Block::make({Cond::make( - CompareSelect::make(y, 3, CompareSelectOperation::kEQ), - Block::make( - {Store::make( - a, {0}, Add::make(Load::make(a, {0}), 1)), - Store::make( - b, {0}, Add::make(Load::make(b, {0}), 1))}), - nullptr)})))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * if (y==3 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * B[0] = (B[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * if (y==3 ? 1 : 0) { - * A_1 = A_1 + 1; - * B[0] = (B[0]) + 1; - * } - * } - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: int A_1 = 0; -# CHECK: for (int x -# CHECK: for (int y -# CHECK: if (y==3 -# CHECK: A_1 = A_1 + 1; -# CHECK: B[0] = (B[0]) + 1; -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Accesses are registerized inside two conditions, but the immediate parent is -// not a condition. -TEST(Registerizer, RegisterizerTwoConditionalLoops) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - * if (x>5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * if (x>5 ? 1 : 0) { - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + 1; - * } - * A[0] = A_2; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: if (x>5 -# CHECK: int A_2 = A[0]; -# CHECK: for (int x -# CHECK: A_2 = A_2 + 1; -# CHECK: } -# CHECK: A[0] = A_2; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Accesses are registerized inside two conditions, cut in the middle. -TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr), - For::make(x, 0, 10, Store::make(a, {x}, 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - * for (int x = 0; x < 10; x++) { - * A[x] = 1; - * } - * if (x>5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * for (int x = 0; x < 10; x++) { - * A[x] = 1; - * } - * if (x>5 ? 1 : 0) { - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + 1; - * } - * A[0] = A_2; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: for (int x -# CHECK: A[x] = 1; -# CHECK: if (x>5 -# CHECK: int A_2 = A[0]; -# CHECK: for (int x -# CHECK: A_2 = A_2 + 1; -# CHECK: } -# CHECK: A[0] = A_2; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// references a Let var in a local scope which cannot be hoisted out of the -// loop. -TEST(Registerizer, RegisterizerLoopLetVar) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make( - x, - 0, - 10, - Block::make( - {Let::make(y, 30), - Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))})); - - /* - * for (int x = 0; x < 10; x++) { - * int y = 30; - * A[y] = x + (A[y]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// references a Let var in an outer scope that does not prevent hoisting the -// initializer. -TEST(Registerizer, RegisterizerLoopLetVarOuter) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Let::make(y, 30), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}); - - /* - * int y = 30; - * for (int x = 0; x < 10; x++) { - * A[y] = x + (A[y]); - * } - */ - - stmt = registerize(stmt); - - /* - * int y = 30; - * int A_1 = A[y]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[y] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int y = 30; -# CHECK: int A_1 = A[y]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + x; -# CHECK: A[y] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Okay so the registerizer generally goes after index flattening, but just in -// case. Test multi index registerization. -TEST(Registerizer, RegisterizerMultiDim) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))}); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, 1, 2] = (A[0, 1, 2]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0, 1, 2] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0, 1, 2] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't registerize if only some dims match, but will still registerize -// distinct elements. -TEST(Registerizer, RegisterizerMultiDimPartial) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))}); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, 2, 2] = (A[0, 1, 4]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * A[0, 1, 2] = 0; - * int A_1 = A[0, 1, 4]; - * int A_2 = A[0, 2, 2]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * } - * A[0, 2, 2] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: A[0, 1, 2] = 0; -# CHECK: int A_1 = A[0, 1, 4]; -# CHECK: int A_2 = A[0, 2, 2]; -# CHECK: for ( -# CHECK: A_2 = A_1 + x; -# CHECK: A[0, 2, 2] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// If they could overlap across all dimensions we cannot registerize. -TEST(Registerizer, RegisterizerMultiDimOverlap) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = (A[y, 2, 2]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// But, if one dimension is known to be distinct they do not overlap. -TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))}); - - /* - * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store. - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff. - * } - */ - - stmt = registerize(stmt); - - /* - * A[0, 1, 2] = 0; - * int A_1 = A[y, 2, 4]; - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = A_1 + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: A[0, 1, 2] = 0; -# CHECK: int A_1 = A[y, 2, 4]; -# CHECK: for ( -# CHECK: A[0, x, 2] = A_1 + x; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// A 3D reduction with different input dimensionality. -TEST(Registerizer, RegisterizerMultiDim3DReduction1) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10, 10}, kInt); - BufHandle c("C", {10, 10, 10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - For::make( - z, - 0, - 10, - Store::make( - c, - {x, y, z}, - Add::make( - Load::make(c, {x, y, z}), - Mul::make(Load::make(b, {x, y}), Load::make(a, {x}))))))); - - /* - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 10; z++) { - * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]); - * } - * } - * } - */ - - // We can registerize the A and B access since they can be hoisted before - // hitting a dependent loop var. - - stmt = registerize(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * for (int y = 0; y < 10; y++) { - * int B_1 = B[x, y]; - * for (int z = 0; z < 10; z++) { - * C[x, y, z] = A_1 * B_1 + (C[x, y, z]); - * } - * } - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x -# CHECK: int A_1 = A[x]; -# CHECK: for (int y -# CHECK: int B_1 = B[x, y]; -# CHECK: for (int z -# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]); -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// A 3D reduction with the same smaller dimensionality using different loop -// vars. -TEST(Registerizer, RegisterizerMultiDim3DReduction2) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = For::make( - x, - 0, - 10, - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - For::make( - y, - 0, - 10, - For::make( - z, - 0, - 10, - Store::make( - c, - {x}, - Add::make( - Load::make(c, {x}), - Mul::make(Load::make(b, {y}), Load::make(a, {x}))))))); - - /* - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 10; z++) { - * C[x] = (C[x]) + (B[y]) * (A[x]); - * } - * } - * } - */ - - // We can registerize all accesses, the A and C access can be hoisted to the - // outer loop since they depend only on it's loop var while the B can only be - // raised to the loop of y. - - stmt = registerize(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * int C_1 = C[x]; - * for (int y = 0; y < 10; y++) { - * int B_1 = B[y]; - * for (int z = 0; z < 10; z++) { - * C_1 = A_1 * B_1 + C_1; - * } - * } - * C[x] = C_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x -# CHECK: int A_1 = A[x]; -# CHECK: int C_1 = C[x]; -# CHECK: for (int y -# CHECK: int B_1 = B[y]; -# CHECK: for (int z -# CHECK: C_1 = A_1 * B_1 + C_1; -# CHECK: } -# CHECK: } -# CHECK: C[x] = C_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp deleted file mode 100644 index 7ca2b74eaa766..0000000000000 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ /dev/null @@ -1,5680 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; -using SimpleIRExprEval = ExprEval; - -TEST(Simplify, ConstantFoldSimple) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle f = (a + b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 5); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 5.f); -} - -TEST(Simplify, ConstantFoldTwoLayer) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), -4); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), -4.f); -} - -TEST(Simplify, ConstantFoldShifts) { - ExprHandle a(7); - ExprHandle b(2); - ExprHandle c(3); - ExprHandle f = ((a << b) << b) >> c; - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 14); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 7 << (4 - 3)); -} - -TEST(Simplify, ConstantFoldBitwise) { - ExprHandle a(59); - ExprHandle b(22); - ExprHandle c(101); - ExprHandle f = (a ^ b) & c; - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 37); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), (59 ^ 22) & 101); -} - -TEST(Simplify, ConstantFoldMultiOp) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle e(6.0f); - ExprHandle f(7.0f); - ExprHandle fn = ((a / e) - (c + d)) * (f / b); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - - SimpleIRExprEval eval(newF); - SimpleIRExprEval ref(fn); - - ASSERT_EQ(eval.value(), ref.value()); -} - -TEST(Simplify, ConstantFoldMinMax) { - ExprHandle a(12.0f); - ExprHandle b(15.0f); - ExprHandle c(17.0f); - - // x = max(12, min(15, 17)). - ExprHandle minHandle = Min::make(b, c, true); - ExprHandle fn = Max::make(a, minHandle, false); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 15.f); -} - -TEST(Simplify, ConstantFoldIntrinsics) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle powHandle = Intrinsics::make(kPow, a, b); - ExprHandle sinHandle = Intrinsics::make(kSin, powHandle); - ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle); - ExprHandle logHandle = Intrinsics::make(kLog10, modHandle); - ExprHandle rndHandle = Intrinsics::make(kRound, logHandle); - ExprHandle fn = Intrinsics::make(kAbs, rndHandle); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - SimpleIRExprEval ref(fn); - - ASSERT_EQ(eval.value(), ref.value()); -} - -TEST(Simplify, ConstantFoldCastToBool) { - ExprHandle f = Cast::make(kBool, IntImm::make(0)); - ExprHandle newF = IRSimplifier::simplify(f); - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), false); -} - -TEST(Simplify, ConstantFoldWithVar) { - { - VarHandle x("x", kInt); - ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); - - ExprHandle newF = IRSimplifier::simplify(body); - MulPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_NE(to(root->lhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3)); - ASSERT_EQ(eval.value(), 3 * (2 + 4)); - } - - { - VarHandle x("x", kFloat); - ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); - - ExprHandle newF = IRSimplifier::simplify(body); - MulPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_NE(to(root->rhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 3 * (2 + 4)); - } -} - -TEST(Simplify, ConditionalSelectFoldSimple) { - ExprHandle a(3.0f); - ExprHandle b(4.0f); - ExprHandle c(3.0f); - { - ExprHandle f = (a > b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } - { - ExprHandle f = (a < b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a == c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a != c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, ConditionalSelectFoldTwoLayer) { - ExprHandle a(3.0f); - ExprHandle b(2.0f); - ExprHandle c(2.0f); - ExprHandle d(1.0f); - { - ExprHandle f = (a + b < c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } - { - ExprHandle f = (a + b > c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a + d == b + c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a + d != b + c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, ConditionalSelectFoldWithVar) { - VarHandle x("x", kFloat); - ExprHandle f = x < 4.f; - - ExprHandle newF = IRSimplifier::simplify(f); - IntImmPtr folded = newF.AsNode(); - ASSERT_EQ(folded, nullptr); - - { - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 1); - } - { - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(5.f)); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, UnFoldableExpr) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); - - ExprHandle newF = IRSimplifier::simplify(body); - AddPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_EQ(to(root->lhs()), nullptr); - ASSERT_EQ(to(root->rhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - eval.bindVar(y, ExprHandle(2.f)); - ASSERT_EQ(eval.value(), 9 + 10); -} - -TEST(Simplify, HashSimple) { - VarHandle x("x", kFloat); - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle f = a + b * x; - - HashProvider hasher; - - auto hash_x = hasher.hash(x.node()); - auto hash_a = hasher.hash(a.node()); - auto hash_f = hasher.hash(f.node()); - - ASSERT_NE(hash_x, (size_t)0); - ASSERT_NE(hash_a, (size_t)0); - ASSERT_NE(hash_f, (size_t)0); - ASSERT_NE(hash_x, hash_a); - ASSERT_NE(hash_x, hash_f); - ASSERT_NE(hash_a, hash_f); -} - -TEST(Simplify, HashEquivalence) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle f = (x * y) + (x * y); - - AddPtr root = f.AsNode(); - ASSERT_NE(root, nullptr); - - HashProvider hasher; - auto hash_f = hasher.hash(f.node()); - auto hash_l = hasher.hash(root->lhs()); - auto hash_r = hasher.hash(root->rhs()); - - // Root not equal to either branch. - ASSERT_NE(hash_f, hash_l); - ASSERT_NE(hash_f, hash_r); - // but branches are equal. - ASSERT_EQ(hash_l, hash_r); - - // Still equivalent if separate. - ExprHandle a(2); - ExprHandle f2 = x + a / y; - ExprHandle b(2); - ExprHandle f3 = x + b / y; - ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node())); - - // Not equivalent if different vars (even with same name). - VarHandle z("x", kFloat); - ExprHandle f4 = z + b / y; - ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node())); - - // Intrinsics sanity check. - ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x); - ASSERT_NE(hasher.hash(f5.node()), (size_t)0); -} - -TEST(Simplify, HashEquivalenceRand) { - ExprHandle f = - Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); - - AddPtr root = f.AsNode(); - ASSERT_NE(root, nullptr); - - HashProvider hasher; - auto hash_f = hasher.hash(f.node()); - auto hash_l = hasher.hash(root->lhs()); - auto hash_r = hasher.hash(root->rhs()); - - // Root not equal to either branch. - ASSERT_NE(hash_f, hash_l); - ASSERT_NE(hash_f, hash_r); - // and branches are NOT equal. - ASSERT_NE(hash_l, hash_r); -} - -TEST(Simplify, HashEquivalenceAfterFolding) { - VarHandle x("x", kFloat); - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(5.0f); - - ExprHandle f1 = ((a + b) * x); - ExprHandle f2 = (c * x); - - HashProvider hasher; - auto hash_l = hasher.hash(f1.node()); - auto hash_r = hasher.hash(f2.node()); - - // Root not equal to either branch, and branches not equal. - ASSERT_NE(hash_l, hash_r); - - ExprHandle ff1 = IRSimplifier::simplify(f1); - ExprHandle ff2 = IRSimplifier::simplify(f2); - - auto hash_l_n = hasher.hash(ff1.node()); - auto hash_r_n = hasher.hash(ff2.node()); - // but branches are now equal. - ASSERT_EQ(hash_l_n, hash_r_n); -} - -TEST(Simplify, HashDifferenceTypes) { - HashProvider hasher; - std::vector immediates; - - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - // NOLINTNEXTLINE(modernize-use-bool-literals) - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - - // Immediates of different types are not equal. - for (unsigned int i = 0; i < immediates.size(); ++i) { - for (unsigned int j = i + 1; j < immediates.size(); ++j) { - ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j])); - } - } - - // But coerced immediates are if they are the same type: - ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1); - ExprHandle f2 = Cast::make(kFloat, IntImm::make(3)); - - ExprHandle ff1 = IRSimplifier::simplify(f1); - ExprHandle ff2 = IRSimplifier::simplify(f2); - - ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node())); -} - -TEST(Simplify, HashLargeExpression) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto memcpy_stmt = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - CompareSelect::make( - Load::make(a, {i}), - Load::make(b, {i}), - CompareSelectOperation::kEQ))); - - BufHandle d("D", {1}, kInt); - BufHandle e("E", {1}, kInt); - auto store_ramp_stmt = Store::make( - e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)})); - - auto if_stmt = Cond::make( - CompareSelect::make( - Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE), - memcpy_stmt, - store_ramp_stmt); - - HashProvider hasher; - auto hash_r = hasher.hash(if_stmt); - // We should not have to do any more work. - ASSERT_TRUE(hasher.cachedHash(memcpy_stmt)); - auto hash_t = hasher.hash(memcpy_stmt); - ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt)); - auto hash_f = hasher.hash(store_ramp_stmt); - - // Root not equal to either branch, and branches not equal. - ASSERT_NE(hash_r, hash_t); - ASSERT_NE(hash_r, hash_f); - ASSERT_NE(hash_t, hash_f); -} - -TEST(Simplify, HashForLoopOptions) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto for_stmt = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - CompareSelect::make( - Load::make(a, {i}), - Load::make(b, {i}), - CompareSelectOperation::kEQ))); - - HashProvider hasher; - auto hash_before = hasher.hash(for_stmt); - hasher.clearCache(); - - for_stmt->set_gpu_block_index(LoopOptions::IDX_X); - auto hash_block_idx = hasher.hash(for_stmt); - hasher.clearCache(); - - ASSERT_NE(hash_before, hash_block_idx); - - for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET); - auto hash_reset = hasher.hash(for_stmt); - hasher.clearCache(); - - ASSERT_EQ(hash_before, hash_reset); - for_stmt->set_gpu_thread_index(LoopOptions::IDX_X); - auto hash_thread_idx = hasher.hash(for_stmt); - - ASSERT_NE(hash_before, hash_thread_idx); - ASSERT_NE(hash_block_idx, hash_thread_idx); -} - -/// (2 + x) + 4 => x + 6 -TEST(Simplify, SimplifyAdd) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle m("m", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle n("n", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle n_1("n_1", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); - - ExprHandle simplified = IRSimplifier::simplify(body); - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - VarPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->name_hint(), "x"); - IntImmPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->value(), 6.f); -} - -/// (2 - x) - 4 => -2 - x -TEST(Simplify, SimplifySub) { - VarHandle x("x", kInt); - ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); - - ExprHandle simplified = IRSimplifier::simplify(body); - SubPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - IntImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), -2.f); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// 2 * (1 - x) - 4 => 2 * (-3 - x) -TEST(Simplify, SimplifyMultiLayer) { - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_IMM_WITH_VAL(Int, sub->lhs(), -3); - IS_VAR_WITH_NAME(sub->rhs(), "x"); -} - -/// 2 * (3 * x) - (x * 4) => 2 * x -TEST(Simplify, SimplifyMultiTerm) { - VarHandle x("x", kInt); - ExprHandle body = - (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); - - ExprHandle simplified = IRSimplifier::simplify(body); - MulPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - IntImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), 2); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// 2 * (3 * (long)x) - (x * 4) => 2 * x -TEST(Simplify, SimplifyCasts) { - VarHandle x("x", kLong); - ExprHandle body = - (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); - - ExprHandle simplified = IRSimplifier::simplify(body); - MulPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - LongImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), 2); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// (x + 0) * 1 => x -TEST(Simplify, SimplifyEliminatesNoOps) { - VarHandle x("x", kInt); - ExprHandle body = (x + ExprHandle(0)) * 1; - - ExprHandle simplified = IRSimplifier::simplify(body); - VarPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_EQ(root->name_hint(), "x"); -} - -/// Cannot simplify this. -TEST(Simplify, SimplifyMultiVar) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = x * 24 + y * 34; - - ExprHandle simplified = IRSimplifier::simplify(body); - - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - MulPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - VarPtr varX = to(lhs->rhs()); - ASSERT_NE(varX, nullptr); - ASSERT_EQ(varX->name_hint(), "x"); - MulPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - VarPtr varY = to(rhs->rhs()); - ASSERT_NE(varY, nullptr); - ASSERT_EQ(varY->name_hint(), "y"); -} - -// x + 2 + y => x + y + 2 -TEST(Simplify, DISABLED_SimplifyReorderings) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = x + 2 + y; - ExprHandle simplified = IRSimplifier::simplify(body); - - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - - IS_NODE_WITH_NAME(Add, root->lhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - IS_IMM_WITH_VAL(Int, root->rhs(), 2); -} - -/// y + x * 0 => y -TEST(Simplify, SimplifyEliminatesVar) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = y + x * ExprHandle(0); - - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); -} - -TEST(Simplify, SimplifyAdds) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) + (x + y) => 2 * (x + y) - ExprHandle body = (x + y) + (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Add, root->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // (x * y) + (x * y) => 2 * (x * y) - ExprHandle body = (x * y) + (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Mul, root->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - y) + (x - y) => 2 * (x - y) - ExprHandle body = (x - y) + (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // (x + x + x + x) => 4 * x - ExprHandle body = (x + x + x + x); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 4); - IS_VAR_WITH_NAME(root->rhs(), "x"); - } - - { - // (x + 0) => x. - ExprHandle body = x + 0; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x + 0.f) => float(x). - ExprHandle body = x + 0.f; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } -} - -TEST(Simplify, SimplifyMuls) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) * (x + y) => (x + y) * (x + y) - // We don't attempt to simplify multiplication of polynomials since the - // result is only very rarely more efficient. - ExprHandle body = (x + y) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // x * y * x * y => x * x * y * y - // These get reordered only. - ExprHandle body = x * y * x * y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); - IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); - IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); - IS_VAR_WITH_NAME(mul1->rhs(), "y"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - IS_VAR_WITH_NAME(mul3->lhs(), "x"); - IS_VAR_WITH_NAME(mul3->rhs(), "x"); - } - - { - // 1 * (x * 1) => x - // Ones cancel cleanly. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // 1.f * (x * 1.f) => x - // Even float ones cancel cleanly, but carry their type. - ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // 1 * (x * 1.f) => x - // One float is enough to cast the expr. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // 1 * (x * 0) => 0 - // Zeroes are eliminated. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // 1 * (x * 0) => 0 - // But not for Float since nan * 0 = nan. - ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Cast, mul->lhs(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0); - } - - { - // (x - y) * (x - y) => (x - y) * (x - y) - // As with Add we don't attempt simplification of this. - ExprHandle body = (x - y) * (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // (x + y) * (x - y) => (x + y) * (x - y) - // Don't simplify with different ops on each side. - ExprHandle body = (x + y) * (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with no scalar, poly with non-identity scalar. - // x * (y + 1) => x + x * y - ExprHandle body = x * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with identity scalar, poly with non-identity scalar. - // (x * 1) * (y + 1) => x + x * y - ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with non-identity scalar, poly with non-identity scalar. - // (x * 2) * (y + 1) => 2 * (x + x * y) - ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with non-identity scalar, poly with identity scalar. - // (x * 2) * (y + 0) => 2 * (x * y) - ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with identity scalar, poly with identity scalar. - // (x * 1) * (y + 0) => x * y - ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with no scalar, poly with identity scalar. - // x * (y + 0) => x * y - ExprHandle body = x * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } -} - -// Sub an expr from itself will result in zero. -TEST(Simplify, SimplifySubs) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) - (x + y) => 0 - ExprHandle body = (x + y) - (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x * y) - (x * y) => 0 - ExprHandle body = (x * y) - (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x - y) - (x - y) => 0 - ExprHandle body = (x - y) - (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x + y) - 2 * (x + y) => -1 * x - y - ExprHandle body = (x + y) - ExprHandle(2) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // (x + y) - y => x - ExprHandle body = (x + y) - y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0) => x. - ExprHandle body = x - 0; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0.f) => x. - // Simple enough to cancel in float. - ExprHandle body = x - ExprHandle(0.f); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // (x - (float)(y - y)) => x. - ExprHandle body = x - Cast::make(kFloat, y - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // (x - y) - y => x - 2 * y - ExprHandle body = (x - y) - y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // 2 * x - x => x - ExprHandle body = (ExprHandle(2) * x) - x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // x - 2 * x = -1 * x - // We don't have a unary negate, but this could be 0 -x I guess? - ExprHandle body = x - (ExprHandle(2) * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - - IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // (x + y + 5) * (x - x) => 0 - // Cancelling out one side of Mul cancels both. - ExprHandle body = (x + y + 5) * (x - x); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Cancel out opaque modulus. - ExprHandle body = (x % y + 2) - (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // Cancel out opaque modulus with a bit more going on. - ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // Sub where result is negative. - ExprHandle body = x - (x + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), -1); - } - - { - // Sub where result is positive due to negative scalar on RHS. - ExprHandle body = x - (x - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 1); - } - - { - // Term - Polynomial sub where RHS must be negated. - ExprHandle body = (x * 2) - (x * 2 + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), -1); - } - - { - // Term - Polynomial sub where the result is a Term. - ExprHandle body = (y * x * 2) - (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Term - Polynomial sub where the result is a Polynomial. - ExprHandle body = (x * 2) - (x + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_IMM_WITH_VAL(Int, sub->rhs(), 1); - } -} - -TEST(Simplify, SimplifyDiv) { - VarHandle x("x", kInt); - - { - ExprHandle body = ExprHandle(0) / x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - ExprHandle body = x / 1; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } -} - -TEST(Simplify, SimplifyDivWithLoopContext0) { - // Stmt to simplify: - // for (int i = 0; i < 100; i++) { - // A[i] = i / 100; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {100}, kInt); - auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext1) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext2) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i + 25) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext3) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) / (-6); - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = -4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext4) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i - 5) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = 0; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext5) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) / 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NEXT: A[i, j] = j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext6) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (int j = -1; j < 9; j++) { - // A[i, j+1] = (i + 6*j) / 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext7) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) / (-6); - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6))); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = -j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext0) { - // Stmt to simplify: - // for (const auto i : c10::irange(100)) { - // A[i] = i % 100; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {100}, kInt); - auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext1) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext2) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i + 25) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i + 1; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext3) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) % (-6); - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext4) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i - 5) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = i - 5; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext5) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) % 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NEXT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext6) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (int j = -1; j < 9; j++) { - // A[i, j+1] = (i + 6*j) % 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext7) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) % (-6); - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6))); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyMod) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Constant folding works. - ExprHandle body = ExprHandle(10) % 8; - ExprHandle simplified = IRSimplifier::simplify(body); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // x % x => 0 - ExprHandle body = x % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // 0 % x => 0 - ExprHandle body = ExprHandle(0) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // x % 1 => 0 - ExprHandle body = x % 1; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Doesn't change unknown mods. - // x % y => x % y - ExprHandle body = x % y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "y"); - } - - { - // don't touch if RHS is unknown. - // 4 % x => 4 % x - ExprHandle body = ExprHandle(4) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_IMM_WITH_VAL(Int, mod->lhs(), 4); - IS_VAR_WITH_NAME(mod->rhs(), "x"); - } - - { - // don't touch if LHS is unknown. - // x % 4 => x % 4 - ExprHandle body = x % 4; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 4); - } - - { - // if LHS is a multiple of RHS, mod is zero. - // 2 * x % x => 0 - ExprHandle body = (x * 2) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true even if the multiple is not constant. - // x * y % x => 0 - ExprHandle body = (x * y) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true with multiple unknown values in LHS. - // x * y * z % x => 0 - ExprHandle body = (x * y * z) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true if the denom is compound. - // x * y * z % y * z => 0 - ExprHandle body = (x * y * z) % (y * z); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Sanity check true with scalars that are multiples. - // 12 * x % 4 => 0 - ExprHandle body = (x * 12) % 4; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Sanity check not true if the smaller scalar is on LHS. - // 4 * x % 12 => 4 * x % 12 - ExprHandle body = (x * 4) % 12; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 4); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 12); - } - - { - // Both scalar and symbolic in multiple. - // (6 * x * y) % (3 * x * y) => 0 - ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } -} - -// Test that mixing ops together simplifies as expected. -TEST(Simplify, SimplifyMultiOp) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x * y) + (x - y) => (x + x * y) - y - ExprHandle body = (x * y) + (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // (x + y) - x * y => (x + y) - x * y - ExprHandle body = (x + y) - x * y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - y) - (x + y) => -2 * y - ExprHandle body = (x - y) - (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - 0) + (x * 1) - (x + 0) => x - ExprHandle body = (x - 0) + (x * 1) - (x + 0); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x) - // Even in Float simple terms cancel out, but the variable ones cannot. - ExprHandle body = - (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_NODE_WITH_NAME(Cast, add->lhs(), cast1); - IS_VAR_WITH_NAME(cast1->src_value(), "x"); - IS_NODE_WITH_NAME(Cast, add->rhs(), cast2); - IS_VAR_WITH_NAME(cast2->src_value(), "x"); - IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3); - IS_VAR_WITH_NAME(cast3->src_value(), "x"); - } -} - -// Test that chaining many ops together works as expected. -TEST(Simplify, SimplifyManyOps) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x - ExprHandle body = x + y + x + x + y + y + x + y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); - } - - { - // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y - ExprHandle body = x - y + x + x - y - y + x - y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // x + y + x - x - y - y + x + y + x = 3 * x - ExprHandle body = x + y + x - x - y - y + x + y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 3); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } -} - -TEST(Simplify, SimplifyFactorization) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (2 * x) + (2 * y) => 2 * (x + y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Factorization when scalars have common divider. - // (2 * x) + (4 * y) => 2 * (2 * y + x) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Factorization attempt without a common divider. - // (2 * x) + (5 * y) => (5 * y) + (2 * x) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Factorization after merging. - // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + - (ExprHandle(8) * x + ExprHandle(6) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 10); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Factorization with common divider but different signs. - // (2 * x) + (-4 * y) => 2 * (x - 2 * y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Factorization with all negative numbers. - // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y) - ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), -1); - IS_VAR_WITH_NAME(mul2->rhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3); - IS_IMM_WITH_VAL(Int, mul3->lhs(), 2); - IS_VAR_WITH_NAME(mul3->rhs(), "y"); - } - - { - // The following test ensures that there in no infinite recursion during - // factorization when negative numbers are involved. - VarHandle a("a", kInt); - VarHandle b("b", kInt); - VarHandle c("c", kInt); - VarHandle d("d", kInt); - VarHandle e("e", kInt); - VarHandle f("f", kInt); - VarHandle g("g", kInt); - VarHandle h("h", kInt); - - ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 + - f * 32 + g * (-1024) + h * (-32); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR( - simplified, - "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h"); - } -} - -// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) -TEST(Simplify, SimplifyFactorizeUneven) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = - (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Add, root->rhs(), add1); - IS_NODE_WITH_NAME(Add, add1->lhs(), add2); - - IS_VAR_WITH_NAME(add2->lhs(), "y"); - IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul); - IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); - - IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); - IS_VAR_WITH_NAME(xmul->rhs(), "x"); - - IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); - IS_VAR_WITH_NAME(zmul->rhs(), "z"); -} - -// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) -// This is kind of a placeholder test for variable factorization. -TEST(Simplify, SimplifyDeeperTerms) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); - IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); - IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); -} - -// Tests the difference between two less trivial expressions. -// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 -TEST(Simplify, SimplifyDeeperDifference) { - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - VarHandle m("m", kInt); - ExprHandle body = - (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 1); -} - -// Test constant folding into the difference between expressions. -// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 -TEST(Simplify, SimplifyFoldComplexDifference) { - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - VarHandle m("m", kInt); - ExprHandle body = - (IntImm::make(2) + - (Cast::make( - kChar, - (m * (ExprHandle(1) * n_1) + (n + 1)) - - (m * (ExprHandle(1) * n_1) + n)))); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 3); -} - -TEST(Simplify, SimplifyIfComponents) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make( - ((ExprHandle(5) - ExprHandle(4)) * x) > y, - ExprHandle(2) * x - x, - ExprHandle(2) * y - y); - - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); - - IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kGT); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_VAR_WITH_NAME(cmp->rhs(), "y"); - - IS_VAR_WITH_NAME(ifexpr->true_value(), "x"); - IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); -} - -TEST(Simplify, SimplifyOpaqueTerms) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // 2 * x/y * y - x/y * y => x/y * y - ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Div, mul->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // x%y - (x%y - 1) => 1 - ExprHandle body = (x % y) - ((x % y) - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 1); - } -} - -TEST(Simplify, SimplifySymbolicMinMax) { - { - // Minimum with constant difference between terms. - VarHandle x("x", kInt); - ExprHandle body = Min::make(x + 3, x + 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_IMM_WITH_VAL(Int, add->rhs(), 3); - } - - { - // Maximum with constant difference between terms. - VarHandle x("x", kInt); - ExprHandle body = Max::make(x + 3, x + 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_IMM_WITH_VAL(Int, add->rhs(), 7); - } - - { - // Can't simplify multiples because of signedness of variable component. - // TODO: maybe we could for unsigned types? - VarHandle x("x", kInt); - ExprHandle body = Max::make(x * 3, x * 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE(Max, simplified.node()); - } -} - -TEST(Simplify, SimplifyNestedMax) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Max(x + y, x + y) => x + y - ExprHandle body = Max::make(x + y, x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); - } - - { - // Max(x + y, Max(x + y, z)) => Max(x + y, z) - ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(x + y, Max(z, x + y)) => Max(x + y, z) - ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(x + y, z), x + y) => Max(x + y, z) - ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(z, x + y), x + y) => Max(x + y, z) - ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(x, y), x) => Max(Max(x, y), x) - // Nested Max ops with different propagate_nans should not be simplified. - ExprHandle body = Max::make(Max::make(x, y, true), x, false); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y"); - ASSERT_TRUE(max1->propagate_nans()); - IS_VAR_WITH_NAME(max->rhs(), "x"); - ASSERT_FALSE(max->propagate_nans()); - } - - { - // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(x, y, true), Min::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(x, y, true), Min::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z)) - // When all the ops in the pattern do not have the same propagate_nans, - // it should not be simplified. - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(z, x, false), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y"); - ASSERT_TRUE(min1->propagate_nans()); - IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z"); - ASSERT_FALSE(min2->propagate_nans()); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(5, Max(x, 8)) => Max(x, 8) - ExprHandle body = Max::make(5, Max::make(x, 8, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(8, Max(x, 5)) => Max(x, 8) - ExprHandle body = Max::make(8, Max::make(x, 5, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(Max(x, 8), 5) => Max(x, 8) - ExprHandle body = Max::make(Max::make(x, 8, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(Max(x, 5), 8) => Max(x, 8) - ExprHandle body = Max::make(Max::make(x, 5, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8) - // Do not simplify when all the Max ops do not have the same - // propagate_nans. - ExprHandle body = Max::make( - Max::make(Max::make(Max::make(z, 5, true), y, false), x, true), - 8, - false); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)"); - } - - { - // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } -} - -TEST(Simplify, SimplifyNestedMin) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Min(x + y, x + y) => x + y - ExprHandle body = Min::make(x + y, x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); - } - - { - // Min(x + y, Min(x + y, z)) => Min(x + y, z) - ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(x + y, Min(z, x + y)) => Min(x + y, z) - ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(x + y, z), x + y) => Min(x + y, z) - ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(z, x + y), x + y) => Min(x + y, z) - ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(x, y), x) => Min(Min(x, y), x) - // Nested Min ops with different propagate_nans should not be simplified. - ExprHandle body = Min::make(Min::make(x, y, true), x, false); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y"); - ASSERT_TRUE(min2->propagate_nans()); - IS_VAR_WITH_NAME(min1->rhs(), "x"); - ASSERT_FALSE(min1->propagate_nans()); - } - - { - // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(x, y, true), Max::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(x, y, true), Max::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z)) - // When all the ops in the pattern do not have the same propagate_nans, - // it should not be simplified. - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(z, x, false), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y"); - ASSERT_TRUE(max1->propagate_nans()); - IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z"); - ASSERT_FALSE(max2->propagate_nans()); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(5, Min(x, 8)) => Min(x, 8) - ExprHandle body = Min::make(5, Min::make(x, 8, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(8, Min(x, 5)) => Min(x, 8) - ExprHandle body = Min::make(8, Min::make(x, 5, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(Min(x, 8), 5) => Min(x, 8) - ExprHandle body = Min::make(Min::make(x, 8, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(Min(x, 5), 8) => Min(x, 8) - ExprHandle body = Min::make(Min::make(x, 5, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8) - // Do not simplify when all the Min ops do not have the same - // propagate_nans. - ExprHandle body = Min::make( - Min::make(Min::make(Min::make(z, 5, true), y, false), x, true), - 8, - false); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)"); - } - - { - // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } -} - -TEST(Simplify, SimplifyWontReorderFloat) { - { - // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) - // This is an expression we can simplify. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 9); - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). - // If the vars are floating point, ops are not associative and we can't - // reorder. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); - IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); - IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); - IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y"); - } - - { - // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). - // We will simplify subexprs if they dont reorder floating point ops. - VarHandle x("x", kDouble); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); - IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); - IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); - IS_VAR_WITH_NAME(rhsMul->rhs(), "y"); - } - - { - // Prevent reordering if FP propagated from dtypes. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3.f) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); - IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); - IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); - IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); - IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); - IS_VAR_WITH_NAME(yCast->src_value(), "y"); - } - - { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - // x%y - (x%y - 1) => x%y - (x%y - 1). - // We won't reorder opaque ops if they are FP. - ExprHandle body = (x % y) - ((x % y) - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); - IS_VAR_WITH_NAME(lhsMod->lhs(), "x"); - IS_VAR_WITH_NAME(lhsMod->rhs(), "y"); - - IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); - IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); - IS_VAR_WITH_NAME(rhsMod->lhs(), "x"); - IS_VAR_WITH_NAME(rhsMod->rhs(), "y"); - IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); - } -} - -TEST(Simplify, SimplifyRoundModPattern) { - { - // (x/y)*y + x%y => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Reverse order. - // x%y + (x/y)*y => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x % y) + ((x / y) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Non opaque denominator. - // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) + - (x % (y + ExprHandle(4))); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Reverse order. - // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x % (y + ExprHandle(4))) + - ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Opaque denominator. - // (x / (2/y)) * (2/y)) + (x % (2/y)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) + - (x % (ExprHandle(2) / y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Non opaque numerator - // ((2*x)/y * y) + ((2*x) % y) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Opaque numerator. - // ((x/2) / y * y) + (x/2 % y) => x / 2. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_IMM_WITH_VAL(Int, div->rhs(), 2); - } - - { - // Numerator and denominator. - // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) + - ((ExprHandle(2) * x) % (ExprHandle(2) * y)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Reverse order. - // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) + - (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Negated Subtraction of Round Mod. - // (x/y) * y - (0 - x%y) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Other terms are preserved. - // (x/y)*y + x%y + (y * x) => x + (y * x). - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) + (x % y) + (y * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Sanity checking we won't do the optimization on floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = ((x / y) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); - IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv); - IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); - IS_VAR_WITH_NAME(roundDiv->rhs(), "y"); - IS_VAR_WITH_NAME(roundMul->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "y"); - } - - { - // Sanity check we won't do it if the mod term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = ((x / y) * y) + (x % z); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(x / y) * y + x % z"); - } - - { - // Sanity check we won't do it if the div term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = (y * (x / z)) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x % y + (x / z) * y"); - } - - { - // Sanity check we won't do it if the mul term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = ((x / y) * z) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x % y + (x / y) * z"); - } -} - -TEST(Simplify, SimplifyRoundModPatternFactorization) { - { - // Full factorization. - // 2 * (x/y * y) + 2 * (x%y) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Partial Factorization. - // 32 * (x/8) + 4 * (x % 8) => 4 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 4); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Factorization requiring constant folding. - // 20 * (x / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x. - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) + - (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 5); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - VarHandle x("x", kInt); - ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - VarHandle x("x", kInt); - ExprHandle body = (x / 10) * 0 + x % 5; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 5); - } -} - -TEST(Simplify, SimplifyRoundModPatternMultivar) { - { - // Multivar. - // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) + - (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Find the right var. - // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8 - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = - (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Add, add->lhs(), add2); - IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod); - IS_VAR_WITH_NAME(xMod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, xMod->rhs(), 8); - IS_VAR_WITH_NAME(add2->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), zMod); - IS_VAR_WITH_NAME(zMod->lhs(), "z"); - IS_IMM_WITH_VAL(Int, zMod->rhs(), 8); - } - - { - // Compound. - // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16) - // => (z + 512 * y) + x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x + (z + 512 * y)"); - } -} - -TEST(Simplify, SimplifyModRoundModPattern) { - { - // t/7 % 9 * 7 + t % 7 => t%63 - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63 - VarHandle t("t", kInt); - ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/x % y * x + t % x => t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (t / x % y) * x + t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // k*t/x % y * x + k*t % x => k*t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (k * t / x % y) * x + k * t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(k * t) % (x * y)"); - } - - { - // t/k/x % y * x + t/k % x => t/k%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / k / x % y) * x + t / k % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "t"); - IS_VAR_WITH_NAME(div->rhs(), "k"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Sanity checking we won't do the optimization on floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - VarHandle z("z", kFloat); - ExprHandle body = ((x / y % z) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), mul); - IS_NODE_WITH_NAME(Mod, mul->lhs(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - IS_VAR_WITH_NAME(mod->rhs(), "z"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); - IS_VAR_WITH_NAME(mod2->lhs(), "x"); - IS_VAR_WITH_NAME(mod2->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyModRoundModPatternFactorization) { - { - // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63) - VarHandle t("t", kInt); - ExprHandle body = - ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63) - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63 - VarHandle t("t", kInt); - ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "t"); - IS_IMM_WITH_VAL(Int, div->rhs(), 2); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189 - VarHandle t("t", kInt); - ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 + - t % (ExprHandle(7) * ExprHandle(3)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 189); - } - - { - // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y)) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyModRoundModPatternMultivar) { - { - // t/7 % 9 * 7 + t % 7 + t => t % 63 + t - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "t % 63 + t"); - } - - { - // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8 => t % 63 + t % 72 - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mod, add->lhs(), mod1); - IS_VAR_WITH_NAME(mod1->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod1->rhs(), 63); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); - IS_VAR_WITH_NAME(mod2->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod2->rhs(), 72); - } - - { - // k + t/x % y * x + t % x => k + t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = k + (t / x % y) * x + t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "k"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x - // => t%(x*y) + t/k % (x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)"); - } - - { - // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63) - // => io_flat - VarHandle t("io_flat", kInt); - ExprHandle body = - ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { // 5D: i0_flat / (11 * 10 * 9 * 7) * (7 * 9 * 10 * 11) + - // (i0_flat / (10 * 9 * 7) % 11) * 7 * 9 * 10 + - // (i0_flat / (9 * 7) % 10) * 7 * 9 + - // (i0_flat / 7 % 9) * 7 + - // i0_flat % 7 => io_flat - VarHandle t("io_flat", kInt); - ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) + - (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 + - (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { - // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) * - // (i0_flat / (m * n)) => io_flat - VarHandle t("io_flat", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { // 5D: i0_flat / (k * l * n * m) * (m * n * l * k) + - // (i0_flat / (l * n * m) % k) * m * n * l + - // (i0_flat / (n * m) % l) * m * n + - // (i0_flat / m % n) * m + - // i0_flat % m => io_flat - VarHandle t("io_flat", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle l("l", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) + - (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n + - (t / m % n) * m + t % m; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } -} - -TEST(Simplify, SimplifyDivisionScalarFactorization) { - { - // Simple factorization of numerator and denominator. - // 8x / 4y => 2x / y. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 8) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } - - { - // Don't change anything if we can't factorize. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 7) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 7); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Don't reorder floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = (x * 8) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f); - IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "y"); - IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f); - } - - { - // Sanity check we do nothing if there are only scalar parts. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 1) / (y * 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } - - { - // Can factorize amounts of variables. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x + x + x + x) / (y + y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyConstantBranches) { - { - // If the condition is constant true then take the true_value. - // 1 ? x : y => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle t(1); - ExprHandle body = IfThenElse::make(t, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // If the condition is constant false then take the false_value. - // 0 ? x : y => y - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle t(0); - ExprHandle body = IfThenElse::make(t, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); - } - - { - // condition is simplified before checking. - // (x-x) ? x : y => y - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(x - x, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); - } - - { - // If both branches are the same then don't do the condition. - // y ? x : x => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(y, x, x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // If both branches simplify to the same thing it still works. - // y ? (x + x) : (2 * x) => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } -} - -TEST(Simplify, SimplifyConstantCond) { - { - // If the condition is constant true then take the true_value. - // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(1); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - CondPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // If the condition is constant false then take the false_value. - // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1 - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(0); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "B"); - } - - { - // condition is simplified before checking. - // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1 - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "B"); - } - - { - // If both branches are the same then don't do the condition. - // x ? A[0] = x : A[0] = x => A[0] = x - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, x); - StmtPtr false_val = Store::make(a, {0}, x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // If both branches simplify to the same thing it still works. - // x ? (x + x) : (2 * x) => x - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x); - StmtPtr false_val = Store::make(a, {0}, x + x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // But not if they dont - // x ? x : (2 * x) => x ? x : (2 * x) - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x); - StmtPtr true_val = Store::make(a, {0}, x); - StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block, nullptr); - } - - { - StmtPtr cond = alloc( - ExprHandle(false).node(), - alloc(std::vector({})), - nullptr); - StmtPtr simplified = IRSimplifier::simplify(cond); - ASSERT_EQ(simplified, nullptr); - } - - { - StmtPtr cond = alloc( - ExprHandle(true).node(), - nullptr, - alloc(std::vector({}))); - StmtPtr simplified = IRSimplifier::simplify(cond); - ASSERT_EQ(simplified, nullptr); - } -} - -TEST(Simplify, SimplifyEliminateEmptyCond) { - // If the branches are empty in different ways, eliminate. - { - VarHandle x("x", kInt); - ExprHandle condition(x); - StmtPtr true_val = alloc(std::vector({})); - - StmtPtr body = alloc(condition.node(), true_val, nullptr); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_NE(block, nullptr); - ASSERT_EQ(block->nstmts(), 0); - } - - { - VarHandle x("x", kInt); - ExprHandle condition(x); - StmtPtr false_val = alloc(std::vector({})); - - StmtPtr body = alloc(condition.node(), nullptr, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_NE(block, nullptr); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyConstantComparisons) { - auto ComparisonTest = - [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { - ExprHandle body = CompareSelect::make(a, b, op); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), result); - }; - - // Equals. - ComparisonTest(2, 2, kEQ, 1); - ComparisonTest(1, 2, kEQ, 0); - ComparisonTest(2, 1, kEQ, 0); - - // Greater than. - ComparisonTest(2, 2, kGT, 0); - ComparisonTest(1, 2, kGT, 0); - ComparisonTest(2, 1, kGT, 1); - - // Greater or Equal. - ComparisonTest(2, 2, kGE, 1); - ComparisonTest(1, 2, kGE, 0); - ComparisonTest(2, 1, kGE, 1); - - // Less Than. - ComparisonTest(2, 2, kLT, 0); - ComparisonTest(1, 2, kLT, 1); - ComparisonTest(2, 1, kLT, 0); - - // Less or Equal. - ComparisonTest(2, 2, kLE, 1); - ComparisonTest(1, 2, kLE, 1); - ComparisonTest(2, 1, kLE, 0); - - // Not equal. - ComparisonTest(2, 2, kNE, 0); - ComparisonTest(1, 2, kNE, 1); - ComparisonTest(2, 1, kNE, 1); - - // With specified results: - ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 42); -} - -TEST(Simplify, SimplifySymbolicComparisons) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); }; - auto TookFalseBranch = [](ExprHandle a) { - IS_IMM_WITH_VAL(Int, a.node(), 0); - }; - - // EQ - - // x == x => 1 - ExprHandle body = CompareSelect::make(x, x, kEQ); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x == x+1 => 0 - body = CompareSelect::make(x, x + 1, kEQ); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x == x * 2 cannot simplify since we don't know x is nonzero. - body = CompareSelect::make(x, x * 2, kEQ); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // x == x * 1 => 1 - body = CompareSelect::make(x, x * 1, kEQ); - TookTrueBranch(IRSimplifier::simplify(body)); - - { - // x == y => x == y - body = CompareSelect::make(x, y, kEQ); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kEQ); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_VAR_WITH_NAME(cmp->rhs(), "y"); - } - - { - // x == 5 => x == 5 - body = CompareSelect::make(x, 5, kEQ); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kEQ); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_IMM_WITH_VAL(Int, cmp->rhs(), 5); - } - - // GT - - // x+1 > x => 1 - body = CompareSelect::make(x + 1, x, kGT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x > x + 1 => 0 - body = CompareSelect::make(x, x + 1, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x > x - 1 => 1 - body = CompareSelect::make(x, x - 1, kGT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x - 1 > x => 0 - body = CompareSelect::make(x - 1, x, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x > x => 0 - body = CompareSelect::make(x, x, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x * 2 > x => x * 2 > x - // since we don't know the sign of x. - body = CompareSelect::make(x * 2, x, kGT); - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // GE - - // x+1 >= x => 1 - body = CompareSelect::make(x + 1, x, kGE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x >= x + 1 => 0 - body = CompareSelect::make(x, x + 1, kGE); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x >= x => 1 - body = CompareSelect::make(x, x, kGE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x * 2 >= x => x * 2 >= x - // since we don't know the sign of x. - body = CompareSelect::make(x * 2, x, kGE); - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // LT - - // x+1 < x => 0 - body = CompareSelect::make(x + 1, x, kLT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x < x + 1 => 1 - body = CompareSelect::make(x, x + 1, kLT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x < x => 0 - body = CompareSelect::make(x, x, kLT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // LE - - // x+1 <= x => 0 - body = CompareSelect::make(x + 1, x, kLE); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x <= x + 1 => 1 - body = CompareSelect::make(x, x + 1, kLE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x <= x => 1 - body = CompareSelect::make(x, x, kLE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // NE - - // x+1 != x => 1 - body = CompareSelect::make(x + 1, x, kNE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x != x + 1 => 1 - body = CompareSelect::make(x, x + 1, kNE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x != x => 0 - body = CompareSelect::make(x, x, kNE); - TookFalseBranch(IRSimplifier::simplify(body)); -} - -TEST(Simplify, SimplifyEliminateZeroLengthFor) { - { - // Will eliminate zero loop For. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // still works if start is not zero. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // works if both terms are variable. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // works if one term simplifies down. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE(For, simplified); - } -} - -TEST(Simplify, SimplifyOneLoopFor) { - { - // Will remove the loop if the body is run once. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // still works if start is not zero. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 2); - } - - { - // works if both terms are variable. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_VAR_WITH_NAME(store->flat_index(), "x"); - } - - { - // works if one term simplifies down. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = - For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE(For, simplified); - } -} - -TEST(Simplify, SimplifyForWontLoseLoopOptions) { - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - LoopOptions options; - options.set_gpu_block_index(LoopOptions::IDX_W); - auto body = - For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, for_); - LoopOptions options2 = for_->loop_options(); - ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index()); - } -} - -TEST(Simplify, SimplifyMultilevelFor) { - { - // Multiple layers of For will be simplified out. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 1, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Will maintain an outer loop if the inner loop is eliminated. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 2, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - ForPtr for__ = static_to(simplified); - IS_NODE_WITH_NAME(For, for__, for_); - IS_VAR_WITH_NAME(for_->var(), "j"); - IS_IMM_WITH_VAL(Int, for_->start(), 0); - IS_IMM_WITH_VAL(Int, for_->stop(), 2); - BlockPtr block = to(for_->body()); - ASSERT_NE(block, nullptr); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Will maintain inner loop if outer loops is eliminated. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 1, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(For, block->front(), for_); - IS_VAR_WITH_NAME(for_->var(), "i"); - IS_IMM_WITH_VAL(Int, for_->start(), 0); - IS_IMM_WITH_VAL(Int, for_->stop(), 2); - IS_NODE_WITH_NAME(Store, for_->body()->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_VAR_WITH_NAME(store->flat_index(), "i"); - } -} - -TEST(Simplify, SimplifyForCleansUp) { - { - BufHandle a("a", {1, 12, 1}, kFloat); - VarHandle x("x", kInt); - Tensor b = Compute( - "x", - {1, 12, 1}, - [](const VarHandle& i, const VarHandle& m, const VarHandle& n) { - return i + m + n; - }); - LoopNest l({b}); - l.prepareForCodegen(); - - StmtPtr body = LoopNest::sanitizeNames(l.root_stmt()); - StmtPtr simplified = IRSimplifier::simplify(body); - - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(For, block->front(), for_); - // for is over "m". - IS_VAR_WITH_NAME(for_->var(), "j"); - // x[m] = m; - IS_NODE_WITH_NAME(Store, for_->body()->front(), store); - IS_VAR_WITH_NAME(store->flat_index(), "j"); - IS_VAR_WITH_NAME(store->value(), "j"); - } -} - -TEST(Simplify, SimplifyEliminateEmptyFor) { - { - // Flatten many layers around an empty block to an empty block. - StmtPtr last = alloc(std::vector({})); - for ([[maybe_unused]] const auto i : c10::irange(11)) { - VarHandle loopVar("loopVar", kInt); - last = For::make(loopVar, 0, 10, last); - } - - StmtPtr simplified = IRSimplifier::simplify(last); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyFlattenBlock) { - { - // Flatten multiple blocks down to one. - // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store1, store2})); - BlockPtr block2 = alloc(std::vector({block1})); - - BlockPtr enclosing = alloc(std::vector({block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten multiple sub blocks containing statements. - // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store1})); - BlockPtr block2 = alloc(std::vector({store2})); - - BlockPtr enclosing = alloc(std::vector({block1, block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten sub blocks with different depths. - // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store2})); - BlockPtr block2 = alloc(std::vector({block1})); - - BlockPtr enclosing = alloc(std::vector({store1, block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten many layers around an empty block to an empty block. - StmtPtr last = alloc(std::vector({})); - for ([[maybe_unused]] const auto i : c10::irange(11)) { - last = alloc(std::vector({last})); - } - - StmtPtr simplified = IRSimplifier::simplify(last); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { - { - // Simple positive case. - BufHandle b("x", {0}, kInt); - - AllocatePtr alloc_ = Allocate::make(b); - FreePtr free_ = Free::make(b); - - BlockPtr block1 = alloc(std::vector({alloc_, free_})); - ASSERT_EQ(block1->nstmts(), 2); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 0); - } - - { - // Simple negative case. - BufHandle b("x", {2}, kInt); - - AllocatePtr alloc_ = Allocate::make(b); - FreePtr free_ = Free::make(b); - - BlockPtr block1 = alloc(std::vector({alloc_, free_})); - ASSERT_EQ(block1->nstmts(), 2); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - } - - { - // Finds right Alloc/Free. - BufHandle b1("x", {0}, kInt); - BufHandle b2("y", {2}, kInt); - - AllocatePtr alloc1 = Allocate::make(b1); - AllocatePtr alloc2 = Allocate::make(b2); - FreePtr free2_ = Free::make(b2); - FreePtr free1_ = Free::make(b1); - - BlockPtr block1 = - alloc(std::vector({alloc1, alloc2, free2_, free1_})); - ASSERT_EQ(block1->nstmts(), 4); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc); - IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y"); - IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free); - ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var()); - } - - { - // Dynamic shape. - VarHandle z("z", kInt); - BufHandle b1("x", {0}, kInt); - BufHandle b2("y", {z}, kInt); - - AllocatePtr alloc1 = Allocate::make(b1); - AllocatePtr alloc2 = Allocate::make(b2); - FreePtr free2_ = Free::make(b2); - FreePtr free1_ = Free::make(b1); - - BlockPtr block1 = - alloc(std::vector({alloc1, alloc2, free2_, free1_})); - ASSERT_EQ(block1->nstmts(), 4); - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - } -} - -TEST(Simplify, DontSimplifyRand) { - { - // rand() + rand() = rand() + rand() NOT 2 * rand(). - ExprHandle body = - Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_RAND(add->lhs()); - IS_RAND(add->rhs()); - } - - { - // rand() - rand() = rand() - rand() NOT 0. - ExprHandle body = - Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_RAND(sub->lhs()); - IS_RAND(sub->rhs()); - } - - { - // rand() * rand() = rand() * rand(). - ExprHandle body = - Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_RAND(mul->lhs()); - IS_RAND(mul->rhs()); - } -} - -TEST(Simplify, SimplifyReorderForCond) { - BufHandle a("A", {4}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - { - // for ( if ( ... ) ) => if ( for ( ... ) ). - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {i}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Can't reorder if condition is dependent on the loop var. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(i, 2, CompareSelectOperation::kEQ), - Store::make(c, {i}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Can't reorder if condition is dependent on a var that is modified inside - // the loop. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(c, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Condition based on buffer not referenced in body. Can reorder here. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(b, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Condition based on buffer read only in body. Can reorder here. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Condition depends on Let in the loop. Cannot reorder. - auto body = For::make( - i, - 0, - 4, - Block::make( - {Let::make(j, 3), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)})); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Let, loop->body()->front(), let); - IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond); - } - - { - // Multi level Ifs where all conditions are distinct. Move BOTH Cond - // statements outside the loop. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kEQ), - Store::make(c, {0}, Load::make(a, {i})), - nullptr), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(Cond, true_block->front(), cond2); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2); - IS_NODE_WITH_NAME(For, true_block2->front(), loop); - } - - { - // Multi level Ifs where the inner condition does depend on a loop var, - // reorder only the first Cond. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Cond::make( - CompareSelect::make(i, 3, CompareSelectOperation::kEQ), - Store::make(c, {0}, Load::make(a, {i})), - nullptr), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - IS_NODE_WITH_NAME(Block, loop->body(), loop_body); - IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2); - } - - { - // Don't reorder if there's an else block of the Cond. - // We could, but is it much better? - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - Store::make(c, {0}, 0))); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Condition uses distinct region of Tensor. - // We could reorder here with better analysis, but we don't. Included for - // completeness. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(c, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {1}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } -} - -TEST(Simplify, SimplifyFuseConditions) { - BufHandle a("A", {2}, kInt); - BufHandle b("B", {2}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - { - // Can fuse since the conditions are identical. - // if (A) { X }; if (A) { Y }; => if (A) { X; Y } - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can't fuse, conditions are not identical in lhs (i != j). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - { - // Can't fuse, conditions are not identical in rhs (10 != 11). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 11, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can't fuse, conditions are not identical in operation (LT vs GT). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kGT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can't fuse, CompareSelect results are different. - // Actually we totally could if we normalized CompareSelect results, but - // TODO for later. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can fuse with false stmt only. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(a, {0}, i)), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(a, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 2); - ASSERT_EQ(cond->true_stmt(), nullptr); - } - - { - // Can fuse with both true and false stmt. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - Store::make(b, {0}, i)), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - Store::make(b, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 2); - } - - { - // Can fuse with mismatched true / false stmt existing - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(b, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 1); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 1); - } - - { - // Can fuse partial block contents, ie when there are non fused stmts before - // and after. - // before: - // if (j < 10) { A[0] = j; } - // if (i < 10) { A[0] = i; } - // if (i < 10) { A[1] = i; } - // if (i < 11) { A[1] = j; } - // - // after: - // - // if (j < 10) { A[0] = j; } - // if (i < 10) { - // A[0] = i; - // A[1] = i; - // } - // if (i < 11) { A[1] = j; } - - auto body = Block::make({ - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 11, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - it++; - IS_NODE_WITH_NAME(Cond, *it, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can fuse longer sequences of identical conditions. - auto body = Block::make({ - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 4); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can't fuse through a non condition. - auto body = Block::make({ - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Store::make(b, {1}, i + j), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt2->nstmts(), 2); - ASSERT_EQ(cond2->false_stmt(), nullptr); - - auto it = block->begin(); - it++; - IS_NODE_WITH_NAME(Store, *it, middle); - } - - { - // Can fuse if the conditions simplify to the same thing. - auto body = Block::make( - {Cond::make( - CompareSelect::make( - i * 2, - ExprHandle(87) % ExprHandle(11), - CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make( - i * 2, - ExprHandle(300) / ExprHandle(30), - CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can fuse non-CompareSelects. - // if (i) { X } if (i) { Y } => if (i) { X; Y } - auto body = Block::make( - {Cond::make(i, Store::make(a, {0}, i), nullptr), - Cond::make(i, Store::make(a, {1}, i), nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Sanity check won't fuse different non-CompareSelects. - auto body = Block::make( - {Cond::make(i, Store::make(a, {0}, i), nullptr), - Cond::make(j, Store::make(a, {1}, i), nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - } - - { - // Sanity check constant condition elimination still occurs when merging is - // possible. - auto body = Block::make( - {Cond::make(1, Store::make(a, {0}, i), nullptr), - Cond::make(1, Store::make(a, {1}, i), nullptr)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Store, block->front(), store1); - IS_NODE_WITH_NAME(Store, block->back(), store2); - } - - { - // Sanity check for-cond reordering occurs after fusing. - auto body = For::make( - i, - 0, - 4, - Block::make( - {Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, Load::make(b, {0})), - nullptr), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {2}, Load::make(b, {0})), - nullptr)})); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } -} - -TEST(Simplify, SimplifySyncThreads) { - BufHandle a("A", {4}, kInt); - VarHandle i("i", kInt); - - { - // Merge two inner SyncThreads. - auto body = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {0}, 1), - alloc(), - alloc(), - Store::make(a, {1}, 0)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } - - { - // Eliminate outer SyncThreads. - auto body = Block::make( - {alloc(), Store::make(a, {1}, 0), alloc()}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - auto it = block->begin(); - IS_NODE(Store, *it); - } - - { - // Merge many inner SyncThreads. - auto body = Block::make( - {Store::make(a, {0}, 1), - alloc(), - alloc(), - alloc(), - alloc(), - alloc(), - Store::make(a, {1}, 0)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } - - { - // Merge multiple outer SyncThreads. - auto body = Block::make( - {alloc(), - alloc(), - Store::make(a, {1}, 0), - alloc(), - alloc(), - alloc(), - alloc()}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - auto it = block->begin(); - IS_NODE(Store, *it); - } - - { - // Merge multiple sections; - auto body = Block::make( - {Store::make(a, {0}, 1), - alloc(), - alloc(), - Store::make(a, {1}, 0), - Store::make(a, {2}, 0), - alloc(), - alloc(), - alloc(), - Store::make(a, {3}, 0)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 6); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } -} - -TEST(Simplify, SimplifyRampSubBroadcast) { - int num_lanes = 4; - ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); - ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); - ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast); - RampPtr newRamp = simplified.AsNode(); - IS_NODE_WITH_NAME(IntImm, newRamp->base(), base); - ASSERT_EQ(base->value(), 5); - IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride); - ASSERT_EQ(stride->value(), 6); - ASSERT_EQ(newRamp->lanes(), num_lanes); -} - -TEST(Simplify, SimplifyBroadcastTermExpander) { - int num_lanes = 8; - ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); - ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes); - ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes); - // NB: We need a term in the middle which isn't simplified to trigger the - // relevant path in TermExpander::mutate. The two bc1 terms are brought - // together and simplified to 2 * bc1, which then needs to make 2 multi-lane. - ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1); - BufHandle buf("buf", {num_lanes}, kInt); - // The result isn't fully simplified currently and thus would be brittle to - // match. Observe its value instead. - auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified); - SimpleIREvaluator eval(store, {buf}); - std::vector output(num_lanes); - eval(output); - for (const auto i : c10::irange(num_lanes)) { - ASSERT_EQ(output[i], 2); - } -} - -TEST(Simplify, CompareSelectLoopBounds) { - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - VarHandle m("m", kInt); - VarHandle var_N("var_N", kInt); - VarHandle var_M("var_M", kInt); - - auto test_case_fn = [](const VarHandle& n, - const BufHandle& b, - const ExprHandle& start, - const ExprHandle& stop, - const int& cmp_val, - const CompareSelectOperation& cmp_op, - const std::string& check_string) { - StmtPtr s = For::make( - n, - start, - stop, - b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - std::string target_string = "# CHECK: "; - target_string += check_string; - torch::jit::testing::FileCheck().run(target_string, oss.str()); - }; - - auto test_case_nest_loops_fn = [](const VarHandle& n, - const VarHandle& m, - const BufHandle& b, - const ExprHandle& n_start, - const ExprHandle& n_stop, - const ExprHandle& m_start, - const ExprHandle& m_stop, - const CompareSelectOperation& cmp_op, - const std::string& check_string) { - StmtPtr s = For::make( - m, - m_start, - m_stop, - b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op))); - StmtPtr root_s = For::make(n, n_start, n_stop, s); - root_s = IRSimplifier::simplify(root_s); - std::ostringstream oss; - oss << *root_s; - std::string target_string = "# CHECK: "; - target_string += check_string; - torch::jit::testing::FileCheck().run(target_string, oss.str()); - }; - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, 2)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, 2)) { - // b[1] = 0.f; - // } - test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kNE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kNE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - test_case_nest_loops_fn( - n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 31, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 31, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - test_case_nest_loops_fn( - n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n < m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kLT, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kLT, - "b[n, m] = n m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kGT, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kGT, - "b[n, m] = n>m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n > m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 1.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kGT, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kGT, - "b[n, m] = n>m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n >= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 31, - kGE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 31, - kGE, - "b[n, m] = n>=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n >= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 1.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kGE, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kGE, - "b[n, m] = n>=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n <= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kLE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kLE, - "b[n, m] = n<=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = (n <= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kLE, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kLE, - "b[n, m] = n<=m ? 0.f : 1.f;"); -} - -TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) { - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - StmtPtr s = For::make( - n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[n] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, IfThenCondAlwaysInLoopBounds) { - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f); - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - StmtPtr s = - For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[n] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) { - // This test mimics the unpadded region of a conv2d. We want to remove any - // conditional that is provably satisfied (or unsatisfied) by the entire loop - // range. - // Before: - // for (const auto i : c10::irange(1, 7)) { - // for (const auto j : c10::irange(1, 7)) { - // b[i, j] = IfThenElse( - // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f); - // After: - // for (const auto i : c10::irange(1, 7)) { - // for (const auto j : c10::irange(1, 7)) { - // b[i, j] = 1.f; - constexpr int N = 8; - BufHandle b("b", {N, N}, kFloat); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto csel = CompareSelect::make(i, 1, kLT); - csel = CompareSelect::make(j, 1, 1, csel, kLT); - csel = CompareSelect::make(i, N - 1, 1, csel, kGE); - csel = CompareSelect::make(j, N - 1, 1, csel, kGE); - StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); - s = For::make(j, 1, N - 1, s); - s = For::make(i, 1, N - 1, s); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[i, j] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, DISABLED_SimplifyLoopBounds) { - // This test mimics the padded region of a conv2d. We want to adjust the - // loop bounds such that the condition will be always met. Note that this - // could be solved by peeling, and applying the range-based conditional - // simplification in the previous tests. - // Before: - // for (const auto i : c10::irange(3)) { - // for (const auto j : c10::irange(3)) { - // b[i, j] = (b[i, j]) + (IfThenElse( - // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j])); - // After: - // for (const auto i : c10::irange(1, 3)) { - // for (const auto j : c10::irange(1, 3)) { - // b[i, j] = (b[i, j]) + 1.f; - constexpr int N = 8; - constexpr int K = 3; - BufHandle a("a", {N, N}, kFloat); - BufHandle b("b", {N, N}, kFloat); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto csel = CompareSelect::make(i, 1, kLT); - csel = CompareSelect::make(j, 1, 1, csel, kLT); - csel = CompareSelect::make(i, N - 1, 1, csel, kGE); - csel = CompareSelect::make(j, N - 1, 1, csel, kGE); - StmtPtr s = b.store( - {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j}))); - s = For::make(j, 0, K, s); - s = For::make(i, 0, K, s); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: for (const auto i : c10::irange(1, 3)) { -# CHECK: for (const auto j : c10::irange(1, 3)) { -# CHECK-NOT: IfThenElse -)IR", - oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp deleted file mode 100644 index 56535de914e43..0000000000000 --- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp +++ /dev/null @@ -1,402 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -struct WithCPUFuser { - WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { - overrideCanFuseOnCPU(val); - } - - ~WithCPUFuser() { - overrideCanFuseOnCPU(cpuFuserEnabled); - } - - bool cpuFuserEnabled; -}; - -TEST(TEFuserPass, FuserPass_1) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(128, strides=[1], device=cpu), - %1 : Float(128, strides=[1], device=cpu)): - %12 : int = prim::Constant[value=1]() - %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) - %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) - %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) - %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) - %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should not be able to fuse across the in-place operation here. - testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("aten::add_") - ->check("prim::TensorExprGroup_") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_2) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(128, strides=[1], device=cpu), - %1 : Float(128, strides=[1], device=cpu)): - %12 : int = prim::Constant[value=1]() - %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) - %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) - %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) - %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) - return (%d))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should not be able to fuse across the in-place operation here. - testing::FileCheck() - .check("aten::add_") - ->check("prim::TensorExprGroup_0") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_3) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(128, strides=[1], device=cpu), - %y : Float(128, strides=[1], device=cpu)): - %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%r))IR"; - { - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // We should not create a fusion group since its size would be too small - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should create a fusion group since its size is above the threshold - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); - } -} - -TEST(TEFuserPass, FuserPass_0DimInput) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(device=cpu), - %y : Float(device=cpu)): - %one : int = prim::Constant[value=1]() - %a : Float(device=cpu) = aten::mul(%x, %y) - %b : Float(device=cpu) = aten::add(%x, %a, %one) - return (%b))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should fuse 0-dim tensors too - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_UnfusibleDevice) { - WithCPUFuser cf(false); - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(10, strides=[1], device=cpu)): - %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%a))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // Test that we're not starting fusion groups from nodes with unfusible device - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_UnknownShapes) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Tensor, - %y : Tensor): - %a : Tensor = aten::mul(%x, %y) - %b : Tensor = aten::mul(%x, %a) - return (%b))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // Test that we're not generating fusion groups when shapes are not known - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_Multidevice) { - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should be able to fuse this - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cuda:0), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should not fuse this aten::cat since its inputs are from different - // devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(10, strides=[1], device=cuda:0)): - %dim : int = prim::Constant[value=0]() - %xy_list : Tensor[] = prim::ListConstruct(%x, %y) - %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) - %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z) - return (%r))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // Test that we check device before merging one node (cat) into another - // (mul) - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(10, strides=[1], device=cuda:0)): - %z2 : Tensor = aten::mul(%z, %z) - %dim : int = prim::Constant[value=0]() - %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // Test that we check device before merging one node (mul) into another - // (cat) - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cuda:0)): - %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%r))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should not fuse this graph since its inputs are from different devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cuda:0), - %y : Float(20, strides=[1], device=cuda:1), - %z : Float(20, strides=[1], device=cpu)): - %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x) - %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y) - %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z) - return (%x2, %y2, %z2))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // We should not fuse these two computations since they use different - // devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } -} - -TEST(TEFuserPass, FuserPass_MergeGroups) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%a : Float(128, strides=[1], device=cpu), - %b : Float(128, strides=[1], device=cpu)): - %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) - %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) - return (%x, %y))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // The %x and %y computations are completely independent and yet we should put - // them into a single fusion group rather than having two separate ones. - testing::FileCheck() - .check("= prim::TensorExprGroup_") - ->check_not("= prim::TensorExprGroup_") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Bool(8, strides=[1], device=cpu), - %y : Bool(8, strides=[1], device=cpu)): - %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y) - %b : Tensor = aten::__or__(%a, %y) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_Where) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(8, strides=[1], device=cpu), - %y : Float(8, strides=[1], device=cpu), - %z : Float(8, strides=[1], device=cpu)): - %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) - %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_WhereList) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(8, strides=[1], device=cpu), - %y : Float(8, strides=[1], device=cpu), - %z : Float(8, strides=[1], device=cpu)): - %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) - %b : Tensor[] = aten::where(%cond) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, DynamicShapeFusion) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(10, 5, strides=[5, 1], device=cpu), - %1 : Float(10, 5, strides=[5, 1], device=cpu)): - %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1) - %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1) - return (%3))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs( - g, - /* min_group_size = */ 2, - /* add_composed_op = */ true, - /* fuse_to_dynamic_shapes = */ true); - Code code(g, ""); - - testing::FileCheck() - .check("prim::TensorExprDynamicGroup_") - ->check("prim::TensorExprDynamicGuard") - ->check("prim::TensorExprGroup_") - ->run(*g); - - auto run_and_compare = [&](const std::vector& inputs) { - TORCH_INTERNAL_ASSERT(inputs.size() == 2); - - auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]); - - InterpreterState interp(code); - Stack stack(inputs.begin(), inputs.end()); - interp.run(stack); - at::Tensor out = pop(stack).toTensor(); - ASSERT_TRUE(at::allclose(out, ref)); - }; - - std::vector inputs = {at::rand({10, 5}), at::rand({10, 5})}; - run_and_compare(inputs); - - std::vector inputs2 = {at::rand({20, 5}), at::rand({20, 5})}; - run_and_compare(inputs2); - - std::vector inputs3 = {at::rand({25, 60}), at::rand({25, 60})}; - run_and_compare(inputs3); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp deleted file mode 100644 index 6758503f4de79..0000000000000 --- a/test/cpp/tensorexpr/test_type.cpp +++ /dev/null @@ -1,202 +0,0 @@ -#include - -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir.h" -#include "torch/csrc/jit/tensorexpr/tensor.h" - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -TEST(Type, Test01) { - { - Dtype dt1 = kInt; - ASSERT_EQ(dt1, kInt); - } - { - Dtype dt2_a(kInt, 8); - Dtype dt2_b(kInt, 4); - Dtype dt2_c(ScalarType::Int, 8); - ASSERT_EQ(dt2_a, dt2_c); - ASSERT_NE(dt2_a, dt2_b); - } - { - ASSERT_EQ(kInt, ToDtype()); - ASSERT_EQ(kFloat, ToDtype()); - ASSERT_EQ(kByte, ToDtype()); - ASSERT_EQ(kChar, ToDtype()); - ASSERT_EQ(kShort, ToDtype()); - ASSERT_EQ(kLong, ToDtype()); - ASSERT_EQ(kHalf, ToDtype()); - ASSERT_EQ(kDouble, ToDtype()); - ASSERT_EQ(kBool, ToDtype()); - } - { - Dtype int32x8(kInt, 8); - Dtype float32x8(kFloat, 8); - ASSERT_NE(int32x8, float32x8); - ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8)); - ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8)); - ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8)); - ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8)); - } -} - -TEST(Type, BitCasting) { - { - VarHandle x("x", kFloat); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kInt); - } - { - VarHandle x("x", kInt); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kFloat); - } - { - VarHandle x("x", kShort); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kHalf); - } - { - VarHandle x("x", kHalf); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kShort); - } - - constexpr int32_t ref32 = 1337; - constexpr int64_t ref64 = 1337; - constexpr float reff32 = 1337.0f; - constexpr double reff64 = 1337.0f; - using SimpleIRExprEval = ExprEval; - // this is broken - /*{ - constexpr int16_t ref16 = 1337; - at::Half k_; - at::Half* k = &k_; - *reinterpret_cast(k) = ref16; - auto a = HalfImm::make(*k); - auto b = BitCast::make(kShort, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref16); - }*/ - - { - float k = raw_bitcast(ref32); - auto a = FloatImm::make(k); - auto b = BitCast::make(kInt, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref32); - } - - { - double k = raw_bitcast(ref64); - auto a = DoubleImm::make(k); - auto b = BitCast::make(kLong, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref64); - } - - { - int64_t k = raw_bitcast(reff64); - auto a = LongImm::make(k); - auto b = BitCast::make(kDouble, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), reff64); - } - - { - int32_t k = raw_bitcast(reff32); - auto a = IntImm::make(k); - auto b = BitCast::make(kFloat, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), reff32); - } - - // This segfaults :( - /*{ - VarHandle x("x", kDouble); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kFloat); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kLong); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kShort); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kInt); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - }*/ -} - -TEST(Type, Propagation) { - // Same types: - { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = FloatImm::make(2.f) + - (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y); - ASSERT_EQ(body.dtype(), kFloat); - } - // Int to bigger int: - { - VarHandle x("x", kShort); - VarHandle y("y", kLong); - ExprHandle body = - ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y); - ASSERT_EQ(body.dtype(), kLong); - } - // Float to bigger float: - { - VarHandle x("x", kHalf); - VarHandle y("y", kDouble); - ExprHandle body = - HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y); - ASSERT_EQ(body.dtype(), kDouble); - } - // Int to Float: - { - VarHandle x("x", kFloat); - VarHandle y("y", kInt); - ExprHandle body = - IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y); - ASSERT_EQ(body.dtype(), kFloat); - } - // Smaller float, bigger Int: - { - VarHandle x("x", kHalf); - VarHandle y("y", kLong); - ExprHandle body = - HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y); - ASSERT_EQ(body.dtype(), kHalf); - } - // Bigger float, smaller Int: - { - VarHandle x("x", kChar); - VarHandle y("y", kDouble); - ExprHandle body = - CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); - ASSERT_EQ(body.dtype(), kDouble); - } - // Sign change char/byte upgrades to short: - { - VarHandle x("x", kChar); - VarHandle y("y", kByte); - ExprHandle body = - CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); - ASSERT_EQ(body.dtype(), kShort); - } -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_type_specializations.cpp b/test/cpp/tensorexpr/test_type_specializations.cpp deleted file mode 100644 index d9756627fa74d..0000000000000 --- a/test/cpp/tensorexpr/test_type_specializations.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include - -// Test that tensor type specializations are available in -// the custom passes - -namespace torch { -namespace jit { - -namespace { - -bool hasTensorTypeSpecializations(torch::jit::Block* block) { - for (Value* v : block->inputs()) { - if (hasTensorTypeSpecialization(v)) - return true; - } - for (Node* n : block->nodes()) { - for (torch::jit::Block* b : n->blocks()) { - if (hasTensorTypeSpecializations(b)) - return true; - } - for (Value* v : n->outputs()) { - if (hasTensorTypeSpecialization(v)) - return true; - } - } - return false; -} - -static bool hasSpecializations = false; -void detectTTSpecializationPass(std::shared_ptr& graph) { - GRAPH_DUMP("In detectTTSpecialization Custom Post Pass: ", graph); - hasSpecializations = hasTensorTypeSpecializations(graph->block()); -} - -} // namespace - -TEST(SpecializationsInCustomPasses, Basic) { - RegisterPass p(detectTTSpecializationPass); - hasSpecializations = false; - std::shared_ptr graph = std::make_shared(); - parseIR( - R"IR( -graph(%a.1 : Tensor, - %b.1 : Tensor): - %c.1 : Tensor = aten::mul(%a.1, %b.1) # misc/test_specializations.py:5:8 - %d.1 : Tensor = aten::mul(%c.1, %b.1) # misc/test_specializations.py:6:8 - return (%d.1) - )IR", - &*graph); - - IValue ival = IValue(torch::randn({22}, at::kCPU)); - std::vector stack = {ival, ival}; - auto run = [&](std::shared_ptr& graph, std::vector stack) { - GraphExecutor executor(graph, ""); - executor.run(stack); - return stack; - }; - run(graph, stack); - - // Profiling mode will not be run with simple executor - if (!getExecutorMode()) { - EXPECT_TRUE(hasSpecializations); - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_utils.h b/test/cpp/tensorexpr/test_utils.h deleted file mode 100644 index 065e513c1a645..0000000000000 --- a/test/cpp/tensorexpr/test_utils.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -#include -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -#define IS_NODE(T, node) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - } - -#define IS_NODE_WITH_NAME(T, node, name) \ - auto name = to(node); \ - ASSERT_NE(nullptr, name); - -#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ - NodePtr name = nullptr; \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ - name = to(node_->src_value()); \ - } \ - ASSERT_NE(nullptr, name); - -#define IS_IMM_WITH_VAL(T, node, val) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->value(), val); \ - } - -#define IS_VAR_WITH_NAME(node, name) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->name_hint(), name); \ - } - -#define IS_BINOP_W_VARS(T, node, name, v1, v2) \ - NodePtr name = nullptr; \ - { \ - name = to(node); \ - ASSERT_NE(nullptr, name); \ - IS_VAR_WITH_NAME(name->lhs(), v1); \ - IS_VAR_WITH_NAME(name->rhs(), v2); \ - } - -#define IS_BINOP_W_CONST(T, node, name, v, c) \ - NodePtr name = nullptr; \ - { \ - name = to(node); \ - ASSERT_NE(nullptr, name); \ - IS_VAR_WITH_NAME(name->lhs(), v); \ - IS_IMM_WITH_VAL(Int, name->rhs(), c); \ - } - -#define IS_RAND(node) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->op_type(), kRand); \ - } - -void checkIR(StmtPtr s, const std::string& pattern); -void checkExprIR(ExprPtr e, const std::string& pattern); -void checkExprIR(const ExprHandle& e, const std::string& pattern); - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp deleted file mode 100644 index 3f4c32af463b6..0000000000000 --- a/test/cpp/tensorexpr/tutorial.cpp +++ /dev/null @@ -1,542 +0,0 @@ -// *** Tensor Expressions *** -// -// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to -// work with them, and outlines how they are used in the overall TorchScript -// compilation pipeline. This doc is permanently a "work in progress" since NNC -// is under active development and things change fast. -// -// This Tutorial's code is compiled in the standard pytorch build, and the -// executable can be found in `build/bin/tutorial_tensorexpr`. -// -// *** What is NNC *** -// -// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT -// and it performs on-the-fly code generation for kernels, which are often a -// combination of multiple aten (torch) operators. -// -// When the JIT interpreter executes a torchscript model, it automatically -// extracts subgraphs from the torchscript IR graph for which specialized code -// can be JIT generated. This usually improves performance as the 'combined' -// kernel created from the subgraph could avoid unnecessary memory traffic that -// is unavoidable when the subgraph is interpreted as-is, operator by operator. -// This optimization is often referred to as 'fusion'. Relatedly, the process of -// finding and extracting subgraphs suitable for NNC code generation is done by -// a JIT pass called 'fuser'. -// -// *** What is TE *** -// -// TE stands for Tensor Expressions. TE is a commonly used approach for -// compiling kernels performing tensor (~matrix) computation. The idea behind it -// is that operators are represented as a mathematical formula describing what -// computation they do (as TEs) and then the TE engine can perform mathematical -// simplification and other optimizations using those formulas and eventually -// generate executable code that would produce the same results as the original -// sequence of operators, but more efficiently. -// -// NNC's design and implementation of TE was heavily inspired by Halide and TVM -// projects. -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace torch::jit::tensorexpr; - -#ifdef TORCH_ENABLE_LLVM - -// Helper function to print a snippet from a big multi-line string -static void printLinesToFrom(const std::string& input_str, int from, int to); - -#endif - -int main(int argc, char* argv[]) { - std::cout << "*** Structure of tensor expressions and statements ***" - << std::endl; - { - // A tensor expression is a tree of expressions. Each expression has a type, - // and that type defines what sub-expressions the current expression has. - // For instance, an expression of type 'Mul' would have a type 'kMul' and - // two subexpressions: LHS and RHS. Each of these two sub-expressions could - // also be a 'Mul' or some other expression. - // - // Let's construct a simple TE: - ExprPtr lhs = alloc(5); - ExprPtr rhs = alloc("x", kInt); - ExprPtr mul = alloc(lhs, rhs); - std::cout << "Tensor expression: " << *mul << std::endl; - // Prints: Tensor expression: 5 * x - - // Here we created an expression representing a 5*x computation, where x is - // an int variable. - - // Another, probably a more convenient, way to construct tensor expressions - // is to use so called expression handles (as opposed to raw expressions - // like we did in the previous example). Expression handles overload common - // operations and allow us to express the same semantics in a more natural - // way: - ExprHandle l = 5; - ExprHandle r = Var::make("x", kInt); - ExprHandle m = l * r; - std::cout << "Tensor expression: " << *m.node() << std::endl; - // Prints: Tensor expression: 5 * x - - // Converting from handles to raw expressions and back is easy: - ExprHandle handle = Var::make("x", kInt); - ExprPtr raw_expr_from_handle = handle.node(); - ExprPtr raw_expr = alloc("x", kInt); - ExprHandle handle_from_raw_expr = ExprHandle(raw_expr); - - // We could construct arbitrarily complex expressions using mathematical - // and logical operations, casts between various data types, and a bunch of - // intrinsics. - ExprHandle a = Var::make("a", kInt); - ExprHandle b = Var::make("b", kFloat); - ExprHandle c = Var::make("c", kFloat); - ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f); - std::cout << "Tensor expression: " << *x.node() << std::endl; - // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f) - - // An ultimate purpose of tensor expressions is to optimize tensor - // computations, and in order to represent accesses to tensors data, there - // is a special kind of expression - a load. - // To construct a load we need two pieces: the base and the indices. The - // base of a load is a Buf expression, which could be thought of as a - // placeholder similar to Var, but with dimensions info. - // - // Let's construct a simple load: - BufHandle A("A", {64, 32}, kInt); - VarPtr i_var = alloc("i", kInt), j_var = alloc("j", kInt); - ExprHandle i(i_var), j(j_var); - ExprHandle load = Load::make(A.dtype(), A, {i, j}); - std::cout << "Tensor expression: " << *load.node() << std::endl; - // Prints: Tensor expression: A[i, j] - - // Tensor Expressions constitute Tensor Statements, which are used to - // represent computation of a given operator or a group of operators from a - // fusion group. - // - // There are three main kinds of tensor statements: - // - block - // - store - // - loop - // - // A Store represents a store to a single element of a tensor (or to a - // group of elements if it's a vectorized store). Store statements, - // similarly to Load expressions, have a base and indices, but on top of - // that they also include a value - an expression representing what needs - // to be stored at the given memory location. Let's create a Store stmt: - StmtPtr store_a = Store::make(A, {i, j}, i + j); - std::cout << "Store statement: " << *store_a << std::endl; - // Prints: Store statement: A[i, j] = i + j; - - // An operator fills the entire tensor, not just a single element, and to - // represent this we need to use For stmt: let's wrap our store stmt with - // two nested loops to represent that variables i and j need to iterate - // over some ranges. - ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a); - ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a); - - std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl; - // Prints: - // Nested for loops: - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // A[i, j] = i + j; - // } - // } - - // A Block statement is used when we need a sequence of other statements. - // E.g. if a fusion group contains several operators, we initially define - // separate loopnest for each of them and put them all into a common block: - BufHandle B("B", {64, 32}, kInt); - StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j)); - ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b); - ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b); - - BlockPtr block = Block::make({loop_i_a, loop_i_b}); - std::cout << "Compound Block statement: " << std::endl - << *block << std::endl; - // Prints: - // Compound Block statement: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // A[i, j] = i + j; - // } - // } - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // B[i, j] = A[i, j]; - // } - // } - // } - - // Manually constructing nested loops and blocks to represent a computation - // might be laborious, and instead we can use a 'Compute' API. This API - // requires us to specify dimensions and a lambda to compute a single - // element of the resulting tensor and returns a `Tensor` structure. This - // structure is simply a pair of a buffer that was created to represent the - // result of the computation (BufPtr) and a statement representing the - // computation itself (StmtPtr). - Tensor C = - Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - std::cout << "Stmt produced by 'Compute' API: " << std::endl - << *C.stmt() << std::endl; - // Prints: - // Stmt produced by 'Compute' API: - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * j; - // } - // } - - // To construct statements to represent computations with reductions, we - // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple - // of extra arguments defining how to perform the reduction. Let's define a - // simple 2D sum of C using that: - Tensor D = Reduce( - "D", - {}, - Sum(), - [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); }, - {64, 32}); - std::cout << "Stmt produced by 'Reduce' API: " << std::endl - << *D.stmt() << std::endl; - } - - std::cout << "*** Loopnests transformations ***" << std::endl; - { - // When a statement for the computation is generated, we might want to - // apply some optimizations to it. These transformations allow us to end up - // with a statement producing the same results, but more efficiently. - // - // Let's look at a couple of transformations that are used in NNC. We will - // begin with constructing a Block statement like we did before. - - Tensor C = - Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return i * (j + 1); - }); - BufHandle c_buf(C.buf()); - Tensor D = - Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return c_buf.load(i, j) - i; - }); - StmtPtr block = Block::make({C.stmt(), D.stmt()}); - std::cout << "Stmt produced by 'Compute' API: " << std::endl - << *block << std::endl; - // Prints: - // Stmt produced by 'Compute' API: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * (j + 1); - // } - // } - // for (const auto i_1 : c10::irange(64)) { - // for (const auto j_1 : c10::irange(32)) { - // D[i_1, j_1] = (C[i_1, j_1]) - i_1; - // } - // } - // } - - // One transformation we can apply to this computation is inlining: i.e. - // taking the expression that defines values of C and substituting a load - // from C with it. - // To do that, we first need to create a special object called LoopNest - - // all transformations are methods of this class. To create a loopnest we - // need to provide a list of output buffers and the root statement: - LoopNest nest(block, {D.buf()}); - - // We can always retrieve the Stmt back from LoopNest: - std::cout << "LoopNest root stmt: " << std::endl - << *nest.root_stmt() << std::endl; - // Prints: - // LoopNest root stmt: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * (j + 1); - // } - // } - // for (const auto i_1 : c10::irange(64)) { - // for (const auto j_1 : c10::irange(32)) { - // D[i_1, j_1] = (C[i_1, j_1]) - i_1; - // } - // } - // } - - // Now we can apply the inlining transformation: - nest.computeInline(C.buf()); - std::cout << "Stmt after inlining:" << std::endl - << *nest.root_stmt() << std::endl; - // Prints: - // Stmt after inlining: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // D[i, j] = i * (j + 1) - i; - // } - // } - // } - - // We can also apply algebraic simplification to a statement: - StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt()); - std::cout << "Stmt after simplification:" << std::endl - << *simplified << std::endl; - // Prints: - // Stmt after simplification: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // D[i, j] = i * j; - // } - // } - // } - - // Many loopnest transformations are stateless and can be applied without - // creating a LoopNest object. In fact, we plan to make all transformations - // stateless. - // splitWithTail is one such transformation: it splits an iteration space - // of a given loop into two with a given factor. - ForPtr outer_loop = to(to(simplified)->stmts().front()); - LoopNest::splitWithTail(outer_loop, 13); - // Call simplifier once more to fold some arithmetic. - simplified = IRSimplifier::simplify(simplified); - std::cout << "Stmt after splitWithTail:" << std::endl - << *simplified << std::endl; - // Prints: - // Stmt after splitWithTail: - // { - // for (const auto i_outer : c10::irange(4)) { - // for (const auto i_inner : c10::irange(13)) { - // for (const auto j : c10::irange(32)) { - // D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j); - // } - // } - // } - // for (const auto i_tail : c10::irange(12)) { - // for (const auto j : c10::irange(32)) { - // D[i_tail + 52, j] = i_tail * j + 52 * j; - // } - // } - // } - - // NNC supports a wide range of loop nest transformations, which we are not - // listing here. Please refer to documentation in - // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h - // for more details. - } - - std::cout << "*** Codegen ***" << std::endl; - { - // An ultimate goal of tensor expressions is to be provide a mechanism to - // execute a given computation in the fastest possible way. So far we've - // looked at how we could describe what computation we're interested in, but - // we haven't looked at how to actually execute it. - // - // All we've been dealing with was just symbols with no actual data - // associated, in this section we would look at how we can bridge that gap. - - // Let's start by constructing a simple computation for us to work with: - BufHandle A("A", {64, 32}, kInt); - BufHandle B("B", {64, 32}, kInt); - Tensor X = - Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j) + B.load(i, j); - }); - - // And let's lower it to a loop nest, as we did in the previous section. We - // can pass Tensor object directly: - LoopNest loopnest({X}); - std::cout << *loopnest.root_stmt() << std::endl; - // Prints: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // X[i, j] = (A[i, j]) + (B[i, j]); - // } - // } - - // Now imagine that we have two actual tensors 64x32 that we want sum - // together, how do we pass those tensors to the computation and how do we - // carry it out? - // - // Codegen object is aimed at providing exactly that functionality. Codegen - // is an abstract class and concrete codegens are derived from it. - // Currently, we have three codegens: - // 1) Simple Evaluator, - // 2) LLVM Codegen for CPU, - // 3) CUDA Codegen. - // In this example we will be using Simple Evaluator, since it's available - // everywhere. - - // To create a codegen, we need to provide the statement - it specifies the - // computation we want to perform - and a list of placeholders and tensors - // used in the computation. The latter part is crucial since that's the only - // way the codegen could use to correlate symbols in the statement to actual - // data arrays that we will be passing when we will actually be performing - // the computation. - // - // Let's create a Simple IR Evaluator codegen for our computation: - SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X}); - - // We are using the simplest codegen and in it almost no work is done at the - // construction step. Real codegens such as CUDA and LLVM perform - // compilation during that stage so that when we're about to run the - // computation everything is ready. - - // Let's now create some inputs and run our computation with them: - std::vector data_A(64 * 32, 3); // This will be the input A - std::vector data_B(64 * 32, 5); // This will be the input B - std::vector data_X(64 * 32, 0); // This will be used for the result - - // Now let's invoke our codegen to perform the computation on our data. We - // need to provide as many arguments as how many placeholders and tensors we - // passed at the codegen construction time. A position in these lists would - // define how real data arrays from the latter call (these arguments are - // referred to as 'CallArg's in our codebase) correspond to symbols - // (placeholders and tensors) used in the tensor expressions we constructed - // (these are referred to as 'BufferArg'). - // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A - // contains data for the placeholder A, data_B - for the placeholder B, and - // data_X would be used for contents of tensor X. - ir_eval(data_A, data_B, data_X); - - // Let's print one of the elements from each array to verify that the - // computation did happen: - std::cout << "A[10] = " << data_A[10] << std::endl - << "B[10] = " << data_B[10] << std::endl - << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl; - // Prints: - // A[10] = 3 - // B[10] = 5 - // X[10] = A[10] + B[10] = 8 - } - - std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl; - { - // This section requires a LLVM-enabled PyTorch build, so we have to use a - // guard: -#ifdef TORCH_ENABLE_LLVM - - // Often we would like to convert a TorchScript IR to TE rather than - // construct TE IR from scratch. NNC provides an API to perform such - // lowering: it takes a TorchScript graph and returns an object that can be - // used to invoke the generated kernel. - // This API is currently used by the TorchScript JIT fuser and can also be - // used ahead of time to pre-compile parts of a model. - // - // To get familiar with this API let's first start with defining a simple - // TorchScript graph: - const auto graph_string = R"IR( - graph(%A : Float(5, 3, strides=[3, 1], device=cpu), - %B : Float(5, 3, strides=[3, 1], device=cpu)): - %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B) - %one : int = prim::Constant[value=1]() - %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB) - %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one) - return (%AAB_plus_B))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - // This graph defines a simple computation of A*A*B + B where A and B are - // input 5x3 tensors. - - // To lower this TorchScript graph to TE, we just need to create a - // TensorExprKernel object. In its constructor it constructs the - // corresponding TE IR and compiles it for the given backend (in this - // example for CPU using LLVM compiler). - TensorExprKernel kernel(graph); - - // We can retrieve the generated TE stmt from the kernel object: - StmtPtr kernel_stmt = kernel.getCodeGenStmt(); - std::cout << "TE Stmt constructed from TorchScript: " << std::endl - << *kernel_stmt << std::endl; - // Prints: - // TE Stmt constructed from TorchScript: - // { - // for (const auto v : c10::irange(5)) { - // for (const auto _tail_tail : c10::irange(3)) { - // aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) * - // ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) + - // (tB[_tail_tail + 3 * v]); - // } - // } - // } - - // We can also examine generated LLVM IR and assembly code: - std::cout << "Generated LLVM IR: " << std::endl; - auto ir_str = kernel.getCodeText("ir"); - printLinesToFrom(ir_str, 15, 20); - // Prints: - // Generated LLVM IR: - // %9 = bitcast float* %2 to <8 x float>* - // %10 = load <8 x float>, <8 x float>* %9 ... - // %11 = bitcast float* %5 to <8 x float>* - // %12 = load <8 x float>, <8 x float>* %11 ... - // %13 = fmul <8 x float> %10, %12 - // %14 = fmul <8 x float> %10, %13 - - std::cout << "Generated assembly: " << std::endl; - auto asm_str = kernel.getCodeText("asm"); - printLinesToFrom(asm_str, 10, 15); - // Prints: - // Generated assembly: - // vmulps %ymm1, %ymm0, %ymm2 - // vfmadd213ps %ymm1, %ymm0, %ymm2 - // vmovups %ymm2, (%rax) - // vmovss 32(%rcx), %xmm0 - // vmovss 32(%rdx), %xmm1 - // vmulss %xmm1, %xmm0, %xmm2 - - // We can also execute the generated kernel: - auto A = - at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * - 2.0; - auto B = - at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * - 3.0; - std::vector inputs = {A, B}; - std::vector stack = torch::fmap(inputs); - kernel.run(stack); - auto R = stack[0].toTensor(); - - // Let's print one of the elements from the result tensor to verify that the - // computation did happen and was correct: - std::cout << "R[2][2] = " << R[2][2] << std::endl; - // Prints: - // R[2][2] = 15 - // [ CPUFloatType{} ] -#endif - } - return 0; -} - -void printLinesToFrom(const std::string& input_str, int from, int to) { - std::istringstream f(input_str); - std::string s; - int idx = 0; - while (getline(f, s)) { - if (idx > from) { - std::cout << s << "\n"; - } - if (idx++ > to) { - break; - } - } -} diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 8d3a8090c67a3..c3e26d37da1b2 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -2939,7 +2939,10 @@ def test_unsupported(self, device, dtype, op): @slowTest @onlyCPU - @ops(op_db, dtypes=OpDTypes.supported) + @ops( + [op for op in op_db if get_name(op) not in known_failures], + dtypes=OpDTypes.supported, + ) def test_nnc_correctness(self, device, dtype, op): if not op.supports_tracing: self.skipTest("Requires tracing support") diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index d5586a5b9cd7b..9e408682ca6c3 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1910,7 +1910,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { } auto& out_t = p_node->Output(0).toTensor(); - if (in0_t.sizes() == in1_t.sizes() && + if (te && te->checkInput(in0_t) && in0_t.sizes() == in1_t.sizes() && in0_t.scalar_type() == in1_t.scalar_type() && in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() && in0_t.scalar_type() == at::kFloat) { From e07c52b2c0b3aa81f082be03234c0aa0a1418029 Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Fri, 8 Aug 2025 23:26:49 +0000 Subject: [PATCH 0175/1424] [dynamo] Improve support for itertools.product (#159693) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159693 Approved by: https://github.com/guilhermeleobas, https://github.com/mlazos --- test/dynamo/cpython/3_13/test_itertools.diff | 20 +++++++++++++++- test/dynamo/cpython/3_13/test_itertools.py | 4 ++-- test/dynamo/test_functions.py | 17 +++++++++++++ ...3-test_itertools-TestBasicOps.test_product | 0 ...3-test_itertools-TestExamples.test_product | 0 ...thon313-test_itertools-TestGC.test_product | 0 torch/_dynamo/graph_break_registry.json | 10 ++++++++ torch/_dynamo/variables/iter.py | 24 +++++++++++++------ 8 files changed, 65 insertions(+), 10 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_product delete mode 100644 test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_product delete mode 100644 test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_product diff --git a/test/dynamo/cpython/3_13/test_itertools.diff b/test/dynamo/cpython/3_13/test_itertools.diff index df7205a1c9033..027e958a4b6f8 100644 --- a/test/dynamo/cpython/3_13/test_itertools.diff +++ b/test/dynamo/cpython/3_13/test_itertools.diff @@ -1,5 +1,5 @@ diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py -index 7d5ba727389..98f962e4353 100644 +index 7d5ba727389..f1cabfe2111 100644 --- a/test/dynamo/cpython/3_13/test_itertools.py +++ b/test/dynamo/cpython/3_13/test_itertools.py @@ -1,3 +1,25 @@ @@ -210,6 +210,24 @@ index 7d5ba727389..98f962e4353 100644 self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')], list(zip('abc', 'def'))) +@@ -1296,7 +1320,6 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(list(product(*(args*r))), + list(product(*args, **dict(repeat=r)))) + self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) +- self.assertRaises(TypeError, product, range(6), None) + + def product1(*args, **kwds): + pools = list(map(tuple, args)) * kwds.get('repeat', 1) +@@ -1336,7 +1359,8 @@ class TestBasicOps(unittest.TestCase): + argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3), + set('abcdefg'), range(11), tuple(range(13))] + for i in range(100): +- args = [random.choice(argtypes) for j in range(random.randrange(5))] ++ with torch._dynamo.set_fullgraph(fullgraph=False): ++ args = [random.choice(argtypes) for j in range(random.randrange(5))] + expected_len = prod(map(len, args)) + self.assertEqual(len(list(product(*args))), expected_len) + self.assertEqual(list(product(*args)), list(product1(*args))) @@ -1767,6 +1791,7 @@ class TestBasicOps(unittest.TestCase): script_helper.assert_python_ok("-c", script) diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py index 98f962e435365..f1cabfe211132 100644 --- a/test/dynamo/cpython/3_13/test_itertools.py +++ b/test/dynamo/cpython/3_13/test_itertools.py @@ -1320,7 +1320,6 @@ def test_product(self): self.assertEqual(list(product(*(args*r))), list(product(*args, **dict(repeat=r)))) self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) - self.assertRaises(TypeError, product, range(6), None) def product1(*args, **kwds): pools = list(map(tuple, args)) * kwds.get('repeat', 1) @@ -1360,7 +1359,8 @@ def product2(*iterables, repeat=1): argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3), set('abcdefg'), range(11), tuple(range(13))] for i in range(100): - args = [random.choice(argtypes) for j in range(random.randrange(5))] + with torch._dynamo.set_fullgraph(fullgraph=False): + args = [random.choice(argtypes) for j in range(random.randrange(5))] expected_len = prod(map(len, args)) self.assertEqual(len(list(product(*args))), expected_len) self.assertEqual(list(product(*args)), list(product1(*args))) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 4afb6acc5d87f..8bd1222a55988 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -268,6 +268,23 @@ def test_itertools_product(a, b): v = v + x * i return v + def test_itertools_product_args(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(*args, **kwargs): + return torch.tensor(list(itertools.product(*args, **kwargs))) + + self.assertRaises(Unsupported, fn, [1, 2, 3], fake_arg=1) + + @make_test + def test_itertools_product_various_iterators(a, b): + itertools.product( + [a, b], + zip([1, 2], [3, 4]), + map(lambda x: x, [1, 2]), + filter(lambda x: True, [1, 2]), + ) + return a + @make_test def test_itertools_chain(a, b): v = a diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_product b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_product deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_product b/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_product deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_product b/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_product deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 15920eb33c3d1..7c25d683b4753 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2680,5 +2680,15 @@ "Use method calls instead of attribute access." ] } + ], + "GB0268": [ + { + "Gb_type": "Unsupported kwargs for itertools.product", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Expected kwargs: 'repeat', but got {','.join(set(kwargs.keys()) - {'repeat'})}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } ] } diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 3db4daefc978e..c6441b884156f 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -59,14 +59,24 @@ def call_function( ) -> "VariableTracker": # See also: module `torch._dynamo.polyfills.itertools` - if ( - self.value is itertools.product - and not kwargs - and all(arg.has_unpack_var_sequence(tx) for arg in args) - ): - seqs = [arg.unpack_var_sequence(tx) for arg in args] + if self.value is itertools.product: + if any(kw != "repeat" for kw in kwargs.keys()): + unimplemented_v2( + gb_type="Unsupported kwargs for itertools.product", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Expected kwargs: 'repeat', but got " + f"{','.join(set(kwargs.keys()) - {'repeat'})}", + hints=[*graph_break_hints.USER_ERROR], + ) + + if "repeat" in kwargs.keys(): + r = kwargs["repeat"].as_python_constant() + else: + r = 1 + seqs = [arg.force_unpack_var_sequence(tx) for arg in args] items = [ - variables.TupleVariable(list(item)) for item in itertools.product(*seqs) + variables.TupleVariable(list(item)) + for item in itertools.product(*seqs, repeat=r) ] return variables.ListIteratorVariable( items, mutation_type=ValueMutationNew() From 5ed4f9177907fe403ec4c4499d0d0e9be6b68fcf Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Fri, 8 Aug 2025 23:26:50 +0000 Subject: [PATCH 0176/1424] [dynamo] support itertools.permutations (#159694) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159694 Approved by: https://github.com/guilhermeleobas ghstack dependencies: #159693 --- test/dynamo/cpython/3_13/test_itertools.diff | 82 +++++++++++++------ test/dynamo/cpython/3_13/test_itertools.py | 11 +-- test/dynamo/test_functions.py | 25 ++++++ ...t_itertools-TestBasicOps.test_permutations | 0 ...t_itertools-TestExamples.test_permutations | 0 ...13-test_itertools-TestGC.test_permutations | 0 torch/_dynamo/variables/iter.py | 18 ++++ 7 files changed, 102 insertions(+), 34 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_permutations delete mode 100644 test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_permutations delete mode 100644 test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_permutations diff --git a/test/dynamo/cpython/3_13/test_itertools.diff b/test/dynamo/cpython/3_13/test_itertools.diff index 027e958a4b6f8..21763d689ac6a 100644 --- a/test/dynamo/cpython/3_13/test_itertools.diff +++ b/test/dynamo/cpython/3_13/test_itertools.diff @@ -1,5 +1,5 @@ diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py -index 7d5ba727389..f1cabfe2111 100644 +index 7d5ba727389..d15d83a2184 100644 --- a/test/dynamo/cpython/3_13/test_itertools.py +++ b/test/dynamo/cpython/3_13/test_itertools.py @@ -1,3 +1,25 @@ @@ -50,7 +50,41 @@ index 7d5ba727389..f1cabfe2111 100644 def pickletest(self, protocol, it, stop=4, take=1, compare=None): """Test that an iterator is the same after pickling, also when part-consumed""" -@@ -756,7 +778,7 @@ class TestBasicOps(unittest.TestCase): +@@ -454,14 +476,8 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1) + self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1) + +- @pickle_deprecated + def test_permutations(self): +- self.assertRaises(TypeError, permutations) # too few arguments +- self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments +- self.assertRaises(TypeError, permutations, None) # pool is not iterable +- self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative + self.assertEqual(list(permutations('abc', 32)), []) # r > n +- self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None + self.assertEqual(list(permutations(range(3), 2)), + [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) + +@@ -498,7 +514,7 @@ class TestBasicOps(unittest.TestCase): + if len(set(indices)) == r: + yield tuple(pool[i] for i in indices) + +- for n in range(7): ++ for n in range(5): + values = [5*x-12 for x in range(n)] + for r in range(n+2): + result = list(permutations(values, r)) +@@ -515,9 +531,6 @@ class TestBasicOps(unittest.TestCase): + self.assertEqual(result, list(permutations(values, None))) # test r as None + self.assertEqual(result, list(permutations(values))) # test default r + +- for proto in range(pickle.HIGHEST_PROTOCOL + 1): +- self.pickletest(proto, permutations(values, r)) # test pickling +- + @support.bigaddrspacetest + def test_permutations_overflow(self): + with self.assertRaises((OverflowError, MemoryError)): +@@ -756,7 +769,7 @@ class TestBasicOps(unittest.TestCase): def test_cycle(self): self.assertEqual(take(10, cycle('abc')), list('abcabcabca')) self.assertEqual(list(cycle('')), []) @@ -59,7 +93,7 @@ index 7d5ba727389..f1cabfe2111 100644 self.assertRaises(TypeError, cycle, 5) self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0]) -@@ -888,7 +910,7 @@ class TestBasicOps(unittest.TestCase): +@@ -888,7 +901,7 @@ class TestBasicOps(unittest.TestCase): # Check normal pickled for proto in range(pickle.HIGHEST_PROTOCOL + 1): dup = [] @@ -68,7 +102,7 @@ index 7d5ba727389..f1cabfe2111 100644 for elem in g: self.assertEqual(k, elem[0]) dup.append(elem) -@@ -896,8 +918,8 @@ class TestBasicOps(unittest.TestCase): +@@ -896,8 +909,8 @@ class TestBasicOps(unittest.TestCase): # Check nested case dup = [] @@ -79,7 +113,7 @@ index 7d5ba727389..f1cabfe2111 100644 for elem in ig: self.assertEqual(k, elem[0]) self.assertEqual(ik, elem[2]) -@@ -907,8 +929,8 @@ class TestBasicOps(unittest.TestCase): +@@ -907,8 +920,8 @@ class TestBasicOps(unittest.TestCase): # Check nested and pickled for proto in range(pickle.HIGHEST_PROTOCOL + 1): dup = [] @@ -90,7 +124,7 @@ index 7d5ba727389..f1cabfe2111 100644 for elem in ig: self.assertEqual(k, elem[0]) self.assertEqual(ik, elem[2]) -@@ -917,7 +939,7 @@ class TestBasicOps(unittest.TestCase): +@@ -917,7 +930,7 @@ class TestBasicOps(unittest.TestCase): # Check case where inner iterator is not used @@ -99,7 +133,7 @@ index 7d5ba727389..f1cabfe2111 100644 expectedkeys = set([r[0] for r in s]) self.assertEqual(set(keys), expectedkeys) self.assertEqual(len(keys), len(expectedkeys)) -@@ -925,7 +947,7 @@ class TestBasicOps(unittest.TestCase): +@@ -925,7 +938,7 @@ class TestBasicOps(unittest.TestCase): # Check case where inner iterator is used after advancing the groupby # iterator s = list(zip('AABBBAAAA', range(9))) @@ -108,7 +142,7 @@ index 7d5ba727389..f1cabfe2111 100644 _, g1 = next(it) _, g2 = next(it) _, g3 = next(it) -@@ -936,7 +958,7 @@ class TestBasicOps(unittest.TestCase): +@@ -936,7 +949,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(g3), []) for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -117,7 +151,7 @@ index 7d5ba727389..f1cabfe2111 100644 _, g = next(it) next(it) next(it) -@@ -1002,27 +1024,29 @@ class TestBasicOps(unittest.TestCase): +@@ -1002,27 +1015,29 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2]) self.assertEqual(list(filter(bool, [0,1,0,2,0])), [1,2]) self.assertEqual(take(4, filter(isEven, count())), [0,2,4,6]) @@ -166,7 +200,7 @@ index 7d5ba727389..f1cabfe2111 100644 @pickle_deprecated def test_filterfalse(self): -@@ -1047,8 +1071,8 @@ class TestBasicOps(unittest.TestCase): +@@ -1047,8 +1062,8 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(take(3,zip('abcdef', count())), lzip('abcdef', range(3))) self.assertEqual(list(zip('abcdef')), lzip('abcdef')) self.assertEqual(list(zip()), lzip()) @@ -177,7 +211,7 @@ index 7d5ba727389..f1cabfe2111 100644 self.assertEqual([tuple(list(pair)) for pair in zip('abc', 'def')], lzip('abc', 'def')) self.assertEqual([pair for pair in zip('abc', 'def')], -@@ -1105,19 +1129,19 @@ class TestBasicOps(unittest.TestCase): +@@ -1105,19 +1120,19 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(zip_longest('abc', 'defg', **{})), list(zip(list('abc')+[None], 'defg'))) # empty keyword dict @@ -210,7 +244,7 @@ index 7d5ba727389..f1cabfe2111 100644 self.assertEqual([tuple(list(pair)) for pair in zip_longest('abc', 'def')], list(zip('abc', 'def'))) -@@ -1296,7 +1320,6 @@ class TestBasicOps(unittest.TestCase): +@@ -1296,7 +1311,6 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(product(*(args*r))), list(product(*args, **dict(repeat=r)))) self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) @@ -218,7 +252,7 @@ index 7d5ba727389..f1cabfe2111 100644 def product1(*args, **kwds): pools = list(map(tuple, args)) * kwds.get('repeat', 1) -@@ -1336,7 +1359,8 @@ class TestBasicOps(unittest.TestCase): +@@ -1336,7 +1350,8 @@ class TestBasicOps(unittest.TestCase): argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3), set('abcdefg'), range(11), tuple(range(13))] for i in range(100): @@ -228,7 +262,7 @@ index 7d5ba727389..f1cabfe2111 100644 expected_len = prod(map(len, args)) self.assertEqual(len(list(product(*args))), expected_len) self.assertEqual(list(product(*args)), list(product1(*args))) -@@ -1767,6 +1791,7 @@ class TestBasicOps(unittest.TestCase): +@@ -1767,6 +1782,7 @@ class TestBasicOps(unittest.TestCase): script_helper.assert_python_ok("-c", script) # Issue 13454: Crash when deleting backward iterator from tee() @@ -236,7 +270,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_tee_del_backward(self): forward, backward = tee(repeat(None, 20000000)) try: -@@ -1920,7 +1945,7 @@ class TestBasicOps(unittest.TestCase): +@@ -1920,7 +1936,7 @@ class TestBasicOps(unittest.TestCase): tp.foobar = 1 @@ -245,7 +279,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_accumulate(self): self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15]) -@@ -2032,7 +2057,7 @@ class TestExamples(unittest.TestCase): +@@ -2032,7 +2048,7 @@ class TestExamples(unittest.TestCase): self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4]) @@ -254,7 +288,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_batched_recipe(self): def batched_recipe(iterable, n): -@@ -2081,6 +2106,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): +@@ -2081,6 +2097,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): for i, element in zip(range(i + 1, stop), iterable): pass @@ -262,7 +296,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_islice_recipe(self): self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB')) self.assertEqual(list(self.islice('ABCDEFG', 2, 4)), list('CD')) -@@ -2265,7 +2291,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): +@@ -2265,7 +2282,7 @@ class TestPurePythonRoughEquivalents(unittest.TestCase): raise @@ -271,7 +305,7 @@ index 7d5ba727389..f1cabfe2111 100644 def makecycle(self, iterator, container): container.append(iterator) -@@ -2465,7 +2491,7 @@ def L(seqn): +@@ -2465,7 +2482,7 @@ def L(seqn): return chain(map(lambda x:x, R(Ig(G(seqn))))) @@ -280,7 +314,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_accumulate(self): s = [1,2,3,4,5] -@@ -2644,7 +2670,7 @@ class TestVariousIteratorArgs(unittest.TestCase): +@@ -2644,7 +2661,7 @@ class TestVariousIteratorArgs(unittest.TestCase): self.assertRaises(TypeError, tee, N(s)) self.assertRaises(ZeroDivisionError, list, tee(E(s))[0]) @@ -289,7 +323,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_repeat(self): self.assertEqual(operator.length_hint(repeat(None, 50)), 50) -@@ -2657,7 +2683,7 @@ class LengthTransparency(unittest.TestCase): +@@ -2657,7 +2674,7 @@ class LengthTransparency(unittest.TestCase): self.assertEqual(operator.length_hint(repeat(None, times=-1)), 0) self.assertEqual(operator.length_hint(repeat(None, times=-2)), 0) @@ -298,7 +332,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_sf_793826(self): # Fix Armin Rigo's successful efforts to wreak havoc -@@ -2718,6 +2744,7 @@ class RegressionTests(unittest.TestCase): +@@ -2718,6 +2735,7 @@ class RegressionTests(unittest.TestCase): @support.skip_if_pgo_task @support.requires_resource('cpu') @@ -306,7 +340,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_long_chain_of_empty_iterables(self): # Make sure itertools.chain doesn't run into recursion limits when # dealing with long chains of empty iterables. Even with a high -@@ -2750,7 +2777,7 @@ class RegressionTests(unittest.TestCase): +@@ -2750,7 +2768,7 @@ class RegressionTests(unittest.TestCase): next(g, None) # shouldn't crash @@ -315,7 +349,7 @@ index 7d5ba727389..f1cabfe2111 100644 def test_keywords_in_subclass(self): # count is not subclassable... testcases = [ -@@ -2805,49 +2832,5 @@ class SubclassWithKwargsTest(unittest.TestCase): +@@ -2805,49 +2823,5 @@ class SubclassWithKwargsTest(unittest.TestCase): self.assertEqual(u.newarg, 3) diff --git a/test/dynamo/cpython/3_13/test_itertools.py b/test/dynamo/cpython/3_13/test_itertools.py index f1cabfe211132..d15d83a2184d6 100644 --- a/test/dynamo/cpython/3_13/test_itertools.py +++ b/test/dynamo/cpython/3_13/test_itertools.py @@ -476,14 +476,8 @@ def test_combinations_with_replacement_tuple_reuse(self): self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1) self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1) - @pickle_deprecated def test_permutations(self): - self.assertRaises(TypeError, permutations) # too few arguments - self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments - self.assertRaises(TypeError, permutations, None) # pool is not iterable - self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative self.assertEqual(list(permutations('abc', 32)), []) # r > n - self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -520,7 +514,7 @@ def permutations2(iterable, r=None): if len(set(indices)) == r: yield tuple(pool[i] for i in indices) - for n in range(7): + for n in range(5): values = [5*x-12 for x in range(n)] for r in range(n+2): result = list(permutations(values, r)) @@ -537,9 +531,6 @@ def permutations2(iterable, r=None): self.assertEqual(result, list(permutations(values, None))) # test r as None self.assertEqual(result, list(permutations(values))) # test default r - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - self.pickletest(proto, permutations(values, r)) # test pickling - @support.bigaddrspacetest def test_permutations_overflow(self): with self.assertRaises((OverflowError, MemoryError)): diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 8bd1222a55988..4d415e19b3c36 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -285,6 +285,31 @@ def test_itertools_product_various_iterators(a, b): ) return a + def test_itertools_permutations_basic(self): + def fn(): + return torch.tensor(list(itertools.permutations([1, 2, 3], 2))) + + actual = torch.compile(fn, backend="eager", fullgraph=True)() + expected = fn() + self.assertEqual(actual, expected) + + def test_itertools_permutations_args(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(*args, **kwargs): + return torch.tensor(list(itertools.permutations(*args, **kwargs))) + + self.assertRaises(Unsupported, fn) + self.assertRaises(Unsupported, fn, [1, 2, 3], 1, 2) + self.assertRaises(Unsupported, fn, [1, 2, 3], fake_arg=1) + + @make_test + def test_itertools_permutations_various_iterators(a, b): + itertools.permutations([a, b]) + itertools.permutations(zip([1, 2], [3, 4])) + itertools.permutations(map(lambda x: x, [1, 2])) + itertools.permutations(filter(lambda x: True, [1, 2])) + return a + @make_test def test_itertools_chain(a, b): v = a diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_permutations b/test/dynamo_expected_failures/CPython313-test_itertools-TestBasicOps.test_permutations deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_permutations b/test/dynamo_expected_failures/CPython313-test_itertools-TestExamples.test_permutations deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_permutations b/test/dynamo_expected_failures/CPython313-test_itertools-TestGC.test_permutations deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index c6441b884156f..75c6712609e90 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -190,6 +190,24 @@ def keyfunc(x): return variables.CountIteratorVariable( *args, mutation_type=ValueMutationNew() ) + elif ( + self.value is itertools.permutations + and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant())) + and not kwargs + ): + if len(args) == 2: + r = args[1].as_python_constant() + else: + r = None + items = [ + variables.TupleVariable(list(item)) + for item in itertools.permutations( + args[0].force_unpack_var_sequence(tx), r + ) + ] + return variables.ListIteratorVariable( + items, mutation_type=ValueMutationNew() + ) else: return super().call_function(tx, args, kwargs) From 0d88593dd826544c9e7bd4aa615ef86847a78d2b Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Sat, 9 Aug 2025 04:01:27 +0000 Subject: [PATCH 0177/1424] [audio hash update] update the pinned audio hash (#160153) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160153 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index cdfbede9e8f09..83860798279ad 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -0c22347335f4c9a5b92a2f5bad65e05e2464c184 +e500f0cf88bc57ffd8b0029033da305eef24ae25 From 303c614f3df95ae2b659c5f6c1838b14e4776ce6 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 8 Aug 2025 17:36:36 -0700 Subject: [PATCH 0178/1424] [dynamo] Be consistent with UserMethodVariable source (#160155) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160155 Approved by: https://github.com/StrongerXi --- test/dynamo/test_functions.py | 23 +++++++++++ torch/_dynamo/variables/functions.py | 29 +++++++++---- torch/_dynamo/variables/higher_order_ops.py | 2 +- torch/_dynamo/variables/misc.py | 6 +-- torch/_dynamo/variables/torch_function.py | 6 +-- torch/_dynamo/variables/user_defined.py | 45 +++++++++++++++------ 6 files changed, 85 insertions(+), 26 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 4d415e19b3c36..6e28264d54669 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -5072,6 +5072,29 @@ def __getattribute__(self, name): with self.assertRaises(Unsupported): a.call_function(None, [], {}) + def test_inspect_method_source(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + + def check(self, x): + return x * 2 + + def forward(self, x): + return x * 2 + + mod = Mod() + + def fn(x): + inspect.signature(mod.check).parameters.items() + return mod(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + instantiate_parametrized_tests(FunctionTests) instantiate_parametrized_tests(DefaultsTests) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index be92c4eb491bc..050f39f55895c 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -424,6 +424,13 @@ def has_self(self): def get_globals(self): return self.fn.__globals__ + def get_source(self): + source = self.source + + if source and isinstance(self, variables.UserMethodVariable): + source = self.source_fn + return source + def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: """ Assume `args` and `kwargs` are VariableTracker arguments for a call to @@ -436,7 +443,9 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: if not isinstance(fn, FunctionType): raise TypeError("Only supports regular Python functions.") root_tx = parent.output.root_tx - result = bind_args_cached(fn, root_tx, self.source, args, kwargs) + + source = self.get_source() + result = bind_args_cached(fn, root_tx, source, args, kwargs) init_cellvars(parent, result, fn.__code__) closure = self.fn.__closure__ or () @@ -449,8 +458,8 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: if cell in side_effects: cell_var = side_effects[cell] - elif self.source: - closure_cell = GetItemSource(ClosureSource(self.source), idx) + elif source: + closure_cell = GetItemSource(ClosureSource(source), idx) closure_cell_contents = AttrSource(closure_cell, "cell_contents") try: contents_var = VariableTracker.build( @@ -480,7 +489,8 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: def var_getattr(self, tx: "InstructionTranslator", name: str): if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) - return fn_var_getattr(tx, self.fn, self.source, name) + source = self.get_source() + return fn_var_getattr(tx, self.fn, source, name) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str @@ -1052,9 +1062,12 @@ def _build_inline_tracer(self, tx, args, kwargs): class UserMethodVariable(UserFunctionVariable): """Some unsupported user-defined method""" - def __init__(self, fn, obj, **kwargs) -> None: + def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: super().__init__(fn=fn, **kwargs) self.obj = obj + self.source_fn = source_fn + if source_fn is None and kwargs.get("source") is not None: + self.source_fn = AttrSource(kwargs.get("source"), "__func__") def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fn}, {self.obj})" @@ -1130,11 +1143,13 @@ def inspect_parameter_names(self): return super().inspect_parameter_names()[1:] def var_getattr(self, tx: "InstructionTranslator", name: str): - source = self.source and AttrSource(self.source, name) if name == "__self__": return self.obj if name == "__func__": - return VariableTracker.build(tx, self.fn, source) + # We might have a better way to access the function object, this + # information is stored in self.source_fn, use that to construct the + # variable tracker. + return VariableTracker.build(tx, self.fn, self.source_fn) return super().var_getattr(tx, name) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 8c0730907a4d5..ea935ae5f7afa 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -938,7 +938,7 @@ def _call_function( torch._dynamo.variables.UserDefinedObjectVariable( self.value, source=self.source ), - source=AttrSource(AttrSource(self.source, "__call__"), "__func__"), + source=AttrSource(self.source, "__call__"), ).call_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 18eda602dbdc0..f75f5b180c72d 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -251,9 +251,9 @@ def call_method( tx, self.objvar.value_type, cls_source ) - return variables.UserMethodVariable( - inner_fn.__func__, cls_variable, source=source - ).call_function(tx, args, kwargs) + return variables.UserFunctionVariable( + inner_fn.__func__, source=AttrSource(source, "__func__") + ).call_function(tx, [cls_variable, *args], kwargs) elif isinstance(inner_fn, types.FunctionType): return variables.UserFunctionVariable( inner_fn, source=source diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index c48c7c3f24844..4458468d8118c 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -59,7 +59,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable -from .functions import UserMethodVariable +from .functions import UserFunctionVariable, UserMethodVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -620,8 +620,8 @@ def var_getattr(self, tx: "InstructionTranslator", name): elif isinstance(attr, property): getter_source = AttrSource(attr_source, "fget") getter = attr.fget - getter_var = UserMethodVariable(getter, self, source=getter_source) - return getter_var.call_function(tx, [], {}) + getter_var = UserFunctionVariable(getter, source=getter_source) + return getter_var.call_function(tx, [self], {}) elif isinstance(attr, classmethod): return UserMethodVariable( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 7cb21ab372801..95b1a37b677fc 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1009,17 +1009,18 @@ def call_method( # check for methods implemented in C++ if isinstance(method, types.FunctionType): - source = None - if self.source: - source = self.get_source_by_walking_mro(name) + source = self.source + source_fn = None + if source: + source_fn = self.get_source_by_walking_mro(name) # TODO(jansel): add a guard to check for monkey patching? from ..mutation_guard import unpatched_nn_module_init if method is torch.nn.Module.__init__: method = unpatched_nn_module_init - return UserMethodVariable(method, self, source=source).call_function( - tx, args, kwargs - ) + return UserMethodVariable( + method, self, source_fn=source_fn, source=source + ).call_function(tx, args, kwargs) if method is list.__len__ and self.source and not (args or kwargs): install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) @@ -1380,7 +1381,9 @@ def var_getattr(self, tx: "InstructionTranslator", name): self.value.__class__, name, NO_SUCH_SUBOBJ ) is_accessible_from_type_mro = ( - subobj_from_class is subobj and self.cls_source is not None + subobj_from_class is subobj + and self.cls_source is not None + and self.source is not None ) if isinstance(subobj, property): @@ -1389,9 +1392,13 @@ def var_getattr(self, tx: "InstructionTranslator", name): source = self.get_source_by_walking_mro(name) # Get the getter function source = AttrSource(source, "fget") - return variables.UserMethodVariable( - subobj.fget, self, source=source - ).call_function(tx, [], {}) + + # Avoid using UserMethodVariable here because there is no way to + # access the method object here. Direct inline by creating the + # UserFunctionVariable. + return variables.UserFunctionVariable( + subobj.fget, source=source + ).call_function(tx, [self], {}) elif isinstance(subobj, _collections._tuplegetter): # namedtuple fields are represented by _tuplegetter, and here we # emulate its `__get__`, which is implemented in C. @@ -1412,8 +1419,17 @@ def var_getattr(self, tx: "InstructionTranslator", name): func = subobj.__get__(self.value) return VariableTracker.build(tx, func, source) elif isinstance(subobj, classmethod): + source_fn = None + if is_accessible_from_type_mro: + # Accessing from __dict__ does not resolve the descriptor, it + # returns a classmethod object, so access the __func__ + # attribute to get to the actual function. + source_fn = AttrSource(self.get_source_by_walking_mro(name), "__func__") return variables.UserMethodVariable( - subobj.__func__, self.var_getattr(tx, "__class__"), source=source + subobj.__func__, + self.var_getattr(tx, "__class__"), + source_fn=source_fn, + source=source, ) elif isinstance(subobj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static({}, "fromkeys") @@ -1503,7 +1519,12 @@ def var_getattr(self, tx: "InstructionTranslator", name): func = subobj if inspect.ismethod(dynamic_subobj): - return variables.UserMethodVariable(func, self, source=source) + source_fn = None + if is_accessible_from_type_mro: + source_fn = self.get_source_by_walking_mro(name) + return variables.UserMethodVariable( + func, self, source_fn=source_fn, source=source + ) elif inspect.isfunction(dynamic_subobj): if is_utils_checkpoint(func): return build_checkpoint_variable(source=source) From bcf23ecc476df2bd7479f142567213e2623308ee Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Sat, 9 Aug 2025 04:17:28 +0000 Subject: [PATCH 0179/1424] [vllm hash update] update the pinned vllm hash (#160235) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vllm hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160235 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vllm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt index d5b7ebc020178..e5260797d2150 100644 --- a/.github/ci_commit_pins/vllm.txt +++ b/.github/ci_commit_pins/vllm.txt @@ -1 +1 @@ -7e3a8dc90670fd312ce1e0d4eba9bf11c571e3ad +35afe1b30b154114dc2ee8329e12f8cf3fe9f576 From fb887c3bb588cfe782615e67f6c26db636b8539b Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Sat, 9 Aug 2025 04:44:12 +0000 Subject: [PATCH 0180/1424] Add Sherlock and Zhengxu as codeowner for schema.py (#160233) Test Plan: CI Rollback Plan: Differential Revision: D79933462 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160233 Approved by: https://github.com/zhxchen17 --- CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/CODEOWNERS b/CODEOWNERS index 24ab4fd35be9d..1d91adacb0629 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -164,6 +164,7 @@ caffe2/utils/hip @jeffdaily @jithunnair-amd # torch.export /torch/export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi /torch/_export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi +/torch/_export/serde/schema.py @SherlockNoMad @zhxchen17 # Dynamic Shapes /torch/fx/experimental/symbolic_shapes.py @bobrenjc93 @laithsakka From 4183d4ff3dcc1d87400326a9a7998c3f9e966f60 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Fri, 8 Aug 2025 13:07:09 -0700 Subject: [PATCH 0181/1424] Make user defined Triton kernels serializable for fx_graph_runnable (#160002) Resolves issue https://github.com/pytorch/pytorch/issues/153475 where `fx_graph_runnable` didn't work with user defined triton kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160002 Approved by: https://github.com/eellison --- test/dynamo/test_fx_graph_runnable.py | 88 +++++++++++++++++++++++++++ torch/_dynamo/repro/after_aot.py | 66 ++++++++++++++++++++ 2 files changed, 154 insertions(+) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index d5ad0c160c4ba..47e9ee3cb888e 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -11,12 +11,65 @@ from torch._inductor.codecache import WritableTempFile from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE +from torch.utils._triton import has_triton if torch.distributed.is_available(): from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore +if has_triton(): + import triton + import triton.language as tl + + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + @triton.jit + def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.atomic_add(output_ptr + offsets, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 1024}, + num_warps=4, + num_stages=2, + pre_hook=init_to_zero("output_ptr"), + ) + ], + pre_hook=init_to_zero("output_ptr"), + post_hook=init_to_zero("output_ptr"), + key=["n_elements"], + ) + @triton.jit + def add_kernel_autotune( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.atomic_add(output_ptr + offsets, output, mask=mask) + + +from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.triton_utils import requires_gpu + class FxGraphRunnableArtifactFilter(logging.Filter): def filter(self, record): @@ -100,6 +153,41 @@ def f(x): torch.compile(f)(torch.randn(4)) self._exec_and_verify_payload() + @unittest.skipUnless(has_triton(), "Triton not available") + def test_user_defined_triton_kernel_autotune(self): + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) + n_elements = output.numel() + + def grid( + meta, + ): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + add_kernel_autotune[grid](x, y, output, n_elements) + return output + + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + + torch.compile(add)(x, y) + self._exec_and_verify_payload() + + @unittest.skipUnless(has_triton(), "Triton not available") + @requires_gpu + def test_user_defined_triton_kernel(self): + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) + n_elements = x.numel() + add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4) + return output + + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + + torch.compile(add)(x, y) + self._exec_and_verify_payload() + def test_two_inputs_matmul(self): def f(a, b): return (a @ b).relu() diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 71f552a83b4ab..6f68405e32fdb 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -34,6 +34,21 @@ from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union from typing_extensions import Unpack +from torch.utils._triton import has_triton + + +if has_triton(): + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction +else: + + class Autotuner: # type: ignore[no-redef] + pass + + class JITFunction: # type: ignore[no-redef] + pass + + import torch import torch.fx as fx import torch.nn as nn @@ -58,6 +73,7 @@ ) from torch._dynamo.utils import clone_inputs, counters, same from torch._environment import is_fbcode +from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.output_code import OutputCode from torch._library.fake_class_registry import FakeScriptObject @@ -302,6 +318,16 @@ def generate_compiler_repro_string( """ ).strip() + triton_imports = "" + + if len(kernel_side_table.id_to_kernel) > 0: + triton_imports = textwrap.dedent( + """ +import triton +import triton.language as tl + """ + ).strip() + model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -312,6 +338,7 @@ def generate_compiler_repro_string( from math import inf import torch._inductor.inductor_prims {distributed_imports} +{triton_imports} {generate_config_string(stable_output=stable_output)} @@ -330,6 +357,45 @@ def generate_compiler_repro_string( model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() + kernel_side_table_prefix = ( + "torch._higher_order_ops.triton_kernel_wrap.kernel_side_table" + ) + # Track which grid entry corresponds to the best config + for id in kernel_side_table.id_to_kernel: + kernel = kernel_side_table.get_kernel(id) + if isinstance(kernel, Autotuner): + config_strs = [] + for kernel_config in kernel.configs: + config_strs.append(f"""triton.Config( + {str(kernel_config.kwargs)}, + num_warps={kernel_config.num_warps}, + num_stages={kernel_config.num_stages}, + )""") + + config_str = ",".join(config_strs) + model_str += textwrap.dedent(f""" + @triton.autotune( + configs=[ + {config_str} + ], + key=[] + ) + """).strip() + + model_str += "\n@triton.jit\n" + src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src + fn_name = ( + kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name + ) + fn_name = fn_name.split(".")[-1] + + model_str += src_code + model_str += "\n" + model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n" + + if len(kernel_side_table.constant_args) > 0: + model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" + model_str += NNModuleToString.convert(gm) writer = InputWriter(save_dir, stable_hash=stable_hash) From 8047421fbb607d70ede13b9cd5a60b7b8bdfe348 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Thu, 7 Aug 2025 22:19:11 -0700 Subject: [PATCH 0182/1424] [Linter] Expanding the scope of detecting device-bias code. (#159949) Currently, the device-bias linter only targets functions decorated with @requires_gpu. This PR adds support for two new detection scenarios: 1. Detect device-bias code in functions decorated with @requires_triton. 2. Detect device-bias code for entire test suites that are defined as shared across GPUs. For example: ``` if __name__ == "__main__": if HAS_GPU: run_tests() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159949 Approved by: https://github.com/EikanWang, https://github.com/jansel --- test/dynamo/test_aot_autograd_cache.py | 6 +- test/dynamo/test_reconstruct.py | 6 +- test/inductor/test_aot_inductor.py | 8 +- test/inductor/test_codecache.py | 4 +- test/inductor/test_inplace_padding.py | 4 +- test/inductor/test_max_autotune.py | 84 +++++++++++-------- test/inductor/test_memory.py | 4 +- test/inductor/test_op_dtype_prop.py | 8 +- test/inductor/test_triton_heuristics.py | 2 +- .../adapters/test_device_bias_linter.py | 81 +++++++++++++----- 10 files changed, 132 insertions(+), 75 deletions(-) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 0d4a1f01f9a30..d26e4b31917e0 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -447,8 +447,8 @@ def test_non_bundled_to_bundled_config_change(self): def fn(x, y): return (x * 2, y @ y) - a = torch.rand(25, device="cuda") - b = torch.rand(5, 5, device="cuda") + a = torch.rand(25, device=GPU_TYPE) + b = torch.rand(5, 5, device=GPU_TYPE) compiled_fn = torch.compile(fn, backend="inductor") self.assertEqual(fn(a, b), compiled_fn(a, b)) @@ -822,7 +822,7 @@ def backward(ctx, grad_output): def fn(a): return MyAutogradFunction.apply(a) - a = torch.randn(5, device="cuda", requires_grad=True) + a = torch.randn(5, device=GPU_TYPE, requires_grad=True) a2 = a.clone().detach_().requires_grad_(True) compiled_fn = torch.compile(fn, backend="inductor") result = compiled_fn(a) diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 0cafaf9878e60..9f3d41964195d 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -7,7 +7,7 @@ import torch import torch._dynamo.test_case from torch.testing._internal.common_utils import IS_FBCODE -from torch.testing._internal.inductor_utils import requires_triton +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton from torch.utils._triton import ( has_triton_experimental_host_tma, has_triton_tensor_descriptor_host_tma, @@ -420,7 +420,7 @@ def create_tma(tensor): ) return tensor + 1, tma - x = torch.randn(128, 128, device="cuda") + x = torch.randn(128, 128, device=GPU_TYPE) ref = create_tma(x) res = torch.compile(create_tma, backend="eager")(x) @@ -441,7 +441,7 @@ def create_tma(tensor): ) return tensor + 1, tma - x = torch.randn(128, 128, device="cuda") + x = torch.randn(128, 128, device=GPU_TYPE) ref = create_tma(x) res = torch.compile(create_tma, backend="eager")(x) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index e0218cd9d8bec..9fa13dc180f93 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -552,7 +552,7 @@ def forward(self, a, b): triton.set_allocator( lambda size, align, stream: torch.empty( - size, dtype=torch.int8, device="cuda" + size, dtype=torch.int8, device=GPU_TYPE ) ) @@ -5235,9 +5235,9 @@ def forward(self, a, b, c): return z example_inputs = ( - torch.randn(10, 20, device="cuda"), - torch.randn(20, 30, device="cuda"), - torch.randn(10, 30, device="cuda"), + torch.randn(10, 20, device=GPU_TYPE), + torch.randn(20, 30, device=GPU_TYPE), + torch.randn(10, 30, device=GPU_TYPE), ) model = Model() kernel_calls = [ diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 8e53725dd159c..3597663431fde 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -2801,8 +2801,8 @@ def get_autotune_stats(): def fn(x, y): return (x + y).relu() - x = torch.randn(100, 100).cuda() - y = torch.randn(100, 100).cuda() + x = torch.randn(100, 100).to(GPU_TYPE) + y = torch.randn(100, 100).to(GPU_TYPE) with config.patch( { diff --git a/test/inductor/test_inplace_padding.py b/test/inductor/test_inplace_padding.py index 46d5cf61121e3..7ddd0dd4441b8 100644 --- a/test/inductor/test_inplace_padding.py +++ b/test/inductor/test_inplace_padding.py @@ -233,9 +233,9 @@ def f(x, y): loss.backward() return loss - x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() + x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16() x.retain_grad() - y = torch.randint(0, V, (B * T,)).cuda() + y = torch.randint(0, V, (B * T,)).to(GPU_TYPE) opt_f = torch.compile(f) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 93165fa2dcec8..ff1d8c3fb8756 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -142,8 +142,16 @@ def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() - b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() + a = ( + torch.randn(*((K, M) if a_transposed else (M, K))) + .to(torch.float16) + .to(GPU_TYPE) + ) + b = ( + torch.randn(*((N, K) if b_transposed else (K, N))) + .to(torch.float16) + .to(GPU_TYPE) + ) with config.patch( { @@ -166,8 +174,8 @@ def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) with ( self.assertRaises(BackendCompilerFailed) as context, @@ -194,8 +202,8 @@ def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) # TMA requires 16-byte alignment: here we repeat the dims # by the factor of 8, as float16 is 2-byte. All dims are @@ -261,9 +269,17 @@ def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() - b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = ( + torch.randn(*((K, M) if a_transposed else (M, K))) + .to(torch.float16) + .to(GPU_TYPE) + ) + b = ( + torch.randn(*((N, K) if b_transposed else (K, N))) + .to(torch.float16) + .to(GPU_TYPE) + ) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) with config.patch( { @@ -286,9 +302,9 @@ def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) with ( self.assertRaises(BackendCompilerFailed) as context, @@ -315,9 +331,9 @@ def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 - a = torch.randn(M, K).to(torch.float16).cuda() - b = torch.randn(K, N).to(torch.float16).cuda() - x = torch.randn(N).to(torch.float16).cuda() + a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) + b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) # TMA requires 16-byte alignment: here we repeat the dims # by the factor of 8, as float16 is 2-byte. All dims are @@ -362,15 +378,15 @@ def scaled_mm( # Create large matrices to ensure we use all possible sms size = 2560 - a = torch.randn(size, size, device="cuda", dtype=torch.bfloat16) + a = torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16) b = ( - torch.randn(size, size, device="cuda", dtype=torch.bfloat16) + torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16) .transpose(0, 1) .contiguous() .transpose(0, 1) ) - scale_a = torch.tensor(1, dtype=torch.float32, device="cuda") - scale_b = torch.tensor(1, dtype=torch.float32, device="cuda") + scale_a = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE) + scale_b = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE) args = ( (a.to(torch.float8_e4m3fn), b.to(torch.float8_e4m3fn), scale_a, scale_b) @@ -949,9 +965,9 @@ def f(x, y): loss.backward() return loss - x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() + x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16() x.retain_grad() - y = torch.randint(0, V, (B * T,)).cuda() + y = torch.randint(0, V, (B * T,)).to(GPU_TYPE) import torch._inductor.utils as inductor_utils @@ -985,8 +1001,8 @@ def test_max_autotune_decompose_k(self, sizes, dtype, dynamic): M, N, K = sizes - a = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=dtype, device="cuda", requires_grad=True) + a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True) + b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True) possible_splits = range(2, min(K // M, K // N) + 1) @@ -1083,10 +1099,10 @@ def f(a, b): return (a_in @ b).relu() a = torch.randn( - 32, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True + 32, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) b = torch.randn( - 32768, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True + 32768, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) torch._dynamo.reset() @@ -1126,9 +1142,11 @@ def f(a, b): a_in = torch.cat([a for _ in range(256)], dim=0) return (a_in @ b).relu().sum() - a = torch.randn(8, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + a = torch.randn( + 8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True + ) b = torch.randn( - 64, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True + 64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) torch._dynamo.reset() @@ -1168,8 +1186,8 @@ def f(a, b): a = a.transpose(0, 1) return a @ b - a = torch.randn((32768, 256), device="cuda", dtype=torch.bfloat16) - b = torch.randn((32768, 1152), device="cuda", dtype=torch.bfloat16) + a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16) b = b[:, :1096] @@ -1522,8 +1540,8 @@ def test_max_autotune_decompose_k_envvars( for M, N, K in shapes: get_k_splits.cache_clear() use_decompose_k_choice.cache_clear() - a = torch.randn(M, K, dtype=torch.float16, device="cuda") - b = torch.randn(K, N, dtype=torch.float16, device="cuda") + a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE) + b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE) with config.patch( { @@ -1560,8 +1578,8 @@ def f(a, b): M, N, K = (1024, 1024, 1024) - a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) + a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE, requires_grad=True) + b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE, requires_grad=True) with mock.patch( "torch._inductor.template_registry.get_template_heuristic" diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 2231b94316b36..81f7ea03d3bb4 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -379,8 +379,8 @@ def foo(inp, inp2): return out, out2, inp2 @ inp2 - inp = torch.rand([256, 256], device="cuda") - inp2 = torch.rand([256, 256], device="cuda") + inp = torch.rand([256, 256], device=GPU_TYPE) + inp2 = torch.rand([256, 256], device=GPU_TYPE) def replace_foreach(gm): nodes = gm.find_nodes( diff --git a/test/inductor/test_op_dtype_prop.py b/test/inductor/test_op_dtype_prop.py index 458d64aa41d5b..6f7eec601666b 100644 --- a/test/inductor/test_op_dtype_prop.py +++ b/test/inductor/test_op_dtype_prop.py @@ -260,7 +260,7 @@ def test_downcast_div_mod(self): def fn(x, y): return x % y, x / y - x, y = (torch.rand([8], dtype=torch.float16, device="cuda") for _ in range(2)) + x, y = (torch.rand([8], dtype=torch.float16, device=GPU_TYPE) for _ in range(2)) out, code = run_and_get_code(torch.compile(fn), x, y) @@ -271,7 +271,7 @@ def fn(x, y): @config.patch("test_configs.runtime_triton_dtype_assert", True) def test_constant(self): def fn(): - return (torch.full((2, 3), 3.1416, device="cuda", dtype=torch.float16),) + return (torch.full((2, 3), 3.1416, device=GPU_TYPE, dtype=torch.float16),) out, code = run_and_get_code(torch.compile(fn)) FileCheck().check("static_assert").check_same(".dtype").run(code[0]) @@ -284,7 +284,7 @@ def test_any(self): def fn(x): return torch.any(x) - x = torch.rand([40], device="cuda").to(torch.bool) + x = torch.rand([40], device=GPU_TYPE).to(torch.bool) out, code = run_and_get_code(torch.compile(fn), x) self.assertEqual(fn(x), out) @@ -293,7 +293,7 @@ def fn(x): def test_assoc_scan(self): from torch._higher_order_ops.associative_scan import associative_scan - x = torch.randn(10, device="cuda") + x = torch.randn(10, device=GPU_TYPE) # dtype check correctly associative_scan( lambda acc, curr: acc + torch.abs(curr), x, dim=-1, combine_mode="pointwise" diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index a9f898a36af55..4c2a04678b889 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -257,7 +257,7 @@ def grid(meta): def fn(x): return triton_sqr(x) - x = torch.randn(32, device="cuda") + x = torch.randn(32, device=GPU_TYPE) ref = fn(x) res = torch.compile(fn)(x) self.assertEqual(ref, res) diff --git a/tools/linter/adapters/test_device_bias_linter.py b/tools/linter/adapters/test_device_bias_linter.py index 00786ef3df86c..a2079e4fe810a 100644 --- a/tools/linter/adapters/test_device_bias_linter.py +++ b/tools/linter/adapters/test_device_bias_linter.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """ This lint verifies that every Python test file (file that matches test_*.py or -*_test.py in the test folder) has a cuda hard code in `requires_gpu()` -decorated function to ensure that the test not fail on other GPU. - +*_test.py in the test folder) has a cuda hard code in `requires_gpu()` or +`requires_triton()` decorated function or `if HAS_GPU:` guarded main section, +to ensure that the test not fail on other GPU devices. """ from __future__ import annotations @@ -39,21 +39,59 @@ class LintMessage(NamedTuple): DEVICE_BIAS = ["cuda", "xpu", "mps"] +GPU_RELATED_DECORATORS = {"requires_gpu", "requires_triton"} + + +def is_main_has_gpu(tree: ast.AST) -> bool: + def _contains_has_gpu(node: ast.AST) -> bool: + if isinstance(node, ast.Name) and node.id in ["HAS_GPU", "RUN_GPU"]: + return True + elif isinstance(node, ast.BoolOp): + return any(_contains_has_gpu(value) for value in node.values) + elif isinstance(node, ast.UnaryOp): + return _contains_has_gpu(node.operand) + elif isinstance(node, ast.Compare): + return _contains_has_gpu(node.left) or any( + _contains_has_gpu(comp) for comp in node.comparators + ) + elif isinstance(node, (ast.IfExp, ast.Call)): + return False + return False + + for node in ast.walk(tree): + # Detect if __name__ == "__main__": + if isinstance(node, ast.If): + if ( + isinstance(node.test, ast.Compare) + and isinstance(node.test.left, ast.Name) + and node.test.left.id == "__name__" + ): + if any( + isinstance(comp, ast.Constant) and comp.value == "__main__" + for comp in node.test.comparators + ): + for inner_node in node.body: + if isinstance(inner_node, ast.If) and _contains_has_gpu( + inner_node.test + ): + return True + return False class DeviceBiasVisitor(ast.NodeVisitor): - def __init__(self, filename: str): + def __init__(self, filename: str, is_gpu_test_suite: bool) -> None: self.filename = filename self.lint_messages: list[LintMessage] = [] + self.is_gpu_test_suite = is_gpu_test_suite - def _has_requires_gpu_decorator(self, node: ast.FunctionDef) -> bool: + def _has_proper_decorator(self, node: ast.FunctionDef) -> bool: for d in node.decorator_list: - if isinstance(d, ast.Name) and d.id == "requires_gpu": + if isinstance(d, ast.Name) and d.id in GPU_RELATED_DECORATORS: return True if ( isinstance(d, ast.Call) and isinstance(d.func, ast.Name) - and d.func.id == "requires_gpu" + and d.func.id in GPU_RELATED_DECORATORS ): return True return False @@ -62,7 +100,6 @@ def _has_requires_gpu_decorator(self, node: ast.FunctionDef) -> bool: def _check_keyword_device(self, subnode: ast.keyword, msg_prefix: str) -> None: if subnode.arg != "device": return - val = subnode.value if isinstance(val, ast.Constant) and any( bias in val.value for bias in DEVICE_BIAS @@ -124,15 +161,7 @@ def _check_with_statement(self, node: ast.With, msg_prefix: str) -> None: f"{msg_prefix} `with torch.device('{ctx_expr.args[0].value}')`, suggest to use torch.device(GPU_TYPE)", ) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # Check if the function is decorated with @requires_gpu, which indicates - # that the function is intended to run on GPU devices (e.g., CUDA or XPU), - # but ensure it does not hardcode the device to CUDA. - if not self._has_requires_gpu_decorator(node): - self.generic_visit(node) - return - - msg_prefix = "`@requires_gpu` function should not hardcode" + def _check_node(self, node: ast.AST, msg_prefix: str) -> None: for subnode in ast.walk(node): if isinstance(subnode, ast.keyword): self._check_keyword_device(subnode, msg_prefix) @@ -143,6 +172,16 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: elif isinstance(subnode, ast.With): self._check_with_statement(subnode, msg_prefix) + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if self._has_proper_decorator(node): + msg_prefix = ( + "`@requires_gpu` or `@requires_triton` function should not hardcode" + ) + self._check_node(node, msg_prefix) + elif self.is_gpu_test_suite: + # If the function is guarded by HAS_GPU in main(), we still need to check for device bias + msg_prefix = "The test suites is shared amount GPUS, should not hardcode" + self._check_node(node, msg_prefix) self.generic_visit(node) def record(self, node: ast.AST, message: str) -> None: @@ -165,16 +204,16 @@ def check_file(filename: str) -> list[LintMessage]: with open(filename) as f: source = f.read() tree = ast.parse(source, filename=filename) - checker = DeviceBiasVisitor(filename) + is_gpu_test_suite = is_main_has_gpu(tree) + checker = DeviceBiasVisitor(filename, is_gpu_test_suite) checker.visit(tree) - return checker.lint_messages def main() -> None: parser = argparse.ArgumentParser( - description="Detect Device bias in python functions decorated with [require_gpu]" - " that may potentially break support for other GPU devices.", + description="Detect Device bias in functions decorated with requires_gpu/requires_triton" + " or guarded by HAS_GPU block in main() that may break other GPU devices.", fromfile_prefix_chars="@", ) parser.add_argument( From 2f4c2226175512af787725c4d5ad7313c60d4db1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 9 Aug 2025 14:01:58 +0000 Subject: [PATCH 0183/1424] Revert "Make user defined Triton kernels serializable for fx_graph_runnable (#160002)" This reverts commit 4183d4ff3dcc1d87400326a9a7998c3f9e966f60. Reverted https://github.com/pytorch/pytorch/pull/160002 on behalf of https://github.com/albanD due to Breaks inductor tests in trunk ([comment](https://github.com/pytorch/pytorch/pull/160002#issuecomment-3170855866)) --- test/dynamo/test_fx_graph_runnable.py | 88 --------------------------- torch/_dynamo/repro/after_aot.py | 66 -------------------- 2 files changed, 154 deletions(-) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index 47e9ee3cb888e..d5ad0c160c4ba 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -11,65 +11,12 @@ from torch._inductor.codecache import WritableTempFile from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE -from torch.utils._triton import has_triton if torch.distributed.is_available(): from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore -if has_triton(): - import triton - import triton.language as tl - - def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - @triton.jit - def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.atomic_add(output_ptr + offsets, output, mask=mask) - - @triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE": 1024}, - num_warps=4, - num_stages=2, - pre_hook=init_to_zero("output_ptr"), - ) - ], - pre_hook=init_to_zero("output_ptr"), - post_hook=init_to_zero("output_ptr"), - key=["n_elements"], - ) - @triton.jit - def add_kernel_autotune( - x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr - ): - pid = tl.program_id(axis=0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.atomic_add(output_ptr + offsets, output, mask=mask) - - -from torch.testing._internal.inductor_utils import GPU_TYPE -from torch.testing._internal.triton_utils import requires_gpu - class FxGraphRunnableArtifactFilter(logging.Filter): def filter(self, record): @@ -153,41 +100,6 @@ def f(x): torch.compile(f)(torch.randn(4)) self._exec_and_verify_payload() - @unittest.skipUnless(has_triton(), "Triton not available") - def test_user_defined_triton_kernel_autotune(self): - def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - output = torch.ones(x.shape, device=x.device, dtype=x.dtype) - n_elements = output.numel() - - def grid( - meta, - ): - return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - - add_kernel_autotune[grid](x, y, output, n_elements) - return output - - x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) - y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) - - torch.compile(add)(x, y) - self._exec_and_verify_payload() - - @unittest.skipUnless(has_triton(), "Triton not available") - @requires_gpu - def test_user_defined_triton_kernel(self): - def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - output = torch.ones(x.shape, device=x.device, dtype=x.dtype) - n_elements = x.numel() - add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4) - return output - - x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) - y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) - - torch.compile(add)(x, y) - self._exec_and_verify_payload() - def test_two_inputs_matmul(self): def f(a, b): return (a @ b).relu() diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 6f68405e32fdb..71f552a83b4ab 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -34,21 +34,6 @@ from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union from typing_extensions import Unpack -from torch.utils._triton import has_triton - - -if has_triton(): - from triton.runtime.autotuner import Autotuner - from triton.runtime.jit import JITFunction -else: - - class Autotuner: # type: ignore[no-redef] - pass - - class JITFunction: # type: ignore[no-redef] - pass - - import torch import torch.fx as fx import torch.nn as nn @@ -73,7 +58,6 @@ class JITFunction: # type: ignore[no-redef] ) from torch._dynamo.utils import clone_inputs, counters, same from torch._environment import is_fbcode -from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.output_code import OutputCode from torch._library.fake_class_registry import FakeScriptObject @@ -318,16 +302,6 @@ def generate_compiler_repro_string( """ ).strip() - triton_imports = "" - - if len(kernel_side_table.id_to_kernel) > 0: - triton_imports = textwrap.dedent( - """ -import triton -import triton.language as tl - """ - ).strip() - model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -338,7 +312,6 @@ def generate_compiler_repro_string( from math import inf import torch._inductor.inductor_prims {distributed_imports} -{triton_imports} {generate_config_string(stable_output=stable_output)} @@ -357,45 +330,6 @@ def generate_compiler_repro_string( model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() - kernel_side_table_prefix = ( - "torch._higher_order_ops.triton_kernel_wrap.kernel_side_table" - ) - # Track which grid entry corresponds to the best config - for id in kernel_side_table.id_to_kernel: - kernel = kernel_side_table.get_kernel(id) - if isinstance(kernel, Autotuner): - config_strs = [] - for kernel_config in kernel.configs: - config_strs.append(f"""triton.Config( - {str(kernel_config.kwargs)}, - num_warps={kernel_config.num_warps}, - num_stages={kernel_config.num_stages}, - )""") - - config_str = ",".join(config_strs) - model_str += textwrap.dedent(f""" - @triton.autotune( - configs=[ - {config_str} - ], - key=[] - ) - """).strip() - - model_str += "\n@triton.jit\n" - src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src - fn_name = ( - kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name - ) - fn_name = fn_name.split(".")[-1] - - model_str += src_code - model_str += "\n" - model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n" - - if len(kernel_side_table.constant_args) > 0: - model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" - model_str += NNModuleToString.convert(gm) writer = InputWriter(save_dir, stable_hash=stable_hash) From 01f66d08d93365015f4af005a252f439c4d4013a Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 9 Aug 2025 14:23:17 +0000 Subject: [PATCH 0184/1424] Remove outdated CMAKE_CUDA_COMPILER_VERSION branch (#160075) Remove the condition `if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)` in cmake/Codegen.cmake, because we are now default to CUDA >=12.0 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160075 Approved by: https://github.com/Skylion007 --- cmake/Codegen.cmake | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 16ee19a91d487..e4973c849a18f 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -91,30 +91,28 @@ if(INTERN_BUILD_ATEN_OPS) torch_cuda_get_nvcc_gencode_flag(_existing_arch_flags) set(_file_compile_flags "") - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0) - foreach(_arch ${archs}) - if("${_arch}" STREQUAL "89") - if(_existing_arch_flags MATCHES ".*compute_86.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_89,code=sm_89") - endif() + foreach(_arch ${archs}) + if("${_arch}" STREQUAL "89") + if(_existing_arch_flags MATCHES ".*compute_86.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_89,code=sm_89") endif() - if("${_arch}" STREQUAL "90a") - if(_existing_arch_flags MATCHES ".*compute_90.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_90a,code=sm_90a") - endif() + endif() + if("${_arch}" STREQUAL "90a") + if(_existing_arch_flags MATCHES ".*compute_90.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_90a,code=sm_90a") endif() - if("${_arch}" STREQUAL "100a") - if(_existing_arch_flags MATCHES ".*compute_100.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a") - endif() + endif() + if("${_arch}" STREQUAL "100a") + if(_existing_arch_flags MATCHES ".*compute_100.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a") endif() - if("${_arch}" STREQUAL "120a") - if(_existing_arch_flags MATCHES ".*compute_120.*") - list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") - endif() + endif() + if("${_arch}" STREQUAL "120a") + if(_existing_arch_flags MATCHES ".*compute_120.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") endif() - endforeach() - endif() + endif() + endforeach() list(JOIN _file_compile_flags " " _file_compile_flags) set_source_files_properties(${file} PROPERTIES COMPILE_FLAGS "${_file_compile_flags}") From 29712314dd5cf500a8ea3d1c69483a3cb768ca72 Mon Sep 17 00:00:00 2001 From: thenumberouscode Date: Sat, 9 Aug 2025 15:13:13 +0000 Subject: [PATCH 0185/1424] [fx][pass] Support converting a float32 tensor to a scalar in FX trace. (#158216) Fixes https://github.com/pytorch/pytorch/issues/158083 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158216 Approved by: https://github.com/laithsakka --- test/dynamo/test_unspec.py | 34 ++++++++++++++++++++ torch/fx/passes/_tensorify_python_scalars.py | 6 +++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 70ba2a8bd1bd3..91862e6d3eb00 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -714,6 +714,40 @@ def fn(x, y): self.assertEqual(fn_opt(x, y3), fn(x, y3)) self.assertEqual(cnt.frame_count, 1) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_tensorfiy_python_scalars_1(self): + @torch.compile(backend="aot_eager") + def f(x): + y = x.sum() + return x + y.item() + + dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64] + for i, dtype in enumerate(dtypes): + x = torch.ones(3, 3, dtype=dtype) + self.assertEqual(f(x), x + x.sum().item()) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_tensorfiy_python_scalars_2(self): + @torch.compile(backend="aot_eager") + def f(x): + return x.item() * x.item() * torch.ones((), dtype=torch.float64) + + x = torch.tensor(1e20, dtype=torch.float32) + self.assertEqual( + f(x), x.item() * x.item() * torch.ones((), dtype=torch.float64) + ) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_tensorfiy_python_scalars_3(self): + @torch.compile(backend="aot_eager") + def f(x): + y = x.item() * 101 + return y * torch.tensor([1], dtype=torch.float32) + + finfo_float16 = torch.finfo(torch.float16) + x = torch.tensor([finfo_float16.max], dtype=torch.float16) + self.assertEqual(f(x), x.item() * 101 * torch.tensor([1], dtype=torch.float32)) + @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False) def test_unspec_float_input_f64(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index bc7537c23847f..dd8edb50e1612 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -203,7 +203,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: and node.target is torch.ops.aten._local_scalar_dense.default ): dtype = node.args[0].meta["val"].dtype - if dtype != torch.float64: + if not dtype.is_floating_point: continue assert isinstance(node.args[0], fx.Node), node.args[0] @@ -212,6 +212,10 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: expr_to_tensor_proxy[s] = MetaProxy( node.args[0], tracer=tracer, fake_mode=fake_mode ) + # Upcast the float tensor to torch.float64 to avoid precision problem + expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default( + expr_to_tensor_proxy[s], torch.float64 + ) expr_to_sym_proxy[s] = MetaProxy( node, tracer=tracer, fake_mode=fake_mode ) From db78943a1ca13a32a3d6045eb15e2b719ee13a2f Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sat, 9 Aug 2025 18:15:46 +0000 Subject: [PATCH 0186/1424] Fix get_free_symbol_uses for several nodes. (#160134) get_free_symbol_uses is used to know what unbacked symbols are used by a given node. not having correct get_free_symbol_uses defined properly leads to : 1. eliminating of some nodes due to not detection of any users. (See the added unit test) 2. Incorrect topological sort. Fix get_free_symbol_uses , NopKernel , ConcarKernel, InputsKerenl, external kernel. for ComputedBuffer with NonOwningLayout its interesting case. when layout is NonOwningLayout we need to access the actual view op base layout and use detect symbols in it. Because when we codegen the ComputedBuffer we uses those symbols. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160134 Approved by: https://github.com/bobrenjc93 --- test/test_dynamic_shapes.py | 11 ++++++++++ torch/_inductor/ir.py | 44 +++++++++++++++++++++++++++++++++---- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 6a721a079a635..dd8695ae4ac50 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3616,6 +3616,17 @@ def func3(x, y): def test_unbacked_select_index_cpp_wrapper(self): self.test_unbacked_select_index() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_unbacked_select2(self): + def f(idx, x): + x = x.select(0, idx.item()) + return x @ x + + x = torch.randn(3, 3, 3) + idx = torch.tensor(1, dtype=torch.int64) + out = torch.compile(f)(idx, x) + self.assertEqual(out, f(idx, x)) + instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4f9f2f1e0b59f..2cc68dcb37824 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4443,7 +4443,8 @@ def get_free_symbol_uses( # unusual reason: we only need accurate dependencies for item() call, # but it's impossible to end up with a reduction over i0 from an # item() call without a regular non-reduction buffer first. - return ( + + result = ( get_free_symbols(self.get_size(), unbacked_only) | get_free_symbols(self.get_stride(), unbacked_only) | get_free_symbols(self.get_offset(), unbacked_only) @@ -4451,6 +4452,21 @@ def get_free_symbol_uses( | self.get_read_writes().get_free_symbol_uses(unbacked_only) ) + if isinstance(self.layout, NonOwningLayout): + assert isinstance(self.layout.view, ReinterpretView) + box = self.layout.view.data + assert isinstance(box, StorageBox), type(box) + input_buffer = box.data + assert isinstance(input_buffer, Buffer), type(box) + result = ( + result + | get_free_symbols(input_buffer.get_size(), unbacked_only) + | get_free_symbols(input_buffer.get_stride(), unbacked_only) + | get_free_symbols(input_buffer.get_offset(), unbacked_only) + ) + + return result + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: if ( not self.get_reduction_type() @@ -5126,6 +5142,18 @@ def get_read_writes(self) -> dependencies.ReadWrites: def get_reads(self) -> OrderedSet[Dep]: return self.get_read_writes().reads + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + r = OrderedSet[sympy.Symbol]() + for inp in self.inputs: + if isinstance(inp, IRNode): + r |= inp.get_free_symbol_uses(unbacked_only) + else: + for inner_inp in inp: + r |= inner_inp.get_free_symbol_uses(unbacked_only) + return r + @classmethod def unwrap_storage_for_input(cls, x: IRNode) -> IRNode: if isinstance(x, TensorBox): @@ -5172,6 +5200,11 @@ def is_no_op(self) -> bool: def get_reads(self) -> OrderedSet[Dep]: return OrderedSet() + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return InputsKernel.get_free_symbol_uses(self, unbacked_only) + class ConcatKernel(NopKernel): """ @@ -5326,6 +5359,11 @@ def can_realize_into_without_copy( and not isinstance(src.data, ExternKernelAlloc) ) + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return NopKernel.get_free_symbol_uses(self, unbacked_only) + @classmethod def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode: # Attempt to turn this into a ReinterpretView rather than assert. @@ -6221,12 +6259,10 @@ def canonicalize(self) -> tuple[Expr, Sequence[Expr]]: def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: - # NB: It's not necessary to check regular inputs as we automatically - # have dependencies on them maybe_get_symbols = ( maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols ) - r = OrderedSet[sympy.Symbol]() + r = InputsKernel.get_free_symbol_uses(self, unbacked_only) for arg in self.constant_args: r |= maybe_get_symbols(arg) for arg in self.kwargs.values(): From f0980fc0bbd656d6c02d23ad97e945353b314f35 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sat, 9 Aug 2025 21:06:00 +0000 Subject: [PATCH 0187/1424] [inductor] turn on windows inductor UTs (#160161) With this PR, we can turn on the inductor UTs on Windows CPU. changes: 1. Turn on inductor UTs on Windows CPU. 2. Add a shard to balance added UTs, otherwise it should run timeout. 3. Fixed `test_invalid_artifact_flag_error_msg`. 4. Skiped `test_distributed_rank_logging` and `test_disable_recursive_false`. 5. Skiped whole UT `test_cpu_select_algorithm.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160161 Approved by: https://github.com/jansel --- .github/workflows/trunk.yml | 7 ++++--- test/dynamo/test_decorators.py | 4 ++++ test/dynamo/test_logging.py | 5 ++++- test/inductor/test_cpu_select_algorithm.py | 3 ++- torch/_dynamo/test_case.py | 8 +++----- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index c7cf4c84e1888..c428127dc6dd2 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -123,9 +123,10 @@ jobs: runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, ]} secrets: inherit diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 3b29e5e961192..9bf982c5b90ec 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -10,6 +10,7 @@ import torch._dynamo.testing from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import skipIfWindows def my_custom_function(x): @@ -892,6 +893,9 @@ def gn(x): self.assertEqual(gn(inp), inp + 3) self.assertEqual(cnts.frame_count, 1) + @skipIfWindows( + msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows." + ) def test_disable_recursive_false(self): def fn2(x): return x + 1 diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index bcea00cdc98f1..c3a37d17d8130 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -21,8 +21,10 @@ from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.common_utils import ( find_free_port, + IS_WINDOWS, munge_exc, skipIfTorchDynamo, + skipIfWindows, TEST_XPU, xfailIf, ) @@ -528,7 +530,7 @@ def test_invalid_artifact_flag_error_msg(self): "import torch", env=env, ) - lines = stderr.decode().split("\n") + lines = stderr.decode().split("\r\n" if IS_WINDOWS else "\n") # This is a sanity assert that our error is not spammy. # As of this test creation this was 18. # See this issue for the purpose o this test: @@ -544,6 +546,7 @@ def test_invalid_artifact_flag_error_msg(self): self.assertEqual(lines[-4], "Valid settings:") @requires_distributed() + @skipIfWindows(msg="TODO: (xuhancn), Can't reproduce locally") def test_distributed_rank_logging(self): env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 7e35c93ee0b79..75d091595cd8a 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -26,6 +26,7 @@ ) from torch.testing._internal.common_utils import ( IS_MACOS, + IS_WINDOWS, parametrize, skipIfWindows, TEST_MKL, @@ -3094,5 +3095,5 @@ def forward(self, x, weight): if __name__ == "__main__": from torch.testing._internal.inductor_utils import HAS_CPU - if HAS_CPU and not IS_MACOS: + if HAS_CPU and not (IS_MACOS or IS_WINDOWS): run_tests() diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 230aac4794f25..f8bde6222dbea 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -41,11 +41,9 @@ def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF: return # skip testing - if ( - not torch.xpu.is_available() - and IS_WINDOWS - and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0" - ): + # Enable Inductor UTs on Windows for CPU. + # CUDA on Windows is not verified, NVDA developer can continue to enable CUDA based on CPU path. + if torch.cuda.is_available() and IS_WINDOWS: return if isinstance(needs, str): From df55ec7d4b35f6d21691e9dd41c82f27de762948 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 8 Aug 2025 17:10:04 -0700 Subject: [PATCH 0188/1424] [OpInfo][BE] Better inputs for addmm (#160234) Right now alpha and betha are both less than zero, which makes them useless for all addmm samples for interal types Pull Request resolved: https://github.com/pytorch/pytorch/pull/160234 Approved by: https://github.com/Skylion007 ghstack dependencies: #160228 --- torch/testing/_internal/common_methods_invocations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 41bb2b96bd938..506bf5488f3c0 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1161,8 +1161,8 @@ def make_arg_conj(size): def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): - alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) - beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2) + alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6 if dtype.is_floating_point else 2) + beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2 if dtype.is_floating_point else 3) tests_list = [ ((2, 3), (2, 2), (2, 3), False), ((3, 3), (3, 3), (3, 3), False), From d3d359dbafa89173a371e2637f22b47398e94a24 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 10 Aug 2025 02:37:40 +0000 Subject: [PATCH 0189/1424] Revert "Fix get_free_symbol_uses for several nodes. (#160134)" This reverts commit db78943a1ca13a32a3d6045eb15e2b719ee13a2f. Reverted https://github.com/pytorch/pytorch/pull/160134 on behalf of https://github.com/malfet due to No, those are not pre-existing, see https://hud.pytorch.org/hud/pytorch/pytorch/df55ec7d4b35f6d21691e9dd41c82f27de762948/1?per_page=50&name_filter=lint&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/160134#issuecomment-3172314322)) --- test/test_dynamic_shapes.py | 11 ---------- torch/_inductor/ir.py | 44 ++++--------------------------------- 2 files changed, 4 insertions(+), 51 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index dd8695ae4ac50..6a721a079a635 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3616,17 +3616,6 @@ def func3(x, y): def test_unbacked_select_index_cpp_wrapper(self): self.test_unbacked_select_index() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_unbacked_select2(self): - def f(idx, x): - x = x.select(0, idx.item()) - return x @ x - - x = torch.randn(3, 3, 3) - idx = torch.tensor(1, dtype=torch.int64) - out = torch.compile(f)(idx, x) - self.assertEqual(out, f(idx, x)) - instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 2cc68dcb37824..4f9f2f1e0b59f 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4443,8 +4443,7 @@ def get_free_symbol_uses( # unusual reason: we only need accurate dependencies for item() call, # but it's impossible to end up with a reduction over i0 from an # item() call without a regular non-reduction buffer first. - - result = ( + return ( get_free_symbols(self.get_size(), unbacked_only) | get_free_symbols(self.get_stride(), unbacked_only) | get_free_symbols(self.get_offset(), unbacked_only) @@ -4452,21 +4451,6 @@ def get_free_symbol_uses( | self.get_read_writes().get_free_symbol_uses(unbacked_only) ) - if isinstance(self.layout, NonOwningLayout): - assert isinstance(self.layout.view, ReinterpretView) - box = self.layout.view.data - assert isinstance(box, StorageBox), type(box) - input_buffer = box.data - assert isinstance(input_buffer, Buffer), type(box) - result = ( - result - | get_free_symbols(input_buffer.get_size(), unbacked_only) - | get_free_symbols(input_buffer.get_stride(), unbacked_only) - | get_free_symbols(input_buffer.get_offset(), unbacked_only) - ) - - return result - def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: if ( not self.get_reduction_type() @@ -5142,18 +5126,6 @@ def get_read_writes(self) -> dependencies.ReadWrites: def get_reads(self) -> OrderedSet[Dep]: return self.get_read_writes().reads - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - r = OrderedSet[sympy.Symbol]() - for inp in self.inputs: - if isinstance(inp, IRNode): - r |= inp.get_free_symbol_uses(unbacked_only) - else: - for inner_inp in inp: - r |= inner_inp.get_free_symbol_uses(unbacked_only) - return r - @classmethod def unwrap_storage_for_input(cls, x: IRNode) -> IRNode: if isinstance(x, TensorBox): @@ -5200,11 +5172,6 @@ def is_no_op(self) -> bool: def get_reads(self) -> OrderedSet[Dep]: return OrderedSet() - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return InputsKernel.get_free_symbol_uses(self, unbacked_only) - class ConcatKernel(NopKernel): """ @@ -5359,11 +5326,6 @@ def can_realize_into_without_copy( and not isinstance(src.data, ExternKernelAlloc) ) - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return NopKernel.get_free_symbol_uses(self, unbacked_only) - @classmethod def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode: # Attempt to turn this into a ReinterpretView rather than assert. @@ -6259,10 +6221,12 @@ def canonicalize(self) -> tuple[Expr, Sequence[Expr]]: def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: + # NB: It's not necessary to check regular inputs as we automatically + # have dependencies on them maybe_get_symbols = ( maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols ) - r = InputsKernel.get_free_symbol_uses(self, unbacked_only) + r = OrderedSet[sympy.Symbol]() for arg in self.constant_args: r |= maybe_get_symbols(arg) for arg in self.kwargs.values(): From 5dddcd5b07c6644efca8d613f4eca1dc95daa87f Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 6 Aug 2025 07:18:42 -0700 Subject: [PATCH 0190/1424] Correctly copy self.module_stack in ModuleStackTracer (#159956) There is a bigger cluster of issues which this does not completely fix, but I think this is a matter of good hygiene, especially because we immediately mutate the dict after assigning it. Signed-off-by: Edward Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/159956 Approved by: https://github.com/pianpwk --- torch/fx/experimental/proxy_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index a578723ea1cbb..9f2c40904634e 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1959,7 +1959,7 @@ def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: # nn_module_stack if node.op not in ["placeholder", "output"]: if "nn_module_stack" not in node.meta: - node.meta["nn_module_stack"] = self.module_stack + node.meta["nn_module_stack"] = self.module_stack.copy() # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]] for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): if isinstance(mod_cls, type): From af10f1f86cc4effc93142a447693d8be55966615 Mon Sep 17 00:00:00 2001 From: ghostspiders <15834128411@126.com> Date: Sun, 10 Aug 2025 07:05:52 +0000 Subject: [PATCH 0191/1424] Fix requires_cuda to requires_cuda_and_triton (#160222) Fixes ##159399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160222 Approved by: https://github.com/janeyx99 --- .../fsdp/test_fully_shard_logging.py | 2 - test/dynamo/test_activation_checkpointing.py | 61 +++++++-------- test/dynamo/test_aot_autograd_cache.py | 10 +-- test/dynamo/test_autograd_function.py | 9 ++- test/dynamo/test_backends.py | 7 +- test/dynamo/test_base_hop.py | 5 -- test/dynamo/test_callback.py | 4 +- test/dynamo/test_compiler_bisector.py | 6 +- test/dynamo/test_debug_utils.py | 4 - test/dynamo/test_higher_order_ops.py | 19 ++--- test/dynamo/test_logging.py | 6 +- test/dynamo/test_structured_trace.py | 13 ++-- test/dynamo/test_subclasses.py | 6 +- test/export/test_export.py | 10 +-- test/export/test_torchbind.py | 6 +- test/higher_order_ops/test_invoke_subgraph.py | 6 +- test/inductor/test_codecache.py | 27 ++++--- test/inductor/test_combo_kernels.py | 42 +++++----- test/inductor/test_compiled_autograd.py | 13 ++-- test/inductor/test_compiled_optimizers.py | 6 +- test/inductor/test_cudacodecache.py | 12 +-- test/inductor/test_cudagraph_trees.py | 4 +- test/inductor/test_cutlass_backend.py | 2 +- test/inductor/test_foreach.py | 78 +++++++++---------- test/inductor/test_inductor_annotations.py | 6 +- test/inductor/test_perf.py | 31 ++++---- test/inductor/test_provenance_tracing.py | 12 +-- .../inductor/test_split_cat_fx_aten_passes.py | 10 +-- test/inductor/test_static_cuda_launcher.py | 6 +- test/inductor/test_torchinductor.py | 8 +- test/inductor/test_torchinductor_opinfo.py | 5 +- test/inductor/test_triton_kernels.py | 2 +- test/test_foreach.py | 4 +- torch/testing/_internal/triton_utils.py | 4 +- 34 files changed, 212 insertions(+), 234 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_logging.py b/test/distributed/_composable/fsdp/test_fully_shard_logging.py index fac56ad0b8d42..c9450a2b8f475 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_logging.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_logging.py @@ -6,11 +6,9 @@ import torch.distributed as dist from torch._dynamo.test_case import run_tests from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.logging_utils import LoggingTestCase -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index ea0882744c546..6b7662cbe646c 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -19,7 +19,7 @@ from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor from torch.utils.checkpoint import ( checkpoint, @@ -28,7 +28,6 @@ ) -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) @@ -243,7 +242,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton def test_tags_function_via_global_checkpoint(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -262,7 +261,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton def test_tags_function_with_kwargs(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -282,7 +281,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton def test_tags_sequential_layers(self, device): def gn(x): x = x.cos() @@ -307,7 +306,7 @@ def fn(x): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) - @requires_cuda + @requires_cuda_and_triton def test_tags_multiple_checkpoints(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -329,7 +328,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton def test_tags_module(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -357,7 +356,7 @@ def fn(x): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) - @requires_cuda + @requires_cuda_and_triton def test_tags_decomps(self, device): # Ensures that tags are passed on through decompositions as well class MockModule(torch.nn.Module): @@ -392,7 +391,7 @@ def fn(x): ) self._validate(fn, backend, x) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch(fallback_random=True) def test_tags_recomputed_rand(self, device): def gn(x, y): @@ -416,7 +415,7 @@ def fn(x, y): backend = "inductor" self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch(fallback_random=True) def test_tags_rand(self, device): def gn(x, y): @@ -443,7 +442,7 @@ def fn(x, y): backend = "inductor" self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch(fallback_random=True) def test_tags_dropout(self, device): # Figure out a way to test the number of inductor_random calls @@ -551,7 +550,7 @@ def _factory_fn(): Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""", ) - @requires_cuda + @requires_cuda_and_triton def test_fallback(self, device): def gn(x, y): torch._dynamo.graph_break() @@ -579,7 +578,7 @@ def fn(x, y): self.assertEqual(cnt.op_count, 2) self.assertEqual(len(cnt.graphs), 2) - @requires_cuda + @requires_cuda_and_triton def test_kwargs(self, device): def gn(x, y, z=None): a = torch.matmul(x, y) @@ -613,7 +612,7 @@ def fn(x, y, z): body_function = getattr(cnt.graphs[0], wrap_node.args[0].name) self.assertEqual(op_count(body_function), 2) - @requires_cuda + @requires_cuda_and_triton def test_symints_location(self, device): def gn(x, y): return torch.matmul(x, torch.nn.functional.dropout(y, 0.5)) @@ -643,7 +642,7 @@ def fn(x, y): wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) self.assertEqual(len(wrap_node.args), 3) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_must_recompute(self, device): def context_fn_must_recompute_mm(): @@ -710,7 +709,7 @@ def fn(x): ), ) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): def selective_checkpointing_context_fn(): @@ -757,7 +756,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_tensor_subclass(self, device): def selective_checkpointing_context_fn(): @@ -807,7 +806,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_custom_rule(self, device): def _get_custom_policy(meta): @@ -872,7 +871,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_partial_ctx_fn(self, device): def selective_checkpointing_context_fn(no_recompute_list): @@ -918,7 +917,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_outplace_op(self, device): def selective_checkpointing_context_fn(): @@ -963,7 +962,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_list_ops(self, device): def selective_checkpointing_context_fn(): @@ -1011,7 +1010,7 @@ def fn(x, y): "In-place op support in selective checkpointing + torch.compile " "requires TorchDispatchMode + torch.compile work to complete" ) - @requires_cuda + @requires_cuda_and_triton def test_compile_selective_checkpoint_inplace_op(self, device): def selective_checkpointing_context_fn(): no_recompute_list = [ @@ -1057,7 +1056,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @torch._inductor.config.patch(fallback_random=True) def test_compile_selective_checkpoint_random_op(self, device): @@ -1117,7 +1116,7 @@ def fn(x): self._validate(fn, backend, x, skip_check=not preserve_rng_state) self._compare_orig_and_checkpointed_fns(gn, fn, x) - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_invalid_context(self): def gn(x, y): @@ -1155,7 +1154,7 @@ def fn(x, y): ): self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_compile_selective_checkpoint_parametrization(self): def sac_policy(): @@ -1249,7 +1248,7 @@ def reset_parameters(self): self.assertEqual(input.grad, input_compiled.grad) @skipIfRocm - @requires_cuda + @requires_cuda_and_triton def test_autocast_flash_attention(self, device): def fn(primals_1, primals_2, primals_3): return torch.ops.aten._scaled_dot_product_efficient_attention.default( @@ -1273,7 +1272,7 @@ def gn(*args): res = opt_gn(*args) self.assertEqual(ref, res) - @requires_cuda + @requires_cuda_and_triton def test_error_msg(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -1297,7 +1296,7 @@ def fn(x): ): opt_fn(x) - @requires_cuda + @requires_cuda_and_triton def test_list_inputs(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -1322,7 +1321,7 @@ def fn(x, ys): res = opt_fn(x, [y, z]) self.assertEqual(ref, res) - @requires_cuda + @requires_cuda_and_triton def test_pattern_matcher(self, device): # Check that the sdpa op is recomputed in the backward graph # tests percolate_tags @@ -1402,7 +1401,7 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): ) @requires_distributed() - @requires_cuda + @requires_cuda_and_triton def test_distributed_utils_checkpoint_wrapper(self): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as dist_checkpoint_wrapper, @@ -1428,7 +1427,7 @@ def forward(self, x): self.assertEqual(ref, res) @requires_distributed() - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_dynamo_does_not_trace_getattr_as_top_frame(self): # inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do. diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index d26e4b31917e0..2895c8991c22c 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -37,7 +37,7 @@ skipIfWindows, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_triton -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor @@ -690,7 +690,7 @@ def fn(a, b): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) - @requires_cuda + @requires_cuda_and_triton @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -746,7 +746,7 @@ def backward(ctx, grad_output): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) - @requires_cuda + @requires_cuda_and_triton @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -788,7 +788,7 @@ def fn(a): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) - @requires_cuda + @requires_cuda_and_triton @requires_triton() @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @@ -1260,7 +1260,7 @@ def f(): result = f() self.assertEqual(result[0].device, torch.device("cuda:1")) - @requires_cuda + @requires_cuda_and_triton @inductor_config.patch("fx_graph_cache", True) @inductor_config.patch("fx_graph_remote_cache", False) @functorch_config.patch({"enable_autograd_cache": True}) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index d93a00f8ae106..de5afce145984 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -8,7 +8,10 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils -from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_cuda +from torch.testing._internal.triton_utils import ( + HAS_CUDA_AND_TRITON, + requires_cuda_and_triton, +) if HAS_CUDA_AND_TRITON: @@ -1473,7 +1476,7 @@ def fn(): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_basic(self): class Add(torch.autograd.Function): @staticmethod @@ -1504,7 +1507,7 @@ def f(x, y): loss.backward() self.assertEqual(x + y, z) - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_multiple_out(self): class Add(torch.autograd.Function): @staticmethod diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index 2b927880cae31..be1470c08e794 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -16,10 +16,7 @@ onlyHPU, ) from torch.testing._internal.common_utils import skipIfHpu -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON - - -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") +from torch.testing._internal.triton_utils import requires_cuda_and_triton class Seq(torch.nn.Module): @@ -133,7 +130,7 @@ def test_aot_eager_decomp_partition(self, device): def test_aot_ts(self, device): self._check_backend_works("aot_ts", device) - @requires_cuda + @requires_cuda_and_triton def test_aot_cudagraphs(self, device): self._check_backend_works("cudagraphs", device) diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index 30252d88a3782..607b502351aaf 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import unittest import unittest.mock as mock import torch @@ -13,10 +12,6 @@ ) from torch._higher_order_ops.schema import find_hop_schema from torch.testing._internal.common_utils import instantiate_parametrized_tests -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON - - -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") def normalize_graph(gm): diff --git a/test/dynamo/test_callback.py b/test/dynamo/test_callback.py index c45fac7933c7d..e516364626314 100644 --- a/test/dynamo/test_callback.py +++ b/test/dynamo/test_callback.py @@ -8,7 +8,7 @@ from torch._dynamo.test_case import run_tests, TestCase from torch._guards import CompileId from torch.testing._internal.common_utils import TEST_WITH_ROCM -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +from torch.testing._internal.triton_utils import requires_cuda_and_triton class CallbackTests(TestCase): @@ -61,7 +61,7 @@ def test_counter_assertion(self) -> None: @unittest.skipIf( TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs" ) - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires triton") + @requires_cuda_and_triton @torch._inductor.config.patch(force_disable_caches=True) def test_triggers(self) -> None: torch._dynamo.reset() diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index cce1b7bc9183f..161f9674cd4a1 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -1,6 +1,5 @@ # Owner(s): ["module: dynamo"] -import unittest from contextlib import contextmanager from importlib import import_module @@ -11,19 +10,18 @@ from torch._inductor.compiler_bisector import CompilerBisector from torch._inductor.test_case import TestCase from torch.library import _scoped_library, Library -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +from torch.testing._internal.triton_utils import requires_cuda_and_triton aten = torch.ops.aten -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") f32 = torch.float32 i64 = torch.int64 i32 = torch.int32 -@requires_cuda +@requires_cuda_and_triton class TestCompilerBisector(TestCase): test_ns = "_test_bisector" diff --git a/test/dynamo/test_debug_utils.py b/test/dynamo/test_debug_utils.py index 1315fa8d9c51a..eae4d06d98904 100644 --- a/test/dynamo/test_debug_utils.py +++ b/test/dynamo/test_debug_utils.py @@ -1,7 +1,6 @@ # Owner(s): ["module: dynamo"] import os -import unittest from unittest.mock import patch import torch @@ -10,11 +9,8 @@ from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string from torch._dynamo.test_case import TestCase from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") - f32 = torch.float32 i64 = torch.int64 i32 = torch.int32 diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 441a10aeba43f..5844a13fcad00 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -38,11 +38,8 @@ xfailIfTorchDynamo, ) from torch.testing._internal.hop_db import hop_db -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test - - -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") +from torch.testing._internal.triton_utils import requires_cuda_and_triton def count_ops(gm, args, freq, op): @@ -6845,7 +6842,7 @@ def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True): for arg, cloned_arg in zip(args, cloned_args): self.assertEqual(arg.grad, cloned_arg.grad) - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_function(self): def gn(x, y): @@ -6864,7 +6861,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_function_with_kwargs(self): def gn(x, y): @@ -6887,7 +6884,7 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_dropout(self): def gn(x, y): @@ -6913,7 +6910,7 @@ def fn(x, y): fn, backend, x, y, skip_check=True ) # dropout decomp is known to diverge with eager - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_dropout_inductor(self): def gn(x, y): @@ -6932,7 +6929,7 @@ def fn(x, y): fn, backend, x, y, skip_check=True ) # dropout decomp is known to diverge with eager - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_fallback(self): def gn(x, y): @@ -6963,7 +6960,7 @@ def fn(x, y): self.assertEqual(cnt.op_count, 2) self.assertEqual(len(backend.graphs), 2) - @requires_cuda + @requires_cuda_and_triton @torch._functorch.config.patch(functionalize_rng_ops=True) def test_module(self): class MockModule(torch.nn.Module): @@ -7216,7 +7213,7 @@ def false_branch(x): class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase): - @requires_cuda + @requires_cuda_and_triton @parametrize("backend", ("aot_eager", "inductor")) @ops( list(filter(lambda op: op.name not in xfail_hops_compile, hop_db)), diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index c3a37d17d8130..a5a6ee54aa74a 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -37,9 +37,9 @@ make_logging_test, make_settings_test, ) +from torch.testing._internal.triton_utils import requires_cuda_and_triton -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_gpu = unittest.skipUnless( HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON, "requires cuda or xpu with triton" ) @@ -139,7 +139,7 @@ def test_fusion(self, records): self.assertGreater(len(records), 0) self.assertLess(len(records), 8) - @requires_cuda + @requires_cuda_and_triton @make_logging_test(cudagraphs=True) def test_cudagraphs(self, records): fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) @@ -252,7 +252,7 @@ def throw(x): exitstack.close() @requires_distributed() - @requires_cuda + @requires_cuda_and_triton @make_logging_test(ddp_graphs=True) def test_ddp_graphs(self, records): class ToyModel(torch.nn.Module): diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index ece491d764ddf..a930fb0406dbd 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -22,7 +22,7 @@ from torch._logging._internal import TorchLogsFormatter from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_utils import find_free_port -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +from torch.testing._internal.triton_utils import requires_cuda_and_triton if torch.distributed.is_available(): @@ -31,7 +31,6 @@ HAS_TLPARSE = shutil.which("tlparse") is not None requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) @@ -238,7 +237,7 @@ def test_compile_id_serialization_deserialization(self): with self.assertRaises(ValueError): torch._guards.CompileId.from_string(bad_cid) - @requires_cuda + @requires_cuda_and_triton def test_schedule(self): fn_opt = torch.compile(inductor_schedule_fn, backend="inductor") fn_opt(torch.ones(1000, 1000, device="cuda")) @@ -271,7 +270,7 @@ def test_schedule(self): self.assertParses() - @requires_cuda + @requires_cuda_and_triton def test_cudagraphs(self): fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) fn_opt(torch.ones(1000, 1000, device="cuda")) @@ -535,7 +534,7 @@ def throw(x): self.assertParses() @requires_distributed() - @requires_cuda + @requires_cuda_and_triton def test_ddp_graphs(self): class ToyModel(torch.nn.Module): def __init__(self) -> None: @@ -1226,7 +1225,7 @@ def _setup_runtime_estimates_capture(self): @requires_tlparse @requires_distributed() - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch("fx_graph_cache", False) @torch._inductor.config.patch("log_tlparse", True) def test_runtime_estimates_simple(self): @@ -1287,7 +1286,7 @@ def forward(self, x): @requires_tlparse @requires_distributed() - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch("fx_graph_cache", False) @torch._inductor.config.patch("log_tlparse", True) def test_runtime_estimates_mixed(self): diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index ef4158b4a65b6..9d60cbe81c970 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -31,7 +31,7 @@ parametrize, subtest, ) -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor from torch.utils._python_dispatch import return_and_correct_aliasing @@ -145,8 +145,6 @@ def mk_subclass_dense_subclass_dense(): VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()} -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") - compile_full_eager = torch.compile(backend="eager", fullgraph=True) @@ -3798,7 +3796,7 @@ def fn1(nt1, nt2): def test_basic_autograd(self): self._test_autograd("aot_eager") - @requires_cuda + @requires_cuda_and_triton def test_basic_autograd_inductor(self): self._test_autograd("inductor") diff --git a/test/export/test_export.py b/test/export/test_export.py index 848373aef6841..1c997b8e86beb 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -86,7 +86,7 @@ ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.torchbind_impls import load_torchbind_test_lib -from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu from torch.testing._internal.two_tensor import TwoTensor from torch.utils._pytree import ( LeafSpec, @@ -8382,7 +8382,7 @@ def forward(self, x): len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1 ) - @requires_cuda + @requires_cuda_and_triton @testing.expectedFailureCppRuntime def test_export_associative_scan_symbol_dim(self): device = torch.device("cuda") @@ -8407,7 +8407,7 @@ def forward(self, x): module_out = Foo()(xs) self.assertTrue(torch.allclose(ep.module()(xs), module_out)) - @requires_cuda + @requires_cuda_and_triton @testing.expectedFailureCppRuntime def test_export_associative_scan_symbol_scandim(self): device = torch.device("cuda") @@ -8432,7 +8432,7 @@ def forward(self, x): module_out = Foo()(xs) self.assertTrue(torch.allclose(ep.module()(xs), module_out)) - @requires_cuda + @requires_cuda_and_triton def test_export_associative_scan_lifted_buffers(self): if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") @@ -15917,7 +15917,7 @@ def forward(self, x): len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs) ) - @requires_cuda + @requires_cuda_and_triton def test_assert_tensor_metadata_device_index(self): class N(torch.nn.Module): def __init__(self): diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index c6f770e19c85a..d24262dab2b1c 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -24,7 +24,7 @@ _empty_tensor_queue, init_torchbind_implementations, ) -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton def _assertEqualSkipScriptObject(test_case, exp, actual): @@ -1552,7 +1552,7 @@ def f(tq, x): self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) ) - @requires_cuda + @requires_cuda_and_triton @parametrize("device", ["cpu", "cuda"]) @parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_compile_obj_torchbind_op_with_autocast(self, backend, device): @@ -1570,7 +1570,7 @@ def f(tq, x): self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) ) - @requires_cuda + @requires_cuda_and_triton @parametrize("device", ["cpu", "cuda"]) def test_export_obj_torchbind_op_with_autocast(self, device): class Mod(torch.nn.Module): diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index c800eb78f905a..46d796f1dac37 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -34,7 +34,7 @@ TestCase, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU -from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu nested_compile_region = torch.compiler.nested_compile_region @@ -556,7 +556,7 @@ def fn(x): self.assertEqual(ref, res) self.assertEqual(x.grad, x_clone.grad) - @requires_cuda + @requires_cuda_and_triton def test_sdpa(self): @nested_compile_region def gn(q, k, v): @@ -1447,7 +1447,7 @@ def forward(self, l_x_: "f32[8, 8]"): """, ) - @requires_cuda + @requires_cuda_and_triton def test_return_none(self): from torch.nn import functional as F diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 3597663431fde..f75a867974671 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -59,7 +59,6 @@ ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, - HAS_CUDA_AND_TRITON, HAS_GPU, HAS_MULTIGPU, HAS_TRITON, @@ -67,7 +66,7 @@ requires_gpu, requires_triton, ) -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton try: @@ -872,7 +871,7 @@ def fn(x): @torch._functorch.config.patch({"enable_autograd_cache": False}) @config.patch("fx_graph_remote_cache", False) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton def test_no_arguments_tensor_device_guards(self): """ Usually, when there are example inputs, the device index of the inputs @@ -902,7 +901,7 @@ def f(): @torch._functorch.config.patch({"enable_autograd_cache": False}) @config.patch("fx_graph_remote_cache", False) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton def test_tensor_device_guards_cpu_tensor(self): """ CPU tensor arguments should still cache hit @@ -1006,7 +1005,7 @@ def fn(x, op): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) - @requires_cuda + @requires_cuda_and_triton @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @with_tf32_off @@ -1464,7 +1463,7 @@ def f(x, val): self.assertNotEqual(a, b) @config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False}) - @requires_cuda + @requires_cuda_and_triton @unittest.expectedFailure # TODO: pass in optimize_mem at runtime def test_async_compile_cache(self): class SimpleFunction(torch.autograd.Function): @@ -2574,7 +2573,7 @@ def test_get_hash_for_files(self): class TestCudaCompileCommand(TestCase): - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton def test_cuda_compile_command(self): cmd_no_extra_args: str = cuda_compile_command( ["abc.cu", "def.cu"], "output", "so" @@ -2619,7 +2618,7 @@ def reset(self): torch._dynamo.reset() clear_caches() - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @unittest.skipIf( TEST_WITH_ROCM, "Requires static cuda launcher, which does not support ROCM" @@ -2670,7 +2669,7 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2711,7 +2710,7 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2772,7 +2771,7 @@ def f(a, b, c, d, e, f): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") @requires_triton() - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2836,7 +2835,7 @@ def fn(x, y): class TestRemoteAOTAutogradCache(TestCase): - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": True}) @@ -2875,7 +2874,7 @@ def f(a, b): for k in global_stats.fx_graph.cache.keys(): self.assertRegex(k, r"pt2:fx-graph-v1::[0-9a-z]{52}:c[0-9]+") - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "Requires CUDA") + @requires_cuda_and_triton @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": True}) @@ -2950,7 +2949,7 @@ def fn(x, y): # This combination of settings exposed a bug where we cleared the # PyCodeCache disk artifacts while they were still needed: - @requires_cuda + @requires_cuda_and_triton @config.patch( { "coordinate_descent_tuning": True, diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index 480094dfb7481..90399546d26ea 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -11,7 +11,7 @@ TestCase, ) from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton aten = torch.ops.aten @@ -55,7 +55,7 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() - @requires_cuda + @requires_cuda_and_triton def test_activation_functions(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -75,7 +75,7 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_reduce_functions(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -98,7 +98,7 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2) - @requires_cuda + @requires_cuda_and_triton def test_mutated_args(self): def test_mutated(a, b, c, d): a.add_(1) @@ -121,7 +121,7 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_reduce_split(self): def fn(a, b): a1 = torch.linalg.vector_norm(a) @@ -137,7 +137,7 @@ def fn(a, b): self.assertEqual(out_eager, out_compiled) - @requires_cuda + @requires_cuda_and_triton def test_2d_blocking_partitioning(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -184,7 +184,7 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() - @requires_cuda + @requires_cuda_and_triton def test_activation_benchmark(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -204,7 +204,7 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) - @requires_cuda + @requires_cuda_and_triton def test_reduce_benchmark(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -227,7 +227,7 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) - @requires_cuda + @requires_cuda_and_triton def test_mutated_benchmark(self): def test_mutated(a, b, c, d): a.add_(1) @@ -250,7 +250,7 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9]) - @requires_cuda + @requires_cuda_and_triton def test_round_robin_dispatch(self): # combo kernel dispatch strategy: round robin def test_mutated(a, b, c, d): @@ -274,7 +274,7 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) - @requires_cuda + @requires_cuda_and_triton def test_2d_blocking_benchmark(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -296,7 +296,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda + @requires_cuda_and_triton def test_persistent_reduction_no_x_dim(self): def fn(x, y): return x.sum(1), y.sum(1) @@ -346,7 +346,7 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() - @requires_cuda + @requires_cuda_and_triton def test_dynamic_shapes_activations(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -366,7 +366,7 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) - @requires_cuda + @requires_cuda_and_triton def test_dynamic_shapes_2d_blocking(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -388,7 +388,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda + @requires_cuda_and_triton def test_dynamic_shapes_reduce(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -411,7 +411,7 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) - @requires_cuda + @requires_cuda_and_triton def test_dynamic_shapes_mutated(self): # combo kernel dispatch strategy: round robin def test_mutated(a, b, c, d): @@ -435,7 +435,7 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch("combo_kernels_autotune", 0) def test_dynamic_shapes_activations_no_autotune(self): def test_activations(a, b, c): @@ -456,7 +456,7 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_persistent_reduction_no_x_dim(self): @@ -475,7 +475,7 @@ def fn(x, y): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_persistent_reduction_no_x_dim_2(self): @@ -494,7 +494,7 @@ def fn(x, y): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_2d_blocking_round_robin(self): @@ -533,7 +533,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(out_eager, out_compiled) self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) @torch._inductor.config.patch("triton.autotune_at_compile_time", True) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index c99ad7f2c95a9..241528b159cc1 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -54,6 +54,7 @@ HAS_GPU, ) from torch.testing._internal.logging_utils import logs_to_string +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.utils._python_dispatch import TorchDispatchMode @@ -2994,7 +2995,7 @@ def backward(ctx, grad): b = MyFunc.apply(a) b.sum().backward() - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") + @requires_cuda_and_triton def test_cudagraphs_cpu_division(self): from torch._dynamo.testing import reduce_to_scalar_loss @@ -3034,7 +3035,7 @@ def test_cudagraphs_cpu_graph(self): self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") + @requires_cuda_and_triton def test_cudagraphs_sdpa(self): query = torch.rand( 32, 8, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True @@ -3056,7 +3057,7 @@ def test_cudagraphs_sdpa(self): 2 if inductor_config.cpp_wrapper else 0, ) - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") + @requires_cuda_and_triton def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): class MyFn(torch.autograd.Function): @staticmethod @@ -3087,7 +3088,7 @@ def backward(ctx, gO): self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) @scoped_load_inline - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") + @requires_cuda_and_triton def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { @@ -3715,7 +3716,7 @@ def inner_compiler(gm_, example_inputs_): self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node)) self.assertTrue(isinstance(view_nodes[1].args[1][0], torch.fx.Node)) - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") + @requires_cuda_and_triton def test_flex_attention(self): def _squared(score, b, h, m, n): """Joint graph needed for correctness""" @@ -3883,7 +3884,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check), ) - @unittest.skipIf(not HAS_CUDA_AND_TRITON, "requires cuda") + @requires_cuda_and_triton def test_cpu_offloading(self): def fn(): def pack(x): diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 9751b3ca8f554..3b23e7a51f702 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -64,7 +64,7 @@ HAS_GPU, has_triton, ) -from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu def get_inputs(optim): @@ -916,7 +916,7 @@ def fn(xs, ys): self.assertLess(end - start, 90) - @requires_cuda + @requires_cuda_and_triton def test_S429861(self): # Just verify we can compile this function without error try: @@ -935,7 +935,7 @@ def test_S429861(self): kwargs = aot_graph_input_parser(forward) torch.compile(forward)(**kwargs) - @requires_cuda + @requires_cuda_and_triton def test_foreach_map_adam(self): params = [ torch.rand( diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 7a132ac2a0468..b6786130416bd 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -1,7 +1,6 @@ # Owner(s): ["module: inductor"] import ctypes -import unittest import torch from torch._inductor.async_compile import AsyncCompile @@ -10,10 +9,7 @@ from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import fresh_cache -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON - - -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") +from torch.testing._internal.triton_utils import requires_cuda_and_triton _SOURCE_CODE = r""" @@ -41,7 +37,7 @@ class TestCUDACodeCache(InductorTestCase): - @requires_cuda + @requires_cuda_and_triton def test_cuda_load(self): with fresh_cache(): # Test both .o and .so compilation. @@ -69,14 +65,14 @@ def test_cuda_load(self): ) torch.testing.assert_close(y, expected_y) - @requires_cuda + @requires_cuda_and_triton def test_compilation_error(self): with fresh_cache(): error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) with self.assertRaises(CUDACompileError): CUDACodeCache.compile(error_source_code, "o") - @requires_cuda + @requires_cuda_and_triton def test_async_compile(self): with fresh_cache(): async_compile = AsyncCompile() diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 4a7f9e6e92e03..1408a0208cf06 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -40,6 +40,7 @@ skipIfRocm, TEST_CUDA_GRAPH, ) +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode @@ -55,11 +56,8 @@ importlib.import_module("functorch") importlib.import_module("filelock") -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON - aten = torch.ops.aten -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") requires_multigpu = functools.partial( unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" ) diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 5889adb120ffa..2a944e4046696 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -159,7 +159,7 @@ def select_no_algorithm(*args, **kwargs): class TestCutlassBackend(TestCase): def setUp(self): if not HAS_CUDA_AND_TRITON: - self.skipTest("CUDA is not available") + self.skipTest("CUDA and triton are not available") if torch.version.hip: self.skipTest("CUTLASS backend is not supported on HIP") diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index f9cedf81f85b0..c51d0bba229ec 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -15,7 +15,7 @@ parametrize, ) from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.utils._pytree import tree_flatten @@ -269,29 +269,29 @@ def fn(a0, a1): ) # called in test_cuda_cpp_wrapper.py - @requires_cuda + @requires_cuda_and_triton def test_foreach_cpp_wrapper_cuda(self): self._test_single_list(op=torch._foreach_add) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_single_list(self, op): self._test_single_list(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_single_scalar(self, op): self._test_single_scalar(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_tensor_bin_ops def test_single_scalar_tensor(self, op): self._test_single_scalar_tensor(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_scheduler_fusion_list(self, op): if op in un_ops_under_test: @@ -319,7 +319,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_scheduler_fusion_scalar(self, op): def fn(a0, a1): @@ -336,7 +336,7 @@ def fn(a0, a1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_broadcasting(self, op): def fn(a0, a1, b0, b1): @@ -355,7 +355,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_singleton_lists(self, op): if op in un_ops_under_test: @@ -392,7 +392,7 @@ def fn(a0, b0, c0): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_type_promotion(self, op): def fn(a0, a1, b0, b1): @@ -413,7 +413,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_kernel_split_arg_limit_list(self, op): # NB: foeach_copy won't pass this test because it will dce one set of buffers @@ -435,7 +435,7 @@ def fn(a, b): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops @unittest.skip( "Triton recursion depth exceeded: https://github.com/triton-lang/triton/issues/1763" @@ -455,7 +455,7 @@ def fn(a): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_fusion_duplicate_buffer_list(self, op): def fn(a0, a1, b0, b1): @@ -479,7 +479,7 @@ def fn(a0, a1, b0, b1): kernel_count = 2 self.assertEqual(torch._inductor.metrics.generated_kernel_count, kernel_count) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_non_foreach_consumer_list(self, op): if op in un_ops_under_test: @@ -507,7 +507,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_consumer_scalar(self, op): def fn(a0, a1): @@ -524,7 +524,7 @@ def fn(a0, a1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_non_foreach_producer_list(self, op): if op in un_ops_under_test: @@ -554,7 +554,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_producer_scalar(self, op): def fn(a0, a1, b0, b1): @@ -574,7 +574,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @all_ops def test_non_foreach_consumer_producer_list(self, op): if op in un_ops_under_test: @@ -616,7 +616,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_consumer_producer_scalar(self, op): def fn(a0, a1, b0, b1): @@ -641,7 +641,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @bin_ops @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @@ -661,7 +661,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) @@ -680,7 +680,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) @@ -715,7 +715,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @decomp_ops def test_decomp(self, op): def fn(a0, a1, b0, b1, c0, c1): @@ -735,7 +735,7 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_fuse_concat(self): def fn(x1, x2, x3, w1, w2, w3): x = torch.stack([x1, x2, x3]) @@ -758,7 +758,7 @@ def fn(x1, x2, x3, w1, w2, w3): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton def test_zero_elems(self): def fn(a0, a1, b0, b1): return torch._foreach_add([a0, a1], [b0, b1]) @@ -775,7 +775,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_2d_blocking(self, op): def fn(a0, a1, b0, b1): @@ -793,7 +793,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_2d_blocking_partitioning(self, op): def fn(a0, a1, b0, b1): @@ -811,7 +811,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @bin_ops def test_2d_blocking_partitioning_elems(self, op): """2D blocking should be grouped by number of yelems""" @@ -833,7 +833,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @bin_ops @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) def test_2d_blocking_partitioning_mixed_sizes(self, op): @@ -856,7 +856,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing(self, op): def fn(a0, a1, b0, b1): @@ -874,7 +874,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing_mut_before(self, op): def fn(a0, a1, b0, b1): @@ -893,7 +893,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing_mut_after(self, op): def fn(a0, a1, b0, b1): @@ -912,7 +912,7 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton def test_multi_device(self): def test_foreach_add(a0, a1, b0, b1): return torch._foreach_add([a0, a1], [b0, b1]) @@ -930,7 +930,7 @@ def test_foreach_add(a0, a1, b0, b1): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton def test_aliasing(self): def test_foreach_add(a0, a1, a2, b0, b1, b2): return torch._foreach_add_([a0, a1, a2], [b0, b1, b2]) @@ -952,7 +952,7 @@ def test_foreach_add(a0, a1, a2, b0, b1, b2): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 1) def test_2d_block_no_mixed_sizes_no_mask(self): """2D blocking with no mixed sizes constant mask""" @@ -974,7 +974,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) def test_2d_block_mixed_sizes_with_mask(self): """2D blocking with mixed sizes should have mask""" @@ -996,7 +996,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) - @requires_cuda + @requires_cuda_and_triton @foreach_map_bin_ops def test_foreach_map_backward_binary(self, op): from torch._dynamo.polyfills import foreach_map_fn @@ -1037,7 +1037,7 @@ def ref_fn(xs, ys): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) - @requires_cuda + @requires_cuda_and_triton def test_foreach_map_input_mutation(self): def fn(xs, ys): outs = foreach_map_add_inplace(xs, ys) @@ -1073,7 +1073,7 @@ def fn(xs, ys): ): _ = run_fw_bw_and_get_code(lambda: torch.compile(fn)(*inps)) - @requires_cuda + @requires_cuda_and_triton @foreach_map_un_ops def test_foreach_map_backward_unary(self, op): from torch._dynamo.polyfills import foreach_map_fn diff --git a/test/inductor/test_inductor_annotations.py b/test/inductor/test_inductor_annotations.py index 75f53f4dd9b81..bee7e0ad917da 100644 --- a/test/inductor/test_inductor_annotations.py +++ b/test/inductor/test_inductor_annotations.py @@ -3,7 +3,7 @@ import torch._inductor.config as inductor_config from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton class InductorAnnotationTestCase(TestCase): @@ -18,7 +18,7 @@ def f(a, b): _, code = run_and_get_code(f_comp, a, b) return code[0] - @requires_cuda + @requires_cuda_and_triton def test_no_annotations(self): code = self.get_code() @@ -26,7 +26,7 @@ def test_no_annotations(self): self.assertTrue("training_annotation" not in code) @inductor_config.patch(annotate_training=True) - @requires_cuda + @requires_cuda_and_triton def test_training_annotation(self): code = self.get_code() diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 30a273ba17e31..83cd236875f45 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -28,7 +28,10 @@ # performance for that setting. # # Defines all the kernels for tests -from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_cuda +from torch.testing._internal.triton_utils import ( + HAS_CUDA_AND_TRITON, + requires_cuda_and_triton, +) # set so that metrics appear @@ -920,7 +923,7 @@ def f(a, b): inp = (T(10, 10), TI(2, mx=5)) self.assertExpectedInline(count_numel(f, *inp), """42""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_training(self): @triton.jit def sin_kernel( @@ -964,7 +967,7 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel_train(f, x), """9""") - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_not_fusable_with_users(self): @triton.jit def _sin_kernel( @@ -1017,7 +1020,7 @@ def f(x): # (it will cost an extra kernel) self.assertExpectedInline(count_numel_train(f, x), """27""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op_training_two_mutated_inputs(self): @torch.library.custom_op( "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"} @@ -1037,7 +1040,7 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel(f, x), """21""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op_training(self): @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) def sin(x: torch.Tensor, result: torch.Tensor) -> None: @@ -1066,7 +1069,7 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel_train(f, x), """9""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor x, Tensor(a!) out) -> ()") @@ -1096,7 +1099,7 @@ def f(x, out): self.assertExpectedInline(count_numel(f, x, out), """21""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op_intermediate(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor x, Tensor(a!) out) -> ()") @@ -1127,7 +1130,7 @@ def f(x, out): self.assertExpectedInline(count_numel(f, x, out), """21""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_custom_op_two_mutated_inputs(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor q, Tensor(a!) k_cache, Tensor(b!) v_cache) -> Tensor") @@ -1159,7 +1162,7 @@ def f(): self.assertExpectedInline(count_numel(f), """39""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v1(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1171,7 +1174,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """50""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v2(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1184,7 +1187,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """70""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v3(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1197,7 +1200,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """80""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v4(self): def f(x: torch.Tensor, y: torch.Tensor): x_view = x.view(-1) @@ -1211,7 +1214,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """70""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v5(self): def f(x: torch.Tensor, y: torch.Tensor): x_view = x.view(-1) @@ -1225,7 +1228,7 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """80""") - @requires_cuda + @requires_cuda_and_triton def test_inplace_triton_kernel_v6(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index 2dd9ca44eb687..77e099cf0cb93 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -19,7 +19,7 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.virtualized import V from torch.testing._internal.inductor_utils import HAS_GPU -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton try: @@ -229,7 +229,7 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): if filepath: shutil.rmtree(filepath) - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_to_post_grad_tracing_cuda(self): self._test_triton_kernel_to_post_grad_tracing(device="cuda") @@ -237,7 +237,7 @@ def test_triton_kernel_to_post_grad_tracing_cuda(self): def test_triton_kernel_to_post_grad_tracing_cpu(self): self._test_triton_kernel_to_post_grad_tracing(device="cpu") - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): M = 8 N = 6 @@ -285,7 +285,7 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): if filepath: shutil.rmtree(filepath) - @requires_cuda + @requires_cuda_and_triton def _test_pt_tracing_combo_kernel(self, backend): """This test checks that generated provenance tracing artifact from triton combo kernel to post grad nodes""" a = torch.randn(10, 10, device="cuda") @@ -320,7 +320,7 @@ def _test_pt_tracing_combo_kernel(self, backend): expected_data = {"triton_poi_fused_0": ["relu", "sigmoid", "tanh"]} self._check_provenance_tracing_artifact(filepath, expected_data) - @requires_cuda + @requires_cuda_and_triton def test_triton_kernel_to_post_grad_tracing_combo_kernel(self): self._test_pt_tracing_combo_kernel(backend="inductor") self._test_pt_tracing_combo_kernel(backend="aot_inductor") @@ -437,7 +437,7 @@ def get_node_with_target(self, gm, target): """ return next(iter([node for node in gm.graph.nodes if node.target == target])) - @requires_cuda # test only works for cuda pattern matcher + @requires_cuda_and_triton # test only works for cuda pattern matcher def test_pattern_matcher_transfer_meta(self): """ Test that stack trace is transfered when node is decomposed in post_grad_passes diff --git a/test/inductor/test_split_cat_fx_aten_passes.py b/test/inductor/test_split_cat_fx_aten_passes.py index 354552c497d98..0ec7825df001c 100644 --- a/test/inductor/test_split_cat_fx_aten_passes.py +++ b/test/inductor/test_split_cat_fx_aten_passes.py @@ -5,7 +5,7 @@ from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton try: @@ -248,7 +248,7 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) ) - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -291,7 +291,7 @@ def test_split_cat_post_grad(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -317,7 +317,7 @@ def test_split_cat_post_grad_singular(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -342,7 +342,7 @@ def test_select_cat_post_grad(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() - @requires_cuda + @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index 2ce294ed0ff55..654bfd269f761 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -13,10 +13,10 @@ from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import skipIfRocm -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton -@requires_cuda +@requires_cuda_and_triton class TestStaticCudaLauncher(TestCase): def setUp(self): super().setUp() @@ -396,7 +396,7 @@ def kernel_many_args(out_tensor, {decl}): self.assertEqual(buf0, buf1) -@requires_cuda +@requires_cuda_and_triton @torch._inductor.config.patch( {"use_static_cuda_launcher": True, "strict_static_cuda_launcher": True} ) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 98604366b842b..cdcedd5a1771e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -138,7 +138,7 @@ skipCPUIf, skipCUDAIf, ) -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton _T = TypeVar("_T") @@ -13155,7 +13155,7 @@ def f(x): "assert_size_stride(buf2, (16, 32), (32, 1)" ).run(code) - @requires_cuda + @requires_cuda_and_triton @config.patch(use_fast_math=True) def test_prepare_softmax_with_fast_math(self): """ @@ -13654,7 +13654,7 @@ def forward(self, x): inputs = (torch.randn(4, device=self.device),) self.common(Model(), inputs) - @requires_cuda + @requires_cuda_and_triton @parametrize("use_cat", [True, False]) def test_copy_non_blocking_is_pinned(self, use_cat): def f(a_list): @@ -14071,7 +14071,7 @@ def forward( torch._inductor.aot_compile(traced, inputs) @skipCUDAIf(not SM90OrLater, "Requires sm90") - @requires_cuda + @requires_cuda_and_triton @unittest.skipIf(TEST_WITH_ROCM, "no grouped_mm support") @config.patch(implicit_fallbacks=True) def test_grouped_mm(self): diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index e8d6ce38d5af6..1ee24c74bb766 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -26,7 +26,6 @@ OpDTypes, ops, skipCPUIf, - skipCUDAIf, skipXPUIf, ) from torch.testing._internal.common_methods_invocations import op_db, skipOps @@ -46,11 +45,11 @@ from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, - HAS_CUDA_AND_TRITON, has_triton, HAS_XPU_AND_TRITON, maybe_skip_size_asserts, ) +from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.utils._dtype_abbrs import dtype_abbrs from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map @@ -1126,7 +1125,7 @@ def tearDown(self): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently - @skipCUDAIf(not HAS_CUDA_AND_TRITON, "Skipped! Triton not found") + @requires_cuda_and_triton @skipXPUIf( not HAS_XPU_AND_TRITON, "Skipped! Supported XPU compiler and Triton not found" ) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 87529c23dd7ad..6804a500fbddb 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -2200,7 +2200,7 @@ def f(x): self.assertEqual(compiled_out, eager_out) # TODO enable this test case on XPU. - @requires_cuda + @requires_cuda_and_triton @parametrize("cfg", ["normal", "cpp_wrapper"]) def test_triton_kernel_dtype_view(self, cfg): # https://github.com/pytorch/pytorch/issues/136159 diff --git a/test/test_foreach.py b/test/test_foreach.py index a5ca220dcb525..7ac128d6bac8a 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -43,7 +43,7 @@ TEST_WITH_ROCM, TestCase, ) -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_cuda_and_triton _BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator" @@ -1375,7 +1375,7 @@ def test_foreach_copy_with_multi_dtypes_large_input(self): ref_out = torch.empty_like(self_tensor).copy_(src_tensor) self.assertEqual(self_tensor, ref_out) - @requires_cuda + @requires_cuda_and_triton @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) def test_foreach_copy_with_different_device_inputs(self, device, dtype, op): if dtype in (torch.complex128, torch.complex64): diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 922bde7cc4b58..40687995470b4 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -6,7 +6,9 @@ from torch.utils._triton import has_triton -requires_cuda = unittest.skipUnless(HAS_CUDA_AND_TRITON, "requires cuda") +requires_cuda_and_triton = unittest.skipUnless( + HAS_CUDA_AND_TRITON, "requires cuda and triton" +) requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu") if has_triton(): From c9671dc865aa0fc1cb86df754e355b44d8e02bb4 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Sun, 10 Aug 2025 00:17:46 -0400 Subject: [PATCH 0192/1424] Delete Python reference implementation from torchdim, as it is untested (#160115) Signed-off-by: Edward Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/160115 Approved by: https://github.com/albanD --- functorch/dim/__init__.py | 43 +- functorch/dim/batch_tensor.py | 26 -- functorch/dim/delayed_mul_tensor.py | 76 ---- functorch/dim/dim.py | 120 ------ functorch/dim/reference.py | 645 ---------------------------- functorch/dim/wrap_type.py | 14 +- 6 files changed, 15 insertions(+), 909 deletions(-) delete mode 100644 functorch/dim/batch_tensor.py delete mode 100644 functorch/dim/delayed_mul_tensor.py delete mode 100644 functorch/dim/dim.py delete mode 100644 functorch/dim/reference.py diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index f52d417d2ba27..95747181e848e 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -24,10 +24,6 @@ class DimensionBindError(Exception): # use dict to avoid writing C++ bindings for set pointwise = dict.fromkeys(op_properties.pointwise, True) -use_c = True -if not use_c: - from . import reference - class _Tensor: # fast path around slow wrapping/unwrapping logic for simply queries used @@ -40,12 +36,8 @@ def dims(self): def dim(self): return self.ndim - if use_c: - __torch_function__ = classmethod(_C.__torch_function__) - expand = _C._instancemethod(_C.expand) - else: - __torch_function__ = reference.__torch_function__ - expand = reference.expand + __torch_function__ = classmethod(_C.__torch_function__) + expand = _C._instancemethod(_C.expand) index = _C._instancemethod(_C.index) @@ -64,8 +56,6 @@ class Dim(_C.Dim, _Tensor): class Tensor(_Tensor, _C.Tensor): - if not use_c: - from_batched = staticmethod(_C.Tensor_from_batched) from_positional = staticmethod(_C.Tensor_from_positional) sum = _C._instancemethod(_C.Tensor_sum) @@ -75,21 +65,17 @@ def cat(tensors, dim, new_dim): return stack(tensors, n, dim).index([n, dim], new_dim) -if use_c: - _wrap = _C._wrap +_wrap = _C._wrap + + +def _def(name, *args, **kwargs): + orig = getattr(torch.Tensor, name) + setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) - def _def(name, *args, **kwargs): - orig = getattr(torch.Tensor, name) - setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) - t__getitem__ = _C._instancemethod(_C.__getitem__) - stack = _C.stack - split = _C._instancemethod(_C.split) -else: - _wrap, _def = reference._wrap, reference._def - t__getitem__ = reference.t__getitem__ - stack = reference.stack - split = reference.split +t__getitem__ = _C._instancemethod(_C.__getitem__) +stack = _C.stack +split = _C._instancemethod(_C.split) # note: there is no python reference t__setitem__ = _C._instancemethod(_C.__setitem__) @@ -105,13 +91,10 @@ def _def(name, *args, **kwargs): _Tensor.split = split torch.Tensor.expand = _C._instancemethod(_C.expand) torch.Tensor.index = _C._instancemethod(_C.index) -wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__) +wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__) del _Tensor.ndim -if use_c: - _Tensor.order = _C._instancemethod(_C.order) -else: - _Tensor.order = reference.positional +_Tensor.order = _C._instancemethod(_C.order) _def("mean") _def("sum") diff --git a/functorch/dim/batch_tensor.py b/functorch/dim/batch_tensor.py deleted file mode 100644 index dae9b270896e9..0000000000000 --- a/functorch/dim/batch_tensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from contextlib import contextmanager - -from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers - - -_enabled = False - - -@contextmanager -def _enable_layers(dims): - global _enabled - assert not _enabled - input = sorted((d._level, d.size) for d in dims if not isinstance(d, int)) - n = len(input) - try: - _vmap_add_layers(input) - _enabled = True - yield - finally: - _enabled = False - _vmap_remove_layers(n) diff --git a/functorch/dim/delayed_mul_tensor.py b/functorch/dim/delayed_mul_tensor.py deleted file mode 100644 index 3c136cfe1247d..0000000000000 --- a/functorch/dim/delayed_mul_tensor.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from . import _Tensor, Tensor -from .reference import _dims, _enable_layers, llist, ltuple - - -class DelayedMulTensor(_Tensor): - def __init__(self, lhs, rhs): - self._lhs, self._rhs = lhs, rhs - self._data = None - self._levels_data = None - self._has_device = lhs._has_device or rhs._has_device - self._batchtensor_data = None - self._tensor_data = None - - @property - def _levels(self): - if self._levels_data is None: - levels = llist(self._lhs._levels) - for l in self._rhs._levels: - if l not in levels: - levels.append(l) - self._levels_data = ltuple(levels) - return self._levels_data - - @property - def _batchtensor(self): - if self._batchtensor_data is None: - with _enable_layers(self._levels): - print("bt multiply fallback") - self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor - return self._batchtensor_data - - @property - def _tensor(self): - if self._tensor_data is None: - self._tensor_data = Tensor.from_batched( - self._batchtensor, self._has_device - )._tensor - return self._tensor_data - - @property - def ndim(self): - return self._batchtensor.ndim - - @property - def dims(self): - return ltuple(super().dims) - - def sum(self, dim): - dims = _dims(dim, 0, False, False) - n = ord("a") - all_levels = self._levels - - def to_char(d): - return chr(n + all_levels.index(d)) - - plhs, levelslhs = self._lhs._tensor, self._lhs._levels - prhs, levelsrhs = self._rhs._tensor, self._rhs._levels - new_levels = [l for l in self._levels if l not in dims] - fmt = "".join( - [ - *(to_char(d) for d in levelslhs), - ",", - *(to_char(d) for d in levelsrhs), - "->", - *(to_char(d) for d in new_levels), - ] - ) - result_data = torch.einsum(fmt, (plhs, prhs)) - return Tensor.from_positional(result_data, new_levels, True) diff --git a/functorch/dim/dim.py b/functorch/dim/dim.py deleted file mode 100644 index 9a4b568664849..0000000000000 --- a/functorch/dim/dim.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import dis -import inspect -from dataclasses import dataclass -from typing import Union - -from . import DimList - - -_vmap_levels = [] - - -@dataclass -class LevelInfo: - level: int - alive: bool = True - - -class Dim: - def __init__(self, name: str, size: Union[None, int] = None): - self.name = name - self._size = None - self._vmap_level = None - if size is not None: - self.size = size - - def __del__(self): - if self._vmap_level is not None: - _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821 - while ( - not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level # noqa: F821 - ): - _vmap_decrement_nesting() # noqa: F821 - _vmap_levels.pop() - - @property - def size(self): - assert self.is_bound - return self._size - - @size.setter - def size(self, size: int): - from . import DimensionBindError - - if self._size is None: - self._size = size - self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821 - self._vmap_stack = len(_vmap_levels) - _vmap_levels.append(LevelInfo(self._vmap_level)) - - elif self._size != size: - raise DimensionBindError( - f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" - ) - - @property - def is_bound(self): - return self._size is not None - - def __repr__(self): - return self.name - - -def extract_name(inst): - assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" - return inst.argval - - -_cache = {} - - -def dims(lists=0): - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - code, lasti = calling_frame.f_code, calling_frame.f_lasti - key = (code, lasti) - if key not in _cache: - first = lasti // 2 + 1 - instructions = list(dis.get_instructions(calling_frame.f_code)) - unpack = instructions[first] - - if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": - # just a single dim, not a list - name = unpack.argval - ctor = Dim if lists == 0 else DimList - _cache[key] = lambda: ctor(name=name) - else: - assert unpack.opname == "UNPACK_SEQUENCE" - ndims = unpack.argval - names = tuple( - extract_name(instructions[first + 1 + i]) for i in range(ndims) - ) - first_list = len(names) - lists - _cache[key] = lambda: tuple( - Dim(n) if i < first_list else DimList(name=n) - for i, n in enumerate(names) - ) - return _cache[key]() - - -def _dim_set(positional, arg): - def convert(a): - if isinstance(a, Dim): - return a - else: - assert isinstance(a, int) - return positional[a] - - if arg is None: - return positional - elif not isinstance(arg, (Dim, int)): - return tuple(convert(a) for a in arg) - else: - return (convert(arg),) diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py deleted file mode 100644 index fd934011d8238..0000000000000 --- a/functorch/dim/reference.py +++ /dev/null @@ -1,645 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# reference python implementations for C ops -import torch -from functorch._C import dim as _C - -from . import op_properties -from .batch_tensor import _enable_layers -from .tree_map import tree_flatten, tree_map - - -DimList = _C.DimList -import operator -from functools import reduce - - -# use dict to avoid writing C++ bindings for set -pointwise = set(op_properties.pointwise) - - -def prod(x): - return reduce(operator.mul, x, 1) - - -def _wrap_dim(d, N, keepdim): - from . import Dim - - if isinstance(d, Dim): - assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" - return d - elif d >= 0: - return d - N - else: - return d - - -def _dims(d, N, keepdim, single_dim): - from . import Dim - - if isinstance(d, (Dim, int)): - return ltuple((_wrap_dim(d, N, keepdim),)) - assert not single_dim, f"expected a single dimension or int but found: {d}" - return ltuple(_wrap_dim(x, N, keepdim) for x in d) - - -def _bind_dims_to_size(lhs_size, rhs, lhs_debug): - from . import DimensionMismatchError - - not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) - if len(not_bound) == 1: - idx, d = not_bound[0] - rhs_so_far = prod(r.size for r in rhs if r.is_bound) - if lhs_size % rhs_so_far != 0: - rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) - raise DimensionMismatchError( - f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" - ) - new_size = lhs_size // rhs_so_far - d.size = new_size - elif len(not_bound) > 1: - rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) - raise DimensionMismatchError( - f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" - ) - else: - rhs_size = prod(r.size for r in rhs) - if lhs_size != rhs_size: - raise DimensionMismatchError( - f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" - ) - - -def _tensor_levels(inp): - from . import _Tensor - - if isinstance(inp, _Tensor): - return inp._tensor, llist(inp._levels), inp._has_device - else: - return inp, llist(range(-inp.ndim, 0)), True - - -def _match_levels(v, from_levels, to_levels): - view = [] - permute = [] - requires_view = False - size = v.size() - for t in to_levels: - try: - idx = from_levels.index(t) - permute.append(idx) - view.append(size[idx]) - except ValueError: - view.append(1) - requires_view = True - if permute != list(range(len(permute))): - v = v.permute(*permute) - if requires_view: - v = v.view(*view) - return v - - -# make a single dimension positional but do not permute it, -# used to do multi-tensor operators where the dim being acted on -# should not physically move if possible -def _positional_no_permute(self, dim, expand_dim=False): - from . import Tensor - - ptensor, levels = self._tensor, llist(self._levels) - try: - idx = levels.index(dim) - except ValueError: - if not expand_dim: - raise - idx = 0 - ptensor = ptensor.expand(dim.size, *ptensor.size()) - levels.insert(0, 0) - idx_batched = 0 - for i in range(idx): - if isinstance(levels[i], int): - levels[i] -= 1 - idx_batched += 1 - levels[idx] = -idx_batched - 1 - return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched - - -def seq(a, b): - from . import Dim - - if isinstance(a, Dim) != isinstance(b, Dim): - return False - if isinstance(a, Dim): - return a is b - else: - return a == b - - -class isin: - __slots__ = () - - def __contains__(self, item): - for x in self: - if seq(item, x): - return True - return False - - def index(self, item): - for i, x in enumerate(self): - if seq(item, x): - return i - raise ValueError - - -class llist(isin, list): - __slots__ = () - - -class ltuple(isin, tuple): - __slots__ = () - - -empty_dict = {} - - -@classmethod -def __torch_function__(self, orig, cls, args, kwargs=empty_dict): - from . import _Tensor, Tensor, TensorLike - from .delayed_mul_tensor import DelayedMulTensor - - if orig is torch.Tensor.__mul__: - lhs, rhs = args - if ( - isinstance(lhs, _Tensor) - and isinstance(rhs, _Tensor) - and lhs.ndim == 0 - and rhs.ndim == 0 - ): - return DelayedMulTensor(lhs, rhs) - all_dims = llist() - flat_args, unflatten = tree_flatten((args, kwargs)) - device_holding_tensor = None - for f in flat_args: - if isinstance(f, _Tensor): - if f._has_device: - device_holding_tensor = f._batchtensor - for d in f.dims: - if d not in all_dims: - all_dims.append(d) - - def unwrap(t): - if isinstance(t, _Tensor): - r = t._batchtensor - if device_holding_tensor is not None and not t._has_device: - r = r.to(device=device_holding_tensor.device) - return r - return t - - if orig in pointwise: - result_levels = llist() - to_expand = [] - for i, f in enumerate(flat_args): - if isinstance(f, TensorLike): - ptensor, levels, _ = _tensor_levels(f) - if ( - isinstance(f, _Tensor) - and not f._has_device - and device_holding_tensor is not None - ): - ptensor = ptensor.to(device=device_holding_tensor.device) - flat_args[i] = ptensor - for l in levels: - if l not in result_levels: - result_levels.append(l) - to_expand.append((i, levels)) - - for i, levels in to_expand: - flat_args[i] = _match_levels(flat_args[i], levels, result_levels) - args, kwargs = unflatten(flat_args) - result = orig(*args, **kwargs) - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_positional( - t, result_levels, device_holding_tensor is not None - ) - return t - - return tree_map(wrap, result) - else: - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_batched(t, device_holding_tensor is not None) - return t - - with _enable_layers(all_dims): - print(f"batch_tensor for {orig}") - args, kwargs = unflatten(unwrap(f) for f in flat_args) - result = orig(*args, **kwargs) - # print("END", orig) - return tree_map(wrap, result) - - -def positional(self, *dims): - from . import Dim, DimensionBindError, Tensor - - ptensor, levels = self._tensor, llist(self._levels) - flat_dims = llist() - view = [] - needs_view = False - ndim = self.ndim - for d in dims: - if isinstance(d, DimList): - flat_dims.extend(d) - view.extend(e.size for e in d) - elif isinstance(d, Dim): - flat_dims.append(d) - view.append(d.size) - elif isinstance(d, int): - d = _wrap_dim(d, ndim, False) - flat_dims.append(d) - view.append(ptensor.size(d)) - else: - flat_dims.extend(d) - view.append(prod(e.size for e in d)) - needs_view = True - - permute = list(range(len(levels))) - for i, d in enumerate(flat_dims): - try: - idx = levels.index(d) - except ValueError as e: - raise DimensionBindError( - f"tensor of dimensions {self.dims} does not contain dim {d}" - ) from e - p = permute[idx] - del levels[idx] - del permute[idx] - levels.insert(i, 0) - permute.insert(i, p) - ptensor = ptensor.permute(*permute) - seen = 0 - for i in range(len(levels) - 1, -1, -1): - if isinstance(levels[i], int): - seen += 1 - levels[i] = -seen - result = Tensor.from_positional(ptensor, levels, self._has_device) - if needs_view: - result = result.reshape(*view, *result.size()[len(flat_dims) :]) - return result - - -def _contains_dim(input): - from . import Dim - - for i in input: - if isinstance(i, Dim): - return True - - -def expand(self, *sizes): - if not _contains_dim(sizes): - return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) - dims = sizes - sizes = [d.size for d in dims] + [-1] * self.ndim - self = self.expand(*sizes) - return self[dims] - - -_not_present = object() - - -def _getarg(name, offset, args, kwargs, default): - if len(args) > offset: - return args[offset] - return kwargs.get(name, default) - - -def _patcharg(name, offset, args, kwargs, value): - if len(args) > offset: - args[offset] = value - else: - kwargs[name] = value - - -def _wrap( - orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True -): - from . import Dim, Tensor, TensorLike - - def fn(self, *args, **kwargs): - dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) - if dim is _not_present or (single_dim and not isinstance(dim, Dim)): - with _enable_layers(self.dims): - print(f"dim fallback batch_tensor for {orig}") - return Tensor.from_batched( - orig(self._batchtensor, *args, **kwargs), self._has_device - ) - keepdim = ( - _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False - ) - t, levels = self._tensor, llist(self._levels) - dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) - dim_indices = tuple(levels.index(d) for d in dims) - if reduce and not keepdim: - new_levels = [l for i, l in enumerate(levels) if i not in dim_indices] - else: - new_levels = levels - - if len(dim_indices) == 1: - dim_indices = dim_indices[ - 0 - ] # so that dims that really only take a single argument work... - args = list(args) - _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_positional(t, new_levels, self._has_device) - return t - - with _enable_layers(new_levels): - print(f"dim used batch_tensor for {orig}") - r = orig(t, *args, **kwargs) - return tree_map(wrap, r) - - return fn - - -def _def(name, *args, **kwargs): - from . import _Tensor - - orig = getattr(torch.Tensor, name) - setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) - - -no_slice = slice(None) - -_orig_getitem = torch.Tensor.__getitem__ - - -class dim_tracker: - def __init__(self) -> None: - self.dims = llist() - self.count = [] - - def record(self, d): - if d not in self.dims: - self.dims.append(d) - self.count.append(1) - - def __getitem__(self, d): - return self.count[self.dims.index(d)] - - -def t__getitem__(self, input): - from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike - - # * bail to original example if we have a single non-Dim tensor, or a non-tensor - # * locate ... or an unbound tensor list, and determine its size, bind dim list - # (remember that None does not count to the total dim count) - # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim, - # produce the re-view if needed - # * for each single-use dim index, replace with no_slice and mark that it will be added - # (keep track of whether we have to call super) - # * call super if needed - # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) - # this handles bool indexing handling, as well as some other simple cases. - - is_simple = ( - not isinstance(input, Dim) - and not isinstance(input, (tuple, list)) - and - # WAR for functorch bug where zero time tensors in getitem are not handled correctly. - not (isinstance(input, TensorLike) and input.ndim == 0) - ) - - if is_simple: - if isinstance(self, _Tensor): - return _Tensor.__torch_function__(_orig_getitem, None, (self, input)) - else: - return _orig_getitem(self, input) - - # can further optimize this case - if not isinstance(input, tuple): - input = [input] - else: - input = list(input) - - dims_indexed = 0 - expanding_object = None - dimlists = [] - for i, s in enumerate(input): - if s is ... or isinstance(s, DimList) and not s.is_bound: - if expanding_object is not None: - msg = ( - "at most one ... or unbound dimension list can exist in indexing list but" - f" found 2 at offsets {i} and {expanding_object}" - ) - raise DimensionBindError(msg) - expanding_object = i - - if isinstance(s, DimList): - dims_indexed += len(s) if s.is_bound else 0 - dimlists.append(i) - elif s is not None and s is not ...: - dims_indexed += 1 - - ndim = self.ndim - if dims_indexed > ndim: - raise IndexError( - f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." - ) - if expanding_object is not None: - expanding_ndims = ndim - dims_indexed - obj = input[expanding_object] - if obj is ...: - input[expanding_object : expanding_object + 1] = [ - no_slice - ] * expanding_ndims - else: - obj.bind_len(expanding_ndims) - # flatten the dimslists into the indexing - for i in reversed(dimlists): - input[i : i + 1] = input[i] - dims_indexed = 0 - requires_view = False - size = self.size() - view_sizes = [] - dims_seen = dim_tracker() - - def add_dims(t): - if not isinstance(t, _Tensor): - return - for d in t.dims: - dims_seen.record(d) - - add_dims(self) - dim_packs = [] - for i, idx in enumerate(input): - if idx is None: - input[i] = no_slice - view_sizes.append(1) - requires_view = True - else: - sz = size[dims_indexed] - if isinstance(idx, Dim): - idx.size = sz - dims_seen.record(idx) - view_sizes.append(sz) - elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): - for d in idx: - dims_seen.record(idx) - _bind_dims_to_size(sz, idx, f"offset {i}") - view_sizes.extend(d.size for d in idx) - requires_view = True - dim_packs.append(i) - else: - add_dims(idx) - view_sizes.append(sz) - dims_indexed += 1 - if requires_view: - self = self.view(*view_sizes) - for i in reversed(dim_packs): - input[i : i + 1] = input[i] - - # currently: - # input is flat, containing either Dim, or Tensor, or something valid for standard indexing - # self may have first-class dims as well. - - # to index: - # drop the first class dims from self, they just become direct indices of their positions - - # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. - # these dimensions will appear and need to be bound at the first place tensor occurs - - if isinstance(self, _Tensor): - ptensor_self, levels = self._tensor, list(self._levels) - # indices to ptensor rather than self which has first-class dimensions - input_it = iter(input) - flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels] - has_device = self._has_device - to_pad = 0 - else: - ptensor_self, flat_inputs = self, input - to_pad = ptensor_self.ndim - len(flat_inputs) - has_device = True - - result_levels = [] - index_levels = [] - tensor_insert_point = None - to_expand = {} - requires_getindex = False - for i, inp in enumerate(flat_inputs): - if isinstance(inp, Dim) and dims_seen[inp] == 1: - flat_inputs[i] = no_slice - result_levels.append(inp) - elif isinstance(inp, TensorLike): - requires_getindex = True - if tensor_insert_point is None: - tensor_insert_point = len(result_levels) - ptensor, levels, _ = _tensor_levels(inp) - to_expand[i] = levels - flat_inputs[i] = ptensor - for l in levels: - if l not in index_levels: - index_levels.append(l) - else: - requires_getindex = True - result_levels.append(0) - - if tensor_insert_point is not None: - result_levels[tensor_insert_point:tensor_insert_point] = index_levels - - for i, levels in to_expand.items(): - flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels) - - if requires_getindex: - result = _orig_getitem(ptensor_self, flat_inputs) - else: - result = ptensor_self - - next_positional = -1 - if to_pad > 0: - result_levels.extend([0] * to_pad) - for i, r in enumerate(reversed(result_levels)): - if isinstance(r, int): - result_levels[-1 - i] = next_positional - next_positional -= 1 - - return Tensor.from_positional(result, result_levels, has_device) - - -# XXX - dim is optional and can be the outer-most dimension... -def stack(tensors, new_dim, dim=0, out=None): - if isinstance(dim, int): - return torch.stack(tensors, dim, out).index(dim, new_dim) - index = None - if out is not None: - out, index = _positional_no_permute(out, dim, expand_dim=True) - ptensors = [] - for t in tensors: - pt, pi = _positional_no_permute(t, dim, expand_dim=True) - if index is not None and pi != index: - pt = pt.move_dim(pi, index) - else: - index = pi - ptensors.append(pt) - pr = torch.stack(ptensors, index, out=out) - return pr.index((index, index + 1), (new_dim, dim)) - - -_orig_split = torch.Tensor.split - - -def split(self, split_size_or_sections, dim=0): - from . import _Tensor, Dim - - if isinstance(split_size_or_sections, int) or any( - isinstance(t, int) for t in split_size_or_sections - ): - if isinstance(dim, Dim): - raise ValueError( - "when dim is specified as a Dim object, split sizes must also be dimensions." - ) - return _orig_split(self, split_size_or_sections, dim=dim) - - if isinstance(dim, Dim): - assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}" - self, dim = _positional_no_permute(self, dim) - - size = self.size(dim) - total_bound_size = 0 - unbound = [] - sizes = [] - for i, d in enumerate(split_size_or_sections): - if d.is_bound: - sizes.append(d.size) - total_bound_size += d.size - else: - sizes.append(0) - unbound.append(i) - - if unbound: - assert total_bound_size <= size, ( - f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" - ) - remaining_size = size - total_bound_size - chunk_size = -(-remaining_size // len(unbound)) - for u in unbound: - sz = min(chunk_size, remaining_size) - split_size_or_sections[u].size = sz - sizes[u] = sz - remaining_size -= sz - else: - assert total_bound_size == size, ( - f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" - ) - return tuple( - t.index(dim, d) - for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) - ) diff --git a/functorch/dim/wrap_type.py b/functorch/dim/wrap_type.py index aae543b91a896..b9ebda47c4cfe 100644 --- a/functorch/dim/wrap_type.py +++ b/functorch/dim/wrap_type.py @@ -26,18 +26,8 @@ PROPERTY_TYPES = (GetSetDescriptorType, property) -def _py_wrap_method(orig, __torch_function__): - def impl(*args, **kwargs): - return __torch_function__(orig, None, args, kwargs) - - return impl - - -def wrap_type(use_c, to_patch, pattern, __torch_function__): - if use_c: - wrap_method = _wrap_method - else: - wrap_method = _py_wrap_method +def wrap_type(to_patch, pattern, __torch_function__): + wrap_method = _wrap_method all = {} for t in reversed(pattern.mro()[:-1]): # skip object From 3ac86e728dfaa7383ff7f865e9e7d33486188dae Mon Sep 17 00:00:00 2001 From: atalman Date: Sun, 10 Aug 2025 12:00:16 +0000 Subject: [PATCH 0193/1424] Add Alban and Piotr to list of maintainers (#160187) Add Alban and Piotr to list of maintainers Pull Request resolved: https://github.com/pytorch/pytorch/pull/160187 Approved by: https://github.com/albanD --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 65c0bb982bd96..3c67d36e74950 100644 --- a/README.md +++ b/README.md @@ -560,7 +560,7 @@ To learn more about making a contribution to Pytorch, please see our [Contributi PyTorch is a community-driven project with several skillful engineers and researchers contributing to it. -PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. +PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), [Alban Desmaison](https://github.com/albanD), [Piotr Bialecki](https://github.com/ptrblck) and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jekbradbury), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. From a84b60c0c4016785fd93b7b8a0c04f2d0770d332 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Sun, 10 Aug 2025 12:25:18 +0000 Subject: [PATCH 0194/1424] [MPS] Sparse coalesce more dtypes to match cpu (#160254) More dtypes to match the cpu Pull Request resolved: https://github.com/pytorch/pytorch/pull/160254 Approved by: https://github.com/malfet --- aten/src/ATen/native/sparse/mps/kernels/Sparse.metal | 7 ++++++- test/test_mps.py | 7 +++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal b/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal index ff76b9b6b5209..8b85950e393a1 100644 --- a/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal +++ b/aten/src/ATen/native/sparse/mps/kernels/Sparse.metal @@ -120,4 +120,9 @@ kernel void coalesce_with_positions_kernel( INSTANTIATE_COALESCE_WITH_POSITIONS(float); INSTANTIATE_COALESCE_WITH_POSITIONS(half); INSTANTIATE_COALESCE_WITH_POSITIONS(bfloat); -INSTANTIATE_COALESCE_WITH_POSITIONS(bool); \ No newline at end of file +INSTANTIATE_COALESCE_WITH_POSITIONS(bool); +INSTANTIATE_COALESCE_WITH_POSITIONS(long); +INSTANTIATE_COALESCE_WITH_POSITIONS(char); +INSTANTIATE_COALESCE_WITH_POSITIONS(uchar); +INSTANTIATE_COALESCE_WITH_POSITIONS(short); +INSTANTIATE_COALESCE_WITH_POSITIONS(int); \ No newline at end of file diff --git a/test/test_mps.py b/test/test_mps.py index 1deee80344404..6c55cb775f063 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -12696,9 +12696,11 @@ def test_resize(self): sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0) self.assertEqual(sparse, sparse_cpu) - def test_coalesce(self): + @parametrize("dtype", [torch.int8, torch.int16, torch.uint8, torch.int32, torch.int64, + torch.float32, torch.float16, torch.bfloat16, torch.bool]) + def test_coalesce(self, dtype): indices = torch.tensor([[0, 0, 1, 1], [0, 0, 2, 2]], dtype=torch.int64, device="mps") - values = torch.tensor([1., 2., 3., 4.], dtype=torch.float32, device="mps") + values = torch.tensor([1., 2., 3., 4.], dtype=dtype, device="mps") size = (2, 3) indices_cpu = indices.cpu() values_cpu = values.cpu() @@ -12770,6 +12772,7 @@ def test_coalesce_large_tensor(self): instantiate_parametrized_tests(TestSDPA) instantiate_parametrized_tests(TestSmoothL1Loss) instantiate_parametrized_tests(TestMetalLibrary) +instantiate_parametrized_tests(TestSparseMPS) if __name__ == "__main__": run_tests() From 0e3e377bd5126cfcc69d70c4d77b352d3404cc11 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sun, 10 Aug 2025 14:22:49 +0000 Subject: [PATCH 0195/1424] [inductor] fix CompiledArtifact.load path on Windows. (#160268) fix CompiledArtifact.load path on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160268 Approved by: https://github.com/ezyang --- test/inductor/test_codecache.py | 5 ++++- torch/_inductor/standalone_compile.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index f75a867974671..757ea061c26f8 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -29,6 +29,7 @@ TensorMetadata, TensorMetadataAndValues, ) +from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.custom_graph_pass import ( CustomGraphModulePass, CustomGraphPass, @@ -1806,7 +1807,9 @@ def f(x): assert not kwargs with tempfile.TemporaryDirectory() as temp_dir: - path = os.path.join(temp_dir, "compiled_artifact.bin") + path = normalize_path_separator( + os.path.join(temp_dir, "compiled_artifact.bin") + ) with fresh_cache(): compiled_artifact = torch._inductor.standalone_compile(gm, args) diff --git a/torch/_inductor/standalone_compile.py b/torch/_inductor/standalone_compile.py index a26a578755f63..88f635426bfd9 100644 --- a/torch/_inductor/standalone_compile.py +++ b/torch/_inductor/standalone_compile.py @@ -10,6 +10,7 @@ import torch.fx from torch._dynamo.utils import dynamo_timed +from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.cudagraph_utils import BoxedDeviceIndex from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir from torch._inductor.utils import BoxedBool, InputType @@ -116,6 +117,7 @@ def save( def load( *, path: str, format: Literal["binary", "unpacked"] = "binary" ) -> CompiledArtifact: + path = normalize_path_separator(path) with dynamo_timed("CompiledArtifact.load"): if format == "binary": # can't assert that it is a file since it might not exist yet From 7ae0629d64b404e0ef5d9c931433ad25e65d6114 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 10 Aug 2025 17:33:19 +0000 Subject: [PATCH 0196/1424] Revert "[inductor] turn on windows inductor UTs (#160161)" This reverts commit f0980fc0bbd656d6c02d23ad97e945353b314f35. Reverted https://github.com/pytorch/pytorch/pull/160161 on behalf of https://github.com/clee2000 due to broke some inductor tests on windows inductor\test_codecache.py::TestStandaloneCompile::test_different_process [GH job link](https://github.com/pytorch/pytorch/actions/runs/16853706010/job/47748778757) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/f0980fc0bbd656d6c02d23ad97e945353b314f35). note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/160161#issuecomment-3172784292)) --- .github/workflows/trunk.yml | 7 +++---- test/dynamo/test_decorators.py | 4 ---- test/dynamo/test_logging.py | 5 +---- test/inductor/test_cpu_select_algorithm.py | 3 +-- torch/_dynamo/test_case.py | 8 +++++--- 5 files changed, 10 insertions(+), 17 deletions(-) diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index c428127dc6dd2..c7cf4c84e1888 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -123,10 +123,9 @@ jobs: runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, ]} secrets: inherit diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 9bf982c5b90ec..3b29e5e961192 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -10,7 +10,6 @@ import torch._dynamo.testing from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters -from torch.testing._internal.common_utils import skipIfWindows def my_custom_function(x): @@ -893,9 +892,6 @@ def gn(x): self.assertEqual(gn(inp), inp + 3) self.assertEqual(cnts.frame_count, 1) - @skipIfWindows( - msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows." - ) def test_disable_recursive_false(self): def fn2(x): return x + 1 diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index a5a6ee54aa74a..439b0361690b2 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -21,10 +21,8 @@ from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.common_utils import ( find_free_port, - IS_WINDOWS, munge_exc, skipIfTorchDynamo, - skipIfWindows, TEST_XPU, xfailIf, ) @@ -530,7 +528,7 @@ def test_invalid_artifact_flag_error_msg(self): "import torch", env=env, ) - lines = stderr.decode().split("\r\n" if IS_WINDOWS else "\n") + lines = stderr.decode().split("\n") # This is a sanity assert that our error is not spammy. # As of this test creation this was 18. # See this issue for the purpose o this test: @@ -546,7 +544,6 @@ def test_invalid_artifact_flag_error_msg(self): self.assertEqual(lines[-4], "Valid settings:") @requires_distributed() - @skipIfWindows(msg="TODO: (xuhancn), Can't reproduce locally") def test_distributed_rank_logging(self): env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 75d091595cd8a..7e35c93ee0b79 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -26,7 +26,6 @@ ) from torch.testing._internal.common_utils import ( IS_MACOS, - IS_WINDOWS, parametrize, skipIfWindows, TEST_MKL, @@ -3095,5 +3094,5 @@ def forward(self, x, weight): if __name__ == "__main__": from torch.testing._internal.inductor_utils import HAS_CPU - if HAS_CPU and not (IS_MACOS or IS_WINDOWS): + if HAS_CPU and not IS_MACOS: run_tests() diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index f8bde6222dbea..230aac4794f25 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -41,9 +41,11 @@ def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF: return # skip testing - # Enable Inductor UTs on Windows for CPU. - # CUDA on Windows is not verified, NVDA developer can continue to enable CUDA based on CPU path. - if torch.cuda.is_available() and IS_WINDOWS: + if ( + not torch.xpu.is_available() + and IS_WINDOWS + and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0" + ): return if isinstance(needs, str): From d6786741a77aba200c78002646cc069b7a1799b0 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sun, 10 Aug 2025 18:35:42 +0000 Subject: [PATCH 0197/1424] [inductor] slow test some Windows UTs. (#160267) When we enabled Windows inductor UTs since the PR: https://github.com/pytorch/pytorch/pull/160161/ The main branch CI occurred timeout issue, Let's move some UT to slow test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160267 Approved by: https://github.com/ezyang --- test/test_schema_check.py | 5 ++++- test/test_torch.py | 16 ++++++++++++++++ test/test_unary_ufuncs.py | 14 ++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/test/test_schema_check.py b/test/test_schema_check.py index 29ea36fd8a5f5..91d9a484d3c89 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -14,9 +14,12 @@ from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests +from torch.testing._internal.common_utils import IS_WINDOWS, slowTestIf pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) + + def secretly_aliasing(x): return x.view(-1) @@ -493,9 +496,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with SchemaInfoBindTestMode(self) as schemaInfoCheck: x.add(x) - class TestSchemaCheckModeOpInfo(JitTestCase): @ops(op_db, dtypes=OpDTypes.supported) + @slowTestIf(IS_WINDOWS) def test_schema_correctness(self, device, dtype, op): # Currently torch.equal isn't supported with torch.complex32 # There's also errors with complex64 and complex128 diff --git a/test/test_torch.py b/test/test_torch.py index ef23f13e4376b..d55fd1aeb6e83 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -66,6 +66,7 @@ get_all_qint_dtypes, all_types_complex_float8_and, ) from torch.testing._internal.two_tensor import TwoTensor +from torch.testing._internal.common_utils import IS_WINDOWS if TEST_WITH_TORCHINDUCTOR: from torch._inductor.test_case import TestCase @@ -158,6 +159,7 @@ def test_constants(self, device): self.assertEqual(torch.inf, math.inf) @onlyNativeDeviceTypes + @slowTestIf(IS_WINDOWS) @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.uint16, torch.uint32, torch.uint64) @@ -190,6 +192,7 @@ def test_int64_upsample3d(self, device, dtype): @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.uint16, torch.uint32, torch.uint64) + @slowTestIf(IS_WINDOWS) def test_storage(self, device, dtype): v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9) self.assertEqual(v.storage()[0], v[0][0]) @@ -220,6 +223,7 @@ def test_storage(self, device, dtype): torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.quint8, torch.qint8, torch.qint32, torch.quint4x2) + @slowTestIf(IS_WINDOWS) def test_storage_setitem(self, device, dtype): # Skip quantized dtypes for CUDA, since they're not supported if torch.device(device).type == 'cuda': @@ -251,6 +255,7 @@ def test_storage_setitem(self, device, dtype): @skipIfTorchDynamo("Not a suitable test for TorchDynamo") @onlyNativeDeviceTypes + @slowTestIf(IS_WINDOWS) def test_storage_use_count(self, device): a = torch.randn(10, device=device) prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata) @@ -261,6 +266,7 @@ def test_storage_use_count(self, device): @xfailIfTorchDynamo @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_tensor_storage_type(self, device, dtype): a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) @@ -271,6 +277,7 @@ def test_tensor_storage_type(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64)) + @slowTestIf(IS_WINDOWS) def test_tensor_from_storage(self, device, dtype): a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) a_s = a.storage() @@ -288,6 +295,7 @@ def test_tensor_from_storage(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_set_storage(self, device, dtype): a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) a_s = a.storage() @@ -326,6 +334,7 @@ def _check_storage_meta(self, s, s_check): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_typed_storage_meta(self, device, dtype): args_list = [ [], @@ -339,6 +348,7 @@ def test_typed_storage_meta(self, device, dtype): self._check_storage_meta(s, s_check) @onlyNativeDeviceTypes + @slowTestIf(IS_WINDOWS) def test_untyped_storage_meta(self, device): args_list = [ [], @@ -353,6 +363,7 @@ def test_untyped_storage_meta(self, device): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_storage_meta_from_tensor(self, device, dtype): t_check = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) t = t_check.to('meta') @@ -362,6 +373,7 @@ def test_storage_meta_from_tensor(self, device, dtype): self._check_storage_meta(s, s_check) @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_storage_meta_errors(self, device, dtype): s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) @@ -402,6 +414,7 @@ def test_storage_meta_errors(self, device, dtype): @onlyCPU @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_storage_meta_ok(self, device, dtype): s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) @@ -417,6 +430,7 @@ def test_module_share_memory(self): model.share_memory() @dtypes(torch.float32, torch.complex64) + @slowTestIf(IS_WINDOWS) def test_deepcopy(self, device, dtype): from copy import deepcopy a = torch.randn(5, 5, dtype=dtype, device=device) @@ -444,6 +458,7 @@ def test_deepcopy(self, device, dtype): self.assertEqual(deepcopy(a).foo, 3) @dtypes(torch.float32, torch.complex64) + @slowTestIf(IS_WINDOWS) def test_deepcopy_scalar(self, device, dtype): from copy import deepcopy a = torch.tensor(5, dtype=dtype, device=device) @@ -3696,6 +3711,7 @@ def ref_index_select(src, dim, idx): # FIXME: find a test suite for the take operator @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @slowTestIf(IS_WINDOWS) def test_take(self, device, dtype): idx_size = (4,) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index d7d9a2b1aab6d..9939e8e76ce94 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -54,6 +54,8 @@ ) from torch.utils import _pytree as pytree +from torch.testing._internal.common_utils import IS_WINDOWS, slowTestIf + if TEST_SCIPY: import scipy @@ -271,6 +273,7 @@ def _helper_reference_numerics( # and noncontiguities. @suppress_warnings @ops(reference_filtered_ops) + @slowTestIf(IS_WINDOWS) def test_reference_numerics_normal(self, device, dtype, op): tensors = generate_elementwise_unary_tensors( op, device=device, dtype=dtype, requires_grad=False @@ -279,6 +282,7 @@ def test_reference_numerics_normal(self, device, dtype, op): @suppress_warnings @ops(reference_filtered_ops) + @slowTestIf(IS_WINDOWS) def test_reference_numerics_small(self, device, dtype, op): if dtype in (torch.bool,): raise self.skipTest("bool has no small values") @@ -290,6 +294,7 @@ def test_reference_numerics_small(self, device, dtype, op): @suppress_warnings @ops(reference_filtered_ops) + @slowTestIf(IS_WINDOWS) def test_reference_numerics_large(self, device, dtype, op): if dtype in (torch.bool, torch.uint8, torch.int8): raise self.skipTest("bool, uint8, and int8 dtypes have no large values") @@ -304,6 +309,7 @@ def test_reference_numerics_large(self, device, dtype, op): reference_filtered_ops, allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), ) + @slowTestIf(IS_WINDOWS) def test_reference_numerics_extremal(self, device, dtype, op): tensors = generate_elementwise_unary_extremal_value_tensors( op, device=device, dtype=dtype, requires_grad=False @@ -312,6 +318,7 @@ def test_reference_numerics_extremal(self, device, dtype, op): # Tests for testing (non)contiguity consistency @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_contig_vs_every_other(self, device, dtype, op): contig = make_tensor( (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] @@ -328,6 +335,7 @@ def test_contig_vs_every_other(self, device, dtype, op): self.assertEqual(result, expected) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_contig_vs_transposed(self, device, dtype, op): contig = make_tensor( (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] @@ -344,6 +352,7 @@ def test_contig_vs_transposed(self, device, dtype, op): self.assertEqual(result, expected) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_non_contig(self, device, dtype, op): shapes = [(5, 7), (1024,)] for shape in shapes: @@ -360,6 +369,7 @@ def test_non_contig(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_non_contig_index(self, device, dtype, op): contig = make_tensor( (2, 2, 1, 2), @@ -378,6 +388,7 @@ def test_non_contig_index(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_non_contig_expand(self, device, dtype, op): shapes = [(1, 3), (1, 7), (5, 7)] for shape in shapes: @@ -399,6 +410,7 @@ def test_non_contig_expand(self, device, dtype, op): ) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_contig_size1(self, device, dtype, op): contig = make_tensor( (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] @@ -414,6 +426,7 @@ def test_contig_size1(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs)) @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_contig_size1_large_dim(self, device, dtype, op): contig = make_tensor( (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), @@ -435,6 +448,7 @@ def test_contig_size1_large_dim(self, device, dtype, op): # Tests that computation on a multiple batches is the same as # per-batch computation. @ops(unary_ufuncs) + @slowTestIf(IS_WINDOWS) def test_batch_vs_slicing(self, device, dtype, op): input = make_tensor( (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] From 05c19d1acecc01b0d2512364183058a6885b9869 Mon Sep 17 00:00:00 2001 From: "Andy (An) Wang" Date: Sun, 10 Aug 2025 19:20:27 +0000 Subject: [PATCH 0198/1424] [Inductor] Add back the revert part (#160054) Add back the reverted code(https://github.com/pytorch/pytorch/pull/159809) as we've figured out the actual root cause of the internal test failures. Mote details in the internal diff. Rollback Plan: Differential Revision: D79776691 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160054 Approved by: https://github.com/blaine-rister --- torch/_dynamo/device_interface.py | 4 ++++ torch/utils/_triton.py | 1 + 2 files changed, 5 insertions(+) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index ada43dd08393b..9ea53c900b054 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -590,6 +590,10 @@ def init_device_reg() -> None: for i in range(torch.xpu.device_count()): register_interface_for_device(f"xpu:{i}", XpuInterface) + register_interface_for_device("mtia", MtiaInterface) + for i in range(torch.mtia.device_count()): + register_interface_for_device(f"mtia:{i}", MtiaInterface) + register_interface_for_device("cpu", CpuInterface) register_interface_for_device("mps", MpsInterface) diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 55beae4baf18a..af1e5e0e6f42a 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -135,6 +135,7 @@ def _return_true(device_interface: Any) -> bool: "cuda": cuda_extra_check, "xpu": _return_true, "cpu": cpu_extra_check, + "mtia": _return_true, } def is_device_compatible_with_triton() -> bool: From 4416433c7c625127b7f975c92f8ec98ea4c67fd3 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sun, 10 Aug 2025 23:18:35 +0000 Subject: [PATCH 0199/1424] [inductor] turn on windows inductor UTs (#160161) With this PR, we can turn on the inductor UTs on Windows CPU. changes: 1. Turn on inductor UTs on Windows CPU. 2. Add a shard to balance added UTs, otherwise it should run timeout. 3. Fixed `test_invalid_artifact_flag_error_msg`. 4. Skiped `test_distributed_rank_logging` and `test_disable_recursive_false`. 5. Skiped whole UT `test_cpu_select_algorithm.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160161 Approved by: https://github.com/jansel --- .github/workflows/trunk.yml | 8 +++++--- test/dynamo/test_decorators.py | 4 ++++ test/dynamo/test_logging.py | 5 ++++- test/inductor/test_cpu_select_algorithm.py | 3 ++- torch/_dynamo/test_case.py | 8 +++----- 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index c7cf4c84e1888..a4d665c202d34 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -123,9 +123,11 @@ jobs: runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, ]} secrets: inherit diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 3b29e5e961192..9bf982c5b90ec 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -10,6 +10,7 @@ import torch._dynamo.testing from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import skipIfWindows def my_custom_function(x): @@ -892,6 +893,9 @@ def gn(x): self.assertEqual(gn(inp), inp + 3) self.assertEqual(cnts.frame_count, 1) + @skipIfWindows( + msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows." + ) def test_disable_recursive_false(self): def fn2(x): return x + 1 diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 439b0361690b2..a5a6ee54aa74a 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -21,8 +21,10 @@ from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.common_utils import ( find_free_port, + IS_WINDOWS, munge_exc, skipIfTorchDynamo, + skipIfWindows, TEST_XPU, xfailIf, ) @@ -528,7 +530,7 @@ def test_invalid_artifact_flag_error_msg(self): "import torch", env=env, ) - lines = stderr.decode().split("\n") + lines = stderr.decode().split("\r\n" if IS_WINDOWS else "\n") # This is a sanity assert that our error is not spammy. # As of this test creation this was 18. # See this issue for the purpose o this test: @@ -544,6 +546,7 @@ def test_invalid_artifact_flag_error_msg(self): self.assertEqual(lines[-4], "Valid settings:") @requires_distributed() + @skipIfWindows(msg="TODO: (xuhancn), Can't reproduce locally") def test_distributed_rank_logging(self): env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 7e35c93ee0b79..75d091595cd8a 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -26,6 +26,7 @@ ) from torch.testing._internal.common_utils import ( IS_MACOS, + IS_WINDOWS, parametrize, skipIfWindows, TEST_MKL, @@ -3094,5 +3095,5 @@ def forward(self, x, weight): if __name__ == "__main__": from torch.testing._internal.inductor_utils import HAS_CPU - if HAS_CPU and not IS_MACOS: + if HAS_CPU and not (IS_MACOS or IS_WINDOWS): run_tests() diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 230aac4794f25..f8bde6222dbea 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -41,11 +41,9 @@ def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF: return # skip testing - if ( - not torch.xpu.is_available() - and IS_WINDOWS - and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0" - ): + # Enable Inductor UTs on Windows for CPU. + # CUDA on Windows is not verified, NVDA developer can continue to enable CUDA based on CPU path. + if torch.cuda.is_available() and IS_WINDOWS: return if isinstance(needs, str): From b602ea9cab7d43a7ee7b4051227090f23fbd3dbf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 11 Aug 2025 00:04:25 +0000 Subject: [PATCH 0200/1424] Revert "[inductor] turn on windows inductor UTs (#160161)" This reverts commit 4416433c7c625127b7f975c92f8ec98ea4c67fd3. Reverted https://github.com/pytorch/pytorch/pull/160161 on behalf of https://github.com/xuhancn due to auto merged with two related issue ([comment](https://github.com/pytorch/pytorch/pull/160161#issuecomment-3172982125)) --- .github/workflows/trunk.yml | 8 +++----- test/dynamo/test_decorators.py | 4 ---- test/dynamo/test_logging.py | 5 +---- test/inductor/test_cpu_select_algorithm.py | 3 +-- torch/_dynamo/test_case.py | 8 +++++--- 5 files changed, 10 insertions(+), 18 deletions(-) diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index a4d665c202d34..c7cf4c84e1888 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -123,11 +123,9 @@ jobs: runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, ]} secrets: inherit diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 9bf982c5b90ec..3b29e5e961192 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -10,7 +10,6 @@ import torch._dynamo.testing from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters -from torch.testing._internal.common_utils import skipIfWindows def my_custom_function(x): @@ -893,9 +892,6 @@ def gn(x): self.assertEqual(gn(inp), inp + 3) self.assertEqual(cnts.frame_count, 1) - @skipIfWindows( - msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows." - ) def test_disable_recursive_false(self): def fn2(x): return x + 1 diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index a5a6ee54aa74a..439b0361690b2 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -21,10 +21,8 @@ from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.common_utils import ( find_free_port, - IS_WINDOWS, munge_exc, skipIfTorchDynamo, - skipIfWindows, TEST_XPU, xfailIf, ) @@ -530,7 +528,7 @@ def test_invalid_artifact_flag_error_msg(self): "import torch", env=env, ) - lines = stderr.decode().split("\r\n" if IS_WINDOWS else "\n") + lines = stderr.decode().split("\n") # This is a sanity assert that our error is not spammy. # As of this test creation this was 18. # See this issue for the purpose o this test: @@ -546,7 +544,6 @@ def test_invalid_artifact_flag_error_msg(self): self.assertEqual(lines[-4], "Valid settings:") @requires_distributed() - @skipIfWindows(msg="TODO: (xuhancn), Can't reproduce locally") def test_distributed_rank_logging(self): env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 75d091595cd8a..7e35c93ee0b79 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -26,7 +26,6 @@ ) from torch.testing._internal.common_utils import ( IS_MACOS, - IS_WINDOWS, parametrize, skipIfWindows, TEST_MKL, @@ -3095,5 +3094,5 @@ def forward(self, x, weight): if __name__ == "__main__": from torch.testing._internal.inductor_utils import HAS_CPU - if HAS_CPU and not (IS_MACOS or IS_WINDOWS): + if HAS_CPU and not IS_MACOS: run_tests() diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index f8bde6222dbea..230aac4794f25 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -41,9 +41,11 @@ def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF: return # skip testing - # Enable Inductor UTs on Windows for CPU. - # CUDA on Windows is not verified, NVDA developer can continue to enable CUDA based on CPU path. - if torch.cuda.is_available() and IS_WINDOWS: + if ( + not torch.xpu.is_available() + and IS_WINDOWS + and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0" + ): return if isinstance(needs, str): From 842cc77ab9aafd518593c2fce077d6abb42a5b7f Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sun, 10 Aug 2025 19:48:04 -0400 Subject: [PATCH 0201/1424] [MPS] Extend addmm to integral types (#160270) By adding `addmm` kernel, which is a logical continuation of `mm` one. The only tricking part are how alpha and beta constants are handled, which are passed as `optmath_t`, i.e. that it could be, int64, int32 or float Unified all MM flavors instantiations thru `INSTANTIATE_MM_OPS` and tested that `addmm` metal kernel works as expected for floating types as well by testing it via ``` PYTORCH_MPS_PREFER_METAL=1 python test/test_mps.py -v -k test_output_match_addmm_mps_ ``` Fixes https://github.com/pytorch/pytorch/issues/154901 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160270 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #160228, #160234 --- .../native/mps/kernels/LinearAlgebra.metal | 85 +++++++++++++------ .../native/mps/operations/LinearAlgebra.mm | 60 ++++++++++++- torch/testing/_internal/common_mps.py | 8 -- 3 files changed, 119 insertions(+), 34 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index 92774f3ff2668..4ba2bca720db7 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -68,6 +68,37 @@ kernel void matmul( } } +template +kernel void addmm( + constant T* mat1Data [[buffer(0)]], + constant T* mat2Data [[buffer(1)]], + device T* outputData [[buffer(2)]], + constant T* biasData [[buffer(3)]], + constant array, 2>& alpha_beta [[buffer(4)]], + constant array& strides [[buffer(5)]], + constant uint3& sizes [[buffer(6)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 thread_id [[thread_position_in_grid]]) { + threadgroup T A_tile[TILE_DIM][TILE_DIM]; + threadgroup T B_tile[TILE_DIM][TILE_DIM]; + + auto sum = matmul_inner( + mat1Data, + mat2Data, + reinterpret_cast&>(strides), + sizes, + A_tile, + B_tile, + tid, + thread_id); + if (thread_id.y < sizes.x && thread_id.x < sizes.z) { + auto bias = + biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y]; + outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] = + static_cast(alpha_beta[0] * sum + alpha_beta[1] * bias); + } +} + template kernel void naive_bmm( constant T* mat1Data [[buffer(0)]], @@ -613,17 +644,15 @@ kernel void applyPivots( } } -#define INSTANTIATE_NAIVE_MM(DTYPE) \ - template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ - constant DTYPE * mat1Data [[buffer(0)]], \ - constant DTYPE * mat2Data [[buffer(1)]], \ - device DTYPE * outputData [[buffer(2)]], \ - constant array & strides [[buffer(3)]], \ - constant uint3 & sizes [[buffer(4)]], \ - uint2 tid [[thread_position_in_threadgroup]], \ - uint2 group_id [[threadgroup_position_in_grid]]) - -#define INSTANTIATE_NAIVE_BMM(DTYPE) \ +#define INSTANTIATE_MM_OPS(DTYPE) \ + template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ + constant DTYPE * mat1Data [[buffer(0)]], \ + constant DTYPE * mat2Data [[buffer(1)]], \ + device DTYPE * outputData [[buffer(2)]], \ + constant array & strides [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 tid [[thread_position_in_threadgroup]], \ + uint2 group_id [[threadgroup_position_in_grid]]); \ template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm( \ constant DTYPE * mat1Data [[buffer(0)]], \ constant DTYPE * mat2Data [[buffer(1)]], \ @@ -631,20 +660,26 @@ kernel void applyPivots( constant array & strides [[buffer(3)]], \ constant uint4 & sizes [[buffer(4)]], \ uint3 tid [[thread_position_in_threadgroup]], \ - uint3 group_id [[threadgroup_position_in_grid]]) + uint3 group_id [[threadgroup_position_in_grid]]); \ + template [[host_name("addmm_" #DTYPE)]] kernel void addmm( \ + constant DTYPE * mat1Data [[buffer(0)]], \ + constant DTYPE * mat2Data [[buffer(1)]], \ + device DTYPE * outputData [[buffer(2)]], \ + constant DTYPE * biasData [[buffer(3)]], \ + constant array, 2> & \ + alpha_beta [[buffer(4)]], \ + constant array & strides [[buffer(5)]], \ + constant uint3 & sizes [[buffer(6)]], \ + uint2 tid [[thread_position_in_threadgroup]], \ + uint2 group_id [[threadgroup_position_in_grid]]) -INSTANTIATE_NAIVE_MM(float); -INSTANTIATE_NAIVE_MM(half); -INSTANTIATE_NAIVE_MM(bfloat); +INSTANTIATE_MM_OPS(float); +INSTANTIATE_MM_OPS(half); +INSTANTIATE_MM_OPS(bfloat); // Integral MM -INSTANTIATE_NAIVE_MM(short); -INSTANTIATE_NAIVE_MM(int); -INSTANTIATE_NAIVE_MM(long); -INSTANTIATE_NAIVE_MM(char); -INSTANTIATE_NAIVE_MM(uchar); -INSTANTIATE_NAIVE_BMM(short); -INSTANTIATE_NAIVE_BMM(int); -INSTANTIATE_NAIVE_BMM(long); -INSTANTIATE_NAIVE_BMM(char); -INSTANTIATE_NAIVE_BMM(uchar); +INSTANTIATE_MM_OPS(long); +INSTANTIATE_MM_OPS(int); +INSTANTIATE_MM_OPS(short); +INSTANTIATE_MM_OPS(char); +INSTANTIATE_MM_OPS(uchar); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 3cdf0021e987f..7a3dde679c05f 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -112,6 +112,61 @@ return output; } +Tensor& do_metal_addmm(const Tensor& self, + const Tensor& other, + Tensor& output, + const Scalar& alpha, + const Scalar& beta, + const Tensor& bias) { + if (beta.toDouble() == 0 && alpha.toDouble() == 1) { + return do_metal_mm(self, other, output); + } + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto matmulPSO = lib.getPipelineStateForFunc("addmm_" + mps::scalarToMetalTypeString(output)); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + getMPSProfiler().beginProfileKernel(matmulPSO, "addmm", {self, other}); + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:matmulPSO]; + std::array sizes = {static_cast(self.size(0)), + static_cast(self.size(1)), + static_cast(output.size(1))}; + std::array strides = {self.stride(0), + self.stride(1), + other.stride(0), + other.stride(1), + output.stride(0), + output.stride(1), + bias.stride(0), + bias.stride(1)}; + union { + std::array i64; + std::array i32; + std::array f32; + } alpha_beta; + if (output.scalar_type() == kLong) { + alpha_beta.i64 = {alpha.toLong(), beta.toLong()}; + } else if (c10::isIntegralType(output.scalar_type(), true)) { + alpha_beta.i32 = {alpha.toInt(), beta.toInt()}; + } else { + TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type())); + alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()}; + } + constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs + uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM; + uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM; + + MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1); + mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta.i64, strides, sizes); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + getMPSProfiler().endProfileKernel(matmulPSO); + } + }); + return output; +} + std::tuple do_mm(MPSGraph* graph, const Tensor& self, const Tensor& other) { @@ -644,7 +699,6 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const TORCH_CHECK(output.is_mps()); TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input"); TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}}; checkAllSameGPU(__func__, args); @@ -671,6 +725,10 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const return output; } + if (use_metal_mm(self, other, output)) { + return do_metal_addmm(self, other, output, alpha, beta, *bias_); + } + bool is_beta_non_zero = beta.toDouble() != 0.0; struct CachedGraph : public mps::MPSCachedGraph { diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 2aefcce61b73c..0391a314568a3 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -428,15 +428,7 @@ def mps_ops_modifier( torch.uint8, torch.int8, ], - "addmmdecomposed": [ - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - "addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # returned output on CPU is float64 From e7152ff8a6a929a0db7f3f4a72a5b6d471769cd3 Mon Sep 17 00:00:00 2001 From: "Han, Xu" Date: Mon, 11 Aug 2025 02:55:37 +0000 Subject: [PATCH 0202/1424] [inductor] fix some windows inductor UTs (#160292) This PR is the UT part of https://github.com/pytorch/pytorch/pull/160161. As @malfet 's comments: https://github.com/pytorch/pytorch/pull/160161#pullrequestreview-3103812178 This PR will not land turn on change, and only land UT part. changes: 1. Fixed `test_invalid_artifact_flag_error_msg`. 2. Skiped `test_distributed_rank_logging` and `test_disable_recursive_false`. 3. Skiped whole UT `test_cpu_select_algorithm.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160292 Approved by: https://github.com/malfet --- test/dynamo/test_decorators.py | 4 ++++ test/dynamo/test_logging.py | 5 ++++- test/inductor/test_cpu_select_algorithm.py | 3 ++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 3b29e5e961192..9bf982c5b90ec 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -10,6 +10,7 @@ import torch._dynamo.testing from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import skipIfWindows def my_custom_function(x): @@ -892,6 +893,9 @@ def gn(x): self.assertEqual(gn(inp), inp + 3) self.assertEqual(cnts.frame_count, 1) + @skipIfWindows( + msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows." + ) def test_disable_recursive_false(self): def fn2(x): return x + 1 diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 439b0361690b2..a5a6ee54aa74a 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -21,8 +21,10 @@ from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.common_utils import ( find_free_port, + IS_WINDOWS, munge_exc, skipIfTorchDynamo, + skipIfWindows, TEST_XPU, xfailIf, ) @@ -528,7 +530,7 @@ def test_invalid_artifact_flag_error_msg(self): "import torch", env=env, ) - lines = stderr.decode().split("\n") + lines = stderr.decode().split("\r\n" if IS_WINDOWS else "\n") # This is a sanity assert that our error is not spammy. # As of this test creation this was 18. # See this issue for the purpose o this test: @@ -544,6 +546,7 @@ def test_invalid_artifact_flag_error_msg(self): self.assertEqual(lines[-4], "Valid settings:") @requires_distributed() + @skipIfWindows(msg="TODO: (xuhancn), Can't reproduce locally") def test_distributed_rank_logging(self): env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 7e35c93ee0b79..75d091595cd8a 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -26,6 +26,7 @@ ) from torch.testing._internal.common_utils import ( IS_MACOS, + IS_WINDOWS, parametrize, skipIfWindows, TEST_MKL, @@ -3094,5 +3095,5 @@ def forward(self, x, weight): if __name__ == "__main__": from torch.testing._internal.inductor_utils import HAS_CPU - if HAS_CPU and not IS_MACOS: + if HAS_CPU and not (IS_MACOS or IS_WINDOWS): run_tests() From d8cb3db5339b45e4b745b2b883ef3ecde9843e2c Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sun, 10 Aug 2025 20:07:40 -0400 Subject: [PATCH 0203/1424] Add unsigned support to `IValue` (#160102) - Moved repeated logic of saving int64/uint64 into a polymorphic container into `THPUtils_unpackInteger` - Added `TestPythonDispatch.test_dispatch_uint64` regression test Fixes https://github.com/pytorch/pytorch/issues/159168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160102 Approved by: https://github.com/ezyang --- aten/src/ATen/core/ivalue.cpp | 8 +++++ aten/src/ATen/core/ivalue.h | 41 ++++++++++++++++++++++++-- test/test_python_dispatch.py | 13 ++++++++ torch/csrc/jit/python/pybind_utils.cpp | 8 +++-- torch/csrc/utils/python_arg_parser.cpp | 16 +--------- torch/csrc/utils/python_numbers.h | 19 ++++++++++++ 6 files changed, 86 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index c6087f0a68ecf..72589436606ec 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -97,6 +97,8 @@ c10::TypePtr IValue::TagType::get(const IValue& v) { return ComplexType::get(); case Tag::Int: return IntType::get(); + case Tag::UInt: + return IntType::get(); case Tag::SymInt: return c10::SymIntType::get(); case Tag::SymFloat: @@ -320,6 +322,8 @@ IValue IValue::equals(const IValue& rhs) const { return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble(); case Tag::Int: return rhs.isInt() && lhs.toInt() == rhs.toInt(); + case Tag::UInt: + return rhs.isUnsigned() && lhs.toUInt() == rhs.toUInt(); case Tag::SymInt: return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt(); case Tag::SymFloat: @@ -379,6 +383,8 @@ size_t IValue::hash(const IValue& v) { case Tag::Int: return c10::get_hash(v.payload.u.as_int); // NB: these are technically strict aliasing violations + case Tag::UInt: + return c10::get_hash(v.payload.u.as_int); case Tag::SymInt: return c10::get_hash(v.payload.u.as_int); case Tag::SymFloat: @@ -806,6 +812,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { return printComplex(out, v); } case IValue::Tag::Int: return out << v.toInt(); + case IValue::Tag::UInt: + return out << v.toUInt(); case IValue::Tag::SymInt: return out << v.toSymInt(); case IValue::Tag::SymFloat: diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 175860dc99a7c..ab2039e058201 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -160,6 +161,7 @@ struct Capsule { _(Double) \ _(ComplexDouble) \ _(Int) \ + _(UInt) \ _(SymInt) \ _(SymFloat) \ _(SymBool) \ @@ -653,6 +655,29 @@ struct TORCH_API IValue final { } } + // Unsigned + IValue(uint64_t u) : tag( u <= std::numeric_limits::max() ? Tag::Int : Tag::UInt) { + payload.u.as_uint = u; + } + + + // See Note [Meaning of HAS_u] + // IValue type model closely follows that of c10::Scalar + // Where all integers are upcast to 64-bit representation, and `as_int` is used as default + // representation unless value could not be represented as signed int + bool isUnsigned() const { + return Tag::UInt == tag || (Tag::Int == tag && payload.u.as_int >= 0); + } + + uint64_t toUInt() const { + if (isUnsigned()) { + return payload.u.as_uint; + } else { + TORCH_INTERNAL_ASSERT(0, "expected unsigned int"); + } + } + + // Bool IValue(bool b) : tag(Tag::Bool) { #if defined(__clang__) && defined(__x86_64__) @@ -893,8 +918,14 @@ struct TORCH_API IValue final { } else { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( s.isIntegral(false), "Unknown type in Scalar"); - tag = Tag::Int; - payload.u.as_int = s.toLong(); + if (s.isUnsigned()) { + const auto val = s.toUInt64(); + payload.u.as_uint = val; + tag = val <= std::numeric_limits::max() ? Tag::Int : Tag::UInt; + } else { + payload.u.as_int = s.toLong(); + tag = Tag::Int; + } } } @@ -918,6 +949,8 @@ struct TORCH_API IValue final { return toSymFloat(); else if (isSymBool()) return toSymBool(); + else if (isUnsigned()) + return toUInt(); TORCH_CHECK(false, "IValue is not a Scalar"); } @@ -1247,6 +1280,8 @@ struct TORCH_API IValue final { return true; case Tag::Int: return false; + case Tag::UInt: + return false; case Tag::SymInt: return true; case Tag::SymFloat: @@ -1343,6 +1378,8 @@ struct TORCH_API IValue final { union TriviallyCopyablePayload { TriviallyCopyablePayload() : as_int(0) {} int64_t as_int; + // See Note [Meaning of HAS_u] + uint64_t as_uint; double as_double; bool as_bool; // Invariant: never nullptr; null state is represented as diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index e0480ba6a6842..71ebf5d784308 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -2513,6 +2513,19 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None): with Mode(): torch.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,)) + def test_dispatch_uint64(self): + class DummyMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args, kwargs): + self.last_args = args + return func(*args, **kwargs) + + # Value that could not be intepreted as signed int64 + uarg = 2**63 + 1 + with DummyMode() as m: + a = torch.full((3, 3), uarg, dtype=torch.uint64) + self.assertEqual(m.last_args[1], uarg) + self.assertTrue((a == uarg).all().item()) + class TestPythonDispatcher(TestCase): def test_basic(self): diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 3f2708619be86..e30648399c5ae 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -90,7 +90,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { if (PyBool_Check(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackBool(obj.ptr())); } else if (THPUtils_checkLong(obj.ptr())) { - scalar = at::Scalar(THPUtils_unpackLong(obj.ptr())); + scalar = THPUtils_unpackInteger(obj.ptr()); } else if (PyComplex_Check(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr())); } else if (THPUtils_checkDouble(obj.ptr())) { @@ -512,7 +512,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { if (py::isinstance(obj)) { return py::cast(obj); } else if (py::isinstance(obj)) { - return py::cast(obj); + return THPUtils_unpackInteger(obj.ptr()); } else if (py::isinstance(obj)) { return py::cast(obj); } else if (PyComplex_CheckExact(obj.ptr())) { @@ -598,6 +598,8 @@ py::object toPyObject(IValue ivalue) { return py::cast(*tensor.const_data_ptr()); case at::ScalarType::Long: return py::cast(*tensor.const_data_ptr()); + case at::ScalarType::UInt64: + return py::cast(*tensor.const_data_ptr()); case at::ScalarType::Double: return py::cast(*tensor.const_data_ptr()); case at::ScalarType::ComplexDouble: @@ -763,6 +765,8 @@ py::object toPyObject(IValue ivalue) { return py::cast(std::move(ivalue).toSymFloat()); } else if (ivalue.isSymBool()) { return py::cast(std::move(ivalue).toSymBool()); + } else if (ivalue.isUnsigned()) { + return py::cast(std::move(ivalue).toUInt()); } else { TORCH_CHECK( false, diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 7066b164a2280..1ae03f91f2180 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -1801,21 +1801,7 @@ at::Tensor PythonArgs::tensor_slow(int i) { if (PyBool_Check(obj)) { scalar = at::Scalar(THPUtils_unpackBool(obj)); } else if (THPUtils_checkLong(obj)) { - int overflow = -1; - long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); - if (value == -1 && PyErr_Occurred()) { - throw python_error(); - } - if (overflow != 0) { - // try unsigned - unsigned long long value = PyLong_AsUnsignedLongLong(obj); - if (value == static_cast(-1) && PyErr_Occurred()) { - throw python_error(); - } - scalar = at::Scalar(static_cast(value)); - } else { - scalar = at::Scalar(static_cast(value)); - } + scalar = THPUtils_unpackInteger(obj); } else if (PyComplex_Check(obj)) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj)); } else if (THPUtils_checkDouble(obj)) { diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index 25ca2692b3291..a8b9b8632a00b 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -208,3 +208,22 @@ inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) { } return (c10::DeviceIndex)value; } + +template +inline T THPUtils_unpackInteger(PyObject* obj) { + int overflow = -1; + const auto value = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (!overflow) { + return static_cast(value); + } + // try unsigned + const auto uvalue = PyLong_AsUnsignedLongLong(obj); + if (uvalue == static_cast>(-1) && + PyErr_Occurred()) { + throw python_error(); + } + return static_cast(uvalue); +} From 8088cfa592504a2897b4c78f8a46fe658ab5c2c2 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 10 Aug 2025 12:04:23 -0700 Subject: [PATCH 0204/1424] Add type assert for tensor_meta, based on real bug in autoparallel. (#157927) Signed-off-by: Edward Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/157927 Approved by: https://github.com/albanD, https://github.com/Skylion007, https://github.com/wconstab --- torch/distributed/tensor/_dtensor_spec.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index eb528ee4f9af1..bffb399b2bca8 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -40,6 +40,16 @@ def __setattr__(self, attr: str, value: Any) -> None: # change (though we do not expect `mesh` or `placements` to change) if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): self._hash = None + # This assert was triggered by buggy handling for dict outputs in some + # FX passes, where you accidentally iterate over a dict and try to put + # keys into TensorMeta. See https://github.com/pytorch/pytorch/issues/157919 + if attr == "tensor_meta" and value is not None: + from torch.fx.passes.shape_prop import TensorMetadata + + # TODO: the TensorMetadata arises from + # test/distributed/tensor/experimental/test_tp_transform.py::TensorParallelTest::test_tp_transform_e2e + # but I actually can't reproduce it, maybe it is also a bug! + assert isinstance(value, (TensorMeta, TensorMetadata)), value def _hash_impl(self) -> int: # hashing and equality check for DTensorSpec are used to cache the sharding From 8ae4d2652f64b8444b3d5314b9232bd2119bcde6 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 11 Aug 2025 04:50:35 +0000 Subject: [PATCH 0205/1424] Tidy torch/csrc/jit/passes/onnx code (#160262) Apply clang-tidy fixes to torch/csrc/jit/passes/onnx Pull Request resolved: https://github.com/pytorch/pytorch/pull/160262 Approved by: https://github.com/justinchuby --- torch/csrc/jit/passes/onnx/constant_fold.cpp | 4 +- .../jit/passes/onnx/function_extraction.cpp | 4 +- torch/csrc/jit/passes/onnx/peephole.cpp | 4 +- .../onnx/remove_inplace_ops_for_onnx.cpp | 4 +- .../jit/passes/onnx/shape_type_inference.cpp | 26 +++--- .../passes/onnx/unpack_quantized_weights.cpp | 81 +------------------ 6 files changed, 23 insertions(+), 100 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 9cf12ffde38a2..0ac07adf0d45c 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -76,8 +76,8 @@ static std::optional runTorchSlice_opset9( if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) { return std::nullopt; } - auto startsAttr = node->is(attr::starts); - auto endsAttr = node->is(attr::ends); + auto const& startsAttr = node->is(attr::starts); + auto const& endsAttr = node->is(attr::ends); if (startsAttr.size() != endsAttr.size()) { return std::nullopt; } diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index ece03b19e961e..32c0e1b77c2cb 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -216,7 +216,7 @@ void FunctionExtractor::FunctionContext::SetAttrName( TORCH_INTERNAL_ASSERT( v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end()); auto* n_in_def = v_it->second->node(); - auto n_attr_it = node_attr_to_name_[n_in_def][attr.toUnqualString()] = name; + node_attr_to_name_[n_in_def][attr.toUnqualString()] = name; } std::optional FunctionExtractor::FunctionContext::FindAttrName( @@ -405,7 +405,7 @@ std::optional FunctionExtractor::InferScope(Node* n) { auto common_ancestor = FindCommonAncestor(scopes); if (common_ancestor.has_value() && IsValidScope(common_ancestor.value())) { - return common_ancestor.value(); + return common_ancestor; } } } diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 73106ba0ef3c7..71595b769ac1c 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -35,8 +35,8 @@ static bool isRNN(const Node* node) { } static bool isNopTranspose(const std::vector& perm) { - for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++) { - if (perm[i] != i) { + for (size_t i = 0, perm_size = perm.size(); i < perm_size; i++) { + if (perm[i] != static_cast(i)) { return false; } } diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index 7a28f1e41c1b5..966388278a32f 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -10,8 +10,6 @@ #include -#include - namespace torch::jit { namespace { @@ -344,7 +342,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { auto it = std::find(node->inputs().begin(), node->inputs().end(), input); if (it != node->inputs().end()) { - int index = std::distance(node->inputs().begin(), it); + auto index = std::distance(node->inputs().begin(), it); TORCH_WARN( "ONNX Preprocess - Removing mutation from node ", node->kind().toQualString(), diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 086e50ae6a7a3..452b18f3efc31 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -282,7 +282,7 @@ Value* CloneValueFromListConstruct( auto input = n_graph->addInput(); if (scalar_type) { auto v_type = TensorType::create( - scalar_type.value(), + scalar_type, at::kCPU, c10::SymbolicShape(), c10::VaryingShape{}, @@ -411,7 +411,9 @@ void ConvertGraphToONNXProto( } } -std::optional ComputeConstantFolding(Node* n, int opset_version) { +std::optional ComputeConstantFolding( + const Node* n, + int opset_version) { if (n->inputs().empty()) { return std::nullopt; } @@ -463,7 +465,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( auto it_0 = std::find_if(shape_vector.begin(), shape_vector.end(), is_zero); bool shape_has_zero = it_0 != shape_vector.end(); - int minus_one_pos = -1; + int64_t minus_one_pos = -1; for (auto i : c10::irange(shape_vector.size())) { if (shape_vector[i].value() == -1) { minus_one_pos = i; @@ -773,7 +775,7 @@ void ProcessBroadcastNode(Node* n) { } void ProcessShapeForConcatNode(Node* n) { - int axis = n->i(attr::axis); + auto axis = n->i(attr::axis); if (ConstantValueMap::HasRank(n->input(0)->debugName())) { auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value(); size_t axis_adjust = 0; @@ -1244,7 +1246,7 @@ void ProcessUnsqueezeNode(Node* n) { void ComputeConstant(Node* n, int opset_version) { if (n->kind() == ::c10::onnx::Constant) { if (n->kindOf(attr::value) == AttributeKind::t) { - at::Tensor const_val = n->t(attr::value); + const at::Tensor& const_val = n->t(attr::value); at::Tensor const_val_copy = at::empty(const_val.sizes(), const_val.options()); const_val_copy.copy_(const_val); @@ -1381,7 +1383,7 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { - auto input0_shape_value = input0_shape_size.value(); + const auto& input0_shape_value = input0_shape_size.value(); if (ConstantValueMap::HasValue(n->input(1)->debugName())) { // When value of `shape` is statically known, // output shape can be computed. @@ -1474,7 +1476,7 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { - auto input0_shape_value = input0_shape_size.value(); + const auto& input0_shape_value = input0_shape_size.value(); int64_t total_size = 1; auto is_full_static = true; for (const auto i : c10::irange(input0_shape_value.size())) { @@ -1510,7 +1512,7 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { - auto input0_shape_value = input0_shape_size.value(); + const auto& input0_shape_value = input0_shape_size.value(); if (ConstantValueMap::HasValue(n->input(1)->debugName())) { auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector( n->input(1)->debugName()); @@ -1659,10 +1661,10 @@ void SpecialPostProcess(Node* n) { }; auto find_sequence_empty = [](Value* input, - TensorTypePtr t_type) -> Node* { + const TensorTypePtr& t_type) -> Node* { auto find_sequence_empty_impl = [](Value* input, - TensorTypePtr t_type, + const TensorTypePtr& t_type, auto& find_sequence_empty_ref) -> Node* { auto input_node = input->node(); TORCH_INTERNAL_ASSERT(input_node); @@ -1708,7 +1710,7 @@ void SpecialPostProcess(Node* n) { return nullptr; }; return find_sequence_empty_impl( - input, std::move(t_type), find_sequence_empty_impl); + input, t_type, find_sequence_empty_impl); }; if (seq_node && t_type && t_type->scalarType()) { @@ -2255,7 +2257,7 @@ void ONNXSetDynamicInputShape( } } -static bool HasSequenceTypeOutput(Node* node) { +static bool HasSequenceTypeOutput(const Node* node) { if (node->kind() == ::c10::onnx::SplitToSequence || node->kind() == ::c10::onnx::SequenceInsert || node->kind() == ::c10::onnx::SequenceEmpty || diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 3116c0721a6c4..63e6804c97eb3 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -21,83 +21,6 @@ using namespace ::c10::onnx; } -// Get the scale of the input to quantized op. There are two cases here -// 1. For ops with output_scale specified in op signature, we get the output -// scale -// 2. For ops with no output scale in op signature (like quantized::relu) -// we traverse up the graph to get the scale from its input until we hit a node -// where scale is explicitly specified. -double getScaleFromInput(Node* input_node) { - std::optional scale; - std::string input_name = input_node->kind().toQualString(); - std::unordered_set noscale_ops = { - "quantized::max_pool2d", - "aten::max_pool2d", - "aten::relu", - "prim::ListUnpack", - "aten::split_with_sizes", - "quantized::nchw2nhwc", - "quantized::nhwc2nchw", - "aten::slice", - "aten::avg_pool2d", - "quantized::cat", - "prim::ListConstruct", - "aten::upsample_nearest2d", - "aten::sigmoid", - "aten::reshape"}; - if (input_name == "aten::quantize_per_tensor") { - TORCH_CHECK( - input_node->inputs().size() > 1, - "aten::quantize_per_tensor expected scale to be 2nd input"); - scale = toIValue(input_node->inputs()[1]); - return scale.value().toDouble(); - } else if (input_name == "quantized::linear") { - // %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point) - TORCH_CHECK( - input_node->inputs().size() > 2, - "quantized::linear expected scale to be 3rd input"); - scale = toIValue(input_node->inputs()[2]); - return scale.value().toDouble(); - } else if (input_name == "quantized::conv2d") { - // %r = quantized::conv2d(%input, %packed_weight, %w_scale, %w_zero_point) - TORCH_CHECK( - input_node->inputs().size() > 2, - "quantized::conv2d expected scale to be 3rd input"); - auto num_inputs = input_node->inputs().size(); - scale = toIValue(input_node->inputs()[num_inputs - 2]); - return scale.value().toDouble(); - } else if (input_name == "quantized::conv2d_relu") { - // %r = quantized::conv2d_relu(%input, %packed_weight, %w_scale, - // %w_zero_point) - TORCH_CHECK( - input_node->inputs().size() > 2, - "quantized::conv2d_relu expected scale to be 3rd input"); - auto num_inputs = input_node->inputs().size(); - scale = toIValue(input_node->inputs()[num_inputs - 2]); - return scale.value().toDouble(); - } else if (input_name == "quantized::add") { - // %r = quantized::add(%input_a, %input_b, %w_scale, %w_zero_point) - TORCH_CHECK( - input_node->inputs().size() > 2, - "quantized::add expected scale to be 3rd input"); - scale = toIValue(input_node->inputs()[2]); - return scale.value().toDouble(); - } else if (input_name == "aten::sigmoid") { - // For the _caffe2::Int8Sigmoid op output scale is 1.0/256 - // And output zero_point is set to 0 (quint8 type). - return 1.0L / 256; - } - // For the ops below the scale is not part of the op signature, so we traverse - // up the graph to get the scale from its input when defined in the graph. - else if (noscale_ops.find(input_name) != noscale_ops.end()) { - return getScaleFromInput(input_node->inputs()[0]->node()); - } - TORCH_INTERNAL_ASSERT( - false, - "Unrecognized quantized operator while trying to compute q_scale for operator ", - input_name); -} - static std::vector CreateQuantizedWeights( std::shared_ptr& graph, const at::Tensor& weight, @@ -315,7 +238,7 @@ static void unpackQuantizedWeightsHelper( auto config_vals = elements[1].to>(); auto tensors = elements[2].to>>(); - std::optional weight = tensors[1]; + const std::optional& weight = tensors[1]; TORCH_INTERNAL_ASSERT( weight, "Weight should always be present in serialized qconv."); unpacked_weight = *weight; @@ -373,7 +296,7 @@ static void unpackQuantizedWeightsHelper( TORCH_INTERNAL_ASSERT(version == "2", "Unknown serialization version"); std::vector non_optional = elements[1].toTensorVector(); - at::Tensor conv_params_packed = non_optional[0]; + const at::Tensor& conv_params_packed = non_optional[0]; unpacked_weight = non_optional[1]; const int64_t kSpatialDim = conv_params_packed[0].item(); From dc0d18e023d9b7e314ebba0f234b6cb1579dbcfd Mon Sep 17 00:00:00 2001 From: FFFrog Date: Sat, 9 Aug 2025 23:47:14 +0800 Subject: [PATCH 0206/1424] [CUDA] Remove the uncessary CUDA_GUARD (#160249) `CUDA_GUARD` is unnecessary in `initDeviceStreamState`, because the `initSingleStream` has already done it. https://github.com/pytorch/pytorch/blob/29712314dd5cf500a8ea3d1c69483a3cb768ca72/c10/cuda/CUDAStream.cpp#L202-L203 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160249 Approved by: https://github.com/Skylion007 --- c10/cuda/CUDAStream.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 0cde2d9de01cf..8eca673cd3a4d 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -216,9 +216,6 @@ static void initSingleStream(int p, DeviceIndex device_index, int i) { // Creates the low and high priority stream pools for the specified device // Warning: only call once per device! static void initDeviceStreamState(DeviceIndex device_index) { - // Switches to the requested device so streams are properly associated - // with it. - CUDAGuard device_guard{device_index}; for (const auto i : c10::irange(kStreamsPerPool)) { for (const auto p : c10::irange(max_stream_priorities)) { initSingleStream(p, device_index, i); From 334b38ccc4427b1d14981c48a3a0b92180d58225 Mon Sep 17 00:00:00 2001 From: Jiaxi WANG <148853031+bjtuwjx@users.noreply.github.com> Date: Mon, 11 Aug 2025 05:09:57 +0000 Subject: [PATCH 0207/1424] Fix typo in README.md (#160160) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The "Get the PyTorch Source" section is now located before the "Install Dependencies/Common" section, so "... using the “Get the PyTorch Source“ section below" should be "... using the “Get the PyTorch Source“ section above". Pull Request resolved: https://github.com/pytorch/pytorch/pull/160160 Approved by: https://github.com/BoyuanFeng --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3c67d36e74950..16000850ae920 100644 --- a/README.md +++ b/README.md @@ -243,7 +243,7 @@ git submodule update --init --recursive ```bash conda install cmake ninja -# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below +# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section above pip install -r requirements.txt ``` From ff0d56d03592aa03f3ced8359241d21df1783393 Mon Sep 17 00:00:00 2001 From: Nick Riasanovsky Date: Mon, 11 Aug 2025 05:27:51 +0000 Subject: [PATCH 0208/1424] [Inductor] [Triton] Enable Configuration warmup/rep iterations when benchmarking in inductor (#159982) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: When benchmarking on B200 Max Autotune, I discovered that the estimations from the autotune logs consistently produced a better ATEN result by > 20% on an example shape. Here is an example of the output: ``` Autotune Choices Stats: {"num_choices": 20, "num_triton_choices": 19, "best_kernel": "mm", "best_time": 0.3081120103597641, "best_triton_pos": 1, "best_triton_time": 0.6589759886264801, "best_triton_kernel": "triton_mm_16", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0"} AUTOTUNE mm(3840x1152, 1152x49136) strides: [1, 3840], [49152, 1] dtypes: torch.bfloat16, torch.bfloat16 mm 0.3081 ms 100.0% triton_mm_16 0.6590 ms 46.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_17 0.6830 ms 45.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_13 0.7015 ms 43.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_9 0.8487 ms 36.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_11 0.8695 ms 35.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_10 0.8797 ms 35.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_18 0.9089 ms 33.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_14 0.9718 ms 31.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_15 1.0169 ms 30.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 SingleProcess AUTOTUNE benchmarking takes 2.8574 seconds and 0.1032 seconds precompiling for 20 choices Removed 3483 outliers from 28645 samples 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:20<00:00, 20.00s/it] (M, N, K) pt2_matmul_maxautotune-latency pt2_matmul_maxautotune-speedup pt2_matmul_maxautotune-tflops ------------------- -------------------------------- -------------------------------- ------------------------------- (3840, 49136, 1152) 0.359392 (±8.27%) 1209.61 average 1209.61 ``` Based on my reading about B200 power usage, I believe this is due to the new for power aware benchmarking as a kernel may perform better in short bursts. This adds environment variables to expand autotuning iterations so we can get more consistent results between the estimation and the actual runtime. I did not update the default yet, even for B200 because I'm not sure how this is used in practice. This is the new output: ``` Autotune Choices Stats: {"num_choices": 20, "num_triton_choices": 19, "best_kernel": "mm", "best_time": 0.3848319947719574, "best_triton_pos": 1, "best_triton_time": 0.6287680268287659, "best_triton_kernel": "triton_mm_16", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0"} AUTOTUNE mm(3840x1152, 1152x49136) strides: [1, 3840], [49152, 1] dtypes: torch.bfloat16, torch.bfloat16 mm 0.3848 ms 100.0% triton_mm_16 0.6288 ms 61.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_13 0.6299 ms 61.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_17 0.6728 ms 57.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_9 0.7189 ms 53.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_18 0.8566 ms 44.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_11 0.8693 ms 44.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_14 0.9298 ms 41.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_10 0.9524 ms 40.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_15 1.0216 ms 37.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 SingleProcess AUTOTUNE benchmarking takes 3.9245 seconds and 0.0965 seconds precompiling for 20 choices Removed 3537 outliers from 29530 samples 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:23<00:00, 23.70s/it] (M, N, K) pt2_matmul_maxautotune-latency pt2_matmul_maxautotune-speedup pt2_matmul_maxautotune-tflops ------------------- -------------------------------- -------------------------------- ------------------------------- (3840, 49136, 1152) 0.359328 (±9.71%) 1209.82 average 1209.82 ``` Test Plan: `TORCH_AUTOTUNE_REP=1000 CUDA_VISIBLE_DEVICES=2 ENABLE_MMA_V5_ATT_PIPELINE=1 TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -- --op gemm --iter $NUM_ITERS --input-loader /home/njriasan/parsed_shapes.json --only pt2_matmul_maxautotune` Rollback Plan: Reviewed By: NikhilAPatel Differential Revision: D79737929 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159982 Approved by: https://github.com/NikhilAPatel --- torch/_inductor/ir.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4f9f2f1e0b59f..a668cd41ebf1b 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6,6 +6,7 @@ import itertools import logging import operator +import os import textwrap import traceback from collections.abc import Container, Generator, Iterable, Iterator, Sequence @@ -156,6 +157,9 @@ indent = functools.partial(textwrap.indent, prefix=" ") aten = torch.ops.aten +autotune_warmup = int(os.getenv("TORCH_AUTOTUNE_WARMUP", 25)) +autotune_rep = int(os.getenv("TORCH_AUTOTUNE_REP", 100)) + """ [Note: Inductor IR] Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each @@ -4910,9 +4914,13 @@ def __init__( def benchmark(self, *args: Any, out: torch.Tensor) -> float: algo = self.to_callable() + benchmark_configs = { + "warmup": autotune_warmup, + "rep": autotune_rep, + } if config.profile_bandwidth_with_do_bench_using_profiling: - return do_bench_using_profiling(lambda: algo(*args)) - return benchmarker.benchmark(algo, args, {"out": out}) + return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) + return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs) def call_name(self) -> str: raise NotImplementedError From 1c2cba17eab2b09d87142883da2bdbdbcf018613 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Fri, 8 Aug 2025 16:39:15 -0700 Subject: [PATCH 0209/1424] [FR] Add stack_id and an optional print of stack_id to stack_trace mapping (#160119) To better help users debug with FR, we want to add stack_id and print a map between stack_id and stack_trace (optional) Screenshot: image image Pull Request resolved: https://github.com/pytorch/pytorch/pull/160119 Approved by: https://github.com/H-Huang, https://github.com/wconstab --- tools/flight_recorder/components/builder.py | 8 ++++- .../components/config_manager.py | 1 + tools/flight_recorder/components/types.py | 2 ++ tools/flight_recorder/components/utils.py | 33 +++++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index 2a9cee36f7bc8..4bc268022e285 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -24,6 +24,7 @@ Traceback, ) from tools.flight_recorder.components.utils import ( + add_stack_id_in_entries, align_trace_from_beginning, check_current_entry_match, check_no_missing_dump_files, @@ -391,6 +392,9 @@ def build_db( # Ensure version is consistent across all ranks. check_version(version_by_ranks, version) entries = align_trace_from_beginning(entries) + stack_id_trace_map: dict[str, int] = {} + if args.just_print_entries: + entries, stack_id_trace_map = add_stack_id_in_entries(entries) # flattened database groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( @@ -402,7 +406,9 @@ def build_db( check_no_missing_dump_files(entries, memberships) if args.just_print_entries: - just_print_entries(entries, _groups, _memberships, _pg_guids, args) + just_print_entries( + entries, _groups, _memberships, _pg_guids, args, stack_id_trace_map + ) sys.exit(0) tracebacks, collectives, nccl_calls = build_collectives( diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index ea9b0cf3918cd..abd7f5372133c 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -67,6 +67,7 @@ def __init__(self: "JobConfig"): ) self.parser.add_argument("-j", "--just_print_entries", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") + self.parser.add_argument("--print_stack_trace", action="store_true") def parse_args( self: "JobConfig", args: Optional[Sequence[str]] diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index 597ee8e3cedaa..ded30fb077cda 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -417,6 +417,7 @@ def __init__( else: self.input_sizes, self.output_sizes = None, None self.collective_seq_id = event["collective_seq_id"] + self.stack_id = event.get("stack_id", -1) self.p2p_seq_id = event["p2p_seq_id"] self.input_dtypes = event["input_dtypes"] self.output_dtypes = event["output_dtypes"] @@ -456,6 +457,7 @@ def __repr__(self) -> str: f"pg_name={self.pg_name}", f"pg_description={self.pg_desc}", f"pg_size={self.pg_size}", + f"stack_id={self.stack_id}", f"state={self.state}", ) return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s) diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 73ec2a13d3be0..b68266c79b2c2 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -616,6 +616,7 @@ def just_print_entries( _memberships: dict[str, set[Any]], _pg_guids: dict[tuple[str, int], str], args: argparse.Namespace, + stack_id_trace_map: dict[str, int], ) -> None: rows = [] ranks = sorted(all_entries.keys()) @@ -650,6 +651,17 @@ def just_print_entries( logger.info(tabulate(rows, headers=headers)) + if stack_id_trace_map and args.print_stack_trace: + headers = ["stack_id", "frame_stack"] + rows = [] + + for frame, stack_id in sorted( + stack_id_trace_map.items(), key=lambda item: item[1] + ): + rows.append([str(stack_id), frame]) + + logger.info(tabulate(rows, headers=headers)) + def check_no_missing_dump_files( entries: dict[int, Any], memberships: list[Membership] @@ -677,6 +689,27 @@ def get_version_detail(version: str) -> tuple[int, int]: return major, minor +def add_stack_id_in_entries( + entries: dict[int, list[dict[str, Any]]], +) -> tuple[dict[int, list[dict[str, Any]]], dict[str, int]]: + stack_id = 0 + stack_id_trace_map = {} + for rank in entries: + for dump in entries[rank]: + if dump.get("frames", []): + frames = str(dump["frames"]) + if frames not in stack_id_trace_map: + stack_id_trace_map[frames] = stack_id + dump["stack_id"] = stack_id + stack_id += 1 + else: + dump["stack_id"] = stack_id_trace_map[frames] + else: + dump["stack_id"] = -1 + + return entries, stack_id_trace_map + + def align_trace_from_beginning( entries: dict[int, list[dict[str, Any]]], ) -> dict[int, list[dict[str, Any]]]: From ecea81117b2fdc52907c97b3c32d779e07b5d55b Mon Sep 17 00:00:00 2001 From: Tanmay Sinha <46783696+tanmay-sinha@users.noreply.github.com> Date: Mon, 11 Aug 2025 09:03:14 +0000 Subject: [PATCH 0210/1424] Fix clang builds by adding headers (#160252) Clang compiler from llvm-14 fails to build full torch from source with the message ``` no template named 'unordered_map' in namespace 'std' std::unordered_map handlers_{}; ~~~~~^ ``` A similar issue here https://github.com/intel/llvm/issues/5264 Fix is to add the correct headers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160252 Approved by: https://github.com/Skylion007, https://github.com/cyyever --- torch/csrc/distributed/c10d/control_plane/Handlers.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index 0b4a2f9568400..973197ded14fc 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -4,7 +4,10 @@ #include #include #include +#include +#include #include +#include namespace c10d::control_plane { From cf4964be68fa9f4ffc334f01cce42d7424b1cc81 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 11 Aug 2025 10:14:47 +0000 Subject: [PATCH 0211/1424] Remove unnecessary CMake checks for glog (#158185) With the updating to CMake 2.27, some old scripts can be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158185 Approved by: https://github.com/malfet, https://github.com/Skylion007 --- cmake/MiscCheck.cmake | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 871a23487f29d..9efb0b46c59dd 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -2,24 +2,6 @@ include(CheckCXXSourceCompiles) include(CheckCXXCompilerFlag) include(CMakePushCheckState) -# ---[ Check if we want to turn off deprecated warning due to glog. -if(USE_GLOG) - cmake_push_check_state(RESET) - set(CMAKE_REQUIRED_FLAGS "-std=c++17") - CHECK_CXX_SOURCE_COMPILES( - "#include - int main(int argc, char** argv) { - return 0; - }" CAFFE2_NEED_TO_TURN_OFF_DEPRECATION_WARNING - FAIL_REGEX ".*-Wno-deprecated.*") - - if(NOT CAFFE2_NEED_TO_TURN_OFF_DEPRECATION_WARNING AND NOT MSVC) - message(STATUS "Turning off deprecation warning due to glog.") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated") - endif() - cmake_pop_check_state() -endif() - # ---[ Check if the compiler has AVX/AVX2 support. We only check AVX2. if(NOT INTERN_BUILD_MOBILE) find_package(AVX) # checks AVX and AVX2 From 05029ad1c30865d3f7e7fd13384db9d826e563eb Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Mon, 11 Aug 2025 11:28:46 +0000 Subject: [PATCH 0212/1424] [xla hash update] update the pinned xla hash (#160306) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned xla hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160306 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/xla.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index ee8531ae65100..cf8eb1a1efceb 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -b6a5b82b9948b610fa4c304d0d869c82b8f17db1 +095faec1e7b6cc47220181e74ae9cde2605f9b00 From 2259dbed4e0d3f2a8174b5847fd0741aed42451d Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Mon, 11 Aug 2025 12:00:09 +0000 Subject: [PATCH 0213/1424] Update slow tests (#158222) This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/weekly.yml). Update the list of slow tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158222 Approved by: https://github.com/pytorchbot --- test/slow_tests.json | 495 +++++++++++++++++++++---------------------- 1 file changed, 237 insertions(+), 258 deletions(-) diff --git a/test/slow_tests.json b/test/slow_tests.json index 457701b46b611..579e69d7e4888 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,260 +1,239 @@ { - "EndToEndLSTM (__main__.RNNTest)": 200.1896718343099, - "MultiheadAttention (__main__.ModulesTest)": 141.92533365885416, - "test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 210.3270060221354, - "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 105.85777706570096, - "test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 115.53966522216797, - "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 62.45811038547092, - "test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 177.51766967773438, - "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 74.74966557820638, - "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 68.23533376057942, - "test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 61.625999450683594, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 134.07366434733072, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 188.88899739583334, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.63599904378255, - "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.27233378092448, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 105.4979985555013, - "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 633.0828002929687, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 91.86733309427898, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 481.1977776421441, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 491.7155592176649, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 124.39833196004231, - "test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.104000091552734, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 81.22966766357422, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 69.64550145467122, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 175.67355600992838, - "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 125.82333374023438, - "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 369.5883280436198, - "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 418.0381130642361, - "test_collect_callgrind (__main__.TestBenchmarkUtils)": 312.76700168185766, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 84.68433380126953, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 86.41216786702473, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 60.670833587646484, - "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 84.44266510009766, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 86.69533284505208, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 63.40933354695638, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 375.11133829752606, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 64.89966583251953, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 386.1840108235677, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 66.45699818929036, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 227.58533223470053, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 236.75483194986978, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1000.12451171875, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 63.72516632080078, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 936.3953450520834, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 65.74933242797852, - "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 70.87016677856445, - "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.49433453877766, - "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.39149983723958, - "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.41349919637044, - "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.10983467102051, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 64.13150151570638, - "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 89.73133341471355, - "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 86.45633188883464, - "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 88.76399993896484, - "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.25218469125254, - "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.11777793036566, - "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 176.61566670735678, - "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 173.7596689860026, - "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 163.57832845052084, - "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 161.29700215657553, - "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 208.6990000406901, - "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 198.11366271972656, - "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 198.788330078125, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 121.93983332316081, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 119.3211669921875, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.11850102742513, - "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 121.52633412679036, - "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 114.41900126139323, - "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 120.74099985758464, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 92.1571667989095, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 93.97516759236653, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 93.90033213297527, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 102.24433135986328, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 237.9564997355143, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 263.09083048502606, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 70.44449869791667, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 78.58383433024089, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 66.97166633605957, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 81.04183451334636, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 89.63233439127605, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 94.67216491699219, - "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 168.28499857584634, - "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 171.91666666666666, - "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 166.12066650390625, - "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 1279.8836669921875, - "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 1132.968994140625, - "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 1118.725341796875, - "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 973.7703247070312, - "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 972.6750081380209, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1209.7756754557292, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1256.0619710286458, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1281.5216471354167, - "test_comprehensive_nn_functional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 917.3249918619791, - "test_comprehensive_nn_functional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 733.1909790039062, - "test_comprehensive_nn_functional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 724.7653401692709, - "test_comprehensive_nn_functional_max_pool3d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 726.2100219726562, - "test_comprehensive_nn_functional_max_pool3d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 705.0809936523438, - "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 517.8646697998047, - "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 521.0065002441406, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 130.64300028483072, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 124.43033345540364, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 128.03166707356772, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.71049880981445, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.55933380126953, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 65.66183217366536, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 69.40700022379558, - "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 74.34766642252605, - "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 112.48366800944011, - "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 116.27966562906902, - "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 117.50433603922527, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 106.86666615804036, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 94.00083287556966, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 62.15316645304362, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.82649993896484, - "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 61.87600072224935, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 69.6066665649414, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 68.90516599019368, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 102.65083312988281, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 85.81283442179362, - "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 70.68100102742513, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 98.76588948567708, - "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 229.82177903917102, - "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 81.8357684795673, - "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 135.92233530680338, - "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 141.42266845703125, - "test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 74.59500092726488, - "test_conv3d_unary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 64.01784662099985, - "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 73.09766684638129, - "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 95.88766733805339, - "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 94.47416687011719, - "test_count_nonzero_all (__main__.TestBool)": 641.161878797743, - "test_custom_module_lstm (__main__.TestQuantizedOps)": 307.93677775065106, - "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 302.5940024058024, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 81.91116714477539, - "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDTensorOpsCPU)": 88.2913335164388, - "test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 67.36266835530598, - "test_fail_arithmetic_ops.py (__main__.TestTyping)": 60.49377780490451, - "test_fail_creation_ops.py (__main__.TestTyping)": 68.32106041185784, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 76.85566584269206, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 91.61366780598958, - "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 204.6830037434896, - "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 134.79716873168945, - "test_fuse_large_params_cpu (__main__.CpuTests)": 97.0917501449585, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.09088897705078, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 147.25677744547525, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 125.67216491699219, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 94.74416732788086, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 98.06850051879883, - "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 150.5540008544922, - "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 139.7729949951172, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 232.7606684366862, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 154.89383188883463, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 156.3326670328776, - "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 650.9168192545573, - "test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.89266459147134, - "test_indirect_device_assert (__main__.TritonCodeGenTests)": 273.2460021972656, - "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.99511040581598, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 101.2813351949056, - "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 154.23166741265192, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 116.40700022379558, - "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 123.70700073242188, - "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 95.7520014444987, - "test_linear (__main__.TestStaticQuantizedModule)": 62.20888815985786, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 102.4893315633138, - "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 127.22689056396484, - "test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 431.17966715494794, - "test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 133.41966756184897, - "test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 360.4186706542969, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 60.48455513848199, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 63.52433310614692, - "test_proper_exit (__main__.TestDataLoader)": 234.38233439127603, - "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 242.4615020751953, - "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 65.31966749827068, - "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 150.28666602240668, - "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn1d)": 65.1363112979465, - "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 63.50664397345649, - "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 62.56345471468839, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 73.45999908447266, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.02366638183594, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 85.85933430989583, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 74.7816670735677, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 88.31666564941406, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.21133422851562, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.58400217692058, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 85.65733337402344, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 94.56866709391277, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 80.31666564941406, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 95.52099863688152, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 92.52433522542317, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.57466634114583, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.05966695149739, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.94766743977864, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 77.00899759928386, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 95.18199920654297, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.22000122070312, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 69.10733286539714, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 84.89466603597005, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 85.52066548665364, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 93.1520004272461, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 91.66366831461589, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 370.8893330891927, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 733.5455017089844, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 605.9030151367188, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1136.014139811198, - "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 72.65350023905437, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 64.6456667582194, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 207.27167002360025, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 91.64166768391927, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 167.19299825032553, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 64.22866694132487, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 116.8476676940918, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 70.6433334350586, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 137.72866566975912, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 87.72266642252605, - "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 78.25366719563802, - "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 67.75999959309895, - "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 68.58633486429851, - "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.43899959988065, - "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 155.9663340250651, - "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 110.39933268229167, - "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 85.31637557347615, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 136.4769990709093, - "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 113.9978896247016, - "test_sort_stable_cpu (__main__.CpuTritonTests)": 76.96166737874348, - "test_split_cumsum_cpu (__main__.CpuTritonTests)": 89.43966674804688, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 149.7841674486796, - "test_tensor_split (__main__.TestVmapOperators)": 76.2336671680021, - "test_terminate_handler_on_crash (__main__.TestTorch)": 111.58677675988939, - "test_terminate_signal (__main__.ForkTest)": 136.8188896137807, - "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 136.99289169742002, - "test_terminate_signal (__main__.SpawnTest)": 140.61755683687, - "test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 69.51326649983724, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 68.61666615804036, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 65.95349820454915, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.64900016784668, - "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.68766657511394, - "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 120.926331837972, - "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 104.47883415222168, - "test_unary_ops (__main__.TestTEFuserDynamic)": 172.1952222188314, - "test_unary_ops (__main__.TestTEFuserStatic)": 158.92655531565347, - "test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 96.95966339111328, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 90.34199778238933, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 69.39216740926106, - "test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 73.56816864013672, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 96.19633483886719, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 93.57866668701172, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 95.94100189208984, - "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 71.65300051371257, - "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 84.81466547648112, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 100.53633308410645, - "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 69.77733103434245, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 67.43849881490071, - "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 77.40583229064941, - "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 64.32900110880534, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 71.61133193969727, - "test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 60.90399932861328, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.39033381144206, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.00383377075195, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 143.61550013224283 + "EndToEndLSTM (__main__.RNNTest)": 192.05133056640625, + "MultiheadAttention (__main__.ModulesTest)": 139.78399658203125, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 87.68600040011935, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 65.84855567084418, + "test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 60.25300089518229, + "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.21100107828777, + "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 75.08200073242188, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 157.21666717529297, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 208.15966288248697, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 125.87799835205078, + "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 77.12099711100261, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 140.02066548665366, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1035.8856404622395, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 135.24966684977213, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 508.929680718316, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 505.31178114149304, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 136.39566548665366, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.21700286865234, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 75.41950098673503, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 223.36288791232639, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.77316665649414, + "test_cat_2k_args (__main__.TestTEFuserDynamic)": 115.93922015362315, + "test_cat_2k_args (__main__.TestTEFuserStatic)": 130.553553307222, + "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 345.87477620442706, + "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 444.5221184624566, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 320.5727776421441, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 113.46416600545247, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 112.7143325805664, + "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.17833370632596, + "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 74.29283396402995, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 112.0316670735677, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 100.49766794840495, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 461.6960042317708, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 456.4236653645833, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.10166422526044, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 282.37300364176434, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1475.5308430989583, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.82050069173177, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1480.9661661783855, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 76.27283477783203, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 77.9731674194336, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 75.6216672261556, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 78.13583374023438, + "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 79.3071657816569, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 73.1963342030843, + "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 73.24300003051758, + "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.95249938964844, + "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 60.023167292277016, + "test_comprehensive_logspace_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 60.90595825513204, + "test_comprehensive_logspace_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 60.20212459564209, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 146.75049845377603, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 134.19933319091797, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 131.4624989827474, + "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 63.848776499430336, + "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 63.11926663716634, + "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 63.54826672871908, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 128.72383244832358, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 125.754332224528, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 112.56066640218098, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 105.46999867757161, + "test_comprehensive_nn_functional_grid_sample_cuda_bfloat16 (__main__.TestDecompCUDA)": 62.39555570814345, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 319.47683970133465, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 318.15632883707684, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 104.06650034586589, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 87.9704984029134, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 88.85649871826172, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 91.08616511027019, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 145.80900065104166, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 144.81166712443033, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1361.4583333333333, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1364.7848307291667, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1371.0353393554688, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 567.3706563313802, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 562.332997639974, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 75.43950017293294, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 73.2380002339681, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.18633397420247, + "test_comprehensive_nn_functional_unfold_cuda_complex128 (__main__.TestDecompCUDA)": 64.52433310614691, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 135.42366409301758, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 135.88899993896484, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 73.0211664835612, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 75.32600021362305, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 76.17533365885417, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 78.49149958292644, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 80.97866566975911, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 143.84516398111978, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 139.04916763305664, + "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 107.44683329264323, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 349.12533315022785, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 713.3404405381945, + "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.65333302815755, + "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 147.33233133951822, + "test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 65.11533101399739, + "test_conv_bn_folded_vs_unfolded (__main__.TestQuantizeEagerQATNumerics)": 60.53688989910815, + "test_conv_bn_fuse_cpu (__main__.CpuTests)": 82.8076680501302, + "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 79.54511260986328, + "test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 86.01536305745442, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 118.80933380126953, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 103.28283437093098, + "test_count_nonzero_all (__main__.TestBool)": 636.5518866644966, + "test_custom_module_lstm (__main__.TestQuantizedOps)": 806.537343343099, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 86.1219991048177, + "test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 129.43338103521438, + "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 226.9676717122396, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 64.93344370524089, + "test_fail_random.py (__main__.TestTyping)": 69.7191998799642, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 89.57850011189778, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 91.1931660970052, + "test_fuse_large_params_cpu (__main__.CpuTests)": 68.59933344523112, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 157.28044637044272, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 155.77044677734375, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.154665629069, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 107.34999974568684, + "test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 75.96997397985214, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 98.00283304850261, + "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 125.0576680501302, + "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.84066518147786, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 227.8953374226888, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 121.02666727701823, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 128.9303321838379, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 607.3985087076823, + "test_group_norm (__main__.TestQuantizedOps)": 94.22445230773, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 322.7479960123698, + "test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 126.8058580671038, + "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 74.46766620212131, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 98.24650065104167, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 165.09344482421875, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 117.98733266194661, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 125.10833231608073, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 96.8866678873698, + "test_linear (__main__.TestStaticQuantizedModule)": 177.4332241482205, + "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 99.29573364257813, + "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 112.58993326822916, + "test_linear_relu (__main__.TestStaticQuantizedModule)": 70.74819436942602, + "test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 106.39933342403836, + "test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.2489998227074, + "test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 581.2816569010416, + "test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 515.0809936523438, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 65.59099833170573, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 130.8411119249132, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.907222747802734, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 63.92422188652886, + "test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 80.63411996126175, + "test_optimize_for_inference_cpu_torchvision (__main__.TestFXExperimental)": 70.60716595252354, + "test_out_variant_custom_op_dynamic_shapes (__main__.DynamicShapesMiscTests)": 61.15033358619327, + "test_proper_exit (__main__.TestDataLoader)": 224.09533182779947, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 258.17566172281903, + "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 61.226499239603676, + "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 159.05066765679254, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn1d)": 63.150904201325915, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 62.33847640809559, + "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 99.43811119927301, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 81.92866770426433, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 90.84566497802734, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.01099904378255, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.23799896240234, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 90.45733388264973, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 90.5086669921875, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 76.81433359781902, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 86.00199890136719, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 86.0836664835612, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 73.06933339436848, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 98.68933614095052, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 90.80333201090495, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 78.26366678873698, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.90333557128906, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.47400156656902, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 90.05833435058594, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 90.04699961344402, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 69.11566670735677, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.11000061035156, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 83.76499938964844, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 90.46166483561198, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 93.64866638183594, + "test_qrnncell (__main__.TestDynamicQuantizedOps)": 76.3342770516562, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 578.3420003255209, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1415.7366739908855, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 764.0906778971354, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1710.9246826171875, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 97.7066650390625, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 350.8980000813802, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.1796646118164, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 271.30833435058594, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 76.83166758219402, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 166.40349833170572, + "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.98755560980902, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 106.40633392333984, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 189.75599924723306, + "test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.40213343302409, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 119.15783309936523, + "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 122.17516708374023, + "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 67.66699981689453, + "test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 165.6238899230957, + "test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 155.86678059895834, + "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 76.51850128173828, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 77.36766730414496, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 163.50216674804688, + "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 135.39966328938803, + "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 161.2034437391493, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 145.5945544772678, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 122.7945556640625, + "test_softmax_view_reshape (__main__.HelionTests)": 174.26483281453451, + "test_std (__main__.TestQuantizedOps)": 91.47738643594978, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 150.35899583498636, + "test_terminate_handler_on_crash (__main__.TestTorch)": 110.8061129252116, + "test_terminate_signal (__main__.ForkTest)": 134.98833089901342, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 135.13266838259167, + "test_terminate_signal (__main__.SpawnTest)": 139.0918925603231, + "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 83.97499879201253, + "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 166.78876847487228, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 76.76449902852376, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 74.20233408610027, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 77.21166737874348, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 126.05833435058594, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 124.58566665649414, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 102.95399856567383, + "test_unary_ops (__main__.TestTEFuserDynamic)": 94.66122142473857, + "test_unary_ops (__main__.TestTEFuserStatic)": 97.9681122303009, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 94.58433278401692, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 80.96083323160808, + "test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 84.94333267211914, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 93.61533101399739, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 99.49200185139973, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 60.70061842600504, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 98.77016703287761, + "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 80.70883369445801, + "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 117.87966664632161, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.81652414231073, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 138.76616923014322, + "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 66.88895261855353, + "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 66.50699996948242, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 98.47683461507161, + "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 115.15083122253418, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 102.98050053914388, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 132.38116709391275, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 124.73283131917317, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 159.73250325520834 } \ No newline at end of file From c184cb3852f0ff2d16a489d61abc3739c309e6ca Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 11 Aug 2025 13:48:02 +0000 Subject: [PATCH 0214/1424] [submodule] Bump fbgemm to latest (#158210) Merge the recent commits of FBGEMM and remove unnecessary CMake code. Specifically, we 1. enable `fbgemm_autovec` since the target is now correctly handled. 2. remove option `USE_FAKELOWP` which is not used. 3. remove `CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS` check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158210 Approved by: https://github.com/q10 --- CMakeLists.txt | 10 ++++----- cmake/BLAS_ABI.cmake | 1 + cmake/Dependencies.cmake | 46 ++++------------------------------------ cmake/MiscCheck.cmake | 40 ---------------------------------- cmake/Summary.cmake | 1 - third_party/fbgemm | 2 +- 6 files changed, 11 insertions(+), 89 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 16fec0c80028c..48b9e2e8df3eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -253,7 +253,6 @@ cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) -option(USE_FAKELOWP "Use FakeLowp operators" OFF) option(USE_GFLAGS "Use GFLAGS" OFF) option(USE_GLOG "Use GLOG" OFF) option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -836,10 +835,11 @@ include(ExternalProject) # ---[ Dependencies ---[ FBGEMM doesn't work on x86 32bit and # CMAKE_SYSTEM_PROCESSOR thinks its 64bit -if(USE_FBGEMM - AND((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VOID_P EQUAL - 4) - OR CMAKE_SYSTEM_PROCESSOR STREQUAL "x86")) +if(USE_FBGEMM AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + message(WARNING + "x64 operating system is required for FBGEMM. " + "Not compiling with FBGEMM. " + "Turn this warning off by USE_FBGEMM=OFF.") set(USE_FBGEMM OFF) endif() diff --git a/cmake/BLAS_ABI.cmake b/cmake/BLAS_ABI.cmake index bb0b5949d73d2..45a15af1027a3 100644 --- a/cmake/BLAS_ABI.cmake +++ b/cmake/BLAS_ABI.cmake @@ -1,3 +1,4 @@ +include(CMakePushCheckState) # Push host architecture when cross-compiling otherwise check would fail # when cross-compiling for arm64 on x86_64 cmake_push_check_state(RESET) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 8836b66bc0360..26d882f2f7f18 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -664,55 +664,20 @@ if(USE_FBGEMM) if(NOT DEFINED FBGEMM_SOURCE_DIR) set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory") endif() - if(NOT CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) - message(WARNING - "A compiler with AVX512 support is required for FBGEMM. " - "Not compiling with FBGEMM. " - "Turn this warning off by USE_FBGEMM=OFF.") - set(USE_FBGEMM OFF) - endif() - if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) - message(WARNING - "x64 operating system is required for FBGEMM. " - "Not compiling with FBGEMM. " - "Turn this warning off by USE_FBGEMM=OFF.") - set(USE_FBGEMM OFF) - endif() if(USE_FBGEMM AND NOT TARGET fbgemm) set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "") set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "") - if(MSVC AND BUILD_SHARED_LIBS) - set(FBGEMM_LIBRARY_TYPE "shared" CACHE STRING "") - else() - set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "") - endif() - if(USE_ASAN) - set(USE_SANITIZER "address,undefined" CACHE STRING "-fsanitize options for FBGEMM") - endif() + set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "") add_subdirectory("${FBGEMM_SOURCE_DIR}") - set_property(TARGET fbgemm_generic PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET fbgemm PROPERTY POSITION_INDEPENDENT_CODE ON) - - # Disabling autovec in fbgemm due to large library size causing symbol relocation issues, which is only allowed in static builds. - # Long-term solution involves modularizing fbgemm targets. - target_compile_definitions(fbgemm_generic PUBLIC DISABLE_FBGEMM_AUTOVEC) - target_compile_definitions(fbgemm_avx2 PUBLIC DISABLE_FBGEMM_AUTOVEC) - target_compile_definitions(fbgemm_avx512 PUBLIC DISABLE_FBGEMM_AUTOVEC) - - if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0.0) - # See https://github.com/pytorch/pytorch/issues/74352 - target_compile_options_if_supported(asmjit -Wno-deprecated-copy) - target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable) - endif() + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") target_compile_options_if_supported(asmjit -Wno-extra-semi) target_compile_options_if_supported(fbgemm -Wno-extra-semi) endif() + target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable) + target_compile_options_if_supported(asmjit -Wno-unused-variable) endif() if(USE_FBGEMM) - target_compile_definitions(fbgemm PUBLIC DISABLE_FBGEMM_AUTOVEC) list(APPEND Caffe2_DEPENDENCY_LIBS fbgemm) endif() endif() @@ -721,9 +686,6 @@ if(USE_FBGEMM) caffe2_update_option(USE_FBGEMM ON) else() caffe2_update_option(USE_FBGEMM OFF) - message(WARNING - "Turning USE_FAKELOWP off as it depends on USE_FBGEMM.") - caffe2_update_option(USE_FAKELOWP OFF) endif() if(USE_OPENCL) diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 9efb0b46c59dd..54126b1f130dc 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -12,46 +12,6 @@ if(NOT INTERN_BUILD_MOBILE) set(CAFFE2_PERF_WITH_AVX2 1) endif() endif() -# ---[ Check if the compiler has AVX512 support. -cmake_push_check_state(RESET) -if(MSVC AND NOT CMAKE_CXX_COMPILER_ID STREQUAL "Clang") - # We could've used MSVC's hidden option /arch:AVX512 that defines __AVX512F__, - # __AVX512DQ__, and __AVX512VL__, and /arch:AVX512F that defines __AVX512F__. - # But, we chose not to do that not to rely on hidden options. - set(CMAKE_REQUIRED_FLAGS "/D__AVX512F__ /D__AVX512DQ__ /D__AVX512VL__") -else() - # We only consider the case where all of avx512f, avx512dq, and avx512vl are - # supported. - # Platforms where avx512f is supported by not avx512dq and avx512vl as of - # Jan 15 2019 : linux_manywheel_2.7mu_cpu_build and - # linux_conda_3.7_cu100_build - set(CMAKE_REQUIRED_FLAGS "-mavx512f -mavx512dq -mavx512vl") -endif() -CHECK_CXX_SOURCE_COMPILES( - "#if defined(_MSC_VER) - #include - #else - #include - #endif - // check avx512f - __m512 addConstant(__m512 arg) { - return _mm512_add_ps(arg, _mm512_set1_ps(1.f)); - } - // check avx512dq - __m512 andConstant(__m512 arg) { - return _mm512_and_ps(arg, _mm512_set1_ps(1.f)); - } - int main() { - __m512i a = _mm512_set1_epi32(1); - __m256i ymm = _mm512_extracti64x4_epi64(a, 0); - ymm = _mm256_abs_epi64(ymm); // check avx512vl - __mmask16 m = _mm512_cmp_epi32_mask(a, a, _MM_CMPINT_EQ); - __m512i r = _mm512_andnot_si512(a, a); - }" CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) -if(CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) - message(STATUS "Current compiler supports avx512f extension. Will build fbgemm.") -endif() -cmake_pop_check_state() # ---[ Checks if compiler supports -fvisibility=hidden check_cxx_compiler_flag("-fvisibility=hidden" COMPILER_SUPPORTS_HIDDEN_VISIBILITY) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 24cfaa7f217d7..63e501bcb5aba 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -136,7 +136,6 @@ function(caffe2_print_configuration_summary) message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") message(STATUS " USE_FBGEMM : ${USE_FBGEMM}") - message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}") message(STATUS " USE_KINETO : ${USE_KINETO}") message(STATUS " USE_GFLAGS : ${USE_GFLAGS}") message(STATUS " USE_GLOG : ${USE_GLOG}") diff --git a/third_party/fbgemm b/third_party/fbgemm index 0adf628317e0c..21c7d30c526c0 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 0adf628317e0cea414f66dcca901e0b85280fdb1 +Subproject commit 21c7d30c526c0f1ad873ecc632dca6cfa8a69067 From 515cb70367e84fcbad23fcc5b39eb1d7706df2aa Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 11 Aug 2025 13:50:16 +0000 Subject: [PATCH 0215/1424] [inductor] normalize_path_separator for test_different_file_paths_local_pgo (#160286) `normalize_path_separator` for test_different_file_paths_local_pgo Pull Request resolved: https://github.com/pytorch/pytorch/pull/160286 Approved by: https://github.com/ezyang --- test/dynamo/test_pgo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py index 93e5274431bec..e9bef4a7714b5 100644 --- a/test/dynamo/test_pgo.py +++ b/test/dynamo/test_pgo.py @@ -12,6 +12,7 @@ import torch.compiler.config import torch.nested from torch._dynamo.testing import CompileCounter +from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.utils import clear_caches, fresh_cache @@ -322,8 +323,9 @@ def func(x): temp_dir1 = tempfile.TemporaryDirectory() temp_dir2 = tempfile.TemporaryDirectory() - path1 = os.path.join(temp_dir1.name, "example.py") - path2 = os.path.join(temp_dir2.name, "example.py") + # We need normalize_path_separator for Windows file path. + path1 = normalize_path_separator(os.path.join(temp_dir1.name, "example.py")) + path2 = normalize_path_separator(os.path.join(temp_dir2.name, "example.py")) cnts = CompileCounter() assert path1 != path2 From 80cca8307943ba64168208b54028f55b2c71daff Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 11 Aug 2025 13:50:40 +0000 Subject: [PATCH 0216/1424] [inductor] Skip some AOTI UTs on Windows. (#160287) Skip some AOTI UTs on Windows, it is not fully ready. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160287 Approved by: https://github.com/ezyang --- test/inductor/test_torchbind.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index 631a4fce31fdd..201590d02ed52 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -13,6 +13,7 @@ from torch._inductor.codecache import WritableTempFile from torch._inductor.package import package_aoti from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu from torch.testing._internal.torchbind_impls import ( _empty_tensor_queue, @@ -158,6 +159,7 @@ def test_torchbind_hop_schema_no_output(self): "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method, Tensor _1) -> NoneType _0", ) + @skipIfWindows(msg="AOTI is not fully support on Windows") def test_torchbind_aot_compile(self): ep, inputs, _, _ = self.get_exported_model() aoti_files = aot_compile( @@ -302,6 +304,7 @@ def test_torchbind_aoti(self): self.assertEqual(result, orig_res) @torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True) + @skipIfWindows(msg="AOTI is not fully support on Windows") def test_torchbind_aot_compile_constant_folding(self): ep, inputs, orig_res, _ = self.get_exported_model() pt2_path = torch._inductor.aoti_compile_and_package(ep) From 68a4b4b2e336cfd4451ce6546d900568e5ddf96c Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Mon, 11 Aug 2025 16:09:24 +0000 Subject: [PATCH 0217/1424] [codemod] Fix unreachable-break issue in caffe2/c10/cuda/CUDAFunctions.cpp +2 (#160257) Summary: LLVM has a warning `-Wunreachable-code-break` which identifies `break` statements that cannot be reached. These compromise readability, are misleading, and may identify bugs. This diff removes such statements. For questions/comments, contact r-barnes. - If you approve of this diff, please use the "Accept & Ship" button :-) Test Plan: Sandcastle Rollback Plan: Differential Revision: D79835614 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160257 Approved by: https://github.com/Skylion007 --- c10/cuda/CUDAFunctions.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 0e8cabf618593..683ed9b768455 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -53,13 +53,12 @@ int device_count_impl(bool fail_if_no_driver) { "https://pytorch.org to install a PyTorch version that has been " "compiled with your version of the CUDA driver."); } - } break; + } case cudaErrorInitializationError: TORCH_CHECK( false, "CUDA driver initialization failed, you might not " "have a CUDA gpu."); - break; case cudaErrorUnknown: TORCH_CHECK( false, @@ -67,7 +66,6 @@ int device_count_impl(bool fail_if_no_driver) { "incorrectly set up environment, e.g. changing env " "variable CUDA_VISIBLE_DEVICES after program start. " "Setting the available devices to be zero."); - break; #if C10_ASAN_ENABLED case cudaErrorMemoryAllocation: // In ASAN mode, we know that a cudaErrorMemoryAllocation error will From ca7315c17162ea21b1ca5ba23f4bf6168766c7b9 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 11 Aug 2025 16:25:12 +0000 Subject: [PATCH 0218/1424] [Graph Partition] Pass all OSS unit tests (#154667) Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315). Run the same diff on two days and both show speedup on average. [first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d) image [second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf) image Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667 Approved by: https://github.com/eellison --- test/inductor/test_compiled_autograd.py | 22 +- test/inductor/test_control_flow.py | 3 + test/inductor/test_cuda_repro.py | 6 +- test/inductor/test_cudagraph_trees.py | 330 +++++++++++++++++++-- test/inductor/test_inductor_annotations.py | 7 +- test/inductor/test_torchinductor.py | 296 ------------------ torch/_inductor/codegen/wrapper.py | 10 +- torch/_inductor/config.py | 6 +- torch/_inductor/cudagraph_utils.py | 5 +- torch/_inductor/scheduler.py | 11 +- torch/_inductor/utils.py | 7 + 11 files changed, 378 insertions(+), 325 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 241528b159cc1..dff94b4aa0927 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -3085,7 +3085,16 @@ def backward(ctx, gO): self.assertEqual(counters["compiled_autograd"]["captures"], 1) # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + if inductor_config.graph_partition: + # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops + # and cudagraphify the remaining computation. So there is no cudagraph skip. + expected_cudagraph_skips = 0 + else: + expected_cudagraph_skips = 1 + + self.assertEqual( + counters["inductor"]["cudagraph_skips"], expected_cudagraph_skips + ) @scoped_load_inline @requires_cuda_and_triton @@ -3150,9 +3159,18 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. # In the future, we can consider having a cpu scalar movement pass sometime after we trace # into the custom C++ autograd::Function (like in AOTDispatcher) + if inductor_config.graph_partition: + # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops + # and cudagraphify the remaining computation. So there is no cudagraph skip. + expected_cudagraph_skips = 0 + elif inductor_config.cpp_wrapper: + expected_cudagraph_skips = 2 + else: + expected_cudagraph_skips = 1 + self.assertEqual( counters["inductor"]["cudagraph_skips"], - 2 if inductor_config.cpp_wrapper else 1, + expected_cudagraph_skips, ) def test_logs(self): diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 107a65d6fa1df..511b9cea5e14d 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -472,6 +472,9 @@ def false_fn(x): @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @torch._inductor.config.patch(size_asserts=False) + # TODO: graph partition does not support creating tensor + # with dynamic shape in conditional subgraph yet + @torch._inductor.config.patch(graph_partition=False) def test_cond_unbacked_symint_inner(self, device): class Model(torch.nn.Module): def forward(self, p, a): diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 00511c572239e..53506698297f1 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -189,9 +189,9 @@ def f(q, k, v, mask): # padded bias should have an expanded dim FileCheck().check("buf0 =").check_same(", 0, ").run(code[0]) # single fused padded kernel - FileCheck().check("def call").check_count( - "empty_strided_cuda", 1, exactly=True - ).check("return").run(code[0]) + FileCheck().check_count("empty_strided_cuda(", 1, exactly=True).check( + "return" + ).run(code[0]) self.assertEqual(out, f(*inputs)) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 1408a0208cf06..763384671eb52 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -279,10 +279,14 @@ def foo(x, y): with capture_stderr() as captured_output: foo(torch.ones([10], device="cuda"), torch.ones([20])) - FileCheck().check( - "skipping cudagraphs due to cpu device (arg1_1). Found from" - ).check("y + 2").run(captured_output[0]) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + if torch._inductor.config.graph_partition: + # graph partition splits on cpu ops + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + else: + FileCheck().check( + "skipping cudagraphs due to cpu device (arg1_1). Found from" + ).check("y + 2").run(captured_output[0]) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) with capture_stderr() as captured_output: foo( @@ -292,7 +296,10 @@ def foo(x, y): FileCheck().check("skipping cudagraphs due to multiple devices").run( captured_output[0] ) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 1 if torch._inductor.config.graph_partition else 2, + ) @torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True) def test_skip_symbolic(self): @@ -807,10 +814,16 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + if torch._inductor.config.graph_partition: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 0), (0, 2)], + ) + else: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) @@ -1127,8 +1140,13 @@ def foo2(x): node = self.curr_node() first_node = next(node._path_from_root) - self.assertFalse(first_node.unaliased_in_all_paths[0]) - self.assertTrue(first_node.cached_tensor_outputs[0] is None) + if torch._inductor.config.graph_partition: + # graph partition may changed the order of outputs + self.assertFalse(first_node.unaliased_in_all_paths[1]) + self.assertTrue(first_node.cached_tensor_outputs[1] is None) + else: + self.assertFalse(first_node.unaliased_in_all_paths[0]) + self.assertTrue(first_node.cached_tensor_outputs[0] is None) @torch._inductor.config.patch("implicit_fallbacks", True) def test_multinomial(self): @@ -1631,10 +1649,16 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + if torch._inductor.config.graph_partition: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 0), (0, 2)], + ) + else: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) def test_separate_recordings(self): @@ -2137,8 +2161,8 @@ def forward(self, x) -> torch.Tensor: with self.assertRaisesRegex( Exception, r"(?s)static input data pointer changed.\n" - r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*" - r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," + r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*" + r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*," r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n", ): self.curr_node().run( @@ -3551,6 +3575,278 @@ def run(padded_size, original_size): self.assertEqual(self.get_manager().new_graph_id().id, 2) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_simple(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + self.assertEqual(eager_out, compiled_out) + + _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) + + if not config.cpp_wrapper: + FileCheck().check("def partition_0(args):").check( + "recursively_apply_fns = runner.recursively_apply_fns" + ).run(code[0]) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_foreach_op(self): + def fn(a0, a1): + c = torch._foreach_abs([a0, a1]) + return torch.mul(c[0], a0) + + compiled_fn = torch.compile(fn) + + a0 = torch.randn(2, 3, device="cuda") + a1 = torch.randn(2, 3, device="cuda") + eager_out = fn(a0, a1) + compiled_out = compiled_fn(a0, a1) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_condition_op(self): + def f(p, b): + def true_fn(x): + return torch.cos(x) + + def false_fn(x): + return torch.sin(x) + + return torch.cond(p, true_fn, false_fn, [b]) + + compiled_f = torch.compile(f) + + # static shape + p = torch.tensor([True], device="cuda") + a = torch.ones([2, 3], device="cuda") + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + # dynamic shape with backed symint + p = torch.tensor([True], device="cuda") + a = torch.ones([4, 5], device="cuda") + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_unbacked_symint_multi_output_layout(self): + def f(p, size_tensor): + size_val = size_tensor.item() + b = torch.ones([size_val, 3], device="cuda") + + def true_fn(x): + return torch.cos(x), torch.cos(x) + 1 + + def false_fn(x): + return torch.sin(x), torch.sin(x) + 1 + + cond_out = torch.cond(p, true_fn, false_fn, [b]) + return cond_out[0] + cond_out[1] + + compiled_f = torch.compile(f) + p = torch.tensor([True], device="cuda") + size_tensor = torch.tensor(2, device="cuda") + eager_out = f(p, size_tensor) + compiled_out = compiled_f(p, size_tensor) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + x, y = ( + torch.ones(4, 4, device="cuda"), + torch.randn(4, 4, device="cuda"), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_cat_backward(self): + def f(x, w): + y = torch.cat((x, x), dim=0) + z = y @ w + return z @ z.T + + compiled_f = torch.compile(f) + + for shape in (2, 3): + torch.manual_seed(42) + eager_x = torch.randn(shape, 2, device="cuda") + eager_w = torch.randn(2, 2, device="cuda", requires_grad=True) + torch.manual_seed(42) + compiled_x = torch.randn(shape, 2, device="cuda") + compiled_w = torch.randn(2, 2, device="cuda", requires_grad=True) + + f(eager_x, eager_w).sum().backward() + compiled_f(compiled_x, compiled_w).sum().backward() + self.assertEqual(eager_w.grad, compiled_w.grad) + + @dynamo_config.patch("capture_dynamic_output_shape_ops", True) + @config.patch(implicit_fallbacks=True) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_nested_indirect_indexing(self): + def nested(x, repeats): + rank = torch.arange(repeats.numel(), device=x.device) + index = rank.repeat_interleave(repeats, dim=0) + return torch.index_select(x, index=index, dim=0) + + example_inputs = ( + torch.randn((32, 64), device="cuda"), + repeats := torch.tensor([5, 10, 15], device="cuda"), + ) + torch._dynamo.mark_dynamic(repeats, 0) # create backed symint + + nested_opt = torch.compile(nested, backend="inductor") + + expect = nested(*example_inputs) + actual = nested_opt(*example_inputs) + self.assertEqual(expect, actual) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_mutation_index(self): + x = torch.zeros(7, device="cuda") + + def fn(n, a): + a[n] = -1 + return a + + opt_fn = torch.compile(fn, fullgraph=True) + + for n in range(2, x.shape[0]): + opt_fn(n, x) + self.assertEqual(x[n], -1) + + # Negative index triggers new compilation. + opt_fn(-x.shape[0], x) + + self.assertEqual(x[0], -1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_unbacked_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y) + eager_out = f(x, y) + self.assertEqual(compiled_out, eager_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_dynamic_scalar_inputs(self): + def f(x, y, integer): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + z += integer + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y, 5) + self.assertEqual(compiled_out, f(x, y, 5)) + + compiled_out = f_compiled(x, y, 6) + self.assertEqual(compiled_out, f(x, y, 6)) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_item(self): + def f(x): + y = x + 1 + scalar = y.item() + return x + y + scalar + + compiled_f = torch.compile(f) + compiled_out = compiled_f(torch.tensor(1, device="cuda")) + self.assertEqual(compiled_out, f(torch.tensor(1, device="cuda"))) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_buffer_reuse(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x1 + y1 + x @ y + u = (y_cpu.to("cuda") + 2) @ y + 3 + u_cpu = u.cpu() + 2 + return z + u_cpu.to("cuda") + + x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_fused_scheduler_node(self): + def foo(x): + x = x * 20 + x_alias = x[0] + y = x * 10 + y_alias = y[0] + torch._dynamo.graph_break() + ind = torch.tensor(4, device="cuda") + x_alias2 = x[ind:] + y_alias2 = y[ind:] + return x, x_alias, x_alias2, y_alias, y_alias2 + + compiled_foo = torch.compile(foo) + x = torch.rand([20, 20], device="cuda") + + eager_out = foo(x) + compiled_out = compiled_foo(x) + self.assertEqual(eager_out, compiled_out) + def test_meta_tensor(self): def foobar(x, y): return x * 2, y * 3 diff --git a/test/inductor/test_inductor_annotations.py b/test/inductor/test_inductor_annotations.py index bee7e0ad917da..3824b25cdeaea 100644 --- a/test/inductor/test_inductor_annotations.py +++ b/test/inductor/test_inductor_annotations.py @@ -31,10 +31,11 @@ def test_training_annotation(self): code = self.get_code() self.assertTrue("from torch.cuda import nvtx" in code) - self.assertEqual( - code.count("training_annotation = nvtx._device_range_start('inference')"), 1 + self.assertTrue( + code.count("training_annotation = nvtx._device_range_start('inference')") + >= 1 ) - self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1) + self.assertTrue(code.count("nvtx._device_range_end(training_annotation)") >= 1) if __name__ == "__main__": diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index cdcedd5a1771e..385a75d98f944 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -15044,302 +15044,6 @@ def fn(x): "'XBLOCK': 'constexpr'" ).run(code[0]) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - self.assertEqual(eager_out, compiled_out) - - _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) - - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").check( - "(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)" - ).check("recursively_apply_fns = runner.recursively_apply_fns").run( - code[0] - ) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_foreach_op(self): - def fn(a0, a1): - c = torch._foreach_abs([a0, a1]) - return torch.mul(c[0], a0) - - compiled_fn = torch.compile(fn) - - a0 = torch.randn(2, 3, device=self.device) - a1 = torch.randn(2, 3, device=self.device) - eager_out = fn(a0, a1) - compiled_out = compiled_fn(a0, a1) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_multiple_functions(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - def g(x): - return x + 1 - - x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = g(f(x, y)) - - f_compiled = torch.compile(f) - g_compiled = torch.compile(g) - compiled_out = g_compiled(f_compiled(x_cloned, y_cloned)) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_condition_op(self): - def f(p, b): - def true_fn(x): - return torch.cos(x) - - def false_fn(x): - return torch.sin(x) - - return torch.cond(p, true_fn, false_fn, [b]) - - compiled_f = torch.compile(f) - - # static shape - p = torch.tensor([True], device=self.device) - a = torch.ones([2, 3], device=self.device) - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - # dynamic shape with backed symint - p = torch.tensor([True], device=self.device) - a = torch.ones([4, 5], device=self.device) - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_unbacked_symint_multi_output_layout(self): - def f(p, size_tensor): - size_val = size_tensor.item() - b = torch.ones([size_val, 3], device=GPU_TYPE) - - def true_fn(x): - return torch.cos(x), torch.cos(x) + 1 - - def false_fn(x): - return torch.sin(x), torch.sin(x) + 1 - - cond_out = torch.cond(p, true_fn, false_fn, [b]) - return cond_out[0] + cond_out[1] - - compiled_f = torch.compile(f) - p = torch.tensor([True], device=GPU_TYPE) - size_tensor = torch.tensor(2, device=GPU_TYPE) - eager_out = f(p, size_tensor) - compiled_out = compiled_f(p, size_tensor) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - x, y = ( - torch.ones(4, 4, device=self.device), - torch.randn(4, 4, device=self.device), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_cat_backward(self): - def f(x, w): - y = torch.cat((x, x), dim=0) - z = y @ w - return z @ z.T - - compiled_f = torch.compile(f) - - for shape in (2, 3): - torch.manual_seed(42) - eager_x = torch.randn(shape, 2, device=self.device) - eager_w = torch.randn(2, 2, device=self.device, requires_grad=True) - torch.manual_seed(42) - compiled_x = torch.randn(shape, 2, device=self.device) - compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True) - - f(eager_x, eager_w).sum().backward() - compiled_f(compiled_x, compiled_w).sum().backward() - self.assertEqual(eager_w.grad, compiled_w.grad) - - @dynamo_config.patch("capture_dynamic_output_shape_ops", True) - @config.patch(implicit_fallbacks=True) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_nested_indirect_indexing(self): - def nested(x, repeats): - rank = torch.arange(repeats.numel(), device=x.device) - index = rank.repeat_interleave(repeats, dim=0) - return torch.index_select(x, index=index, dim=0) - - example_inputs = ( - torch.randn((32, 64), device=self.device), - repeats := torch.tensor([5, 10, 15], device=self.device), - ) - torch._dynamo.mark_dynamic(repeats, 0) # create backed symint - - nested_opt = torch.compile(nested, backend="inductor") - - expect = nested(*example_inputs) - actual = nested_opt(*example_inputs) - self.assertEqual(expect, actual) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_mutation_index(self): - x = torch.zeros(7, device=GPU_TYPE) - - def fn(n, a): - a[n] = -1 - return a - - opt_fn = torch.compile(fn, fullgraph=True) - - for n in range(2, x.shape[0]): - opt_fn(n, x) - self.assertEqual(x[n], -1) - - # Negative index triggers new compilation. - opt_fn(-x.shape[0], x) - - self.assertEqual(x[0], -1) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_unbacked_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y) - eager_out = f(x, y) - self.assertEqual(compiled_out, eager_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_dynamic_scalar_inputs(self): - def f(x, y, integer): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - z += integer - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y, 5) - self.assertEqual(compiled_out, f(x, y, 5)) - - compiled_out = f_compiled(x, y, 6) - self.assertEqual(compiled_out, f(x, y, 6)) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_item(self): - def f(x): - y = x + 1 - scalar = y.item() - return x + y + scalar - - compiled_f = torch.compile(f) - compiled_out = f(torch.tensor(1, device=GPU_TYPE)) - self.assertEqual(compiled_out, f(torch.tensor(1, device=GPU_TYPE))) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_buffer_reuse(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x1 + y1 + x @ y - u = (y_cpu.to(GPU_TYPE) + 2) @ y + 3 - u_cpu = u.cpu() + 2 - return z + u_cpu.to(GPU_TYPE) - - x, y = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_fused_scheduler_node(self): - def foo(x): - x = x * 20 - x_alias = x[0] - y = x * 10 - y_alias = y[0] - torch._dynamo.graph_break() - ind = torch.tensor(4, device=GPU_TYPE) - x_alias2 = x[ind:] - y_alias2 = y[ind:] - return x, x_alias, x_alias2, y_alias, y_alias2 - - foo = torch.compile(foo) - x = torch.rand([20, 20], device=GPU_TYPE) - _, code = run_and_get_code(foo, x) - - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").run(code[0]) - @unittest.skipIf(TEST_WITH_ROCM or not IS_SM90, "no scaled_grouped_mm support") def test_respect_scaled_grouped_mm_layout_tag(self): # scaled_grouped_mm needs `mat2` to be column-major diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 49f8549170b6b..a5ff9bd7b754b 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -50,6 +50,7 @@ get_benchmark_name, IndentedBuffer, is_codegen_graph_partition_subgraph, + is_using_cudagraph_partition, LineContext, sympy_product, sympy_str, @@ -1197,7 +1198,14 @@ def write_prefix(self) -> None: self.write_args(graph_input_names) self.codegen_inputs() - self.codegen_input_size_and_nan_asserts() + + # avoid duplicating asserts for both partition functions and + # the call function when using cudagraph partition + if not ( + is_using_cudagraph_partition() + and (not is_codegen_graph_partition_subgraph(self)) + ): + self.codegen_input_size_and_nan_asserts() def codegen_input_size_and_nan_asserts(self) -> None: if config.size_asserts: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 8d3b4cd7ed492..770da725a9aad 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -437,7 +437,11 @@ def prologue_fusion_enabled() -> bool: ) # enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph -graph_partition = False +graph_partition: bool = ( + os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0") + == "1" +) + # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 2686d1d2ddde2..7826c797d36be 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -10,6 +10,8 @@ from torch._inductor.utils import GraphPartitionMap, InputType from torch.utils._ordered_set import OrderedSet +from .utils import is_using_cudagraph_partition + if TYPE_CHECKING: from collections.abc import Sequence @@ -170,7 +172,8 @@ def check_multiple_devices_or_any_cpu_nodes( # meta tensors are supported since there is no compute device_node_mapping.pop(torch.device("meta"), None) - if torch._inductor.config.graph_partition: + # dynamo cudagraph does not support graph partition + if is_using_cudagraph_partition(): # graph partition supports splitting on cpu op. So we can ignore cpu nodes. device_node_mapping.pop(torch.device("cpu"), None) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index e0a0309d1c811..d8a96c573b320 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2179,7 +2179,10 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) self.process_grouped_nodes() - if torch._inductor.config.graph_partition: + if ( + torch._inductor.config.graph_partition + and torch._inductor.config.triton.cudagraphs + ): self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) @@ -4312,6 +4315,12 @@ def should_partition( ) -> bool: """Return True if we should partition the inductor graph on this node""" + # When not using cudagraphs, keep all kernels in the `call` function + # instead of graph partition functions, since graph partition only brings + # benefit to cudagraph + if not torch._inductor.config.triton.cudagraphs: + return True + # avoid duplicating logs when should_partition is called multiple times # on the same node def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f21905e16e9d7..0418edb2a1154 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3329,6 +3329,13 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: ) +def is_using_cudagraph_partition() -> bool: + return ( + torch._inductor.config.triton.cudagraphs + and torch._inductor.config.graph_partition + ) + + def dtype_from_size(size: int) -> torch.dtype: from .virtualized import V From 9ccd0f5e31ea54fcf42101dfbaacc103494e34df Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Mon, 11 Aug 2025 17:16:15 +0000 Subject: [PATCH 0219/1424] Fix unbacked symint and memory leak in inductor memory planning (#159839) Summary: In memory planning, some allocation sizes involve unbacked symints. These unbacked symints are not known before they are computed in run time, so **allocation pools that involve unbacked symints cannot be allocated until we have the values of the unbacked symints** . So we add a notion of `earliest_available` to Allocation nodes. If an allocation node has unbacked symint, it is available at only when its live range begin. Then in AllocationPool, if a pool involves an Allocation node that has an earliest available time, we restrict its life range. If a block's earliest available time is later than a pool's life range's start time, we cannot allocate it from the pool. We also fix a memory leak that's caused by allocating tensor without wrapping it with RAIIAtenTensor. In python wrapper for JIT inductor, `codegen_alloc_from_pool` doesn't actually write the alloc lines to wrapper, it just returns the string to alloc. However, in cpp_wrapper, `codegen_alloc_from_pool` actually write to the wrapper. Specifically, it writes the following and returns string `RAIIAtenTensorHandle`. ``` AtenTensorHandle handle_name; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(....); ``` This is bug prune. **If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle as well**, otherwise you get memory leaks. We remove the alloc_from_pool call from codegen_create, because this doesn't work for AOTI. In python wrapper, we can generate the same alloc_from_pool variable name for the same block, but cpp_wrapper will generate a different variable name for each call to alloc_from_pool. Test Plan: ``` python test/inductor/test_memory_planning.py ``` Rollback Plan: Differential Revision: D79603119 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159839 Approved by: https://github.com/jansel --- test/inductor/test_memory_planning.py | 63 +++++++++++++++++++--- torch/_inductor/codegen/cpp_wrapper_cpu.py | 17 +++--- torch/_inductor/codegen/memory_planning.py | 51 ++++++++++++++++-- torch/_inductor/codegen/wrapper.py | 6 ++- 4 files changed, 117 insertions(+), 20 deletions(-) diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index d5f90e662697d..1bcdeaa08e955 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -24,6 +24,14 @@ from torch.export import Dim +try: + from .test_aot_inductor import AOTIRunnerUtil +except ImportError: + from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library + AOTIRunnerUtil, + ) + + @requires_gpu() @config.patch(memory_planning=True) class TestMemoryPlanning(TestCase): @@ -76,13 +84,6 @@ def test_cpp_wrapper(self): @skipIfXpu(msg="aoti doesn't work on XPU") def test_aoti(self): - try: - from .test_aot_inductor import AOTIRunnerUtil - except ImportError: - from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library - AOTIRunnerUtil, - ) - f, args = self._generate(device=GPU_TYPE) dim0_x = Dim("dim0_x", min=1, max=2048) dynamic_shapes = ({0: dim0_x}, None, None) @@ -103,6 +104,54 @@ def test_aoti(self): ).check_next("aoti_torch__alloc_from_pool(pool1, 0").run(code) self.assertTrue(same(f(*args), result)) + @config.patch({"triton.autotune_at_compile_time": False}) + def test_unbacked_symint(self): + # when allocation's size has unbacked symints + # the unbacked symints are only available after computed + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Repro(torch.nn.Module): + def forward(self, x, y): + x = x + 1 + u0 = x.item() + torch._check(u0 >= 1) + s0 = y.size(0) + expr = u0 * s0 + sevens = torch.empty_strided( + size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device + ).fill_(7) + return sevens * 3 + + example_inputs = ( + torch.scalar_tensor(2, dtype=torch.int, device=self.device), + torch.ones(8, device=self.device), + ) + model = Repro().to(self.device) + result, code = run_and_get_cpp_code( + lambda: AOTIRunnerUtil.run(model, example_inputs) + ) + self.assertTrue(same(model(*example_inputs), result)) + + # check allocation is done after the unbacked symint is computed + FileCheck().check("auto u0 = u0_raw;").check( + "const int64_t int_array_2[] = {10L, 8L*u0, 32L};" + ).check("AtenTensorHandle pool0_handle;").check( + "aoti_torch_empty_strided(3, int_array_2, int_array_3" + ).run(code) + + # all AtenTensorHandle allocated using aoti_torch__alloc_from_pool are wrapped with RAIIAtenTensorHandle + # otherwise we'll have memory leak + FileCheck().check_count( + "aoti_torch__alloc_from_pool(pool1", 1, exactly=True + ).check_count("aoti_torch__alloc_from_pool(pool0", 1, exactly=True).run(code) + + FileCheck().check( + "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_int32, 0, int_array_1, int_array_1, &tmp_tensor_handle_0));" # noqa: B950 + ).check("RAIIAtenTensorHandle(tmp_tensor_handle_0);").check( + "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool0, 0, cached_torch_dtype_float32, 3, int_array_4, int_array_5, &tmp_tensor_handle_1));" # noqa: B950 + ).check("RAIIAtenTensorHandle(tmp_tensor_handle_1);").run(code) + if __name__ == "__main__": if HAS_GPU: diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 0edeabccebbd8..794a971adf08e 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1651,7 +1651,9 @@ def make_allocation( return f"RAIIAtenTensorHandle {name}({handle_name});" - def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + def codegen_alloc_from_pool( + self, name, offset, dtype, shape, stride + ) -> tuple[str, list[str]]: size = self.codegen_shape_tuple(shape) stride = self.codegen_shape_tuple(stride) tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" @@ -1668,11 +1670,14 @@ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: ), f"&{tmp_name}", ] - self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") - self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" - ) - return f"RAIIAtenTensorHandle({tmp_name})" + # We return the lines instead of writing here because writing here is bug prune. + # If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle + # as well, otherwise you get memory leaks + allocations_to_write = [ + f"AtenTensorHandle {tmp_name};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));", + ] + return f"RAIIAtenTensorHandle({tmp_name})", allocations_to_write def codegen_reinterpret_view( self, diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 8efec7eeca9f8..12d7500975e5b 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -10,6 +10,7 @@ import sympy import torch +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet from .. import config @@ -142,6 +143,17 @@ class Allocation(AllocationTreeNode): allocated: bool = False pool: Optional[AllocationPool] = None offset: Optional[sympy.Expr] = None + earliest_available: Optional[float] = None + + def __post_init__(self) -> None: + has_unbacked_sym = False + for s in self.node.get_layout().size: + if free_unbacked_symbols(s): + has_unbacked_sym = True + break + + if has_unbacked_sym: + self.earliest_available = self.get_live_ranges().begin @property def device(self): @@ -186,6 +198,9 @@ def __repr__(self): f"offset={self.offset})" ) + def get_earliest_available(self): + return self.earliest_available + @dataclasses.dataclass class Empty(AllocationTreeNode): @@ -377,14 +392,26 @@ class AllocationPool: names_to_del: list[str] = dataclasses.field(default_factory=list) creation_cache: dict[str, str] = dataclasses.field(default_factory=dict) + def __post_init__(self) -> None: + for block in self.root.allocations: + if isinstance(block, Allocation): + self.update_restrict_live_range(block) + def allocate(self, block: Allocation, is_last: bool): - if self.restrict_live_range and not self.restrict_live_range.contains( - block.live_range + if ( + self.restrict_live_range is not None + and not self.restrict_live_range.contains(block.live_range) ): return False + block_earliest_available = block.get_earliest_available() + pool_begin = self.root.get_live_ranges().begin + if block_earliest_available and block_earliest_available > pool_begin: + return False + is_last = self.can_expand and is_last if self.root.allocate(block, is_last): + self.update_restrict_live_range(block) return True if is_last: @@ -392,9 +419,22 @@ def allocate(self, block: Allocation, is_last: bool): return False + def update_restrict_live_range(self, block: Allocation): + if block_earliest_available := block.get_earliest_available(): + if self.restrict_live_range is None: + self.restrict_live_range = LiveRange( + block_earliest_available, float("inf") + ) + else: + self.restrict_live_range = LiveRange( + min(self.restrict_live_range.begin, block_earliest_available), + self.restrict_live_range.end, + ) + def allocate_at_end(self, block): block.mark_allocated() self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) + self.update_restrict_live_range(block) return True def finalize(self, name): @@ -408,7 +448,6 @@ def codegen_create(self, wrapper, code: IndentedBuffer): nbytes = self.root.get_symbolic_size() for block in self.root.allocations: if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): - # optimization: fuse first allocation and pool creation node = block.node code.writeline( wrapper.make_allocation( @@ -419,7 +458,6 @@ def codegen_create(self, wrapper, code: IndentedBuffer): stride=tuple(node.get_stride()), ) ) - self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name return else: code.writeline( @@ -577,7 +615,10 @@ def codegen(self, code: IndentedBuffer): pool.codegen_create(self.wrapper, code) pool.names_to_del.extend(self.group.names) - alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) + alloc_from_pool, allocation_lines_to_write = allocation.codegen_alloc_from_pool( + self.wrapper + ) + code.writelines(allocation_lines_to_write) if alloc_from_pool in pool.creation_cache: code.writeline( self.wrapper.make_tensor_alias( diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index a5ff9bd7b754b..9394c0e4a16d6 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1765,7 +1765,9 @@ def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str: def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str: return self.codegen_python_shape_tuple(shape) - def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + def codegen_alloc_from_pool( + self, name, offset, dtype, shape, stride + ) -> tuple[str, list[str]]: return "alloc_from_pool({})".format( ", ".join( [ @@ -1776,7 +1778,7 @@ def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: self.codegen_python_shape_tuple(stride), ] ) - ) + ), [] def codegen_reinterpret_view( self, From d0e2240f680ea2a553f7ee8188f52482e130bfd0 Mon Sep 17 00:00:00 2001 From: David Berard Date: Mon, 11 Aug 2025 17:22:40 +0000 Subject: [PATCH 0220/1424] [triton_heuristics] Optimize the triton launcher in pt2 (#160000) Summary: (Original author: Xu Zhao. Commandeered by David to land this since it is relatively urgent) We observed ~10us PT2-Triton launch overhead regression after pin update. Before Triton pin-update: {F1980557238} After Triton pin-update: {F1980557240} The root cause is because https://github.com/pytorch/pytorch/pull/145051 adds `_get_args_with_constexprs` to the cubin launcher caller function, which is on the critical path. The motivation for `_get_args_with_constexprs` was that between triton 3.2 and triton 3.3, the convention for calling Triton kernels (at the level that non-static-cuda-launcher inductor integrates) changed. Previously, the callable did not take constexpr arguments as parameters; after 3.3, it does. With pointwise/reduction kernels, we don't know the constexpr values until after autotuning occurs; so `_get_args_with_constexprs` would inject constexprs into the arguments list before calling the Triton kernel. The fix (in this PR) is to instead inject the constexpr args into the launcher string - this avoids the cost of sorting/reordering arguments which previously occurred upon execution of each kernel. Note that the static_cuda_launcher.py does not require constants to be passed to the cubin launcher (https://github.com/pytorch/pytorch/blob/e96c7c4bb0f6aeae2ab3b6f040f7d67edbec199a/torch/_inductor/runtime/static_cuda_launcher.py#L220), there is no need to pass in constexprs to the generated launcher code. The new launcher code needs to work on three cases: - StaticallyLaunchedCudaKernel - triton.compile.CompiledKernel - AOTInductor Analysis: https://docs.google.com/document/d/1PHaSmx2w59K8qpjw5_qzKWShfEgptf_Zpv_DL7YxiWU/edit?tab=t.0 Test Plan: Before: ``` $ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BERT_pytorch --performance --backend=inductor --training --amp --disable-cudagraphs 1.893x ``` ``` $ buck2 run mode/opt //pytorch/tritonbench:run -- --op launch_latency x_val nop_python_function-walltime nop_triton_kernel-walltime nop_triton_compiled_kernel_run-walltime nop_inductor_kernel-walltime nop_inductor_kernel_cudagraph-walltime ------- ------------------------------ ---------------------------- ----------------------------------------- ------------------------------ ---------------------------------------- 0 0.00760921 1.80298 0.623282 5.25024 0.203722 19 0.00799885 4.78223 1.00226 5.8213 0.239084 average 0.00780403 3.29261 0.812769 5.53577 0.221403 ``` After: ``` buck2 run mode/opt //pytorch/tritonbench:run -- --op launch_latency x_val nop_python_function-walltime nop_triton_kernel-walltime nop_triton_compiled_kernel_run-walltime nop_inductor_kernel-walltime nop_inductor_kernel_cudagraph-walltime ------- ------------------------------ ---------------------------- ----------------------------------------- ------------------------------ ---------------------------------------- 0 0.00747067 1.92589 0.726509 4.35459 0.204205 19 0.00747823 7.36852 1.26241 6.28208 0.239278 average 0.00747445 4.6472 0.994459 5.31834 0.221741 ``` ``` $ buck2 run mode/opt //pytorch/benchmark:pt2 -- --only BERT_pytorch --performance --backend=inductor --training --amp --disable-cudagraphs 1.985x ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160000 Approved by: https://github.com/jansel Co-authored-by: Xu Zhao --- torch/_inductor/ir.py | 3 + torch/_inductor/runtime/triton_heuristics.py | 65 +++++++++----------- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a668cd41ebf1b..47167b180f52e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6630,6 +6630,9 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: for name, arg in itertools.chain( named_args.items(), zip(itertools.repeat(""), extra_launch_args) ): + if name in constexpr_names and triton_version_uses_attrs_dict(): + # see #160000 - we don't pass in constexpr args to speed up runtime. + continue raw_keys_filtered.append(name) raw_args_filtered.append(arg) if isinstance(arg, IRNode): diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 8425cba55795a..47516a4a71c47 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -196,8 +196,7 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): call_kwargs[k] = v else: call_kwargs[k] = v - if not triton_version_uses_attrs_dict(): - call_kwargs.update(launcher.config.kwargs) + call_kwargs.update(launcher.config.kwargs) call_kwargs["num_warps"] = launcher.config.num_warps call_kwargs["num_stages"] = launcher.config.num_stages if HAS_WARP_SPEC: @@ -770,28 +769,6 @@ def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta) - def _get_args_with_constexprs(self, args, launcher): - """ - `args` is passed in with only the non-constexpr args (because the constexpr arg values - depend on the config). However, in later triton versions, the constexpr args need to be - added into the args list. - """ - if triton_version_uses_attrs_dict(): - # first: aggregate the constexpr args in (index, val) pairs - # so we can sort them by index. - constexpr_args: list[tuple[int, Any]] = [] - for arg_name, arg_val in launcher.config.kwargs.items(): - if arg_name in self.fn.arg_names: - constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) - - constexpr_args.sort() - new_args = [*args] - for arg_idx, arg_val in constexpr_args: - new_args.insert(arg_idx, arg_val) - - return new_args - return args - def bench(self, launcher, *args, with_profiler=False, **kwargs): """Measure the performance of a given launcher""" # we don't skip configs with spilled registers when auto-tuning custom @@ -820,23 +797,22 @@ def kernel_call(): ) # reset to zero before evaluating any config self.reset_to_zero_args(*args, **kwargs) - args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher) if autograd_profiler._is_profiler_enabled: profiler_kwargs = self.get_profiler_kwargs(stream, launcher) with torch._C._profiler._RecordFunctionFast( self.inductor_meta.get("kernel_name", "triton kernel"), - args_with_constexprs, + cloned_args, profiler_kwargs, ): launcher( - *args_with_constexprs, + *cloned_args, **cloned_kwargs, stream=stream, ) else: launcher( - *args_with_constexprs, + *cloned_args, **cloned_kwargs, stream=stream, ) @@ -1240,7 +1216,6 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): # so _RecordFunctionFast need to capture the args into CachingAutotuner::run() # make a copy here to avoid mutating the original args args_without_constexprs = tuple(args) - args = self._get_args_with_constexprs(args, launcher) if self.dump_launch_params: new_args, grid = self._interpret_args_grid(args, launcher.config) @@ -1296,6 +1271,10 @@ def __call__(self, _=None) -> str: class CompileResult(Generic[_T]): + """ + Base class representing compiled result. + """ + def __init__( self, kernel: _T, @@ -1359,21 +1338,30 @@ def _get_arg_lists( ) none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys())) + def _convert_constant(constant): + if isinstance(constant, str): + return "r'" + constant + "'" + else: + return repr(constant) + if triton_version_uses_attrs_dict(): call_args = arg_names def_args = arg_names - if ( - "num_warps" in compile_meta["constants"] - or "num_stages" in compile_meta["constants"] + implicit_constants = OrderedSet( + ( + "num_warps", + "num_stages", + ) + ).union(OrderedSet(k for k in known_constants)) + if implicit_constants := implicit_constants & OrderedSet( + compile_meta["constants"].keys() ): # num_warps/num_stages are special implicit args that are not in the signature # see test_triton_kernel_special_params - def_args = [ - arg for arg in def_args if arg not in ("num_warps", "num_stages") - ] + def_args = [arg for arg in def_args if arg not in implicit_constants] repl = { - k: str(compile_meta["constants"].get(k)) - for k in ("num_warps", "num_stages") + k: _convert_constant(compile_meta["constants"].get(k)) + for k in implicit_constants } call_args = [repl.get(arg, arg) for arg in call_args] else: @@ -1653,6 +1641,8 @@ def make_launcher(self) -> LauncherType: import math as math_lib + import triton as triton_lib + import torch as torch_lib scope = { @@ -1687,6 +1677,7 @@ def make_launcher(self) -> LauncherType: "runner": get_first_attr(binary, "run", "c_wrapper"), "math": math_lib, "torch": torch_lib, + "triton": triton_lib, } if not hasattr(binary, "launch_metadata"): From d25c4f954d599ea512e2f70cd6df101c21479d4c Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 11 Aug 2025 09:57:30 -0700 Subject: [PATCH 0221/1424] [MPS] Type-promote tensor-iterator common dtype (#160334) Otherwise, `torch.add(FloatTensor, IntTensor, alpha=2)` and `torch.add(FloatTensor, IntTensor, alpha=2)` were dispatched to different kernels Fixes https://github.com/pytorch/pytorch/issues/160208 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160334 Approved by: https://github.com/Skylion007, https://github.com/dcci --- aten/src/ATen/native/mps/operations/BinaryKernel.mm | 1 + test/test_mps.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 806eeb82e1d17..b2a1b2757b13a 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -53,6 +53,7 @@ void binary_op_kernel(const std::string func_name, .add_input(input) .add_input(other) .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(true) .build(); lib.exec_binary_kernel(iter, func_name, alpha); diff --git a/test/test_mps.py b/test/test_mps.py index 6c55cb775f063..bff55eec95ae1 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7736,6 +7736,8 @@ def helper(shape, alpha, op_name, inplace): y = torch.arange(32, device='mps', dtype=torch.int32) self.assertEqual(torch.add(x, y, alpha=2).cpu(), torch.add(x.cpu(), y.cpu(), alpha=2)) self.assertEqual(torch.add(x, 3, alpha=2).cpu(), torch.add(x.cpu(), 3, alpha=2)) + # Regression test for https://github.com/pytorch/pytorch/issues/160208 + self.assertEqual(torch.add(y, x, alpha=2).cpu(), torch.add(y.cpu(), x.cpu(), alpha=2)) # Test add def test_add_scalars(self): From c8205cb35435f39d2c26f6c94b45e4adeb6dcb23 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Sat, 9 Aug 2025 12:02:47 -0700 Subject: [PATCH 0222/1424] [autograd] match 0-dim gradients device type regardless of subclassness (#160165) Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it? FIXES https://github.com/pytorch/pytorch/issues/160084 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160165 Approved by: https://github.com/ezyang, https://github.com/albanD --- test/dynamo/test_repros.py | 25 +++++++++++++++++++ test/test_autograd.py | 23 ++++++++++++++++++ test/test_python_dispatch.py | 44 ---------------------------------- torch/csrc/autograd/engine.cpp | 14 +++++------ 4 files changed, 55 insertions(+), 51 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 1da35106d54c8..fe16e4906ef39 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7673,6 +7673,31 @@ def forward(self, x): out2 = torch.compile(model, backend="eager")(input.clone()) self.assertEqual(out1, out2) + @requires_cuda + def test_zero_dim_param_mixed_device_grad(self): + # cpu 0-dim params with cuda grads + # https://github.com/pytorch/pytorch/issues/160084 + class RegressionModel(torch.nn.Module): + def __init__(self, a=0, b=0): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(a).float()) + self.b = torch.nn.Parameter(torch.tensor(b).float()) + + def forward(self, x): + return x * self.a + self.b + + model = RegressionModel() + model.forward = torch.compile( + model.forward, backend="aot_eager", fullgraph=True + ) + inputs = torch.randn(4, 10).to("cuda") + out = model(inputs) + out.sum().backward() + self.assertIsNotNone(model.a.grad) + self.assertIsNotNone(model.b.grad) + self.assertEqual(model.a.grad.device, torch.device("cpu")) + self.assertEqual(model.b.grad.device, torch.device("cpu")) + def test_filter_warnings(self): x = torch.ones(2, 2, requires_grad=True) diff --git a/test/test_autograd.py b/test/test_autograd.py index e26e193cc799a..01a2c54dc2774 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -12396,6 +12396,29 @@ def test_resize_version_bump(self, device): x.resize_as_(y) self.assertEqual(x._version, 2) + @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + def test_zero_dim_param_mixed_device_grad(self, device): + # cpu 0-dim params with an accelerator device grad + # https://github.com/pytorch/pytorch/issues/160084 + class RegressionModel(torch.nn.Module): + def __init__(self, a=0, b=0): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(a).float()) + self.b = torch.nn.Parameter(torch.tensor(b).float()) + + def forward(self, x): + return x * self.a + self.b + + # Keep the model on cpu as we do want to test the mixed cpu/accelerator behavior here + model = RegressionModel() + inputs = torch.randn(4, 10, device=device) + out = model(inputs) + out.sum().backward() + self.assertIsNotNone(model.a.grad) + self.assertIsNotNone(model.b.grad) + self.assertEqual(model.a.grad.device, torch.device("cpu")) + self.assertEqual(model.b.grad.device, torch.device("cpu")) + class TestAllowMutationOnSaved(TestCase): def assertClonedLenEqual(self, ctx, n): diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 71ebf5d784308..9faa5ce4b8946 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -1,7 +1,6 @@ # Owner(s): ["module: __torch_dispatch__"] # ruff: noqa: F841 -import logging import pickle import sys import tempfile @@ -1718,49 +1717,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): self.assertEqual(s.device_index, 2) self.assertEqual(s.device_type, 3) - def test_subclass_autograd_device_check(self) -> None: - class NonWrapperSubclass(torch.Tensor): - elem: torch.Tensor - - __slots__ = ["elem"] - - @staticmethod - def __new__(cls, elem, *args, **kwargs): - # Wrong device here! - r = torch.Tensor._make_subclass( - cls, elem.to("meta"), elem.requires_grad - ) - # ...the real tensor is held as an element on the tensor. - r.elem = elem - return r - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(e): - return e.elem if isinstance(e, NonWrapperSubclass) else e - - def wrap(e): - return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e - - rs = tree_map( - wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) - ) - logging.getLogger("NonWrapperSubclass").info( - f"{func.__module__}.{func.__name__}", # noqa: G004 - args, - kwargs, - rs, - ) - return rs - - x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True)) - y = torch.randn(2, requires_grad=True) - z = x * y - self.assertIsInstance(z, NonWrapperSubclass) - z.sum().backward(torch.tensor(1)) - self.assertEqual(x.grad, y) - self.assertEqual(y.grad, x) - def test_none_wrapping(self): # A Tensor subclass that returns None when doing add # See LoggingTensor above for more details on the subclass diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 4e8cb2efca0e1..f0024f8f0b070 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -979,13 +979,13 @@ static void validate_outputs_impl( } if (grad.device() != metadata.device()) { - // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but - // should be eventually removed - if (!(metadata.is_tensor_subclass() || - grad.unsafeGetTensorImpl()->is_python_dispatch())) { - if (grad.dim() == 0) { - grad = grad.to(metadata.device()); - } else { + if (grad.dim() == 0) { + grad = grad.to(metadata.device()); + } else { + // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but + // should be eventually removed + if (!(metadata.is_tensor_subclass() || + grad.unsafeGetTensorImpl()->is_python_dispatch())) { std::stringstream ss; ss << "invalid gradient at index " << i << " - expected device "; ss << metadata.device() << " but got " << grad.device(); From 76a0609b6bddb2bc40f1eb4ade12885023653d59 Mon Sep 17 00:00:00 2001 From: "Liao, Wei" Date: Mon, 11 Aug 2025 19:43:11 +0000 Subject: [PATCH 0223/1424] port distributed pipeline test files for Intel GPU (#159033) In this PR we will port all distributed pipeline test files. We could enable Intel GPU with following methods and try the best to keep the original code styles: 1. instantiate_device_type_tests() 2. use "torch.accelerator.current_accelerator()" to determine the accelerator backend 3. use "requires_accelerator_dist_backend()" to replace requires_nccl() 4. use "get_default_backend_for_device()" to get backend 5. enabled XPU for some test path 6. add TEST_MULTIACCELERATOR in common_utils for all backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159033 Approved by: https://github.com/guangyey, https://github.com/d4l3k Co-authored-by: Daisy Deng --- test/distributed/pipelining/test_schedule.py | 10 +-- .../pipelining/test_schedule_multiproc.py | 89 ++++++++++++------- test/distributed/pipelining/test_stage.py | 51 ++++++----- .../pipelining/test_transformer.py | 4 +- test/distributed/pipelining/test_unflatten.py | 4 +- torch/testing/_internal/common_utils.py | 1 + 6 files changed, 102 insertions(+), 57 deletions(-) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index b1ad9b757a89b..6f5b4df82a4ad 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -38,7 +38,7 @@ W, ) from torch.distributed.pipelining.stage import _PipelineStageBase, PipelineStage -from torch.testing._internal.common_distributed import requires_nccl +from torch.testing._internal.common_distributed import requires_accelerator_dist_backend from torch.testing._internal.common_utils import ( check_leaked_tensors, instantiate_parametrized_tests, @@ -51,6 +51,8 @@ ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "artifacts") +device = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + logger = logging.getLogger(__name__) torch.manual_seed(0) @@ -657,7 +659,7 @@ def _dump_csv(pipeline_order_with_comms, filename: str): # print(_format_pipeline_order(simulated_schedule)) self.assertEqual(num_steps, 113) - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) def test_grad_with_v_schedule(self): """ We have a special case for V schedules where 2 adjacent stages are on the same rank. @@ -677,7 +679,6 @@ def test_grad_with_v_schedule(self): d_hid = 512 batch_size = 256 n_stages = 2 - device = "cuda" full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(device) @@ -776,7 +777,7 @@ def test_grad_with_v_schedule(self): torch.distributed.destroy_process_group() - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) def test_grad_with_split_b_w(self): """ Ensure that separate dInput and dWeight computations are correctly executed. @@ -789,7 +790,6 @@ def test_grad_with_split_b_w(self): d_hid = 512 batch_size = 256 n_stages = 1 - device = "cuda" full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(device) diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index ae91911bc6a02..a87d924541513 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -26,10 +26,9 @@ ScheduleZBVZeroBubble, ) from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime -from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, - requires_nccl, + requires_accelerator_dist_backend, ) from torch.testing._internal.common_utils import ( check_leaked_tensors, @@ -37,6 +36,7 @@ parametrize, run_tests, skip_but_pass_in_sandcastle_if, + TEST_MULTIACCELERATOR, ) @@ -45,7 +45,8 @@ d_hid = 512 batch_size = 64 torch.manual_seed(0) -device_type = "cuda" +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +backend = dist.get_default_backend_for_device(device_type) class ScheduleTest(MultiProcContinousTest): @@ -53,8 +54,7 @@ class ScheduleTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: - # Testing with NCCL backend - return "nccl" + return backend @property def device(self) -> torch.device: @@ -180,8 +180,10 @@ def _zero_gradients(self, stage_modules): for stage_module in stage_modules: stage_module.zero_grad() - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [_ScheduleForwardOnly]) def test_forward_only(self, ScheduleClass): mod, mod_ref, x, _, _ = self._setup_models_and_data() @@ -210,8 +212,10 @@ def test_forward_only(self, ScheduleClass): x_clone = mod_ref(x_clone) torch.testing.assert_close(x_clone, out) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize( "ScheduleClass", [ @@ -283,8 +287,10 @@ def test_eval_inference_mode(self, ScheduleClass): if self.rank == self.world_size - 1: self.assertTrue(len(losses) > 0, "Losses should be computed during eval()") - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_multi_iter(self, ScheduleClass): mod, _, x, target, loss_fn = self._setup_models_and_data() @@ -302,8 +308,10 @@ def test_multi_iter(self, ScheduleClass): else: schedule.step() - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_kwargs_with_tracer(self, ScheduleClass): # Model has two stages only, thus limiting group size to 2 @@ -359,8 +367,10 @@ def test_kwargs_with_tracer(self, ScheduleClass): torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3) torch.testing.assert_close(pipe_loss, ref_loss) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_grad_with_tracer(self, ScheduleClass): mod, ref_mod, x, target, loss_fn = self._setup_models_and_data() @@ -398,8 +408,10 @@ def test_grad_with_tracer(self, ScheduleClass): # Check gradients using helper method self._check_gradients(stage_module, ref_mod) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @parametrize("shape_inference", [True, False]) def test_grad_with_manual(self, ScheduleClass, shape_inference): @@ -453,8 +465,10 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): # Check gradients using helper method self._check_gradients(stage_module, ref_mod) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize( "ScheduleClass", [ @@ -563,8 +577,10 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): stage_modules, ref_mod, submod_names, rtol=5e-3, atol=5e-3 ) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble]) def test_schedule_with_native_zero_bubble(self, ScheduleClass): print(ScheduleClass) @@ -621,9 +637,16 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - @parametrize("ScheduleClass", [ScheduleWithReorderedB]) + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) + @parametrize( + "ScheduleClass", + [ + ScheduleWithReorderedB, + ], + ) def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): n_stages = 2 stages_per_rank = 1 @@ -679,8 +702,10 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize( "schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble] ) @@ -740,8 +765,10 @@ def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble]) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 @@ -820,8 +847,10 @@ def dw_runner(): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize( "ScheduleClass", [ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B], diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index a711cec64d72a..acb5bec7d84ee 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -14,11 +14,10 @@ ScheduleGPipe, ) from torch.distributed.pipelining._utils import PipeliningShapeError -from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, MultiProcessTestCase, - requires_nccl, + requires_accelerator_dist_backend, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -26,6 +25,7 @@ run_tests, skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, + TEST_MULTIACCELERATOR, ) from torch.utils._pytree import tree_map_only @@ -34,8 +34,8 @@ batch_size = 256 chunks = 4 -device_type = "cuda" - +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +backend = dist.get_default_backend_for_device(device_type) torch.manual_seed(0) @@ -66,8 +66,7 @@ def f(x): class StageTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: - # Testing with NCCL backend - return "nccl" + return backend @classmethod def device_type(cls) -> str: @@ -77,8 +76,10 @@ def device_type(cls) -> str: def device(self) -> torch.device: return torch.device(device_type, self.rank) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ModelClass", [ExampleCode, MultiMLP]) def test_tracer(self, ModelClass): mod = ModelClass(d_hid, self.world_size) @@ -121,8 +122,10 @@ def _run_step(x): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) @parametrize("ModelClass", [ModelWithKwargs]) def test_tracer_kwargs(self, ModelClass): mod = ModelClass(d_hid, self.world_size) @@ -170,8 +173,10 @@ def test_tracer_kwargs(self, ModelClass): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) def test_manual(self): full_mod = MultiMLP(d_hid, n_layers=self.world_size) full_mod.to(self.device) @@ -202,8 +207,10 @@ def _run_step(x): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) def test_custom_dw_with_fb_schedule(self): """Tests that separate weight grad function 'dw_runner' gets run under a schedule that's only aware of F/B.""" full_mod = MultiMLP(d_hid, n_layers=self.world_size) @@ -262,8 +269,10 @@ def _run_step(x): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) def test_output_chunks_memory_usage(self): """Test that output_chunks doesn't store memory for non-first stages.""" full_mod = MultiMLP(d_hid, n_layers=self.world_size) @@ -347,14 +356,14 @@ def tearDown(self): def init_pg(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( - backend="nccl", + backend=backend, store=store, rank=self.rank, world_size=self.world_size, device_id=self.device, ) - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle("Flaky in CI") def test_shape_prop_mismatch(self): """Tests shape prop errors are raised""" @@ -402,8 +411,10 @@ def _run_step(x): with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): _run_step(x) - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @requires_accelerator_dist_backend(["nccl", "xccl"]) + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" + ) def test_custom_dw_errors(self): """Tests expected errors are raised""" self.init_pg() diff --git a/test/distributed/pipelining/test_transformer.py b/test/distributed/pipelining/test_transformer.py index 7e58129186a69..20e830547de7b 100644 --- a/test/distributed/pipelining/test_transformer.py +++ b/test/distributed/pipelining/test_transformer.py @@ -73,7 +73,9 @@ def get_layers(module): devices = ["cpu", "cuda", "hpu", "xpu"] -instantiate_device_type_tests(TransformerTests, globals(), only_for=devices) +instantiate_device_type_tests( + TransformerTests, globals(), only_for=devices, allow_xpu=True +) if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index ae1e684d7c222..0493f39b16cb8 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -73,7 +73,9 @@ def test_unflatten(self, device): devices = ["cpu", "cuda", "hpu", "xpu"] -instantiate_device_type_tests(UnflattenTests, globals(), only_for=devices) +instantiate_device_type_tests( + UnflattenTests, globals(), only_for=devices, allow_xpu=True +) if __name__ == "__main__": run_tests() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index bfc568bc14645..f3c0648b46254 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1422,6 +1422,7 @@ def is_privateuse1_backend_available(): TEST_XPU = torch.xpu.is_available() TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False TEST_CUDA = torch.cuda.is_available() +TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2 custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) TEST_PRIVATEUSE1 = is_privateuse1_backend_available() TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name() From c3dc8dc4122977893004c49d10e4676cd0a97da4 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Sun, 10 Aug 2025 14:37:12 -0400 Subject: [PATCH 0224/1424] 159965 is merged, no need to patch it in (#160275) Signed-off-by: Edward Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/160275 Approved by: https://github.com/albanD, https://github.com/ZainRizvi --- codex_setup.sh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/codex_setup.sh b/codex_setup.sh index f169a7b1f6936..85c7b93e89794 100755 --- a/codex_setup.sh +++ b/codex_setup.sh @@ -9,10 +9,6 @@ COMMIT=$(grep -oE '[0-9a-f]{40}' <<< "$NIGHTLY_PATCH" | head -1) COMMIT_DATE=$(echo "$NIGHTLY_PATCH" | grep '^Date:' | sed -E 's/Date: .*, ([0-9]+) ([A-Za-z]+) ([0-9]+) .*/\3 \2 \1/' | awk 'BEGIN{split("Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec", months, " "); for(i=1;i<=12;i++) month[months[i]]=sprintf("%02d",i)} {print $1 month[$2] sprintf("%02d",$3)}') VERSION_STRING="2.9.0.dev${COMMIT_DATE}+cpu" git rev-parse HEAD > /tmp/orig_work.txt -cp AGENTS.md /tmp git reset --hard $COMMIT -cp /tmp/AGENTS.md . -curl https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/159965.diff | patch -p1 USE_NIGHTLY=$VERSION_STRING python setup.py develop -git commit -asm "Agents patch" echo "source $PWD/.venv/bin/activate" >> ~/.bashrc From 9eedd2a20b64302d0d116ea2802b50948d2ebb09 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Mon, 11 Aug 2025 20:13:22 +0000 Subject: [PATCH 0225/1424] [PGO] no counterfactual suggestions for dynamic allowlist (#160231) Being more conservative with whitelist suggestions as we roll out suggestions; now we only suggest sources that were dynamic in previous runs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160231 Approved by: https://github.com/bobrenjc93 --- test/dynamo/test_pgo.py | 20 +++++++++++++------- torch/_dynamo/variables/builder.py | 1 - 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py index e9bef4a7714b5..643d15eb2413d 100644 --- a/test/dynamo/test_pgo.py +++ b/test/dynamo/test_pgo.py @@ -56,6 +56,10 @@ def f(x): f(torch.randn(2, 6)) self.assertEqual(cnts.frame_count, 1) + @torch._dynamo.config.patch( + force_parameter_static_shapes=False, + force_nn_module_property_static_shapes=False, + ) def test_whitelist_suggestion(self): cnts = CompileCounter() @@ -195,14 +199,16 @@ def run(): self.assertEqual(cnts.frame_count, 3) # parameter static shapes are forced static, so we recompile once - run() - self.assertEqual(cnts.frame_count, 2) + with torch._dynamo.config.patch( + force_parameter_static_shapes=False, + force_nn_module_property_static_shapes=False, + ): + run() + self.assertEqual(cnts.frame_count, 2) - # flags are flipped, PGO records dynamism, so params are dynamically compiled to start - torch._dynamo.config.force_parameter_static_shapes = False - torch._dynamo.config.force_nn_module_property_static_shapes = False - run() - self.assertEqual(cnts.frame_count, 1) + # because flags were flipped, params were included in PGO + run() + self.assertEqual(cnts.frame_count, 1) def test_njt(self): cnts = CompileCounter() diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 481773860f8d5..d4aac8041452c 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3247,7 +3247,6 @@ def _automatic_dynamic( ) if static_shapes and not is_dynamic_source(name): - record_automatic_dynamic(tx, name, e) return StatefulSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * e.dim(), dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), From 09381f5dacda7bbbfa361f5df76bde5cd309adc1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 11 Aug 2025 20:34:27 +0000 Subject: [PATCH 0226/1424] Revert "[Graph Partition] Pass all OSS unit tests (#154667)" This reverts commit ca7315c17162ea21b1ca5ba23f4bf6168766c7b9. Reverted https://github.com/pytorch/pytorch/pull/154667 on behalf of https://github.com/clee2000 due to broke inductor/test_memory.py::TestOperatorReorderForPeakMemory::test_reorder_peak_memory_lpmf [GH job link](https://github.com/pytorch/pytorch/actions/runs/16885961204/job/47836769279) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/ca7315c17162ea21b1ca5ba23f4bf6168766c7b9) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/154667#issuecomment-3176805477)) --- test/inductor/test_compiled_autograd.py | 22 +- test/inductor/test_control_flow.py | 3 - test/inductor/test_cuda_repro.py | 6 +- test/inductor/test_cudagraph_trees.py | 330 ++------------------- test/inductor/test_inductor_annotations.py | 7 +- test/inductor/test_torchinductor.py | 296 ++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 10 +- torch/_inductor/config.py | 6 +- torch/_inductor/cudagraph_utils.py | 5 +- torch/_inductor/scheduler.py | 11 +- torch/_inductor/utils.py | 7 - 11 files changed, 325 insertions(+), 378 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index dff94b4aa0927..241528b159cc1 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -3085,16 +3085,7 @@ def backward(ctx, gO): self.assertEqual(counters["compiled_autograd"]["captures"], 1) # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. - if inductor_config.graph_partition: - # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops - # and cudagraphify the remaining computation. So there is no cudagraph skip. - expected_cudagraph_skips = 0 - else: - expected_cudagraph_skips = 1 - - self.assertEqual( - counters["inductor"]["cudagraph_skips"], expected_cudagraph_skips - ) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) @scoped_load_inline @requires_cuda_and_triton @@ -3159,18 +3150,9 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. # In the future, we can consider having a cpu scalar movement pass sometime after we trace # into the custom C++ autograd::Function (like in AOTDispatcher) - if inductor_config.graph_partition: - # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops - # and cudagraphify the remaining computation. So there is no cudagraph skip. - expected_cudagraph_skips = 0 - elif inductor_config.cpp_wrapper: - expected_cudagraph_skips = 2 - else: - expected_cudagraph_skips = 1 - self.assertEqual( counters["inductor"]["cudagraph_skips"], - expected_cudagraph_skips, + 2 if inductor_config.cpp_wrapper else 1, ) def test_logs(self): diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 511b9cea5e14d..107a65d6fa1df 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -472,9 +472,6 @@ def false_fn(x): @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @torch._inductor.config.patch(size_asserts=False) - # TODO: graph partition does not support creating tensor - # with dynamic shape in conditional subgraph yet - @torch._inductor.config.patch(graph_partition=False) def test_cond_unbacked_symint_inner(self, device): class Model(torch.nn.Module): def forward(self, p, a): diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 53506698297f1..00511c572239e 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -189,9 +189,9 @@ def f(q, k, v, mask): # padded bias should have an expanded dim FileCheck().check("buf0 =").check_same(", 0, ").run(code[0]) # single fused padded kernel - FileCheck().check_count("empty_strided_cuda(", 1, exactly=True).check( - "return" - ).run(code[0]) + FileCheck().check("def call").check_count( + "empty_strided_cuda", 1, exactly=True + ).check("return").run(code[0]) self.assertEqual(out, f(*inputs)) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 763384671eb52..1408a0208cf06 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -279,14 +279,10 @@ def foo(x, y): with capture_stderr() as captured_output: foo(torch.ones([10], device="cuda"), torch.ones([20])) - if torch._inductor.config.graph_partition: - # graph partition splits on cpu ops - self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) - else: - FileCheck().check( - "skipping cudagraphs due to cpu device (arg1_1). Found from" - ).check("y + 2").run(captured_output[0]) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + FileCheck().check( + "skipping cudagraphs due to cpu device (arg1_1). Found from" + ).check("y + 2").run(captured_output[0]) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) with capture_stderr() as captured_output: foo( @@ -296,10 +292,7 @@ def foo(x, y): FileCheck().check("skipping cudagraphs due to multiple devices").run( captured_output[0] ) - self.assertEqual( - counters["inductor"]["cudagraph_skips"], - 1 if torch._inductor.config.graph_partition else 2, - ) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) @torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True) def test_skip_symbolic(self): @@ -814,16 +807,10 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - if torch._inductor.config.graph_partition: - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 0), (0, 2)], - ) - else: - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) @@ -1140,13 +1127,8 @@ def foo2(x): node = self.curr_node() first_node = next(node._path_from_root) - if torch._inductor.config.graph_partition: - # graph partition may changed the order of outputs - self.assertFalse(first_node.unaliased_in_all_paths[1]) - self.assertTrue(first_node.cached_tensor_outputs[1] is None) - else: - self.assertFalse(first_node.unaliased_in_all_paths[0]) - self.assertTrue(first_node.cached_tensor_outputs[0] is None) + self.assertFalse(first_node.unaliased_in_all_paths[0]) + self.assertTrue(first_node.cached_tensor_outputs[0] is None) @torch._inductor.config.patch("implicit_fallbacks", True) def test_multinomial(self): @@ -1649,16 +1631,10 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - if torch._inductor.config.graph_partition: - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 0), (0, 2)], - ) - else: - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) def test_separate_recordings(self): @@ -2161,8 +2137,8 @@ def forward(self, x) -> torch.Tensor: with self.assertRaisesRegex( Exception, r"(?s)static input data pointer changed.\n" - r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*" - r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*," + r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*" + r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n", ): self.curr_node().run( @@ -3575,278 +3551,6 @@ def run(padded_size, original_size): self.assertEqual(self.get_manager().new_graph_id().id, 2) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_simple(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to("cuda") - - x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - self.assertEqual(eager_out, compiled_out) - - _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) - - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").check( - "recursively_apply_fns = runner.recursively_apply_fns" - ).run(code[0]) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_foreach_op(self): - def fn(a0, a1): - c = torch._foreach_abs([a0, a1]) - return torch.mul(c[0], a0) - - compiled_fn = torch.compile(fn) - - a0 = torch.randn(2, 3, device="cuda") - a1 = torch.randn(2, 3, device="cuda") - eager_out = fn(a0, a1) - compiled_out = compiled_fn(a0, a1) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_condition_op(self): - def f(p, b): - def true_fn(x): - return torch.cos(x) - - def false_fn(x): - return torch.sin(x) - - return torch.cond(p, true_fn, false_fn, [b]) - - compiled_f = torch.compile(f) - - # static shape - p = torch.tensor([True], device="cuda") - a = torch.ones([2, 3], device="cuda") - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - # dynamic shape with backed symint - p = torch.tensor([True], device="cuda") - a = torch.ones([4, 5], device="cuda") - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_unbacked_symint_multi_output_layout(self): - def f(p, size_tensor): - size_val = size_tensor.item() - b = torch.ones([size_val, 3], device="cuda") - - def true_fn(x): - return torch.cos(x), torch.cos(x) + 1 - - def false_fn(x): - return torch.sin(x), torch.sin(x) + 1 - - cond_out = torch.cond(p, true_fn, false_fn, [b]) - return cond_out[0] + cond_out[1] - - compiled_f = torch.compile(f) - p = torch.tensor([True], device="cuda") - size_tensor = torch.tensor(2, device="cuda") - eager_out = f(p, size_tensor) - compiled_out = compiled_f(p, size_tensor) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to("cuda") - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device="cuda"), - torch.randn(3, 3, device="cuda"), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - x, y = ( - torch.ones(4, 4, device="cuda"), - torch.randn(4, 4, device="cuda"), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_cat_backward(self): - def f(x, w): - y = torch.cat((x, x), dim=0) - z = y @ w - return z @ z.T - - compiled_f = torch.compile(f) - - for shape in (2, 3): - torch.manual_seed(42) - eager_x = torch.randn(shape, 2, device="cuda") - eager_w = torch.randn(2, 2, device="cuda", requires_grad=True) - torch.manual_seed(42) - compiled_x = torch.randn(shape, 2, device="cuda") - compiled_w = torch.randn(2, 2, device="cuda", requires_grad=True) - - f(eager_x, eager_w).sum().backward() - compiled_f(compiled_x, compiled_w).sum().backward() - self.assertEqual(eager_w.grad, compiled_w.grad) - - @dynamo_config.patch("capture_dynamic_output_shape_ops", True) - @config.patch(implicit_fallbacks=True) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_nested_indirect_indexing(self): - def nested(x, repeats): - rank = torch.arange(repeats.numel(), device=x.device) - index = rank.repeat_interleave(repeats, dim=0) - return torch.index_select(x, index=index, dim=0) - - example_inputs = ( - torch.randn((32, 64), device="cuda"), - repeats := torch.tensor([5, 10, 15], device="cuda"), - ) - torch._dynamo.mark_dynamic(repeats, 0) # create backed symint - - nested_opt = torch.compile(nested, backend="inductor") - - expect = nested(*example_inputs) - actual = nested_opt(*example_inputs) - self.assertEqual(expect, actual) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_mutation_index(self): - x = torch.zeros(7, device="cuda") - - def fn(n, a): - a[n] = -1 - return a - - opt_fn = torch.compile(fn, fullgraph=True) - - for n in range(2, x.shape[0]): - opt_fn(n, x) - self.assertEqual(x[n], -1) - - # Negative index triggers new compilation. - opt_fn(-x.shape[0], x) - - self.assertEqual(x[0], -1) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_unbacked_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to("cuda") - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device="cuda"), - torch.randn(3, 3, device="cuda"), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y) - eager_out = f(x, y) - self.assertEqual(compiled_out, eager_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_dynamic_scalar_inputs(self): - def f(x, y, integer): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - z += integer - return x1 + y1 + z + y_cpu.to("cuda") - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device="cuda"), - torch.randn(3, 3, device="cuda"), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y, 5) - self.assertEqual(compiled_out, f(x, y, 5)) - - compiled_out = f_compiled(x, y, 6) - self.assertEqual(compiled_out, f(x, y, 6)) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_item(self): - def f(x): - y = x + 1 - scalar = y.item() - return x + y + scalar - - compiled_f = torch.compile(f) - compiled_out = compiled_f(torch.tensor(1, device="cuda")) - self.assertEqual(compiled_out, f(torch.tensor(1, device="cuda"))) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_buffer_reuse(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x1 + y1 + x @ y - u = (y_cpu.to("cuda") + 2) @ y + 3 - u_cpu = u.cpu() + 2 - return z + u_cpu.to("cuda") - - x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_fused_scheduler_node(self): - def foo(x): - x = x * 20 - x_alias = x[0] - y = x * 10 - y_alias = y[0] - torch._dynamo.graph_break() - ind = torch.tensor(4, device="cuda") - x_alias2 = x[ind:] - y_alias2 = y[ind:] - return x, x_alias, x_alias2, y_alias, y_alias2 - - compiled_foo = torch.compile(foo) - x = torch.rand([20, 20], device="cuda") - - eager_out = foo(x) - compiled_out = compiled_foo(x) - self.assertEqual(eager_out, compiled_out) - def test_meta_tensor(self): def foobar(x, y): return x * 2, y * 3 diff --git a/test/inductor/test_inductor_annotations.py b/test/inductor/test_inductor_annotations.py index 3824b25cdeaea..bee7e0ad917da 100644 --- a/test/inductor/test_inductor_annotations.py +++ b/test/inductor/test_inductor_annotations.py @@ -31,11 +31,10 @@ def test_training_annotation(self): code = self.get_code() self.assertTrue("from torch.cuda import nvtx" in code) - self.assertTrue( - code.count("training_annotation = nvtx._device_range_start('inference')") - >= 1 + self.assertEqual( + code.count("training_annotation = nvtx._device_range_start('inference')"), 1 ) - self.assertTrue(code.count("nvtx._device_range_end(training_annotation)") >= 1) + self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1) if __name__ == "__main__": diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 385a75d98f944..cdcedd5a1771e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -15044,6 +15044,302 @@ def fn(x): "'XBLOCK': 'constexpr'" ).run(code[0]) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + self.assertEqual(eager_out, compiled_out) + + _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) + + if not config.cpp_wrapper: + FileCheck().check("def partition_0(args):").check( + "(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)" + ).check("recursively_apply_fns = runner.recursively_apply_fns").run( + code[0] + ) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_foreach_op(self): + def fn(a0, a1): + c = torch._foreach_abs([a0, a1]) + return torch.mul(c[0], a0) + + compiled_fn = torch.compile(fn) + + a0 = torch.randn(2, 3, device=self.device) + a1 = torch.randn(2, 3, device=self.device) + eager_out = fn(a0, a1) + compiled_out = compiled_fn(a0, a1) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_multiple_functions(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + def g(x): + return x + 1 + + x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = g(f(x, y)) + + f_compiled = torch.compile(f) + g_compiled = torch.compile(g) + compiled_out = g_compiled(f_compiled(x_cloned, y_cloned)) + + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_condition_op(self): + def f(p, b): + def true_fn(x): + return torch.cos(x) + + def false_fn(x): + return torch.sin(x) + + return torch.cond(p, true_fn, false_fn, [b]) + + compiled_f = torch.compile(f) + + # static shape + p = torch.tensor([True], device=self.device) + a = torch.ones([2, 3], device=self.device) + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + # dynamic shape with backed symint + p = torch.tensor([True], device=self.device) + a = torch.ones([4, 5], device=self.device) + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_unbacked_symint_multi_output_layout(self): + def f(p, size_tensor): + size_val = size_tensor.item() + b = torch.ones([size_val, 3], device=GPU_TYPE) + + def true_fn(x): + return torch.cos(x), torch.cos(x) + 1 + + def false_fn(x): + return torch.sin(x), torch.sin(x) + 1 + + cond_out = torch.cond(p, true_fn, false_fn, [b]) + return cond_out[0] + cond_out[1] + + compiled_f = torch.compile(f) + p = torch.tensor([True], device=GPU_TYPE) + size_tensor = torch.tensor(2, device=GPU_TYPE) + eager_out = f(p, size_tensor) + compiled_out = compiled_f(p, size_tensor) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + x, y = ( + torch.ones(4, 4, device=self.device), + torch.randn(4, 4, device=self.device), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_cat_backward(self): + def f(x, w): + y = torch.cat((x, x), dim=0) + z = y @ w + return z @ z.T + + compiled_f = torch.compile(f) + + for shape in (2, 3): + torch.manual_seed(42) + eager_x = torch.randn(shape, 2, device=self.device) + eager_w = torch.randn(2, 2, device=self.device, requires_grad=True) + torch.manual_seed(42) + compiled_x = torch.randn(shape, 2, device=self.device) + compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True) + + f(eager_x, eager_w).sum().backward() + compiled_f(compiled_x, compiled_w).sum().backward() + self.assertEqual(eager_w.grad, compiled_w.grad) + + @dynamo_config.patch("capture_dynamic_output_shape_ops", True) + @config.patch(implicit_fallbacks=True) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_nested_indirect_indexing(self): + def nested(x, repeats): + rank = torch.arange(repeats.numel(), device=x.device) + index = rank.repeat_interleave(repeats, dim=0) + return torch.index_select(x, index=index, dim=0) + + example_inputs = ( + torch.randn((32, 64), device=self.device), + repeats := torch.tensor([5, 10, 15], device=self.device), + ) + torch._dynamo.mark_dynamic(repeats, 0) # create backed symint + + nested_opt = torch.compile(nested, backend="inductor") + + expect = nested(*example_inputs) + actual = nested_opt(*example_inputs) + self.assertEqual(expect, actual) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_mutation_index(self): + x = torch.zeros(7, device=GPU_TYPE) + + def fn(n, a): + a[n] = -1 + return a + + opt_fn = torch.compile(fn, fullgraph=True) + + for n in range(2, x.shape[0]): + opt_fn(n, x) + self.assertEqual(x[n], -1) + + # Negative index triggers new compilation. + opt_fn(-x.shape[0], x) + + self.assertEqual(x[0], -1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_unbacked_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y) + eager_out = f(x, y) + self.assertEqual(compiled_out, eager_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_dynamic_scalar_inputs(self): + def f(x, y, integer): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + z += integer + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y, 5) + self.assertEqual(compiled_out, f(x, y, 5)) + + compiled_out = f_compiled(x, y, 6) + self.assertEqual(compiled_out, f(x, y, 6)) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_item(self): + def f(x): + y = x + 1 + scalar = y.item() + return x + y + scalar + + compiled_f = torch.compile(f) + compiled_out = f(torch.tensor(1, device=GPU_TYPE)) + self.assertEqual(compiled_out, f(torch.tensor(1, device=GPU_TYPE))) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_buffer_reuse(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x1 + y1 + x @ y + u = (y_cpu.to(GPU_TYPE) + 2) @ y + 3 + u_cpu = u.cpu() + 2 + return z + u_cpu.to(GPU_TYPE) + + x, y = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_fused_scheduler_node(self): + def foo(x): + x = x * 20 + x_alias = x[0] + y = x * 10 + y_alias = y[0] + torch._dynamo.graph_break() + ind = torch.tensor(4, device=GPU_TYPE) + x_alias2 = x[ind:] + y_alias2 = y[ind:] + return x, x_alias, x_alias2, y_alias, y_alias2 + + foo = torch.compile(foo) + x = torch.rand([20, 20], device=GPU_TYPE) + _, code = run_and_get_code(foo, x) + + if not config.cpp_wrapper: + FileCheck().check("def partition_0(args):").run(code[0]) + @unittest.skipIf(TEST_WITH_ROCM or not IS_SM90, "no scaled_grouped_mm support") def test_respect_scaled_grouped_mm_layout_tag(self): # scaled_grouped_mm needs `mat2` to be column-major diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 9394c0e4a16d6..8ac01ae791f72 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -50,7 +50,6 @@ get_benchmark_name, IndentedBuffer, is_codegen_graph_partition_subgraph, - is_using_cudagraph_partition, LineContext, sympy_product, sympy_str, @@ -1198,14 +1197,7 @@ def write_prefix(self) -> None: self.write_args(graph_input_names) self.codegen_inputs() - - # avoid duplicating asserts for both partition functions and - # the call function when using cudagraph partition - if not ( - is_using_cudagraph_partition() - and (not is_codegen_graph_partition_subgraph(self)) - ): - self.codegen_input_size_and_nan_asserts() + self.codegen_input_size_and_nan_asserts() def codegen_input_size_and_nan_asserts(self) -> None: if config.size_asserts: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 770da725a9aad..8d3b4cd7ed492 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -437,11 +437,7 @@ def prologue_fusion_enabled() -> bool: ) # enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph -graph_partition: bool = ( - os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0") - == "1" -) - +graph_partition = False # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 7826c797d36be..2686d1d2ddde2 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -10,8 +10,6 @@ from torch._inductor.utils import GraphPartitionMap, InputType from torch.utils._ordered_set import OrderedSet -from .utils import is_using_cudagraph_partition - if TYPE_CHECKING: from collections.abc import Sequence @@ -172,8 +170,7 @@ def check_multiple_devices_or_any_cpu_nodes( # meta tensors are supported since there is no compute device_node_mapping.pop(torch.device("meta"), None) - # dynamo cudagraph does not support graph partition - if is_using_cudagraph_partition(): + if torch._inductor.config.graph_partition: # graph partition supports splitting on cpu op. So we can ignore cpu nodes. device_node_mapping.pop(torch.device("cpu"), None) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index d8a96c573b320..e0a0309d1c811 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2179,10 +2179,7 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) self.process_grouped_nodes() - if ( - torch._inductor.config.graph_partition - and torch._inductor.config.triton.cudagraphs - ): + if torch._inductor.config.graph_partition: self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) @@ -4315,12 +4312,6 @@ def should_partition( ) -> bool: """Return True if we should partition the inductor graph on this node""" - # When not using cudagraphs, keep all kernels in the `call` function - # instead of graph partition functions, since graph partition only brings - # benefit to cudagraph - if not torch._inductor.config.triton.cudagraphs: - return True - # avoid duplicating logs when should_partition is called multiple times # on the same node def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0418edb2a1154..f21905e16e9d7 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3329,13 +3329,6 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: ) -def is_using_cudagraph_partition() -> bool: - return ( - torch._inductor.config.triton.cudagraphs - and torch._inductor.config.graph_partition - ) - - def dtype_from_size(size: int) -> torch.dtype: from .virtualized import V From b149c7204c218e7c4d6594a89dd74f72bd480ec5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 11 Aug 2025 20:44:45 +0000 Subject: [PATCH 0227/1424] Revert "port distributed pipeline test files for Intel GPU (#159033)" This reverts commit 76a0609b6bddb2bc40f1eb4ade12885023653d59. Reverted https://github.com/pytorch/pytorch/pull/159033 on behalf of https://github.com/clee2000 due to broke test_cpp_extensions_stream_and_event.py::TestCppExtensionStreamAndEvent::test_stream_event [GH job link](https://github.com/pytorch/pytorch/actions/runs/16890370216/job/47849586456) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/76a0609b6bddb2bc40f1eb4ade12885023653d59) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/159033#issuecomment-3176833314)) --- test/distributed/pipelining/test_schedule.py | 10 +-- .../pipelining/test_schedule_multiproc.py | 89 +++++++------------ test/distributed/pipelining/test_stage.py | 51 +++++------ .../pipelining/test_transformer.py | 4 +- test/distributed/pipelining/test_unflatten.py | 4 +- torch/testing/_internal/common_utils.py | 1 - 6 files changed, 57 insertions(+), 102 deletions(-) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 6f5b4df82a4ad..b1ad9b757a89b 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -38,7 +38,7 @@ W, ) from torch.distributed.pipelining.stage import _PipelineStageBase, PipelineStage -from torch.testing._internal.common_distributed import requires_accelerator_dist_backend +from torch.testing._internal.common_distributed import requires_nccl from torch.testing._internal.common_utils import ( check_leaked_tensors, instantiate_parametrized_tests, @@ -51,8 +51,6 @@ ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "artifacts") -device = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" - logger = logging.getLogger(__name__) torch.manual_seed(0) @@ -659,7 +657,7 @@ def _dump_csv(pipeline_order_with_comms, filename: str): # print(_format_pipeline_order(simulated_schedule)) self.assertEqual(num_steps, 113) - @requires_accelerator_dist_backend(["nccl", "xccl"]) + @requires_nccl() def test_grad_with_v_schedule(self): """ We have a special case for V schedules where 2 adjacent stages are on the same rank. @@ -679,6 +677,7 @@ def test_grad_with_v_schedule(self): d_hid = 512 batch_size = 256 n_stages = 2 + device = "cuda" full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(device) @@ -777,7 +776,7 @@ def test_grad_with_v_schedule(self): torch.distributed.destroy_process_group() - @requires_accelerator_dist_backend(["nccl", "xccl"]) + @requires_nccl() def test_grad_with_split_b_w(self): """ Ensure that separate dInput and dWeight computations are correctly executed. @@ -790,6 +789,7 @@ def test_grad_with_split_b_w(self): d_hid = 512 batch_size = 256 n_stages = 1 + device = "cuda" full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(device) diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index a87d924541513..ae91911bc6a02 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -26,9 +26,10 @@ ScheduleZBVZeroBubble, ) from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime +from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, - requires_accelerator_dist_backend, + requires_nccl, ) from torch.testing._internal.common_utils import ( check_leaked_tensors, @@ -36,7 +37,6 @@ parametrize, run_tests, skip_but_pass_in_sandcastle_if, - TEST_MULTIACCELERATOR, ) @@ -45,8 +45,7 @@ d_hid = 512 batch_size = 64 torch.manual_seed(0) -device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" -backend = dist.get_default_backend_for_device(device_type) +device_type = "cuda" class ScheduleTest(MultiProcContinousTest): @@ -54,7 +53,8 @@ class ScheduleTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: - return backend + # Testing with NCCL backend + return "nccl" @property def device(self) -> torch.device: @@ -180,10 +180,8 @@ def _zero_gradients(self, stage_modules): for stage_module in stage_modules: stage_module.zero_grad() - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [_ScheduleForwardOnly]) def test_forward_only(self, ScheduleClass): mod, mod_ref, x, _, _ = self._setup_models_and_data() @@ -212,10 +210,8 @@ def test_forward_only(self, ScheduleClass): x_clone = mod_ref(x_clone) torch.testing.assert_close(x_clone, out) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize( "ScheduleClass", [ @@ -287,10 +283,8 @@ def test_eval_inference_mode(self, ScheduleClass): if self.rank == self.world_size - 1: self.assertTrue(len(losses) > 0, "Losses should be computed during eval()") - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_multi_iter(self, ScheduleClass): mod, _, x, target, loss_fn = self._setup_models_and_data() @@ -308,10 +302,8 @@ def test_multi_iter(self, ScheduleClass): else: schedule.step() - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_kwargs_with_tracer(self, ScheduleClass): # Model has two stages only, thus limiting group size to 2 @@ -367,10 +359,8 @@ def test_kwargs_with_tracer(self, ScheduleClass): torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3) torch.testing.assert_close(pipe_loss, ref_loss) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_grad_with_tracer(self, ScheduleClass): mod, ref_mod, x, target, loss_fn = self._setup_models_and_data() @@ -408,10 +398,8 @@ def test_grad_with_tracer(self, ScheduleClass): # Check gradients using helper method self._check_gradients(stage_module, ref_mod) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @parametrize("shape_inference", [True, False]) def test_grad_with_manual(self, ScheduleClass, shape_inference): @@ -465,10 +453,8 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): # Check gradients using helper method self._check_gradients(stage_module, ref_mod) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize( "ScheduleClass", [ @@ -577,10 +563,8 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): stage_modules, ref_mod, submod_names, rtol=5e-3, atol=5e-3 ) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble]) def test_schedule_with_native_zero_bubble(self, ScheduleClass): print(ScheduleClass) @@ -637,16 +621,9 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) - @parametrize( - "ScheduleClass", - [ - ScheduleWithReorderedB, - ], - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleWithReorderedB]) def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): n_stages = 2 stages_per_rank = 1 @@ -702,10 +679,8 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize( "schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble] ) @@ -765,10 +740,8 @@ def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble]) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 @@ -847,10 +820,8 @@ def dw_runner(): # Check gradients using helper method self._check_gradients(stage_modules, ref_mod, submod_names) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize( "ScheduleClass", [ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B], diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index acb5bec7d84ee..a711cec64d72a 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -14,10 +14,11 @@ ScheduleGPipe, ) from torch.distributed.pipelining._utils import PipeliningShapeError +from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, MultiProcessTestCase, - requires_accelerator_dist_backend, + requires_nccl, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -25,7 +26,6 @@ run_tests, skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, - TEST_MULTIACCELERATOR, ) from torch.utils._pytree import tree_map_only @@ -34,8 +34,8 @@ batch_size = 256 chunks = 4 -device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" -backend = dist.get_default_backend_for_device(device_type) +device_type = "cuda" + torch.manual_seed(0) @@ -66,7 +66,8 @@ def f(x): class StageTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: - return backend + # Testing with NCCL backend + return "nccl" @classmethod def device_type(cls) -> str: @@ -76,10 +77,8 @@ def device_type(cls) -> str: def device(self) -> torch.device: return torch.device(device_type, self.rank) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ModelClass", [ExampleCode, MultiMLP]) def test_tracer(self, ModelClass): mod = ModelClass(d_hid, self.world_size) @@ -122,10 +121,8 @@ def _run_step(x): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ModelClass", [ModelWithKwargs]) def test_tracer_kwargs(self, ModelClass): mod = ModelClass(d_hid, self.world_size) @@ -173,10 +170,8 @@ def test_tracer_kwargs(self, ModelClass): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_manual(self): full_mod = MultiMLP(d_hid, n_layers=self.world_size) full_mod.to(self.device) @@ -207,10 +202,8 @@ def _run_step(x): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_custom_dw_with_fb_schedule(self): """Tests that separate weight grad function 'dw_runner' gets run under a schedule that's only aware of F/B.""" full_mod = MultiMLP(d_hid, n_layers=self.world_size) @@ -269,10 +262,8 @@ def _run_step(x): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_output_chunks_memory_usage(self): """Test that output_chunks doesn't store memory for non-first stages.""" full_mod = MultiMLP(d_hid, n_layers=self.world_size) @@ -356,14 +347,14 @@ def tearDown(self): def init_pg(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( - backend=backend, + backend="nccl", store=store, rank=self.rank, world_size=self.world_size, device_id=self.device, ) - @requires_accelerator_dist_backend(["nccl", "xccl"]) + @requires_nccl() @skip_but_pass_in_sandcastle("Flaky in CI") def test_shape_prop_mismatch(self): """Tests shape prop errors are raised""" @@ -411,10 +402,8 @@ def _run_step(x): with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): _run_step(x) - @requires_accelerator_dist_backend(["nccl", "xccl"]) - @skip_but_pass_in_sandcastle_if( - not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" - ) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_custom_dw_errors(self): """Tests expected errors are raised""" self.init_pg() diff --git a/test/distributed/pipelining/test_transformer.py b/test/distributed/pipelining/test_transformer.py index 20e830547de7b..7e58129186a69 100644 --- a/test/distributed/pipelining/test_transformer.py +++ b/test/distributed/pipelining/test_transformer.py @@ -73,9 +73,7 @@ def get_layers(module): devices = ["cpu", "cuda", "hpu", "xpu"] -instantiate_device_type_tests( - TransformerTests, globals(), only_for=devices, allow_xpu=True -) +instantiate_device_type_tests(TransformerTests, globals(), only_for=devices) if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 0493f39b16cb8..ae1e684d7c222 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -73,9 +73,7 @@ def test_unflatten(self, device): devices = ["cpu", "cuda", "hpu", "xpu"] -instantiate_device_type_tests( - UnflattenTests, globals(), only_for=devices, allow_xpu=True -) +instantiate_device_type_tests(UnflattenTests, globals(), only_for=devices) if __name__ == "__main__": run_tests() diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index f3c0648b46254..bfc568bc14645 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1422,7 +1422,6 @@ def is_privateuse1_backend_available(): TEST_XPU = torch.xpu.is_available() TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False TEST_CUDA = torch.cuda.is_available() -TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2 custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) TEST_PRIVATEUSE1 = is_privateuse1_backend_available() TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name() From cf0a0dcb0afa5e84b95461cc542f862b51ca96bf Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Mon, 11 Aug 2025 04:23:23 -0700 Subject: [PATCH 0228/1424] Make user defined Triton kernels serializable for fx_graph_runnable (#160002) Resolves issue https://github.com/pytorch/pytorch/issues/153475 where `fx_graph_runnable` didn't work with user defined triton kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160002 Approved by: https://github.com/eellison --- test/dynamo/test_fx_graph_runnable.py | 88 +++++++++++++++++++++++++++ torch/_dynamo/repro/after_aot.py | 77 +++++++++++++++++++++++ 2 files changed, 165 insertions(+) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index d5ad0c160c4ba..47e9ee3cb888e 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -11,12 +11,65 @@ from torch._inductor.codecache import WritableTempFile from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE +from torch.utils._triton import has_triton if torch.distributed.is_available(): from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore +if has_triton(): + import triton + import triton.language as tl + + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + @triton.jit + def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.atomic_add(output_ptr + offsets, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 1024}, + num_warps=4, + num_stages=2, + pre_hook=init_to_zero("output_ptr"), + ) + ], + pre_hook=init_to_zero("output_ptr"), + post_hook=init_to_zero("output_ptr"), + key=["n_elements"], + ) + @triton.jit + def add_kernel_autotune( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.atomic_add(output_ptr + offsets, output, mask=mask) + + +from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.triton_utils import requires_gpu + class FxGraphRunnableArtifactFilter(logging.Filter): def filter(self, record): @@ -100,6 +153,41 @@ def f(x): torch.compile(f)(torch.randn(4)) self._exec_and_verify_payload() + @unittest.skipUnless(has_triton(), "Triton not available") + def test_user_defined_triton_kernel_autotune(self): + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) + n_elements = output.numel() + + def grid( + meta, + ): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + add_kernel_autotune[grid](x, y, output, n_elements) + return output + + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + + torch.compile(add)(x, y) + self._exec_and_verify_payload() + + @unittest.skipUnless(has_triton(), "Triton not available") + @requires_gpu + def test_user_defined_triton_kernel(self): + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) + n_elements = x.numel() + add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4) + return output + + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + + torch.compile(add)(x, y) + self._exec_and_verify_payload() + def test_two_inputs_matmul(self): def f(a, b): return (a @ b).relu() diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 71f552a83b4ab..136d2af1a6087 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -34,6 +34,24 @@ from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union from typing_extensions import Unpack +from torch.utils._triton import has_triton + + +if has_triton(): + from triton.runtime.autotuner import Autotuner, Heuristics + from triton.runtime.jit import JITFunction +else: + + class Autotuner: # type: ignore[no-redef] + pass + + class JITFunction: # type: ignore[no-redef] + pass + + class Heuristics: # type: ignore[no-redef] + pass + + import torch import torch.fx as fx import torch.nn as nn @@ -58,6 +76,7 @@ ) from torch._dynamo.utils import clone_inputs, counters, same from torch._environment import is_fbcode +from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.output_code import OutputCode from torch._library.fake_class_registry import FakeScriptObject @@ -302,6 +321,16 @@ def generate_compiler_repro_string( """ ).strip() + triton_imports = "" + + if len(kernel_side_table.id_to_kernel) > 0: + triton_imports = textwrap.dedent( + """ +import triton +import triton.language as tl + """ + ).strip() + model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -312,6 +341,7 @@ def generate_compiler_repro_string( from math import inf import torch._inductor.inductor_prims {distributed_imports} +{triton_imports} {generate_config_string(stable_output=stable_output)} @@ -330,6 +360,53 @@ def generate_compiler_repro_string( model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() + kernel_side_table_prefix = ( + "torch._higher_order_ops.triton_kernel_wrap.kernel_side_table" + ) + # Track which grid entry corresponds to the best config + for id in kernel_side_table.id_to_kernel: + kernel = kernel_side_table.get_kernel(id) + + if isinstance(kernel, Autotuner): + if isinstance(kernel.fn, Heuristics): + model_str += "ERROR: Repro will not work as intended, " + model_str += ( + "triton.runtime.autotuner.Heuristics is not currently supported\n" + ) + break + + config_strs = [] + for kernel_config in kernel.configs: + config_strs.append(f"""triton.Config( + {str(kernel_config.kwargs)}, + num_warps={kernel_config.num_warps}, + num_stages={kernel_config.num_stages}, + )""") + + config_str = ",".join(config_strs) + model_str += textwrap.dedent(f""" + @triton.autotune( + configs=[ + {config_str} + ], + key=[] + ) + """).strip() + + model_str += "\n@triton.jit\n" + src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src + fn_name = ( + kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name + ) + fn_name = fn_name.split(".")[-1] + + model_str += src_code + model_str += "\n" + model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n" + + if len(kernel_side_table.constant_args) > 0: + model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" + model_str += NNModuleToString.convert(gm) writer = InputWriter(save_dir, stable_hash=stable_hash) From fc80f6859e0ccf66513a40f04b9e735e759d4ddb Mon Sep 17 00:00:00 2001 From: Sandeep Narendranath Karjala Date: Mon, 11 Aug 2025 10:40:43 -0700 Subject: [PATCH 0229/1424] Fix collective schedule logging and runtime tests (#160260) Summary: - Fix collective schedule logging so that only logs when collectives present - Fix runtime estimate test to check if each op has a number value Pull Request resolved: https://github.com/pytorch/pytorch/pull/160260 Approved by: https://github.com/Skylion007 --- test/dynamo/test_structured_trace.py | 37 +++++++--------------------- torch/_inductor/debug.py | 4 ++- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index a930fb0406dbd..5897c129b267f 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -260,7 +260,6 @@ def test_schedule(self): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} @@ -293,7 +292,6 @@ def test_cudagraphs(self): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} @@ -333,7 +331,6 @@ def fn(x, y): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} @@ -354,7 +351,6 @@ def fn(x, y): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} @@ -385,7 +381,6 @@ def test_example_fn(self): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} @@ -443,7 +438,6 @@ def test_example_training_fn(self): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} @@ -453,7 +447,6 @@ def test_example_training_fn(self): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} @@ -678,7 +671,6 @@ def forward(self, x): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} @@ -698,7 +690,6 @@ def forward(self, x): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} @@ -738,7 +729,6 @@ def fn(x): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 0} @@ -898,7 +888,6 @@ def fn(a): {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} @@ -1159,9 +1148,9 @@ def test_collective_schedule_empty(self): log_collective_schedule([]) - self.assertIn('"inductor_collective_schedule"', self.buffer.getvalue()) - self.assertEqual(json.loads(payload_buffer.getvalue()), []) - self.assertParses() + # With no collectives, artifact should not be logged and payload should be empty + self.assertNotIn('"inductor_collective_schedule"', self.buffer.getvalue()) + self.assertEqual(payload_buffer.getvalue().strip(), "") @requires_tlparse @requires_distributed() @@ -1271,14 +1260,10 @@ def forward(self, x): self.assertTrue(len(compute_ops) > 0 or len(collective_ops) > 0) - # All ops should have runtime > 0 except wait_tensor can be 0 + # Just check each op has an estimated runtime value (any value, including 0) for op in ops: - if "wait_tensor" not in op["name"]: - self.assertGreater( - op["estimated_runtime_ns"], - 0, - f"Op {op['name']} should have runtime > 0", - ) + self.assertIn("estimated_runtime_ns", op) + self.assertIsNotNone(op["estimated_runtime_ns"]) self.assertParses() finally: @@ -1339,14 +1324,10 @@ def forward(self, x): self.assertIn("compute", op_types) self.assertIn("collective", op_types) - # All ops should have runtime > 0 except wait_tensor can be 0 + # Just check each op has an estimated runtime value (any value, including 0) for op in ops: - if "wait_tensor" not in op["name"]: - self.assertGreater( - op["estimated_runtime_ns"], - 0, - f"Op {op['name']} should have runtime > 0", - ) + self.assertIn("estimated_runtime_ns", op) + self.assertIsNotNone(op["estimated_runtime_ns"]) self.assertParses() finally: diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index f3be4a6b5506f..71df3429bb01c 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -719,7 +719,9 @@ def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None: if isinstance(op := getattr(node, "node", None), ir._CollectiveKernel) ] - _dump_collective_schedule(schedule) + # Only log when there is at least one collective op + if schedule: + _dump_collective_schedule(schedule) def log_runtime_estimates(node_runtimes: Sequence[tuple[Any, float]]) -> None: From 7d2ec704e47f4b740cdecda5534b305e8e1875ef Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 11 Aug 2025 21:01:52 +0000 Subject: [PATCH 0230/1424] Fix MPS autocast for ConvTranspose3d (#160345) ## Summary - ensure ConvTranspose3d uses fp32 under MPS autocast - add MPS autocast test for ConvTranspose3d Generated by Codex, see https://chatgpt.com/codex/tasks/task_e_689a360388288327a2cac6f55bbfc42c Fixes https://github.com/pytorch/pytorch/issues/160332 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160345 Approved by: https://github.com/dcci --- aten/src/ATen/autocast_mode.cpp | 1 + test/test_mps.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index afd0a6b67674a..2bf57a7ca5cb8 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -239,6 +239,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { KERNEL_MPS(scaled_dot_product_attention, lower_precision_fp) // fp32 + KERNEL_MPS(conv_transpose3d, input, fp32) KERNEL_MPS(acos, fp32) KERNEL_MPS(asin, fp32) KERNEL_MPS(cosh, fp32) diff --git a/test/test_mps.py b/test/test_mps.py index bff55eec95ae1..25e8836c761f5 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -199,6 +199,13 @@ def test_scaled_dot_product_attention_autocast(self, dtype): y = F.scaled_dot_product_attention(query, key, value.to(torch.float32)) self.assertEqual(y.to(y_autocast.dtype), y_autocast) + def test_conv_transpose3d_autocast_fp32(self): + m = nn.ConvTranspose3d(16, 33, 3, stride=2).to("mps") + x = torch.randn(20, 16, 10, 50, 100, device="mps") + with torch.amp.autocast(device_type="mps"): + y = m(x) + self.assertEqual(y.dtype, torch.float32) + def test_gradscaler_mps(self): # big model to force chunking/depth in the gradscaler dispatch class Model(nn.Module): From 5a40c5784482255b9baf14086cc4b9349fc6d512 Mon Sep 17 00:00:00 2001 From: Pat Vignola Date: Mon, 11 Aug 2025 21:45:09 +0000 Subject: [PATCH 0231/1424] [MTIA] Implement isAvailable() for MTIA hooks (#160304) Summary: MTIA is missing the `isAvailable()` override, which is necessary for some of the device agnostic methods. Test Plan: `torch._C._get_accelerator()` Rollback Plan: Differential Revision: D79981115 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160304 Approved by: https://github.com/nautsimon --- aten/src/ATen/detail/MTIAHooksInterface.cpp | 4 ++++ aten/src/ATen/detail/MTIAHooksInterface.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/aten/src/ATen/detail/MTIAHooksInterface.cpp b/aten/src/ATen/detail/MTIAHooksInterface.cpp index b6e260e59ec41..d2e331abb0c04 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.cpp +++ b/aten/src/ATen/detail/MTIAHooksInterface.cpp @@ -21,6 +21,10 @@ bool isMTIAHooksBuilt() { } // namespace detail +bool MTIAHooksInterface::isAvailable() const { + return detail::isMTIAHooksBuilt() && detail::getMTIAHooks().deviceCount() > 0; +} + C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs) } // namespace at diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index fb8ed6fb23226..b415862f29e7c 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -149,6 +149,8 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); return; } + + virtual bool isAvailable() const override; }; struct TORCH_API MTIAHooksArgs {}; From fc25c68f20f772290927a7031b998b92615259cf Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Mon, 11 Aug 2025 11:50:43 -0700 Subject: [PATCH 0232/1424] [hop][exc] make UncapturedHigherOrderOpError print user code and avoid re-raise (#159296) After the change, the error stacktrace is attached with user code stack and is suppressed into 1 (without the scrolling up mssage). For example: ```python class Test(torch.nn.Module): def forward(self, c, x): def cond_fn(c, x): return c > 0 and x.size(0) < 20 def body_fn(c, x): return c - 1, x.sin() return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x)) ``` Now gives the following error message: ```python Traceback (most recent call last): File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1705, in test_while_loop_size_mismatch_tensor_expansion self._run_test( ~~~~~~~~~~~~~~^ model=WhileLoopModels.SizeMismatchTensorExpansion(), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ...<2 lines>... dynamic=dynamic, ^^^^^^^^^^^^^^^^ ) ^ File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1417, in _run_test result = model(*inputs_with_counters) File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1053, in forward return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x)) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 176, in while_loop return torch.compile( ~~~~~~~~~~~~~~ _while_loop_op_wrapper, backend=backend, fullgraph=True ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ )(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple()) ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1595, in __call__ result = self._torchdynamo_orig_backend( frame, cache_entry, self.hooks, frame_state, skip=1 ) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1353, in __call__ result = self._inner_convert( frame, cache_entry, hooks, frame_state, skip=skip + 1 ) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 682, in __call__ result = _compile( frame.f_code, ...<16 lines>... convert_frame_box=self._box, ) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1172, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/home/yidi/local/pytorch/torch/_utils_internal.py", line 98, in wrapper_function return function(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 858, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 897, in _compile_inner out_code = transform_code_object(code, transform) File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object transformations(instructions, code_options) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 300, in _fn return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 818, in transform tracer.run() ~~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3528, in run super().run() ~~~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run while self.step(): ~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step self.dispatch_table[inst.opcode](self, inst) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper return inner_fn(self, inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 91, in graph_break_as_hard_error raise exc.with_traceback(sys.exc_info()[2]) from None File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 77, in graph_break_as_hard_error return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1287, in call_function ) = speculate_subgraph( ~~~~~~~~~~~~~~~~~~^ tx, ^^^ ...<33 lines>... supports_aliasing=self.supports_aliasing, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 877, in speculate_subgraph raise ex File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 718, in speculate_subgraph output = f.call_function(tx, args, sub_kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function return super().call_function(tx, args, kwargs) ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call return tracer.inline_call_() ~~~~~~~~~~~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_ self.run() ~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run while self.step(): ~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step self.dispatch_table[inst.opcode](self, inst) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper return inner_fn(self, inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function return super().call_function(tx, args, kwargs) ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call return tracer.inline_call_() ~~~~~~~~~~~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_ self.run() ~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run while self.step(): ~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step self.dispatch_table[inst.opcode](self, inst) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 830, in inner unimplemented_v2( ~~~~~~~~~~~~~~~~^ gb_type="Data-dependent branching", ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ...<5 lines>... ], ^^ ) ^ File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 580, in unimplemented_v2 raise Unsupported(msg) torch._dynamo.exc.UncapturedHigherOrderOpError: while_loop doesn't work unless it is captured completely with torch.compile. Got Data-dependent branching Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Hint: Use `torch.cond` to express dynamic control flow. Developer debug context: attempted to jump with TensorVariable() For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html from user code: File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 167, in _while_loop_op_wrapper return while_loop_op(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 137, in flat_cond_fn return cond_fn(*carried, *additional) File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1047, in cond_fn return c > 0 and x.size(0) < 20 Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" To execute this test, run the following from the base repo dir: python test/inductor/test_control_flow.py WhileLoopTests.test_while_loop_size_mismatch_tensor_expansion_device_cpu_dynamic_False This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159296 Approved by: https://github.com/zou3519 --- test/higher_order_ops/test_invoke_subgraph.py | 36 +++++-------------- torch/_dynamo/exc.py | 9 ++++- torch/_dynamo/variables/higher_order_ops.py | 15 ++++++-- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 46d796f1dac37..df1bd941d8857 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -1195,17 +1195,11 @@ def fn(x, y): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( - RuntimeError, - "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", - ) as cm: + torch._dynamo.exc.UncapturedHigherOrderOpError, + "Encountered aliasing during higher order op tracing", + ): opt_fn(x, y) - cause = cm.exception.__cause__ - self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) - self.assertTrue( - "Encountered aliasing during higher order op tracing" in str(cause) - ) - def test_input_input_aliasing(self): @nested_compile_region def gn(x, y): @@ -1219,17 +1213,11 @@ def fn(x): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( - RuntimeError, - "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", - ) as cm: + torch._dynamo.exc.UncapturedHigherOrderOpError, + "Encountered aliasing during higher order op tracing", + ): opt_fn(x) - cause = cm.exception.__cause__ - self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) - self.assertTrue( - "Encountered aliasing during higher order op tracing" in str(cause) - ) - def test_output_output_aliasing(self): @nested_compile_region def gn(x): @@ -1244,17 +1232,11 @@ def fn(x): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( - RuntimeError, - "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", - ) as cm: + torch._dynamo.exc.UncapturedHigherOrderOpError, + "Encountered aliasing during higher order op tracing", + ): opt_fn(x) - cause = cm.exception.__cause__ - self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) - self.assertTrue( - "Encountered aliasing during higher order op tracing" in str(cause) - ) - def test_mod_attr_aliasing(self): class MutateParam(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index e1247917ef82e..0636170391319 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -264,7 +264,14 @@ class UnsafeScriptObjectError(TorchDynamoException): class UncapturedHigherOrderOpError(TorchDynamoException): - pass + def __init__(self, msg: str, real_stack: Optional[StackSummary] = None) -> None: + super().__init__(msg) + self.msg = msg + self.real_stack = ( + real_stack + if real_stack is not None + else torch._guards.TracingContext.extract_stack() + ) class IncorrectUsage(Exception): diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index ea935ae5f7afa..d3334424c5f45 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -77,8 +77,19 @@ def graph_break_as_hard_error(*args, **kwargs): try: return fn(*args, **kwargs) except (Unsupported, ObservedException) as e: - msg = " Scroll up to find out what causes the graph break." - raise UncapturedHigherOrderOpError(reason + msg) from e + import sys + + if isinstance(e, Unsupported): + exc = UncapturedHigherOrderOpError( + f"{reason} Got {e.msg}", e.real_stack + ) + else: + msg = e.msg if hasattr(e, "msg") else type(e) + real_stack = e.real_stack if hasattr(e, "real_stack") else None + exc = UncapturedHigherOrderOpError( + f"{reason} Got {msg}", real_stack + ) + raise exc.with_traceback(sys.exc_info()[2]) from None return graph_break_as_hard_error From 99bc2f94c1955657e950ebdad5f77e518785ccbd Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Mon, 11 Aug 2025 23:14:08 +0000 Subject: [PATCH 0233/1424] Update export/schema.py (#160220) Summary: Model could have multiple ExportedPrograms - for different methods. They can have different weights. - for different delegates. They can also have different weights. For this reason, we make weight per ExportedProgram. Also, we cleanup Model, and Program. IIUC, Model and Program are not used anywhere, so it's ok to make BC breaking change. Test Plan: CI Rollback Plan: Differential Revision: D79917395 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160220 Approved by: https://github.com/angelayi, https://github.com/dolpm, https://github.com/jingsh --- torch/_export/serde/export_schema.thrift | 14 ++- torch/_export/serde/schema.py | 37 ++++---- torch/_export/serde/schema.yaml | 25 +++--- .../utils/generated_serialization_types.h | 86 +++++++------------ 4 files changed, 64 insertions(+), 98 deletions(-) diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index 0b2f2b4fe7408..5eb5512cde638 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<0b6fec18525f05577f007055f774b5e6f143ca7499b931474d1f4cd4a5dc5004>> +// checksum<> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -330,18 +330,14 @@ struct ExportedProgram { 60: SchemaVersion schema_version; 70: list verifiers; 80: string torch_version; -} - -struct Program { - 200: map methods; + 90: map tensor_paths; + 100: map constant_paths; } struct Model { 10: string name; - 20: map tensorPaths; - 40: Program program; - 50: map delegates; - 70: map constantPaths; + 80: ExportedProgram program; + 90: map variants; } struct AOTInductorModelPickleData { diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 30bc119a54007..dba719a601558 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -9,7 +9,7 @@ # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 9) +SCHEMA_VERSION = (8, 10) TREESPEC_VERSION = 1 @@ -436,34 +436,35 @@ class ExportedProgram: verifiers: Annotated[list[str], 70] = field(default_factory=list) torch_version: Annotated[str, 80] = "<=2.4" + # key is the FQN of tensor in exported program + # value is the archive path of tensor payloads + # e.g. "L__self__linear.weight" : "/data/tensor/weight_1" + tensor_paths: Annotated[dict[str, str], 90] = field(default_factory=dict) + + # key is the FQN of constant in exported program (constant tensor or torchbind objs) + # value is the archive path of serialized constants + constant_paths: Annotated[dict[str, str], 100] = field(default_factory=dict) + ######################################################################### # Container types for inference tasks, not being used directly for export. ######################################################################### -@dataclass -class Program: - methods: Annotated[dict[str, ExportedProgram], 200] - - # This is the top-level model definition that be will serialized into the package @dataclass class Model: # unique identifier of the model in the package, e.g. local, remote, merge name: Annotated[str, 10] - # key is the FQN of tensor in exported program - # value is the archive path of tensor payloads - # e.g. "L__self__linear.weight" : "/data/tensor/L__self__linear.weight" - tensorPaths: Annotated[dict[str, str], 20] - # program exported from torch.export() - program: Annotated[Program, 40] - # Backend-specialized Lowered GraphModule - # e.g. "aotinductor-a100" : ExportedProgram_with_AOTInductor_delegate - delegates: Annotated[dict[str, Program], 50] - # key is the FQN of constant in exported program (constant tensor or torchbind objs) - # value is the archive path of serialized constants - constantPaths: Annotated[dict[str, str], 70] + + # the main program exported from torch.export() + program: Annotated[ExportedProgram, 80] + + # a collection of ExportedPrograms that are related to the same model + # They can be used for different purposes, e.g. + # - different methods such as "encode" and "decode" for the same model + # - different delegates such as "aoti_sm80" and "aoti_sm90" + variants: Annotated[dict[str, ExportedProgram], 90] # diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 56e40f309744e..bb087048a30c8 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<89a616d78254f20c027a2e0f882a3f8b096b4169c781d5dfd0254c8bce33cb35>> +# checksum<> AOTInductorModelPickleData: kind: struct fields: @@ -131,6 +131,12 @@ ExportedProgram: torch_version: type: str default: <=2.4 + tensor_paths: + type: Dict[str, str] + default: '{}' + constant_paths: + type: Dict[str, str] + default: '{}' ExternKernelNode: kind: struct fields: @@ -298,14 +304,10 @@ Model: fields: name: type: str - tensorPaths: - type: Dict[str, str] program: - type: Program - delegates: - type: Dict[str, Program] - constantPaths: - type: Dict[str, str] + type: ExportedProgram + variants: + type: Dict[str, ExportedProgram] ModuleCallEntry: kind: struct fields: @@ -386,11 +388,6 @@ OutputTokenSpec: fields: arg: type: TokenArgument -Program: - kind: struct - fields: - methods: - type: Dict[str, ExportedProgram] RangeConstraint: kind: struct fields: @@ -532,5 +529,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 9 +- 10 TREESPEC_VERSION: 1 diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index f93532ef9de23..62c8390f7c9b5 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<89a616d78254f20c027a2e0f882a3f8b096b4169c781d5dfd0254c8bce33cb35>> +// checksum<> // clang-format off #pragma once @@ -158,7 +158,6 @@ class Node; class OptionalTensorArgument; class OutputSpec; class OutputTokenSpec; -class Program; class RangeConstraint; class SchemaVersion; class SymBool; @@ -3014,6 +3013,8 @@ class ExportedProgram { SchemaVersion schema_version; std::vector verifiers = {}; std::string torch_version = "<=2.4"; + std::unordered_map tensor_paths = {}; + std::unordered_map constant_paths = {}; public: @@ -3065,35 +3066,31 @@ class ExportedProgram { torch_version = std::move(def); } - friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t); - friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t); -}; - -class Program { - private: - std::unordered_map methods; + const std::unordered_map& get_tensor_paths() const { + return tensor_paths; + } - public: + void set_tensor_paths(std::unordered_map def) { + tensor_paths = std::move(def); + } - const std::unordered_map& get_methods() const { - return methods; + const std::unordered_map& get_constant_paths() const { + return constant_paths; } - void set_methods(std::unordered_map def) { - methods = std::move(def); + void set_constant_paths(std::unordered_map def) { + constant_paths = std::move(def); } - friend void to_json(nlohmann::json& nlohmann_json_j, const Program& nlohmann_json_t); - friend void from_json(const nlohmann::json& nlohmann_json_j, Program& nlohmann_json_t); + friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t); }; class Model { private: std::string name; - std::unordered_map tensorPaths; - Program program; - std::unordered_map delegates; - std::unordered_map constantPaths; + ExportedProgram program; + std::unordered_map variants; public: @@ -3105,36 +3102,20 @@ class Model { name = std::move(def); } - const std::unordered_map& get_tensorPaths() const { - return tensorPaths; - } - - void set_tensorPaths(std::unordered_map def) { - tensorPaths = std::move(def); - } - - const Program& get_program() const { + const ExportedProgram& get_program() const { return program; } - void set_program(Program def) { + void set_program(ExportedProgram def) { program = std::move(def); } - const std::unordered_map& get_delegates() const { - return delegates; + const std::unordered_map& get_variants() const { + return variants; } - void set_delegates(std::unordered_map def) { - delegates = std::move(def); - } - - const std::unordered_map& get_constantPaths() const { - return constantPaths; - } - - void set_constantPaths(std::unordered_map def) { - constantPaths = std::move(def); + void set_variants(std::unordered_map def) { + variants = std::move(def); } friend void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_t); @@ -3308,6 +3289,8 @@ inline void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nloh nlohmann_json_j["schema_version"] = nlohmann_json_t.schema_version; nlohmann_json_j["verifiers"] = nlohmann_json_t.verifiers; nlohmann_json_j["torch_version"] = nlohmann_json_t.torch_version; + nlohmann_json_j["tensor_paths"] = nlohmann_json_t.tensor_paths; + nlohmann_json_j["constant_paths"] = nlohmann_json_t.constant_paths; } inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t) { @@ -3318,6 +3301,8 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nl nlohmann_json_t.schema_version = nlohmann_json_j.value("schema_version", nlohmann_json_default_obj.schema_version); nlohmann_json_t.verifiers = nlohmann_json_j.value("verifiers", nlohmann_json_default_obj.verifiers); nlohmann_json_t.torch_version = nlohmann_json_j.value("torch_version", nlohmann_json_default_obj.torch_version); + nlohmann_json_t.tensor_paths = nlohmann_json_j.value("tensor_paths", nlohmann_json_default_obj.tensor_paths); + nlohmann_json_t.constant_paths = nlohmann_json_j.value("constant_paths", nlohmann_json_default_obj.constant_paths); } inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNode& nlohmann_json_t) { @@ -3503,19 +3488,15 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlo inline void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_t) { nlohmann_json_j["name"] = nlohmann_json_t.name; - nlohmann_json_j["tensorPaths"] = nlohmann_json_t.tensorPaths; nlohmann_json_j["program"] = nlohmann_json_t.program; - nlohmann_json_j["delegates"] = nlohmann_json_t.delegates; - nlohmann_json_j["constantPaths"] = nlohmann_json_t.constantPaths; + nlohmann_json_j["variants"] = nlohmann_json_t.variants; } inline void from_json(const nlohmann::json& nlohmann_json_j, Model& nlohmann_json_t) { Model nlohmann_json_default_obj; nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); - nlohmann_json_t.tensorPaths = nlohmann_json_j.value("tensorPaths", nlohmann_json_default_obj.tensorPaths); nlohmann_json_t.program = nlohmann_json_j.value("program", nlohmann_json_default_obj.program); - nlohmann_json_t.delegates = nlohmann_json_j.value("delegates", nlohmann_json_default_obj.delegates); - nlohmann_json_t.constantPaths = nlohmann_json_j.value("constantPaths", nlohmann_json_default_obj.constantPaths); + nlohmann_json_t.variants = nlohmann_json_j.value("variants", nlohmann_json_default_obj.variants); } inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t) { @@ -3594,15 +3575,6 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nl nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); } -inline void to_json(nlohmann::json& nlohmann_json_j, const Program& nlohmann_json_t) { - nlohmann_json_j["methods"] = nlohmann_json_t.methods; -} - -inline void from_json(const nlohmann::json& nlohmann_json_j, Program& nlohmann_json_t) { - Program nlohmann_json_default_obj; - nlohmann_json_t.methods = nlohmann_json_j.value("methods", nlohmann_json_default_obj.methods); -} - inline void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t) { nlohmann_json_j["min_val"] = nlohmann_json_t.min_val; nlohmann_json_j["max_val"] = nlohmann_json_t.max_val; From 3626ba711b34397d1fbf0a9b1979f85cbf68b919 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 11 Aug 2025 23:30:15 +0000 Subject: [PATCH 0234/1424] [FlexAttention] Swap from and to & for new triton (#160227) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #158463 On B200 I am getting a bunch of error spew: ```Shell /tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: error: Failures have been detected while processing an MLIR pass pipeline /tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: note: Pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` Triton compilation failed: triton_tem_fused_zeros_1 def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): PRESCALE_QK : tl.constexpr = False ``` ```Shell 74 = arith.subi %170, %166 : i32 %175 = arith.muli %174, %c128_i32 : i32 %176 = arith.subi %175, %c64_i32 : i32 %177 = arith.extui %173 : i1 to i32 %178 = arith.muli %176, %177 : i32 %179 = arith.subi %c1_i32, %177 : i32 %180 = arith.muli %179, %c64_i32 : i32 %181 = arith.addi %178, %180 : i32 %182 = arith.muli %181, %c64_i32 : i32 %183 = tt.splat %182 : i32 -> tensor<64x64xi32> %184 = tt.addptr %arg19, %183 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %185 = tt.addptr %arg20, %183 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %186 = tt.splat %181 : i32 -> tensor<64xi32> %187 = arith.addi %arg21, %186 : tensor<64xi32> scf.yield %163, %184, %185, %187 : tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64xi32> } %114 = tt.expand_dims %113#3 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %115 = arith.cmpi slt, %114, %cst_7 : tensor<1x64xi32> %116 = tt.broadcast %115 : tensor<1x64xi1> -> tensor<64x64xi1> %117 = tt.load %113#1, %116, %cst_8 : tensor<64x64x!tt.ptr> %118 = tt.dot %46, %117, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %119 = arith.mulf %118, %cst_13 : tensor<64x64xf32> %120 = arith.mulf %119, %cst_3 : tensor<64x64xf32> %121 = arith.select %116, %120, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %122 = arith.select %115, %cst_4, %cst_5 : tensor<1x64xi1>, tensor<1x64xi1> %123 = tt.broadcast %122 : tensor<1x64xi1> -> tensor<64x64xi1> %124 = arith.select %123, %121, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %125 = arith.mulf %124, %cst_2 : tensor<64x64xf32> %126 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32> %127 = arith.subf %125, %126 : tensor<64x64xf32> %128 = math.exp2 %127 : tensor<64x64xf32> %129 = tt.load %113#2, %116, %cst_8 : tensor<64x64x!tt.ptr> %130 = tt.dot %51, %129, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %131 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> %132 = tt.broadcast %131 : tensor<64x1xf32> -> tensor<64x64xf32> %133 = arith.subf %130, %132 : tensor<64x64xf32> %134 = arith.mulf %128, %133 : tensor<64x64xf32> %135 = arith.mulf %134, %cst_3 : tensor<64x64xf32> %136 = arith.select %116, %135, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %137 = arith.select %115, %122, %cst_5 : tensor<1x64xi1>, tensor<1x64xi1> %138 = tt.broadcast %137 : tensor<1x64xi1> -> tensor<64x64xi1> %139 = arith.select %138, %136, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %140 = arith.truncf %139 : tensor<64x64xf32> to tensor<64x64xf16> %141 = tt.trans %117 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %142 = tt.dot %140, %141, %113#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> scf.yield %142 : tensor<64x64xf32> } else { scf.yield %cst_9 : tensor<64x64xf32> } %84 = tt.addptr %arg13, %22 : !tt.ptr, i32 %85 = tt.load %84 : !tt.ptr %86 = arith.muli %85, %c128_i32 : i32 %87 = tt.addptr %arg12, %21 : !tt.ptr, i32 %88 = tt.load %87 : !tt.ptr %89 = tt.splat %86 : i32 -> tensor<64xi32> %90 = arith.addi %89, %14 : tensor<64xi32> %91 = tt.expand_dims %90 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %92 = arith.muli %91, %cst_11 : tensor<1x64xi32> %93 = tt.addptr %71, %92 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> %94 = tt.broadcast %93 : tensor<1x64x!tt.ptr> -> tensor<64x64x!tt.ptr> %95 = tt.addptr %94, %74 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %96 = tt.addptr %76, %92 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> %97 = tt.broadcast %96 : tensor<1x64x!tt.ptr> -> tensor<64x64x!tt.ptr> %98 = tt.addptr %97, %74 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %99 = arith.muli %88, %c2_i32 : i32 %100 = arith.minsi %99, %c4_i32 : i32 %101 = arith.cmpi sge, %100, %c1_i32 : i32 %102 = scf.if %101 -> (tensor<64x64xf32>) { %112 = arith.subi %100, %c1_i32 : i32 %113:4 = scf.for %arg17 = %c0_i32 to %112 step %c1_i32 iter_args(%arg18 = %83, %arg19 = %95, %arg20 = %98, %arg21 = %90) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64xi32>) : i32 { %137 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %138 = arith.cmpi slt, %137, %cst_7 : tensor<1x64xi32> %139 = tt.broadcast %138 : tensor<1x64xi1> -> tensor<64x64xi1> %140 = tt.load %arg19, %139, %cst_8 : tensor<64x64x!tt.ptr> %141 = tt.dot %46, %140, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %142 = arith.mulf %141, %cst_13 : tensor<64x64xf32> %143 = arith.mulf %142, %cst_3 : tensor<64x64xf32> %144 = arith.mulf %143, %cst_2 : tensor<64x64xf32> %145 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32> %146 = arith.subf %144, %145 : tensor<64x64xf32> %147 = math.exp2 %146 : tensor<64x64xf32> %148 = tt.load %arg20, %139, %cst_8 : tensor<64x64x!tt.ptr> %149 = tt.dot %51, %148, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %150 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> %151 = tt.broadcast %150 : tensor<64x1xf32> -> tensor<64x64xf32> %152 = arith.subf %149, %151 : tensor<64x64xf32> %153 = arith.mulf %147, %152 : tensor<64x64xf32> %154 = arith.mulf %153, %cst_3 : tensor<64x64xf32> %155 = arith.truncf %154 : tensor<64x64xf32> to tensor<64x64xf16> %156 = tt.trans %140 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %157 = tt.dot %155, %156, %arg18, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %158 = arith.divsi %arg17, %c2_i32 : i32 %159 = tt.addptr %84, %158 : !tt.ptr, i32 %160 = tt.load %159 evictionPolicy = evict_last : !tt.ptr %161 = arith.addi %158, %c1_i32 : i32 %162 = arith.cmpi slt, %161, %88 : i32 %163 = tt.addptr %159, %c1_i32 : !tt.ptr, i32 %164 = tt.load %163, %162 evictionPolicy = evict_last : !tt.ptr %165 = arith.addi %arg17, %c1_i32 : i32 %166 = arith.remsi %165, %c2_i32 : i32 %167 = arith.cmpi eq, %166, %c0_i32 : i32 %168 = arith.subi %164, %160 : i32 %169 = arith.muli %168, %c128_i32 : i32 %170 = arith.subi %169, %c64_i32 : i32 %171 = arith.extui %167 : i1 to i32 %172 = arith.muli %170, %171 : i32 %173 = arith.subi %c1_i32, %171 : i32 %174 = arith.muli %173, %c64_i32 : i32 %175 = arith.addi %172, %174 : i32 %176 = arith.muli %175, %c64_i32 : i32 %177 = tt.splat %176 : i32 -> tensor<64x64xi32> %178 = tt.addptr %arg19, %177 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %179 = tt.addptr %arg20, %177 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %180 = tt.splat %175 : i32 -> tensor<64xi32> %181 = arith.addi %arg21, %180 : tensor<64xi32> scf.yield %157, %178, %179, %181 : tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64xi32> } %114 = tt.expand_dims %113#3 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %115 = arith.cmpi slt, %114, %cst_7 : tensor<1x64xi32> %116 = tt.broadcast %115 : tensor<1x64xi1> -> tensor<64x64xi1> %117 = tt.load %113#1, %116, %cst_8 : tensor<64x64x!tt.ptr> %118 = tt.dot %46, %117, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %119 = arith.mulf %118, %cst_13 : tensor<64x64xf32> %120 = arith.mulf %119, %cst_3 : tensor<64x64xf32> %121 = arith.select %116, %120, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %122 = arith.mulf %121, %cst_2 : tensor<64x64xf32> %123 = tt.broadcast %61 : tensor<64x1xf32> -> tensor<64x64xf32> %124 = arith.subf %122, %123 : tensor<64x64xf32> %125 = math.exp2 %124 : tensor<64x64xf32> %126 = tt.load %113#2, %116, %cst_8 : tensor<64x64x!tt.ptr> %127 = tt.dot %51, %126, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %128 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> %129 = tt.broadcast %128 : tensor<64x1xf32> -> tensor<64x64xf32> %130 = arith.subf %127, %129 : tensor<64x64xf32> %131 = arith.mulf %125, %130 : tensor<64x64xf32> %132 = arith.mulf %131, %cst_3 : tensor<64x64xf32> %133 = arith.select %116, %132, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %134 = arith.truncf %133 : tensor<64x64xf32> to tensor<64x64xf16> %135 = tt.trans %117 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %136 = tt.dot %134, %135, %113#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> scf.yield %136 : tensor<64x64xf32> } else { scf.yield %83 : tensor<64x64xf32> } %103 = tt.splat %33 : !tt.ptr -> tensor<64x1x!tt.ptr> %104 = tt.addptr %103, %37 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> %105 = tt.broadcast %104 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> %106 = tt.addptr %105, %42 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %107 = arith.mulf %102, %cst_13 : tensor<64x64xf32> %108 = arith.cmpi slt, %40, %cst_11 : tensor<1x64xi32> %109 = tt.broadcast %108 : tensor<1x64xi1> -> tensor<64x64xi1> %110 = arith.andi %45, %109 : tensor<64x64xi1> %111 = arith.truncf %107 : tensor<64x64xf32> to tensor<64x64xf16> tt.store %106, %111, %110 : tensor<64x64x!tt.ptr> } else { %16 = arith.divsi %0, %c2_i32 : i32 %17 = arith.muli %0, %c64_i32 : i32 %18 = tt.splat %17 : i32 -> tensor<64xi32> %19 = arith.addi %18, %14 : tensor<64xi32> %20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %21 = arith.muli %20, %cst_14 : tensor<64x1xi32> %22 = tt.splat %11 : !tt.ptr -> tensor<64x1x!tt.ptr> %23 = tt.addptr %22, %21 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> %24 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %25 = tt.broadcast %23 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> %26 = tt.broadcast %24 : tensor<1x64xi32> -> tensor<64x64xi32> %27 = tt.addptr %25, %26 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %28 = arith.cmpi slt, %20, %cst_10 : tensor<64x1xi32> %29 = tt.broadcast %28 : tensor<64x1xi1> -> tensor<64x64xi1> %30 = tt.load %27, %29, %cst_8 : tensor<64x64x!tt.ptr> %31 = tt.splat %12 : !tt.ptr -> tensor<64x1x!tt.ptr> %32 = tt.addptr %31, %21 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> %33 = tt.broadcast %32 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> %34 = tt.addptr %33, %26 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %35 = tt.load %34, %29, %cst_8 : tensor<64x64x!tt.ptr> %36:2 = scf.for %arg17 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg18 = %cst_9, %arg19 = %cst_9) -> (tensor<64x64xf32>, tensor<64x64xf32>) : i32 { %55 = arith.muli %2, %c4_i32 : i32 %56 = arith.addi %55, %arg17 : i32 %57 = arith.muli %56, %c2048_i32 : i32 %58 = arith.muli %1, %c32768_i32 : i32 %59 = arith.addi %57, %58 : i32 %60 = arith.extsi %59 : i32 to i64 %61 = arith.muli %1, %c16_i32 : i32 %62 = arith.addi %61, %56 : i32 %63 = arith.muli %62, %c32_i32 : i32 %64 = arith.extsi %63 : i32 to i64 %65 = tt.addptr %arg0, %60 : !tt.ptr, i64 %66 = tt.addptr %arg5, %60 : !tt.ptr, i64 %67 = tt.addptr %arg3, %64 : !tt.ptr, i64 %68 = tt.addptr %arg4, %64 : !tt.ptr, i64 %69 = arith.remsi %56, %c16_i32 : i32 %70 = arith.muli %3, %c16_i32 : i32 %71 = arith.addi %70, %69 : i32 %72 = arith.muli %71, %c2_i32 : i32 %73 = arith.addi %72, %16 : i32 %74 = tt.addptr %arg11, %73 : !tt.ptr, i32 %75 = tt.load %74 : !tt.ptr %76 = arith.muli %75, %c128_i32 : i32 %77 = tt.addptr %arg10, %73 : !tt.ptr, i32 %78 = tt.load %77 : !tt.ptr %79 = tt.splat %76 : i32 -> tensor<64xi32> %80 = arith.addi %79, %14 : tensor<64xi32> %81 = tt.expand_dims %80 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %82 = arith.muli %81, %cst_11 : tensor<1x64xi32> %83 = tt.splat %65 : !tt.ptr -> tensor<1x64x!tt.ptr> %84 = tt.addptr %83, %82 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> %85 = tt.expand_dims %14 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %86 = tt.broadcast %84 : tensor<1x64x!tt.ptr> -> tensor<64x64x!tt.ptr> %87 = tt.broadcast %85 : tensor<64x1xi32> -> tensor<64x64xi32> %88 = tt.addptr %86, %87 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %89 = tt.expand_dims %80 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %90 = arith.muli %89, %cst_14 : tensor<64x1xi32> %91 = tt.splat %66 : !tt.ptr -> tensor<64x1x!tt.ptr> %92 = tt.addptr %91, %90 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> %93 = tt.broadcast %92 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> %94 = tt.addptr %93, %26 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %95 = arith.muli %78, %c2_i32 : i32 %96 = arith.minsi %95, %c1_i32 : i32 %97 = arith.cmpi sge, %96, %c1_i32 : i32 %98:2 = scf.if %97 -> (tensor<64x64xf32>, tensor<64x64xf32>) { %120 = arith.subi %96, %c1_i32 : i32 %121:5 = scf.for %arg20 = %c0_i32 to %120 step %c1_i32 iter_args(%arg21 = %arg18, %arg22 = %arg19, %arg23 = %88, %arg24 = %94, %arg25 = %80) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64xi32>) : i32 { %167 = tt.expand_dims %arg25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %168 = arith.cmpi slt, %167, %cst_1 : tensor<1x64xi32> %169 = tt.broadcast %168 : tensor<1x64xi1> -> tensor<64x64xi1> %170 = tt.load %arg23, %169, %cst_8 : tensor<64x64x!tt.ptr> %171 = arith.cmpi slt, %arg25, %cst_17 : tensor<64xi32> %172 = tt.splat %67 : !tt.ptr -> tensor<64x!tt.ptr> %173 = tt.addptr %172, %arg25 : tensor<64x!tt.ptr>, tensor<64xi32> %174 = tt.load %173, %171 : tensor<64x!tt.ptr> %175 = arith.cmpf oeq, %174, %cst_16 : tensor<64xf32> %176 = arith.select %175, %cst_15, %174 : tensor<64xi1>, tensor<64xf32> %177 = tt.dot %30, %170, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %178 = arith.mulf %177, %cst_13 : tensor<64x64xf32> %179 = arith.mulf %178, %cst_3 : tensor<64x64xf32> %180 = arith.mulf %179, %cst_2 : tensor<64x64xf32> %181 = tt.expand_dims %176 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %182 = tt.broadcast %181 : tensor<1x64xf32> -> tensor<64x64xf32> %183 = arith.subf %180, %182 : tensor<64x64xf32> %184 = math.exp2 %183 : tensor<64x64xf32> %185 = tt.expand_dims %arg25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %186 = arith.cmpi slt, %185, %cst_12 : tensor<64x1xi32> %187 = tt.broadcast %186 : tensor<64x1xi1> -> tensor<64x64xi1> %188 = tt.load %arg24, %187, %cst_8 : tensor<64x64x!tt.ptr> %189 = arith.truncf %184 : tensor<64x64xf32> to tensor<64x64xf16> %190 = tt.dot %189, %188, %arg22, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %191 = tt.splat %68 : !tt.ptr -> tensor<64x!tt.ptr> %192 = tt.addptr %191, %arg25 : tensor<64x!tt.ptr>, tensor<64xi32> %193 = tt.load %192, %171 : tensor<64x!tt.ptr> %194 = tt.trans %188 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %195 = tt.dot %35, %194, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %196 = tt.expand_dims %193 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %197 = tt.broadcast %196 : tensor<1x64xf32> -> tensor<64x64xf32> %198 = arith.subf %195, %197 : tensor<64x64xf32> %199 = arith.mulf %184, %198 : tensor<64x64xf32> %200 = arith.mulf %199, %cst_3 : tensor<64x64xf32> %201 = arith.truncf %200 : tensor<64x64xf32> to tensor<64x64xf16> %202 = tt.trans %170 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %203 = tt.dot %201, %202, %arg21, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %204 = arith.divsi %arg20, %c2_i32 : i32 %205 = tt.addptr %74, %204 : !tt.ptr, i32 %206 = tt.load %205 evictionPolicy = evict_last : !tt.ptr %207 = arith.addi %204, %c1_i32 : i32 %208 = arith.cmpi slt, %207, %78 : i32 %209 = tt.addptr %205, %c1_i32 : !tt.ptr, i32 %210 = tt.load %209, %208 evictionPolicy = evict_last : !tt.ptr %211 = arith.addi %arg20, %c1_i32 : i32 %212 = arith.remsi %211, %c2_i32 : i32 %213 = arith.cmpi eq, %212, %c0_i32 : i32 %214 = arith.subi %210, %206 : i32 %215 = arith.muli %214, %c128_i32 : i32 %216 = arith.subi %215, %c64_i32 : i32 %217 = arith.extui %213 : i1 to i32 %218 = arith.muli %216, %217 : i32 %219 = arith.subi %c1_i32, %217 : i32 %220 = arith.muli %219, %c64_i32 : i32 %221 = arith.addi %218, %220 : i32 %222 = arith.muli %221, %c64_i32 : i32 %223 = tt.splat %222 : i32 -> tensor<64x64xi32> %224 = tt.addptr %arg23, %223 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %225 = tt.addptr %arg24, %223 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %226 = tt.splat %221 : i32 -> tensor<64xi32> %227 = arith.addi %arg25, %226 : tensor<64xi32> scf.yield %203, %190, %224, %225, %227 : tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64xi32> } %122 = tt.expand_dims %121#4 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %123 = arith.cmpi slt, %122, %cst_1 : tensor<1x64xi32> %124 = tt.broadcast %123 : tensor<1x64xi1> -> tensor<64x64xi1> %125 = tt.load %121#2, %124, %cst_8 : tensor<64x64x!tt.ptr> %126 = arith.cmpi slt, %121#4, %cst_17 : tensor<64xi32> %127 = tt.splat %67 : !tt.ptr -> tensor<64x!tt.ptr> %128 = tt.addptr %127, %121#4 : tensor<64x!tt.ptr>, tensor<64xi32> %129 = tt.load %128, %126 : tensor<64x!tt.ptr> %130 = arith.cmpf oeq, %129, %cst_16 : tensor<64xf32> %131 = arith.select %130, %cst_15, %129 : tensor<64xi1>, tensor<64xf32> %132 = tt.dot %30, %125, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %133 = arith.mulf %132, %cst_13 : tensor<64x64xf32> %134 = arith.mulf %133, %cst_3 : tensor<64x64xf32> %135 = arith.select %29, %134, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %136 = arith.select %28, %cst, %cst_0 : tensor<64x1xi1>, tensor<64x1xi1> %137 = tt.broadcast %136 : tensor<64x1xi1> -> tensor<64x64xi1> %138 = arith.select %137, %135, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %139 = arith.mulf %138, %cst_2 : tensor<64x64xf32> %140 = tt.expand_dims %131 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %141 = tt.broadcast %140 : tensor<1x64xf32> -> tensor<64x64xf32> %142 = arith.subf %139, %141 : tensor<64x64xf32> %143 = math.exp2 %142 : tensor<64x64xf32> %144 = tt.expand_dims %121#4 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %145 = arith.cmpi slt, %144, %cst_12 : tensor<64x1xi32> %146 = tt.broadcast %145 : tensor<64x1xi1> -> tensor<64x64xi1> %147 = tt.load %121#3, %146, %cst_8 : tensor<64x64x!tt.ptr> %148 = arith.truncf %143 : tensor<64x64xf32> to tensor<64x64xf16> %149 = tt.dot %148, %147, %121#1, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %150 = tt.splat %68 : !tt.ptr -> tensor<64x!tt.ptr> %151 = tt.addptr %150, %121#4 : tensor<64x!tt.ptr>, tensor<64xi32> %152 = tt.load %151, %126 : tensor<64x!tt.ptr> %153 = tt.trans %147 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %154 = tt.dot %35, %153, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %155 = tt.expand_dims %152 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %156 = tt.broadcast %155 : tensor<1x64xf32> -> tensor<64x64xf32> %157 = arith.subf %154, %156 : tensor<64x64xf32> %158 = arith.mulf %143, %157 : tensor<64x64xf32> %159 = arith.mulf %158, %cst_3 : tensor<64x64xf32> %160 = arith.select %29, %159, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %161 = arith.select %28, %136, %cst_0 : tensor<64x1xi1>, tensor<64x1xi1> %162 = tt.broadcast %161 : tensor<64x1xi1> -> tensor<64x64xi1> %163 = arith.select %162, %160, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %164 = arith.truncf %163 : tensor<64x64xf32> to tensor<64x64xf16> %165 = tt.trans %125 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %166 = tt.dot %164, %165, %121#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> scf.yield %166, %149 : tensor<64x64xf32>, tensor<64x64xf32> } else { scf.yield %arg18, %arg19 : tensor<64x64xf32>, tensor<64x64xf32> } %99 = tt.addptr %arg15, %73 : !tt.ptr, i32 %100 = tt.load %99 : !tt.ptr %101 = arith.muli %100, %c128_i32 : i32 %102 = tt.addptr %arg14, %73 : !tt.ptr, i32 %103 = tt.load %102 : !tt.ptr %104 = tt.splat %101 : i32 -> tensor<64xi32> %105 = arith.addi %104, %14 : tensor<64xi32> %106 = tt.expand_dims %105 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %107 = arith.muli %106, %cst_11 : tensor<1x64xi32> %108 = tt.addptr %83, %107 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> %109 = tt.broadcast %108 : tensor<1x64x!tt.ptr> -> tensor<64x64x!tt.ptr> %110 = tt.addptr %109, %87 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %111 = tt.expand_dims %105 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %112 = arith.muli %111, %cst_14 : tensor<64x1xi32> %113 = tt.addptr %91, %112 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> %114 = tt.broadcast %113 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> %115 = tt.addptr %114, %26 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %116 = arith.muli %103, %c2_i32 : i32 %117 = arith.minsi %116, %c1_i32 : i32 %118 = arith.cmpi sge, %117, %c1_i32 : i32 %119:2 = scf.if %118 -> (tensor<64x64xf32>, tensor<64x64xf32>) { %120 = arith.subi %117, %c1_i32 : i32 %121:5 = scf.for %arg20 = %c0_i32 to %120 step %c1_i32 iter_args(%arg21 = %98#0, %arg22 = %98#1, %arg23 = %110, %arg24 = %115, %arg25 = %105) -> (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64xi32>) : i32 { %161 = tt.expand_dims %arg25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %162 = arith.cmpi slt, %161, %cst_1 : tensor<1x64xi32> %163 = tt.broadcast %162 : tensor<1x64xi1> -> tensor<64x64xi1> %164 = tt.load %arg23, %163, %cst_8 : tensor<64x64x!tt.ptr> %165 = arith.cmpi slt, %arg25, %cst_17 : tensor<64xi32> %166 = tt.splat %67 : !tt.ptr -> tensor<64x!tt.ptr> %167 = tt.addptr %166, %arg25 : tensor<64x!tt.ptr>, tensor<64xi32> %168 = tt.load %167, %165 : tensor<64x!tt.ptr> %169 = arith.cmpf oeq, %168, %cst_16 : tensor<64xf32> %170 = arith.select %169, %cst_15, %168 : tensor<64xi1>, tensor<64xf32> %171 = tt.dot %30, %164, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %172 = arith.mulf %171, %cst_13 : tensor<64x64xf32> %173 = arith.mulf %172, %cst_3 : tensor<64x64xf32> %174 = arith.mulf %173, %cst_2 : tensor<64x64xf32> %175 = tt.expand_dims %170 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %176 = tt.broadcast %175 : tensor<1x64xf32> -> tensor<64x64xf32> %177 = arith.subf %174, %176 : tensor<64x64xf32> %178 = math.exp2 %177 : tensor<64x64xf32> %179 = tt.expand_dims %arg25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %180 = arith.cmpi slt, %179, %cst_12 : tensor<64x1xi32> %181 = tt.broadcast %180 : tensor<64x1xi1> -> tensor<64x64xi1> %182 = tt.load %arg24, %181, %cst_8 : tensor<64x64x!tt.ptr> %183 = arith.truncf %178 : tensor<64x64xf32> to tensor<64x64xf16> %184 = tt.dot %183, %182, %arg22, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %185 = tt.splat %68 : !tt.ptr -> tensor<64x!tt.ptr> %186 = tt.addptr %185, %arg25 : tensor<64x!tt.ptr>, tensor<64xi32> %187 = tt.load %186, %165 : tensor<64x!tt.ptr> %188 = tt.trans %182 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %189 = tt.dot %35, %188, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %190 = tt.expand_dims %187 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %191 = tt.broadcast %190 : tensor<1x64xf32> -> tensor<64x64xf32> %192 = arith.subf %189, %191 : tensor<64x64xf32> %193 = arith.mulf %178, %192 : tensor<64x64xf32> %194 = arith.mulf %193, %cst_3 : tensor<64x64xf32> %195 = arith.truncf %194 : tensor<64x64xf32> to tensor<64x64xf16> %196 = tt.trans %164 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %197 = tt.dot %195, %196, %arg21, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %198 = arith.divsi %arg20, %c2_i32 : i32 %199 = tt.addptr %99, %198 : !tt.ptr, i32 %200 = tt.load %199 evictionPolicy = evict_last : !tt.ptr %201 = arith.addi %198, %c1_i32 : i32 %202 = arith.cmpi slt, %201, %103 : i32 %203 = tt.addptr %199, %c1_i32 : !tt.ptr, i32 %204 = tt.load %203, %202 evictionPolicy = evict_last : !tt.ptr %205 = arith.addi %arg20, %c1_i32 : i32 %206 = arith.remsi %205, %c2_i32 : i32 %207 = arith.cmpi eq, %206, %c0_i32 : i32 %208 = arith.subi %204, %200 : i32 %209 = arith.muli %208, %c128_i32 : i32 %210 = arith.subi %209, %c64_i32 : i32 %211 = arith.extui %207 : i1 to i32 %212 = arith.muli %210, %211 : i32 %213 = arith.subi %c1_i32, %211 : i32 %214 = arith.muli %213, %c64_i32 : i32 %215 = arith.addi %212, %214 : i32 %216 = arith.muli %215, %c64_i32 : i32 %217 = tt.splat %216 : i32 -> tensor<64x64xi32> %218 = tt.addptr %arg23, %217 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %219 = tt.addptr %arg24, %217 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %220 = tt.splat %215 : i32 -> tensor<64xi32> %221 = arith.addi %arg25, %220 : tensor<64xi32> scf.yield %197, %184, %218, %219, %221 : tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64xi32> } %122 = tt.expand_dims %121#4 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %123 = arith.cmpi slt, %122, %cst_1 : tensor<1x64xi32> %124 = tt.broadcast %123 : tensor<1x64xi1> -> tensor<64x64xi1> %125 = tt.load %121#2, %124, %cst_8 : tensor<64x64x!tt.ptr> %126 = arith.cmpi slt, %121#4, %cst_17 : tensor<64xi32> %127 = tt.splat %67 : !tt.ptr -> tensor<64x!tt.ptr> %128 = tt.addptr %127, %121#4 : tensor<64x!tt.ptr>, tensor<64xi32> %129 = tt.load %128, %126 : tensor<64x!tt.ptr> %130 = arith.cmpf oeq, %129, %cst_16 : tensor<64xf32> %131 = arith.select %130, %cst_15, %129 : tensor<64xi1>, tensor<64xf32> %132 = tt.dot %30, %125, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %133 = arith.mulf %132, %cst_13 : tensor<64x64xf32> %134 = arith.mulf %133, %cst_3 : tensor<64x64xf32> %135 = arith.select %29, %134, %cst_6 : tensor<64x64xi1>, tensor<64x64xf32> %136 = arith.mulf %135, %cst_2 : tensor<64x64xf32> %137 = tt.expand_dims %131 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %138 = tt.broadcast %137 : tensor<1x64xf32> -> tensor<64x64xf32> %139 = arith.subf %136, %138 : tensor<64x64xf32> %140 = math.exp2 %139 : tensor<64x64xf32> %141 = tt.expand_dims %121#4 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %142 = arith.cmpi slt, %141, %cst_12 : tensor<64x1xi32> %143 = tt.broadcast %142 : tensor<64x1xi1> -> tensor<64x64xi1> %144 = tt.load %121#3, %143, %cst_8 : tensor<64x64x!tt.ptr> %145 = arith.truncf %140 : tensor<64x64xf32> to tensor<64x64xf16> %146 = tt.dot %145, %144, %121#1, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %147 = tt.splat %68 : !tt.ptr -> tensor<64x!tt.ptr> %148 = tt.addptr %147, %121#4 : tensor<64x!tt.ptr>, tensor<64xi32> %149 = tt.load %148, %126 : tensor<64x!tt.ptr> %150 = tt.trans %144 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %151 = tt.dot %35, %150, %cst_9, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> %152 = tt.expand_dims %149 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> %153 = tt.broadcast %152 : tensor<1x64xf32> -> tensor<64x64xf32> %154 = arith.subf %151, %153 : tensor<64x64xf32> %155 = arith.mulf %140, %154 : tensor<64x64xf32> %156 = arith.mulf %155, %cst_3 : tensor<64x64xf32> %157 = arith.select %29, %156, %cst_9 : tensor<64x64xi1>, tensor<64x64xf32> %158 = arith.truncf %157 : tensor<64x64xf32> to tensor<64x64xf16> %159 = tt.trans %125 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> %160 = tt.dot %158, %159, %121#0, inputPrecision = tf32 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> scf.yield %160, %146 : tensor<64x64xf32>, tensor<64x64xf32> } else { scf.yield %98#0, %98#1 : tensor<64x64xf32>, tensor<64x64xf32> } scf.yield %119#0, %119#1 : tensor<64x64xf32>, tensor<64x64xf32> } %37 = tt.splat %13 : !tt.ptr -> tensor<64x1x!tt.ptr> %38 = tt.addptr %37, %21 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> %39 = tt.broadcast %38 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> %40 = tt.addptr %39, %26 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %41 = arith.cmpi slt, %24, %cst_11 : tensor<1x64xi32> %42 = tt.broadcast %41 : tensor<1x64xi1> -> tensor<64x64xi1> %43 = arith.andi %29, %42 : tensor<64x64xi1> %44 = arith.truncf %36#1 : tensor<64x64xf32> to tensor<64x64xf16> tt.store %40, %44, %43 : tensor<64x64x!tt.ptr> %45 = arith.mulf %36#0, %cst_13 : tensor<64x64xf32> %46 = tt.broadcast %21 : tensor<64x1xi32> -> tensor<64x64xi32> %47 = arith.addi %26, %46 : tensor<64x64xi32> %48 = tt.splat %4 : i32 -> tensor<64x64xi32> %49 = arith.addi %47, %48 : tensor<64x64xi32> %50 = tt.splat %8 : i32 -> tensor<64x64xi32> %51 = arith.addi %49, %50 : tensor<64x64xi32> %52 = tt.splat %arg16 : !tt.ptr -> tensor<64x64x!tt.ptr> %53 = tt.addptr %52, %51 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> %54 = arith.truncf %45 : tensor<64x64xf32> to tensor<64x64xf16> tt.store %53, %54, %29 : tensor<64x64x!tt.ptr> } tt.return } } {-# external_resources: { mlir_reproducer: { pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=4 target=cuda:100 threads-per-warp=32}, tritongpu-coalesce, tritongpu-F32DotTC, triton-nvidia-gpu-plan-cta, tritongpu-remove-layout-conversions, tritongpu-optimize-thread-locality, tritongpu-accelerate-matmul, tritongpu-remove-layout-conversions, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, triton-nvidia-optimize-descriptor-encoding, triton-loop-aware-cse, tritongpu-fuse-nested-loops, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, tritongpu-optimize-accumulator-init, tritongpu-hoist-tmem-alloc, tritongpu-promote-lhs-to-tmem, tritongpu-assign-latencies{num-stages=3}, tritongpu-schedule-loops, tritongpu-automatic-warp-specialization{num-stages=3}, tritongpu-pipeline{dump-intermediate-steps=false num-stages=3}, tritongpu-combine-tensor-select-and-if, triton-nvidia-gpu-remove-tmem-tokens, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-loop-aware-cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritongpu-coalesce-async-copy, triton-nvidia-optimize-tmem-layouts, tritongpu-remove-layout-conversions, triton-nvidia-interleave-tmem, tritongpu-reduce-data-duplication, tritongpu-reorder-instructions, triton-loop-aware-cse, symbol-dce, triton-nvidia-tma-lowering, triton-nvidia-gpu-fence-insertion{compute-capability=90}, sccp, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true})", disable_threading: false, verify_each: true } } #-} /tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: error: Failures have been detected while processing an MLIR pass pipeline /tmp/tmp0yiz3c94/p4/cp4ahrfnz4obsvzgftux7dg3aszopks2jljnoaz3eowlooi2scem.py:18:0: note: Pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` Triton compilation failed: triton_tem_fused_zeros_1 def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): PRESCALE_QK : tl.constexpr = False ROWS_GUARANTEED_SAFE : tl.constexpr = False BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False WRITE_DQ : tl.constexpr = True OUTPUT_LOGSUMEXP : tl.constexpr = True FLOAT32_PRECISION : tl.constexpr = 'tf32' IS_DIVISIBLE : tl.constexpr = False SM_SCALE : tl.constexpr = 0.125 GQA_SHARED_HEADS : tl.constexpr = 4 HAS_FULL_BLOCKS : tl.constexpr = True QK_HEAD_DIM : tl.constexpr = 64 QK_HEAD_DIM_ROUNDED : tl.constexpr = 64 V_HEAD_DIM : tl.constexpr = 64 V_HEAD_DIM_ROUNDED : tl.constexpr = 64 SAFE_HEAD_DIM : tl.constexpr = True BLOCK_M1 : tl.constexpr = 64 BLOCK_N1 : tl.constexpr = 64 BLOCK_M2 : tl.constexpr = 64 BLOCK_N2 : tl.constexpr = 64 SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 Q = arg_Q K = arg_K V = arg_V LSE = arg_LSE DELTA = arg_DELTA DO = arg_DO DQ = arg_DQ DV = arg_DV KV_NUM_BLKS = arg_KV_NUM_BLKS KV_IDX = arg_KV_IDX Q_NUM_BLKS = arg_Q_NUM_BLKS Q_IDX = arg_Q_IDX FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS FULL_KV_IDX = arg_FULL_KV_IDX FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS FULL_Q_IDX = arg_FULL_Q_IDX # Sub notation for this kernel: # # Q: Query, K: Key, V: Value # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) # DELTA: Precomputed sum(OUT*DO, axis=-1) # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value # DK: Derivative of Key, is the written to via the store_output call due to some limitations with # inductor codegen # M: Number of queries, N: Number of keys/values # QK_HEAD_DIM: The dimension of the query and key embeddings # V_HEAD_DIM: The dimension of the value embeddings # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. # (Modifiable) Performance tuning options # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. # # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. # The below are kernel options that can be applied for certain score_mods, # or involve a numerics vs. perf tradeoff # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has # about 20% more numerical error, but slightly faster. # Define strides of inputs stride_qz, stride_qh, stride_qm, stride_qd = 32768, 2048, 64, 1 stride_kz, stride_kh, stride_kn, stride_kd = 65536, 16384, 64, 1 stride_vz, stride_vh, stride_vn, stride_vd = 65536, 16384, 64, 1 stride_doz, stride_doh, stride_dom, stride_dod = 32768, 2048, 64, 1 stride_dqz, stride_dqh, stride_dqm, stride_dqd = 32768, 2048, 64, 1 stride_dvz, stride_dvh, stride_dvm, stride_dvd = 65536, 16384, 64, 1 ZQ = 2 HQ = 16 HKV = 4 Q_LEN = 32 ZKV = 2 KV_LEN = 256 MATMUL_PRECISION = Q.dtype.element_ty pid = tl.program_id(0) NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) off_zq = tl.program_id(1) # q batch idx off_hkv = tl.program_id(2) # kv head idx off_zkv = off_zq % ZKV # kv batch idx SPARSE_Z = 2 SPARSE_HQ = 16 sparse_idx_z = off_zq % SPARSE_Z k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) # offset K, V, DV pointers for batch/kv-head K += k_adj V += v_adj DV += dv_adj RCP_LN2 = 1.44269504 offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) if pid >= NUM_KV_BLOCKS: off_pid = pid - NUM_KV_BLOCKS # THIS BLOCK DOES DQ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS start_m2_block = off_pid % NUM_Q_BLOCKS off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE stride_kv_num_blks_h = 1 stride_kv_idx_h = 2 stride_kv_idx_m = 2 sparse_idx_hq2 = off_hq2 % SPARSE_HQ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) Q2 = Q + q_adj2 DO2 = DO + do_adj2 # TODO: This does not work if DQ is not the same layout as Q (for example, # if Q is broadcasted) DQ2 = DQ + dq_adj2 LSE2 = LSE + off_chz2 DELTA2 = DELTA + off_chz2 # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) start_m2 = start_m2_block * BLOCK_M2 offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) # load Q and do: they stay in SRAM throughout the inner loop. q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) if PRESCALE_QK: q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) if IS_DIVISIBLE: Di = tl.load(DELTA2 + offs_m2) lse = tl.load(LSE2 + offs_m2) else: Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) lse = tl.where(lse == -float("inf"), 0.0, lse) lse = lse[:, None] # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # KV_IDX and KV_NUM_BLKS are always contiguous. kv_indices = KV_IDX + sparse_kv_idx_offset kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) offs_n2 = kv_start + tl.arange(0, BLOCK_N2) dq = bwd_dq_inner( arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, K, V, dq, q, do, Di, lse, off_zq, off_hq2, offs_m2, offs_n2, stride_kn, stride_kd, stride_vn, stride_vd, kv_indices, sparse_kv_num_blocks, MATMUL_PRECISION, IS_FULL_BLOCKS=False, ) if HAS_FULL_BLOCKS: # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. kv_indices = FULL_KV_IDX + sparse_kv_idx_offset kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) offs_n2 = kv_start + tl.arange(0, BLOCK_N2) dq = bwd_dq_inner( arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, K, V, dq, q, do, Di, lse, off_zq, off_hq2, offs_m2, offs_n2, stride_kn, stride_kd, stride_vn, stride_vd, kv_indices, sparse_kv_num_blocks, MATMUL_PRECISION, IS_FULL_BLOCKS=True, ) # Write back dQ. dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd dq *= SM_SCALE if IS_DIVISIBLE and SAFE_HEAD_DIM: tl.store(dq_ptrs, dq) else: tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) else: # THIS BLOCK DOES DK & DV SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) pid_mask = pid // SPARSE_KV_MULTIPLE stride_q_num_blks_h = 2 stride_q_idx_h = 2 stride_q_idx_n = 1 dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) start_n1 = pid * BLOCK_N1 offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) # load K and V: they stay in SRAM throughout the inner loop. k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) if PRESCALE_QK: k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) for off_g in range(0, GQA_SHARED_HEADS): off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) Q1 = Q + q_adj1 DO1 = DO + do_adj1 # TODO: This does not work if DQ is not the same layout as Q (for example, # if Q is broadcasted) LSE1 = LSE + off_chz1 DELTA1 = DELTA + off_chz1 sparse_idx_hq1 = off_hq1 % SPARSE_HQ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Q_IDX and Q_NUM_BLKS are always contiguous. q_indices = Q_IDX + sparse_q_idx_offset q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) offs_m1 = q_start + tl.arange(0, BLOCK_M1) dk, dv = bwd_dkdv_inner( arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, Q1, DO1, DELTA1, LSE1, dk, dv, k, v, off_zq, off_hq1, offs_n1, offs_m1, stride_qm, stride_qd, stride_dom, stride_dod, q_indices, sparse_q_num_blocks, MATMUL_PRECISION, IS_FULL_BLOCKS=False, ) if HAS_FULL_BLOCKS: # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. q_indices = FULL_Q_IDX + sparse_q_idx_offset q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) offs_m1 = q_start + tl.arange(0, BLOCK_M1) dk, dv = bwd_dkdv_inner( arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, Q1, DO1, DELTA1, LSE1, dk, dv, k, v, off_zq, off_hq1, offs_n1, offs_m1, stride_qm, stride_qd, stride_dom, stride_dod, q_indices, sparse_q_num_blocks, MATMUL_PRECISION, IS_FULL_BLOCKS=True, ) # Write back dV and dK. dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd index_n = offs_n1[:, None] index_k = offs_k[None, :] index_v = offs_v[None, :] if IS_DIVISIBLE and SAFE_HEAD_DIM: tl.store(dv_ptrs, dv) else: tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) dk *= SM_SCALE if SAFE_HEAD_DIM: mask = index_n < KV_LEN else: mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] xindex = index_k + 64*index_n + 16384*off_hkv + 65536*off_zq tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) metadata: {'signature': {'arg_Q': '*fp16', 'arg_K': '*fp16', 'arg_V': '*fp16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*fp16', 'arg_DQ': '*fp16', 'arg_DV': '*fp16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*fp16'}, 'device': 0, 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}], 'device_type': 'cuda', 'num_warps': 4, 'num_stages': 3, 'debug': True, 'cc': 100} Traceback (most recent call last): File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 748, in _precompile_config binary = triton.compile(*compile_args, **compile_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/compiler/compiler.py", line 359, in compile next_module = compile_ir(module, metadata) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 456, in stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 298, in make_ttgir pm.run(mod) RuntimeError: PassManager::run failed frames [('total', 3), ('ok', 3)] inline_call [] stats [('calls_captured', 8), ('unique_graphs', 3)] aot_autograd [('total', 1), ('autograd_cache_miss', 1), ('ok', 1)] inductor [('triton_bundler_save_kernel', 8), ('async_compile_cache_miss', 3), ('fxgraph_cache_miss', 1), ('triton_bundler_save_static_autotuner', 1), ('fxgraph_cache_bypass', 1)] graph_break [] F ==================================================== FAILURES ===================================================== _____________________________ TestFlexAttentionCUDA.test_GQA_score_mod1_cuda_float16 ______________________________ Traceback (most recent call last): File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 58, in testPartExecutor yield File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 634, in run self._callTestMethod(testMethod) File "/home/drisspg/.conda/envs/dev/lib/python3.12/unittest/case.py", line 589, in _callTestMethod if method() is not None: ^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper method(*args, **kwargs) File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper method(*args, **kwargs) File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 446, in instantiated_test raise rte File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 426, in instantiated_test result = test(self, **param_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 1349, in dep_fn return fn(self, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/testing/_internal/common_device_type.py", line 1215, in dep_fn return fn(slf, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/test/inductor/test_flex_attention.py", line 1430, in test_GQA self.run_test(*inputs) File "/home/drisspg/meta/pytorch/test/inductor/test_flex_attention.py", line 566, in run_test compiled_out.backward(backward_grad) File "/home/drisspg/meta/pytorch/torch/_tensor.py", line 625, in backward torch.autograd.backward( File "/home/drisspg/meta/pytorch/torch/autograd/__init__.py", line 354, in backward _engine_run_backward( File "/home/drisspg/meta/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/autograd/function.py", line 315, in apply return user_fn(self, *args) ^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2303, in backward return impl_fn() ^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2289, in impl_fn out = CompiledFunction._backward_impl(ctx, all_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2394, in _backward_impl CompiledFunction.compiled_bw = aot_config.bw_compiler( ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/schemas.py", line 1256, in __call__ return self.compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_dynamo/backends/common.py", line 76, in _wrapped_bw_compiler disable( File "/home/drisspg/meta/pytorch/torch/_dynamo/eval_frame.py", line 1005, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_utils_internal.py", line 92, in wrapper_function return function(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 2428, in bw_compiler return inner_compile( ^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 773, in compile_fx_inner return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper inner_compiled_fn = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 952, in _compile_fx_inner mb_compiled_graph = fx_codegen_and_compile( ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1652, in fx_codegen_and_compile return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1506, in codegen_and_compile compiled_module = graph.compile_to_module() ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2318, in compile_to_module return self._compile_to_module() ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2328, in _compile_to_module mod = self._compile_to_module_lines(wrapper_code) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 2396, in _compile_to_module_lines mod = PyCodeCache.load_by_key_path( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/codecache.py", line 3466, in load_by_key_path mod = _reload_python_module(key, path, set_sys_modules=in_toplevel) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/compile_tasks.py", line 33, in _reload_python_module exec(code, mod.__dict__, mod.__dict__) File "/tmp/tmp0yiz3c94/az/caza2gzmsagyuusmf2ka3oat3na4xv6zudssk244xmlzsbv2knze.py", line 117, in File "/home/drisspg/meta/pytorch/torch/_inductor/async_compile.py", line 489, in triton kernel.precompile( File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 437, in precompile self._precompile_worker() File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 459, in _precompile_worker compile_results.append(self._precompile_config(c)) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 748, in _precompile_config binary = triton.compile(*compile_args, **compile_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/compiler/compiler.py", line 359, in compile next_module = compile_ir(module, metadata) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 456, in stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/drisspg/.conda/envs/dev/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 298, in make_ttgir pm.run(mod) RuntimeError: PassManager::run failed To execute this test, run the following from the base repo dir: python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_score_mod1_cuda_float16 This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ============================================= short test summary info ============================================= FAILED [5.1441s] test/inductor/test_flex_attention.py::TestFlexAttentionCUDA::test_GQA_score_mod1_cuda_float16 - RuntimeError: PassManager::run failed ================================== 1 failed, 1 passed, 1404 deselected in 18.10s ================================== ~/meta/pytorch flex-warning !1 ❯ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160227 Approved by: https://github.com/Skylion007, https://github.com/Chillee --- torch/_inductor/kernel/flex/flex_attention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index b6f5646bb57cb..429f8d05c8cd5 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -361,7 +361,6 @@ def flex_attention( score_mod_other_buffers, mask_mod_other_buffers, ) - # below is cuda path if device is not cpu # tl.dot does not support embedding size less than 16 small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16)) @@ -1138,7 +1137,7 @@ def bwd_dq_block_mn( # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ if WRITE_DQ: - scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) {{ modification( subgraph_number=3, output_name=None, @@ -1341,7 +1340,7 @@ def bwd_dkdv_block_mn( idx_h = off_hq idx_m = m idx_n = n - scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) {{ modification( subgraph_number=3, output_name=None, From e63c2b21c186a7d2ab8a8953b8aa1535f2e96e58 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 11 Aug 2025 10:59:16 -0700 Subject: [PATCH 0235/1424] [PP] Initialize P2P communicators on first step (#160210) Was hitting hangs in multi-node settings and initializing the NCCL communicators needed for batch p2p ops ahead of time fixes this. This change adds extra communication since it communicates a dummy tensor to next and previous stage ranks. However, this is only paid on the first step so it is negligible. Debug history: https://docs.google.com/document/d/1EKVJYmW2hj_VsvDvnSggXhZzJyvMu9dA0iDJWOZAtjY/edit?tab=t.0 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160210 Approved by: https://github.com/wconstab --- torch/distributed/pipelining/schedules.py | 15 +++++++ torch/distributed/pipelining/stage.py | 54 +++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index d0133ae1f19b1..1c0f4d27a638e 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -554,6 +554,13 @@ def __init__( ) def _initialize_stage(self, args, kwargs): + # Prepare the communication needed for the pipeline schedule execution + # This is needed because during execution we always perform a series of batch P2P ops + # The first call of the batched P2P needs to involve the global group + all_ops: list[dist.P2POp] = [] + all_ops.extend(self._stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) if self._has_backward: self._stage._prepare_backward_infra(self._n_microbatches) @@ -1428,6 +1435,14 @@ def __init__( ) def _initialize_stages(self, args: tuple[Any, ...], kwargs): + # Prepare the communication needed for the pipeline schedule execution + # This is needed because during execution we always perform a series of batch P2P ops + # The first call of the batched P2P needs to involve the global group + all_ops: list[dist.P2POp] = [] + for stage in self._stages: + all_ops.extend(stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) # or real value (if this stage and next stage are on the same device) next_stage_args: tuple[Any, ...] = tuple() diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index e4de0ddd03ab5..c1abebde5b853 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -935,6 +935,60 @@ def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]): f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs ) + def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: + """ + Get the operations to initialize the p2p communicators between previous and next stages. + This is done so by creating a dummy tensor and sending it to the next stage and receiving + from the previous stage. + """ + ops: list[dist.P2POp] = [] + next_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index + 1) + prev_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index - 1) + + recv_tensor = torch.zeros(1, device=self.device) + send_tensor = torch.tensor(self.stage_index, device=self.device) + # forward + if not self.is_first: + ops.append( + dist.P2POp( + dist.irecv, + recv_tensor, + group_peer=prev_stage_peer_rank, + group=self.group, + ) + ) + if not self.is_last: + ops.append( + dist.P2POp( + dist.isend, + send_tensor, + group_peer=next_stage_peer_rank, + group=self.group, + ) + ) + + # backward + if not self.is_first: + ops.append( + dist.P2POp( + dist.isend, + send_tensor, + group_peer=prev_stage_peer_rank, + group=self.group, + ) + ) + if not self.is_last: + ops.append( + dist.P2POp( + dist.irecv, + recv_tensor, + group_peer=next_stage_peer_rank, + group=self.group, + ) + ) + + return ops + class _PipelineStage(_PipelineStageBase): def __init__( From ee89cc7a0acd69de25f98fe4ef828546db7b444c Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 12 Aug 2025 00:18:15 +0000 Subject: [PATCH 0236/1424] [ROCm][Windows] Fix LoadHIP handling of environment variable paths on Windows. (#159080) See https://cmake.org/cmake/help/latest/command/file.html#path-conversion. Paths stored in environment variables may use `/` or `\` (e.g. on Windows), while cmake-style paths always use `/`. This fixes configure errors like: ``` CMake Error at D:/b/pytorch_main/build/CMakeFiles/CMakeScratch/TryCompile-srhq07/CMakeLists.txt:2 (set): Syntax error in cmake code at D:/b/pytorch_main/build/CMakeFiles/CMakeScratch/TryCompile-srhq07/CMakeLists.txt:2 when parsing string D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\_rocm_sdk_devel/cmake/;D:/b/pytorch_main/cmake/Modules Invalid character escape '\p'. CMake Error at D:/projects/TheRock/external-builds/pytorch/.venv/Lib/site-packages/cmake/data/share/cmake-3.31/Modules/Internal/CheckSourceCompiles.cmake:108 (try_compile): Failed to configure test project build system. ``` (note the mixed usage of `\` and `/` in that string) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159080 Approved by: https://github.com/jeffdaily --- cmake/public/LoadHIP.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 132f9670ff34f..018bca837a5a8 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -6,7 +6,7 @@ set(PYTORCH_FOUND_HIP FALSE) # In the latter case, if /opt/rocm does not exist emit status # message and return. if(DEFINED ENV{ROCM_PATH}) - set(ROCM_PATH $ENV{ROCM_PATH}) + file(TO_CMAKE_PATH "$ENV{ROCM_PATH}" ROCM_PATH) if(NOT EXISTS ${ROCM_PATH}) message(FATAL_ERROR "ROCM_PATH environment variable is set to ${ROCM_PATH} but does not exist.\n" @@ -31,7 +31,7 @@ if(NOT DEFINED ENV{MAGMA_HOME}) set(MAGMA_HOME ${ROCM_PATH}/magma) set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma) else() - set(MAGMA_HOME $ENV{MAGMA_HOME}) + file(TO_CMAKE_PATH "$ENV{MAGMA_HOME}" MAGMA_HOME) endif() # MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different From cae2b5e3d223829bdc553fc8601df4b1c1554cff Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 12 Aug 2025 01:28:17 +0000 Subject: [PATCH 0237/1424] [ROCm][Windows] Enable USE_ROCM, disable USE_RCCL on Windows. (#159079) This allows setting `USE_ROCM` on Windows. A few other patches are still required to build (see https://github.com/ROCm/TheRock/issues/589), but we have instructions using open source code and rocm python packages available at https://github.com/ROCm/TheRock/tree/main/external-builds/pytorch#build-pytorch-with-rocm-support. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159079 Approved by: https://github.com/jeffdaily --- CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 48b9e2e8df3eb..cc9476bb001ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,7 +239,7 @@ option(USE_XPU "Use XPU" ON) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) -cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF) +cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX OR WIN32" OFF) cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF) option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF) option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF) @@ -267,6 +267,7 @@ cmake_dependent_option(USE_NCCL "Use NCCL" ON cmake_dependent_option(USE_XCCL "Use XCCL" ON "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) +cmake_dependent_option(USE_RCCL "Use RCCL" ON "USE_NCCL;NOT WIN32" OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" OFF) From 0d40ff3b496e68193bc16d5391fa2e3623709f81 Mon Sep 17 00:00:00 2001 From: "Han, Xu" Date: Tue, 12 Aug 2025 01:35:39 +0000 Subject: [PATCH 0238/1424] [inductor] fix test_different_file_paths_local_pgo on Windows. (#160382) fix test_different_file_paths_local_pgo on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160382 Approved by: https://github.com/angelayi --- test/dynamo/test_pgo.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py index 643d15eb2413d..623143ae4dcb5 100644 --- a/test/dynamo/test_pgo.py +++ b/test/dynamo/test_pgo.py @@ -14,6 +14,7 @@ from torch._dynamo.testing import CompileCounter from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.utils import clear_caches, fresh_cache +from torch.testing._internal.common_utils import IS_WINDOWS class PgoTest(torch._dynamo.test_case.TestCase): @@ -349,7 +350,11 @@ def write_load_and_run(path): write_load_and_run(path1) self.assertEqual(cnts.frame_count, 2) state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state()) - self.assertTrue("hash(390fe689)" in state) + + # Windows can't create unification temp path: + # hash(a18a3259)C:/Users/Xuhan/AppData/Local/Temp/tmpx3hfkuqa/example.py + # Skip hash check + self.assertTrue("hash" if IS_WINDOWS else "hash(390fe689)" in state) self.assertTrue("/example.py:4:func:" in state) self.assertTrue(" L['x']: tensor size=[?] stride=[1]" in state) # We should compile this only once due to PGO. From b90feeac86bda00afc2789321bcd706015ff44e3 Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Sun, 10 Aug 2025 20:37:44 -0700 Subject: [PATCH 0239/1424] [BE][cutlass backend] Fix subproc addmm tests (#160295) Differential Revision: [D79977421](https://our.internmc.facebook.com/intern/diff/D79977421/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160295 Approved by: https://github.com/jingsh --- test/inductor/test_cutlass_backend.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 2a944e4046696..8b0712dc810a9 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -294,20 +294,19 @@ def test_cutlass_backend_subproc_mm(self): Y = torch.mm(a, b) torch.testing.assert_close(Y_compiled, Y) - @unittest.skipIf( - True, "FIXME: Disabled temporarily since IMA or crashing in subprocess" - ) @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - def test_cutlass_backend_subproc_addmm(self, shape_combo): + @parametrize("dtype", (torch.float16, torch.bfloat16)) + def test_cutlass_backend_subproc_addmm(self, dtype): """ Test autotune_in_subproc works for addmm. """ M, N, K = 4096, 2048, 25728 + dtype = torch.float16 - a = torch.randn(M, K).cuda().half() - b = torch.randn(N, K).cuda().half().t() + a = torch.randn(M, K, dtype=dtype).cuda() + b = torch.randn(N, K, dtype=dtype).cuda().t() x_shapes = [ (M, N), @@ -329,7 +328,10 @@ def test_cutlass_backend_subproc_addmm(self, shape_combo): } ): for x_shape in x_shapes: - x = torch.randn(x_shape).cuda().half() + torch._dynamo.reset() + clear_caches() + + x = torch.randn(x_shape).cuda().to(dtype) Y_compiled = torch.compile(torch.addmm)(x, a, b, alpha=alpha, beta=beta) Y = torch.addmm(x, a, b, alpha=alpha, beta=beta) torch.testing.assert_close(Y_compiled, Y) From f3f159ff8c4bad2edec99c68a941c628e983d04c Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Sun, 10 Aug 2025 21:38:15 -0700 Subject: [PATCH 0240/1424] [BE][cutlass backend] Reduce severity of log message for no cutlass config found (#160148) This is not really a problem. Sometimes we cannot find a cutlass config due to shape, e.g. when k is odd. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160148 Approved by: https://github.com/mlazos, https://github.com/Skylion007 --- torch/_inductor/codegen/cuda/gemm_template.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index e74161deeb141..0e11bc100002e 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -594,11 +594,14 @@ def _add_cutlass_gemm_choices( ) if len(ops) == 0: - input_layouts = [node.get_layout() for node in input_nodes] - input_strides = [node.get_stride() for node in input_nodes] - output_layout = layout - warning_msg = f"No suitable Cutlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950 - log.warning(warning_msg) + log.info( + "No suitable Cutlass GEMM configs found, fallbacks used " + "( len(ops)=%d, output_layout=%s, input_layouts=%s, input_strides=%s )", + len(ops), + layout, + [node.get_layout() for node in input_nodes], + [node.get_stride() for node in input_nodes], + ) log.debug( "Added %d Cutlass gemm configs.", len(ops), From 7a974a88f2c529a614baeabe4debd00fc8a3b299 Mon Sep 17 00:00:00 2001 From: Ramya Ramineni <62723901+rraminen@users.noreply.github.com> Date: Tue, 12 Aug 2025 01:57:58 +0000 Subject: [PATCH 0241/1424] [ROCm] Fix resource_strings.h (#159996) This PR fixes the errors like below: ``` [rank7]: RuntimeError: /tmp/comgr-c3c81b/input/CompileSourceejOPx6:34:8: error: unknown type name 'uint64_t'; did you mean '__hip_internal::uint64_t'? [rank7]: 34 | if(((uint64_t) t0.data) % (4 * sizeof(half)) != 0) flag_vec4 = false; ``` The following datatypes needs to be defined in `torch/csrc/jit/codegen/fuser/cuda/resource_strings.h` for ROCm versions >= 7.0. ``` typedef unsigned char uint8_t; typedef signed char int8_t; typedef short int int16_t; typedef long long int int64_t; typedef unsigned long long int uint64_t; ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159996 Approved by: https://github.com/pruthvistony, https://github.com/Skylion007, https://github.com/jeffdaily --- torch/csrc/jit/codegen/fuser/cuda/resource_strings.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index ff2ef1f2377ce..9728d27d4d79b 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -12,7 +12,7 @@ cases*/ static constexpr auto bfloat16_type_string = "__nv_bfloat16"; -#if defined(USE_ROCM) +#if defined(USE_ROCM) && ROCM_VERSION < 70000 static auto type_declarations_template = at::jit::CodeTemplate(R"( ${HalfHeader} ${BFloat16Header} From 95210cc409dd578988c7116b47725c304dea54c7 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Tue, 12 Aug 2025 01:58:44 +0000 Subject: [PATCH 0242/1424] [BE] Isolate pre-push hook dependencies in dedicated virtual environment (#160048) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds two changes: - Isolates pre-push hook dependencies into an isolated venv, no longer affect your system environment - Lets you manually run the pre-push lintrunner (including with lintrunner -a) by invoking `python scripts/lintrunner.py [-a]` (it's ugly, but better than nothing...for now) This is a follow up to: - https://github.com/pytorch/pytorch/pull/158389 ## Problem The current pre-push hook setup installs lintrunner and related dependencies globally, which makes developers nervous about system pollution and can cause version conflicts with existing installations. Also, if the pre-push lintrunner found errors, you had to hope your normal lintrunner could fix them (which wasn't always the case, e.g. if those errors only manifested in certain python versions) ## Key Changes: - Isolated Environment: Creates .git/hooks/linter/.venv/ with Python 3.9 (the python used in CI) and an isolated lintrunner installation - User-Friendly CLI: New python scripts/lintrunner.py wrapper allows developers to run lintrunner (including -a auto-fix) from any environment - Simplified Architecture: Eliminates pre-commit dependency entirely - uses direct git hooks File Changes: - scripts/setup_hooks.py: Rewritten to create isolated uv-managed virtual environment - scripts/lintrunner.py: New wrapper script with shared hash management logic - scripts/run_lintrunner.py: Removed (functionality merged into lintrunner.py) - .pre-commit-config.yaml: Removed (no longer needed) ## Usage: ``` # Setup (run once) python scripts/setup_hooks.py # Manual linting (works from any environment) python scripts/lintrunner.py # Check mode python scripts/lintrunner.py -a # Auto-fix mode # Git hooks work automatically git push # Runs lintrunner in isolated environment # Need to skip the pre-push hook? git push --no-verify ``` ## Benefits: - ✅ Zero global dependency installation - ✅ Per-repository isolation prevents version conflicts - ✅ Full lintrunner functionality is now accessible ## Implementation Notes: - Virtual env is kept in a dedicated dir in .git, to keep per-repo mechanics - lintrunner.py does not need to be invoked from a specific venv. It'll invoke the right venv itself. A minor bug: It tends to garble the lintrunner output a bit, like the screenshot below shows, but I haven't found a workaround so far and it remains understandable to users: image ## What's next? Features that could be added: - Check for lintrunner updates, auto-update if needed - Depending on dev response, this could be enabled by default for all pytorch/pytorch environments Pull Request resolved: https://github.com/pytorch/pytorch/pull/160048 Approved by: https://github.com/seemethere --- .pre-commit-config.yaml | 12 --- scripts/lintrunner.py | 181 ++++++++++++++++++++++++++++++++++++++ scripts/run_lintrunner.py | 110 ----------------------- scripts/setup_hooks.py | 153 +++++++++++++++----------------- 4 files changed, 250 insertions(+), 206 deletions(-) delete mode 100644 .pre-commit-config.yaml create mode 100644 scripts/lintrunner.py delete mode 100644 scripts/run_lintrunner.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 2c67fb1981b71..0000000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,12 +0,0 @@ -repos: - - repo: local - hooks: - - id: lintrunner - name: Run Lintrunner in an isolated venv before every push. The first run may be slow... - entry: python scripts/run_lintrunner.py # wrapper below - language: python # pre‑commit manages venv for the wrapper - additional_dependencies: [] # wrapper handles lintrunner install - always_run: true - stages: [pre-push] # fire only on pre‑push - pass_filenames: false # Lintrunner gets no per‑file args - verbose: true # stream output as it is produced...allegedly anyways diff --git a/scripts/lintrunner.py b/scripts/lintrunner.py new file mode 100644 index 0000000000000..2e3ad2bc219ab --- /dev/null +++ b/scripts/lintrunner.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +Wrapper script to run the isolated hook version of lintrunner. + +This allows developers to easily run lintrunner (including with -a for auto-fixes) +using the same isolated environment that the pre-push hook uses, without having +to manually activate/deactivate virtual environments. + +Usage: + python scripts/lintrunner.py # Check mode (same as git push) + python scripts/lintrunner.py -a # Auto-fix mode + python scripts/lintrunner.py --help # Show lintrunner help + +This module also provides shared functionality for lintrunner hash management. +""" + +from __future__ import annotations + +import hashlib +import os +import shlex +import shutil +import subprocess +import sys +from pathlib import Path + + +def find_repo_root() -> Path: + """Find repository root using git.""" + try: + result = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + capture_output=True, + text=True, + check=True, + ) + return Path(result.stdout.strip()) + except subprocess.CalledProcessError: + sys.exit("❌ Not in a git repository") + + +def compute_file_hash(path: Path) -> str: + """Returns SHA256 hash of a file's contents.""" + hasher = hashlib.sha256() + with path.open("rb") as f: + while chunk := f.read(8192): + hasher.update(chunk) + return hasher.hexdigest() + + +def read_stored_hash(path: Path) -> str | None: + if not path.exists(): + return None + try: + return path.read_text().strip() + except Exception: + return None + + +# Venv location - change this if the path changes +HOOK_VENV_PATH = ".git/hooks/linter/.venv" + + +def get_hook_venv_path() -> Path: + """Get the path to the hook virtual environment.""" + repo_root = find_repo_root() + return repo_root / HOOK_VENV_PATH + + +def find_hook_venv() -> Path: + """Locate the isolated hook virtual environment.""" + venv_dir = get_hook_venv_path() + + if not venv_dir.exists(): + sys.exit( + f"❌ Hook virtual environment not found at {venv_dir}\n" + " Please set this up by running: python scripts/setup_hooks.py" + ) + + return venv_dir + + +def check_lintrunner_installed(venv_dir: Path) -> None: + """Check if lintrunner is installed in the given venv, exit if not.""" + result = subprocess.run( + [ + "uv", + "pip", + "show", + "--python", + str(venv_dir / "bin" / "python"), + "lintrunner", + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + if result.returncode != 0: + sys.exit( + "❌ lintrunner is required but was not found in the hook environment. " + "Please run `python scripts/setup_hooks.py` to reinstall." + ) + print("✅ lintrunner is already installed") + + +def run_lintrunner(venv_dir: Path, args: list[str]) -> int: + """Run lintrunner command in the specified venv and return exit code.""" + # Run lintrunner directly from the venv's bin directory with environment setup + lintrunner_exe = venv_dir / "bin" / "lintrunner" + cmd = [str(lintrunner_exe)] + args + env = os.environ.copy() + + # PATH: Ensures lintrunner can find other tools in the venv (like python, pip, etc.) + env["PATH"] = str(venv_dir / "bin") + os.pathsep + env.get("PATH", "") + # VIRTUAL_ENV: Tells tools like pip_init.py that we're in a venv (prevents --user flag issues) + env["VIRTUAL_ENV"] = str(venv_dir) + + # Note: Progress tends to be slightly garbled due to terminal control sequences, + # but functionality and final results will be correct + return subprocess.call(cmd, env=env) + + +def initialize_lintrunner_if_needed(venv_dir: Path) -> None: + """Check if lintrunner needs initialization and run init if needed.""" + repo_root = find_repo_root() + lintrunner_toml_path = repo_root / ".lintrunner.toml" + initialized_hash_path = venv_dir / ".lintrunner_plugins_hash" + + if not lintrunner_toml_path.exists(): + print("⚠️ No .lintrunner.toml found. Skipping init.") + return + + current_hash = compute_file_hash(lintrunner_toml_path) + stored_hash = read_stored_hash(initialized_hash_path) + + if current_hash != stored_hash: + print("🔁 Running `lintrunner init` …", file=sys.stderr) + result = run_lintrunner(venv_dir, ["init"]) + if result != 0: + sys.exit(f"❌ lintrunner init failed") + initialized_hash_path.write_text(current_hash) + else: + print("✅ Lintrunner plugins already initialized and up to date.") + + +def main() -> None: + """Run lintrunner in the isolated hook environment.""" + venv_dir = find_hook_venv() + python_exe = venv_dir / "bin" / "python" + + if not python_exe.exists(): + sys.exit(f"❌ Python executable not found at {python_exe}") + + try: + print(f"🐍 Virtual env being used: {venv_dir}", file=sys.stderr) + + # 1. Ensure lintrunner binary is available in the venv + check_lintrunner_installed(venv_dir) + + # 2. Check for plugin updates and re-init if needed + initialize_lintrunner_if_needed(venv_dir) + + # 3. Run lintrunner with any passed arguments and propagate its exit code + args = sys.argv[1:] + result = run_lintrunner(venv_dir, args) + + # If lintrunner failed and we're not already in auto-fix mode, suggest the wrapper + if result != 0 and "-a" not in args: + print( + "\n💡 To auto-fix these issues, run: python scripts/lintrunner.py -a", + file=sys.stderr, + ) + + sys.exit(result) + + except KeyboardInterrupt: + print("\n Lintrunner interrupted by user (KeyboardInterrupt)", file=sys.stderr) + sys.exit(1) # Tell git push to fail + + +if __name__ == "__main__": + main() diff --git a/scripts/run_lintrunner.py b/scripts/run_lintrunner.py deleted file mode 100644 index 60d5b545cf917..0000000000000 --- a/scripts/run_lintrunner.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -""" -Pre‑push hook wrapper for Lintrunner. - -✓ Stores a hash of .lintrunner.toml in the venv -✓ Re-runs `lintrunner init` if that file's hash changes -""" - -from __future__ import annotations - -import hashlib -import os -import shutil -import subprocess -import sys -from pathlib import Path - - -REPO_ROOT = Path(__file__).resolve().parents[1] -LINTRUNNER_TOML_PATH = REPO_ROOT / ".lintrunner.toml" - -# This is the path to the pre-commit-managed venv -VENV_ROOT = Path(sys.executable).parent.parent -# Stores the hash of .lintrunner.toml from the last time we ran `lintrunner init` -INITIALIZED_LINTRUNNER_TOML_HASH_PATH = VENV_ROOT / ".lintrunner_plugins_hash" - - -def ensure_lintrunner() -> None: - """Fail if Lintrunner is not on PATH.""" - if shutil.which("lintrunner"): - print("✅ lintrunner is already installed") - return - sys.exit( - "❌ lintrunner is required but was not found on your PATH. Please run the `python scripts/setup_hooks.py` to install to configure lintrunner before using this script. If `git push` still fails, you may need to open an new terminal" - ) - - -def ensure_virtual_environment() -> None: - """Fail if not running within a virtual environment.""" - in_venv = ( - os.environ.get("VIRTUAL_ENV") is not None - or hasattr(sys, "real_prefix") - or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix) - ) - - if not in_venv: - sys.exit( - "❌ This script must be run from within a virtual environment. " - "Please activate your virtual environment before running this script." - ) - - -def compute_file_hash(path: Path) -> str: - """Returns SHA256 hash of a file's contents.""" - hasher = hashlib.sha256() - with path.open("rb") as f: - while chunk := f.read(8192): - hasher.update(chunk) - return hasher.hexdigest() - - -def read_stored_hash(path: Path) -> str | None: - if not path.exists(): - return None - try: - return path.read_text().strip() - except Exception: - return None - - -def initialize_lintrunner_if_needed() -> None: - """Runs lintrunner init if .lintrunner.toml changed since last run.""" - if not LINTRUNNER_TOML_PATH.exists(): - print("⚠️ No .lintrunner.toml found. Skipping init.") - return - - print( - f"INITIALIZED_LINTRUNNER_TOML_HASH_PATH = {INITIALIZED_LINTRUNNER_TOML_HASH_PATH}" - ) - current_hash = compute_file_hash(LINTRUNNER_TOML_PATH) - stored_hash = read_stored_hash(INITIALIZED_LINTRUNNER_TOML_HASH_PATH) - - if current_hash == stored_hash: - print("✅ Lintrunner plugins already initialized and up to date.") - return - - print("🔁 Running `lintrunner init` …", file=sys.stderr) - subprocess.check_call(["lintrunner", "init"]) - INITIALIZED_LINTRUNNER_TOML_HASH_PATH.write_text(current_hash) - - -def main() -> None: - # 0. Ensure we're running in a virtual environment - ensure_virtual_environment() - print(f"🐍 Virtual env being used: {VENV_ROOT}", file=sys.stderr) - - # 1. Ensure lintrunner binary is available - ensure_lintrunner() - - # 2. Check for plugin updates and re-init if needed - initialize_lintrunner_if_needed() - - # 3. Run lintrunner with any passed arguments and propagate its exit code - args = sys.argv[1:] # Forward all arguments to lintrunner - result = subprocess.call(["lintrunner"] + args) - sys.exit(result) - - -if __name__ == "__main__": - main() diff --git a/scripts/setup_hooks.py b/scripts/setup_hooks.py index 41f08d45e98b6..e8effe7f82325 100644 --- a/scripts/setup_hooks.py +++ b/scripts/setup_hooks.py @@ -1,31 +1,51 @@ #!/usr/bin/env python3 """ -Bootstrap Git pre‑push hook. +Bootstrap Git pre‑push hook with isolated virtual environment. ✓ Requires uv to be installed (fails if not available) -✓ Installs/updates pre‑commit with uv (global, venv‑proof) -✓ Registers the repo's pre‑push hook and freezes hook versions +✓ Creates isolated venv in .git/hooks/linter/.venv/ for hook dependencies +✓ Installs lintrunner only in the isolated environment +✓ Creates direct git hook that bypasses pre-commit Run this from the repo root (inside or outside any project venv): python scripts/setup_hooks.py + +IMPORTANT: The generated git hook references scripts/lintrunner.py. If users checkout +branches that don't have this file, git push will fail with "No such file or directory". +Users would need to either: +1. Re-run the old setup_hooks.py from that branch, or +2. Manually delete .git/hooks/pre-push to disable hooks temporarily, or +3. Switch back to a branch with the new scripts/lintrunner.py """ from __future__ import annotations +import shlex import shutil import subprocess import sys from pathlib import Path -from typing import Tuple + + +# Add scripts directory to Python path so we can import lintrunner module +scripts_dir = Path(__file__).parent +sys.path.insert(0, str(scripts_dir)) + +# Import shared functions from lintrunner module +from lintrunner import find_repo_root, get_hook_venv_path + + +# Restore sys.path to avoid affecting other imports +sys.path.pop(0) # ─────────────────────────────────────────── # Helper utilities # ─────────────────────────────────────────── -def run(cmd: list[str]) -> None: +def run(cmd: list[str], cwd: Path = None) -> None: print(f"$ {' '.join(cmd)}") - subprocess.check_call(cmd) + subprocess.check_call(cmd, cwd=cwd) def which(cmd: str) -> bool: @@ -34,28 +54,7 @@ def which(cmd: str) -> bool: def ensure_uv() -> None: if which("uv"): - # Ensure the path uv installs binaries to is part of the system path - print("$ uv tool update-shell") - result = subprocess.run( - ["uv", "tool", "update-shell"], capture_output=True, text=True - ) - if result.returncode == 0: - # Check if the output indicates changes were made - if ( - "Updated" in result.stdout - or "Added" in result.stdout - or "Modified" in result.stdout - ): - print( - "⚠️ Shell configuration updated. You may need to restart your terminal for changes to take effect." - ) - elif result.stdout.strip(): - print(result.stdout) - return - else: - sys.exit( - f"❌ Warning: uv tool update-shell failed: {result.stderr}. uv installed tools may not be available." - ) + return sys.exit( "\n❌ uv is required but was not found on your PATH.\n" @@ -65,29 +64,6 @@ def ensure_uv() -> None: ) -def ensure_tool_installed( - tool: str, force_update: bool = False, python_ver: Tuple[int, int] = None -) -> None: - """ - Checks to see if the tool is available and if not (or if force update requested) then - it reinstalls it. - - Returns: Whether or not the tool is available on PATH. If it's not, a new terminal - needs to be opened before git pushes work as expected. - """ - if force_update or not which(tool): - print(f"Ensuring latest {tool} via uv …") - command = ["uv", "tool", "install", "--force", tool] - if python_ver: - # Add the Python version to the command if specified - command.extend(["--python", f"{python_ver[0]}.{python_ver[1]}"]) - run(command) - if not which(tool): - print( - f"\n⚠️ {tool} installation succeed, but it's not on PATH. Launch a new terminal if your git pushes don't work.\n" - ) - - if sys.platform.startswith("win"): print( "\n⚠️ Lintrunner is not supported on Windows, so there are no pre-push hooks to add. Exiting setup.\n" @@ -95,52 +71,61 @@ def ensure_tool_installed( sys.exit(0) # ─────────────────────────────────────────── -# 1. Install dependencies +# 1. Setup isolated hook environment # ─────────────────────────────────────────── ensure_uv() -# Ensure pre-commit is installed globally via uv -ensure_tool_installed("pre-commit", force_update=True, python_ver=(3, 9)) +# Find repo root and setup hook directory +repo_root = find_repo_root() +venv_dir = get_hook_venv_path() +hooks_dir = venv_dir.parent.parent # Go from .git/hooks/linter/.venv to .git/hooks + -# Don't force a lintrunner update because it might break folks -# who already have it installed in a different way -ensure_tool_installed("lintrunner") +print(f"Setting up isolated hook environment in {venv_dir}") + +# Create isolated virtual environment for hooks +if venv_dir.exists(): + print("Removing existing hook venv...") + shutil.rmtree(venv_dir) + +run(["uv", "venv", str(venv_dir), "--python", "3.9"]) + +# Install lintrunner in the isolated environment +print("Installing lintrunner in isolated environment...") +run( + ["uv", "pip", "install", "--python", str(venv_dir / "bin" / "python"), "lintrunner"] +) # ─────────────────────────────────────────── -# 2. Activate (or refresh) the pre‑push hook +# 2. Create direct git pre-push hook # ─────────────────────────────────────────── -# ── Activate (or refresh) the repo’s pre‑push hook ────────────────────────── -# Creates/overwrites .git/hooks/pre‑push with a tiny shim that will call -# `pre-commit run --hook-stage pre-push` on every `git push`. -# This is why we need to install pre-commit globally. -# -# The --allow-missing-config flag lets pre-commit succeed if someone changes to -# a branch that doesn't have pre-commit installed -run( - [ - "uv", - "tool", - "run", - "pre-commit", - "install", - "--hook-type", - "pre-push", - "--allow-missing-config", - ] +pre_push_hook = hooks_dir / "pre-push" +python_exe = venv_dir / "bin" / "python" +lintrunner_script_path_quoted = shlex.quote( + str(repo_root / "scripts" / "lintrunner.py") ) -# ── Pin remote‑hook versions for reproducibility ──────────────────────────── -# (Note: we don't have remote hooks right now, but it future-proofs this script) -# 1. `autoupdate` bumps every remote hook’s `rev:` in .pre-commit-config.yaml -# to the latest commit on its default branch. -# 2. `--freeze` immediately rewrites each `rev:` to the exact commit SHA, -# ensuring all contributors and CI run identical hook code. -run(["uv", "tool", "run", "pre-commit", "autoupdate", "--freeze"]) +hook_script = f"""#!/bin/bash +set -e + +# Check if lintrunner script exists (user might be on older commit) +if [ ! -f {lintrunner_script_path_quoted} ]; then + echo "⚠️ {lintrunner_script_path_quoted} not found - skipping linting (likely on an older commit)" + exit 0 +fi + +# Run lintrunner wrapper using the isolated venv's Python +{shlex.quote(str(python_exe))} {lintrunner_script_path_quoted} +""" +print(f"Creating git pre-push hook at {pre_push_hook}") +pre_push_hook.write_text(hook_script) +pre_push_hook.chmod(0o755) # Make executable print( - "\n✅ pre‑commit is installed globally via uv and the pre‑push hook is active.\n" + "\n✅ Isolated hook environment created and pre‑push hook is active.\n" " Lintrunner will now run automatically on every `git push`.\n" + f" Hook dependencies are isolated in {venv_dir}\n" ) From be53f609aaf6f01e2863f490975ea9eaac3ee9ff Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 12 Aug 2025 02:03:15 +0000 Subject: [PATCH 0243/1424] fix retaining multimem in symmetric memory (#160343) fixes OOM in #160289 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160343 Approved by: https://github.com/eqy --- c10/cuda/driver_api.h | 3 ++- .../c10d/symm_mem/CUDASymmetricMemory.cu | 14 ++++++++++++-- .../c10d/symm_mem/CUDASymmetricMemory.hpp | 4 +++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 9800809d1e535..6702cb9b532d4 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -53,7 +53,8 @@ #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ _(cuMulticastAddDevice, 12030) \ _(cuMulticastBindMem, 12030) \ - _(cuMulticastCreate, 12030) + _(cuMulticastCreate, 12030) \ + _(cuMulticastUnbind, 12030) #else #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) #endif diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index e9fc7aefaf57e..b2f216335bb11 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -46,11 +46,13 @@ AllocationRef::AllocationRef( void* ptr, HandleType handle, size_t block_size, - int device_idx) + int device_idx, + bool is_multicast) : ptr(ptr), handle(handle), block_size(block_size), - device_idx(device_idx) {} + device_idx(device_idx), + is_multicast(is_multicast) {} AllocationRef::~AllocationRef() { if (is_finalizing()) { @@ -63,6 +65,10 @@ AllocationRef::~AllocationRef() { auto driver_api = c10::cuda::DriverAPI::get(); C10_CUDA_DRIVER_CHECK( driver_api->cuMemUnmap_(reinterpret_cast(ptr), block_size)); + if (is_multicast) { + C10_CUDA_DRIVER_CHECK( + driver_api->cuMulticastUnbind_(handle, device_idx, 0, block_size)); + } C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle)); #elif defined(USE_ROCM) C10_HIP_CHECK(hipMemUnmap(reinterpret_cast(ptr), block_size)); @@ -797,6 +803,10 @@ c10::intrusive_ptr make_symm_mem( for (int r = 0; r < world_size; ++r) { if (r == rank) { alloc_refs.emplace_back(block->alloc_ref); + if (mc_addr != nullptr) { + alloc_refs.push_back(c10::make_intrusive( + mc_addr, mc_handle, block->block_size, block->device_idx, true)); + } continue; } alloc_refs.push_back(c10::make_intrusive( diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp index a5340ffc9806e..f61d8f9622a7b 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp @@ -15,12 +15,14 @@ struct AllocationRef : public c10::intrusive_ptr_target { HandleType handle; size_t block_size; int device_idx; + bool is_multicast; AllocationRef( void* ptr, HandleType handle, size_t block_size, - int device_idx); + int device_idx, + bool is_multicast = false); ~AllocationRef(); }; From eed9dbf70f43ee529fec78ac00ed9a4fd74c6e76 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 12 Aug 2025 02:24:17 +0000 Subject: [PATCH 0244/1424] [ROCm] Add torch/_rocm_init.py to .gitignore. (#159806) Follow-up to https://github.com/pytorch/pytorch/pull/155285. Build scripts like https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py generate this file with contents like: ```python def initialize(): import rocm_sdk rocm_sdk.initialize_process( preload_shortnames=['amd_comgr', 'amdhip64', 'hiprtc', 'hipblas', 'hipfft', 'hiprand', 'hipsparse', 'hipsolver', 'hipblaslt', 'miopen'], check_version='7.0.0rc20250804') ``` We may also have https://github.com/pytorch/pytorch/blob/main/tools/amd_build/build_amd.py do the same thing as more of that build support moves here into the upstream PyTorch repository itself (see https://github.com/pytorch/pytorch/issues/159520). This file is then loaded if present here: https://github.com/pytorch/pytorch/blob/a7f3bdf550635c796e53442375477efe98fe5447/torch/__init__.py#L145-L157 Given that the file is generated by build scripts, I think adding it to `.gitignore` makes sense, as that will prevent accidental check-ins and keep local history cleaner. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159806 Approved by: https://github.com/jeffdaily --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index b4e78e642b245..ed7208e55aa00 100644 --- a/.gitignore +++ b/.gitignore @@ -146,6 +146,9 @@ merge_record.json torchgen/packaged/* !torchgen/packaged/README.md +# This file is injected by ROCm build scripts to bootstrap in torch/__init__.py. +torch/_rocm_init.py + # IPython notebook checkpoints .ipynb_checkpoints From bfc873d02ec413344717493e4175a902921359fd Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 12 Aug 2025 02:45:46 +0000 Subject: [PATCH 0245/1424] [ROCm][Windows] Revert copying hipblaslt and rocblas dirs. (#159083) This reverts the changes from https://github.com/pytorch/pytorch/commit/b367e5f6a6c5853d0206bfd43d8b4a7cb76704f1. This will also close https://github.com/pytorch/pytorch/pull/158922. Since https://github.com/pytorch/pytorch/commit/30387ab2e485384ab2e67084a1e2c5569190ba92, ROCm is bootstrapped using the 'rocm' Python module which contains these files (see https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md), so they do not need to be bundled into torch/lib. There was also a bug in here - if `ROCM_DIR` is unset, the code crashes: ``` File "D:\projects\TheRock\external-builds\pytorch\.venv\Lib\site-packages\setuptools\_distutils\dist.py", line 1002, in run_command cmd_obj.run() File "D:\b\pytorch_main\setup.py", line 853, in run rocm_dir_path = Path(os.environ["ROCM_DIR"]) ~~~~~~~~~~^^^^^^^^^^^^ File "", line 714, in __getitem__ KeyError: 'ROCM_DIR' ``` The code could have checked for `ROCM_PATH` too. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159083 Approved by: https://github.com/jeffdaily --- setup.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/setup.py b/setup.py index ad00317da0866..cd04f5313aa43 100644 --- a/setup.py +++ b/setup.py @@ -1226,23 +1226,6 @@ def run(self) -> None: target_dir.mkdir(parents=True, exist_ok=True) self.copy_file(export_lib, target_lib) - # In ROCm on Windows case copy rocblas and hipblaslt files into - # torch/lib/rocblas/library and torch/lib/hipblaslt/library - if str2bool(os.getenv("USE_ROCM")): - rocm_dir_path = Path(os.environ["ROCM_DIR"]) - rocm_bin_path = rocm_dir_path / "bin" - rocblas_dir = rocm_bin_path / "rocblas" - target_rocblas_dir = target_dir / "rocblas" - target_rocblas_dir.mkdir(parents=True, exist_ok=True) - self.copy_tree(rocblas_dir, str(target_rocblas_dir)) - - hipblaslt_dir = rocm_bin_path / "hipblaslt" - target_hipblaslt_dir = target_dir / "hipblaslt" - target_hipblaslt_dir.mkdir(parents=True, exist_ok=True) - self.copy_tree(hipblaslt_dir, str(target_hipblaslt_dir)) - else: - report("The specified environment variable does not exist.") - def build_extensions(self) -> None: self.create_compile_commands() From 32e5e2f596d55bb9441d5d53f3c58bcb55828047 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Tue, 12 Aug 2025 04:04:49 +0000 Subject: [PATCH 0246/1424] [vllm hash update] update the pinned vllm hash (#160259) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vllm hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160259 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vllm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt index e5260797d2150..b86f3276765d4 100644 --- a/.github/ci_commit_pins/vllm.txt +++ b/.github/ci_commit_pins/vllm.txt @@ -1 +1 @@ -35afe1b30b154114dc2ee8329e12f8cf3fe9f576 +458e74eb907f96069e6d8a4f3c9f457001fef2ea From 10bc36fe840cb3510fab84d2ea22663b76702f1e Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 11 Aug 2025 17:57:31 -0700 Subject: [PATCH 0247/1424] Get tensor subclasses and torch.library.triton_op to dispatch correctly (#160341) Short-term fix for https://github.com/pytorch/pytorch/issues/160333 The problem is: 1) `triton_op` adds a decomposition for FunctionalTensorMode for this operation 2) Tensor Subclasses rely on FunctionalTensorMode's `__torch_dispatch__` returning NotImplemented. 3) `triton_op`'s FunctionalTensorMode decomposition takes precedence over FunctionalTensorMode's decomposition. The easy fix is to copy-paste the FunctionalTensorMode's NotImplemented return logic into the decomposition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160341 Approved by: https://github.com/drisspg --- test/inductor/test_triton_kernels.py | 34 ++++++++++++++++++++++++++++ torch/_library/triton.py | 17 ++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 6804a500fbddb..fc9f92477c79d 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -3583,6 +3583,40 @@ def f(x, y): self.assertNotIn(libname, code) self.assertNotIn(opname, code) + @requires_gpu + def test_subclass(self): + libname = "my_cool_namespace" + opname = "my_triton_operator" + + @torch.library.triton_op(f"{libname}::{opname}", mutates_args={}) + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = output.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) + + return output + + def f(x, y): + return add(x, y) + + x0 = torch.randn(3, device=GPU_TYPE) + y0 = torch.randn(3, device=GPU_TYPE) + x1 = torch.randn(3, device=GPU_TYPE) + y1 = torch.randn(3, device=GPU_TYPE) + + from torch.testing._internal.two_tensor import TwoTensor + + x = TwoTensor(x0, x1) + y = TwoTensor(y0, y1) + + out = torch.compile(f, fullgraph=True)(x, y) + expected = f(x, y) + self.assertEqual(out.a, expected.a) + self.assertEqual(out.b, expected.b) + @requires_gpu @dynamo_config.patch("recompile_limit", 1) def test_triton_dynamic_grid_no_recompile(self): diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 72805c765d86d..17d02a9945630 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -155,6 +155,23 @@ def functional_decomp( # type: ignore[no-untyped-def] if custom_triton_ops_decomposition_disabled(): return mode.__torch_dispatch__(op, types, args, kwargs) else: + # TODO: https://github.com/pytorch/pytorch/issues/160333 + # We should deduplicate the unrecognized_types logic. + import torch._subclasses + + unrecognized_types = [ + t + for t in types + if not issubclass(t, torch._subclasses.FakeTensor) + and t + not in [ + torch.Tensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ] + ] + + if unrecognized_types: + return NotImplemented with mode: return fn(*args, **kwargs) From edaa151d0d5a4e75fbec9843f49cc78770eb61fb Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 11 Aug 2025 16:25:13 -0700 Subject: [PATCH 0248/1424] [CI] Move CUDA tests to trunk workflow (#160379) Which is getting run before PR is merged anyway, but according to 3X less frequently than pull workflow according to [Flambeau](https://pytorchci.grafana.net/public-dashboards/1c571e79090443eaaa9811db71f8d23b) image I.e. that will probably results in some longer time to signal, but considering that frequency of changes to eager PyTorch-on-CUDA slowed down and Inductor changes are decorated with ciflow/inductor, this looks like an acceptable tradeoff to reduce costs Pull Request resolved: https://github.com/pytorch/pytorch/pull/160379 Approved by: https://github.com/izaitsevfb --- .github/workflows/pull.yml | 36 ------------------------------------ .github/workflows/trunk.yml | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index cc2c4e89664ba..3fe8ac15a3059 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -254,42 +254,6 @@ jobs: timeout-minutes: 600 secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc11-build: - name: linux-jammy-cuda12.8-py3.10-gcc11 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc11 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - cuda-arch-list: '7.5 8.9' - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc11-test: - name: linux-jammy-cuda12.8-py3.10-gcc11 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_8-py3_10-gcc11-build - - target-determination - with: - timeout-minutes: 360 - build-environment: linux-jammy-cuda12.8-py3.10-gcc11 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build: name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index c7cf4c84e1888..19b0e88b5921a 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -63,6 +63,43 @@ jobs: ]} secrets: inherit + linux-jammy-cuda12_8-py3_10-gcc11-build: + name: linux-jammy-cuda12.8-py3.10-gcc11 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: '7.5 8.9' + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-test: + name: linux-jammy-cuda12.8-py3.10-gcc11 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-build + - target-determination + with: + timeout-minutes: 360 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} + secrets: inherit + + # no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated linux-jammy-cuda12_8-py3_10-gcc11-no-ops-build: name: linux-jammy-cuda12.8-py3.10-gcc11-no-ops From 5f1010fbb3850d99c8fdf9a9de2f79260cdc586a Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 12 Aug 2025 04:37:58 +0000 Subject: [PATCH 0249/1424] [Graph Partition] Pass all OSS unit tests (#154667) Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315). Run the same diff on two days and both show speedup on average. [first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d) image [second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf) image Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667 Approved by: https://github.com/eellison --- test/inductor/test_compiled_autograd.py | 22 +- test/inductor/test_control_flow.py | 3 + test/inductor/test_cuda_repro.py | 6 +- test/inductor/test_cudagraph_trees.py | 330 +++++++++++++++++++-- test/inductor/test_inductor_annotations.py | 7 +- test/inductor/test_memory.py | 34 ++- test/inductor/test_torchinductor.py | 296 ------------------ torch/_inductor/codegen/wrapper.py | 10 +- torch/_inductor/config.py | 6 +- torch/_inductor/cudagraph_utils.py | 5 +- torch/_inductor/scheduler.py | 11 +- torch/_inductor/utils.py | 7 + 12 files changed, 408 insertions(+), 329 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 241528b159cc1..dff94b4aa0927 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -3085,7 +3085,16 @@ def backward(ctx, gO): self.assertEqual(counters["compiled_autograd"]["captures"], 1) # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + if inductor_config.graph_partition: + # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops + # and cudagraphify the remaining computation. So there is no cudagraph skip. + expected_cudagraph_skips = 0 + else: + expected_cudagraph_skips = 1 + + self.assertEqual( + counters["inductor"]["cudagraph_skips"], expected_cudagraph_skips + ) @scoped_load_inline @requires_cuda_and_triton @@ -3150,9 +3159,18 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. # In the future, we can consider having a cpu scalar movement pass sometime after we trace # into the custom C++ autograd::Function (like in AOTDispatcher) + if inductor_config.graph_partition: + # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops + # and cudagraphify the remaining computation. So there is no cudagraph skip. + expected_cudagraph_skips = 0 + elif inductor_config.cpp_wrapper: + expected_cudagraph_skips = 2 + else: + expected_cudagraph_skips = 1 + self.assertEqual( counters["inductor"]["cudagraph_skips"], - 2 if inductor_config.cpp_wrapper else 1, + expected_cudagraph_skips, ) def test_logs(self): diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 107a65d6fa1df..511b9cea5e14d 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -472,6 +472,9 @@ def false_fn(x): @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @torch._inductor.config.patch(size_asserts=False) + # TODO: graph partition does not support creating tensor + # with dynamic shape in conditional subgraph yet + @torch._inductor.config.patch(graph_partition=False) def test_cond_unbacked_symint_inner(self, device): class Model(torch.nn.Module): def forward(self, p, a): diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 00511c572239e..53506698297f1 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -189,9 +189,9 @@ def f(q, k, v, mask): # padded bias should have an expanded dim FileCheck().check("buf0 =").check_same(", 0, ").run(code[0]) # single fused padded kernel - FileCheck().check("def call").check_count( - "empty_strided_cuda", 1, exactly=True - ).check("return").run(code[0]) + FileCheck().check_count("empty_strided_cuda(", 1, exactly=True).check( + "return" + ).run(code[0]) self.assertEqual(out, f(*inputs)) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 1408a0208cf06..763384671eb52 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -279,10 +279,14 @@ def foo(x, y): with capture_stderr() as captured_output: foo(torch.ones([10], device="cuda"), torch.ones([20])) - FileCheck().check( - "skipping cudagraphs due to cpu device (arg1_1). Found from" - ).check("y + 2").run(captured_output[0]) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + if torch._inductor.config.graph_partition: + # graph partition splits on cpu ops + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + else: + FileCheck().check( + "skipping cudagraphs due to cpu device (arg1_1). Found from" + ).check("y + 2").run(captured_output[0]) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) with capture_stderr() as captured_output: foo( @@ -292,7 +296,10 @@ def foo(x, y): FileCheck().check("skipping cudagraphs due to multiple devices").run( captured_output[0] ) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 1 if torch._inductor.config.graph_partition else 2, + ) @torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True) def test_skip_symbolic(self): @@ -807,10 +814,16 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + if torch._inductor.config.graph_partition: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 0), (0, 2)], + ) + else: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) @@ -1127,8 +1140,13 @@ def foo2(x): node = self.curr_node() first_node = next(node._path_from_root) - self.assertFalse(first_node.unaliased_in_all_paths[0]) - self.assertTrue(first_node.cached_tensor_outputs[0] is None) + if torch._inductor.config.graph_partition: + # graph partition may changed the order of outputs + self.assertFalse(first_node.unaliased_in_all_paths[1]) + self.assertTrue(first_node.cached_tensor_outputs[1] is None) + else: + self.assertFalse(first_node.unaliased_in_all_paths[0]) + self.assertTrue(first_node.cached_tensor_outputs[0] is None) @torch._inductor.config.patch("implicit_fallbacks", True) def test_multinomial(self): @@ -1631,10 +1649,16 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + if torch._inductor.config.graph_partition: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 0), (0, 2)], + ) + else: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) def test_separate_recordings(self): @@ -2137,8 +2161,8 @@ def forward(self, x) -> torch.Tensor: with self.assertRaisesRegex( Exception, r"(?s)static input data pointer changed.\n" - r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*" - r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," + r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*" + r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*," r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n", ): self.curr_node().run( @@ -3551,6 +3575,278 @@ def run(padded_size, original_size): self.assertEqual(self.get_manager().new_graph_id().id, 2) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_simple(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + self.assertEqual(eager_out, compiled_out) + + _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) + + if not config.cpp_wrapper: + FileCheck().check("def partition_0(args):").check( + "recursively_apply_fns = runner.recursively_apply_fns" + ).run(code[0]) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_foreach_op(self): + def fn(a0, a1): + c = torch._foreach_abs([a0, a1]) + return torch.mul(c[0], a0) + + compiled_fn = torch.compile(fn) + + a0 = torch.randn(2, 3, device="cuda") + a1 = torch.randn(2, 3, device="cuda") + eager_out = fn(a0, a1) + compiled_out = compiled_fn(a0, a1) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_condition_op(self): + def f(p, b): + def true_fn(x): + return torch.cos(x) + + def false_fn(x): + return torch.sin(x) + + return torch.cond(p, true_fn, false_fn, [b]) + + compiled_f = torch.compile(f) + + # static shape + p = torch.tensor([True], device="cuda") + a = torch.ones([2, 3], device="cuda") + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + # dynamic shape with backed symint + p = torch.tensor([True], device="cuda") + a = torch.ones([4, 5], device="cuda") + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_unbacked_symint_multi_output_layout(self): + def f(p, size_tensor): + size_val = size_tensor.item() + b = torch.ones([size_val, 3], device="cuda") + + def true_fn(x): + return torch.cos(x), torch.cos(x) + 1 + + def false_fn(x): + return torch.sin(x), torch.sin(x) + 1 + + cond_out = torch.cond(p, true_fn, false_fn, [b]) + return cond_out[0] + cond_out[1] + + compiled_f = torch.compile(f) + p = torch.tensor([True], device="cuda") + size_tensor = torch.tensor(2, device="cuda") + eager_out = f(p, size_tensor) + compiled_out = compiled_f(p, size_tensor) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + x, y = ( + torch.ones(4, 4, device="cuda"), + torch.randn(4, 4, device="cuda"), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_cat_backward(self): + def f(x, w): + y = torch.cat((x, x), dim=0) + z = y @ w + return z @ z.T + + compiled_f = torch.compile(f) + + for shape in (2, 3): + torch.manual_seed(42) + eager_x = torch.randn(shape, 2, device="cuda") + eager_w = torch.randn(2, 2, device="cuda", requires_grad=True) + torch.manual_seed(42) + compiled_x = torch.randn(shape, 2, device="cuda") + compiled_w = torch.randn(2, 2, device="cuda", requires_grad=True) + + f(eager_x, eager_w).sum().backward() + compiled_f(compiled_x, compiled_w).sum().backward() + self.assertEqual(eager_w.grad, compiled_w.grad) + + @dynamo_config.patch("capture_dynamic_output_shape_ops", True) + @config.patch(implicit_fallbacks=True) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_nested_indirect_indexing(self): + def nested(x, repeats): + rank = torch.arange(repeats.numel(), device=x.device) + index = rank.repeat_interleave(repeats, dim=0) + return torch.index_select(x, index=index, dim=0) + + example_inputs = ( + torch.randn((32, 64), device="cuda"), + repeats := torch.tensor([5, 10, 15], device="cuda"), + ) + torch._dynamo.mark_dynamic(repeats, 0) # create backed symint + + nested_opt = torch.compile(nested, backend="inductor") + + expect = nested(*example_inputs) + actual = nested_opt(*example_inputs) + self.assertEqual(expect, actual) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_mutation_index(self): + x = torch.zeros(7, device="cuda") + + def fn(n, a): + a[n] = -1 + return a + + opt_fn = torch.compile(fn, fullgraph=True) + + for n in range(2, x.shape[0]): + opt_fn(n, x) + self.assertEqual(x[n], -1) + + # Negative index triggers new compilation. + opt_fn(-x.shape[0], x) + + self.assertEqual(x[0], -1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_unbacked_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y) + eager_out = f(x, y) + self.assertEqual(compiled_out, eager_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_dynamic_scalar_inputs(self): + def f(x, y, integer): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + z += integer + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y, 5) + self.assertEqual(compiled_out, f(x, y, 5)) + + compiled_out = f_compiled(x, y, 6) + self.assertEqual(compiled_out, f(x, y, 6)) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_item(self): + def f(x): + y = x + 1 + scalar = y.item() + return x + y + scalar + + compiled_f = torch.compile(f) + compiled_out = compiled_f(torch.tensor(1, device="cuda")) + self.assertEqual(compiled_out, f(torch.tensor(1, device="cuda"))) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_buffer_reuse(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x1 + y1 + x @ y + u = (y_cpu.to("cuda") + 2) @ y + 3 + u_cpu = u.cpu() + 2 + return z + u_cpu.to("cuda") + + x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_fused_scheduler_node(self): + def foo(x): + x = x * 20 + x_alias = x[0] + y = x * 10 + y_alias = y[0] + torch._dynamo.graph_break() + ind = torch.tensor(4, device="cuda") + x_alias2 = x[ind:] + y_alias2 = y[ind:] + return x, x_alias, x_alias2, y_alias, y_alias2 + + compiled_foo = torch.compile(foo) + x = torch.rand([20, 20], device="cuda") + + eager_out = foo(x) + compiled_out = compiled_foo(x) + self.assertEqual(eager_out, compiled_out) + def test_meta_tensor(self): def foobar(x, y): return x * 2, y * 3 diff --git a/test/inductor/test_inductor_annotations.py b/test/inductor/test_inductor_annotations.py index bee7e0ad917da..3824b25cdeaea 100644 --- a/test/inductor/test_inductor_annotations.py +++ b/test/inductor/test_inductor_annotations.py @@ -31,10 +31,11 @@ def test_training_annotation(self): code = self.get_code() self.assertTrue("from torch.cuda import nvtx" in code) - self.assertEqual( - code.count("training_annotation = nvtx._device_range_start('inference')"), 1 + self.assertTrue( + code.count("training_annotation = nvtx._device_range_start('inference')") + >= 1 ) - self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1) + self.assertTrue(code.count("nvtx._device_range_end(training_annotation)") >= 1) if __name__ == "__main__": diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 81f7ea03d3bb4..80372bca9fdca 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -68,9 +68,16 @@ def test_reorder_peak_memory(self): outp_corr = self.model(self.inputs) compiled_model = torch.compile(self.model) code = run_and_get_triton_code(compiled_model, self.inputs) + + call_str = ( + "def call(self, args):" + if torch._inductor.config.graph_partition + else "def call(args):" + ) + ( FileCheck() - .check("def call(args):") + .check(call_str) .check("buf1 = ") .check("buf0 = ") .check("buf2 = ") @@ -105,6 +112,12 @@ def reorder_with_only_lpmf( methods=[memory.topological_sort_lpmf], ) + call_str = ( + "def call(self, args):" + if torch._inductor.config.graph_partition + else "def call(args):" + ) + with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_lpmf ): @@ -113,7 +126,7 @@ def reorder_with_only_lpmf( code = run_and_get_triton_code(compiled_model, self.inputs) ( FileCheck() - .check("def call(args):") + .check(call_str) .check("buf1 = ") .check("buf0 = ") .check("buf2 = ") @@ -148,15 +161,22 @@ def reorder_with_only_bfs( methods=[memory.topological_sort_bfs], ) + call_str = ( + "def call(self, args):" + if torch._inductor.config.graph_partition + else "def call(args):" + ) + with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_bfs ): compiled_model = torch.compile(self.model) code = run_and_get_triton_code(compiled_model, self.inputs) + ( FileCheck() - .check("def call(args):") + .check(call_str) .check("buf0 = ") .check("buf1 = ") .check("buf2 = ") @@ -191,6 +211,12 @@ def reorder_with_only_dfs( methods=[memory.topological_sort_dfs], ) + call_str = ( + "def call(self, args):" + if torch._inductor.config.graph_partition + else "def call(args):" + ) + with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_dfs ): @@ -199,7 +225,7 @@ def reorder_with_only_dfs( code = run_and_get_triton_code(compiled_model, self.inputs) ( FileCheck() - .check("def call(args):") + .check(call_str) .check("buf0 = ") .check("buf2 = ") .check("buf4 = ") diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index cdcedd5a1771e..385a75d98f944 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -15044,302 +15044,6 @@ def fn(x): "'XBLOCK': 'constexpr'" ).run(code[0]) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - self.assertEqual(eager_out, compiled_out) - - _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) - - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").check( - "(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)" - ).check("recursively_apply_fns = runner.recursively_apply_fns").run( - code[0] - ) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_foreach_op(self): - def fn(a0, a1): - c = torch._foreach_abs([a0, a1]) - return torch.mul(c[0], a0) - - compiled_fn = torch.compile(fn) - - a0 = torch.randn(2, 3, device=self.device) - a1 = torch.randn(2, 3, device=self.device) - eager_out = fn(a0, a1) - compiled_out = compiled_fn(a0, a1) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_multiple_functions(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - def g(x): - return x + 1 - - x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = g(f(x, y)) - - f_compiled = torch.compile(f) - g_compiled = torch.compile(g) - compiled_out = g_compiled(f_compiled(x_cloned, y_cloned)) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_condition_op(self): - def f(p, b): - def true_fn(x): - return torch.cos(x) - - def false_fn(x): - return torch.sin(x) - - return torch.cond(p, true_fn, false_fn, [b]) - - compiled_f = torch.compile(f) - - # static shape - p = torch.tensor([True], device=self.device) - a = torch.ones([2, 3], device=self.device) - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - # dynamic shape with backed symint - p = torch.tensor([True], device=self.device) - a = torch.ones([4, 5], device=self.device) - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_unbacked_symint_multi_output_layout(self): - def f(p, size_tensor): - size_val = size_tensor.item() - b = torch.ones([size_val, 3], device=GPU_TYPE) - - def true_fn(x): - return torch.cos(x), torch.cos(x) + 1 - - def false_fn(x): - return torch.sin(x), torch.sin(x) + 1 - - cond_out = torch.cond(p, true_fn, false_fn, [b]) - return cond_out[0] + cond_out[1] - - compiled_f = torch.compile(f) - p = torch.tensor([True], device=GPU_TYPE) - size_tensor = torch.tensor(2, device=GPU_TYPE) - eager_out = f(p, size_tensor) - compiled_out = compiled_f(p, size_tensor) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - x, y = ( - torch.ones(4, 4, device=self.device), - torch.randn(4, 4, device=self.device), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_cat_backward(self): - def f(x, w): - y = torch.cat((x, x), dim=0) - z = y @ w - return z @ z.T - - compiled_f = torch.compile(f) - - for shape in (2, 3): - torch.manual_seed(42) - eager_x = torch.randn(shape, 2, device=self.device) - eager_w = torch.randn(2, 2, device=self.device, requires_grad=True) - torch.manual_seed(42) - compiled_x = torch.randn(shape, 2, device=self.device) - compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True) - - f(eager_x, eager_w).sum().backward() - compiled_f(compiled_x, compiled_w).sum().backward() - self.assertEqual(eager_w.grad, compiled_w.grad) - - @dynamo_config.patch("capture_dynamic_output_shape_ops", True) - @config.patch(implicit_fallbacks=True) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_nested_indirect_indexing(self): - def nested(x, repeats): - rank = torch.arange(repeats.numel(), device=x.device) - index = rank.repeat_interleave(repeats, dim=0) - return torch.index_select(x, index=index, dim=0) - - example_inputs = ( - torch.randn((32, 64), device=self.device), - repeats := torch.tensor([5, 10, 15], device=self.device), - ) - torch._dynamo.mark_dynamic(repeats, 0) # create backed symint - - nested_opt = torch.compile(nested, backend="inductor") - - expect = nested(*example_inputs) - actual = nested_opt(*example_inputs) - self.assertEqual(expect, actual) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_mutation_index(self): - x = torch.zeros(7, device=GPU_TYPE) - - def fn(n, a): - a[n] = -1 - return a - - opt_fn = torch.compile(fn, fullgraph=True) - - for n in range(2, x.shape[0]): - opt_fn(n, x) - self.assertEqual(x[n], -1) - - # Negative index triggers new compilation. - opt_fn(-x.shape[0], x) - - self.assertEqual(x[0], -1) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_unbacked_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y) - eager_out = f(x, y) - self.assertEqual(compiled_out, eager_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_dynamic_scalar_inputs(self): - def f(x, y, integer): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - z += integer - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y, 5) - self.assertEqual(compiled_out, f(x, y, 5)) - - compiled_out = f_compiled(x, y, 6) - self.assertEqual(compiled_out, f(x, y, 6)) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_item(self): - def f(x): - y = x + 1 - scalar = y.item() - return x + y + scalar - - compiled_f = torch.compile(f) - compiled_out = f(torch.tensor(1, device=GPU_TYPE)) - self.assertEqual(compiled_out, f(torch.tensor(1, device=GPU_TYPE))) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_buffer_reuse(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x1 + y1 + x @ y - u = (y_cpu.to(GPU_TYPE) + 2) @ y + 3 - u_cpu = u.cpu() + 2 - return z + u_cpu.to(GPU_TYPE) - - x, y = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_fused_scheduler_node(self): - def foo(x): - x = x * 20 - x_alias = x[0] - y = x * 10 - y_alias = y[0] - torch._dynamo.graph_break() - ind = torch.tensor(4, device=GPU_TYPE) - x_alias2 = x[ind:] - y_alias2 = y[ind:] - return x, x_alias, x_alias2, y_alias, y_alias2 - - foo = torch.compile(foo) - x = torch.rand([20, 20], device=GPU_TYPE) - _, code = run_and_get_code(foo, x) - - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").run(code[0]) - @unittest.skipIf(TEST_WITH_ROCM or not IS_SM90, "no scaled_grouped_mm support") def test_respect_scaled_grouped_mm_layout_tag(self): # scaled_grouped_mm needs `mat2` to be column-major diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 8ac01ae791f72..9394c0e4a16d6 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -50,6 +50,7 @@ get_benchmark_name, IndentedBuffer, is_codegen_graph_partition_subgraph, + is_using_cudagraph_partition, LineContext, sympy_product, sympy_str, @@ -1197,7 +1198,14 @@ def write_prefix(self) -> None: self.write_args(graph_input_names) self.codegen_inputs() - self.codegen_input_size_and_nan_asserts() + + # avoid duplicating asserts for both partition functions and + # the call function when using cudagraph partition + if not ( + is_using_cudagraph_partition() + and (not is_codegen_graph_partition_subgraph(self)) + ): + self.codegen_input_size_and_nan_asserts() def codegen_input_size_and_nan_asserts(self) -> None: if config.size_asserts: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 8d3b4cd7ed492..770da725a9aad 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -437,7 +437,11 @@ def prologue_fusion_enabled() -> bool: ) # enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph -graph_partition = False +graph_partition: bool = ( + os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0") + == "1" +) + # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 2686d1d2ddde2..7826c797d36be 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -10,6 +10,8 @@ from torch._inductor.utils import GraphPartitionMap, InputType from torch.utils._ordered_set import OrderedSet +from .utils import is_using_cudagraph_partition + if TYPE_CHECKING: from collections.abc import Sequence @@ -170,7 +172,8 @@ def check_multiple_devices_or_any_cpu_nodes( # meta tensors are supported since there is no compute device_node_mapping.pop(torch.device("meta"), None) - if torch._inductor.config.graph_partition: + # dynamo cudagraph does not support graph partition + if is_using_cudagraph_partition(): # graph partition supports splitting on cpu op. So we can ignore cpu nodes. device_node_mapping.pop(torch.device("cpu"), None) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index e0a0309d1c811..d8a96c573b320 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2179,7 +2179,10 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) self.process_grouped_nodes() - if torch._inductor.config.graph_partition: + if ( + torch._inductor.config.graph_partition + and torch._inductor.config.triton.cudagraphs + ): self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) @@ -4312,6 +4315,12 @@ def should_partition( ) -> bool: """Return True if we should partition the inductor graph on this node""" + # When not using cudagraphs, keep all kernels in the `call` function + # instead of graph partition functions, since graph partition only brings + # benefit to cudagraph + if not torch._inductor.config.triton.cudagraphs: + return True + # avoid duplicating logs when should_partition is called multiple times # on the same node def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f21905e16e9d7..0418edb2a1154 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3329,6 +3329,13 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: ) +def is_using_cudagraph_partition() -> bool: + return ( + torch._inductor.config.triton.cudagraphs + and torch._inductor.config.graph_partition + ) + + def dtype_from_size(size: int) -> torch.dtype: from .virtualized import V From 0f3b10b8eebe68e3c75d473d499b87dfe14a2eca Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Tue, 12 Aug 2025 04:37:58 +0000 Subject: [PATCH 0250/1424] [audio hash update] update the pinned audio hash (#160384) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160384 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 83860798279ad..9f7623cf35caf 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -e500f0cf88bc57ffd8b0029033da305eef24ae25 +bdb88e1d66f272cad72156c90ac8428ca61a601c From 8d3d1c844303cb1d46123a1caa76d4cf83973347 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 11 Aug 2025 17:27:19 -0700 Subject: [PATCH 0251/1424] [dynamo] fixes to propagate tag safeness (#159807) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159807 Approved by: https://github.com/jansel --- test/dynamo/test_functions.py | 1 + test/dynamo/test_guard_manager.py | 39 ++++----- torch/_C/_dynamo/guards.pyi | 6 ++ torch/_dynamo/config.py | 19 ++++ torch/_dynamo/guards.py | 110 ++++++++++++++++++++++-- torch/_dynamo/variables/functions.py | 12 +++ torch/_dynamo/variables/user_defined.py | 8 +- 7 files changed, 161 insertions(+), 34 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 6e28264d54669..31505b9445d40 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -4136,6 +4136,7 @@ def func(): self.assertEqual(cnts.frame_count, 3) self.assertEqual(cnts.op_count, 6) + @torch._dynamo.config.patch(assume_dunder_attributes_remain_unchanged=False) def test_meth_default_tensor_args(self): """ Tests that we indeed reference (and mutate) "the one" default tensor arg diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 8a66c847b52a1..27401f36e02f6 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -1,5 +1,7 @@ # Owner(s): ["module: dynamo"] +import abc import functools +import inspect import unittest import weakref @@ -1150,21 +1152,32 @@ def hook(guard_wrapper, f_locals, builder): def test_nn_module_tag_safe(self): class Foo(torch.nn.Module): + c = 2 + def __init__(self): super().__init__() self.a = 4 + def check(self, x): + return True + def forward(self, x): - return x + self.a + inspect.signature(self.check).parameters.items() + return x + self.a + self.c foo = Foo() - class Baz(torch.nn.Module): + class Env(metaclass=abc.ABCMeta): # noqa: B024 + pass + + class Baz(torch.nn.Module, Env): def __init__(self): super().__init__() self.foo = foo def forward(self, x): + if "Foo" in str(type(self).__mro__): + x = torch.sin(x) return self.foo(x) baz = Baz() @@ -1179,7 +1192,6 @@ def fn(x): from utils import install_guard_manager_testing_hook def hook(guard_wrapper, f_locals, builder): - from torch._C._dynamo.guards import GetGenericDictGuardAccessor from torch._dynamo.source import LocalSource baz_source = LocalSource("baz") @@ -1189,27 +1201,6 @@ def hook(guard_wrapper, f_locals, builder): self.assertTrue(baz_mgr.is_tag_safe()) self.assertTrue(baz_mgr.is_tag_safe_root()) - # Check tagness of baz.__dict__ - self.assertTrue(len(baz_mgr.get_accessors()) == 1) - dunder_dict_accessor = baz_mgr.get_accessors()[0] - self.assertTrue( - isinstance(dunder_dict_accessor, GetGenericDictGuardAccessor) - ) - - dunder_dict_mgr = baz_mgr.get_child_managers()[0] - self.assertTrue(dunder_dict_mgr.is_tag_safe()) - self.assertFalse(dunder_dict_mgr.is_tag_safe_root()) - - # Check tagness of baz.__dict__["_modules"] - modules_mgr = dunder_dict_mgr.get_child_managers()[0] - self.assertTrue(modules_mgr.is_tag_safe()) - self.assertFalse(modules_mgr.is_tag_safe_root()) - - # Check tagness of baz.__dict__["_modules"]["foo"] - modules_foo_mgr = modules_mgr.get_child_managers()[0] - self.assertTrue(modules_foo_mgr.is_tag_safe()) - self.assertFalse(modules_foo_mgr.is_tag_safe_root()) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) with install_guard_manager_testing_hook(hook): opt_fn(torch.randn(4, 4)) diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 5e0a014e8f784..64800504f4795 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -354,6 +354,12 @@ class DictGetItemGuardAccessor(GuardAccessor): ... class GetGenericDictGuardAccessor(GuardAccessor): ... class TypeDictGuardAccessor(GuardAccessor): ... class TypeMROGuardAccessor(GuardAccessor): ... +class ClosureGuardAccessor(GuardAccessor): ... +class TupleGetItemGuardAccessor(GuardAccessor): ... +class TypeGuardAccessor(GuardAccessor): ... +class CodeGuardAccessor(GuardAccessor): ... +class FuncDefaultsGuardAccessor(GuardAccessor): ... +class FuncKwDefaultsGuardAccessor(GuardAccessor): ... class GetAttrGuardAccessor(GuardAccessor): def get_attr_name(self) -> str: ... diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 0d83b7078eae9..b8b7561dde16b 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -354,6 +354,25 @@ # Skips guards on func.__defaults__ if the element to be guarded is a constant skip_guards_on_constant_func_defaults = True + +# The recursive-dict-tag guard relies on the class/function identity staying +# stable. We therefore assume that the following function dunder attributes +# are **never rebound** to a different object: +# +# • __code__ • __closure__ +# • __defaults__ • __kwdefaults__ +# • __annotations__ • __mro__ +# +# It is fine to mutate the objects they already point to (e.g. tweak an element +# inside __defaults__), but assignments like +# +# foo.__defaults__ = (3, 4) # REBIND - NOT SUPPORTED +# +# would invalidate the optimization. This type of rebinding is rare, so we +# assume that the rebinding never happens for guard purposes. Set the flag +# below to False only in environments where such rebinding is known to occur. +assume_dunder_attributes_remain_unchanged = True + # Speedup guard execution of nested nn modules by recursively checking for dict # tags to avoid full guard execution. use_recursive_dict_tags_for_guards = True diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index a32b8d686dac7..445224319b970 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -48,10 +48,16 @@ from torch._C._dynamo.guards import ( check_obj_id, check_type_id, + ClosureGuardAccessor, + CodeGuardAccessor, dict_version, DictGetItemGuardAccessor, DictGuardManager, + FuncDefaultsGuardAccessor, + FuncKwDefaultsGuardAccessor, + GetAttrGuardAccessor, GetGenericDictGuardAccessor, + GuardAccessor, GuardDebugInfo, GuardManager, install_no_tensor_aliasing_guard, @@ -62,6 +68,10 @@ profile_guard_manager, RelationalGuard, RootGuardManager, + TupleGetItemGuardAccessor, + TypeDictGuardAccessor, + TypeGuardAccessor, + TypeMROGuardAccessor, ) from torch._dynamo.source import ( get_global_source_name, @@ -204,6 +214,17 @@ verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") +dunder_attrs_assumed_constants = ( + "__defaults__", + "__kwdefaults__", + "__code__", + "__closure__", + "__annotations__", + "__func__", + "__mro__", +) + + class IndentedBufferWithPrefix(IndentedBuffer): def prefix(self) -> str: return "| " * (self._indent * self.tabwidth) @@ -372,6 +393,16 @@ def find_tag_safe_roots(self) -> None: subset that are tag safe roots. """ + def check_tag_safety( + node: GuardManager, accepted_accessors: tuple[type[GuardAccessor], ...] + ) -> bool: + accessors = node.get_accessors() + child_mgrs = node.get_child_managers() + return all( + isinstance(accessor, accepted_accessors) and mgr.is_tag_safe() + for accessor, mgr in zip(accessors, child_mgrs) + ) + def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]: # Just recurse through the key and value dict managers and check if # all of them are tag safe nodes. @@ -429,12 +460,8 @@ def visit_manager(node: GuardManager) -> list[GuardManager]: if is_subtree_tag_safe: node.mark_tag_safe() elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module): - accessors = node.get_accessors() - child_mgrs = node.get_child_managers() - is_subtree_tag_safe = all( - isinstance(accessor, GetGenericDictGuardAccessor) - and mgr.is_tag_safe() - for accessor, mgr in zip(accessors, child_mgrs) + is_subtree_tag_safe = check_tag_safety( + node, (GetGenericDictGuardAccessor, TypeGuardAccessor) ) if is_subtree_tag_safe: node.mark_tag_safe() @@ -443,6 +470,77 @@ def visit_manager(node: GuardManager) -> list[GuardManager]: return [ node, ] + elif ( + node.get_type_of_guarded_value() + in ( + types.FunctionType, + types.MethodType, + staticmethod, + classmethod, + ) + and config.assume_dunder_attributes_remain_unchanged + ): + # Assumption: callers will not reassignthe attributes + # func.__code__, func.__closure__, func.__defaults__, or func.__kwdefaults__. + # Mutating the objects those attributes point to is fine; + # rebinding the attribute itself is not. + # Example ─ allowed: foo.__defaults__[0].bar = 99 + # forbidden: foo.__defaults__ = (3, 4) + is_subtree_tag_safe = check_tag_safety( + node, + ( + CodeGuardAccessor, + ClosureGuardAccessor, + FuncDefaultsGuardAccessor, + FuncKwDefaultsGuardAccessor, + GetAttrGuardAccessor, + ), + ) + + for accessor in node.get_accessors(): + if isinstance(accessor, GetAttrGuardAccessor): + is_subtree_tag_safe &= ( + accessor.get_attr_name() in dunder_attrs_assumed_constants + ) + + if is_subtree_tag_safe: + node.mark_tag_safe() + elif issubclass(node.get_type_of_guarded_value(), types.CellType): + is_subtree_tag_safe = check_tag_safety(node, (GetAttrGuardAccessor,)) + + is_subtree_tag_safe &= all( + isinstance(accessor, GetAttrGuardAccessor) + and accessor.get_attr_name() == "cell_contents" + for accessor in node.get_accessors() + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + elif ( + issubclass(node.get_type_of_guarded_value(), tuple) + and node.get_source().endswith(dunder_attrs_assumed_constants) + and config.assume_dunder_attributes_remain_unchanged + ): + # We trust tuples obtained from a function’s __closure__ or + # __defaults__. Any *other* tuple-valued attribute can be + # silently replaced—for example: + # + # foo.bar = (1, 2) # original + # foo.bar = (3, 4) # rebinding that our dict-tag optimisation won’t see + # + # Therefore only tuples from __closure__ / __defaults__ participate in the + # recursive-dict-tag optimization; all others are ignored. + is_subtree_tag_safe = check_tag_safety( + node, (TupleGetItemGuardAccessor,) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + elif issubclass(node.get_type_of_guarded_value(), type): + is_subtree_tag_safe = check_tag_safety( + node, (TypeDictGuardAccessor, TypeMROGuardAccessor) + ) + if is_subtree_tag_safe: + node.mark_tag_safe() + return tag_safe_roots def visit(node: GuardManager) -> list[GuardManager]: diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 050f39f55895c..4bdcecf3b3c2c 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1066,6 +1066,18 @@ def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: super().__init__(fn=fn, **kwargs) self.obj = obj self.source_fn = source_fn + # Note on source and source_fn + # Be careful with `source` when delegating to UserFunctionVariable + # (base-class) methods. In this __init__, `source` is a *bound method* + # object, but the base class expects the underlying *function* object. + # One way is to simplly use `__func__` to unwrap it. + # + # For recursive dict-tag optimizations, it can be faster to fetch the + # function directly from `cls.__dict__`; that’s why we pass on + # `source_fn`. Whenever it is possible to access the function from + # cls.__dict__, we pass that on to `source_fn`. Because bind_args + # operates on the unbound function, most guards should target + # `source_fn` rather than the original `source`. if source_fn is None and kwargs.get("source") is not None: self.source_fn = AttrSource(kwargs.get("source"), "__func__") diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 95b1a37b677fc..084a1e2149d04 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -253,6 +253,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke elif name == "__dict__": options = {"source": source} return variables.GetAttrVariable(self, name, **options) + elif name == "__mro__": + attr_source = self.source and TypeMROSource(self.source) + return VariableTracker.build(tx, self.value.__mro__, attr_source) # Special handling of collections.OrderedDict.fromkeys() # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with @@ -295,10 +298,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke func = obj.__get__(None, self.value) return VariableTracker.build(tx, func, source) elif source: - # __mro__ is a member in < 3.12, an attribute in >= 3.12 - if inspect.ismemberdescriptor(obj) or ( - sys.version_info >= (3, 12) and name == "__mro__" - ): + if inspect.ismemberdescriptor(obj): return VariableTracker.build(tx, obj.__get__(self.value), source) if ConstantVariable.is_literal(obj): From 01bcf9a40dea937637d2cdd530bed2652510943d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 12 Aug 2025 05:14:17 +0000 Subject: [PATCH 0252/1424] Bump transformers pin (#159291) Trying to update hf pin. Benchmarking run to figure out issues image Retrying - https://github.com/pytorch/pytorch/pull/156118 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159291 Approved by: https://github.com/BoyuanFeng, https://github.com/huydhn Co-authored-by: Huy Do --- .ci/docker/ci_commit_pins/huggingface.txt | 2 +- .../common/install_inductor_benchmark_deps.sh | 6 +++--- .ci/pytorch/macos-test.sh | 3 +++ .ci/pytorch/test.sh | 1 - benchmarks/dynamo/check_accuracy.py | 1 + .../aot_eager_huggingface_inference.csv | 14 +------------- .../aot_eager_huggingface_training.csv | 16 ++-------------- .../aot_eager_torchbench_inference.csv | 4 ++-- .../aot_eager_torchbench_training.csv | 4 ++-- .../aot_inductor_huggingface_inference.csv | 14 +------------- ...t_inductor_freezing_huggingface_inference.csv | 14 +------------- ...ductor_amp_freezing_huggingface_inference.csv | 14 +------------- ...nductor_amp_freezing_torchbench_inference.csv | 4 ++-- ...u_inductor_freezing_huggingface_inference.csv | 14 +------------- ...pu_inductor_freezing_torchbench_inference.csv | 4 ++-- .../cpu_inductor_huggingface_inference.csv | 14 +------------- .../cpu_inductor_torchbench_inference.csv | 4 ++-- .../dynamic_aot_eager_huggingface_inference.csv | 14 +------------- .../dynamic_aot_eager_huggingface_training.csv | 16 ++-------------- .../dynamic_aot_eager_torchbench_inference.csv | 4 ++-- .../dynamic_aot_eager_torchbench_training.csv | 2 +- ...ynamic_cpu_inductor_huggingface_inference.csv | 14 +------------- ...dynamic_cpu_inductor_torchbench_inference.csv | 4 ++-- ...ductor_amp_freezing_huggingface_inference.csv | 14 +------------- ...nductor_amp_freezing_torchbench_inference.csv | 4 ++-- .../dynamic_inductor_huggingface_inference.csv | 14 +------------- .../dynamic_inductor_huggingface_training.csv | 16 ++-------------- .../dynamic_inductor_torchbench_inference.csv | 4 ++-- .../dynamic_inductor_torchbench_training.csv | 2 +- .../dynamo_eager_huggingface_inference.csv | 14 +------------- .../dynamo_eager_huggingface_training.csv | 16 ++-------------- .../dynamo_eager_torchbench_inference.csv | 4 ++-- .../dynamo_eager_torchbench_training.csv | 4 ++-- .../inductor_huggingface_inference.csv | 14 +------------- .../inductor_huggingface_training.csv | 16 ++-------------- .../inductor_torchbench_inference.csv | 4 ++-- .../inductor_torchbench_training.csv | 4 ++-- .../rocm/aot_eager_huggingface_inference.csv | 14 +------------- .../rocm/aot_eager_huggingface_training.csv | 16 ++-------------- .../rocm/aot_eager_torchbench_inference.csv | 4 ++-- .../rocm/aot_eager_torchbench_training.csv | 4 ++-- .../rocm/aot_inductor_huggingface_inference.csv | 14 +------------- .../dynamic_aot_eager_huggingface_inference.csv | 14 +------------- .../dynamic_aot_eager_huggingface_training.csv | 16 ++-------------- .../dynamic_aot_eager_torchbench_inference.csv | 4 ++-- .../dynamic_aot_eager_torchbench_training.csv | 4 ++-- .../dynamic_inductor_huggingface_inference.csv | 14 +------------- .../dynamic_inductor_huggingface_training.csv | 16 ++-------------- .../dynamic_inductor_torchbench_inference.csv | 4 ++-- .../dynamic_inductor_torchbench_training.csv | 4 ++-- .../rocm/dynamo_eager_huggingface_inference.csv | 14 +------------- .../rocm/dynamo_eager_huggingface_training.csv | 16 ++-------------- .../rocm/dynamo_eager_torchbench_inference.csv | 4 ++-- .../rocm/dynamo_eager_torchbench_training.csv | 4 ++-- .../rocm/inductor_huggingface_inference.csv | 14 +------------- .../rocm/inductor_huggingface_training.csv | 16 ++-------------- .../rocm/inductor_torchbench_inference.csv | 4 ++-- .../rocm/inductor_torchbench_training.csv | 4 ++-- benchmarks/dynamo/common.py | 1 - benchmarks/dynamo/huggingface.py | 6 ++++++ benchmarks/dynamo/huggingface.yaml | 3 --- benchmarks/dynamo/huggingface_models_list.txt | 3 --- .../dynamo/huggingface_models_list_cpu.txt | 3 --- benchmarks/dynamo/torchbench.py | 16 ++++++++++++++++ 64 files changed, 116 insertions(+), 437 deletions(-) diff --git a/.ci/docker/ci_commit_pins/huggingface.txt b/.ci/docker/ci_commit_pins/huggingface.txt index f00d6ca4f9ca7..4fc4729a25da1 100644 --- a/.ci/docker/ci_commit_pins/huggingface.txt +++ b/.ci/docker/ci_commit_pins/huggingface.txt @@ -1 +1 @@ -243e186efbf7fb93328dd6b34927a4e8c8f24395 +v4.54.0 diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index c2601adb67e32..21fced2e851d8 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -26,15 +26,15 @@ function install_torchbench() { python install.py --continue_on_fail - # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 - # is regressing speedup metric. This needs to be investigated further - pip install transformers==4.38.1 + # soxr comes from https://github.com/huggingface/transformers/pull/39429 + pip install transformers==4.54.0 soxr==0.5.0 echo "Print all dependencies after TorchBench is installed" python -mpip freeze popd chown -R jenkins torchbench + chown -R jenkins /opt/conda } # Pango is needed for weasyprint which is needed for doctr diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index c38448898cb4b..c9d926a5df37c 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -175,6 +175,9 @@ checkout_install_torchbench() { python install.py --continue_on_fail fi + # soxr comes from https://github.com/huggingface/transformers/pull/39429 + pip install transformers==4.54.0 soxr==0.5.0 + echo "Print all dependencies after TorchBench is installed" python -mpip freeze popd diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 473a125475c4e..daa258d283fa3 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1682,7 +1682,6 @@ elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then install_torchaudio install_torchvision - install_torchao id=$((SHARD_NUMBER-1)) # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 diff --git a/benchmarks/dynamo/check_accuracy.py b/benchmarks/dynamo/check_accuracy.py index 7fa24ae7346b1..5cd714fe02e93 100644 --- a/benchmarks/dynamo/check_accuracy.py +++ b/benchmarks/dynamo/check_accuracy.py @@ -14,6 +14,7 @@ "detectron2_maskrcnn_r_101_c4", "timm_efficientnet", # see https://github.com/pytorch/pytorch/issues/148699 "XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148 + "moondream", # discovered in https://github.com/pytorch/pytorch/pull/159291 } diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv index 66e088f334071..f65909f3a24ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index af605accecf6e..01762c5f5f290 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 33ede2b914b4f..54b7d63f3a4bc 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv index 1cafcbe55675d..ce334e22c698b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv @@ -42,14 +42,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -66,7 +58,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -154,10 +146,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv index 1cafcbe55675d..ce334e22c698b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv @@ -42,14 +42,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -66,7 +58,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -154,10 +146,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index faafea393ede5..9620a79f91a97 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -158,7 +158,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv index a2b7c1a7b15ca..aec659fdcd654 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -158,7 +158,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index 697fe04cd91a5..4f2eec1493520 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -158,7 +158,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv index 66e088f334071..f65909f3a24ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 7f11e13980273..f9874a7a4b900 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index cb8cead2ba034..81ed3080dd3e8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index 6f9e9e0ed5a7b..c8db4d5823203 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -122,7 +122,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -142,7 +142,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv index 4f7ca2b638c48..f4c9ffddd9974 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,24 +hf_BigBird,pass,25 @@ -158,7 +158,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv index 66e088f334071..f65909f3a24ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 7f11e13980273..f9874a7a4b900 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index 05eb7e3546eef..188f3dd00cac3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv index 66e088f334071..f65909f3a24ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index af605accecf6e..01762c5f5f290 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 44983e8ecc214..0985e42fc5cb9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv index 66e088f334071..f65909f3a24ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 9a9a68629f875..fbd169539ab77 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 33ede2b914b4f..54b7d63f3a4bc 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv index 9fdb41506e3b2..08061de428d71 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv index b3a3265baa16f..6f316b219bb92 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv @@ -166,7 +166,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -181,7 +181,7 @@ hf_T5_base,pass,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv index d2300bdac05b8..48d0b111788f7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -114,7 +114,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv index 1cafcbe55675d..ce334e22c698b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv @@ -42,14 +42,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -66,7 +58,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -154,10 +146,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv index 9fdb41506e3b2..08061de428d71 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv index 624f295624783..4b5138ce9c367 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv @@ -166,7 +166,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -181,7 +181,7 @@ hf_T5_base,pass,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv index 1605a26b7ce5f..643a02fdca8fd 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -114,7 +114,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv index 66e088f334071..f65909f3a24ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv index 6776cc5f5d7a7..a3fc7cf192371 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -174,7 +174,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv index b43e38b7d822a..ced88884720b7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv index 9fdb41506e3b2..08061de428d71 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv index b3a3265baa16f..6f316b219bb92 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv @@ -166,7 +166,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -181,7 +181,7 @@ hf_T5_base,pass,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv index 754f5f718e436..d1606b622639e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -114,7 +114,7 @@ hf_Longformer,pass,4 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv index fd57a3b4cbf3c..0f088e7892d8f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv @@ -46,14 +46,6 @@ CamemBert,pass,0 -DebertaForMaskedLM,pass,0 - - - -DebertaForQuestionAnswering,pass,0 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,0 -DistillGPT2,pass,0 +DistillGPT2,pass,2 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,0 -Speech2Text2ForCausalLM,pass,0 - - - T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv index 66e088f334071..f65909f3a24ea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv @@ -46,14 +46,6 @@ CamemBert,pass,5 -DebertaForMaskedLM,pass,5 - - - -DebertaForQuestionAnswering,pass,5 - - - DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -70,7 +62,7 @@ DistilBertForQuestionAnswering,pass,5 -DistillGPT2,pass,5 +DistillGPT2,pass,7 @@ -130,7 +122,7 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,6 +OPTForCausalLM,pass,8 @@ -158,10 +150,6 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,6 - - - T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv index 3e4e9ee702aa3..8ccf95da9659e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv @@ -162,7 +162,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,5 +hf_Reformer,pass,8 @@ -174,7 +174,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,pass,3 +hf_T5_generate,pass,11 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv index 86ad955b5a2cb..e842ac7cb8e1f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv @@ -102,7 +102,7 @@ hf_DistilBert,pass,6 -hf_GPT2,pass,6 +hf_GPT2,pass,8 @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,23 +hf_Reformer,pass,25 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 651bc90ba194b..469ece2958df4 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -204,7 +204,6 @@ class CI(NamedTuple): "PLBartForCausalLM", "PLBartForConditionalGeneration", "PegasusForCausalLM", - "Speech2Text2ForCausalLM", "TrOCRForCausalLM", "XGLMForCausalLM", # TIMM diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 916a33276d996..aa81832a88315 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -459,6 +459,12 @@ def load_model( else: model.eval() + # Turning off kv cache for torchbench models. This is not the right + # thing to do, but the pt2 dashboard is outdated. Real transformers + # benchmarks will be added soon using a different infra. + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + self.validate_model(model, example_inputs) return device, model_name, model, example_inputs, batch_size diff --git a/benchmarks/dynamo/huggingface.yaml b/benchmarks/dynamo/huggingface.yaml index f0ee57a589657..5640776117096 100644 --- a/benchmarks/dynamo/huggingface.yaml +++ b/benchmarks/dynamo/huggingface.yaml @@ -31,8 +31,6 @@ batch_size: BlenderbotSmallForCausalLM: 4 BlenderbotSmallForConditionalGeneration: 2 CamemBert: 2 - DebertaForMaskedLM: 4 - DebertaForQuestionAnswering: 2 DebertaV2ForMaskedLM: 4 DebertaV2ForQuestionAnswering: 8 DistilBertForMaskedLM: 2 @@ -63,7 +61,6 @@ batch_size: PegasusForConditionalGeneration: 2 RobertaForCausalLM: 2 RobertaForQuestionAnswering: 2 - Speech2Text2ForCausalLM: 4 T5ForConditionalGeneration: 2 T5Small: 2 TrOCRForCausalLM: 2 diff --git a/benchmarks/dynamo/huggingface_models_list.txt b/benchmarks/dynamo/huggingface_models_list.txt index 6e3cf19a783d7..12ceedd5c4ccc 100644 --- a/benchmarks/dynamo/huggingface_models_list.txt +++ b/benchmarks/dynamo/huggingface_models_list.txt @@ -10,8 +10,6 @@ BlenderbotForConditionalGeneration,16 BlenderbotSmallForCausalLM,256 BlenderbotSmallForConditionalGeneration,128 CamemBert,32 -DebertaForMaskedLM,32 -DebertaForQuestionAnswering,32 DebertaV2ForMaskedLM,8 DebertaV2ForQuestionAnswering,8 DistilBertForMaskedLM,256 @@ -42,7 +40,6 @@ PegasusForCausalLM,128 PegasusForConditionalGeneration,64 RobertaForCausalLM,32 RobertaForQuestionAnswering,32 -Speech2Text2ForCausalLM,1024 T5ForConditionalGeneration,8 T5Small,8 TrOCRForCausalLM,64 diff --git a/benchmarks/dynamo/huggingface_models_list_cpu.txt b/benchmarks/dynamo/huggingface_models_list_cpu.txt index cabd79ac830f6..4078368a69c44 100644 --- a/benchmarks/dynamo/huggingface_models_list_cpu.txt +++ b/benchmarks/dynamo/huggingface_models_list_cpu.txt @@ -10,8 +10,6 @@ BlenderbotForCausalLM,32 BlenderbotSmallForCausalLM,64 BlenderbotSmallForConditionalGeneration,64 CamemBert,16 -DebertaForMaskedLM,32 -DebertaForQuestionAnswering,8 DebertaV2ForMaskedLM,16 DebertaV2ForQuestionAnswering,2 DistilBertForMaskedLM,128 @@ -38,7 +36,6 @@ PLBartForCausalLM,8 PLBartForConditionalGeneration,4 RobertaForCausalLM,16 RobertaForQuestionAnswering,16 -Speech2Text2ForCausalLM,32 T5ForConditionalGeneration,4 T5Small,1 TrOCRForCausalLM,32 diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index c2568aa1daa19..1f10ecc661d8e 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -382,6 +382,22 @@ def load_model( if self.args.trace_on_xla: # work around for: https://github.com/pytorch/xla/issues/4174 import torch_xla # noqa: F401 + + # Turning off kv cache for torchbench models. This is not the right + # thing to do, but the torchbench models are way outdated, and since we + # are using torchbench pt2 dashboard to track regressions (rather than + # improving performance), we are just setting the kv cache to false. + # Real transformers benchmarks will be added soon using a different + # infra. + if ( + model_name.startswith("hf") + and hasattr(model, "config") + and hasattr(model.config, "use_cache") + ): + model.config.use_cache = False + if model_name == "hf_T5_generate": + model.model.config.use_cache = False + self.validate_model(model, example_inputs) return device, benchmark.name, model, example_inputs, batch_size From 9a0f7a3bb01b235ea04581ee540970a098071b72 Mon Sep 17 00:00:00 2001 From: Jovian Anthony Jaison <38627145+jovianjaison@users.noreply.github.com> Date: Tue, 12 Aug 2025 06:24:54 +0000 Subject: [PATCH 0253/1424] [retry-land][pytorch][dynamo_compile] Log stack_trace to dynamo_compile (#160348) refer: https://github.com/pytorch/pytorch/pull/159655 Earlier pr failed on dynamo/test_utils.py::TestDynamoTimed::test_dynamo_timed. Updated test_dynamo_timed + re-ran locally to test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160348 Approved by: https://github.com/masnesral --- test/dynamo/test_utils.py | 31 ++++++++++++++++++++++++ torch/_dynamo/convert_frame.py | 44 +++++++++++++++++++--------------- torch/_dynamo/utils.py | 1 + 3 files changed, 57 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index d4206575d7b08..fdb34ab0b68e0 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -246,6 +246,32 @@ def add(x, y): utils.reset_frame_count() torch._logging._internal.structured_logging_overhead.clear() + @dynamo_config.patch({"log_compilation_metrics": True}) + @inductor_config.patch({"force_disable_caches": True}) + def test_stack_trace(self): + self.warmup() + + compilation_events = [] + with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event: + self.run_forward_backward() + compilation_events = [arg[0][0] for arg in log_event.call_args_list] + stack_trace_list = [] + for e in compilation_events: + stack_trace_list.append(e.stack_trace) + + self.assertGreater(len(stack_trace_list), 0) + result = "\n".join( + item + for sublist in stack_trace_list + if sublist + for item in (sublist if isinstance(sublist, list) else [sublist]) + ) + self.assertIn( + "test_stack_trace", + result, + "Log file does not contain the expected string: 'test_stack_trace'", + ) + @dynamo_config.patch( { "log_compilation_metrics": True, @@ -396,6 +422,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): e.cuda_version = None e.triton_version = None e.python_version = None + e.stack_trace = None # First event is for the forward. Formatting makes reading diffs # much easier. @@ -479,6 +506,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': 0, 'specialize_float': False, + 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, @@ -560,6 +588,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': 0, 'specialize_float': False, + 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, @@ -652,6 +681,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': None, 'specialize_float': None, + 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, @@ -733,6 +763,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': None, 'specialize_float': None, + 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index bba4d9c980869..fb27c29935439 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -225,30 +225,35 @@ def fx_forward_from_src_skip_result( return result -def log_dynamo_start(code: CodeType, skip: int = 0) -> None: +def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]: convert_frame_intern = structured.intern_string(__file__) + # Extract and filter the stack + stack = list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] # Initialize the ChromiumEventLogger on start torch._logging.trace_structured( "dynamo_start", - lambda: { - "stack": list( - itertools.takewhile( - lambda f: f["filename"] != convert_frame_intern, - structured.from_traceback( - CapturedTraceback.extract(skip=4 + skip).summary() - ), - ) - ) - + [ - { - "line": code.co_firstlineno, - "name": code.co_name, - "filename": structured.intern_string(code.co_filename), - } - ] - }, + lambda: {"stack": stack}, ) + stack_strings = [ + f"Line: {frame['line']}, Name: {frame['name']}, Filename: {frame['filename']}" + for frame in stack + ] + return stack_strings + def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: """ @@ -1160,7 +1165,7 @@ def format_func_info(code: CodeType) -> str: # # 2 extra here # torch/_logging/_internal.py:1064 in trace_structured # torch/_dynamo/convert_frame.py:780 in - log_dynamo_start(code, skip) + stack_trace = log_dynamo_start(code, skip) start_time_ns = time.time_ns() fail_type: Optional[str] = None fail_reason: Optional[str] = None @@ -1300,6 +1305,7 @@ def format_func_info(code: CodeType) -> str: "dynamo_compile_time_before_restart_us": to_int_us( dynamo_time_before_restart ), + "stack_trace": stack_trace, } # TODO: replace with CompileEventLogger.compilation_metrics # There are some columns here not in PT2 Compile Events diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 588f1ddb99a19..c6707fe12fbd0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1288,6 +1288,7 @@ class CompilationMetrics: compliant_custom_ops: Optional[set[str]] = None restart_reasons: Optional[set[str]] = None dynamo_time_before_restart_s: Optional[float] = None + stack_trace: Optional[list[str]] = None # Sometimes, we will finish analyzing a frame but conclude we don't want # to install any guarded code. True means we actually decided to install # a compiled frame From fea7e9dd37c02c334b130f6624af6163fde6b2ab Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 12 Aug 2025 08:38:15 +0000 Subject: [PATCH 0254/1424] extract shape in _view_has_unbacked_input (#160255) Summary: We were getting DDE on reshape still!! i looked deeper and found an issue in _view_has_unbacked_input namely when input is [[,,]] it need to be normalized to [..] Test Plan: existing tests. Rollback Plan: Differential Revision: D79951119 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160255 Approved by: https://github.com/bobrenjc93 --- torch/_subclasses/fake_impls.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 4d33280f7ac82..7ebd2ec92d124 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -514,6 +514,8 @@ def maybe_guard_or_true(x): def _view_has_unbacked_input(a, shape): from torch.fx.experimental.symbolic_shapes import has_hint + shape = utils.extract_shape_from_varargs(shape, validate=False) + return ( any(not has_hint(s) for s in a.size()) or any(not has_hint(s) for s in a.stride()) From b9003ed3d87699e81e436719625a21996a6654e5 Mon Sep 17 00:00:00 2001 From: morrison-turnansky Date: Tue, 12 Aug 2025 08:53:28 +0000 Subject: [PATCH 0255/1424] Dynamo Deep Dive Documentation Fix (#158860) changed SourceBuilder to VariableBuilder Fixes #158447 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158860 Approved by: https://github.com/mlazos --- docs/source/torch.compiler_dynamo_deepdive.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/torch.compiler_dynamo_deepdive.md b/docs/source/torch.compiler_dynamo_deepdive.md index 6bbb03170e549..9fa7654023ca5 100644 --- a/docs/source/torch.compiler_dynamo_deepdive.md +++ b/docs/source/torch.compiler_dynamo_deepdive.md @@ -285,7 +285,7 @@ appear in the errors, and the `VariableTracker` method that throws the exception when you encounter a Dynamo error. In particular, sometimes we find that an object is tracked as a `UserDefinedObjectVariable` (this is Dynamo’s catch-all class), when it should have been tracked as -something more specific. In these cases, the `SourceBuilder.__call__` +something more specific. In these cases, the `VariableBuilder` logic is often to blame. **Debugging tip**. When running a program with `TORCH_LOGS=dynamo`, From f990490a23815ea6ee27e487c70ba2cf513ba43d Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Tue, 12 Aug 2025 09:36:59 +0000 Subject: [PATCH 0256/1424] Add `label_smoothing` param in `nn.BCELoss` and `nn.BCEWithLogitsLoss` (#150282) Fixes #91545 ## Changes - Add `label_smoothing` param and docs - Add test case for `label_smoothing` - Remove duplicate description in `nn.BCELoss` and `nn.BCEWithLogitsLoss` ## Test Result ```bash pytest -s test/test_nn.py -k test_bce ``` ![image](https://github.com/user-attachments/assets/30c0b7fe-fe49-4aa0-9b05-4d70403a7b05) ![image](https://github.com/user-attachments/assets/4fe3fd1c-54b8-4012-afd9-133ce9fb4964) ![image](https://github.com/user-attachments/assets/5cad019a-3a4c-475a-9fde-9c1acad5792d) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150282 Approved by: https://github.com/cyyever, https://github.com/mikaylagawarecki --- torch/nn/functional.py | 30 ++++++++++++++++++++--- torch/nn/functional.pyi.in | 2 ++ torch/nn/modules/loss.py | 19 +++++++++++++- torch/overrides.py | 6 ++--- torch/testing/_internal/common_modules.py | 16 +++++++++--- 5 files changed, 62 insertions(+), 11 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 6b61c3a5799db..c3219644fee87 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3472,6 +3472,7 @@ def binary_cross_entropy( size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean", + label_smoothing: float = 0.0, ) -> Tensor: r"""Compute Binary Cross Entropy between the target and input probabilities. @@ -3490,9 +3491,11 @@ def binary_cross_entropy( elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` - + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Examples:: - >>> input = torch.randn(3, 2, requires_grad=True) >>> target = torch.rand(3, 2, requires_grad=False) >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target) @@ -3508,6 +3511,7 @@ def binary_cross_entropy( size_average=size_average, reduce=reduce, reduction=reduction, + label_smoothing=label_smoothing, ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) @@ -3523,6 +3527,13 @@ def binary_cross_entropy( new_size = _infer_size(target.size(), weight.size()) weight = weight.expand(new_size) + assert 0 <= label_smoothing <= 1, ( + f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + ) + + if label_smoothing > 0: + target = target * (1 - label_smoothing) + (1 - target) * label_smoothing + return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) @@ -3534,6 +3545,7 @@ def binary_cross_entropy_with_logits( reduce: Optional[bool] = None, reduction: str = "mean", pos_weight: Optional[Tensor] = None, + label_smoothing: float = 0.0, ) -> Tensor: r"""Compute Binary Cross Entropy between target and input logits. @@ -3560,9 +3572,11 @@ def binary_cross_entropy_with_logits( [C, H, W] the same pos_weights across the batch. To apply the same positive weight along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. Default: ``None`` - + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Examples:: - >>> input = torch.randn(3, requires_grad=True) >>> target = torch.empty(3).random_(2) >>> loss = F.binary_cross_entropy_with_logits(input, target) @@ -3579,6 +3593,7 @@ def binary_cross_entropy_with_logits( reduce=reduce, reduction=reduction, pos_weight=pos_weight, + label_smoothing=label_smoothing, ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) @@ -3590,6 +3605,13 @@ def binary_cross_entropy_with_logits( f"Target size ({target.size()}) must be the same as input size ({input.size()})" ) + assert 0 <= label_smoothing <= 1, ( + f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + ) + + if label_smoothing > 0: + target = target * (1 - label_smoothing) + (1 - target) * label_smoothing + return torch.binary_cross_entropy_with_logits( input, target, weight, pos_weight, reduction_enum ) diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index d0b64447e900b..580a768e4d9f1 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -134,6 +134,7 @@ def binary_cross_entropy_with_logits( reduce: bool | None = ..., reduction: str = ..., pos_weight: Tensor | None = ..., + label_smoothing: float = ..., ) -> Tensor: ... __all__ += ["binary_cross_entropy_with_logits"] @@ -145,6 +146,7 @@ def binary_cross_entropy( size_average: bool | None = ..., reduce: bool | None = ..., reduction: str = ..., + label_smoothing: float = ..., ) -> Tensor: ... __all__ += ["binary_cross_entropy"] diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 6fa0d53c8a448..0b9468797d4c9 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -692,6 +692,10 @@ class BCELoss(_WeightedLoss): elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. @@ -717,15 +721,21 @@ def __init__( size_average=None, reduce=None, reduction: str = "mean", + label_smoothing: float = 0.0, ) -> None: super().__init__(weight, size_average, reduce, reduction) + self.label_smoothing = label_smoothing def forward(self, input: Tensor, target: Tensor) -> Tensor: """ Runs the forward pass. """ return F.binary_cross_entropy( - input, target, weight=self.weight, reduction=self.reduction + input, + target, + weight=self.weight, + reduction=self.reduction, + label_smoothing=self.label_smoothing, ) @@ -815,6 +825,10 @@ class BCEWithLogitsLoss(_Loss): [C, H, W] the same pos_weights across the batch. To apply the same positive weight along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. Default: ``None`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. @@ -838,12 +852,14 @@ def __init__( reduce=None, reduction: str = "mean", pos_weight: Optional[Tensor] = None, + label_smoothing: float = 0.0, ) -> None: super().__init__(size_average, reduce, reduction) self.register_buffer("weight", weight) self.register_buffer("pos_weight", pos_weight) self.weight: Optional[Tensor] self.pos_weight: Optional[Tensor] + self.label_smoothing = label_smoothing def forward(self, input: Tensor, target: Tensor) -> Tensor: """Runs the forward pass.""" @@ -853,6 +869,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: self.weight, pos_weight=self.pos_weight, reduction=self.reduction, + label_smoothing=self.label_smoothing, ) diff --git a/torch/overrides.py b/torch/overrides.py index fe7af6bc4ff0c..3304cfab5e19c 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -488,7 +488,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.bernoulli: lambda input, generator=None, out=None: -1, torch.bilinear: lambda input1, input2, weight, bias: -1, torch.binary_cross_entropy_with_logits: ( - lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1 + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None, label_smoothing=0.0: -1 # noqa: B950 ), torch.bincount: lambda input, weights=None, minlength=0: -1, torch.binomial: lambda count, prob, generator=None: -1, @@ -851,10 +851,10 @@ def get_testing_overrides() -> dict[Callable, Callable]: ), torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1, torch.nn.functional.binary_cross_entropy: ( - lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1 + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", label_smoothing=0.0: -1 ), torch.nn.functional.binary_cross_entropy_with_logits: ( - lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1 + lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None, label_smoothing=0.0: -1 # noqa: B950 ), torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1, torch.nn.functional.cosine_embedding_loss: ( diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index edb897b6f99a5..f42ae06e7b303 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1463,9 +1463,14 @@ def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, tr ('reduction_mean', {'reduction': 'mean'}), ('reduction_none', {'reduction': 'none'}), ('weights', {'weight': make_weight((10,))}), + ('label_smoothing', {'label_smoothing': 0.15}), ] - def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None): + def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None, label_smoothing=0.0): + assert 0 <= label_smoothing <= 1 + if label_smoothing > 0: + t = t * (1 - label_smoothing) + (1 - t) * label_smoothing + result = -(t * i.log() + (1 - t) * (1 - i).log()) if weight is not None: @@ -1511,10 +1516,15 @@ def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, require ('reduction_mean', {'reduction': 'mean'}), ('reduction_none', {'reduction': 'none'}), ('weights', {'weight': make_weight((10,))}), - ('scalar_weights', {'weight': make_weight(())}) + ('scalar_weights', {'weight': make_weight(())}), + ('label_smoothing', {'label_smoothing': 0.15}), ] - def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None): + def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None, label_smoothing=0.0): + assert 0 <= label_smoothing <= 1 + if label_smoothing > 0: + t = t * (1 - label_smoothing) + (1 - t) * label_smoothing + # TODO: add pos_weight to the definition here and corresponding SampleInputs max_val = (-i).clamp(min=0) result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_()) From 4d5b3f2d5af7c8e4f41da4ffca53fafe8bb86235 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 11 Aug 2025 22:09:51 -0700 Subject: [PATCH 0257/1424] [dynamo][guards] Install dict watchers for recrusive dict tag optimization (#159796) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159796 Approved by: https://github.com/jansel --- torch/csrc/dynamo/guards.cpp | 158 ++++++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 2 deletions(-) diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 9e25d07b1e839..c8e0ae9c27360 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) { static std::unordered_map dict_version_map; static int dict_version_watcher_id; +static int dict_recursive_tag_watcher_id; static uint64_t global_dict_version_id = 1; static int dict_version_watch_callback( PyDict_WatchEvent event, @@ -1557,6 +1558,37 @@ class GuardManager; class RootGuardManager; class DictGuardManager; +// Global registry used by the *recursive-dict-tag* optimisation. +// +// Key : `PyObject*` pointing to a watched `dict` +// Value : list of `GuardManager*` instances that have recorded that dict +// +// Why is this global? +// ------------------- +// * CPython allows only a small, fixed number of dict-watcher IDs (≈64). +// All `GuardManager`s therefore share a single watcher callback. +// * Different guard managers (possibly across different frames) can end up +// watching the same dictionary pointer. Therefore, we have a list of guard +// managers for each dict pointer. +// +// When is watch registered? +// * During the recording phase of recursive dict tag matching in GuardManager. +// +// When are they watched? +// * In the dict_recursive_tag_watch_callback function. +// +// When are the dict pointers unwatched? +// * If a dict is mutated or the guard manager deallocates. +// * Read `unwatch_all_saved_dict_pointers` docstring for more details. +// +// Expected size +// ------------- +// Every compilation frame contributes its tag-safe dicts to this registry, so +// the container can grow large over the lifetime of the process. That’s +// acceptable: lookup is by pointer (hash/equals = identity) and each entry +// stores only lightweight pointers. +std::unordered_map> dict_to_guard_managers; + /** * Base class for the leaf guard in the GuardManager hierarchy. */ @@ -2625,6 +2657,7 @@ class GuardManager { virtual ~GuardManager() { cleanup_tag_safe_entries(); + disable_recursive_dict_tag_optimization(); } void cleanup_tag_safe_entries() { @@ -2727,6 +2760,11 @@ class GuardManager { _tensor_pointers[value] = tensor_pointers; } + void disable_recursive_dict_tag_optimization() { + unwatch_all_saved_dict_pointers(); + _disable_dict_tag_matching = true; + } + public: // For cloning GuardManager( @@ -2833,6 +2871,10 @@ class GuardManager { } bool check_dict_pointer_tags(PyObject* value) { + if (_dict_callback_installed) { + // This means that for 3.12+, there are callbacks watching dict pointers. + return true; + } for (auto& kv : _dict_pointers[value]) { PyObject* dict_pointer = kv.first; uint64_t old_tag = kv.second; @@ -2963,6 +3005,11 @@ class GuardManager { throw std::runtime_error( "Could not register a callback for recursive dict tag optimization"); } +#if IS_PYTHON_3_12_PLUS + // Ideally we don't need to even register a weakref callback for value. + // But it does not hurt to be more cautious + _dict_callback_installed = watch_dict_pointers(value); +#endif } } if (!result) { @@ -2979,8 +3026,9 @@ class GuardManager { } GuardManager* guard_manager = static_cast( PyCapsule_GetPointer(self_capsule, "GuardManager*")); - if (guard_manager) - guard_manager->_disable_dict_tag_matching = true; + if (guard_manager) { + guard_manager->disable_recursive_dict_tag_optimization(); + } Py_RETURN_NONE; } @@ -3031,6 +3079,81 @@ class GuardManager { return true; } + bool watch_dict_pointers(PyObject* value) { +#if IS_PYTHON_3_12_PLUS + // ----------------------------------------------------------------------------- + // CPython 3.12 dict-watcher integration + // ----------------------------------------------------------------------------- + // + // We register a single watcher on all every dictionary pointer recorded by + // a tag-safe root. The watcher callback fires *once* for any structural + // change to those dictionaries + // + // Fast-path benefit + // ----------------- + // In steady state we no longer need to iterate over the recorded + // dictionaries and compare their `ma_version_tag`s (the + // “are-tags-unchanged” loop that used to dominate the fast-path guard + // evaluation). The presence of an *active watcher* is itself a guarantee + // that none of the dicts has mutated; if one **does** mutate, the callback + // simply flips `_disable_dict_tag_matching = true`, causing the next guard + // evaluation to skip the recursive-dict-tag optimisation entirely. + for (auto& kv : _dict_pointers[value]) { + PyObject* dict_pointer = kv.first; + int rc = PyDict_Watch(dict_recursive_tag_watcher_id, dict_pointer); + if (rc != 0) { + PyErr_Clear(); + return false; + } + dict_to_guard_managers[dict_pointer].push_back(this); + } +#endif + return true; + } + + void unwatch_all_saved_dict_pointers() { + /* + We may have recorded hundreds/thousands of dict pointers for the recursive + dict-tag optimisation. If any of those dicts mutates, we want to disable the + optimisation and then unwatch as many dict pointers as we can. + + Be careful: the same dict pointer can be recorded by multiple GuardManagers. + So the flow is: + + 1) Remove *this* GuardManager from dict_to_guard_managers[dict_pointer]. + 2) If the list for that dict becomes empty, then: + - PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer) + - erase the dict_pointer entry from dict_to_guard_managers. + */ +#if IS_PYTHON_3_12_PLUS + if (!_disable_dict_tag_matching) { + for (auto& value_stashed_pointers : _dict_pointers) { + auto stashed_pointers = value_stashed_pointers.second; + + for (auto& stashed_pointer : stashed_pointers) { + PyObject* dict_pointer = stashed_pointer.first; + + // Delete the guard manager from the dict_to_guard_managers + auto it = std::find( + dict_to_guard_managers[dict_pointer].begin(), + dict_to_guard_managers[dict_pointer].end(), + this); + if (it != dict_to_guard_managers[dict_pointer].end()) { + dict_to_guard_managers[dict_pointer].erase(it); + } + + // Unwatch the dict pointer if this was the last guard manager + // watching it. + if (dict_to_guard_managers[dict_pointer].empty()) { + PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer); + dict_to_guard_managers.erase(dict_pointer); + } + } + } + } +#endif + } + virtual bool check_nopybind(FrameLocalsMapping* value) { return check_nopybind_template(value); } @@ -3270,6 +3393,9 @@ class GuardManager { std::unordered_map> _tensor_pointers; std::vector _tag_safe_entries; + // 3.12+ related helper + bool _dict_callback_installed = false; + protected: // weakref to the type of guarded value // protected because it is used for cloning by DictGuardManager @@ -3957,6 +4083,27 @@ void add_relational_guard_resetter_to_cloned_root( root->add_relational_guard_resetter(std::move(guard)); } +#if IS_PYTHON_3_12_PLUS +static int dict_recursive_tag_watch_callback( + PyDict_WatchEvent event, + PyObject* dict, + PyObject* key, + PyObject* new_value) noexcept { + if (event != PyDict_EVENT_CLONED) { + auto it = dict_to_guard_managers.find(dict); + if (it != dict_to_guard_managers.end()) { + auto guard_managers = it->second; + for (auto& guard_manager : guard_managers) { + if (guard_manager) { + guard_manager->disable_recursive_dict_tag_optimization(); + } + } + } + } + return 0; // keep watching +} +#endif + std::unique_ptr make_guard_manager( RootGuardManager* root, std::string source, @@ -7558,6 +7705,13 @@ PyObject* torch_c_dynamo_guards_init() { throw std::runtime_error("Failed to install dict_version_watch_callback"); } + dict_recursive_tag_watcher_id = + PyDict_AddWatcher(dict_recursive_tag_watch_callback); + if (dict_recursive_tag_watcher_id == -1) { + throw std::runtime_error( + "Failed to install dict_recursive_tag_watch_callback"); + } + #endif return m; From f33ce40bc062a281e1a1f57e8c1926d0a7d155cc Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 7 Aug 2025 03:15:48 -0700 Subject: [PATCH 0258/1424] [bucketing] Bucket only adjacent collectives to prevent reordering (#159983) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159983 Approved by: https://github.com/wconstab, https://github.com/eellison --- test/distributed/test_inductor_collectives.py | 35 +++++++++++++------ torch/_inductor/fx_passes/bucketing.py | 31 ++++++++++++---- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index d0b8c32497f04..f7cf7764df56e 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1524,39 +1524,49 @@ def _reorder_communication_preserving_peak_memory( @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not SM80OrLater, "bfloat16") def test_all_gather_bucket(self): - def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) - # cast the inputs - ag_0_cast = ag_0.to(torch.bfloat16) ag_1_cast = ag_1.to(torch.bfloat16) - # allgather group_name = ( torch.distributed.distributed_c10d._get_default_group().group_name ) + ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_2, group_size, group_name + ) + ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out) + + ag_0 = ag_2_out + ag_0 + ag_0_cast = ag_0.to(torch.bfloat16) + ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor( ag_0_cast, group_size, group_name ) ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out) ag_0_out = ag_0_out * 2 - ag_1_cast = ag_1_cast * 2 ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor( ag_1_cast, group_size, group_name ) - # wait op ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out) - return y, ag_0_out, ag_1_out + ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_3, group_size, group_name + ) + ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out) + return y, ag_0_out, ag_1_out, ag_2_out, ag_3_out x = torch.ones(4, 384, device="cuda", dtype=torch.float32) w = torch.ones(384, 512, device="cuda", dtype=torch.float32) ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32) - inputs = [x, w, ag_0, ag_1] + ag_2 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_3 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1, ag_2, ag_3] + correct = func(*inputs, **self.get_world_trs()) with torch._inductor.config.patch( { @@ -1568,9 +1578,14 @@ def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) # NOTE: The first return value should be the output of the first wait_tensor. # We want to make sure no unnecessary copy is made. - (FileCheck().check("all_gather_into_tensor_out").run(code)) + ( + FileCheck() + .check("= torch.ops._c10d_functional.all_gather_into_tensor") + .check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(") + .check("= torch.ops._c10d_functional.all_gather_into_tensor") + .run(code) + ) out = compiled(*inputs, **self.get_world_trs()) - correct = func(*inputs, **self.get_world_trs()) assert same(out, correct), f"{out} va {correct}" @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 75dd3678d51c7..3bf1ff9dab86e 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -93,6 +93,12 @@ def greedy_bucket_collective_by_mb( node_group_key: Callable[[torch.fx.Node], Any], filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, ) -> list[list[torch.fx.Node]]: + """ + Bucketing adjacent collectives with equal node_group_key. + We can not bucket non adjacent collectives, + as this will effectively change the order of collectives. + Reordering can lead to different order on different ranks. + """ g = gm.graph found_candidates = False for node in g.nodes: @@ -102,10 +108,12 @@ def greedy_bucket_collective_by_mb( if not found_candidates: return [] - nodes_groups: dict[Any, list[torch.fx.Node]] = defaultdict(list) nodes_successors: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = defaultdict( OrderedSet ) + nodes_groups: list[list[torch.fx.Node]] = [] + cur_group: list[torch.fx.Node] = [] + cur_group_key = None for node in g.nodes: for n, successors in nodes_successors.items(): @@ -115,10 +123,19 @@ def greedy_bucket_collective_by_mb( if (filter_wait_node is None) or filter_wait_node(node): coll_node = node.args[0] group_key = node_group_key(coll_node) - nodes_groups[group_key].append(coll_node) + if group_key == cur_group_key: + cur_group.append(coll_node) + else: + if len(cur_group) > 1: + nodes_groups.append(cur_group) + cur_group = [coll_node] + cur_group_key = group_key + + if len(cur_group) > 1: + nodes_groups.append(cur_group) buckets: list[list[torch.fx.Node]] = [] - for nodes in nodes_groups.values(): + for nodes in nodes_groups: cur_bucket: list[torch.fx.Node] = [] cur_bucket_successors: OrderedSet[torch.fx.Node] = OrderedSet() cur_bucket_size_bytes: int = 0 @@ -128,7 +145,7 @@ def greedy_bucket_collective_by_mb( ) for node in nodes: if node in cur_bucket_successors: - # We can not bucket successors with the node + # We cannot bucket successors with the node continue assert "val" in node.meta n_val = node.meta["val"] @@ -163,7 +180,7 @@ def bucket_all_gather_by_mb( Args: gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers. - bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket + bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow to specify different sizes of the buckets at the start, as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx @@ -201,14 +218,14 @@ def bucket_reduce_scatter_by_mb( Args: gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters. - bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket + bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow to specify different sizes of the buckets. filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified, only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed. Returns: - list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes. + list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes. """ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: From 7fbc22855c17741ae016992803b2e147a13aa22d Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Tue, 12 Aug 2025 14:02:36 +0000 Subject: [PATCH 0259/1424] Update triton xpu commit to support python 3.14 (#160183) Follow PR #159725 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160183 Approved by: https://github.com/EikanWang, https://github.com/atalman --- .ci/docker/ci_commit_pins/triton-xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index 80d7d7ed18af9..3c187be1bb649 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -ae324eeac8e102a2b40370e341460f3791353398 +0958dc9b2bb815e428f721f9da599dab0dc1c5d7 From a288b15ea9f87ddd665f249d492e0fb0861f5a69 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Tue, 12 Aug 2025 14:04:26 +0000 Subject: [PATCH 0260/1424] [CI] Reduce XPU Windows build time (#159763) Reduce the time cost from 2.5 hours to about 1.5 hours. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159763 Approved by: https://github.com/EikanWang, https://github.com/atalman --- .ci/pytorch/win-test-helpers/build_pytorch.bat | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 7ceb425ce2d1a..19d715b9d0b6d 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -61,9 +61,10 @@ if "%USE_XPU%"=="1" ( call "C:\Program Files (x86)\Intel\oneAPI\compiler\latest\env\vars.bat" call "C:\Program Files (x86)\Intel\oneAPI\ocloc\latest\env\vars.bat" if errorlevel 1 exit /b 1 - :: Reduce build time. Only have MTL self-hosted runner now - SET TORCH_XPU_ARCH_LIST=xe-lpg - SET USE_KINETO=0 + :: Reduce build time + SET TORCH_XPU_ARCH_LIST=bmg + :: Re-setup python env for build + call pip install -r requirements.txt ) @echo on From 9708fcf92db88b80b9010c68662d634434da3106 Mon Sep 17 00:00:00 2001 From: James Wu Date: Sun, 10 Aug 2025 15:38:35 -0700 Subject: [PATCH 0261/1424] Account for triton kernel source code hidden in custom ops properly in AOTAutogradCache (#160120) This PR fixes a bug where user defined triton kernels hidden behind `triton_op` do not register source code changes. If a user *only* changes a triton kernel source_code, because triton kernels are hidden under the custom op, dynamo hasn't traced into them yet. This means at AOTAutograd time, we don't know the list of triton kernels that are defined by custom ops. This is an initial fix for the issue by parsing the AST of the custom op looking for triton kernels. This won't catch more degenerate cases if the custom op calls other custom ops/functions that then call triton kernels, and then the toplevel compiled graph doesn't know about it. To handle that, we'd have to trace through the custom op at dynamo time. This should handle 99% of cases, though. I added an expectedFailure test to show the limitation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160120 Approved by: https://github.com/zou3519 --- test/dynamo/test_aot_autograd_cache.py | 209 +++++++++++++++++- .../_aot_autograd/autograd_cache.py | 37 ++++ torch/_library/custom_ops.py | 1 + torch/_library/triton.py | 77 +++++++ 4 files changed, 323 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 2895c8991c22c..7e6895ccde5cd 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -789,7 +789,6 @@ def fn(a): self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) @requires_cuda_and_triton - @requires_triton() @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -842,6 +841,214 @@ def fn(a): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"autograd_cache_allow_custom_autograd_functions": True}) + def test_custom_autograd_function_with_custom_triton_kernel_cache_invalidation( + self, + ): + @triton.jit + def my_jit(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + class MyAutogradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + y = torch.ops.test.my_triton_op(x) + ctx.save_for_backward(y) + ctx.foo = x.cos() + return y + + @staticmethod + def backward(ctx, grad_output): + result = ctx.saved_tensors[0] + return grad_output * result + ctx.foo * grad_output + + def fn(a): + return MyAutogradFunction.apply(a) + + a = torch.randn(5, device=GPU_TYPE, requires_grad=True) + a2 = a.clone().detach_().requires_grad_(True) + a3 = a.clone().detach_().requires_grad_(True) + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + result.sum().backward() + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Clear dynamo and run again. Should be a cache hit. + counters.clear() + self._clear_dynamo_and_codecache() + result = compiled_fn(a2) + self.assertEqual(fn(a2), result) + result.sum().backward() + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) + + # Now modify the source code of my_jit by redefining it + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) # Changed from +1 to +2 + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + # Clear dynamo and run again. Should be a cache miss due to modified source code. + counters.clear() + self._clear_dynamo_and_codecache() + compiled_fn = torch.compile(fn, backend="inductor") + + result = compiled_fn(a3) + # Assert that after changing the source code, the cache no longer hits + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(fn(a3), result) + + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_triton_op_cache_invalidation(self): + from torch._library import capture_triton + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + capture_triton(my_jit)[1,](y) + return y + + def fn(a): + return torch.ops.test.my_triton_op(a) + + a = torch.randn(5, device=GPU_TYPE) + a2 = a.clone().detach_() + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + self._clear_dynamo_and_codecache() + + # Redefine the triton op + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a2) + + # Second run should still miss + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + + self.assertEqual(fn(a2), result) + + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @unittest.expectedFailure # Currently ops that call other ops does not properly invalidate cache + def test_triton_op_cache_multiple_ops_invalidation(self): + @triton.jit + def my_jit(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @triton.jit + def my_jit2(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + torch._library.capture_triton(my_jit2)[1,](y) + return y + + @torch._library.triton_op("test::my_triton_op2", mutates_args=()) + def my_triton_op2(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch.ops.test.my_triton_op(y) + return y + + def fn(a): + return torch.ops.test.my_triton_op2(a) + + a = torch.randn(5, device=GPU_TYPE) + a2 = a.clone().detach_() + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + self._clear_dynamo_and_codecache() + + # Redefine the triton op + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + torch._library.capture_triton(my_jit2)[1,](y) + return y + + @torch._library.triton_op("test::my_triton_op2", mutates_args=()) + def my_triton_op2(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch.ops.test.my_triton_op(y) + return y + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a2) + + # Second run should still miss + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + + self.assertEqual(fn(a2), result) + @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch({"fx_graph_cache": True}) @functorch_config.patch({"enable_autograd_cache": True}) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 7217a9c9b3903..248c3a0ae673e 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -302,6 +302,42 @@ class AOTAutogradCacheDetails(FxGraphHashDetails): a safe and stable cache key for AOTAutograd. """ + def get_triton_source_codes_from_gm( + self, + gm: torch.fx.GraphModule, + ): + triton_kernels = [] + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if isinstance(node.target, torch._ops.OpOverloadPacket): + attrs = node.target._dir + for attr in attrs: + if custom_op := getattr(node.target, attr, None): + kernels = torch._library.triton.get_triton_kernels_for_op( + custom_op._name + ) + triton_kernels.extend(kernels) + elif isinstance(node.target, torch._ops.OpOverload): + kernels = torch._library.triton.get_triton_kernels_for_op( + node.target._name + ) + triton_kernels.extend(kernels) + + triton_kernel_source_codes = [] + from torch._inductor.codegen.wrapper import ( + user_defined_triton_kernel_transitive_closure_source_code, + ) + + for kernel in triton_kernels: + source_codes = user_defined_triton_kernel_transitive_closure_source_code( + kernel + ) + triton_kernel_source_codes.append(source_codes) + + return triton_kernel_source_codes + def __init__( self, gm: torch.fx.GraphModule, @@ -319,6 +355,7 @@ def __init__( [], [], ) + self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm) if hasattr(gm, "saved_tensors_hooks_pack_0"): diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index bd8acb2789e16..251cdefe0f05d 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -210,6 +210,7 @@ def __init__( self._lib = get_library_allowing_overwrite(self._namespace, self._name) self._register_to_dispatcher(self._tags) self._disabled_kernel: set = set() + self._used_triton_kernels: list[Any] = list() OPDEFS[self._qualname] = self @property diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 17d02a9945630..741b341f7e210 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -1,4 +1,6 @@ +import ast import contextlib +import inspect import threading from collections.abc import Generator, Iterable from typing import Any, Callable, Optional, Union @@ -9,6 +11,79 @@ from .infer_schema import infer_schema +triton_ops_to_kernels: dict[str, list[object]] = {} + + +def get_triton_kernels_for_op(name: str) -> list[object]: + return triton_ops_to_kernels.get(name, []) + + +def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]: + """ + Inspect the source of an arbitrary callable passed to torch._library.triton_op, + and grab all of the triton kernels that are wrapped inside of it. + + TODO: This check is best effort. It does *not* handle the case where the triton + kernel is hidden behind recursive function calls. + """ + + def find_triton_kernels(fn: Callable[..., Any]) -> list[object]: + try: + source = inspect.getsource(fn) + except (OSError, TypeError): + return [] # Source code not available + + from torch._inductor.utils import IndentedBuffer + + buffer = IndentedBuffer() + buffer.splice(source, strip=True) + tree = ast.parse(buffer.getrawvalue()) + + # Visitor to collect function calls and triton kernels + class Visitor(ast.NodeVisitor): + def __init__(self) -> None: + self.triton_kernels: list[Any] = [] + + def visit_Call(self, node: ast.Call) -> None: + triton_func_names = ("capture_triton", "wrap_triton") + if isinstance(node.func, ast.Attribute): + attr = node.func + if ( + isinstance(attr.value, ast.Attribute) + and isinstance(attr.value.value, ast.Name) + and attr.value.value.id == "torch" + and attr.value.attr == "_library" + and attr.attr in triton_func_names + ): + if node.args and isinstance(node.args[0], ast.Name): + self.triton_kernels.append(node.args[0].id) + + # Catch capture_triton, wrap_triton that's been + # imported directly + elif isinstance(node.func, ast.Name): + if node.func.id in triton_func_names: + if node.args and isinstance(node.args[0], ast.Name): + self.triton_kernels.append(node.args[0].id) + + self.generic_visit(node) + + collector = Visitor() + collector.visit(tree) + closure_vars = inspect.getclosurevars(fn) + resolved = [] + # First, resolve triton kernel names + for name in collector.triton_kernels: + if name in closure_vars.nonlocals: + resolved.append(closure_vars.nonlocals[name]) + elif name in closure_vars.globals: + resolved.append(closure_vars.globals[name]) + elif name in closure_vars.builtins: + resolved.append(closure_vars.builtins[name]) + return resolved + + return find_triton_kernels(fn) + + @exposed_in("torch.library") def triton_op( name: str, @@ -175,6 +250,8 @@ def functional_decomp( # type: ignore[no-untyped-def] with mode: return fn(*args, **kwargs) + triton_kernels = get_inner_triton_kernels(fn) + triton_ops_to_kernels[name] = triton_kernels result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) return result From b7db86600a2614adc71c92ca42d359a7ac534d78 Mon Sep 17 00:00:00 2001 From: atalman Date: Tue, 12 Aug 2025 15:15:12 +0000 Subject: [PATCH 0262/1424] Fix Tensor illustration, use permalinks for image embedding in Readme.md (#160416) Fixes Tensor illustration being broken on pypi.org. Also uses permalinks instead of links to images for embedding as per this suggestion of Alban: https://github.com/pytorch/pytorch/pull/160187#discussion_r2262978006 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160416 Approved by: https://github.com/malfet --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 16000850ae920..03f76893e3e8d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -![PyTorch Logo](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/pytorch-logo-dark.png) +![PyTorch Logo](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/pytorch-logo-dark.png) -------------------------------------------------------------------------------- @@ -72,7 +72,7 @@ Elaborating Further: If you use NumPy, then you have used Tensors (a.k.a. ndarray). -![Tensor illustration](./docs/source/_static/img/tensor_illustration.png) +![Tensor illustration](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/tensor_illustration.png) PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the computation by a huge amount. @@ -99,7 +99,7 @@ from several research papers on this topic, as well as current and past work suc While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research. -![Dynamic graph](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif) +![Dynamic graph](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/dynamic_graph.gif) ### Python First From b219ca2a00a305753c4f1ea4c9c5d23243d54753 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 12 Aug 2025 15:29:19 +0000 Subject: [PATCH 0263/1424] Revert "Update triton xpu commit to support python 3.14 (#160183)" This reverts commit 7fbc22855c17741ae016992803b2e147a13aa22d. Reverted https://github.com/pytorch/pytorch/pull/160183 on behalf of https://github.com/clee2000 due to I'm not sure how, but it seems to have broken inductor/test_extension_backend.py::ExtensionBackendTests::test_open_device_registration [GH job link](https://github.com/pytorch/pytorch/actions/runs/16911267995/job/47917091939) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/7fbc22855c17741ae016992803b2e147a13aa22d). Maybe because the docker build changed? Note to self: not bad TD ([comment](https://github.com/pytorch/pytorch/pull/160183#issuecomment-3179840160)) --- .ci/docker/ci_commit_pins/triton-xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index 3c187be1bb649..80d7d7ed18af9 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -0958dc9b2bb815e428f721f9da599dab0dc1c5d7 +ae324eeac8e102a2b40370e341460f3791353398 From 9d37c960a4fc44d5ac334ca8bf775f85b95d76fc Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 12 Aug 2025 16:07:19 +0000 Subject: [PATCH 0264/1424] [ROCm][CI] use new benchmark image for dynamo (#160421) Follow-up to #160047 that separated the rocm image into default CI and benchmarks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160421 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .github/workflows/inductor-periodic.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index db6a235b8c864..fdb54978e8082 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -77,7 +77,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-jammy-rocm-py3_10 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks sync-tag: rocm-build test-matrix: | { include: [ From f7b2f3314cf7aede67d5fa5c75e4243208484344 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 12 Aug 2025 16:33:02 +0000 Subject: [PATCH 0265/1424] Revert "[triton_heuristics] Optimize the triton launcher in pt2 (#160000)" This reverts commit d0e2240f680ea2a553f7ee8188f52482e130bfd0. Reverted https://github.com/pytorch/pytorch/pull/160000 on behalf of https://github.com/davidberard98 due to D80054972 failing with test_triton_kernel_2d_autotune_grad_False_dynamic_True_backend_inductor_grid_type_1_tdlp_1 ([comment](https://github.com/pytorch/pytorch/pull/160000#issuecomment-3180144676)) --- torch/_inductor/ir.py | 3 - torch/_inductor/runtime/triton_heuristics.py | 65 +++++++++++--------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 47167b180f52e..a668cd41ebf1b 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6630,9 +6630,6 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: for name, arg in itertools.chain( named_args.items(), zip(itertools.repeat(""), extra_launch_args) ): - if name in constexpr_names and triton_version_uses_attrs_dict(): - # see #160000 - we don't pass in constexpr args to speed up runtime. - continue raw_keys_filtered.append(name) raw_args_filtered.append(arg) if isinstance(arg, IRNode): diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 47516a4a71c47..8425cba55795a 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -196,7 +196,8 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): call_kwargs[k] = v else: call_kwargs[k] = v - call_kwargs.update(launcher.config.kwargs) + if not triton_version_uses_attrs_dict(): + call_kwargs.update(launcher.config.kwargs) call_kwargs["num_warps"] = launcher.config.num_warps call_kwargs["num_stages"] = launcher.config.num_stages if HAS_WARP_SPEC: @@ -769,6 +770,28 @@ def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta) + def _get_args_with_constexprs(self, args, launcher): + """ + `args` is passed in with only the non-constexpr args (because the constexpr arg values + depend on the config). However, in later triton versions, the constexpr args need to be + added into the args list. + """ + if triton_version_uses_attrs_dict(): + # first: aggregate the constexpr args in (index, val) pairs + # so we can sort them by index. + constexpr_args: list[tuple[int, Any]] = [] + for arg_name, arg_val in launcher.config.kwargs.items(): + if arg_name in self.fn.arg_names: + constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) + + constexpr_args.sort() + new_args = [*args] + for arg_idx, arg_val in constexpr_args: + new_args.insert(arg_idx, arg_val) + + return new_args + return args + def bench(self, launcher, *args, with_profiler=False, **kwargs): """Measure the performance of a given launcher""" # we don't skip configs with spilled registers when auto-tuning custom @@ -797,22 +820,23 @@ def kernel_call(): ) # reset to zero before evaluating any config self.reset_to_zero_args(*args, **kwargs) + args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher) if autograd_profiler._is_profiler_enabled: profiler_kwargs = self.get_profiler_kwargs(stream, launcher) with torch._C._profiler._RecordFunctionFast( self.inductor_meta.get("kernel_name", "triton kernel"), - cloned_args, + args_with_constexprs, profiler_kwargs, ): launcher( - *cloned_args, + *args_with_constexprs, **cloned_kwargs, stream=stream, ) else: launcher( - *cloned_args, + *args_with_constexprs, **cloned_kwargs, stream=stream, ) @@ -1216,6 +1240,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): # so _RecordFunctionFast need to capture the args into CachingAutotuner::run() # make a copy here to avoid mutating the original args args_without_constexprs = tuple(args) + args = self._get_args_with_constexprs(args, launcher) if self.dump_launch_params: new_args, grid = self._interpret_args_grid(args, launcher.config) @@ -1271,10 +1296,6 @@ def __call__(self, _=None) -> str: class CompileResult(Generic[_T]): - """ - Base class representing compiled result. - """ - def __init__( self, kernel: _T, @@ -1338,30 +1359,21 @@ def _get_arg_lists( ) none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys())) - def _convert_constant(constant): - if isinstance(constant, str): - return "r'" + constant + "'" - else: - return repr(constant) - if triton_version_uses_attrs_dict(): call_args = arg_names def_args = arg_names - implicit_constants = OrderedSet( - ( - "num_warps", - "num_stages", - ) - ).union(OrderedSet(k for k in known_constants)) - if implicit_constants := implicit_constants & OrderedSet( - compile_meta["constants"].keys() + if ( + "num_warps" in compile_meta["constants"] + or "num_stages" in compile_meta["constants"] ): # num_warps/num_stages are special implicit args that are not in the signature # see test_triton_kernel_special_params - def_args = [arg for arg in def_args if arg not in implicit_constants] + def_args = [ + arg for arg in def_args if arg not in ("num_warps", "num_stages") + ] repl = { - k: _convert_constant(compile_meta["constants"].get(k)) - for k in implicit_constants + k: str(compile_meta["constants"].get(k)) + for k in ("num_warps", "num_stages") } call_args = [repl.get(arg, arg) for arg in call_args] else: @@ -1641,8 +1653,6 @@ def make_launcher(self) -> LauncherType: import math as math_lib - import triton as triton_lib - import torch as torch_lib scope = { @@ -1677,7 +1687,6 @@ def make_launcher(self) -> LauncherType: "runner": get_first_attr(binary, "run", "c_wrapper"), "math": math_lib, "torch": torch_lib, - "triton": triton_lib, } if not hasattr(binary, "launch_metadata"): From a7abf57aabec0ce686092e2d66e53ba185dbc56b Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 12 Aug 2025 16:42:55 +0000 Subject: [PATCH 0266/1424] [ROCm] Support large inputs for coalesceValuesKernel (#158281) # Description `.coalesce` cannot handle large inputs on ROCM due to maximal grid size limit. This PR splits axis `X` into axes `X` and `Y`, and repurposes `Z` for original `Y` on ROCm to avoid such limitation. Confirmed the new approach can handle large inputs. Correctness needs validation. # Testing Command `python torch_spmv.py 22500000 272500000` ## Script `torch_spmv.py` ``` python import torch import argparse def parse_args(): parser = argparse.ArgumentParser( description="Sparse COO Matrix by Dense Vector Multiplication using PyTorch" ) parser.add_argument("n", type=int, help="Size of the NxN matrix") parser.add_argument("nnz", type=int, help="Number of non-zero entries") return parser.parse_args() def main(): args = parse_args() n = args.n nnz = args.nnz dtype = torch.float32 device = torch.device('cuda') # Generate random indices for the sparse matrix in COO format. torch.manual_seed(42) rows = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device) cols = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device) indices = torch.stack([rows, cols], dim=0) # Generate random values. values = torch.randn(nnz, dtype=torch.float32, device=device) # Create the sparse COO matrix and move it to the target device. sparse_matrix = torch.sparse_coo_tensor(indices, values, size=(n, n), dtype=torch.float32, device=device) sparse_matrix = sparse_matrix.coalesce() # Generate a random dense vector. dense_vector = torch.randn(n, dtype=torch.float32, device=device) # Perform sparse matrix - dense vector multiplication. # Using torch.sparse.mm which expects a 2D tensor for the vector. result = torch.sparse.mm(sparse_matrix, dense_vector.unsqueeze(1)).squeeze() # result = torch.mv(sparse_matrix, dense_vector) # Print the result. print("Result of the multiplication:") print(torch.sum(result)) if __name__ == "__main__": main() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158281 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily --- .../sparse/cuda/SparseCUDAApplyUtils.cuh | 32 ++++++++++++++++--- .../native/sparse/cuda/SparseCUDATensor.cu | 10 ++++++ test/test_sparse.py | 15 ++++++++- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh index 693ca536a3198..c11588a32ba05 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh @@ -196,9 +196,17 @@ C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, Dtype *values, Dtype *newValues, - int64_t nnz, int64_t newNnz, int64_t stride) { + int64_t nnz, int64_t newNnz, +#ifdef USE_ROCM + int64_t nsegments, +#endif + int64_t stride) { - int seg = blockIdx.x * 4 + threadIdx.y; +#ifdef USE_ROCM + int64_t seg = (blockIdx.x * gridDim.y + blockIdx.y) * 4 + threadIdx.y; +#else + int64_t seg = blockIdx.x * 4 + threadIdx.y; +#endif // Number of values processed by each thread (grain size) const int SZ = 4; @@ -207,7 +215,11 @@ __global__ void coalesceValuesKernel( const int newValueRow = seg * stride; const int begin = segment_offsets[seg]; const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz; +#ifdef USE_ROCM + const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ; +#else const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; +#endif Acctype tmp[SZ]; #pragma unroll for (int ii = 0; ii < SZ; ii++) { @@ -250,9 +262,17 @@ C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4) __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, bool *values, bool *newValues, - int64_t nnz, int64_t newNnz, int64_t stride) { + int64_t nnz, int64_t newNnz, +#ifdef USE_ROCM + int64_t nsegments, +#endif + int64_t stride) { - int seg = blockIdx.x * 4 + threadIdx.y; +#ifdef USE_ROCM + int64_t seg = (blockIdx.x * gridDim.y + blockIdx.y) * 4 + threadIdx.y; +#else + int64_t seg = blockIdx.x * 4 + threadIdx.y; +#endif // Number of values processed by each thread (grain size) const int SZ = 4; @@ -261,7 +281,11 @@ __global__ void coalesceValuesKernel( const int newValueRow = seg * stride; const int begin = segment_offsets[seg]; const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz; +#ifdef USE_ROCM + const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ; +#else const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; +#endif bool tmp[SZ]; #pragma unroll for (int ii = 0; ii < SZ; ii++) { diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index a36ec9b203fc3..2e84ca8982fb2 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -106,7 +106,14 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { values = values.contiguous(); int64_t stride = c10::multiply_integers(values.sizes().slice(1)); int warp_size = at::cuda::warp_size(); +#ifdef USE_ROCM + const int64_t BATCHING_SEGMENT = 4096; + int64_t nsegments = ceil_div(newNnz, (int64_t) SZ); + int64_t s_batch = ceil_div(nsegments, BATCHING_SEGMENT); + dim3 grid(s_batch, (s_batch == 1) ? nsegments : BATCHING_SEGMENT, ceil_div(stride, (int64_t) warp_size*SZ)); +#else dim3 grid(ceil_div(newNnz, (int64_t) SZ), ceil_div(stride, (int64_t) warp_size*SZ)); +#endif dim3 block(warp_size, SZ); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( at::ScalarType::ComplexHalf, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, @@ -119,6 +126,9 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { newValues.data_ptr(), nnz, newNnz, +#if USE_ROCM + nsegments, +#endif stride ); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/test/test_sparse.py b/test/test_sparse.py index 608b5ef13c1be..cef3adb34721b 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -21,7 +21,7 @@ (SM53OrLater, SM80OrLater, TEST_MULTIGPU) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride, - deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes) + deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf) from torch.testing._internal.common_methods_invocations import \ (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs) from torch.testing._internal.common_dtype import ( @@ -367,6 +367,19 @@ def _test_coalesce(t): t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size, dtype, device, coalesced) _test_coalesce(t) # this tests correctness + @onlyCUDA + @skipCUDAIf(not SM80OrLater and not TEST_WITH_ROCM, "CUDA capability < SM80 and not ROCM") + @dtypes(torch.float) + def test_coalesce_accepts_large_tensor(self, device, dtype): + N = 22500000 + NNZ = 272500000 + rows = torch.randint(0, N, (NNZ,), dtype=torch.int64, device=device) + cols = torch.randint(0, N, (NNZ,), dtype=torch.int64, device=device) + indices = torch.stack([rows, cols], dim=0) + values = torch.randn(NNZ, dtype=dtype, device=device) + sparse_matrix = torch.sparse_coo_tensor(indices, values, size=(N, N), dtype=torch.float32, device=device) + sparse_matrix = sparse_matrix.coalesce() + @dtypes(torch.double) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/89395") def test_coalesce_reference_cycle(self, device, dtype): From 94b91a876327820a4bb6f5d39d156f13f2553ab6 Mon Sep 17 00:00:00 2001 From: Jovian Anthony Jaison Date: Tue, 12 Aug 2025 16:49:05 +0000 Subject: [PATCH 0267/1424] [redone][pytorch] Moving torch.compile worker process logs to a dedicated rank based log directory (#160352) Summary: Writing torch.compile worked logs to dedicated_log_rank{RANK} if we're running on mast. ref: D79456310 (got reverted because of linter) Testing: Refer differential Revision: D79917440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160352 Approved by: https://github.com/masnesral --- test/inductor/test_compile_worker.py | 14 ++++++++++++++ .../_inductor/compile_worker/subproc_pool.py | 19 +++++++++++++++---- torch/_inductor/config.py | 18 ++++++++++++++++++ 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index dcbf1b380934f..8fde26c6acf67 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import operator import os +import tempfile from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, @@ -66,6 +67,19 @@ def test_quiesce(self): finally: pool.shutdown() + @skipIfWindows(msg="pass_fds not supported on Windows.") + def test_logging(self): + os.environ["MAST_HPC_JOB_NAME"] = "test_job" + os.environ["ROLE_RANK"] = "0" + with tempfile.NamedTemporaryFile(delete=True) as temp_log: + os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name + pool = SubprocPool(2) + try: + pool.submit(operator.add, 100, 1) + self.assertEqual(os.path.exists(temp_log.name), True) + finally: + pool.shutdown() + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 0b670b268b37e..7c05b01f45d77 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -145,10 +145,19 @@ def __init__( f"--write-fd={str(subproc_write_fd)}", f"--torch-key={torch_key_str}", ] - local = False + log_path = None + self.log_file = None + if config.worker_suppress_logging: + log_path = os.devnull log.info("Suppressing compile worker output due to config") - local = True + else: + log_path = config.torchinductor_worker_logpath + if not log_path: + log_path = config.get_worker_log_path() + + if log_path: + self.log_file = open(log_path, "w") self.process = subprocess.Popen( cmd, @@ -164,8 +173,8 @@ def __init__( "LD_LIBRARY_PATH": get_ld_library_path(), }, pass_fds=(subproc_read_fd, subproc_write_fd), - stdout=subprocess.DEVNULL if local else None, - stderr=subprocess.DEVNULL if local else None, + stdout=self.log_file, + stderr=self.log_file, ) self.write_lock = threading.Lock() self.read_thread = threading.Thread( @@ -262,6 +271,8 @@ def shutdown(self) -> None: _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) self.write_pipe.close() self.process.wait(300) + if self.log_file: + self.log_file.close() except OSError as e: log.warning("Ignored OSError in pool shutdown: %s", e) finally: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 770da725a9aad..deebfa273ba14 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1020,6 +1020,24 @@ def decide_compile_threads() -> int: autotune_lookup_table: dict[str, dict[str, Any]] = {} +def get_worker_log_path() -> Optional[str]: + log_loc = None + if is_fbcode(): + mast_job_name = os.environ.get("MAST_HPC_JOB_NAME", None) + global_rank = os.environ.get("ROLE_RANK", "0") + + if mast_job_name is not None: + log_loc = f"/logs/dedicated_log_torch_compile_worker_rank{global_rank}" + + return log_loc + + +torchinductor_worker_logpath: str = Config( + env_name_force="TORCHINDUCTOR_WORKER_LOGPATH", + default="", +) + + # config specific to codegen/cpp.py class cpp: """ From 1f4057c11ac941fb324386ca594d0a6882185aad Mon Sep 17 00:00:00 2001 From: David Berard Date: Mon, 11 Aug 2025 17:03:20 +0000 Subject: [PATCH 0268/1424] [inductor] remove no_x_dim (#159810) no_x_dim is used to indicate that a reduction operates on a single row, and data loaded for the reduction is 1-dimensional. no_x_dim was introduced in https://github.com/pytorch/pytorch/pull/102444 - in which there was bad perf in some reductions, and using 1D tensors fixed the perf issue. However, it appears that this perf issue no longer exists in current Triton versions. https://github.com/pytorch/pytorch/pull/118822 checked this, and we can also check this on H100 benchmarks (linked below). And another motivation for removing this behavior is that it enables larger loads, which we observe is necessary for good performance on certain shapes on Blackwell. H100 inference benchmarks: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2004%20Aug%202025%2004%3A13%3A24%20GMT&stopTime=Mon%2C%2011%20Aug%202025%2004%3A13%3A24%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/davidberard98/396/orig&lCommit=a6bcd4692fb39fa2fad260f290bff545d4425829&rBranch=main&rCommit=e96c7c4bb0f6aeae2ab3b6f040f7d67edbec199a H100 training benchmarks: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2004%20Aug%202025%2004%3A13%3A24%20GMT&stopTime=Mon%2C%2011%20Aug%202025%2004%3A13%3A24%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/davidberard98/396/orig&lCommit=a6bcd4692fb39fa2fad260f290bff545d4425829&rBranch=main&rCommit=e96c7c4bb0f6aeae2ab3b6f040f7d67edbec199a Overall, the benchmarks show minimal change in performance. Differential Revision: [D79599286](https://our.internmc.facebook.com/intern/diff/D79599286) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159810 Approved by: https://github.com/ngimel, https://github.com/eellison --- .../test_torchinductor_strided_blocks.py | 25 ------------------- torch/_inductor/choices.py | 12 --------- torch/_inductor/codegen/triton.py | 10 +++----- 3 files changed, 4 insertions(+), 43 deletions(-) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index c203ea661fbe7..034f83096c1a6 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -746,31 +746,6 @@ def test_2d_reduction_odd_shapes( # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2) - def test_2d_reduction_no_x_dim(self): - """ - Tests a 2D reduction without an "x" dimension. - """ - # We need a size to get no x dim. - view = self._discontiguous_tensor((2, 346), self.device) - - # Expect 1 block pointer for the input. - result, (code,) = self._run_and_compare( - torch.prod, - view, - expected_num_block_pointers=1, - expected_num_triton_kernels=1, - config_patches=tiled_reduction_config, - ) - - # Check that there's no X dimension in the signature. - (signature_line,) = ( - line for line in code.splitlines() if line.startswith("def triton") - ) - self.assertNotIn("BLOCK", signature_line) - - # Check for 2 reduction dimensions in the body. - self._assert_reduction_ndims(code, 2) - @parametrize( "size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback", [ diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index d79db5f2a0539..aacb62c7a1234 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -196,18 +196,6 @@ def should_use_persistent_reduction( features.reduction_numel, threshold ) # type: ignore[arg-types] - @staticmethod - def want_no_x_dim(features: SIMDKernelFeatures) -> bool: - """ - Heuristic to decide if we should drop the X dimension from a persistent reduction kernel. - So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1. - Strangely this is faster than a [1, RBLOCK] block in some cases. - """ - return ( - features.get_reduction_hint() == ReductionHint.INNER - and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256) - ) - @staticmethod def reduction_split_factor( device: torch.device, diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0f9139ae0611a..e34fe5010d089 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2001,14 +2001,12 @@ def should_use_persistent_reduction(self) -> bool: ) def want_no_x_dim(self): - if ( + return ( self.persistent_reduction and len(self.numels) == self.num_reduction_dims + 1 - ): - if self.fixed_config: - return self.fixed_config["XBLOCK"] == 1 - return V.choices.want_no_x_dim(self.features) - return False + and self.fixed_config + and self.fixed_config["XBLOCK"] == 1 + ) @property def assert_function(self) -> str: From ee9f8ba11d664b871a9e0c7933fdc8571635b78c Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Tue, 12 Aug 2025 17:13:54 +0000 Subject: [PATCH 0269/1424] [ROCm] Use opportunistic fastatomics based on hueristics (#159430) * Opportunistic fast atomics works better with small sizes, since there is more chance of lanes doing atomics on the same address Co-author: @amd-hhashemi Reproducer: ``` import time import torch x = torch.randn((1_632_960, 128), device='cuda', dtype=torch.float) ind = torch.randint(0, x.size(0), size=(5_079_670,), device='cuda') src = torch.randn((5_079_670, 128), device='cuda', dtype=torch.float) for _ in range(20): x.index_add_(0, ind, src) start_time = time.time() for i in range(100): x.index_add_(0, ind, src) torch.cuda.synchronize() end_time = time.time() mean_time = (end_time - start_time)/100 print(f"Avg time for index_add_: {mean_time * 1e6:.2f} us") ``` Perf numbers: ``` Before: Avg time for index_add_: 25652.16 us After: Avg time for index_add_: 2675.15 us ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159430 Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily --- aten/src/ATen/native/cuda/KernelUtils.cuh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 1696ee64eac67..5bdb3f6cc67d4 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -282,6 +282,14 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd( } // not coalsced, so now let try to capture lane-matches... + + if (numel > 16 /*<-hueristic threshold*/ * 64 ) { + // well shucks, unlikely to capture same-dest atomics in a wave. + // fall back to direct fastAtomic... + fastAtomicAdd(self_ptr, index, numel, value, true); + return; + } + // __activemask() -- finds the set of threads in the warp that are about to perform atomicAdd // __match_any_sync() -- returns bit mask of the threads that have same dest addr auto mask = __match_any_sync(__activemask(), (int64_t)dst); From 3cec82a7e9aea040a34dd7a2587ae6d3bd65dba0 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 12 Aug 2025 06:23:03 -0700 Subject: [PATCH 0270/1424] Ensure outer aliasing on DTensor matches inner aliasing (#158954) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158954 Approved by: https://github.com/albanD, https://github.com/wconstab --- torch/distributed/tensor/_dispatch.py | 10 ++++++++-- torch/distributed/tensor/_op_schema.py | 6 ++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index faa2a1ba4941f..b562153ad507f 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -23,6 +23,7 @@ ) from torch.distributed.tensor._utils import try_find_mesh_from_args from torch.distributed.tensor.placement_types import Partial, Placement, Replicate +from torch.utils._python_dispatch import return_and_correct_aliasing try: @@ -164,7 +165,8 @@ def dispatch( assert output_sharding is not None, "output sharding should not be None" mesh = op_info.compute_mesh - if mesh.get_coordinate() is not None: + participating = mesh.get_coordinate() is not None + if participating: # computation that happens in the current rank of the mesh, normal case if output_sharding.needs_redistribute: # If sharding propagation decision needs redistribute, perform redistribute @@ -299,7 +301,11 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: assert len(out_dts) >= 1, "out variant should have at least one out arg" return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] else: - return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + if participating and op_info.schema.is_view_op(): + return return_and_correct_aliasing(op_call, args, kwargs, ret) + else: + return ret @staticmethod def redistribute_local_args( diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index b892d8883527c..b60373ea6f834 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -450,6 +450,12 @@ def is_out_variant_op(self) -> bool: # be entirely correct, but it's good enough for now. return "out" in self.op._schema.overload_name + def is_view_op(self) -> bool: + return any( + a.alias_info is not None and not a.alias_info.is_write + for a in self.op._schema.arguments + ) + def __hash__(self) -> int: # Only hash args and kwargs that op indicates to hash if not self.schema_info: From f341077ce4710172da20cfad916ee37159bfe9fe Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 12 Aug 2025 17:57:57 +0000 Subject: [PATCH 0271/1424] Revert "[ROCm] Support large inputs for coalesceValuesKernel (#158281)" This reverts commit a7abf57aabec0ce686092e2d66e53ba185dbc56b. Reverted https://github.com/pytorch/pytorch/pull/158281 on behalf of https://github.com/clee2000 due to broke windows cuda build? [GH job link](https://github.com/pytorch/pytorch/actions/runs/16915172288/job/47927141460) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/a7abf57aabec0ce686092e2d66e53ba185dbc56b). Not caught b/c PR didn't have ciflow/trunk ([comment](https://github.com/pytorch/pytorch/pull/158281#issuecomment-3180408766)) --- .../sparse/cuda/SparseCUDAApplyUtils.cuh | 32 +++---------------- .../native/sparse/cuda/SparseCUDATensor.cu | 10 ------ test/test_sparse.py | 15 +-------- 3 files changed, 5 insertions(+), 52 deletions(-) diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh index c11588a32ba05..693ca536a3198 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh @@ -196,17 +196,9 @@ C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, Dtype *values, Dtype *newValues, - int64_t nnz, int64_t newNnz, -#ifdef USE_ROCM - int64_t nsegments, -#endif - int64_t stride) { + int64_t nnz, int64_t newNnz, int64_t stride) { -#ifdef USE_ROCM - int64_t seg = (blockIdx.x * gridDim.y + blockIdx.y) * 4 + threadIdx.y; -#else - int64_t seg = blockIdx.x * 4 + threadIdx.y; -#endif + int seg = blockIdx.x * 4 + threadIdx.y; // Number of values processed by each thread (grain size) const int SZ = 4; @@ -215,11 +207,7 @@ __global__ void coalesceValuesKernel( const int newValueRow = seg * stride; const int begin = segment_offsets[seg]; const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz; -#ifdef USE_ROCM - const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ; -#else const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; -#endif Acctype tmp[SZ]; #pragma unroll for (int ii = 0; ii < SZ; ii++) { @@ -262,17 +250,9 @@ C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4) __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, bool *values, bool *newValues, - int64_t nnz, int64_t newNnz, -#ifdef USE_ROCM - int64_t nsegments, -#endif - int64_t stride) { + int64_t nnz, int64_t newNnz, int64_t stride) { -#ifdef USE_ROCM - int64_t seg = (blockIdx.x * gridDim.y + blockIdx.y) * 4 + threadIdx.y; -#else - int64_t seg = blockIdx.x * 4 + threadIdx.y; -#endif + int seg = blockIdx.x * 4 + threadIdx.y; // Number of values processed by each thread (grain size) const int SZ = 4; @@ -281,11 +261,7 @@ __global__ void coalesceValuesKernel( const int newValueRow = seg * stride; const int begin = segment_offsets[seg]; const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz; -#ifdef USE_ROCM - const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ; -#else const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; -#endif bool tmp[SZ]; #pragma unroll for (int ii = 0; ii < SZ; ii++) { diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index 2e84ca8982fb2..a36ec9b203fc3 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -106,14 +106,7 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { values = values.contiguous(); int64_t stride = c10::multiply_integers(values.sizes().slice(1)); int warp_size = at::cuda::warp_size(); -#ifdef USE_ROCM - const int64_t BATCHING_SEGMENT = 4096; - int64_t nsegments = ceil_div(newNnz, (int64_t) SZ); - int64_t s_batch = ceil_div(nsegments, BATCHING_SEGMENT); - dim3 grid(s_batch, (s_batch == 1) ? nsegments : BATCHING_SEGMENT, ceil_div(stride, (int64_t) warp_size*SZ)); -#else dim3 grid(ceil_div(newNnz, (int64_t) SZ), ceil_div(stride, (int64_t) warp_size*SZ)); -#endif dim3 block(warp_size, SZ); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( at::ScalarType::ComplexHalf, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, @@ -126,9 +119,6 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { newValues.data_ptr(), nnz, newNnz, -#if USE_ROCM - nsegments, -#endif stride ); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/test/test_sparse.py b/test/test_sparse.py index cef3adb34721b..608b5ef13c1be 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -21,7 +21,7 @@ (SM53OrLater, SM80OrLater, TEST_MULTIGPU) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride, - deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf) + deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes) from torch.testing._internal.common_methods_invocations import \ (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs) from torch.testing._internal.common_dtype import ( @@ -367,19 +367,6 @@ def _test_coalesce(t): t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size, dtype, device, coalesced) _test_coalesce(t) # this tests correctness - @onlyCUDA - @skipCUDAIf(not SM80OrLater and not TEST_WITH_ROCM, "CUDA capability < SM80 and not ROCM") - @dtypes(torch.float) - def test_coalesce_accepts_large_tensor(self, device, dtype): - N = 22500000 - NNZ = 272500000 - rows = torch.randint(0, N, (NNZ,), dtype=torch.int64, device=device) - cols = torch.randint(0, N, (NNZ,), dtype=torch.int64, device=device) - indices = torch.stack([rows, cols], dim=0) - values = torch.randn(NNZ, dtype=dtype, device=device) - sparse_matrix = torch.sparse_coo_tensor(indices, values, size=(N, N), dtype=torch.float32, device=device) - sparse_matrix = sparse_matrix.coalesce() - @dtypes(torch.double) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/89395") def test_coalesce_reference_cycle(self, device, dtype): From 9903ca4f70bdc1653016256f5b4fd74fdfc609f8 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 12 Aug 2025 18:07:41 +0000 Subject: [PATCH 0272/1424] [cuDNN][64-bit indexing] update conv depthwise 64bit indexing dispatch condition to match native kernel (#156140) The native kernel doesn't support batch splitting so the previous check wasn't aggressive enough in dispatching to cuDNN https://github.com/pytorch/pytorch/issues/155225 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156140 Approved by: https://github.com/ngimel, https://github.com/atalman --- aten/src/ATen/native/Convolution.cpp | 3 ++- test/nn/test_convolution.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 7932e32b428b6..5bcb4fe55fd20 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -463,7 +464,7 @@ struct ConvParams { return true; } // native kernel doesn't support 64-bit non-splittable case - if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) { + if (cudnn_enabled && !(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) { static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1; if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index df3a3f5766c14..64e6349e0364c 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -4058,13 +4058,22 @@ def test_conv3d_64bit_indexing(self, device): @largeTensorTest("20GB") @largeTensorTest("64GB", "cpu") def test_depthwise_conv_64bit_indexing(self, device): - x = torch.randn(1, 2, 32800, 32800, dtype=torch.half) + x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to( + memory_format=torch.channels_last + ) c = nn.Conv2d( 2, 2, kernel_size=3, stride=1, padding=1, groups=2, dtype=torch.half - ) + ).to(memory_format=torch.channels_last) + yref = c(x) + y = c.to(device=device)(x.to(device=device)) + self.assertEqual(yref, y, atol=1e-3, rtol=1e-4) + del y, yref + + # try a batch-splittable case + x = x.reshape(100, 2, 3280, 3280).contiguous(memory_format=torch.channels_last) yref = c(x) y = c.to(device=device)(x.to(device=device)) - self.assertEqual(yref, y, atol=5e-3, rtol=1e-4) + self.assertEqual(yref, y, atol=1e-3, rtol=1e-4) instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True) From 2d0cdee394bccadcd0abe19dd4623ed978a331ad Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 12 Aug 2025 19:25:04 +0000 Subject: [PATCH 0273/1424] move thread-local capture mode guard to include work.isStarted (#160398) Per title, should fix capture errors that happen because nccl watchdog races with capture start. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160398 Approved by: https://github.com/aorenste --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3cb6aee8b9df8..3e9802d855e7c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2284,6 +2284,10 @@ void ProcessGroupNCCL::Watchdog::runLoop() { // Work status logging for desync debug desyncDebugger_.logWorkStart(work); + // allow watchdog to do an event query on a side thread + at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); + at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; + // a work could be started but not completed, so we should not update // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start @@ -2295,10 +2299,6 @@ void ProcessGroupNCCL::Watchdog::runLoop() { pg_->pgStatus_->lastStartedNumelOut = work.numelOut_; } - // allow watchdog to do an event query on a side thread - at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); - at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; - // Clean up completed work if (work.isCompleted()) { // In case user didn't call `work.wait()` with async collectives, From 89654db1abccf7e5f261989a150db4d1619ea2aa Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Tue, 12 Aug 2025 09:25:08 -0700 Subject: [PATCH 0274/1424] [inductor] fix triton bucketize mask propagation (#159961) See https://hud.pytorch.org/pytorch/pytorch/commit/6b414f56a4a133a428af618d8ed1553849341497 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159961 Approved by: https://github.com/eellison --- test/inductor/test_torchinductor.py | 22 ++++++++++++++++++++++ torch/_inductor/codegen/triton.py | 15 +++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 385a75d98f944..0e76ca4892841 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -13697,6 +13697,28 @@ def f(a_list): print(profile_output) self.assertFalse("Pageable" in profile_output) + @unittest.skipIf( + config.cpp_wrapper, + "cpp_wrapper samples will lead to invalid indexing", + ) + def test_inductor_triton_bucketize_respects_masking(self): + def fn(inp, repeats, output_size): + # return torch.repeat_interleave(inp, repeats, dim=0, output_size=output_size) + idx = torch.searchsorted( + repeats.cumsum(0), + torch.arange(0, output_size, device=repeats.device), + right=True, + ) + return torch.index_select(inp, 0, idx) + + inp = torch.arange(0, 4, device=self.device) + repeats = torch.tensor([1, 2, 3, 4], device=self.device) + output_size = repeats.sum().item() + args = (inp, repeats, output_size) + self.assertEqual(fn(*args), torch.compile(fn)(*args)) + + # end of class CommonTemplate - add new tests here + @dataclasses.dataclass class TestFailure: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e34fe5010d089..8e0831e3726f7 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2669,6 +2669,18 @@ def guard_cooperative_store(self, name, buffer): buffer.writeline(DeferredLine(name, f"if rsplit_id == ({idx} % RSPLIT):")) return buffer.indent() + def _combine_masks(self, *variables: Optional[CSEVariable]): + masks = None + for elem in variables: + if elem is None: + continue + if hasattr(elem, "mask_vars"): + if masks is None: + masks = elem.mask_vars + else: + masks = masks | elem.mask_vars + return masks + def bucketize( self, values: CSEVariable, @@ -2718,6 +2730,9 @@ def bucketize( dtype=indexing_dtype, # type: ignore[attr-defined] ) + masks = self._combine_masks(values, boundary_indices, sorter_indices) + result.mask_vars = masks # type: ignore[attr-defined] + return result def reduction_resize(self, value) -> str: From 7e91394955721c77645fcdb75a5d47a255d65020 Mon Sep 17 00:00:00 2001 From: Paul de Supinski Date: Tue, 12 Aug 2025 20:08:45 +0000 Subject: [PATCH 0275/1424] Support NUMA Binding for Callable Entrypoints (#160163) # Context This is an extension of #149334. # This PR Add support for NUMA bindings with Callable entrypoints, such as `do_train` instead of `/usr/local/bin/python`. Most notably, we utilize a hack in order to force `Process.start()` to use custom NUMA bindings for each subprocess. Please search for `HACK:` in the code to see a description of the implementation we chose, and #160006 for discussion of alternatives and why this is necessary. Other changes: * Remove unnecessary `--preferred` option from all binding strategies. By default, Linux already allocates memory to the NUMA node local to the CPU which triggered the allocation. (See [MPOL_LOCAL](https://man7.org/linux/man-pages/man2/set_mempolicy.2.html).) * Refactor so that the main API is `maybe_wrap_command_with_numa_bindings`, which computes bindings for a single rank at a time, rather than `maybe_wrap_with_numa_bindings` which computed bindings for all ranks at once. This allowed for more code sharing between `Callable` and `str` entrypoints. # Test Plan ## Automated `$ pytest test/test_numa_binding.py` ## Manual Using [this benchmark,](https://gist.github.com/pdesupinski/bbe01ade455d86e989794f2c612e2d91), ran ``` $ PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -m torch.distributed.run --standalone --nproc-per-node=8 --numa-binding=node --run-path mlp_train.py 2>&1 | tee node_callable.txt && PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -u -m torch.distributed.run --standalone --nproc-per-node=8 --run-path mlp_train.py 2>&1 | tee none_callable.txt ``` and observed * 6.6% remote memory accesses with 'node' bindings * 11.6% remote without bindings I also ran similar with `str` entrypoints as before just to be sure it's still working. NOTE: [--run-path triggers the code to be run inside a `Callable`.](https://github.com/pytorch/pytorch/blob/017259f9c65b6fad55fb9597d7077e2543eaae46/torch/distributed/run.py#L870) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160163 Approved by: https://github.com/d4l3k --- docs/source/elastic/numa.rst | 4 +- test/test_numa_binding.py | 250 ++++++++++------ torch/distributed/elastic/agent/server/api.py | 9 +- .../elastic/multiprocessing/__init__.py | 3 +- .../elastic/multiprocessing/api.py | 12 +- .../subprocess_handler/handlers.py | 4 + .../subprocess_handler/subprocess_handler.py | 12 + torch/distributed/launcher/api.py | 10 +- torch/distributed/run.py | 2 +- torch/multiprocessing/spawn.py | 58 +++- torch/{distributed => }/numa/__init__.py | 0 torch/{distributed => }/numa/binding.py | 275 +++++++++++------- 12 files changed, 424 insertions(+), 215 deletions(-) rename torch/{distributed => }/numa/__init__.py (100%) rename torch/{distributed => }/numa/binding.py (74%) diff --git a/docs/source/elastic/numa.rst b/docs/source/elastic/numa.rst index b6caa8a94c0e7..d56c99cf422e3 100644 --- a/docs/source/elastic/numa.rst +++ b/docs/source/elastic/numa.rst @@ -3,8 +3,8 @@ NUMA Binding Utilities ====================== -.. automodule:: torch.distributed.numa +.. automodule:: torch.numa :members: -.. automodule:: torch.distributed.numa.binding +.. automodule:: torch.numa.binding :members: diff --git a/test/test_numa_binding.py b/test/test_numa_binding.py index e1637b2aad967..e89d06174f385 100644 --- a/test/test_numa_binding.py +++ b/test/test_numa_binding.py @@ -2,16 +2,19 @@ from __future__ import annotations +import multiprocessing.spawn as spawn +import os import subprocess import sys +import tempfile from dataclasses import dataclass from typing import Any, Optional -from unittest import skipIf, skipUnless +from unittest import skipUnless from unittest.mock import mock_open, patch import torch from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes -from torch.distributed.numa.binding import ( +from torch.numa.binding import ( _get_ranges_str_from_ints, _get_set_of_int_from_ranges_str, AffinityMode, @@ -35,12 +38,10 @@ class MockDeviceProperties: _real_open = open +_real_mkstemp = tempfile.mkstemp -@skipIf( - sys.platform == "win32", - "Windows is missing various os module attributes like sched_getaffinity", -) +@skipUnless(sys.platform == "linux", "Only linux currently supported") @skipUnless( torch.distributed.is_available(), "Need access to some distributed submodules" ) @@ -53,26 +54,44 @@ def setUp(self) -> None: self._mock_num_logical_cpus = 0 self._mock_num_numa_nodes = 0 self._mock_num_sockets = 0 + self._temp_file_paths = [] self._context_managers_to_apply_to_all_tests = [ patch("torch.cuda.device_count", self._mock_device_count), patch("torch.cuda.get_device_properties", self._mock_get_device_properties), patch("torch.cuda.is_available", self._mock_is_available), + # Implicitly used by dynamo + patch("torch.cuda.get_rng_state"), patch("builtins.open", new=self._mock_open), patch("os.listdir", new=self._mock_listdir), patch("os.sched_getaffinity", new=self._mock_sched_getaffinity), patch("shutil.which", return_value="/usr/bin/numactl"), - patch("subprocess.run"), + patch("torch.numa.binding.run"), + patch("torch.numa.binding.mkstemp", self._mock_mkstemp), ] for context_manager in self._context_managers_to_apply_to_all_tests: context_manager.__enter__() def tearDown(self) -> None: + # Clean up temporary files + for temp_file_path in self._temp_file_paths: + try: + os.unlink(temp_file_path) + except FileNotFoundError: + # File may have already been deleted or doesn't exist + pass + for context_manager in self._context_managers_to_apply_to_all_tests: context_manager.__exit__(None, None, None) super().tearDown() + def _mock_mkstemp(self, *args, **kwargs): + # Just keep track of temp files so we can delete them + fd, path = _real_mkstemp(*args, **kwargs) + self._temp_file_paths.append(path) + return fd, path + def _add_mock_hardware( self, *, @@ -204,7 +223,7 @@ def _mock_get_device_properties(self, index: int) -> MockDeviceProperties: def _mock_open(self, path: str, *args, **kwargs) -> Any: if path in self._mock_file_path_to_contents: return mock_open(read_data=self._mock_file_path_to_contents[path])() - if path.startswith("/sys/"): + if isinstance(path, str) and path.startswith("/sys/"): raise FileNotFoundError(f"File {path} was not mocked.") # Looks like CI is calling open and intending to open an actual file in some places. # Need this to make the CI pass. @@ -222,8 +241,8 @@ def _mock_listdir(self, target_path: str) -> set[str]: def _mock_sched_getaffinity(self, pid: int) -> set[int]: return set(range(self._mock_num_logical_cpus)) - def _start_test_processes_and_get_command_args_for_local_rank( - self, *, numa_options: Optional[NumaOptions], local_rank: int + def _start_processes_for_str_entrypoint_and_get_Popen_args( + self, *, numa_options: Optional[NumaOptions], target_local_rank: int ) -> tuple[str, ...]: """ Calls start_processes like elastic_launch ultimately would @@ -250,10 +269,58 @@ def _start_test_processes_and_get_command_args_for_local_rank( call_args = next( call_args for call_args in mock_popen.call_args_list - if call_args.kwargs.get("env", {}).get("LOCAL_RANK") == str(local_rank) + if call_args.kwargs.get("env", {}).get("LOCAL_RANK") + == str(target_local_rank) ) return call_args.kwargs["args"] + def _start_processes_for_callable_entrypoint_and_get_executable_contents( + self, *, numa_options: Optional[NumaOptions], target_local_rank: int + ) -> str: + active_local_rank = None + executable_path = None + + def _mock_process_start(self: Any) -> None: + nonlocal active_local_rank + active_local_rank = self._args[1] + spawn.get_command_line() + self._target(*self._args) + + original_get_command_line = spawn.get_command_line + + def _mock_get_command_line(*args, **kwargs) -> list[str]: + nonlocal executable_path + result = original_get_command_line(*args, **kwargs) + if active_local_rank == target_local_rank: + executable_path = result[0] + + return result + + with ( + patch("multiprocessing.context.SpawnProcess.start", _mock_process_start), + patch("multiprocessing.spawn.get_command_line", _mock_get_command_line), + patch("multiprocessing.process.BaseProcess.sentinel", 1), + # Prevent hanging + patch( + "multiprocessing.synchronize.Event.wait", + lambda self, timeout=None: None, + ), + ): + start_processes( + name="test_process", + entrypoint=lambda x: x, + args=dict.fromkeys(range(self._mock_device_count()), (0,)), + envs={ + i: {"LOCAL_RANK": str(i)} for i in range(self._mock_device_count()) + }, + logs_specs=DefaultLogsSpecs(), + numa_options=numa_options, + ) + + assert executable_path is not None + with open(executable_path) as executable_file: + return executable_file.read() + def test_node_numa_binding(self) -> None: self._add_mock_hardware( num_sockets=4, @@ -263,8 +330,9 @@ def test_node_numa_binding(self) -> None: num_physical_core_per_l3_cache=2, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=11 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), + target_local_rank=11, ) self.assertEqual( command_args, @@ -273,7 +341,6 @@ def test_node_numa_binding(self) -> None: ( "numactl", "--cpunodebind=5", - "--preferred=5", "echo", "Hello, world!", ), @@ -288,8 +355,8 @@ def test_no_numa_binding_if_numa_options_not_provided(self) -> None: num_physical_core_per_l3_cache=2, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=None, local_rank=11 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=None, target_local_rank=11 ) self.assertEqual( command_args, @@ -340,20 +407,18 @@ def test_fallback(self) -> None: ) with ( - patch("torch.distributed.numa.binding.signpost_event") as signpost_patch, + patch("torch.numa.binding.signpost_event") as signpost_patch, patch( - "subprocess.run", + "torch.numa.binding.run", side_effect=subprocess.CalledProcessError(1, "numactl"), ), ): - command_args = ( - self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions( - affinity_mode=AffinityMode.NODE, - should_fall_back_if_binding_fails=True, - ), - local_rank=0, - ) + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions( + affinity_mode=AffinityMode.NODE, + should_fall_back_if_binding_fails=True, + ), + target_local_rank=0, ) self.assertIn( "subprocess.CalledProcessError", @@ -387,6 +452,25 @@ def test_explicit_numa_options_overrides_default(self) -> None: NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), ) + def test_fork_start_method_does_not_call_get_default_numa_options(self) -> None: + # Inner import to avoid crashing if not torch.distributed.is_available() + from torch.distributed.launcher.api import LaunchConfig + + with patch( + "torch.distributed.launcher.api.get_default_numa_options" + ) as mock_get_default_numa_options: + launch_config = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=1, + start_method="fork", + # Don't provide numa_options + ) + # Verify get_default_numa_options was not called + mock_get_default_numa_options.assert_not_called() + # Verify numa_options is None when start_method is fork + self.assertIsNone(launch_config.numa_options) + def test_socket_numa_binding_with_multiple_numa_per_socket(self) -> None: self._add_mock_hardware( num_sockets=4, @@ -396,15 +480,15 @@ def test_socket_numa_binding_with_multiple_numa_per_socket(self) -> None: num_physical_core_per_l3_cache=2, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), local_rank=15 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), + target_local_rank=15, ) self.assertEqual( command_args, ( "numactl", "--cpunodebind=6-7", - "--preferred-many=6-7", "echo", "Hello, world!", ), @@ -419,15 +503,15 @@ def test_socket_numa_binding_with_single_numa_per_socket(self) -> None: num_physical_core_per_l3_cache=2, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), local_rank=7 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), + target_local_rank=7, ) self.assertEqual( command_args, ( "numactl", "--cpunodebind=3", - "--preferred=3", "echo", "Hello, world!", ), @@ -442,8 +526,9 @@ def test_exclusive_numa_binding(self) -> None: num_physical_core_per_l3_cache=3, ) - command_args_0 = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), local_rank=0 + command_args_0 = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), + target_local_rank=0, ) self.assertEqual( command_args_0, @@ -451,14 +536,14 @@ def test_exclusive_numa_binding(self) -> None: "numactl", # Gets an extra physical core due to odd number of physical cores on numa node "--physcpubind=0-3", - "--preferred=0", "echo", "Hello, world!", ), ) - command_args_1 = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), local_rank=1 + command_args_1 = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), + target_local_rank=1, ) self.assertEqual( command_args_1, @@ -466,7 +551,6 @@ def test_exclusive_numa_binding(self) -> None: "numactl", # Does not get an extra physical core, since the 1st GPU already took the extra. "--physcpubind=4-5", - "--preferred=0", "echo", "Hello, world!", ), @@ -485,9 +569,9 @@ def test_exclusive_raises_if_too_few_physical_cores(self) -> None: RuntimeError, "There are only 1 physical cores on numa_node_index=0, but there are 2 GPUs associated with this NUMA node.", ): - self._start_test_processes_and_get_command_args_for_local_rank( + self._start_processes_for_str_entrypoint_and_get_Popen_args( numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), - local_rank=1, + target_local_rank=1, ) def test_core_complex_numa_binding_with_extra_l3(self) -> None: @@ -499,9 +583,9 @@ def test_core_complex_numa_binding_with_extra_l3(self) -> None: num_physical_core_per_l3_cache=3, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), - local_rank=3, + target_local_rank=3, ) self.assertEqual( command_args, @@ -509,7 +593,6 @@ def test_core_complex_numa_binding_with_extra_l3(self) -> None: "numactl", # The second L3 on the second numa node "--physcpubind=24-29", - "--preferred=1", "echo", "Hello, world!", ), @@ -524,9 +607,9 @@ def test_core_complex_numa_binding_with_fewer_l3_than_gpu(self) -> None: num_physical_core_per_l3_cache=3, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), - local_rank=3, + target_local_rank=3, ) self.assertEqual( command_args, @@ -535,7 +618,6 @@ def test_core_complex_numa_binding_with_fewer_l3_than_gpu(self) -> None: # There are only 2 L3 caches, so the 4th GPU shares the same # cores as the 3rd GPU. "--physcpubind=6-11", - "--preferred=1", "echo", "Hello, world!", ), @@ -552,11 +634,9 @@ def test_core_complex_prefers_caches_with_more_cpus(self) -> None: # Only some subset of the CPUs are available this time. with patch("os.sched_getaffinity", return_value={0, 4, 6, 7, 9}): - command_args = ( - self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), - local_rank=0, - ) + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), + target_local_rank=0, ) self.assertEqual( @@ -565,7 +645,6 @@ def test_core_complex_prefers_caches_with_more_cpus(self) -> None: "numactl", # Binds to the second L3 because it has the most available CPUs "--physcpubind=6-7,9", - "--preferred=0", "echo", "Hello, world!", ), @@ -584,42 +663,20 @@ def test_core_complex_tiebreak_prefers_lower_cache_key(self) -> None: num_physical_core_per_l3_cache=1, ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX), - local_rank=0, + target_local_rank=0, ) self.assertEqual( command_args, ( "numactl", "--physcpubind=0-1", - "--preferred=0", "echo", "Hello, world!", ), ) - def test_raises_error_if_numa_options_provided_for_callable_entrypoint( - self, - ) -> None: - # Inner import to avoid crashing if not torch.distributed.is_available() - from torch.distributed.elastic.agent.server.api import WorkerSpec - - def mock_entrypoint() -> None: - pass - - with self.assertRaisesRegex(ValueError, r".*numa_options.*"): - # not relevant to test, just pass in an arbitrary value - mock_rdzv_handler: Any = 0 - WorkerSpec( - role="trainer", - # Only str entrypoint (e.g. "echo") is currently supported - entrypoint=mock_entrypoint, - local_world_size=8, - rdzv_handler=mock_rdzv_handler, - numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), - ) - def test_raises_error_if_numactl_unavailable(self) -> None: self._add_mock_hardware( num_sockets=1, @@ -632,8 +689,9 @@ def test_raises_error_if_numactl_unavailable(self) -> None: patch("shutil.which", return_value=None), self.assertRaisesRegex(RuntimeError, r".*numactl.*"), ): - self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=0 + self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), + target_local_rank=0, ) def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None: @@ -654,20 +712,50 @@ def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None: contents="-1", ) - command_args = self._start_test_processes_and_get_command_args_for_local_rank( - numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=0 + command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args( + numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), + target_local_rank=0, ) self.assertEqual( command_args, ( "numactl", "--cpunodebind=0", - "--preferred=0", "echo", "Hello, world!", ), ) + def test_callable_entrypoint_basic(self) -> None: + self._add_mock_hardware( + num_sockets=4, + num_numa_nodes_per_socket=2, + num_gpus_per_numa_node=2, + num_l3_caches_per_numa_node=4, + num_physical_core_per_l3_cache=2, + ) + + executable_contents = ( + self._start_processes_for_callable_entrypoint_and_get_executable_contents( + numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), + target_local_rank=11, + ) + ) + self.assertEqual( + executable_contents, + # There are 8 numa nodes and 2 GPUs per numa node, so GPU 11 would be + # on numa node 11 // 2 = 5. + f"""#!/bin/bash + +# If this file is more than a few minutes old and still exists on your machine, +# that is NOT expected. It should have deleted itself. If you are seeing an accumulation of such +# files, that could suggest a bug in pytorch. See https://github.com/pytorch/pytorch/pull/160163. + +rm -- "$0" +numactl --cpunodebind=5 {sys.executable} "$@" +""", + ) + def test_get_set_of_int_from_ranges_str(self) -> None: self.assertEqual( _get_set_of_int_from_ranges_str("0-2,4,6-7"), {0, 1, 2, 4, 6, 7} diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 2759f20bd2778..1175da3b91b7c 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -27,7 +27,7 @@ from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.numa.binding import NumaOptions +from torch.numa.binding import NumaOptions __all__ = [ @@ -104,13 +104,6 @@ def __post_init__(self): self.entrypoint = self.fn assert self.entrypoint - if ( - self.numa_options is not None - and not self.numa_options.should_fall_back_if_binding_fails - and not isinstance(self.entrypoint, str) - ): - raise ValueError("numa_options is only supported for str entrypoints.") - def get_entrypoint_name(self): """Get the entry point name. diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index d283e0129f0ac..7e293ce47cb7b 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -80,7 +80,7 @@ def trainer(a, b, c): to_map, ) from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.numa.binding import NumaOptions +from torch.numa.binding import NumaOptions __all__ = [ @@ -227,6 +227,7 @@ def start_processes( log_line_prefixes=log_line_prefixes, start_method=start_method, logs_specs=logs_specs, + numa_options=numa_options, ) try: diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 6cd8d2a12f351..ed3ea86b0f2aa 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -37,7 +37,7 @@ SubprocessHandler, ) from torch.distributed.elastic.multiprocessing.tail_log import TailLog -from torch.distributed.numa.binding import maybe_wrap_with_numa_bindings, NumaOptions +from torch.numa.binding import NumaOptions IS_WINDOWS = sys.platform == "win32" @@ -631,6 +631,7 @@ def __init__( start_method: str, logs_specs: LogsSpecs, log_line_prefixes: Optional[dict[int, str]] = None, + numa_options: Optional[NumaOptions] = None, ): super().__init__( name, @@ -655,6 +656,8 @@ def __init__( # successfully. If any process died on event.wait() calling set() method will deadlock. self._worker_finished_event = mp.get_context(self.start_method).Event() + self._numa_options: Optional[NumaOptions] = numa_options + def _start(self): if self._pc: raise ValueError( @@ -676,6 +679,7 @@ def _start(self): join=False, daemon=False, start_method=self.start_method, + numa_options=self._numa_options, ) def _is_done(self) -> bool: @@ -814,10 +818,6 @@ def __init__( log_line_prefixes: Optional[dict[int, str]] = None, numa_options: Optional[NumaOptions] = None, ): - entrypoint, args = maybe_wrap_with_numa_bindings( - entrypoint=entrypoint, local_rank_to_args=args, numa_options=numa_options - ) - super().__init__( name, entrypoint, @@ -831,6 +831,7 @@ def __init__( self._running_local_ranks: set[int] = set(range(self.nprocs)) self._failures: dict[int, ProcessFailure] = {} self.subprocess_handlers: dict[int, SubprocessHandler] = {} + self._numa_options: Optional[NumaOptions] = numa_options def _start(self): if self.subprocess_handlers: @@ -845,6 +846,7 @@ def _start(self): stdout=self.stdouts[local_rank], stderr=self.stderrs[local_rank], local_rank_id=local_rank, + numa_options=self._numa_options, ) for local_rank in range(self.nprocs) } diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py index fea707a3c3ab2..947ce7b001ef7 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -3,10 +3,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( SubprocessHandler, ) +from torch.numa.binding import NumaOptions __all__ = ["get_subprocess_handler"] @@ -19,6 +21,7 @@ def get_subprocess_handler( stdout: str, stderr: str, local_rank_id: int, + numa_options: Optional[NumaOptions] = None, ) -> SubprocessHandler: return SubprocessHandler( entrypoint=entrypoint, @@ -27,4 +30,5 @@ def get_subprocess_handler( stdout=stdout, stderr=stderr, local_rank_id=local_rank_id, + numa_options=numa_options, ) diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py index 6b927fcd6a670..c2327e1cd3cf3 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -11,6 +11,8 @@ from subprocess import Popen from typing import Any, Optional +from torch.numa.binding import maybe_wrap_command_with_numa_bindings, NumaOptions + __all__ = ["SubprocessHandler"] @@ -39,6 +41,7 @@ def __init__( stdout: Optional[str], stderr: Optional[str], local_rank_id: int, + numa_options: Optional[NumaOptions], ): self._stdout = open(stdout, "w") if stdout else None self._stderr = open(stderr, "w") if stderr else None @@ -47,6 +50,15 @@ def __init__( env_vars.update(env) args_str = (entrypoint, *[str(e) for e in args]) + args_str = ( + maybe_wrap_command_with_numa_bindings( + command_args=args_str, + gpu_index=local_rank_id, + numa_options=numa_options, + ) + or args_str + ) + self.local_rank_id = local_rank_id self.proc: Popen = self._popen(args_str, env_vars) diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index d788ad568bd5c..ef6e75c8dde36 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -26,7 +26,7 @@ from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.numa.binding import NumaOptions +from torch.numa.binding import NumaOptions __all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] @@ -107,7 +107,13 @@ def __post_init__(self): if self.logs_specs is None: self.logs_specs = DefaultLogsSpecs() - if self.numa_options is None and torch.cuda.is_available(): + if ( + self.numa_options is None + # NOTE: This filter isn't relevant for str entrypoints, + # but it's the default anyway. + and self.start_method == "spawn" + and torch.cuda.is_available() + ): self.numa_options = get_default_numa_options() logger.info("Using default numa options = %r", self.numa_options) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index c37ecd8f72d86..2738191f0e379 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -382,7 +382,7 @@ def main(): from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.launcher.api import elastic_launch, LaunchConfig -from torch.distributed.numa.binding import ( +from torch.numa.binding import ( AffinityMode as _AffinityMode, # Signify as private with _ NumaOptions as _NumaOptions, ) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 4cef60948ad98..eb5f885acc194 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -2,6 +2,7 @@ import logging import multiprocessing import multiprocessing.connection +import multiprocessing.spawn as mp_spawn import os import pickle import signal @@ -12,6 +13,11 @@ from concurrent.futures import as_completed, ThreadPoolExecutor from typing import Optional +from torch.numa.binding import ( + maybe_get_temporary_python_executable_with_numa_bindings, + NumaOptions, +) + from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] @@ -236,6 +242,7 @@ def start_processes( join=True, daemon=False, start_method="spawn", + numa_options: Optional[NumaOptions] = None, ): # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), # this func will start processes in parallel if start_method is 'forkserver'. @@ -251,11 +258,43 @@ def start_processes( # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start start_parallel = False + if numa_options is not None and start_method != "spawn": + raise ValueError("NUMA binding is only compatible with spawn") + + if numa_options is not None and start_parallel: + raise ValueError("NUMA binding is not compatible with parallel start") + mp = multiprocessing.get_context(start_method) error_files = [None] * nprocs processes = [None] * nprocs + original_executable = mp_spawn.get_executable() def start_process(i): + # HACK: We want to force Process.start() to kick off the subprocess + # using a custom numactl command per rank. However, the API exposed + # by multiprocessing only allows us to override the executable for + # the entire context, and only with a single str rather than a tuple. + # Furthermore, there is no API for passing additional options, e.g. + # to make LOCAL_RANK available to the executable. + # + # In order to get around these limitations, we pre-compute + # the appropriate command containing NUMA bindings and store it in a + # temporary executable which passes Python args on to the original + # executable. Then, we call set_executable before and after each + # Process.start() call. + # + # This assumes that, under the hood, Process.start() for rank n + # will not call get_executable after start_process for rank n+1 + # calls set_executable again. We guarantee this by + # raising an exception if `start_parallel`, above. (Not clear + # if there would be a race condition otherwise, but we want to be safe.) + temporary_executable_path = ( + maybe_get_temporary_python_executable_with_numa_bindings( + python_executable_path=original_executable, + gpu_index=i, + numa_options=numa_options, + ) + ) # Each process is assigned a file to write tracebacks to. We # use the file being non-empty to indicate an exception # occurred (vs an expected shutdown). Note: this previously @@ -267,12 +306,19 @@ def start_process(i): ) tf.close() os.unlink(tf.name) - process = mp.Process( - target=_wrap, - args=(fn, i, args, tf.name), - daemon=daemon, - ) - process.start() + + try: + if temporary_executable_path is not None: + mp.set_executable(temporary_executable_path) + process = mp.Process( + target=_wrap, + args=(fn, i, args, tf.name), + daemon=daemon, + ) + process.start() + finally: + if temporary_executable_path is not None: + mp.set_executable(original_executable) return i, process, tf.name if not start_parallel: diff --git a/torch/distributed/numa/__init__.py b/torch/numa/__init__.py similarity index 100% rename from torch/distributed/numa/__init__.py rename to torch/numa/__init__.py diff --git a/torch/distributed/numa/binding.py b/torch/numa/binding.py similarity index 74% rename from torch/distributed/numa/binding.py rename to torch/numa/binding.py index 51876583ec56c..7e4cc40aad5b3 100644 --- a/torch/distributed/numa/binding.py +++ b/torch/numa/binding.py @@ -1,28 +1,31 @@ import os import shutil +import stat import subprocess import traceback from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass from enum import Enum +from logging import getLogger +from subprocess import run +from tempfile import mkstemp from typing import Callable, Optional, TypeVar import torch from torch._utils_internal import signpost_event -from torch.distributed.elastic.utils.logging import get_logger __all__ = [ - "maybe_wrap_with_numa_bindings", "AffinityMode", + "maybe_get_temporary_python_executable_with_numa_bindings", + "maybe_wrap_command_with_numa_bindings", "NumaOptions", ] - _NUMACTL_COMMAND = "numactl" -logger = get_logger(__file__) +logger = getLogger(__name__) class AffinityMode(str, Enum): @@ -40,10 +43,10 @@ class AffinityMode(str, Enum): @dataclass(frozen=True) class NumaOptions: affinity_mode: AffinityMode + """ - If true, we will silently return the original command if any of the following occur: - - An exception is raised as we compute the wrapped command. - - During a dry run of the wrapped command, numactl fails for any reason. + If true, we will fall back to using the original command/entrypoint if we fail to compute + or apply NUMA bindings. You should avoid using this option! It is only intended as a safety mechanism for facilitating mass rollouts of numa binding. @@ -51,135 +54,201 @@ class NumaOptions: should_fall_back_if_binding_fails: bool = False -def maybe_wrap_with_numa_bindings( - *, - entrypoint: str, - local_rank_to_args: dict[int, tuple], - numa_options: Optional[NumaOptions], -) -> tuple[str, dict[int, tuple]]: +def maybe_get_temporary_python_executable_with_numa_bindings( + *, python_executable_path: str, gpu_index: int, numa_options: Optional[NumaOptions] +) -> Optional[str]: """ Args: - entrypoint: The entrypoint to the program, such as might be input to Popen. - Example: "python" - local_rank_to_args: A mapping from local rank to args for the entrypoint. - Example: {0: ("trainer.py",)} - numa_options: See NumaOptions for details. - + python_executable_path: E.g., "/usr/local/bin/python" Returns: - A tuple of (entrypoint, local_rank_to_args), basically transforming the inputs, - where the entrypoint and args may now involve numa binding. - Example: ("numactl", {"0": ("--cpunodebind=0", "--preferred=0", "python", "trainer.py")}) + Path to a temporary file. This file can be executed just like the original python + executable, except it will first apply NUMA bindings. """ if numa_options is None: - return (entrypoint, local_rank_to_args) - - wrapped_local_rank_to_args = {} - for local_rank, args in local_rank_to_args.items(): - try: - numactl_command_options = _maybe_get_numactl_options( - command_args=(entrypoint, *[str(arg) for arg in args]), - gpu_index=local_rank, - numa_options=numa_options, - ) - except Exception: - if numa_options.should_fall_back_if_binding_fails: - # NOTE: If any element of the batch fails to apply NUMA bindings - # for any reason, we do not apply NUMA bindings to any element of the batch, - # for maximum safety. This only applies if fallback is enabled. - return (entrypoint, local_rank_to_args) - raise - wrapped_local_rank_to_args[local_rank] = ( - *numactl_command_options, - entrypoint, - *args, - ) - return (_NUMACTL_COMMAND, wrapped_local_rank_to_args) + logger.info("Received numa_options=None, not creating numa executable.") + return None + + if isinstance(python_executable_path, bytes): + python_executable_path = python_executable_path.decode() + + full_numactl_command = maybe_wrap_command_with_numa_bindings( + # "$@", i.e. pass through any args the python executable would have + # received. + command_args=(python_executable_path, '"$@"'), + gpu_index=gpu_index, + numa_options=numa_options, + ) + if full_numactl_command is None: + return None + + executable_path = _get_temporary_executable_for_command( + command_args=full_numactl_command + ) + logger.info("Returning python executable with NUMA bindings %s", executable_path) -def _maybe_get_numactl_options( + return executable_path + + +def maybe_wrap_command_with_numa_bindings( *, command_args: tuple[str, ...], gpu_index: int, - numa_options: NumaOptions, -) -> tuple[str, ...]: + numa_options: Optional[NumaOptions], +) -> Optional[tuple[str, ...]]: """ Args: - command_args: The args for a command, such as might be input to Popen. - Example: ("python", "trainer.py") - gpu_index: The index of the GPU that will be used by the subprocess which executes command_args. - Example: 0 - numa_options: See NumaOptions for details. + command_args: Full shell command, like ("/usr/local/bin/python", "train.py") + gpu_index: The index of the GPU which command_args should bind to Returns: - Depending on numa_options, something like - ("--cpunodebind=0", "--preferred=0") + command_args, but wrapped so that it runs with NUMA bindings corresponding to + gpu_index and numa_options. + E.g., ("numactl", "--cpunodebind=0", "/usr/local/bin/python", "train.py") """ + if not numa_options: + logger.info("Received numa_options=None, not applying bindings.") + return None + + kwargs = { + "command_args": command_args, + "gpu_index": gpu_index, + "numa_options": numa_options, + } + logger.info("Attempting to wrap command with NUMA bindings, given input %r", kwargs) + try: _raise_if_numactl_not_available() - if numa_options.affinity_mode == AffinityMode.NODE: - numactl_command_options = _get_node_numactl_options(gpu_index=gpu_index) - elif numa_options.affinity_mode == AffinityMode.SOCKET: - numactl_command_options = _get_socket_numactl_options(gpu_index=gpu_index) - elif numa_options.affinity_mode == AffinityMode.EXCLUSIVE: - numactl_command_options = _get_exclusive_numactl_options( - gpu_index=gpu_index - ) - elif numa_options.affinity_mode == AffinityMode.CORE_COMPLEX: - numactl_command_options = _get_core_complex_numactl_options( - gpu_index=gpu_index - ) - else: - raise ValueError( - f"Affinity mode {numa_options.affinity_mode} not supported." - ) - if numa_options.should_fall_back_if_binding_fails: - _raise_if_numactl_fails_dry_run(numactl_options=numactl_command_options) + numactl_options = _get_numactl_cli_options( + command_args=command_args, gpu_index=gpu_index, numa_options=numa_options + ) + logger.info("Computed numactl_options=%r", numactl_options) + + _raise_if_numactl_fails_dry_run(numactl_options=numactl_options) + logger.info("Validated numactl_options=%r", numactl_options) + + full_numactl_command = _get_assembled_command_from_pieces( + command_args=command_args, numactl_options=numactl_options + ) + logger.info( + "Successfully wrapped command with numa_bindings. Returning %r", + full_numactl_command, + ) signpost_event( category="numa_binding", name="wrap_command_success", - parameters={ - "original_command_args": command_args, - "gpu_index": gpu_index, - "numa_options": numa_options, - "numactl_command_options": numactl_command_options, - }, + parameters={**kwargs, "result": full_numactl_command}, ) - return numactl_command_options + return full_numactl_command except Exception: signpost_event( category="numa_binding", name="wrap_command_exception", parameters={ + **kwargs, "traceback": traceback.format_exc(), - "original_command_args": command_args, - "gpu_index": gpu_index, - "numa_options": numa_options, }, ) logger.exception( - """Failed to wrap command with NUMA bindings. - Input: - command_args=%r, - gpu_index=%d, - numa_options=%r, - """, - command_args, - gpu_index, - numa_options, + "Failed to wrap command with NUMA bindings for input = %r", kwargs ) + if numa_options.should_fall_back_if_binding_fails: + logger.warning("Falling back to original command without NUMA bindings.") + return None raise +def _get_temporary_executable_for_command( + *, + command_args: tuple[str, ...], +) -> str: + """ + Returns: + Path to a temporary file which executes the specified command. The executable + deletes itself the first time it runs, so do not try to run it multiple times. + """ + fd, path = mkstemp( + prefix="pytorch-numa-bind", + suffix=".sh", + ) + + # We do rm first to guarantee the file deletes itself. The rest of the file + # will still run as intended. + contents = f"""#!/bin/bash + +# If this file is more than a few minutes old and still exists on your machine, +# that is NOT expected. It should have deleted itself. If you are seeing an accumulation of such +# files, that could suggest a bug in pytorch. See https://github.com/pytorch/pytorch/pull/160163. + +rm -- "$0" +{" ".join(command_args)} +""" + + with os.fdopen(fd, "w") as file: + file.write(contents) + + # Ensure the file is fully synced, in order to avoid race condition + # from trying to execute it too early. + file.flush() + os.fsync(fd) + + # Make the script executable + os.chmod(path, stat.S_IRWXU) + + logger.info( + "Created temporary executable at path %s, with contents\n%s", path, contents + ) + + return path + + +def _get_numactl_cli_options( + *, + command_args: tuple[str, ...], + gpu_index: int, + numa_options: NumaOptions, +) -> tuple[str, ...]: + """ + Args: + command_args: The args for a command, such as might be input to Popen. + Example: ("python", "trainer.py") + gpu_index: The index of the GPU that will be used by the subprocess which executes command_args. + Example: 0 + numa_options: See NumaOptions for details. + + Returns: + Depending on numa_options, something like + ("--cpunodebind=0") + """ + if numa_options.affinity_mode == AffinityMode.NODE: + numactl_command_options = _get_node_numactl_options(gpu_index=gpu_index) + elif numa_options.affinity_mode == AffinityMode.SOCKET: + numactl_command_options = _get_socket_numactl_options(gpu_index=gpu_index) + elif numa_options.affinity_mode == AffinityMode.EXCLUSIVE: + numactl_command_options = _get_exclusive_numactl_options(gpu_index=gpu_index) + elif numa_options.affinity_mode == AffinityMode.CORE_COMPLEX: + numactl_command_options = _get_core_complex_numactl_options(gpu_index=gpu_index) + else: + raise ValueError(f"Affinity mode {numa_options.affinity_mode} not supported.") + + return numactl_command_options + + def _raise_if_numactl_fails_dry_run(*, numactl_options: tuple[str, ...]) -> None: noop_args = _get_assembled_command_from_pieces( # Execute arbitrary noop command_args=("true",), numactl_options=numactl_options, ) + + temporary_executable_path = _get_temporary_executable_for_command( + command_args=noop_args + ) + try: - subprocess.run( - noop_args, + run( + (temporary_executable_path,), stdout=subprocess.DEVNULL, # These allow us to capture the stderr as text stderr=subprocess.PIPE, @@ -219,14 +288,11 @@ def _get_node_numactl_options(*, gpu_index: int) -> tuple[str, ...]: Core logic of 'node' numa strategy. Returns options to be used with numactl. E.g., - ("--cpunodebind=0", "--preferred=0"). + ("--cpunodebind=0"). """ numa_node_index = _get_numa_node_index_for_gpu_index(gpu_index=gpu_index) - return ( - f"--cpunodebind={numa_node_index}", - f"--preferred={numa_node_index}", - ) + return (f"--cpunodebind={numa_node_index}",) def _get_socket_numactl_options(*, gpu_index: int) -> tuple[str, ...]: @@ -242,14 +308,7 @@ def _get_socket_numactl_options(*, gpu_index: int) -> tuple[str, ...]: ) numa_node_indices_str = _get_ranges_str_from_ints(numa_node_indices) - return ( - f"--cpunodebind={numa_node_indices_str}", - ( - f"--preferred-many={numa_node_indices_str}" - if len(numa_node_indices) > 1 - else f"--preferred={numa_node_indices_str}" - ), - ) + return (f"--cpunodebind={numa_node_indices_str}",) def _get_exclusive_numactl_options(*, gpu_index: int) -> tuple[str, ...]: @@ -321,7 +380,6 @@ def _get_exclusive_numactl_options(*, gpu_index: int) -> tuple[str, ...]: return ( f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices_for_original_gpu)}", - f"--preferred={numa_node_index}", ) @@ -371,7 +429,6 @@ def _get_core_complex_numactl_options(*, gpu_index: int) -> tuple[str, ...]: return ( f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices_for_original_gpu)}", - f"--preferred={numa_node_index}", ) From 8e6a3138581152ab827a0997f34c470271399f5e Mon Sep 17 00:00:00 2001 From: albanD Date: Tue, 12 Aug 2025 20:14:18 +0000 Subject: [PATCH 0276/1424] Add ownership token when needed on GradientEdge (#160098) We can avoid the token by introducing PyObject preservation for THPFunction. But I think it will be too much complexity given that this kind of issue is very rare. Happy to be talked into doing it though if someone really wants to. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160098 Approved by: https://github.com/ezyang, https://github.com/soulitzer --- test/test_autograd.py | 27 +++++++++++++++++++++++++++ torch/autograd/graph.py | 14 +++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 01a2c54dc2774..7ce40e59dd4b5 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1196,6 +1196,33 @@ def fn(x, reduce=True): tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0]) ) + def test_gradient_edge_graph_ownership(self): + # Ensure we own the graph properly + class Clone(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, gX): + return gX.clone() + + inp = torch.rand(1, requires_grad=True).clone() + + # C++ Node + out = inp.clone() + edge = torch.autograd.graph.get_gradient_edge(out) + torch.autograd.backward(edge) + del out + torch.autograd.backward(edge) + + # python Node + out = Clone.apply(inp) + edge = torch.autograd.graph.get_gradient_edge(out) + torch.autograd.backward(edge) + del out + torch.autograd.backward(edge) + def test_grad_nonleaf(self): x_init = torch.randn(2, 2, requires_grad=True) x = x_init diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index bf643a97f60f6..4b2707b65d0f1 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -194,6 +194,9 @@ class GradientEdge(NamedTuple): node: Node output_nr: int + # This token can be used to ensure the graph stays alive when it cannot be + # done via the node field + ownership_token: Optional[Node] = None def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: @@ -209,9 +212,18 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: ) grad_fn = _get_grad_fn_or_grad_acc(tensor) + # Python-based Node are owned by the C++ side meaning the python grad_fn + # object we hold here does NOT keep the C++ graph alive. + # Create an ownership token by creating a new C++ node that own the graph + # we care about here. + token = None + if isinstance(grad_fn, torch._C._FunctionBase): + with torch.enable_grad(): + token = tensor.view_as(tensor).grad_fn + # Note that output_nr default to 0 which is the right value # for the AccumulateGrad node. - return GradientEdge(grad_fn, tensor.output_nr) + return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token) def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None: From f95b58c2844b3444cd8446fed8570729dc4216eb Mon Sep 17 00:00:00 2001 From: Ankita George Date: Tue, 12 Aug 2025 11:01:41 -0700 Subject: [PATCH 0277/1424] Remove usage of fsspec in HF consolidation script (#159392) Moving towards just supporting local storage to take advantage of HF apis such as safe_open. This was already done in Storage component in https://github.com/pytorch/pytorch/pull/159405. This PR removes fsspec usages in consolidation script and relies on local storage only Differential Revision: [D78997975](https://our.internmc.facebook.com/intern/diff/D78997975/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159392 Approved by: https://github.com/sibuachu --- .../checkpoint/_consolidate_hf_safetensors.py | 132 +++++------------- torch/distributed/checkpoint/hf_storage.py | 8 +- 2 files changed, 34 insertions(+), 106 deletions(-) diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index 8577180e9f893..a0d205f808213 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -1,33 +1,26 @@ # pyre-strict import concurrent.futures +import glob import json import logging import math import mmap import os -import shutil import struct -import tempfile import time from dataclasses import dataclass, field from typing import Any, Optional -import fsspec # type: ignore[import-untyped] -from fsspec.core import url_to_fs # type: ignore[import-untyped] -from fsspec.implementations.local import LocalFileSystem # type: ignore[import-untyped] - import torch from torch.distributed.checkpoint._hf_utils import ( _gen_file_name, _get_dcp_custom_metadata, - _get_dtype, _get_safetensors_file_metadata, _metadata_fn, DATA_OFFSETS_KEY, DEFAULT_EXTRA_METADATA_KEY, DTYPE_KEY, - FILE_NAME, SAVED_OFFSETS_KEY, SHAPE_KEY, SUFFIX, @@ -100,6 +93,9 @@ def _parse_input_metadata( Raises: ValueError: If no DCP custom metadata is found in a safetensors file """ + + from safetensors.torch import _getdtype # type: ignore[import] + # Dictionary to track the full size of each tensor across all shards fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {} @@ -138,14 +134,13 @@ def _parse_input_metadata( if fqn in output_data.fqn_data or len(output_files_data) == 1: output_data.fqn_data[fqn] = _FqnData( shape_in_file=tensor_size, - dtype_size=torch.finfo(_get_dtype(dtype_str)).bits + dtype_size=torch.finfo(_getdtype(dtype_str)).bits // 8, # Convert bits to bytes dtype_str=dtype_str, ) def _write_metadata( - fs: fsspec.AbstractFileSystem, output_files_data: dict[str, _OutputFileData], ) -> None: """ @@ -156,12 +151,11 @@ def _write_metadata( field for each tensor in the output_files_data. Args: - fs: Filesystem interface for file operations output_files_data: Dictionary mapping output file paths to their metadata """ # Process each output file for file_path, output_data in output_files_data.items(): - with fs.open(file_path, "wb") as f: + with open(file_path, "wb") as f: metadata = {} curr_offset = 0 @@ -205,7 +199,6 @@ def _write_metadata( def _read_tensor_data_mmap( - input_fs: fsspec.AbstractFileSystem, file_path: str, start_offset: int, end_offset: int, @@ -215,7 +208,6 @@ def _read_tensor_data_mmap( Read tensor data from a safetensors file using memory mapping for efficiency. Args: - input_fs: Filesystem interface for input file operations file_path: Path to the safetensors file start_offset: Start offset of tensor data within the data section end_offset: End offset of tensor data within the data section @@ -224,24 +216,15 @@ def _read_tensor_data_mmap( Returns: Raw tensor data as bytes """ - # For local files, use mmap for efficient access - if isinstance(input_fs, LocalFileSystem): - # Local file - use mmap - with open(file_path, "rb") as f: - with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: - absolute_start = metadata_size + start_offset - absolute_end = metadata_size + end_offset - return bytes(mm[absolute_start:absolute_end]) - else: - # Remote file - fall back to regular read - with input_fs.open(file_path, "rb") as f: - f.seek(metadata_size + start_offset) - return f.read(end_offset - start_offset) + # Use mmap for efficient access + with open(file_path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + absolute_start = metadata_size + start_offset + absolute_end = metadata_size + end_offset + return bytes(mm[absolute_start:absolute_end]) def _process_output_file( - input_fs: fsspec.AbstractFileSystem, - output_fs: fsspec.AbstractFileSystem, output_file: str, output_data: _OutputFileData, input_files_data: dict[str, _InputFileData], @@ -252,8 +235,6 @@ def _process_output_file( This function is designed to be run in parallel for different output files. Args: - input_fs: Filesystem interface for input file operations - output_fs: Filesystem interface for output file operations output_file: Path to the output file output_data: Metadata for the output file input_files_data: Dictionary mapping input file paths to their metadata @@ -275,7 +256,6 @@ def _process_output_file( # Use memory mapping to read tensor data efficiently data_to_write = _read_tensor_data_mmap( - input_fs, safetensors_file, data_offsets[0], data_offsets[1], @@ -291,7 +271,6 @@ def _process_output_file( # Write this tensor shard to the appropriate position in the output file _write_sub_tensor_to_file_optimized( - output_fs, data_to_write, fqn_data.dtype_size, # Size of each element in bytes fqn_data.shape_in_file, # Full tensor shape @@ -304,8 +283,6 @@ def _process_output_file( def _write_data( - input_fs: fsspec.AbstractFileSystem, - output_fs: fsspec.AbstractFileSystem, input_files_data: dict[str, _InputFileData], output_files_data: dict[str, _OutputFileData], num_threads: int = 1, @@ -318,8 +295,6 @@ def _write_data( the work is split across threads with each thread handling a different output file. Args: - input_fs: Filesystem interface for input file operations - output_fs: Filesystem interface for output file operations input_files_data: Dictionary mapping input file paths to their metadata output_files_data: Dictionary mapping output file paths to their metadata num_threads: Number of threads to use for parallel processing @@ -327,9 +302,7 @@ def _write_data( if num_threads <= 1 or len(output_files_data) <= 1: # Sequential processing for output_file, output_data in output_files_data.items(): - _process_output_file( - input_fs, output_fs, output_file, output_data, input_files_data - ) + _process_output_file(output_file, output_data, input_files_data) else: # Parallel processing with ThreadPoolExecutor with concurrent.futures.ThreadPoolExecutor( @@ -340,8 +313,6 @@ def _write_data( futures.append( executor.submit( _process_output_file, - input_fs, - output_fs, output_file, output_data, input_files_data, @@ -359,7 +330,6 @@ def _write_data( def _write_sub_tensor_to_file_optimized( - fs: fsspec.AbstractFileSystem, sub_tensor_bytes: bytes, element_size: int, tensor_shape: list[int], @@ -379,7 +349,6 @@ def _write_sub_tensor_to_file_optimized( - Optimized chunks for other patterns Args: - fs: Filesystem interface for file operations sub_tensor_bytes: Raw tensor data as bytes element_size: Size of each element in bytes tensor_shape: Shape of the full tensor @@ -403,7 +372,7 @@ def _write_sub_tensor_to_file_optimized( total_elements = math.prod(sub_tensor_shape) - with fs.open(output_file_path, "r+b") as out_f: + with open(output_file_path, "r+b") as out_f: elements_written = 0 while elements_written < total_elements: @@ -524,10 +493,19 @@ def _calculate_max_contiguous_elements( def _write_overall_metadata_file( - fs: fsspec.AbstractFileSystem, output_dir: str, output_files_data: dict[str, _OutputFileData], ) -> None: + """ + Write the overall metadata file that maps tensor names to their file locations. + + This creates a model.safetensors.index.json file that HuggingFace models use + to locate tensors across multiple files. + + Args: + output_dir: Directory where the metadata file will be written + output_files_data: Dictionary mapping output file paths to their metadata + """ total_size = 0 weight_map = {} for output_path, value in output_files_data.items(): @@ -540,32 +518,10 @@ def _write_overall_metadata_file( metadata_to_write["weight_map"] = weight_map metadata_path = os.path.join(output_dir, f"{_metadata_fn}") - with fs.open(metadata_path, "w") as metadata_file: + with open(metadata_path, "w") as metadata_file: json.dump(metadata_to_write, metadata_file, indent=2) -def _upload_files_to_remote_fs( - local_fs: fsspec.AbstractFileSystem, - local_dir: str, - output_fs: fsspec.AbstractFileSystem, - output_dir: str, -) -> None: - """ - Uploads the consolidated files to the remote filesystem. - """ - for path in local_fs.ls(local_dir, detail=False): - file = os.path.basename(path) - model_str = FILE_NAME.split("-")[0] - # Upload only the consolidated files with full tensors or the metadata file. - # The check for file.startwith(model_str) is to ensure that we only upload - # the consolidated files in the format "model-0000n-of-0000m.safetensors" - # and not the files with sharded tensors. - if file.endswith(SUFFIX) and file.startswith(model_str) or file == _metadata_fn: - local_path = os.path.join(local_dir, file) - remote_path = os.path.join(output_dir, file) - output_fs.put_file(local_path, remote_path) - - def consolidate_safetensors_files( input_dir: str, output_dir: str, @@ -597,17 +553,6 @@ def consolidate_safetensors_files( output_dir, start_time, ) - # Create filesystem using fsspec for file operations - input_fs, _ = url_to_fs(input_dir) - output_fs, _ = url_to_fs(output_dir) - - if not isinstance(output_fs, LocalFileSystem): - local_output_dir = tempfile.mkdtemp() - logger.info("Created temporary directory %s", local_output_dir) - local_output_fs, _ = url_to_fs(local_output_dir) - else: - local_output_fs = output_fs - local_output_dir = output_dir # Initialize the output file structure output_files_data: dict[str, _OutputFileData] = {} @@ -616,7 +561,7 @@ def consolidate_safetensors_files( for fqn, index in fqn_to_index_mapping.items(): # Generate names like "model-00001-of-00005.safetensors" file_name = _gen_file_name(index, max(fqn_to_index_mapping.values())) - output_path = os.path.join(local_output_dir, file_name) + output_path = os.path.join(output_dir, file_name) if output_path not in output_files_data: output_files_data[output_path] = _OutputFileData( @@ -627,19 +572,16 @@ def consolidate_safetensors_files( else: # If no mapping is provided, create a single output file file_name = _gen_file_name(1, 1) - output_path = os.path.join(local_output_dir, file_name) + output_path = os.path.join(output_dir, file_name) output_files_data[output_path] = _OutputFileData() # Find all safetensors files in the input directory - safetensors_files = [] - for file in input_fs.ls(input_dir, detail=False): - if file.endswith(SUFFIX): - safetensors_files.append(file) + safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}")) # Read metadata from all input files input_files_data: dict[str, _InputFileData] = {} for safetensor_file in safetensors_files: - with input_fs.open(safetensor_file, "rb") as f: + with open(safetensor_file, "rb") as f: metadata, size = _get_safetensors_file_metadata(f) input_files_data[safetensor_file] = _InputFileData( metadata_size=size, metadata=metadata @@ -649,22 +591,12 @@ def consolidate_safetensors_files( _parse_input_metadata(input_files_data, output_files_data) # Step 2: Write metadata headers to output files - _write_metadata(local_output_fs, output_files_data) + _write_metadata(output_files_data) # Step 3: Write actual tensor data from input files to output files - _write_data( - input_fs, local_output_fs, input_files_data, output_files_data, num_threads - ) + _write_data(input_files_data, output_files_data, num_threads) # Step 4: Write overall model.index.safetensors.json file with weight map - _write_overall_metadata_file(local_output_fs, local_output_dir, output_files_data) + _write_overall_metadata_file(output_dir, output_files_data) logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time) - - if local_output_dir != output_dir: - logger.info("Copying consolidated files to remote storage %s", output_dir) - _upload_files_to_remote_fs( - local_output_fs, local_output_dir, output_fs, output_dir - ) - shutil.rmtree(local_output_dir) - logger.info("Deleting temporary directory %s", local_output_dir) diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 6b36e619f7ced..542203ed82cf7 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -47,9 +47,7 @@ class HuggingFaceStorageWriter(FileSystemWriter): """ - A writer that writes to a huggingface repository in the huggingface format. - Uses Fsspec back-end to communicate with back-end storage. - Fsspec registration of the storage solution is required. + A writer that writes to storage in the huggingface safetensors format. """ def __init__( @@ -196,9 +194,7 @@ def metadata_path(self) -> str: class HuggingFaceStorageReader(FileSystemReader): """ - A reader that reads from a huggingface repository in the huggingface format. - Uses in Fsspec back-end to communicate with storage. - Fsspec registration of the storage solution is required. + A reader that reads a checkpoint in the huggingface safetensors format. """ def __init__(self, path: str) -> None: From a354fa91e26b376d96385a2206c5ff5b42aa4600 Mon Sep 17 00:00:00 2001 From: Chien-Lin Chen Date: Tue, 12 Aug 2025 20:52:21 +0000 Subject: [PATCH 0278/1424] added class or module info for functions blocked by weight-only load (#159935) Fixes #152985 In #152985, users are confused why weights-only load failed even though functions were registered in safe_globals. Because the error message doesn't make the critical failure reason clear, they couldn't figure out only some functions are missing from safe_globals registration. This fix is to make that point more clear. Here's the new errror message, the blocked function information will be following the warning message with a line breaker to make it stand out. ``` _pickle.UnpicklingError: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Trying to call reduce for unrecognized function which belongs to Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html. To execute this test, run the following from the base repo dir: python test/test_serialization.py TestSerialization.test_weights_only_with_safe_zoneinfo_unpickle_registration_success This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159935 Approved by: https://github.com/mikaylagawarecki --- test/test_serialization.py | 34 ++++++++++++++++++++++++++++++++ torch/_weights_only_unpickler.py | 5 ++++- torch/serialization.py | 2 +- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index 3413366608f4e..8fa78cb5da4b5 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -61,6 +61,7 @@ ) from torch.testing._internal.two_tensor import TwoTensor # noqa: F401 from torch.utils._import_utils import import_dill +from pickle import UnpicklingError if not IS_WINDOWS: @@ -1356,6 +1357,39 @@ def test_weights_only_error(self, unsafe_global): "file an issue with the following so that we can make `weights_only=True`"): torch.load(f, weights_only=True) + def test_weights_only_blocked_func_error_msg(self): + import datetime + import zoneinfo + + data = { + "a": torch.tensor([1, 2, 3]), + "b": datetime.datetime(2025, 1, 1, 12, 0, tzinfo=zoneinfo.ZoneInfo(key="UTC")), + } + with tempfile.NamedTemporaryFile() as f: + torch.save(data, f) + f.seek(0) + + with torch.serialization.safe_globals([datetime.datetime, getattr, zoneinfo.ZoneInfo]): + with self.assertRaisesRegex(UnpicklingError, ".*_unpickle.*zoneinfo.ZoneInfo.*"): + torch.load(f) + + + def test_weights_only_with_zoneinfo_unpickle_registration_success(self): + import datetime + import zoneinfo + + data = { + "a": torch.tensor([1, 2, 3]), + "b": datetime.datetime(2025, 1, 1, 12, 0, tzinfo=zoneinfo.ZoneInfo(key="UTC")), + } + with tempfile.NamedTemporaryFile() as f: + torch.save(data, f) + f.seek(0) + + with torch.serialization.safe_globals([datetime.datetime, getattr, zoneinfo.ZoneInfo, zoneinfo.ZoneInfo._unpickle]): + loaded_data = torch.load(f) + self.assertEqual(loaded_data, data) + @parametrize('weights_only', (False, True)) def test_serialization_math_bits(self, weights_only): t = torch.randn(1, dtype=torch.cfloat) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 2352bb836a9d2..745cdd315a634 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -403,9 +403,12 @@ def load(self): func not in _get_allowed_globals().values() and func not in _get_user_allowed_globals().values() ): - raise UnpicklingError( + error_msg = ( f"Trying to call reduce for unrecognized function {func}" ) + if hasattr(func, "__self__"): + error_msg += f" which belongs to {func.__self__}" + raise UnpicklingError(error_msg) result = func(*args) if func in torch._tensor_classes and "sparse" in func.__module__: _sparse_tensors_to_validate.append(result) diff --git a/torch/serialization.py b/torch/serialization.py index 61a4acf684152..a6eb314fc1a82 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1426,7 +1426,7 @@ def _get_wo_message(message: str) -> str: "Please file an issue with the following so that we can make " "`weights_only=True` compatible with your use case: WeightsUnpickler error: " ) - updated_message += message + updated_message += "\n\n" + message return updated_message + DOCS_MESSAGE weights_only_not_set = weights_only is None From 5a9c4cfce42b9eb87da0de40c5633f083115c307 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Tue, 12 Aug 2025 00:38:40 -0700 Subject: [PATCH 0279/1424] [Fix XPU CI][Inductor UT] Fix test cases broken by community. (#160403) Fixes #160243, Fixes #160244, Fixes #160245 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160403 Approved by: https://github.com/janeyx99 --- test/inductor/test_torchinductor_opinfo.py | 13 +++++++++++-- torch/_inductor/ir.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 1ee24c74bb766..c3a6662f1bf3c 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -26,6 +26,7 @@ OpDTypes, ops, skipCPUIf, + skipCUDAIf, skipXPUIf, ) from torch.testing._internal.common_methods_invocations import op_db, skipOps @@ -45,11 +46,11 @@ from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, + HAS_CUDA_AND_TRITON, has_triton, HAS_XPU_AND_TRITON, maybe_skip_size_asserts, ) -from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.utils._dtype_abbrs import dtype_abbrs from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map @@ -682,6 +683,14 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("nn.functional.unfold", f16): { "reference_in_float": True, }, + # Reference crash on Intel LTS2 driver. + ("nn.functional.interpolate.trilinear", f32): { + "check_gradient": False, + }, + # Reference crash on Intel LTS2 driver. + ("nn.functional.interpolate.trilinear", f64): { + "check_gradient": False, + }, } if TEST_WITH_ROCM: inductor_override_kwargs["cuda"].update( @@ -1125,7 +1134,7 @@ def tearDown(self): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently - @requires_cuda_and_triton + @skipCUDAIf(not HAS_CUDA_AND_TRITON, "Skipped! Triton not found") @skipXPUIf( not HAS_XPU_AND_TRITON, "Skipped! Supported XPU compiler and Triton not found" ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a668cd41ebf1b..9859ca8a1b132 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7078,10 +7078,10 @@ def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: # x.get_stride() may be unimplemented if x's size is empty stride = x.get_stride() is_destination_pinned = ( - x_device.type == "cuda" and device.type == "cpu" and non_blocking + is_gpu(x_device.type) and device.type == "cpu" and non_blocking ) is_source_pinned = ( - x_device.type == "cpu" and device.type == "cuda" and non_blocking + x_device.type == "cpu" and is_gpu(device.type) and non_blocking ) if is_source_pinned and is_storage_and_layout(x): x.get_layout().is_pinned = True From b4596895b9d85a686c2cb978938b0a7797b3690a Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Tue, 12 Aug 2025 21:05:24 +0000 Subject: [PATCH 0280/1424] [DTensor] Registers sharding rule for rms_norm (#159692) Reduces collective calls in the forward pass from 2 to 1 In #158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After #159324 this should get properly dispatched hence I am adding it now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159692 Approved by: https://github.com/tianyu-l --- test/distributed/tensor/test_math_ops.py | 178 +++++++-------------- torch/distributed/tensor/_ops/_math_ops.py | 65 +++++--- 2 files changed, 103 insertions(+), 140 deletions(-) diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 93ce80f18ee15..2419720256ded 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -271,14 +271,22 @@ def test_layer_norm_fwd(self): norm_shape_idx_list = list(range(x.ndim)) shard_dims = [-1, 0, 1, 2] elementwise_affine_list = [False, True] + + # Test RMSNorm as well if CUDA + norm_types = [torch.nn.LayerNorm] + if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"): + norm_types.append(torch.nn.RMSNorm) + test_config_list = list( - itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) + itertools.product( + norm_types, shard_dims, norm_shape_idx_list, elementwise_affine_list + ) ) # normalized shape is a torch.Size object - for shard_dim, norm_idx, elementwise_affine in test_config_list: + for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list: normalized_shape = x.shape[norm_idx:] - layer_norm = torch.nn.LayerNorm( + layer_norm = norm_type( normalized_shape, elementwise_affine=elementwise_affine, device=self.device_type, @@ -287,6 +295,7 @@ def test_layer_norm_fwd(self): def _replicate_fn(name, module, device_mesh): for name, param in module.named_parameters(): + # RMSNorm only has weight, LayerNorm has both weight and bias if name in ["weight", "bias"]: param_dist = torch.nn.Parameter( distribute_tensor(param, device_mesh, [Replicate()]) @@ -307,7 +316,7 @@ def _replicate_fn(name, module, device_mesh): self.assertLessEqual( comm_mode.get_total_counts(), 1, # TODO: This should be 0! - f"comm count={comm_mode.get_total_counts()}, " + f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -329,12 +338,20 @@ def test_layer_norm_bwd(self): norm_shape_idx_list = list(range(3)) shard_dims = [0, 1, 2] elementwise_affine_list = [False, True] + + # Test both LayerNorm and RMSNorm (if CUDA) + norm_types = [torch.nn.LayerNorm] + if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"): + norm_types.append(torch.nn.RMSNorm) + test_config_list = list( - itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) + itertools.product( + norm_types, shard_dims, norm_shape_idx_list, elementwise_affine_list + ) ) # normalized shape is a torch.Size object - for shard_dim, norm_idx, elementwise_affine in test_config_list: + for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list: x = torch.rand( batch, sentence_length, @@ -343,7 +360,7 @@ def test_layer_norm_bwd(self): requires_grad=True, ) normalized_shape = x.shape[norm_idx:] - layer_norm = torch.nn.LayerNorm( + layer_norm = norm_type( normalized_shape, elementwise_affine=elementwise_affine, device=self.device_type, @@ -364,9 +381,11 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( layer_norm_local.weight, layer_norm_dist.weight.full_tensor() ) - self.assertEqual( - layer_norm_local.bias, layer_norm_dist.bias.full_tensor() - ) + # RMSNorm doesn't have bias + if hasattr(layer_norm_local, "bias"): + self.assertEqual( + layer_norm_local.bias, layer_norm_dist.bias.full_tensor() + ) x_local = x.detach().clone().requires_grad_(True) x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) @@ -384,7 +403,7 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( sum(comm_mode.comm_module_counts["Global"]["forward"].values()), expected_fwd_comm, - f"comm count={comm_mode.get_total_counts()}, " + f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -398,7 +417,7 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( sum(comm_mode.comm_module_counts["Global"]["backward"].values()), expected_bwd_comm, - f"comm count={comm_mode.get_total_counts()}, " + f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -412,18 +431,22 @@ def _replicate_fn(name, module, device_mesh): is_tensor_partial(layer_norm_dist.weight.grad._spec), needs_reduction, ) - self.assertEqual( - is_tensor_partial(layer_norm_dist.bias.grad._spec), - needs_reduction, - ) + # RMSNorm doesn't have bias + if hasattr(layer_norm_dist, "bias"): + self.assertEqual( + is_tensor_partial(layer_norm_dist.bias.grad._spec), + needs_reduction, + ) self.assertEqual( layer_norm_local.weight.grad, layer_norm_dist.weight.grad.full_tensor(), ) - self.assertEqual( - layer_norm_local.bias.grad, - layer_norm_dist.bias.grad.full_tensor(), - ) + # RMSNorm doesn't have bias + if hasattr(layer_norm_local, "bias"): + self.assertEqual( + layer_norm_local.bias.grad, + layer_norm_dist.bias.grad.full_tensor(), + ) self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) @@ -432,8 +455,14 @@ def test_layer_norm_bwd_req_grad(self): device_mesh = self.build_device_mesh() batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32 + # Test both LayerNorm and RMSNorm (if CUDA) + norm_types = [torch.nn.LayerNorm] + if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"): + norm_types.append(torch.nn.RMSNorm) + # build our subtest configurations and filter out invalid ones class SubTest(NamedTuple): + norm_type: type multidim_norm: bool elementwise_affine: bool emb_req_grad: bool @@ -443,19 +472,24 @@ class SubTest(NamedTuple): subtest_fails = {} valid_filter = ( # noqa: E731 lambda cfg: ( - not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:]) + not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[3:]) ) ) subtest_cfgs = list( filter( valid_filter, - [SubTest(*cfg) for cfg in itertools.product(*(((False, True),) * 5))], + [ + SubTest(norm_type, *cfg) + for norm_type in norm_types + for cfg in itertools.product(*(((False, True),) * 5)) + ], ) ) for subtest_cfg in subtest_cfgs: try: ( + norm_type, multidim_norm, elementwise_affine, emb_req_grad, @@ -473,7 +507,7 @@ def __init__(self): self.preln_embeddings = torch.nn.Embedding( vocab_size, embedding_dim ) - self.layer_norm = torch.nn.LayerNorm( + self.layer_norm = norm_type( normalized_shape, elementwise_affine=elementwise_affine ) self.postln_linear = torch.nn.Linear( @@ -572,104 +606,6 @@ def forward(self, tokens): f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" ) - @with_comms - def test_rms_norm_bwd(self): - device_mesh = self.build_device_mesh() - - # NLP example from pytorch docs - batch, sentence_length, embedding_dim = 20, 5, 10 - norm_shape_idx_list = list(range(3)) - shard_dims = [0] # non-first dimensional sharding is not supported - elementwise_affine_list = [False, True] - test_config_list = list( - itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) - ) - - # normalized shape is a torch.Size object - for shard_dim, norm_idx, elementwise_affine in test_config_list: - x = torch.rand( - batch, - sentence_length, - embedding_dim, - device=self.device_type, - requires_grad=True, - ) - normalized_shape = x.shape[norm_idx:] - rms_norm = torch.nn.RMSNorm( - normalized_shape, - elementwise_affine=elementwise_affine, - device=self.device_type, - ) - rms_norm_local = copy.deepcopy(rms_norm).to(self.device_type) - - def _replicate_fn(name, module, device_mesh): - for name, param in module.named_parameters(): - if name == "weight": - param_dist = torch.nn.Parameter( - distribute_tensor(param, device_mesh, [Replicate()]) - ) - module.register_parameter(name, param_dist) - - rms_norm_dist = distribute_module(rms_norm, device_mesh, _replicate_fn) - - if elementwise_affine: - self.assertEqual( - rms_norm_local.weight, rms_norm_dist.weight.full_tensor() - ) - - x_local = x.detach().clone().requires_grad_(True) - x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) - self.assertEqual(x_local, x_dist.full_tensor()) - - y_local = rms_norm_local(x_local) - # make sure that backward rms norm does not introduce extra collectives - comm_mode = CommDebugMode() - with comm_mode: - y_dist = rms_norm_dist(x_dist) - y_dist.sum().backward() - - # TODO: forward pass is sharding strategy is generated from composite, hence 1 more collective than layer_norm - # see: https://github.com/pytorch/pytorch/pull/158716#issuecomment-3096012679 - expected_fwd_comm = 0 if shard_dim < norm_idx else 2 - - self.assertEqual( - sum(comm_mode.comm_module_counts["Global"]["forward"].values()), - expected_fwd_comm, - f"comm count={comm_mode.get_total_counts()}, " - f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", - ) - - self.assertEqual(y_local, y_dist.full_tensor()) - - # backward step - y_local.sum().backward() - - expected_bwd_comm = 0 if shard_dim < norm_idx else 1 - - self.assertEqual( - sum(comm_mode.comm_module_counts["Global"]["backward"].values()), - expected_bwd_comm, - f"comm count={comm_mode.get_total_counts()}, " - f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", - ) - - if elementwise_affine: - # if input is sharded on any outer dimension, the gradient of weight - # should be Partial - dim_map = x_dist._spec.dim_map - outer_dims = range(norm_idx) - needs_reduction = any(dim_map[d] >= 0 for d in outer_dims) - self.assertEqual( - is_tensor_partial(rms_norm_dist.weight.grad._spec), - needs_reduction, - ) - self.assertEqual( - rms_norm_local.weight.grad, - rms_norm_dist.weight.grad.full_tensor(), - ) - - self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) - @with_comms def test_topk(self): device_mesh = self.build_device_mesh() diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 78d2ac3e4b137..1e6eb40939e4a 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -818,27 +818,38 @@ def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy: return grad_in_strategy -@register_op_strategy( - [aten.native_layer_norm.default], - schema_info=RuntimeSchemaInfo(1), -) -def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: +def _common_norm_forward_strategy( + op_schema: OpSchema, + rms_norm: bool = False, +) -> OpStrategy: + """Common forward strategy logic for layer_norm and rms_norm.""" mesh = op_schema.get_mesh_from_args() - # args must be: input, normalized_shape, weight, bias, eps - # for None weight and bias, their corresponding objects will - # be None as well. layer_norm_strategy returns one OpStrategy - # for the triple return values (out, mean, rstd). - assert len(op_schema.args_schema) == 5 - ( - input_strategy, - normalized_shape, - weight_strategy, - bias_strategy, - _, - ) = op_schema.args_schema + if not rms_norm: + # layer_norm args: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + assert len(op_schema.args_schema) == 5 + ( + input_strategy, + normalized_shape, + weight_strategy, + bias_strategy, + _, + ) = op_schema.args_schema + else: + # rms_norm args: input, normalized_shape, weight, eps + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + normalized_shape, + weight_strategy, + _, + ) = op_schema.args_schema + bias_strategy = None - # the current layer norm implementation requires that all + # the current norm implementation requires that all # input DTensor's sharding must be in form of OpStrategy assert isinstance(input_strategy, OpStrategy) assert isinstance(normalized_shape, (int, Sequence, torch.Size)) @@ -847,7 +858,7 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: input_ndim = input_strategy.ndim axis = input_ndim - len(normalized_size) - # we use OpStrategy because the output (out, mean, rstd) + # we use OpStrategy because the output values (out, mean, rstd) # should have the same placements output_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): @@ -915,6 +926,22 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: return output_strategy +@register_op_strategy( + [aten.native_layer_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_forward_strategy(op_schema) + + +@register_op_strategy( + [aten._fused_rms_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def fused_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_forward_strategy(op_schema, rms_norm=True) + + def _common_norm_backward_strategy( op_schema: OpSchema, rms_norm: bool = False, From c24ca7f4bf79f62fd623d76346ca27e53f731431 Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Tue, 12 Aug 2025 10:06:12 -0700 Subject: [PATCH 0281/1424] [FSDP][Collectives] skipping allgather when world size is 1 (#160135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_params group to skip the foreach_all_gather and foreach_all_gather_copy_out APIs when world_size ‎ = 1. I have created a test that uses CommDebugMode to verify that the all gather comm has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. Below, I have included the link to the profile trace verifying these two APIs were skipped and two test commands. https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/anshulsi_f846ac3b-9467-4060-8e36-8cc3bc4449c3_devgpu263.prn2.facebook.com_652183.1753822140871934814.pt.trace.json Pull Request resolved: https://github.com/pytorch/pytorch/pull/160135 Approved by: https://github.com/weifengpy --- .../fsdp/test_fully_shard_compile.py | 42 ++++++++---- .../fsdp/test_fully_shard_training.py | 65 ++++++++++++++++++ .../test_2d_composability.py | 8 +-- .../fsdp/_fully_shard/_fsdp_param_group.py | 68 ++++++++++++++++--- 4 files changed, 159 insertions(+), 24 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index c8e98c5c3e1f3..b64d4107ee0ca 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -299,12 +299,20 @@ def _check_count(copy_count, resize_count): def _reinplace_all_gather_with_optional_checks(self, fwd_fullgraph): def _run_with_checks(graph, orig_fn): - self.assertGreater( - _count_op_in_graph( - graph, torch.ops._c10d_functional.all_gather_into_tensor.default - ), - 0, - ) + if self.world_size > 1: + self.assertGreater( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, + ) + elif self.world_size == 1: + self.assertEqual( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, + ) orig_fn(graph) @@ -315,12 +323,22 @@ def _run_with_checks(graph, orig_fn): 0, ) - self.assertGreater( - _count_op_in_graph( - graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default - ), - 0, - ) + if self.world_size > 1: + self.assertGreater( + _count_op_in_graph( + graph, + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + ), + 0, + ) + else: + self.assertEqual( + _count_op_in_graph( + graph, + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + ), + 0, + ) if fwd_fullgraph: return mock.patch.object( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index cf8b86cc8e06d..6ff022f46d192 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -1467,5 +1467,70 @@ def forward(self, imgs: torch.Tensor) -> torch.Tensor: check_sharded_parity(self, ref_model, model) +class TestFullyShardWorldSize1(FSDPTest): + @property + def world_size(self) -> int: + return 1 + + @skip_if_lt_x_gpu(1) + def test_train_parity_single_worldsize1(self): + """ + Tests train parity with DDP for a single FSDP group when sharding + parameters on dim-0. + """ + self.run_subtests( + { + "lin_shapes": [ + [(16, 15), (15, 8)], + [(7, 15), (15, 3)], + [(16, 17), (17, 8)], + ], + "use_shard_placement_fn": [False], + }, + self._test_train_parity_single_group, + ) + + def _test_train_parity_single_group( + self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool + ): + torch.manual_seed(42) + model = nn.Sequential( + nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1]) + ) + ref_model = copy.deepcopy(model).to(device_type) + replicate(ref_model, device_ids=[self.rank]) + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) + + def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + return Shard(param.shape.index(max(param.shape))) + + shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None + fully_shard(model, shard_placement_fn=shard_placement_fn) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + torch.manual_seed(42 + self.rank + 1) + inp = (torch.randn((4, lin_shapes[0][0]), device=device_type.type),) + + for iter_idx in range(10): + losses: list[torch.Tensor] = [] + + ref_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + losses.append(ref_model(*inp).sum()) + losses[-1].backward() + ref_optim.step() + + optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + comm_mode = CommDebugMode() + with comm_mode: + losses.append(model(*inp).sum()) + losses[-1].backward() + + # Before there was 1 all-gather and 1 reduce-scatter + # Now therre is 1 reduce-scatter + self.assertEqual(comm_mode.get_total_counts(), 1) + optim.step() + + self.assertEqual(losses[0], losses[1]) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 3ab0b6269b2da..bcaf06ea947a0 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -277,19 +277,19 @@ def test_tp_with_fsdp_offloading(self): loss = model(inp).sum() fwd_comm_counts = fwd_comm_mode.get_comm_counts() - self.assertEqual(len(fwd_comm_counts), 2) + self.assertEqual(len(fwd_comm_counts), 1) self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps) - self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps) + self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], 0) ref_loss = ref_model(inp).sum() self.assertEqual(loss, ref_loss) with CommDebugMode() as bwd_comm_mode: loss.backward() bwd_comm_counts = bwd_comm_mode.get_comm_counts() - self.assertEqual(len(bwd_comm_counts), 3) + self.assertEqual(len(bwd_comm_counts), 2) # First MLP's input gradient does not need to be all-reduced self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1) - self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps) + self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], 0) self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps) ref_loss.backward() diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index 121f3d4c13885..554367e8705c8 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -32,7 +32,7 @@ HSDPMeshInfo, TrainingState, ) -from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState +from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState logger = logging.getLogger("torch.distributed.fsdp.fully_shard") @@ -166,6 +166,7 @@ def __init__( self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {} self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None self._all_gather_comm: AllGather = DefaultAllGather() + self._all_gather_output = torch.empty(0, device=self.device) self._reduce_scatter_comm: ReduceScatter = DefaultReduceScatter() # Optional stream to run the user-defined all-reduce hook in # Saved here and not in the comm. context because we allow the user to @@ -310,6 +311,22 @@ def unshard(self, async_op: bool = False): # used in the all-gather streams self._wait_all_gather_streams_on_event(self._reshard_after_forward_event) self._reshard_after_forward_event = None + + world_size = self._all_gather_process_group.size() + if world_size == 1: + # can't skip due to early return in wait_for_unshard if + # no self._all_gather_result + self._all_gather_result = AllGatherResult( + all_gather_output=self._all_gather_output, + all_gather_event=self.device_handle.Event().record(), + all_gather_work=None, + param_all_gather_input_dtypes=[], + param_all_gather_input_numels=[], + all_gather_input_split_sizes=[], + ) + + return + with record_function(self._with_fqn("FSDP::all_gather")): self._all_gather_result = foreach_all_gather( self.fsdp_params, @@ -336,18 +353,52 @@ def wait_for_unshard(self): if prev_all_gather_state := self.comm_ctx.all_gather_state: self._wait_all_gather_streams_on_event(prev_all_gather_state.event) self.comm_ctx.all_gather_state = None # free the all-gather result - with record_function(self._with_fqn("FSDP::all_gather_copy_out")): - foreach_all_gather_copy_out( - self._all_gather_result, - self.fsdp_params, - self._all_gather_process_group, - ) + world_size = self._all_gather_process_group.size() + if world_size == 1: + # directly initialize unsharded parameters from sharded parameters + + for fsdp_param in self.fsdp_params: + # Use all_gather_inputs which already handles conversion to param_dtype + # This is consistent with the world_size > 1 path + all_gather_input = fsdp_param.all_gather_inputs[0] + + # Make sure the all_gather_outputs has proper storage size before using it + # First ensure we have at least one tensor in all_gather_outputs + fsdp_param.init_all_gather_outputs( + [all_gather_input.numel()], + [all_gather_input.dtype], + world_size, + self.device, + force_recreate=False, + ) + + tensor = fsdp_param.all_gather_outputs[0] + alloc_storage(tensor) + + # find alternative way to check if tensor.is_inference + with torch.autograd._unsafe_preserve_version_counter(tensor): + tensor.copy_(all_gather_input) + + else: + with record_function(self._with_fqn("FSDP::all_gather_copy_out")): + foreach_all_gather_copy_out( + self._all_gather_result, + self.fsdp_params, + self._all_gather_process_group, + ) + for fsdp_param in self.fsdp_params: fsdp_param.init_unsharded_param() + self._to_unsharded() all_gather_copy_out_event = self.device_handle.Event() all_gather_copy_out_event.record() - if not async_op and self._training_state == TrainingState.FORWARD: + + if ( + not async_op + and self._training_state == TrainingState.FORWARD + and world_size > 1 + ): # Defer free to allow for overlap of this copy-out with next # all-gather collective self.comm_ctx.all_gather_state = AllGatherState( @@ -355,6 +406,7 @@ def wait_for_unshard(self): ) else: self._wait_all_gather_streams_on_event(all_gather_copy_out_event) + self._all_gather_result = None # free unless saved in `all_gather_state` def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]): From f27232a2134150cb5e55d26a74d8c36c6a961ca5 Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Tue, 12 Aug 2025 21:15:52 +0000 Subject: [PATCH 0282/1424] [ROCm] Limit number of values per thread for reductions on three dimensions (#159652) In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159652 Approved by: https://github.com/jeffdaily --- aten/src/ATen/native/cuda/Reduce.cuh | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 15a572804af5f..521b467480900 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -209,6 +209,10 @@ struct ReduceConfig { int values_per_thread() const { return div_up(num_inputs, step_input); } + + int mock_values_per_thread(int parallelism) { + return div_up(num_inputs, step_input * parallelism); + } }; std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); @@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ else if (config.ctas_per_output < 16) config.ctas_per_output = 1; bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); - if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) + if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) { config.ctas_per_output = 4; + int vpt = config.values_per_thread(); + // Capping the number of values per thread to 2048 for now + // based on known use cases. + while (vpt >= 2048) { + config.ctas_per_output *= 2; + // Computes the new values per thread without side effects + vpt = config.mock_values_per_thread(config.ctas_per_output); + } + } #endif if (config.ctas_per_output > 1) { config.input_mult[2] = config.split_input(config.ctas_per_output); From 655137b6782a3ada290c8276c3ff0cffe09d02c7 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 12 Aug 2025 17:17:47 +0000 Subject: [PATCH 0283/1424] Update torch::stable::Tensor() default constructor (#159507) Allows things like ```cpp Tensor cu_seqlens_q; if (...) { cu_seqlens_q = ... } ... ``` Also adds `torch::stable::Tensor.defined()` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159507 Approved by: https://github.com/janeyx99 --- .../libtorch_agnostic/csrc/kernel.cpp | 35 +++++++++++++++++++ .../libtorch_agnostic/ops.py | 12 +++++++ .../test/test_libtorch_agnostic.py | 14 ++++++++ torch/csrc/inductor/aoti_torch/c/shim.h | 3 ++ .../csrc/inductor/aoti_torch/shim_common.cpp | 12 +++++-- torch/csrc/stable/tensor.h | 16 ++++++++- 6 files changed, 89 insertions(+), 3 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 63e9eb77dd34e..34f4729d98e99 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -320,3 +320,38 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { m.impl("my_zero_", &boxed_my_zero_); } + +bool test_default_constructor(bool defined) { + Tensor out; + if (defined) { + AtenTensorHandle defined_ath; + int64_t sizes[] = {2, 3}; + int64_t strides[] = {3, 1}; + aoti_torch_empty_strided( + 2, + sizes, + strides, + aoti_torch_dtype_float32(), + aoti_torch_device_type_cpu(), + 0, + &defined_ath); + out = Tensor(defined_ath); + } + return out.defined(); +} + +void boxed_test_default_constructor( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + bool res = test_default_constructor(to(stack[0])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("test_default_constructor(bool undefined) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("test_default_constructor", &boxed_test_default_constructor); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 1694bfa1b3965..04488e7d91834 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -164,3 +164,15 @@ def fill_infinity(t) -> Tensor: Returns: The modified tensor (same as input) """ return torch.ops.libtorch_agnostic.fill_infinity.default(t) + + +def test_default_constructor(defined) -> bool: + """ + Tests the default constructor for torch::stable::Tensor. + + Args: + defined: bool - if True, tests defined tensor; if False, tests undefined tensor + + Returns: bool - result of calling .defined() on the tensor + """ + return torch.ops.libtorch_agnostic.test_default_constructor.default(defined) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index bd409a0eb5a69..e197904e8ae2b 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -218,6 +218,20 @@ def test_fill_infinity(self, device): expected = torch.full_like(t, math.inf) self.assertEqual(out, expected) + @onlyCPU + def test_default_constructor(self): + import libtorch_agnostic + + defined_tensor_is_defined = libtorch_agnostic.ops.test_default_constructor( + True + ) + self.assertTrue(defined_tensor_is_defined) + + undefined_tensor_is_defined = ( + libtorch_agnostic.ops.test_default_constructor(False) + ) + self.assertFalse(undefined_tensor_is_defined) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index d6f32358cdcc5..b1446318dd34f 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -227,6 +227,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_is_defined(AtenTensorHandle tensor, bool* ret_is_defined); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle( AtenTensorHandle orig_handle, AtenTensorHandle* new_handle); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index eff8276315a20..868da9831e767 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -402,6 +402,15 @@ AOTITorchError aoti_torch_is_contiguous( }); } +AOTITorchError aoti_torch_is_defined( + AtenTensorHandle tensor, + bool* ret_is_defined) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); + *ret_is_defined = t->defined(); + }); +} + AOTITorchError aoti_torch_new_tensor_handle( AtenTensorHandle orig_handle, AtenTensorHandle* new_handle) { @@ -1204,8 +1213,7 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) { if (msg) { std::cout << " " << msg; } - std::cout << " " - << "]:" << '\n'; + std::cout << " " << "]:" << '\n'; // Print exact tensor values for small size tensors const int64_t numel = t->numel(); diff --git a/torch/csrc/stable/tensor.h b/torch/csrc/stable/tensor.h index 741da7e62e409..d02763923a5f8 100644 --- a/torch/csrc/stable/tensor.h +++ b/torch/csrc/stable/tensor.h @@ -29,7 +29,15 @@ class Tensor { std::shared_ptr ath_; public: - Tensor() = delete; + // Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH) + // Steals ownership from the ATH + Tensor() { + AtenTensorHandle ret; + TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret)); + ath_ = std::shared_ptr(ret, [](AtenTensorHandle ath) { + TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); + }); + } // Construct a stable::Tensor from an AtenTensorHandle (ATH) // Steals ownership from the ATH @@ -115,6 +123,12 @@ class Tensor { return size; } + bool defined() const { + bool defined; + TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined)); + return defined; + } + // ============================================================================= // END of C-shimified TensorBase APIs // ============================================================================= From 4d419a74610c32b1372f8802dcc61893740a23cf Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 12 Aug 2025 17:17:47 +0000 Subject: [PATCH 0284/1424] Add pad and narrow to torch/csrc/stable/ops.h (#159328) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159328 Approved by: https://github.com/janeyx99 ghstack dependencies: #159507 --- .../libtorch_agnostic/csrc/kernel.cpp | 37 +++++++++++++++++++ .../libtorch_agnostic/ops.py | 27 ++++++++++++++ .../test/test_libtorch_agnostic.py | 20 ++++++++++ .../aoti_torch/generated/c_shim_aten.h | 2 + torch/csrc/stable/ops.h | 34 +++++++++++++++++ torchgen/aoti/fallback_ops.py | 2 + 6 files changed, 122 insertions(+) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 34f4729d98e99..e3dfc581179ac 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -291,10 +291,43 @@ void boxed_fill_infinity( stack[0] = from(res); } +Tensor my_pad(Tensor t) { + std::vector padding = {1, 2, 2, 1}; + std::string mode = "constant"; + double value = 0.0; + return pad(t, padding, mode, value); +} + +void boxed_my_pad( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + auto res = my_pad(to(stack[0])); + stack[0] = from(res); +} + +Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) { + return narrow(t, dim, start, length); +} + +void boxed_my_narrow( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + auto res = my_narrow( + to(stack[0]), + to(stack[1]), + to(stack[2]), + to(stack[3])); + stack[0] = from(res); +} + STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); m.def("my_empty_like(Tensor t) -> Tensor"); m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)"); + m.def("my_pad(Tensor t) -> Tensor"); + m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { @@ -303,6 +336,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("fill_infinity", &boxed_fill_infinity); } +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) { + m.impl("my_pad", &boxed_my_pad); + m.impl("my_narrow", &boxed_my_narrow); +} Tensor my_zero_(Tensor t) { return zero_(t); diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 04488e7d91834..817732371060d 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -176,3 +176,30 @@ def test_default_constructor(defined) -> bool: Returns: bool - result of calling .defined() on the tensor """ return torch.ops.libtorch_agnostic.test_default_constructor.default(defined) + + +def my_pad(t) -> Tensor: + """ + Pads the input tensor with hardcoded padding parameters. + + Args: + t: Input tensor + + Returns: Padded tensor with padding [1, 2, 2, 1], mode "constant", value 0.0 + """ + return torch.ops.libtorch_agnostic.my_pad.default(t) + + +def my_narrow(t, dim, start, length) -> Tensor: + """ + Returns a new tensor that is a narrowed version of the input tensor. + + Args: + t: Input tensor + dim: Dimension along which to narrow + start: Starting position + length: Length of the narrowed section + + Returns: Narrowed tensor + """ + return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index e197904e8ae2b..ae3c2767627fc 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -232,6 +232,26 @@ def test_default_constructor(self): ) self.assertFalse(undefined_tensor_is_defined) + def test_my_pad(self, device): + import libtorch_agnostic + + t = torch.rand(2, 3, device=device) + out = libtorch_agnostic.ops.my_pad(t) + expected = torch.nn.functional.pad(t, [1, 2, 2, 1], "constant", 0.0) + self.assertEqual(out, expected) + + def test_my_narrow(self, device): + import libtorch_agnostic + + t = torch.randn(2, 5, device=device) + + dim0 = 0 + start0 = 0 + length0 = 1 + out0 = libtorch_agnostic.ops.my_narrow(t, dim0, start0, length0) + expected0 = torch.narrow(t, dim0, start0, length0) + self.assertEqual(out0, expected0) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h index cc2dcdf4c75e0..d5bc50750fc7f 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h @@ -15,6 +15,8 @@ extern "C" { #endif AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); #ifdef __cplusplus } // extern "C" diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index c4a8a99848055..7ce25af14d3f4 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -4,11 +4,15 @@ #include #include #include +#include +#include #include using torch::stable::Tensor; +namespace torch::stable { + // We expect this to be the stable version of the empty_like op that takes in // no kwargs (device, dtype, layout, memory_format). We will add kwargs // support in the future. @@ -36,6 +40,34 @@ inline Tensor fill_(const Tensor& self, double value) { return self; } +// We expect this to be the stable version of the narrow.default op. +// narrow takes in a SymInt for start and length, but these are typed as +// int64_t as SymInt is not yet header-only. +inline Tensor narrow(Tensor& self, int64_t dim, int64_t start, int64_t length) { + AtenTensorHandle ret0 = nullptr; + + TORCH_ERROR_CODE_CHECK( + aoti_torch_aten_narrow(self.get(), dim, start, length, &ret0)); + return Tensor(ret0); +} + +// We expect this to be the stable version of the pad.default op. +// pad.default takes in a SymInt[] as the pad argument however pad is typed as +// use std::vector because +// (1) IntArrayRef is not yet header-only +// (2) SymInt is not yet header-only +inline Tensor pad( + const Tensor& self, + std::vector pad, + const std::string& mode = "constant", + double value = 0.0) { + AtenTensorHandle ret0 = nullptr; + + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_pad( + self.get(), pad.data(), pad.size(), mode.c_str(), &value, &ret0)); + return Tensor(ret0); +} + // We expect this to be the stable version of the transpose op with identical // semantics to the existing transpose.int op. inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { @@ -56,3 +88,5 @@ inline Tensor zero_(Tensor& self) { aoti_torch_call_dispatcher("aten::zero_", "", stack.data())); return to(stack[0]); } + +} // namespace torch::stable diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 3ff40412898ab..be00c49d7b1f1 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -183,4 +183,6 @@ # The same BC rules apply as inductor_fallback_ops. aten_shimified_ops: dict[str, dict[str, list[str]]] = { "aten.fill_.Scalar": {}, + "aten.pad.default": {}, + "aten.narrow.default": {}, } From f8f0414a5983ff481a2188e0c18594150430c8c5 Mon Sep 17 00:00:00 2001 From: Ivan Zaitsev Date: Tue, 12 Aug 2025 21:36:19 +0000 Subject: [PATCH 0285/1424] fix cpp builder to avoid missing-source compile error (#160354) Summary: the condition ``` if config.is_fbcode() and (not self._aot_mode or self._use_relative_path): sources = [os.path.basename(i) for i in sources] ``` unintentionally (?) stripped paths even when use_relative_path was False (as long as aot_mode was False), breaking local tests that rely on absolute temp-file paths. Fixes internal issue: ``` FAILED (errors=1) CppCompileError: C++ compile error Command: /mnt/gvfs/third-party2/llvm-fb/0f1f083aa5508772f3db24bf4f697bc118ba0958/17/platform010/72a2ff8/bin/clang-17 czyi3nhzin5b3mc3376vmfnlbjobvjcghbvv4tatuazs3syqubay.cpp -shared -fPIC -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -Werror=ignored-optimization-argument -g -o /re_tmp/tmpsp58ya2h/zy/test_symbol.so Output: clang-17: error: no such file or directory: 'czyi3nhzin5b3mc3376vmfnlbjobvjcghbvv4tatuazs3syqubay.cpp' clang-17: error: no input files ``` Reviewed By: clee2000 Differential Revision: D80025417 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160354 Approved by: https://github.com/benjaminglass1, https://github.com/clee2000 --- torch/_inductor/cpp_builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 45e655d1dfa8e..c58849f9bf5ac 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1631,7 +1631,8 @@ def __init__( if isinstance(sources, str): sources = [sources] - if config.is_fbcode() and (not self._aot_mode or self._use_relative_path): + # Use relative paths only when requested (typically for remote builds) + if config.is_fbcode() and self._use_relative_path: # Will create another temp directory for building, so do NOT use the # absolute path. self._orig_source_paths = list(sources) From 78a2fe1d42edeaa2ef7020b0fa0ac82ee4a640e4 Mon Sep 17 00:00:00 2001 From: David Berard Date: Tue, 12 Aug 2025 11:47:04 -0700 Subject: [PATCH 0286/1424] [TorchScript] thread-safe ErrorReport::CallStack (#160386) Context: During jit.script, the TorchScript frontend maintains a callstack of Python frames, which is used to present the corresponding user code in case TorchScript errors. The callstack is maintained via ErrorReport::CallStack RAII guards. Before recursing into a function, an ErrorReport::CallStack guard is created and the CallStack guard pushes the frame information onto a thread_local callstack (a list of calls); and after exiting, the frame information is popped off the callstack. Note that the CallStack guards are also sometimes used in python via pybindings. The problem is that sometimes another thread can obtain a reference to the CallStack guard (if it's a Python CallStack guard). **This means that the destructor for a CallStack guard can be called from a different thread than the constructor was called**. When this happens, it causes a segfault. This PR makes the callstack vector thread-safe to access, and each CallStack guard will store a reference to the callstack vector onto which it pushed. When the CallStack guard is destructed, it pops off the appropriate callstack vector. Although this could potentially lead to mangled callstacks, it should prevent segfaults. Added a test `test_thread_safe_error_stacks` which segfaults prior to these changes, and no longer segfaults. Differential Revision: [D80054972](https://our.internmc.facebook.com/intern/diff/D80054972) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160386 Approved by: https://github.com/eellison --- test/jit/test_recursive_script.py | 20 +++++++++++ torch/csrc/jit/frontend/error_report.cpp | 42 ++++++++++++++++++++---- torch/csrc/jit/frontend/error_report.h | 36 ++++++++++++++++++++ 3 files changed, 92 insertions(+), 6 deletions(-) diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index d595c793e79b6..d6addfddca1a7 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -4,6 +4,7 @@ import os import re import sys +import threading import types import typing import typing_extensions @@ -773,6 +774,25 @@ def forward(self, x): mod.foo = None self.checkModule(mod, (torch.rand(2, 2),)) + def test_thread_safe_error_stacks(self): + # prior to #160386, this causes a segfault. See [Note: Thread-safe CallStack] + callstacks = [] + + def callstack_creator(): + factory = torch._C._jit_tree_views.SourceRangeFactory( + "source code", "a.py", 1, 0 + ) + x = torch._C.CallStack("a", factory.make_range(1, 0, 1)) + callstacks.append(x) + del x + + t = threading.Thread(target=callstack_creator) + t.start() + t.join() + del t + del callstacks[0] + self.assertTrue(len(callstacks) == 0) + def test_override_instance_method_ignore(self): class M(torch.nn.Module): @torch.jit.ignore diff --git a/torch/csrc/jit/frontend/error_report.cpp b/torch/csrc/jit/frontend/error_report.cpp index d642746abaaa5..d5a8408e971c0 100644 --- a/torch/csrc/jit/frontend/error_report.cpp +++ b/torch/csrc/jit/frontend/error_report.cpp @@ -6,7 +6,34 @@ namespace torch::jit { // Avoid storing objects with destructor in thread_local for mobile build. #ifndef C10_MOBILE -static thread_local std::vector calls; +// [NOTE: Thread-safe CallStack] +// `calls` maintains a stack of Python calls that resulted in the +// currently compiled TorchScript code. RAII ErrorReport::CallStack +// push and pop from the `calls` object during compilation to track +// these stacks so that they can be used to report compilation errors +// +// Q: Why can't this just be a thread_local vector (as it was previously)? +// +// A: Sometimes a CallStack RAII guard is created in Python in a given +// thread (say, thread A). Then later, someone can call +// sys._current_frames() from another thread (thread B), which causes +// thread B to hold references to the CallStack guard. e.g. +// 1. CallStack RAII guard created by thread A +// 2. CallStack guard now has a reference from thread B +// 3. thread A releases guard, but thread B still holds a reference +// 4. thread B releases guard, refcount goes to 0, and we +// call the destructor +// under this situation, **we pop an element off the wrong `call` +// object (from the wrong thread!) +// +// To fix this: +// * in CallStack, store a reference to which thread's `calls` +// the CallStack corresponds to, so you can pop from the correct +// `calls` object. +// * make it a shared_ptr and add a mutex to make this thread safe +// (since now multiple threads access a given thread_local calls object) +static thread_local std::shared_ptr calls = + std::make_shared(); #endif // C10_MOBILE ErrorReport::ErrorReport(const ErrorReport& e) @@ -17,20 +44,23 @@ ErrorReport::ErrorReport(const ErrorReport& e) #ifndef C10_MOBILE ErrorReport::ErrorReport(const SourceRange& r) - : context(r), error_stack(calls.begin(), calls.end()) {} + : context(r), error_stack(calls->get_stack()) {} void ErrorReport::CallStack::update_pending_range(const SourceRange& range) { - calls.back().caller_range = range; + calls->update_pending_range(range); } ErrorReport::CallStack::CallStack( const std::string& name, const SourceRange& range) { - calls.push_back({name, range}); + source_callstack_ = calls; + source_callstack_->push_back({name, range}); } ErrorReport::CallStack::~CallStack() { - calls.pop_back(); + if (source_callstack_) { + source_callstack_->pop_back(); + } } #else // defined C10_MOBILE ErrorReport::ErrorReport(const SourceRange& r) : context(r) {} @@ -61,7 +91,7 @@ static std::string get_stacked_errors(const std::vector& error_stack) { std::string ErrorReport::current_call_stack() { #ifndef C10_MOBILE - return get_stacked_errors(calls); + return get_stacked_errors(calls->get_stack()); #else TORCH_CHECK(false, "Call stack not supported on mobile"); #endif // C10_MOBILE diff --git a/torch/csrc/jit/frontend/error_report.h b/torch/csrc/jit/frontend/error_report.h index 635dd35468e3b..9f5ad9bf3bb68 100644 --- a/torch/csrc/jit/frontend/error_report.h +++ b/torch/csrc/jit/frontend/error_report.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace torch::jit { @@ -18,6 +19,38 @@ struct TORCH_API ErrorReport : public std::exception { const char* what() const noexcept override; + class TORCH_API Calls { + private: + std::vector calls_; + mutable std::mutex mutex_; + + public: + void push_back(Call call) { + std::lock_guard lock(mutex_); + calls_.push_back(std::move(call)); + } + + void pop_back() { + std::lock_guard lock(mutex_); + calls_.pop_back(); + } + + bool empty() const { + std::lock_guard lock(mutex_); + return calls_.empty(); + } + + void update_pending_range(const SourceRange& range) { + std::lock_guard lock(mutex_); + calls_.back().caller_range = range; + } + + std::vector get_stack() const { + std::lock_guard lock(mutex_); + return calls_; + } + }; + struct TORCH_API CallStack { // These functions are used to report why a function was being compiled // (i.e. what was the call stack of user functions at compilation time that @@ -28,6 +61,9 @@ struct TORCH_API ErrorReport : public std::exception { // Change the range that is relevant for the current function (i.e. after // each successful expression compilation, change it to the next expression) static void update_pending_range(const SourceRange& range); + + private: + std::shared_ptr source_callstack_; }; static std::string current_call_stack(); From cbffde774557752cf20447d42d99ec6102673c31 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 12 Aug 2025 21:59:50 +0000 Subject: [PATCH 0287/1424] Factor out the strings to templates for better editor integration (#160357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Summary More code motion, tldr is that install 'Better Jinja' in vscode and now you can get highlighting Before Screenshot 2025-08-11 at 2 41 08 PM After: Screenshot 2025-08-11 at 2 40 27 PM Pull Request resolved: https://github.com/pytorch/pytorch/pull/160357 Approved by: https://github.com/eellison --- setup.py | 1 + torch/_inductor/kernel/flex/common.py | 267 +---- torch/_inductor/kernel/flex/flex_attention.py | 956 +----------------- torch/_inductor/kernel/flex/flex_decoding.py | 270 +---- .../kernel/flex/templates/common.py.jinja | 193 ++++ .../flex/templates/flex_attention.py.jinja | 248 +++++ .../flex/templates/flex_backwards.py.jinja | 682 +++++++++++++ .../flex/templates/flex_decode.py.jinja | 252 +++++ .../kernel/flex/templates/utilities.py.jinja | 59 ++ 9 files changed, 1451 insertions(+), 1477 deletions(-) create mode 100644 torch/_inductor/kernel/flex/templates/common.py.jinja create mode 100644 torch/_inductor/kernel/flex/templates/flex_attention.py.jinja create mode 100644 torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja create mode 100644 torch/_inductor/kernel/flex/templates/flex_decode.py.jinja create mode 100644 torch/_inductor/kernel/flex/templates/utilities.py.jinja diff --git a/setup.py b/setup.py index cd04f5313aa43..23ef581241396 100644 --- a/setup.py +++ b/setup.py @@ -1669,6 +1669,7 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.h", "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", + "_inductor/kernel/flex/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/torch/_inductor/kernel/flex/common.py b/torch/_inductor/kernel/flex/common.py index 8ee50753439eb..6cc197a35b9cf 100644 --- a/torch/_inductor/kernel/flex/common.py +++ b/torch/_inductor/kernel/flex/common.py @@ -3,6 +3,7 @@ import math from collections.abc import Sequence +from pathlib import Path from typing import Any, Optional, Union import sympy @@ -323,267 +324,13 @@ def next_power_of_two(n): return 2 ** math.ceil(math.log2(n)) -# ---- Common Template Strings ---- -compute_forward_block_mn = r""" -@triton.jit -def forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +_TEMPLATE_DIR = Path(__file__).parent / "templates" -): - # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through - {{gen_defines() | indent_except_first(1)}} - - # -- load k -- - # NB reversed order to since K is transposed - {%- if USE_TMA %} - k = tl.load_tensor_descriptor( - desc_k, - [kv_start + kv_offset, 0], - ) - {%- else %} - k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) - {%- endif %} - - if USE_TMA: - k = tl.trans(k) - # -- compute qk --- - qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. - if not PRESCALE_QK: - qk *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, - # which is larger than the actual number of elements. To avoid access memory out of bound, - # we need to mask out the elements that are out of Q_LEN & KV_LEN. - m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) - n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) - - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qk", - b="off_z", - h="off_h", - m="m", - n="n", - out="qk" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=1, - output_name="mask_mod_output", - score="qk", - b="off_z", - h="off_h", - m="m", - n="n", - ) | indent_except_first(2) }} - - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) - # apply mask for partially unmasked blocks - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - # -- compute scaling constant --- - m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) - if not ROWS_GUARANTEED_SAFE: - masked_out_rows = (m_ij == float("-inf")) - m_ij_masked = tl.where(masked_out_rows, 0, m_ij) - else: - m_ij_masked = m_ij - - alpha = tl.math.exp2(m_i - m_ij_masked) - p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) - - # NB: l_i update is pulled up here since it's a bit faster - # NB: For headdim=256, it's faster to move it back down to after m_i = - # m_ij - l_i = l_i * alpha + tl.sum(p, 1) - # # -- scale and update acc -- - acc = acc * alpha[:, None] - {%- if USE_TMA %} - v = tl.load_tensor_descriptor( - desc_v, - [kv_start + kv_offset, 0], - ) - {%- else %} - v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) - {%- endif %} - acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) - - # -- update m_i - m_i = m_ij - - return acc, l_i, m_i - -""" - -compute_forward_inner = r""" -@triton.jit -def forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets used as inputs to score_mod & mask_mod - # of size [BLOCK_M, BLOCK_N] or scalar. - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - # blocksparse data - kv_indices, kv_num_blocks, - # start kv and end kv block - block_n_start, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through - {{gen_defines() | indent_except_first(1)}} - - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - RCP_LN2: tl.constexpr = 1.44269504 - - if PRESCALE_QK: - q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - kv_offset = 0 - - # loop over k, v and update accumulator until block_n_end - for start_n in range(block_n_start, block_n_end): - # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. - if IS_DIVISIBLE: - acc, l_i, m_i = forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - else: - # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, - # it's on par or slightly faster than only applying to the last block in fwd. - # However, we choose different strategy for bwd, where we only apply mod & mask - # to the last block because it's faster a lot. - acc, l_i, m_i = forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) - - - - offset = get_offset_for_next_block( - start_n, kv_indices, kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS - ) - offs_n = offs_n + offset - kv_offset += offset - if not USE_TMA: - K_block_ptr = tl.advance(K_block_ptr, (0, offset)) - V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) +def load_template(name: str) -> str: + """Load a template file and return its content.""" + with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f: + return f.read() - return acc, l_i, m_i - -""" - -# Inner Triton functions shared by flex_attention & split-k decoding kernels. -compute_next_offset_func = r""" -@triton.jit -def get_offset_for_next_block( - loop_iter, col_indices, total_blocks, - SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, - BLOCKS_ARE_CONTIGUOUS: tl.constexpr -): - if BLOCKS_ARE_CONTIGUOUS: - return BLOCK - cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE - cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") - next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) - needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 - jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK - offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK - return offset -""" - -get_bounded_indices_func = r""" -@triton.jit -def get_bounded_indices(indices, max_len=None): - return indices % max_len if max_len is not None else indices -""" - - -load_checked_block = r""" -@triton.jit -def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): - if IS_DIVISIBLE and SAFE_HEAD_DIM: - return tl.load(block_ptr) - elif IS_DIVISIBLE and not SAFE_HEAD_DIM: - return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") - elif not IS_DIVISIBLE and SAFE_HEAD_DIM: - return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") - else: - return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") -""" - -load_checked_2d = r""" -@triton.jit -def load_checked_2d( - ptr, - offs_m, - offs_n, - stride_m, - stride_n, - IS_DIVISIBLE_M: tl.constexpr, - IS_DIVISIBLE_N: tl.constexpr, - M_LEN: tl.constexpr, - N_DIM: tl.constexpr, -): - # Calculate final pointer if strides are provided - if stride_m is not None and stride_n is not None: - ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n - - # Handle all masking cases - if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) - elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) - elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) - else: # Both divisible - return tl.load(ptr) -""" +# Template strings have been moved to templates/common.py.jinja diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 429f8d05c8cd5..a3e441d033b3f 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -22,17 +22,12 @@ ) from .common import ( build_subgraph_buffer, - compute_forward_block_mn, - compute_forward_inner, - compute_next_offset_func, create_indices_fake, create_num_blocks_fake_generator, create_placeholder, - get_bounded_indices_func, get_fwd_subgraph_outputs, infer_dense_strides, - load_checked_2d, - load_checked_block, + load_template, maybe_realize, set_head_dim_values, SubgraphResults, @@ -67,267 +62,12 @@ def get_float32_precision(): return "'tf32'" -compute_flex_attention = r""" -{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} - # Sub notation for this kernel: - # - # Q: Query, K: Key, V: Value - # M: Number of queries, N: Number of keys/values, D: Model dimension - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # - # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. - # - # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad - # - # (Modifiable) Performance tuning options - # BLOCK_M: The thread block size across the seqlen dim of Q. - # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. - - # The below are kernel options that can be applied for certain score_mods, - # or involve a numerics vs. perf tradeoff - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has - # about 20% more numerical error, but slightly faster. - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are - # contiguous? If so, we don't need to do an indirect jump for every block - - tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) - tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) - - # Define strides of inputs - stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} - - ZQ = {{size("Q", 0)}} - HQ = {{size("Q", 1)}} - Q_LEN = {{size("Q", 2)}} - ZKV = {{size("K", 0)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - q_start = tl.program_id(0) - off_zq = tl.program_id(1) - off_hq = tl.program_id(2) - - # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. - # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. - off_zkv = off_zq % ZKV - off_hkv = off_hq // GQA_SHARED_HEADS - off_g = off_hq % GQA_SHARED_HEADS - - q_offset = off_zq * stride_qz + off_hq * stride_qh - k_offset = off_zkv * stride_kz + off_hkv * stride_kh - v_offset = off_zkv * stride_vz + off_hkv * stride_vh - - Q = Q + q_offset - K = K + k_offset - V = V + v_offset - - # Setting up the TMA descriptors for Q, K, V - desc_q = None - desc_k = None - desc_v = None - {%- if USE_TMA %} - desc_q = tl.make_tensor_descriptor( - base=Q, - shape=[Q_LEN, QK_HEAD_DIM], - strides=[stride_qm, 1], - block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], - ) - - desc_k = tl.make_tensor_descriptor( - base=K, - shape=[KV_LEN, QK_HEAD_DIM], - strides=[stride_kn, 1], - block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], - ) - - desc_v = tl.make_tensor_descriptor( - base=V, - shape=[KV_LEN, V_HEAD_DIM], - strides=[stride_vn, 1], - block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], - ) - {%- endif %} - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_zq % SPARSE_Z - sparse_idx_hq = off_hq % SPARSE_HQ - - SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - - stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} - stride_kv_idx_h = {{stride("KV_IDX", 1)}} - stride_kv_idx_m = {{stride("KV_IDX", 2)}} - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - - offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) - - # KV_IDX and KV_NUM_BLKS are always contiguous. - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq - sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE - sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 - K_block_ptr = None - V_block_ptr = None - Q_block_ptr = None - - if not USE_TMA: - Q_block_ptr = tl.make_block_ptr( - base=Q , - shape=(Q_LEN, QK_HEAD_DIM), - strides=(stride_qm, stride_qk), - offsets=(q_start * BLOCK_M, 0), - block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - - {%- if USE_TMA %} - q = tl.load_tensor_descriptor( - desc_q, - [(q_start * BLOCK_M).to(tl.int32), 0], - ) - {%- else %} - q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) - {%- endif %} - - # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We don't know anything "special" about these blocks, so we need to apply - # both score_mod and mask_mod to it - kv_indices = KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - - if not USE_TMA: - K_block_ptr = tl.make_block_ptr( - base=K, - shape=(QK_HEAD_DIM, KV_LEN), - strides=(stride_kk, stride_kn), - offsets=(0, kv_start), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - - V_block_ptr = tl.make_block_ptr( - base=V, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(kv_start, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - - offs_n = kv_start + tl.arange(0, BLOCK_N) - - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - acc, l_i, m_i, - off_zq, off_hq, offs_m[:, None], offs_n[None, :], - kv_start, - kv_indices, kv_num_blocks, - 0, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We know these blocks are guaranteed to be "full", so we don't need to - # apply mask_mod to them - only score_mod - if HAS_FULL_BLOCKS: - # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. - kv_indices = FULL_KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - if not USE_TMA: - K_block_ptr = tl.make_block_ptr( - base=K, - shape=(QK_HEAD_DIM, KV_LEN), - strides=(stride_kk, stride_kn), - offsets=(0, kv_start), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(kv_start, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = kv_start + tl.arange(0, BLOCK_N) - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - acc, l_i, m_i, - off_zq, off_hq, offs_m[:, None], offs_n[None, :], - kv_start, - kv_indices, kv_num_blocks, - 0, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - - # [Note] Handle fully masked out rows: - # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. - # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step - l_i = tl.where(l_i == 0.0, 1, l_i) - - acc = acc / l_i[:, None] - idx_zq = tl.program_id(1) - idx_hq = tl.program_id(2) - idx_m = offs_m[:, None] - idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] - - mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) - - {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} - - if OUTPUT_LOGSUMEXP: - off_hz = off_zq * HQ + off_hq - l_ptrs = LSE + off_hz * Q_LEN + offs_m - lse = m_i + tl.math.log2(l_i) - if IS_DIVISIBLE: - tl.store(l_ptrs, lse) - else: - tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) - """ - - flex_attention_template = TritonTemplate( name="flex_attention", grid=flex_attention_grid, - source=compute_flex_attention - + compute_forward_inner - + compute_next_offset_func - + compute_forward_block_mn - + load_checked_block - + get_bounded_indices_func, + source=load_template("flex_attention") + + load_template("utilities") + + load_template("common"), ) @@ -684,693 +424,7 @@ def flex_attention_backward_grid( flex_attention_backward_template = TritonTemplate( name="flex_attention_backward", grid=flex_attention_backward_grid, - source=r""" -{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} - # Sub notation for this kernel: - # - # Q: Query, K: Key, V: Value - # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) - # DELTA: Precomputed sum(OUT*DO, axis=-1) - # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value - # DK: Derivative of Key, is the written to via the store_output call due to some limitations with - # inductor codegen - # M: Number of queries, N: Number of keys/values - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # (Modifiable) Performance tuning options - # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. - # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. - # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. - # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. - # - # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. - # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. - # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. - # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. - - # The below are kernel options that can be applied for certain score_mods, - # or involve a numerics vs. perf tradeoff - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has - # about 20% more numerical error, but slightly faster. - - # Define strides of inputs - stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} - stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} - - stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} - stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} - - ZQ = {{size("Q", 0)}} - HQ = {{size("Q", 1)}} - HKV = {{size("K", 1)}} - Q_LEN = {{size("Q", 2)}} - ZKV = {{size("K", 0)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - pid = tl.program_id(0) - NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) - NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) - - off_zq = tl.program_id(1) # q batch idx - off_hkv = tl.program_id(2) # kv head idx - off_zkv = off_zq % ZKV # kv batch idx - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_zq % SPARSE_Z - - k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) - v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) - # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] - # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] - dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) - - # offset K, V, DV pointers for batch/kv-head - K += k_adj - V += v_adj - DV += dv_adj - - RCP_LN2 = 1.44269504 - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - if pid >= NUM_KV_BLOCKS: - off_pid = pid - NUM_KV_BLOCKS - # THIS BLOCK DOES DQ - SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) - SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) - off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS - start_m2_block = off_pid % NUM_Q_BLOCKS - off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE - stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} - stride_kv_idx_h = {{stride("KV_IDX", 1)}} - stride_kv_idx_m = {{stride("KV_IDX", 2)}} - - sparse_idx_hq2 = off_hq2 % SPARSE_HQ - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 - - sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask - sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 - - # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. - q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) - do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) - dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) - off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) - - Q2 = Q + q_adj2 - DO2 = DO + do_adj2 - # TODO: This does not work if DQ is not the same layout as Q (for example, - # if Q is broadcasted) - DQ2 = DQ + dq_adj2 - LSE2 = LSE + off_chz2 - DELTA2 = DELTA + off_chz2 - - # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) - dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) - - start_m2 = start_m2_block * BLOCK_M2 - offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) - - # load Q and do: they stay in SRAM throughout the inner loop. - q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) - do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) - - if PRESCALE_QK: - q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - if IS_DIVISIBLE: - Di = tl.load(DELTA2 + offs_m2) - lse = tl.load(LSE2 + offs_m2) - else: - Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) - lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) - lse = tl.where(lse == -float("inf"), 0.0, lse) - lse = lse[:, None] - - # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # KV_IDX and KV_NUM_BLKS are always contiguous. - kv_indices = KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) - - offs_n2 = kv_start + tl.arange(0, BLOCK_N2) - dq = bwd_dq_inner( - {{gen_argdefs()}}, - K, V, - dq, q, do, Di, lse, - off_zq, off_hq2, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - if HAS_FULL_BLOCKS: - # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. - kv_indices = FULL_KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) - - offs_n2 = kv_start + tl.arange(0, BLOCK_N2) - dq = bwd_dq_inner( - {{gen_argdefs()}}, - K, V, - dq, q, do, Di, lse, - off_zq, off_hq2, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - # Write back dQ. - dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd - dq *= SM_SCALE - if IS_DIVISIBLE and SAFE_HEAD_DIM: - tl.store(dq_ptrs, dq) - else: - tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) - else: - # THIS BLOCK DOES DK & DV - SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) - SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) - - pid_mask = pid // SPARSE_KV_MULTIPLE - - stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} - stride_q_idx_h = {{stride("Q_IDX", 1)}} - stride_q_idx_n = {{stride("Q_IDX", 2)}} - - - dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) - - start_n1 = pid * BLOCK_N1 - offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) - - # load K and V: they stay in SRAM throughout the inner loop. - k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) - v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) - - if PRESCALE_QK: - k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - for off_g in range(0, GQA_SHARED_HEADS): - off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g - - # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. - q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) - do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) - dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) - off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) - - Q1 = Q + q_adj1 - DO1 = DO + do_adj1 - # TODO: This does not work if DQ is not the same layout as Q (for example, - # if Q is broadcasted) - LSE1 = LSE + off_chz1 - DELTA1 = DELTA + off_chz1 - - sparse_idx_hq1 = off_hq1 % SPARSE_HQ - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 - - sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask - sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 - - # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Q_IDX and Q_NUM_BLKS are always contiguous. - q_indices = Q_IDX + sparse_q_idx_offset - q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading - sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) - - offs_m1 = q_start + tl.arange(0, BLOCK_M1) - dk, dv = bwd_dkdv_inner( - {{gen_argdefs()}}, - Q1, DO1, DELTA1, LSE1, - dk, dv, k, v, - off_zq, off_hq1, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - - if HAS_FULL_BLOCKS: - # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. - q_indices = FULL_Q_IDX + sparse_q_idx_offset - q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading - sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) - - offs_m1 = q_start + tl.arange(0, BLOCK_M1) - dk, dv = bwd_dkdv_inner( - {{gen_argdefs()}}, - Q1, DO1, DELTA1, LSE1, - dk, dv, k, v, - off_zq, off_hq1, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - # Write back dV and dK. - dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd - - index_n = offs_n1[:, None] - index_k = offs_k[None, :] - index_v = offs_v[None, :] - - if IS_DIVISIBLE and SAFE_HEAD_DIM: - tl.store(dv_ptrs, dv) - else: - tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) - - dk *= SM_SCALE - - if SAFE_HEAD_DIM: - mask = index_n < KV_LEN - else: - mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) - - # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] - # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] - {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} - -@triton.jit -def bwd_dq_inner( - {{gen_argdefs()}}, - K, V, # pointers - dq, q, do, Di, lse, - off_z, off_hq, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - {{gen_defines() | indent_except_first(1) }} - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) - RCP_LN2: tl.constexpr = 1.44269504 - Q_LEN = {{size("Q", 2)}} - KV_LEN = {{size("K", 2)}} - - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd - vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - - hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) - if not IS_DIVISIBLE: - if hi >= 1: - for start_n in range(0, hi - 1): - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - - # Increment pointers. - offset = get_offset_for_next_block( - start_n, kv_indices, sparse_kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS - ) - - kT_ptrs += offset * stride_kn - vT_ptrs += offset * stride_vn - - offs_n2 += offset - - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) - else: - for start_n in range(0, hi): - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - - # Increment pointers. - offset = get_offset_for_next_block( - start_n, kv_indices, sparse_kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS - ) - - kT_ptrs += offset * stride_kn - vT_ptrs += offset * stride_vn - - offs_n2 += offset - - return dq - - -@triton.jit -def bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, -): - {{gen_defines() | indent_except_first(1)}} - - # NB reversed order to since K is transposed - kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) - qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) - if not PRESCALE_QK: - qk *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - pre_mod_scores = qk - n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) - # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim - # that the M reads out of bounds prior to the last loop - m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) - - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qk", - b="off_z", - h="off_hq", - m="m", - n="n", - out="qk" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=2, - output_name="mask_mod_output", - score="qk", - b="off_z", - h="off_hq", - m="m", - n="n", - ) | indent_except_first(2) }} - - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) - # apply mask for partial masked block - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - p = tl.math.exp2(post_mod_scores - lse) - # Compute dP and dS. - # NB reversed order to since V is transposed - vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) - - dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) - ds = p * (dp - Di[:, None]) - # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ - {{ modification( - subgraph_number=1, - output_name = "grad_scores", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="ds" - ) | indent_except_first(1) }} - if CHECK_BLOCK_BOUNDARY: - grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) - - # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ - if WRITE_DQ: - scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) - {{ modification( - subgraph_number=3, - output_name=None, - mask="scatter_mask", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="ds" - ) | indent_except_first(2) }} - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - ds = grad_scores - - if not IS_FULL_BLOCKS: - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for partially unmasked block - ds = tl.where(mask_mod_output, ds, 0.0) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - ds = ds.to(MATMUL_PRECISION) - # Compute dQ. - dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) - - return dq - - -@triton.jit -def bwd_dkdv_inner( - {{gen_argdefs()}}, - Q, DO, DELTA, LSE, # pointers - dk, dv, k, v, - off_z, off_hq, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - {{gen_defines() | indent_except_first(1) }} - SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) - RCP_LN2: tl.constexpr = 1.44269504 - Q_LEN = {{size("Q", 2)}} - KV_LEN = {{size("K", 2)}} - - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd - do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod - # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) - - if not IS_DIVISIBLE: - if hi >= 1: - for start_m in range(0, hi - 1): - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - # Increment pointers. - offset = get_offset_for_next_block( - start_m, q_indices, sparse_q_num_blocks, - SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS - ) - - qT_ptrs += offset * stride_qm - do_ptrs += offset * stride_dom - - offs_m1 += offset - - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) - else: - for start_m in range(0, hi): - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - # Increment pointers. - offset = get_offset_for_next_block( - start_m, q_indices, sparse_q_num_blocks, - SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS - ) - - qT_ptrs += offset * stride_qm - do_ptrs += offset * stride_dom - - offs_m1 += offset - - return dk, dv - - -@triton.jit -def bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, -): - {{gen_defines() | indent_except_first(1) }} - - # NB reversed order since Q is transposed - qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) - # Load LSE before computing qk to reduce pipeline stall. - if IS_DIVISIBLE: - lse = tl.load(LSE + offs_m1) - else: - lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) - lse = tl.where(lse == -float("inf"), 0.0, lse) - qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) - if not PRESCALE_QK: - qkT *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) - # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim - # that the n reads out of bounds prior to the last loop - n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) - - pre_mod_scores = qkT - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qkT", - b="off_z", - h="off_hq", - m="m", - n="n", - out="qkT" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=2, - output_name="mask_mod_output", - score="qkT", - b="off_z", - h="off_hq", - m="m", - n="n", - ) | indent_except_first(2) }} - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for fully masked block - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - pT = tl.math.exp2(post_mod_scores - lse[None, :]) - do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) - # Compute dV. - ppT = pT - dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) - if IS_DIVISIBLE: - Di = tl.load(DELTA + offs_m1) - else: - Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) - # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) - dsT = pT * (dpT - Di[None, :]) - # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ - {{ modification( - subgraph_number=1, - output_name = "grad_scores", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="dsT" - ) | indent_except_first(1) }} - - # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ - if not WRITE_DQ: - idx_b = off_z - idx_h = off_hq - idx_m = m - idx_n = n - scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) - {{ modification( - subgraph_number=3, - output_name=None, - mask="scatter_mask", - score="pre_mod_scores", - b="idx_b", - h="idx_h", - m="idx_m", - n="idx_n", - grad_score_mod="dsT" - ) | indent_except_first(2) }} - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - if CHECK_BLOCK_BOUNDARY: - grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) - - dsT = grad_scores - if not IS_FULL_BLOCKS: - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for partially unmasked block - dsT = tl.where(mask_mod_output, dsT, 0.0) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) - - return dk, dv - """ - + compute_next_offset_func - + get_bounded_indices_func - + load_checked_2d, + source=load_template("flex_backwards") + load_template("utilities"), ) diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py index 7f92fbc705a59..361729d44b992 100644 --- a/torch/_inductor/kernel/flex/flex_decoding.py +++ b/torch/_inductor/kernel/flex/flex_decoding.py @@ -18,15 +18,10 @@ TritonTemplate, ) from .common import ( - compute_forward_block_mn, - compute_forward_inner, - compute_next_offset_func, create_indices_fake, create_num_blocks_fake_generator, - get_bounded_indices_func, get_fwd_subgraph_outputs, - load_checked_2d, - load_checked_block, + load_template, maybe_realize, set_head_dim_values, ) @@ -90,266 +85,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me flex_decoding_template = TritonTemplate( name="flex_decoding", grid=flex_decoding_grid, - source=r""" - {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} - # Sub notation for this kernel: - # Q: Query, K: Key, V: Value - # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split - # M: Number of queries, N: Number of keys/values - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits - # (Modifiable) Config options: - # SPLIT_KV: number of blocks K & V are split into - # TILE_KV: length of each local KV split - # BLOCK_M: block size that Q is padded along seqlen dim. - # BLOCK_N: block size of K & V along N dimension. - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # - # change of base out of the loop - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. - # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. - - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. - # - # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # - # - # Output: ACC output accumulated across local KV split. - - tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) - - # Define Q Strides - stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} - stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} - stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} - - - Z = {{size("Q", 0)}} - ZKV = {{size("K", 0)}} - HKV = {{size("Q", 1)}} - G: tl.constexpr = GQA_SHARED_HEADS - HQ = HKV * G - Q_LEN = {{size("Q", 3)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - # Make sure each split is a multiple of BLOCK_N - TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) - TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N - TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) - - off_z = tl.program_id(0) // HKV - off_zkv = off_z % ZKV - off_hkv = tl.program_id(0) % HKV - off_t = tl.program_id(1) - - q_offset = off_z * stride_qz + off_hkv * stride_qh - k_offset = off_zkv * stride_kz + off_hkv * stride_kh - v_offset = off_zkv * stride_vz + off_hkv * stride_vh - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_z % SPARSE_Z - sparse_idx_h = off_hkv % SPARSE_HQ - - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - - # initialize offsets - tl.device_assert(BLOCK_M % G == 0) - BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G - off_g = tl.arange(0, G) # [G] - offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] - offs_hq = offs_g + off_hkv * G - off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] - offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] - offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) - - # Get HZ offsets for KV_NUM_BLKS and KV_IDX - stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} - sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h - stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} - sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h - - # Calculate KV blocks that belong this CTA. - block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block - block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N - - q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] - - if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) - elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) - elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) - else: - q = tl.load(Q + q_offset + q_range) - - q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) - - - # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Apply both score_mod and mask_mod - - # find first kv block we are loading and the number of blocks we are loading - # Offset the kv_indices tensor by the correct batch and head - kv_indices = KV_IDX + sparse_idx_hz_offset - kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) - indices_idx = block_n_start // SPARSE_KV_MULTIPLE - off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE - off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N - # first kv block we're loading - - # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(QK_HEAD_DIM, KV_LEN), # (d, N) - strides=(stride_kk, stride_kn), - offsets=(0, off_n), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(off_n, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = tl.arange(0, BLOCK_N) + off_n - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, - # accumulatd values - acc, l_i, m_i, - #offsets - off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], - None, - #block sparse data - kv_indices, kv_num_blocks, - block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - - # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We know these blocks are guaranteed to be "full", so we don't need to - # apply mask_mod to them - only score_mod - if HAS_FULL_BLOCKS: - kv_indices = FULL_KV_IDX + sparse_idx_hz_offset - kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) - # Assign full block in a reverse order for off_t. Prioritize the last CTA. - block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE - block_n_end = block_n_start + TILE_KV_MULTIPLE - indices_idx = block_n_start // SPARSE_KV_MULTIPLE - off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE - off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N - - # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(QK_HEAD_DIM, KV_LEN), # (d, N) - strides=(stride_kk, stride_kn), - offsets=(0, off_n), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(off_n, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = tl.arange(0, BLOCK_N) + off_n - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, - # accumulatd values - acc, l_i, m_i, - #offsets - off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], - None, - #block sparse data - kv_indices, kv_num_blocks, - block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - m_offset = off_t * stride_mt + off_z * stride_mz - l_offset = off_t * stride_lt + off_z * stride_lz - - M_block_ptr = tl.make_block_ptr( - base=M + m_offset, - shape=(G, Q_LEN), # (G, M) - strides=(stride_mh, stride_mm), - offsets=(off_hkv*G, 0), - block_shape=(G, BLOCK_M_PER_HQ), - order=(1, 0) - ) - L_block_ptr = tl.make_block_ptr( - base=L + l_offset, - shape=(G, Q_LEN), # (G, M) - strides=(stride_lh, stride_lm), - offsets=(off_hkv*G, 0), - block_shape=(G, BLOCK_M_PER_HQ), - order=(1, 0) - ) - - # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) - m_i = m_i.reshape(G, BLOCK_M_PER_HQ) - l_i = l_i.reshape(G, BLOCK_M_PER_HQ) - if SAFE_M_BOUNDARY: - tl.store(M_block_ptr, m_i) - tl.store(L_block_ptr, l_i) - else: - tl.store(M_block_ptr, m_i, boundary_check=(1,)) - tl.store(L_block_ptr, l_i, boundary_check=(1,)) - - # -- store output - idx_z = off_z - idx_t = off_t - idx_hq = off_hkv*G + off_g[:, None, None] - idx_m = off_m[None, :, None] - idx_d = offs_vd[None, None, :] - - mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) - acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) - {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} - """ - + compute_forward_inner - + get_bounded_indices_func - + load_checked_block - + load_checked_2d - + compute_next_offset_func - + compute_forward_block_mn, + source=load_template("flex_decode") + + load_template("utilities") + + load_template("common"), ) diff --git a/torch/_inductor/kernel/flex/templates/common.py.jinja b/torch/_inductor/kernel/flex/templates/common.py.jinja new file mode 100644 index 0000000000000..0e967570127d4 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/common.py.jinja @@ -0,0 +1,193 @@ + + +# Common Imports +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + # NB reversed order to since K is transposed + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( + desc_k, + [kv_start + kv_offset, 0], + ) + {%- else %} + k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- endif %} + + if USE_TMA: + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_start + kv_offset, 0], + ) + {%- else %} + v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + if not USE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + + + return acc, l_i, m_i diff --git a/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja b/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja new file mode 100644 index 0000000000000..79410fb500460 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja @@ -0,0 +1,248 @@ +{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0) + off_zq = tl.program_id(1) + off_hq = tl.program_id(2) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN, QK_HEAD_DIM], + strides=[stride_qm, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + + desc_k = tl.make_tensor_descriptor( + base=K, + shape=[KV_LEN, QK_HEAD_DIM], + strides=[stride_kn, 1], + block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], + ) + + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN, V_HEAD_DIM], + strides=[stride_vn, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + K_block_ptr = None + V_block_ptr = None + Q_block_ptr = None + + if not USE_TMA: + Q_block_ptr = tl.make_block_ptr( + base=Q , + shape=(Q_LEN, QK_HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(q_start * BLOCK_M, 0), + block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} + q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1) + idx_hq = tl.program_id(2) + idx_m = offs_m[:, None] + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) diff --git a/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja b/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja new file mode 100644 index 0000000000000..1775833b8e68f --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja @@ -0,0 +1,682 @@ +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1) # q batch idx + off_hkv = tl.program_id(2) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + if not IS_DIVISIBLE: + if hi >= 1: + for start_n in range(0, hi - 1): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1)}} + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds prior to the last loop + m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + if not IS_DIVISIBLE: + if hi >= 1: + for start_m in range(0, hi - 1): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1) }} + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds prior to the last loop + n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) + + dsT = grad_scores + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv \ No newline at end of file diff --git a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja new file mode 100644 index 0000000000000..f4596070c833e --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja @@ -0,0 +1,252 @@ + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0) % HKV + off_t = tl.program_id(1) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Apply both score_mod and mask_mod + + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} \ No newline at end of file diff --git a/torch/_inductor/kernel/flex/templates/utilities.py.jinja b/torch/_inductor/kernel/flex/templates/utilities.py.jinja new file mode 100644 index 0000000000000..7e2367e4f2692 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/utilities.py.jinja @@ -0,0 +1,59 @@ + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_DIM: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) From 101276f81b4d2a8c31bfd6796b986d4c1bfdf483 Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Tue, 12 Aug 2025 10:20:26 -0700 Subject: [PATCH 0288/1424] [BE] Save attributes for CppCompileError for pickleing (#160294) Differential Revision: [D79977408](https://our.internmc.facebook.com/intern/diff/D79977408/) Context: When testing cutlass backend and used autotune with subproc, sometimes I would see C++ compilation error (expected) followed by ``` Traceback (most recent call last): File "/torch/_inductor/autotune_process.py", line 175, in get result = TuningProcess.recv(self.read_pipe) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/torch/_inductor/autotune_process.py", line 99, in recv return pickle.load(read_pipe) ^^^^^^^^^^^^^^^^^^^^^^ TypeError: CppCompileError.__init__() missing 1 required positional argument: 'output' ``` which is unexpected. After asking claude, it seems > Now I can see the issue. The `CppCompileError` class requires two arguments: `cmd` (a list of strings) and `output` (a string). However, when exceptions are being pickled and unpickled across process boundaries, the pickling process might not be preserving the constructor arguments correctly. > > The problem is likely that when a `CppCompileError` is raised in the subprocess and then pickled/unpickled through the `recv` function, the unpickling process is trying to reconstruct the exception but doesn't have the required constructor arguments. > > The issue is clear now. The `CppCompileError` class doesn't have custom pickle methods (`__reduce__`, `__getstate__`, `__setstate__`), so when it's pickled and unpickled across process boundaries, Python's default pickling mechanism tries to reconstruct it but fails because it doesn't preserve the constructor arguments properly. > > The solution is to add a `__reduce__` method to the `CppCompileError` class to ensure it can be properly pickled and unpickled. Let me implement this fix: Adding these seem to help. fbcode repro: [D79977541](https://www.internalfb.com/diff/D79977541) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160294 Approved by: https://github.com/masnesral --- torch/_inductor/exc.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index ac321c9974ae8..a46663ed8f8c0 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -92,6 +92,9 @@ def __init__(self, cmd: list[str], output: str) -> None: if isinstance(output, bytes): output = output.decode("utf-8") + self.cmd = cmd + self.output = output + super().__init__( textwrap.dedent( """ @@ -108,6 +111,9 @@ def __init__(self, cmd: list[str], output: str) -> None: .format(cmd=" ".join(cmd), output=output) ) + def __reduce__(self) -> tuple[type, tuple[list[str], str]]: + return (self.__class__, (self.cmd, self.output)) + class CUDACompileError(CppCompileError): pass From 16d15445f8bd8740095b23de4af89d757af793ca Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Tue, 12 Aug 2025 22:06:18 +0000 Subject: [PATCH 0289/1424] Fullgraph graph capture with dynamo. (#159749) Summary: Following up on Avik's doc https://docs.google.com/document/d/11RW0Bbkp1QwFbEu8rCNW5d7wUFaEkxbL0uLyqcc2jTk/edit?tab=t.0 We are experimenting with a new API which utilizes torch.compile(fullgraph=True) and intend to use it to replace the old dynamo.export() API. This PR adds a prototype for the API described in the doc. Test Plan: test_misc -- -k test_aot_capture Rollback Plan: Differential Revision: D79534608 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159749 Approved by: https://github.com/tugsbayasgalan --- test/dynamo/test_misc.py | 46 +++++++++++++++++++++ torch/_dynamo/eval_frame.py | 82 ++++++++++++++++++++++++++++++++++++- torch/_dynamo/package.py | 14 +++++-- 3 files changed, 138 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index d34670c357bf4..624f0603678af 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -16,11 +16,13 @@ import math import operator import os +import pickle import random import sys import tempfile import threading import traceback +import types import typing import unittest import unittest.mock as mock @@ -8520,6 +8522,50 @@ def global_context_capture_fn(frame_summary): self.assertEqual(seen_frames[0].name, "fn") self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)") + def test_fullgraph_capture(self): + def foo(x): + return x + x.shape[0] + + compiled_foo = torch._dynamo.eval_frame.fullgraph_capture(foo) + compiled_foo(torch.randn(3, 2)) + compiled_foo(torch.randn(4)) + artifacts = compiled_foo.get_artifacts() + + guarded_codes = artifacts.dynamo_artifacts.guarded_codes + backend_ids = list(artifacts.backend_inputs.keys()) + gms = [b.graph_module for b in artifacts.backend_inputs.values()] + + def _convert_to_ep_demo(code, backend_id, gm, args): + # Inject compiled function as the original gm + new_globals = copy.copy(globals()) + new_globals[backend_id] = gm + # Minimal boilerplate to setup a callable. + SerializedCode = type(code.dynamo_code) + dynamo_bytecode = SerializedCode.to_code_object(code.dynamo_code) + guards_state = pickle.loads(code.guards_state) + guard_manager = torch._dynamo.guards.CheckFunctionManager( + foo.__code__, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + runtime_global_scope=new_globals, + ).guard_manager + + class ModuleForExport(torch.nn.Module): + def forward(self, x): + return types.FunctionType(dynamo_bytecode, new_globals)(x) + + m = ModuleForExport() + return guard_manager, torch.export.export(m, args) + + guards0, ep0 = _convert_to_ep_demo( + guarded_codes[0], backend_ids[0], gms[0], (torch.randn(3, 2),) + ) + self.assertTrue(guards0.check({"x": torch.randn(3, 2)})) + self.assertFalse(guards0.check({"x": torch.randn(4)})) + input0 = torch.randn(3, 2) + self.assertEqual(ep0.module()(input0), foo(input0)) + def test_torch_guards_stack_frame_register_inlining_deep(self): x = torch.tensor([0.5, 0.5]) y = torch.tensor([0.75, 0.75, 0.75, 0.75]) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index fd85b5d28e03c..63c2ed9e9bad7 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -113,7 +113,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence - from torch._dynamo.package import CompilePackage + from torch._dynamo.package import CompilePackage, DynamoCaptureOutput from torch._dynamo.repro.after_dynamo import WrapBackendDebug from torch._subclasses import fake_tensor from torch.fx.node import Argument, Node, Target @@ -2288,3 +2288,83 @@ def skip_code(code: types.CodeType) -> None: set_code_exec_strategy( code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) ) + + +@dataclass +class BackendInput: + graph_module: torch.fx.GraphModule + example_inputs: tuple[Any, ...] + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode + + +@dataclass +class CaptureOutput: + """ + Core data structure that contains the all the information dynamo generates + from fullgraph=True. Ideally, this is should be the "return" type if dynamo + has a standard API to return compilation artifacts. + """ + + dynamo_artifacts: DynamoCaptureOutput + backend_inputs: dict[str, BackendInput] + + +def fullgraph_capture(model: Callable[..., Any]) -> Callable[..., Any]: + """ + A helper function which wraps a model and returns a callable like optimize(). + The callable can be called with normal inputs like torch.compile()-ed functions + and user can dump dynamo compilation artifacts through `get_artifacts()` call. + + The CaptureOutput is separated into two parts: + 1. Dynamo specific information from DynamoCaptureOutput, which includes: + - guards + - generated bytecode + - python source information + 2. Backend specific information (indexed by unique backend id) such as: + - fx graph + - example inputs + + Example: + def fn(*args): + ... + + compiled_fn = fullgraph_capture(fn) + compiled_fn(args) + compiled_fn(another_args) + artifacts = compiled_fn.get_artifacts() + """ + from torch._dynamo.package import CompilePackage + + package = CompilePackage(model) + + backend_inputs: dict[str, BackendInput] = {} + + def _backend( + gm: torch.fx.GraphModule, example_inputs: tuple[Any, ...] + ) -> torch.fx.GraphModule: + from torch._guards import TracingContext + + fake_mode = TracingContext.get().fake_mode + assert fake_mode is not None + backend_id = gm._backend_id + assert isinstance(backend_id, str) + backend_inputs[backend_id] = BackendInput(gm, example_inputs, fake_mode) + return gm + + # TODO For now we use eval_frame to give us the frame. This is can be simplified to + # a manual frame creation helper. + optimized_model = optimize(nopython=True, backend=_backend, package=package)(model) + + @functools.wraps(model) + def capture_context(*args: Any, **kwargs: Any) -> Any: + return optimized_model(*args, **kwargs) + + def get_artifacts() -> CaptureOutput: + cache_entry = package.cache_entry() + assert len(cache_entry.codes) == 1 + return CaptureOutput( + dynamo_artifacts=cache_entry.codes[0], backend_inputs=backend_inputs + ) + + capture_context.get_artifacts = get_artifacts # type: ignore[attr-defined] + return capture_context diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index b15dc0b2fdf69..311a702dfa38a 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -112,7 +112,17 @@ class InlinedSource: @dataclasses.dataclass -class _DynamoCodeCacheEntry: +class DynamoCaptureOutput: + """ + Core information generated from Dynamo for fullgraph=True. + """ + + guarded_codes: list[_GuardedCodeCacheEntry] + backend_ids: list[_BackendId] + + +@dataclasses.dataclass +class _DynamoCodeCacheEntry(DynamoCaptureOutput): """ Contains the serializable information associated with a single code object in dynamo. To restore an execution of compiled code, we will need the following @@ -135,9 +145,7 @@ class _DynamoCodeCacheEntry: python_code: SerializedCode python_module: str function_names: list[_FunctionId] - guarded_codes: list[_GuardedCodeCacheEntry] import_sources: dict[str, str] - backend_ids: list[_BackendId] code_source: Optional[str] install_to_global: bool has_compile_id: bool = False From 2e4e5ab4be9e0aeffd9c49b5b2f9f820bd0895b1 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Tue, 12 Aug 2025 22:08:44 +0000 Subject: [PATCH 0290/1424] [MPS] Add mps keys to `indices` and `values` ops (#160223) enable indices and values on sparse mps Pull Request resolved: https://github.com/pytorch/pytorch/pull/160223 Approved by: https://github.com/malfet --- aten/src/ATen/native/native_functions.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e7492f4c379af..1bb8fe52512ca 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7462,7 +7462,7 @@ - func: indices(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: indices_sparse + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: indices_sparse CompositeExplicitAutograd: indices_default device_check: NoCheck device_guard: False @@ -7470,7 +7470,7 @@ - func: values(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: values_sparse + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: values_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: values_sparse_csr NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: values_nested CompositeExplicitAutograd: values_default From 5737372862253a0ac0292407a5844796f02380ad Mon Sep 17 00:00:00 2001 From: deedongala Date: Tue, 12 Aug 2025 22:42:40 +0000 Subject: [PATCH 0291/1424] [CI] Switch ROCm MI300 GitHub Actions workflows from 2-GPU to 1-GPU runners (#158882) Updated .github/actionlint.yaml to replace linux.rocm.gpu.mi300.2 with linux.rocm.gpu.mi300.1 in the supported runner list Modified all affected workflows (inductor-perf-test-nightly-rocm.yml, inductor-periodic.yml, inductor-rocm-mi300.yml, and rocm-mi300.yml) to run jobs on 1-GPU MI300 runners instead of 2-GPU runners This should help increase available runners even with same number of CI nodes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158882 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .github/actionlint.yaml | 1 + .github/actions/setup-rocm/action.yml | 5 --- .github/workflows/_rocm-test.yml | 10 ++++++ .../inductor-perf-test-nightly-rocm.yml | 34 +++++++++---------- .github/workflows/inductor-periodic.yml | 30 ++++++++-------- .github/workflows/inductor-rocm-mi300.yml | 4 +-- .github/workflows/rocm-mi300.yml | 12 +++---- 7 files changed, 51 insertions(+), 45 deletions(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 647671e8c83d2..85c7999c1857e 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -54,6 +54,7 @@ self-hosted-runner: - linux.rocm.gpu.2 - linux.rocm.gpu.4 # gfx942 runners + - linux.rocm.gpu.gfx942.1 - linux.rocm.gpu.gfx942.2 - linux.rocm.gpu.gfx942.4 - rocm-docker diff --git a/.github/actions/setup-rocm/action.yml b/.github/actions/setup-rocm/action.yml index d3644c52fbcd8..a58db801b1cf8 100644 --- a/.github/actions/setup-rocm/action.yml +++ b/.github/actions/setup-rocm/action.yml @@ -59,11 +59,6 @@ runs: echo "$msg" exit 1 fi - if [[ $ngpu -eq 1 ]]; then - echo "Error: only 1 GPU detected, at least 2 GPUs are needed for distributed jobs" - echo "$msg" - exit 1 - fi - name: Runner diskspace health check uses: pytorch/pytorch/.github/actions/diskspace-cleanup@main diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 2d660d98905e9..f73972942b5f9 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -88,6 +88,16 @@ jobs: - name: Setup ROCm uses: ./.github/actions/setup-rocm + - name: Runner check GPU count (distributed jobs) + if: ${{ contains(matrix.config, 'distributed') }} + shell: bash + run: | + ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') + if [[ $ngpu -lt 4 ]]; then + echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs" + exit 1 + fi + - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 diff --git a/.github/workflows/inductor-perf-test-nightly-rocm.yml b/.github/workflows/inductor-perf-test-nightly-rocm.yml index 1ec494ace6577..f329fe74e6b64 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm.yml @@ -88,23 +88,23 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks test-matrix: | { include: [ - { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" }, + { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, ]} secrets: inherit diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index fdb54978e8082..436cf95c156d0 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -81,21 +81,21 @@ jobs: sync-tag: rocm-build test-matrix: | { include: [ - { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, + { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, ]} secrets: inherit diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index f4c81ce7d7b8d..732ec7eb85f3e 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -47,8 +47,8 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" }, + { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, ]} secrets: inherit diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index c51d89e5c955d..7e3ba43bf9845 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -48,12 +48,12 @@ jobs: sync-tag: rocm-build test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, - { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" }, + { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, ]} secrets: inherit From 0d71ca2c46753bb268bfdcf815c14415c122a289 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Tue, 12 Aug 2025 22:44:22 +0000 Subject: [PATCH 0292/1424] [EZ] Replace `pytorch-labs` with `meta-pytorch` (#160459) This PR replaces all instances of 'pytorch-labs' with 'meta-pytorch' in this repository now that the 'pytorch-labs' org has been renamed to 'meta-pytorch' ## Changes Made - Replaced all occurrences of 'pytorch-labs' with 'meta-pytorch' - Only modified files with extensions: .py, .md, .sh, .rst, .cpp, .h, .txt, .yml - Skipped binary files and files larger than 1MB due to GitHub api payload limits in the script to cover all repos in this org. Will do a more manual second pass later to cover any larger files ## Files Modified This PR updates files that contained the target text. Generated by automated script on 2025-08-12T20:41:29.888681+00:00Z Pull Request resolved: https://github.com/pytorch/pytorch/pull/160459 Approved by: https://github.com/huydhn, https://github.com/clee2000, https://github.com/atalman, https://github.com/malfet --- android/README.md | 2 +- aten/src/ATen/native/cuda/int4mm.cu | 2 +- torch/testing/_internal/common_quantization.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/android/README.md b/android/README.md index 6b8000c13fccc..f0c74750522de 100644 --- a/android/README.md +++ b/android/README.md @@ -2,7 +2,7 @@ ## Demo applications and tutorials -Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch). +Please refer to [meta-pytorch/executorch-examples](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch). Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions. diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index 272eb9b9c564f..5444bb57eba7c 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -1304,7 +1304,7 @@ at::Tensor _convert_weight_to_int4pack_cuda( constexpr int32_t kKTileSize = 16; // GPT-FAST assumes nTileSize of 8 for quantized weight tensor. - // See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510 + // See https://github.com/meta-pytorch/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510 // Torch dynamo also requires the torch ops has the same output shape for each device. // See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263 constexpr int32_t kNTileSizeTensor = 8; diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 211b282c4fc4a..f8671379950ec 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -611,7 +611,7 @@ def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): - # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py + # source: https://github.com/meta-pytorch/gpt-fast/blob/main/quantize.py # default setup for affine quantization of activations x_dtype = x.dtype x = x.float() From b1f43548cad8fc0e30bda250f6e196310fa7a4bc Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 12 Aug 2025 20:13:16 +0000 Subject: [PATCH 0293/1424] [c10d] Error out the case when registering symmetric memory without eager init (#160145) Instead of implicitly creating nccl comm inside mem pool registration for symmetric memory, we decide to error it out so that we only support eager init case when the nccl comm is already initiated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160145 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_nccl.py | 55 +++++++++++-------- .../distributed/c10d/ProcessGroupNCCL.cpp | 9 +-- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index fd9e7594828d6..a1e8d30fef6c4 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3172,35 +3172,42 @@ def test_nccl_user_buffer_registration(self): @requires_multicast_support() def test_nccl_window_registration(self): store = c10d.FileStore(self.file_name, self.world_size) - c10d.init_process_group( - backend="nccl", rank=self.rank, world_size=self.world_size, store=store - ) device = torch.device(f"cuda:{self.rank}") - torch.cuda.set_device(self.rank) - pg = c10d.distributed_c10d._get_default_group() - backend = pg._get_backend(torch.device(device)) - - # Use NCCL memory allocator - # enable symmetric memory usage in NCCL - pool = torch.cuda.MemPool(backend.mem_allocator, symmetric=True) - - # allocate memory with ncclMemAlloc - # note: symmetric kernels are not available for dtypes like torch.int64 - with torch.cuda.use_mem_pool(pool): - tensor = torch.arange(1024 * 1024 * 2, device=device, dtype=torch.float32) + with torch.cuda.device(device): + # Eager init the nccl comm so that we don't implicitly create one during register_mem_pool + c10d.init_process_group( + backend="nccl", + rank=self.rank, + world_size=self.world_size, + store=store, + device_id=device, + ) + pg = c10d.distributed_c10d._get_default_group() + backend = pg._get_backend(torch.device(device)) + + # Use NCCL memory allocator + # enable symmetric memory usage in NCCL + pool = torch.cuda.MemPool(backend.mem_allocator, symmetric=True) + + # allocate memory with ncclMemAlloc + # note: symmetric kernels are not available for dtypes like torch.int64 + with torch.cuda.use_mem_pool(pool): + tensor = torch.arange( + 1024 * 1024 * 2, device=device, dtype=torch.float32 + ) - # register buffers to NCCL - backend.register_mem_pool(pool) + # register buffers to NCCL + backend.register_mem_pool(pool) - # allreduce now should use NVIDIA Switches - pg.allreduce(tensor).wait() - torch.cuda.synchronize(device=device) + # allreduce now should use NVIDIA Switches + pg.allreduce(tensor).wait() + torch.cuda.synchronize(device=device) - # de-register buffers from NCCL - backend.deregister_mem_pool(pool) + # de-register buffers from NCCL + backend.deregister_mem_pool(pool) - # clean up memory - del tensor, pool + # clean up memory + del tensor, pool with open(os.environ["NCCL_DEBUG_FILE"]) as f: nccl_debug_file_content = f.read() diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3e9802d855e7c..655193e8f3186 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1091,18 +1091,15 @@ ErrorType ProcessGroupNCCL::getError() { void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { const auto key = std::to_string(pool->device()); - auto device = at::Device(at::DeviceType::CUDA, pool->device()); LOG(INFO) << logPrefix() << "Performing NCCL user buffer registration for all buffers in " << "MemPool: " << pool->id() << ", device index: " << key << ", i am " << this; auto ncclComm = getNCCLComm(key); if (ncclComm == nullptr) { - // HACK: currently we are using this function for NVLS - // reductions, and that's why using OpType::ALLREDUCE. - // If we end up using this API for zero-copy P2P, we might - // need to refactor and account for different OpType. - ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE); + C10_THROW_ERROR( + DistBackendError, + "NCCL communicator has not been initialized before mem pool creation. You can pass `device_id` to init_process_group -- one way of eager initialization -- to work around this issue"); } TORCH_INTERNAL_ASSERT(ncclComm != nullptr); { From 8d1cf529229dce7cd5ea04abb0faac83b87ca6d1 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 12 Aug 2025 16:19:27 -0700 Subject: [PATCH 0294/1424] [EZ][BE] Remove unused `conda-env-macOS-ARM64` (#160477) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160477 Approved by: https://github.com/atalman --- .github/requirements/conda-env-macOS-ARM64 | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 .github/requirements/conda-env-macOS-ARM64 diff --git a/.github/requirements/conda-env-macOS-ARM64 b/.github/requirements/conda-env-macOS-ARM64 deleted file mode 100644 index b6e9a6ce9f3e5..0000000000000 --- a/.github/requirements/conda-env-macOS-ARM64 +++ /dev/null @@ -1,5 +0,0 @@ -# Not pinning certifi so that we can always get the latest certificates -certifi -pip=23.2.1 -pkg-config=0.29.2 -wheel=0.37.1 From 32099961d588fc19ead8afe805d6b5108de75669 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 12 Aug 2025 16:19:33 -0700 Subject: [PATCH 0295/1424] [EZ] Delete CircleCI case (#160479) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160479 Approved by: https://github.com/izaitsevfb ghstack dependencies: #160477 --- .ci/manywheel/build.sh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.ci/manywheel/build.sh b/.ci/manywheel/build.sh index 4c4d51134715a..6b2a60bc5ca28 100755 --- a/.ci/manywheel/build.sh +++ b/.ci/manywheel/build.sh @@ -5,10 +5,6 @@ set -ex SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" case "${GPU_ARCH_TYPE:-BLANK}" in - BLANK) - # Legacy behavior for CircleCI - bash "${SCRIPTPATH}/build_cuda.sh" - ;; cuda) bash "${SCRIPTPATH}/build_cuda.sh" ;; From 69a0a9aa7f5e320a02e97fa789d2f72baff1554f Mon Sep 17 00:00:00 2001 From: Nikhil Patel Date: Wed, 13 Aug 2025 01:27:57 +0000 Subject: [PATCH 0296/1424] [Inductor][Triton] Pass GPUTarget param to updated make_ir function (#160422) Summary: A recent Triton commit changed `ASTSource.make_ir` to a 5-arg signature that includes a `GPUTarget`. We need to pass in this new argument. Test Plan: `buck2 test 'fbcode//mode/opt' -m ovr_config//triton:trunk fbcode//caffe2/test/inductor:test_inductor_cuda -- triton_kernel` Rollback Plan: Reviewed By: davidberard98 Differential Revision: D80069909 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160422 Approved by: https://github.com/davidberard98, https://github.com/mlazos --- torch/_higher_order_ops/triton_kernel_wrap.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 34a9c5915254d..4dd2bd145a90a 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -461,11 +461,16 @@ def get_signature_value(idx: int, arg: Any) -> str: elif make_ir_sig_params == 3: codegen_fns = backend.get_codegen_implementation() ttir_module = src.make_ir(options, codegen_fns, context) - else: + elif make_ir_sig_params == 4: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() ttir_module = src.make_ir(options, codegen_fns, module_map, context) + else: + codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] + codegen_fns = backend.get_codegen_implementation(*codegen_args) + module_map = backend.get_module_map() + ttir_module = src.make_ir(target, options, codegen_fns, module_map, context) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") From f15ada5c6fad97a7dcbfa4673f067b6942dda640 Mon Sep 17 00:00:00 2001 From: nandesuka <11392812+nandesuka@users.noreply.github.com> Date: Wed, 13 Aug 2025 01:28:19 +0000 Subject: [PATCH 0297/1424] Enable output padding when only outermost dim is dynamic (#159404) Summary: When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if required. Test Plan: CI Rollback Plan: Differential Revision: D79146886 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159404 Approved by: https://github.com/blaine-rister, https://github.com/eellison --- test/inductor/test_padding.py | 105 ++++++++++++++++++++++++++++++---- torch/_inductor/ir.py | 16 +++--- 2 files changed, 102 insertions(+), 19 deletions(-) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 15c1abdf32db2..41944a9169239 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -49,6 +49,18 @@ def geninp(): return input_dict +def get_padded_stride(shape, alignment_bytes, pad_output, itemsize): + align = alignment_bytes // itemsize + new_strides = [0 for _ in range(len(shape))] + new_strides[len(shape) - 1] = 1 + for i in range(len(shape) - 1, 0, -1): + stride = shape[i] * new_strides[i] + if pad_output and stride % align != 0: + stride = (stride + align - 1) // align * align + new_strides[i - 1] = stride + return tuple(new_strides) + + class LinearAndSoftmax(nn.Module): """ It's very common that a transformer model will do a matmul and then @@ -745,20 +757,11 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor: input_tensors = [get_input(shape, alignment_bytes) for _ in range(num_inputs)] config_patches = { - "compile_threads": 1, "comprehensive_padding": pad_output, "cpu_backend": "triton", - "disable_padding_cpu": False, - "implicit_fallbacks": False, - "inplace_buffers": False, "padding_alignment_bytes": alignment_bytes, - "pad_channels_last": True, "pad_outputs": True, "padding_stride_threshold": 0, - "triton.prefer_nd_tiling": True, - "triton.use_block_ptr": True, - "triton.codegen_upcast_to_fp32": False, - "unroll_reductions_threshold": 1, } with config.patch(config_patches): compiled = torch.compile(torch.cat) @@ -767,7 +770,89 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor: output_shape = (shape[0] * num_inputs, shape[1]) output_stride = input_tensors[0].stride() output_line = f"buf12 = empty_strided_{GPU_TYPE}({output_shape}, {output_stride}, torch.float32)" - self.assertTrue(any(output_line in line for line in code)) + self.assertTrue(output_line in code[0]) + + @parametrize( + "shape,alignment_bytes,pad_output", + [ + ((512, 1), 32, False), + ((512, 1), 32, True), + ((32, 30), 64, False), + ((32, 30), 64, True), + ((512, 100, 1), 32, False), + ((512, 100, 1), 32, True), + ((32, 50, 30), 64, False), + ((32, 50, 30), 64, True), + ], + ) + def test_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output): + """ + When only the outermost dim is dynamic shape, the output can still be padded up + based on padding configuration. + """ + num_inputs = 2 + input_tensors = [ + torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs) + ] + + config_patches = { + "comprehensive_padding": pad_output, + "cpu_backend": "triton", + "padding_alignment_bytes": alignment_bytes, + "pad_outputs": True, + "padding_stride_threshold": 0, + } + with config.patch(config_patches): + torch._dynamo.mark_dynamic(input_tensors[0], 0) + torch._dynamo.mark_dynamic(input_tensors[1], 0) + compiled = torch.compile(torch.add) + result, _ = run_and_get_code(compiled, *input_tensors) + + expected_stride = get_padded_stride( + result.shape, alignment_bytes, pad_output, result.dtype.itemsize + ) + self.assertEqual(result.stride(), expected_stride) + + @parametrize( + "shape,alignment_bytes,pad_output", + [ + ((500, 10, 1), 32, False), + ((500, 20, 1), 32, True), + ((30, 10, 20), 64, True), + ((30, 10, 20), 64, False), + ], + ) + def test_perm_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output): + """ + When only the outermost dim is dynamic shape, the output can still be padded up + based on padding configuration. Test when this occurs after a permute op. + """ + + def permute_contig(x): + return torch.transpose(x, 0, 2).contiguous() + + num_inputs = 1 + input_tensors = [ + torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs) + ] + + config_patches = { + "comprehensive_padding": pad_output, + "cpu_backend": "triton", + "padding_alignment_bytes": alignment_bytes, + "pad_outputs": True, + "padding_stride_threshold": 0, + "triton.use_block_ptr": True, + } + with config.patch(config_patches): + torch._dynamo.mark_dynamic(input_tensors[0], 2) + compiled = torch.compile(permute_contig) + result, _ = run_and_get_code(compiled, *input_tensors) + + expected_stride = get_padded_stride( + result.shape, alignment_bytes, pad_output, result.dtype.itemsize + ) + self.assertEqual(result.stride(), expected_stride) if __name__ == "__main__": diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 9859ca8a1b132..db62af3616334 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3733,10 +3733,8 @@ def _pad_strides( # do for dynamic shape. # # Skip padding the strides for dynamic shape for now. - if not all( - isinstance(s, (int, sympy.Integer)) - for s in itertools.chain(in_strides, size) - ): + # If outermost dim is dynamic, stride still can be fully static + if not all(isinstance(s, (int, sympy.Integer)) for s in in_strides): return in_strides stride_order = get_stride_order(in_strides) @@ -3751,11 +3749,11 @@ def _pad_strides( for rank, idx in enumerate(fill_order[1:], start=1): prev_idx = fill_order[rank - 1] stride = new_strides[prev_idx] * size[prev_idx] - - if stride > config.padding_stride_threshold and stride % align != 0: - stride = ceildiv(stride, align) * align - padded = True - new_strides[idx] = stride + if isinstance(stride, (int, sympy.Integer)): + if stride > config.padding_stride_threshold and stride % align != 0: + stride = ceildiv(stride, align) * align + padded = True + new_strides[idx] = stride if not padded: # Consider a tensor with shape [256, 1, 5, 5] From 6be6d06295c870c77a6eb69f96b3170d983520d5 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Mon, 11 Aug 2025 09:55:37 +0000 Subject: [PATCH 0298/1424] Avoid potential deadlocks in host allocator (#159352) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Motivation This PR fixes a potential deadlock in the host allocator. When calling `event->record(stream)`, the `record_stream` implementation may acquire the Python GIL. In places such as https://github.com/pytorch/pytorch/blob/842cc77ab9aafd518593c2fce077d6abb42a5b7f/aten/src/ATen/cuda/CachingHostAllocator.cpp#L145-L151, and https://github.com/pytorch/pytorch/blob/842cc77ab9aafd518593c2fce077d6abb42a5b7f/aten/src/ATen/xpu/CachingHostAllocator.cpp#L22-L28 `record_stream` is invoked while holding the allocator lock. To prevent deadlocks, we must ensure the locking order is: **GIL → Allocator Lock**. Reversing the order (**Allocator Lock → GIL**) can cause a deadlock. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159352 Approved by: https://github.com/cyyever, https://github.com/ezyang --- aten/src/ATen/core/CachingHostAllocator.h | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/core/CachingHostAllocator.h b/aten/src/ATen/core/CachingHostAllocator.h index 5049018d731e1..a8f5f2fd79973 100644 --- a/aten/src/ATen/core/CachingHostAllocator.h +++ b/aten/src/ATen/core/CachingHostAllocator.h @@ -251,6 +251,7 @@ struct CachingHostAllocatorImpl { auto* block = reinterpret_cast(ctx); std::optional> events; + ska::flat_hash_set streams; { std::lock_guard g(block->mutex_); block->allocated_ = false; @@ -259,14 +260,19 @@ struct CachingHostAllocatorImpl { } else { events = std::vector(); events->reserve(block->streams_.size()); - for (auto stream : block->streams_) { - record_stream(events, stream); - } - block->event_count_ += events->size(); + block->event_count_ += block->streams_.size(); + // Move out streams to avoid holding the mutex during event recording + streams = std::move(block->streams_); block->streams_.clear(); } } + // Event recording must be done outside the mutex to avoid potential + // deadlocks (e.g., when Python GIL is involved) + for (auto stream : streams) { + record_stream(events, stream); + } + if (!events) { auto index = size_index(block->size_); std::lock_guard g(free_list_[index].mutex_); From 41673110cd7c5960824cc74a6fcaeda1a8bc7a23 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 13 Aug 2025 02:36:19 +0000 Subject: [PATCH 0299/1424] [inductor] Windows inductor use intel-openmp. (#160258) After some debug work, I found PyTorch torch_cpu.dll is using intel-openmp, but not MSVC openmp. So, switch Windows inductor to intel-openmp. It fixed: https://github.com/pytorch/pytorch/blob/c8205cb35435f39d2c26f6c94b45e4adeb6dcb23/test/inductor/test_aot_inductor.py#L2405-L2408 image Pull Request resolved: https://github.com/pytorch/pytorch/pull/160258 Approved by: https://github.com/ezyang --- setup.py | 1 + torch/_inductor/cpp_builder.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 23ef581241396..fc03de4298018 100644 --- a/setup.py +++ b/setup.py @@ -1598,6 +1598,7 @@ def main() -> None: "networkx>=2.5.1", "jinja2", "fsspec>=0.8.5", + 'intel-openmp==2025.1.1 ;platform_system == "Windows" ', # for Windows inductor ] if BUILD_PYTHON_ONLY: install_requires += [f"{LIBTORCH_PKG_NAME}=={TORCH_VERSION}"] diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index c58849f9bf5ac..74f45583ccda0 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -910,8 +910,15 @@ def _get_python_related_args() -> tuple[list[str], list[str]]: str( ( Path(sysconfig.get_path("include", scheme="nt")).parent / "libs" - ).absolute() - ) + ).absolute() # python[ver].lib + ), + str( + ( + Path(sysconfig.get_path("include", scheme="nt")).parent + / "Library" + / "lib" + ).absolute() # install python librarys location, such as intel-openmp + ), ] else: python_lib_path = [sysconfig.get_config_var("LIBDIR")] @@ -1077,11 +1084,10 @@ def _get_openmp_args( libs.append("libiomp5md") perload_icx_libomp_win(cpp_compiler) else: - # /openmp, /openmp:llvm - # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ - # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 cflags.append("openmp") - cflags.append("openmp:experimental") # MSVC CL + cflags.append("openmp:experimental") + libs.append("libiomp5md") # intel-openmp + ldflags.append("nodefaultlib:vcomp") else: if config.is_fbcode(): include_dir_paths.append(build_paths.openmp_include) From 355462e1278d818deb9ef4a184073d5b66074816 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 12 Aug 2025 13:52:59 -0700 Subject: [PATCH 0300/1424] Add stable Tensor get_device_index, use more stable DeviceIndex (#160143) Pull Request resolved: https://github.com/pytorch/pytorch/pull/160143 Approved by: https://github.com/mikaylagawarecki --- .../libtorch_agnostic/csrc/kernel.cpp | 5 +++ torch/csrc/stable/tensor.h | 31 ++++++++++++++++--- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index e3dfc581179ac..8f31a680c6d21 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -36,6 +36,11 @@ Tensor sgd_out_of_place( const bool maximize) { STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); + // these test the get_device() and get_device_index() methods + // while ascertaining that we are still on CPU + STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1"); + STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1"); + int64_t *param_sizes; int64_t *param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); diff --git a/torch/csrc/stable/tensor.h b/torch/csrc/stable/tensor.h index d02763923a5f8..8d1323c543e66 100644 --- a/torch/csrc/stable/tensor.h +++ b/torch/csrc/stable/tensor.h @@ -1,13 +1,15 @@ #pragma once #include +#include #include +#include #include - namespace torch::stable { -using DeviceIndex = - int8_t; // this is from c10/core/Device.h and can be header only +// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we +// can converge on in this world as DeviceIndex in libtorch is not stable. +using DeviceIndex = int32_t; // The torch::stable::Tensor class is a highlevel C++ wrapper around // the C shim Tensor APIs. We've modeled this class after TensorBase, as custom @@ -103,11 +105,30 @@ class Tensor { return stride; } - DeviceIndex get_device() const { + // This is almost the same API as the one in TensorBase.h, except + // we add a check that the returned device_index is within the + // range of int8_t. + int8_t get_device() const { + int32_t device_index; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(ath_.get(), &device_index)); + STD_TORCH_CHECK( + device_index >= std::numeric_limits::min() && + device_index <= std::numeric_limits::max(), + "Device index is out of range of return type int8_t, please use get_device_index() instead."); + return static_cast(device_index); + } + + // The same as get_device but with two differences: + // 1. it has a more suiting name + // 2. it returns a DeviceIndex, which is int32_t in this world + // that should be more stable than the likely shifting + // DeviceIndex in libtorch (it is int8_t that might become int16_t) + DeviceIndex get_device_index() const { int32_t device_index; TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_index(ath_.get(), &device_index)); - return static_cast(device_index); + return device_index; } bool is_cuda() const { From 2c5e10a5fceb208b11c3d569ae02e348b5893b31 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Tue, 12 Aug 2025 15:59:32 -0700 Subject: [PATCH 0301/1424] Add new function consolidate_safetensors_files_on_every_rank for HF consolidation (#159393) Currently we are only using rank-0 for HF consolidation. But we should be able to use every rank to consolidate the sharded files, which will speed up the consolidation by Nx (where N is the number of ranks). Adding a new method consolidate_safetensors_files_on_every_rank to do this. Differential Revision: [D79000720](https://our.internmc.facebook.com/intern/diff/D79000720/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159393 Approved by: https://github.com/saumishr ghstack dependencies: #159392 --- .../test_consolidate_hf_safetensors.py | 42 +++- .../checkpoint/_consolidate_hf_safetensors.py | 202 ++++++++++++++---- torch/distributed/checkpoint/hf_storage.py | 8 +- 3 files changed, 209 insertions(+), 43 deletions(-) diff --git a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py index ad74c34c4e2ef..731a2c4d6546e 100644 --- a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py +++ b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py @@ -10,6 +10,7 @@ from torch.distributed.checkpoint._consolidate_hf_safetensors import ( _calculate_max_contiguous_elements, consolidate_safetensors_files, + consolidate_safetensors_files_on_every_rank, ) from torch.distributed.checkpoint._hf_utils import _metadata_fn from torch.distributed.device_mesh import init_device_mesh @@ -87,7 +88,11 @@ def test_consolidate_to_one_file(self) -> None: global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) if self.rank == 0: - consolidate_safetensors_files(checkpoint_dir, output_dir) + consolidate_safetensors_files( + checkpoint_dir, + output_dir, + fqn_to_index_mapping={"dtensor": 1, "dtensor_col": 1}, + ) file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors") loaded_dict = safetensors.torch.load_file(file_path) @@ -224,6 +229,41 @@ def test_calculate_max_contiguous_elements_valid_cases(self) -> None: result, 3 ) # Only 3 elements (width of sub-tensor) can be written contiguously + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(2) + def test_consolidate_with_two_ranks(self): + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + import safetensors + + checkpoint_dir = self.temp_dir + output_dir = os.path.join(checkpoint_dir, "consolidated") + os.makedirs(output_dir, exist_ok=True) + + self._create_d_tensors() + + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + + fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2} + consolidate_safetensors_files_on_every_rank( + checkpoint_dir, output_dir, fqn_to_index_mapping=fqn_to_index_mapping + ) + + file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors") + file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors") + + loaded_dict = safetensors.torch.load_file(file1_path) + self.assertEqual(loaded_dict.keys(), {"dtensor"}) + self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) + + loaded_dict_col = safetensors.torch.load_file(file2_path) + self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"}) + self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor)) + + dist.barrier() + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index a0d205f808213..c8eeed784c883 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -13,6 +13,7 @@ from typing import Any, Optional import torch +from torch import distributed as dist from torch.distributed.checkpoint._hf_utils import ( _gen_file_name, _get_dcp_custom_metadata, @@ -130,8 +131,8 @@ def _parse_input_metadata( tensor_size = tensor_info[0] dtype_str = tensor_info[1] for output_data in output_files_data.values(): - # Add this tensor to the output file if it's already assigned there or if we're using a single output file - if fqn in output_data.fqn_data or len(output_files_data) == 1: + # Add this tensor to the output file if it's already assigned there + if fqn in output_data.fqn_data: output_data.fqn_data[fqn] = _FqnData( shape_in_file=tensor_size, dtype_size=torch.finfo(_getdtype(dtype_str)).bits @@ -522,10 +523,48 @@ def _write_overall_metadata_file( json.dump(metadata_to_write, metadata_file, indent=2) +def _consolidate_safetensors_files( + input_dir: str, + output_dir: str, + fqn_to_file_mapping: dict[str, str], + num_threads: int, +) -> dict[str, _OutputFileData]: + output_files_data: dict[str, _OutputFileData] = {} + # Create multiple output files based on the provided mapping + for fqn, filename in fqn_to_file_mapping.items(): + output_path = os.path.join(output_dir, filename) + + if output_path not in output_files_data: + output_files_data[output_path] = _OutputFileData(fqn_data={fqn: _FqnData()}) + else: + output_files_data[output_path].fqn_data[fqn] = _FqnData() + + # Find all safetensors files in the input directory + safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}")) + + # Read metadata from all input files + input_files_data: dict[str, _InputFileData] = {} + for safetensor_file in safetensors_files: + with open(safetensor_file, "rb") as f: + metadata, size = _get_safetensors_file_metadata(f) + input_files_data[safetensor_file] = _InputFileData( + metadata_size=size, metadata=metadata + ) + # Step 1: Parse metadata to determine tensor shapes and types + _parse_input_metadata(input_files_data, output_files_data) + + # Step 2: Write metadata headers to output files + _write_metadata(output_files_data) + # Step 3: Write actual tensor data from input files to output files + _write_data(input_files_data, output_files_data, num_threads) + + return output_files_data + + def consolidate_safetensors_files( input_dir: str, output_dir: str, - fqn_to_index_mapping: Optional[dict[str, int]] = None, + fqn_to_index_mapping: dict[str, int], num_threads: int = 1, ) -> None: """ @@ -554,49 +593,130 @@ def consolidate_safetensors_files( start_time, ) - # Initialize the output file structure - output_files_data: dict[str, _OutputFileData] = {} - if fqn_to_index_mapping is not None: - # Create multiple output files based on the provided mapping - for fqn, index in fqn_to_index_mapping.items(): - # Generate names like "model-00001-of-00005.safetensors" - file_name = _gen_file_name(index, max(fqn_to_index_mapping.values())) - output_path = os.path.join(output_dir, file_name) - - if output_path not in output_files_data: - output_files_data[output_path] = _OutputFileData( - fqn_data={fqn: _FqnData()} - ) - else: - output_files_data[output_path].fqn_data[fqn] = _FqnData() - else: - # If no mapping is provided, create a single output file - file_name = _gen_file_name(1, 1) - output_path = os.path.join(output_dir, file_name) - output_files_data[output_path] = _OutputFileData() + max_index = max(fqn_to_index_mapping.values()) + fqn_to_file_mapping = { + fqn: _gen_file_name(idx, max_index) for fqn, idx in fqn_to_index_mapping.items() + } - # Find all safetensors files in the input directory - safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}")) + output_files_data = _consolidate_safetensors_files( + input_dir, output_dir, fqn_to_file_mapping, num_threads + ) - # Read metadata from all input files - input_files_data: dict[str, _InputFileData] = {} - for safetensor_file in safetensors_files: - with open(safetensor_file, "rb") as f: - metadata, size = _get_safetensors_file_metadata(f) - input_files_data[safetensor_file] = _InputFileData( - metadata_size=size, metadata=metadata + # Step 4: Write overall model.index.safetensors.json file with weight map + _write_overall_metadata_file(output_dir, output_files_data) + + logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time) + + +def consolidate_safetensors_files_on_every_rank( + input_dir: str, + output_dir: str, + fqn_to_index_mapping: dict[str, int], + num_threads: int = 1, + rank: Optional[int] = None, + world_size: Optional[int] = None, +) -> None: + """ + Consolidate sharded safetensors files across multiple ranks, with each rank handling a subset of output files. + + This function distributes the consolidation work by assigning output files to different ranks. + All tensors with the same index in fqn_to_index_mapping are processed by the same rank, + as they belong to the same output file. + + If rank and world_size are not provided, they will be automatically detected from the + distributed environment if available. + + Args: + input_dir: Directory containing sharded safetensors files + output_dir: Directory where consolidated files will be written + fqn_to_index_mapping: Mapping of tensor names to output file indices + num_threads: Number of threads to use for parallel processing on each rank + rank: Current process rank (default: None, will be auto-detected) + world_size: Total number of ranks/processes (default: None, will be auto-detected) + """ + + start_time = time.time() + # Auto-detect rank and world_size if not provided + if rank is None or world_size is None: + if dist.is_available() and dist.is_initialized(): + if rank is None: + rank = dist.get_rank() + if world_size is None: + world_size = dist.get_world_size() + else: + # Default to single process mode if distributed is not initialized + rank = 0 + world_size = 1 + logger.warning( + "Distributed environment not initialized. Running in single process mode." ) - # Step 1: Parse metadata to determine tensor shapes and types - _parse_input_metadata(input_files_data, output_files_data) + start_time = time.time() + logger.info( + "Rank %d/%d: Consolidating safetensors files from %s to %s", + rank, + world_size, + input_dir, + output_dir, + ) - # Step 2: Write metadata headers to output files - _write_metadata(output_files_data) + # Find all unique indices in the mapping + unique_indices = set(fqn_to_index_mapping.values()) - # Step 3: Write actual tensor data from input files to output files - _write_data(input_files_data, output_files_data, num_threads) + # Distribute indices across ranks + indices_for_this_rank = [] + for idx in unique_indices: + # Simple distribution: index % world_size == rank + if idx % world_size == rank: + indices_for_this_rank.append(idx) - # Step 4: Write overall model.index.safetensors.json file with weight map - _write_overall_metadata_file(output_dir, output_files_data) + logger.info( + "Rank %d: Assigned %d output files out of %d total files", + rank, + len(indices_for_this_rank), + len(unique_indices), + ) - logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time) + # Filter the fqn_to_index_mapping to only include tensors for this rank + filtered_mapping = { + fqn: idx + for fqn, idx in fqn_to_index_mapping.items() + if idx in indices_for_this_rank + } + + if not filtered_mapping: + logger.info("Rank %d: No files to process, exiting early", rank) + # Wait for all ranks to complete + if dist.is_available() and dist.is_initialized(): + dist.barrier() + return + + # Convert index mapping to filename mapping + max_index = max(unique_indices) + filtered_filename_mapping = {} + for fqn, idx in filtered_mapping.items(): + filename = _gen_file_name(idx, max_index) + filtered_filename_mapping[fqn] = filename + + # Call the existing consolidation function with the filtered mapping + _consolidate_safetensors_files( + input_dir=input_dir, + output_dir=output_dir, + fqn_to_file_mapping=filtered_filename_mapping, + num_threads=num_threads, + ) + + logger.info( + "Rank %d: Done consolidating. Processed %d unique indices in %.2f secs.", + rank, + len(indices_for_this_rank), + time.time() - start_time, + ) + + # Wait for all ranks to complete + if dist.is_available() and dist.is_initialized(): + logger.info("Rank %d: Waiting for all ranks to complete...", rank) + dist.barrier() + logger.info("Rank %d: All ranks have completed.", rank) + if rank == 0: + logger.info("Total time taken: %.2f secs.", time.time() - start_time) diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 542203ed82cf7..23a4cc1f877ab 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -144,11 +144,17 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: logger.info("Not consolidating sharded checkpoint in finish step.") return if self.save_distributed: + fqn_to_index_mapping: dict[str, int] = ( + self.fqn_to_index_mapping + if self.fqn_to_index_mapping is not None + else dict.fromkeys(metadata.state_dict_metadata.keys(), 1) + ) + return consolidate_safetensors_files( input_dir=str(self.path), output_dir=self.consolidated_output_path, # type: ignore[arg-type] num_threads=self.thread_count_consolidation, - fqn_to_index_mapping=self.fqn_to_index_mapping, + fqn_to_index_mapping=fqn_to_index_mapping, ) # writing a model.index.safetensors.json file with fqn to file mapping From ba47821f524eee50a214ed39fa2e7765d54aabf4 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Wed, 13 Aug 2025 03:41:21 +0000 Subject: [PATCH 0302/1424] [ROCm] Set thread_work_size to 16 for vectorized elementwise kernels for MI300X (#160444) * thread_work_size of 16 is giving better perf with many workloads for MI300X cherry-pick of https://github.com/ROCm/pytorch/commit/fb81400d34a8fdf301394b8197bef0fbcdb40f00 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160444 Approved by: https://github.com/jeffdaily --- aten/src/ATen/native/cuda/CUDALoops.cuh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 9b104a7966363..16acbe0b8bf2d 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -226,8 +226,9 @@ C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { using traits = function_traits; constexpr auto io_size = calc_io_size(); -#ifdef __gfx942__ - constexpr int tws = (io_size >= 2) ? 8 : 16; +#if defined(USE_ROCM) && defined(__gfx942__) + // Similar check in launch_vectorized_kernel() as well. Both should be in sync. + constexpr int tws = 16; #else constexpr int tws = elems_per_thread(); #endif @@ -296,7 +297,8 @@ static inline void launch_vectorized_kernel( int vec_size = memory::can_vectorize_up_to(data); c10::DeviceIndex curDevice = -1; AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice)); - int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? ((io_size >= 2) ? 8 : 16) : elems_per_thread(); + // Similar check in vectorized_elementwise_kernel() as well. Both should be in sync. + int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? 16 : elems_per_thread(); #else using cpp_type = typename function_traits::result_type; const uint16_t max_vec_size = memory::can_vectorize_up_to(data); From d0f9785af34f49825f6cf33e8ef4d6cb111b1e1b Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 12 Aug 2025 19:47:35 -0700 Subject: [PATCH 0303/1424] [CI] Prevent accidental gql_mocks updates by test_trymerge (#160490) As they could not longer be fetched from GitHub, see https://github.com/pytorch/pytorch/issues/160489 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160490 Approved by: https://github.com/huydhn --- .github/scripts/test_trymerge.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index e4a8cb2bc8df1..58f3ca50baa1a 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -70,6 +70,9 @@ def save_mocked_queries(obj: Any) -> None: if key in mocked_queries: return mocked_queries[key] + # TODO: Remove me once https://github.com/pytorch/pytorch/issues/160489 is resolved + raise ValueError(f"Key {key} could not be found in gql_mocks") + try: rc = fallback_function(*args) except HTTPError as err: From 1151b40cbf4c26c6c749cd26a093077fdf15ca34 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 12 Aug 2025 19:47:36 -0700 Subject: [PATCH 0304/1424] [BE] Filter unused mocks (#160492) Somebody checked in twice the number of mocks into the archive Filter them out by running following script ```python import json with open("gql_mocks-orig.json") as f: mocks = json.load(f) keys = list(mocks.keys()) good_shas = {'a32a7ca3a2f6e2c9de07aef821b0111539758b4ac254f8a3432af32314f94876', '157add81c519f614388f3a67e287bdf4fbb1791e6d0bffe312e169d02ac2813f', '4715ed05b382e572135c049664939f22f9b1249bc0c499ae278d655ad8cb598b', 'a91ab398f97fb43cbe6e0899980dad8ff7447457ea5a71bbc59f7702a9280eb5', 'e5130469b5373479776bfbccade8039ce4741b97873bb3bec4e279fed08602be', '5dc32efeb8306f03744f6804ef4b500882f2759f7ac17fdc9f123669bfe4805a', '0a34acb829d8aca9dd28a8ba388dfa52f6ecdde7e903ace1caabdcfaba87de98', '8b50878b010492fe64005cc4b4ed34ac5f6695ce093f06b0d8d5403b7787c2c0', '2877b3b1e8630ca4ae797b9d85d5673d25ca8488c01141e11ff55f4a1359fca7'} for k in keys: if any(sha in k for sha in good_shas): continue del mocks[k] with open("gql_mocks.json","w") as f: json.dump(mocks, f, indent=2) f.write("\n") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160492 Approved by: https://github.com/huydhn ghstack dependencies: #160490 --- .github/scripts/gql_mocks.json.gz | Bin 692987 -> 281579 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/.github/scripts/gql_mocks.json.gz b/.github/scripts/gql_mocks.json.gz index 07628227a18a8c78e981d04941f1de74869afafa..1974b2d06ec14ec5aedf6f8c382b17bba43921de 100644 GIT binary patch literal 281579 zcmZ6y1ymftwl$1vfZ*=#1P|`+?(Q}~aJK-#-Q8V-yThQt-QC^gC%Nyv@2-EXnKgCh zRCQO+u6@p~={|(vFff)6=@ua1<_W=;(EDQ0P!LzZRI9cD=mbE6*W(Er5}N5S4-=qI4ib znOi$7Zfz}VEqnHTk2d6aevpx4kgsF8D79237AEmJXQH+dYk>(p(n>mVh*so(KH4UfkL<@CDc%{impYW~ zcmiC$h(yiM1&1(T^8K6~s^7W{48oi&lGg857v#bjm^lRwU5ds9qwv2UOb-nv4q+^7 zSdx6OM+!e4#|{+IzxnmGO9_>r*LK6^GZ}S`c(a6MOd0wgtVAQ0;lNIq*uOhq$nPC1 z0PwD#a~YdH43jDV-hF&_$E6lfHG`VigCnQblhHF|h#ouA>&3q@zjW-*vLPb4EtOtz zwg-jGfCrZPq6=nw<9oy_KNUFECF#@225iv#HzBK@rSPR!6Wymc_kUl71qMQXF+%zlXN zw*0{SgwtZ?e=XK)$%DV0LD7B6!Iv>EMk-H_>h-!@{PFNS*ch^`DX1m3Zo|pTbm-rc z{-wb*y|D#-Hmpusp?*-<7T#65#Q(E>GbnuK*hxMsumoBXD4Ax8icJl>emK7DNVR5H zua=P6+jcS7}F>)D08wQpcsoFyM0I1IyQ}lnJfed+oMp$pbrkWMF z&o<+7Q0JaqIG=(G+}jmXhmP7UE>fp4V9S}2N*jg`Qo9r7l`~z=$d^5F>$)Xv#sblv z%GijuL0{P=5nLmkvl~qtZwM_xcMoKp0!xA?K?Wc0wmsf z+PxaXr$>{)fPM!o-L@Es3OL~i3c$u}-p3f?gETd7>9#?~ObJ?Eu?!BYtJfjY8id?; znJ|=f_tmixH> zo&wM3_O7Lc!Lk=UYALMX7u6TK6eALmuWp8D5Xf{twk6eDt4=%ous`YeuVky$*bel9R^k>(3G_?53oL@#0etg*fb@pSj<~wOO%9fQV}!@5 zjx!)M5HVn5EGF9O(pJps#Y-BP0dw;{uiW57KsP0)_RSOpe zsu`~BZ+5V4uG|@ONCDOSeP@IK3c+4%^9-egMcPV@5GUgt4(qqFtEP9xET41O>>yga zAZ5K7GOyn7dcHi)!s>QvTwRK-^Zqh5(S6FcxZ<@kfRs?#`(6~;A$X@SZ*VqWov!eo z1G&^M;c}A)wQA$HTKt`^t1-Z={|yIN`rJ2R}<6{VKJ#b(P<%p9@O*4 z5d#k6nywzVQ)7V*?7DNICriC$8MdF2=O|=!0==Q%IA;6YEu{STbM5utD_pnd04VTQ zuo(biY1EZmKu>isTH3@PaLFTV+G0sflXL7|%goH)aInbt-UPkJn_QE(?XPEIPq>h% z5nZ#kyQ7L+9#03+Lio+Z=NYY)@%(;i$8LWO4uBQol~=F~#&96rXz(6o2^>aoLE~hQ zVb(`IUjKy_0%^Kx*U7-pY^-SXBg9(J4sOK|EvG88Pfw}QWU})eCk6_NPt_=r#@Uv` zLvo3_0vGM24$2y6tq-B3m%@p$wz^ucM?QrD|sZ;us+#6#M*KCJv zCgt9rd_{gUGfP#=n}jLNj`%ylHwtr-O3-Yr25Z=a1bT?Badi=uAAZVRDhiTIs)!b5 zy+9m9`x8tg8tOJ6==c|_>U|I?jkAQpxVoQsQ}?taWzF4=HcD-)5|W?1wm3&ep3~W^ zIa~Gn%io$8T9R`r9v=kEX0JAUK}DA&_%u&Z5pvAt{&hJPm4wW_fWi5%`hAeUeWHFy z{)gnQ$^1Q0%||>s$PN!rsO=Bs)?J<7d;>bA*FUbKou5#_uv!qFP&*@1NJV$W;qt<8QV6AnZhzT+)5JfzisAp${t9cT<;LD;;jn;&dP!KVpIT60v*5Cx_xek! z!d9_M!u~H<=?VObkqzTb&07sQr)cxC!0eL2<=Up^p8A|qu=(+^>=K9-;=D%2s=si_ zhW)OBN>9~ZozhYGTz(Y$@+!i%Y1%p+S_&?($NSY*TeDGBH2(m*rYx&*p?+;bpAW0k z)o9}KHA43ehhHi4HBnRggdq5HX#R!9ip|mYzi3&!i;zDfE&*vKsoO$*gl{Oz{v6y1 zn|&^Wyp5^XUskt}YlAac%Aa6~9muN$0_tshpn(_x8~d-nXb@{kc5IsNXn|FZ0-qyk zzWXV*M((f|pVC0Hvijzwfor^R`$_>Z90gWBz0KWZRXN%%AgrIUHmBtYbxx;7l+W}p zbc^2a{oanC%gUMS#!4 z(^2FT>i)`2tZOHOZ*LKXrSLn%o!3UacXoIvT)4x2aF7EskSUM7K^G-x)<~~mdm;4| z7W!waF7O7;4cb@7<;mwH7eC6r#d3vY(jPU7)Ijs&YjjPeESC(+>i{Je8d7CX|IL*lC?=ceXq=0gUYXws) zRu%4V9_*jTsu3gqa(QQd#iY`88*tTJ{H4ImpIYQixbeHGbx#mBk90i6wq^bD16)Ev zxqn(8h0w7bG89L#4Q#=gv~x+f7R=D{wwJNo5<`V{=COi*XCK{48Q}#}@hQnO zf~<3i$Oc7$PufhzUs|%x<>Gjo+3ms3CFqKD<(4AO9j4E#nWYr{PhYVOm!|1UHq+uM zfK!<;+hbK*BCwiz1{HclFh_6R5jvNjN; z;NMa~ZcGDjlnlyLUMGasxj9=as^Nl_I)k9$&uxpcC^BcWHf8o*hkfajnLX^5N~+gP0i zrv#e_t!)5=&9OBlOJQIVNEIsP3LZcovQrqprxENk6n=d*3tp)P-DqrDUSTtZ?%8no z{e};zUIDsM>6`Nz8$bO}E9!w4%&p{)5=(-Y>Dbm?P+rP5d%yoA=z?{PsmfmWq zaWVzBnV@iP#ly9p8pr)d(10z4&DT9OhJL}j#!%zwXCW`SrMNc_-dThWE+hQ;I{+*- zwRW^~{za5VV6z5kTe4@B5FV9ww43qF_Pn3%g6q%(bsMBh{PcEN&-&mQ*C(46$J=;~ z5K`DHeKS5y`V@p_UL*W0Z8EA!YVB}y7Q1K-@Y24MkoGZZZ7|D|*ECfHmQj~R?T)ZI zRfQ(L_R0WIo5z)9pxk@{zNueuaFZ$vNVV^a^1mP4!fgcy%D!nCjMel7E2r`(LUoe` zIP9{mulR{;a1o1I;H2gmu87n>DK;87?|)&}bV2*qnsOSnx|Fr1rOYTgC3-}VHLrlu zzqB3r)W)!{V#O?K#sJ&O3r+l~hh-NB1$5MA5g5<%8&=sRRSF9`25B}_jeWJ}?4nZ| zR922;Qv}r6E{V_V@y{wugSDwml*{={r4Zxerln?JAvUIH7Qh$d0PD_$bBy`XlU0)TH7hZjo0|P_%|uGA2F2;jaAfEh6>`6d?>Z0i z9|j-Fb|~7?H>YNsHR4k#P^+l=n=RquQH|!J!!wl|O&9@&IyP*JHt=WOa!Xt_Ywedz z);+e!s&45?9Y04tBL?2OY%5(J4q%}|-_uTR>3e&+le!4`I+kKR{<_3JQ&og*eYFbk zW0IOX#5LY7tYmP7y@MVB>~Ii4AmFH*~*P@$7;}j z)clZ&5=EPqhyF4RVDZ(K?J=r>{!!qQSGYJL+OVCnW0iF5vj1Zu@YOde9sIwp6dM>` zIj1O(k%)PfgX>0jH#-LjXMLUHk~6Abh*CZi;3K-JICf61vo^fh!AP9nN zT9p~dM5FR6`rk!x=8je_W1u){q`eC98LR(75}lU@8??<}i?!+@@&+jVw8PZDNUQ<;NeT`e`H6?EPWu9ht2-)8lFzjMGibU>=<@$O^9G!>m zjwX{*Wy!^Zxs!B~Lbc>>J5yC)GorWN%M)f*Bl-rh_$l}eni%BD;jBgr?j~(M#k#=S zL_#WDyKO{aMJCdx8b-XJC>&GSv7JL|q5lVul3CeR;|XDZ;y-k#S?1i@?v#n7ffHM6 z1qrEA?UnS1%l~UxwQIhWXZ_t~cu(f4^7DDV?azd$KXiliqV(rA7@zzLXu`tfW!T+I zz?&pxmjqB894*xmkc1d|Y!K>|ArM>H!g8?BOeT@~d3G2dVdbN;9@ore^8@mKXw> zKO*MZ;L6*|{RKEpN*#gkM>~ue=1b_3rgM>(I~3hAWBLq|dLf zNCVZ*&bX?DhHhf?(#Z5_5#EXl)GB&b&>Nzir2hW;IVx#Z?SBZn{;~?Cy&C!4_Qzr6 z7TCHBW0SK6WwhiQ%Apg94pI(E7uH~YFzd3%sv?&-K zugS&9S8@r3LE|?Rh)0_e?rvlwn9jak=?Af?xxHEB(9-JsvM_LBa|}1SzQh;nGhY!L zunx;z4wY=xpZ_MjEShI5?@vQf`jD!Nt5eqCcpLkPqV++O_PI53pX=SQ%Je;Y+*Xyp z(T{`014{jNy#|p5G!tQ5_ zL2-4?27y&Y%19#cQg7;@=SqUd`aquIPw63BcBhZfEZn5M74|7mI7%0N$u}cHmmB=0 zKQ+pFfz`*HU;T3nIJxNq2tcXj!>W7ynOgWiR>!AR`K+4#$wv^M$-DQ9SWvw4y0d#H z*}pa$5?90pNWoQ;X^x+x7zpP zb5HAWx}ru<5dbS=mflz&nf#Ajr=47jy1KWTBM<@e)z`X8exBKL_?u5T>2wDQ&}Cfp+?7UWu@XFRBUhL`9(5Tpeor6h;J9*N<_%}rAt6AFbYsmb&@cGL8?OZidAS0y2$Uc- zcH;XI55XtZS|l8GR^!<>n|lN)qrEwh`fp@(&K`?9lpC^LQou?|}zWGTR* z$kijnEtd9$5h;R+jiDe`0vl&W#nlY0!3{BjlbONP#V;B{&&bV?^P^6`mW%)r^+?wA zQKo~NboNXU>Vp88E9cy=W}`f`ed$)FvK#0XXXt@Liyc<~$aX(nFJ6w<**IsenFLkK zF}`=%_Ha=)RxQpf!ly~wpY$`oMkYjKtV)8-i*-fAo)nXS#^?P`styaA1c^kpJ8|If0qg%h*)A<^1-)KZFM=DA3TSoPFHzvfnS>j2m zpV&SpI?APo!6BC2?%lRhhkfs!`^_F17JylIV~}ns1IBXT9%HimqhUMB6$@NLjHX^LVZqwUgfGFRo_!}Tm1szYLW{EP;Q;a>}&QDvs- zSXAuPKRBPZSTp6`enI!%hseUz{>7|4o}N}I0PCY8sfU5Yn0rHA!lD(`R}62zd3+Zq z-2m8xVo4Y#^VnCZ2eEo@O_z2Siwfms8V99p(v03?aQmR0c1hWb3!@s!tQc;;TM54# z`J$3LLSFMLIm7(NQpUq9M;5)83luzt8+U^7ZX^Ueb*~;GkKY(>IWMRR3sN%(|3Yg& z6Q6o}IN=~sU%0o>Tnx1fxSRe!u9T=h^;nHBzZ@DxO@{KJ513aKF2ij&yhbaLs98;w z*pY};CPcENJw~4qEA!Noz2ykV`i`+j&Tqt=1>q%=dV>OSi|8;?uxBWy;w8kyKK7pj zu$3^~tB?SU$_u$YCLi#hXx-A&mOZxL8Sf)GEIp!LcV&NcLJ|nPv2oLqr5zw8FM<(j z8z7ML&0sTMQe4xR!=8jLgKr_qoT$;oR(ERK_lG|K)_FYN-}~zG9|&E#u>+PO!g-I7 zMNit|wT@a)1c8e^x=(28>_qql@d}pF)XaS%wgXopRxn%kChB(}-g^M#KD4|$ZepS0d{P=G@s@b* zpr?Wh?8iP}HPT+r2yqe;2c%wdcuUBVi<%FhCPKZvqZ+g5BPVxn^|R z!7t}k{W4(UsN>%mx#hZ*zC|M9Omjv=&TAZn@8yD=Y6hbqcw(pWLJd1F&+Sm)!hL&(8 z{yq9gVf-t5h{l#qA(Jna{?@?nX}w2Lw9+gl0}9fIJQJpzZ1{t~xV!gLR_-yg;^Gpn zZdSv5PRQsT6B^2299~KhN2nTKQysV@&5bSg2AB{h5T@`7{K-feQ%jpthqt|N1>kkf zg?|xw%<_;#DZ$tkRtCrG)u@Z!J)aw9N)WnCX6fzm@@Lwyr;uA3eQUoPt)-o~aUqi3 zue`7ujWz$^#3xU#yXCK^sS<(wxoS`N!@-hX9_+kRj;`z@ zyP(TLGmzxW4`L%Ij!fxnyMxv0ctuiZeR$HW2a}G4PIe0=^dP6HGEb6|@o2o|8>Yv` zh5Z3Mk))3vWHTa0<~7>=WwE0d0wOeU4F_mZ(Mwn{F+`_65rU(4-RuDr4)@&wgh+zo z6W$ixzYi6Z8$B&@)H&7Q;zBN9jNhKFCB9Q_rqBH1n683(G(C@0h}`%5DFa7%jw3CrzPp!$lS??yF4R6O^@=>%$it zFINL1^50?Nu))RRen2l28fRKLvouuFLSw2Kk2arAriHFtd3%mOE^agsSp-jvh|j+o`8ER{F+oL~I9XDY z0n%W=hAY)_1N{^`M-aikFpds#AAB5}BRt$G&kKc!f5DHktTOv!sMQfE zZ4Z!~}$f6qjC>^;ULMyvllY}re zLPB>Fb0LTxvdR|L#!D4i*{W%I)cTew(O|fxSUgP|@(e1R$wEEiUy&Gqj!C($fDeL4 z<*gjMMrUqyi6mN|*;{1PB$~z1NU#j!q24U^Ct(@X-N z@Xih-hYkz(*wg|yLoBgKwYUEzY$4kCX2&#NdGU;=m3izYU9*2pXkU#`HiILMJJDpF z_VT?owW+i3?E?TanjNuV@hac(!$}cg=62QX{%UE`*2XWN&+tze zw0CV^`ZIQ{zJ25EJA)U4e@S)>Px$8UkQ72|B!d=Hf;z%U)t67PE|EugB*3k4z*`P4 zOSxR0O%}ox`l666gig7-gf$Pzx5lA|8w-)B*|0TQdhNTdc-R7$kRoR`aKvz1nW3_- z5&~|-ct|V7@U89}qRD!z%Ub*6lk&SKAt}NYnhTC^Y~N_hjLoj>1m?Q(cjv1JogchUvV3m5iL+(JGt?D(G88$G8`JK@#z6? zjlKKkn<8f)BAF;Q;hk5)jyQo1?gp=y(mZ#;$^5`e6hf{knb#ZW-DX1q8z|gKCy!EN zyL?G58i--9IioL3RYE#w1H{Axp#^E+;W-j(z6icak+Gq@*eOlYK9BpiM9MJQ8ZXzD zpQ&Uj&_P9mTK$2huzxJx?90V$xFCo>3X6uCo=ZhNr+C7b6z5T=V1R@-tSTWcrx&e%(e#1lQ`$EsW9y z2wTpgq>`1JZ~&E+=7BnOiQ*=G_ot{%x|SHoz{dnQP<`9KXI%;6H6_TC)Bj9Y%xKK8&U8;{fM6V> zXa?cP4B6OsNL-mlD~!CR1W)ry1z;HXZ5$yhgsUm~Kuk*gfjSskUQOcC(mPkuvY7U!L89=Q}Bvj#yJKXY| zXbA8OCy+a%-upyk8uo72Zkl@edaA|2c2aa*wC==WKCN80O)xwXy zF!vng()*$aZFSAeUf#B)L=y2 zROv=G8U!4AkN6Qo>q2oRJ`-3Ol1QkWs;^jIgO;hUMY6>`eo_{ttdRS`RTnyv_=oQt z%FvcH-+ZkgaEo0(!q=m&|EJZGd(BTb!K@tZiZrk#qq}yuO8vHoZU%2jD@hQTCC)kJ zCT5HQ%VidX^0&4c`;GCRf=5MZrLwTfwk*FgSO$m&b6E+Z+oi{%el}ha?@{(Is2eH7 z27kL2tXd^3i*H%OVjpjh5|D|mtj_$5N6bL`!J2D8b+jNGRad~ zV+!Oz`q7``ml`iX0X{AV}{#*uP}~-NmAQLPc54KNfSoG`s#cA&0WG>g0&xW;fa@o)nc{DR>aPxDr%DrvIs z==z`%hC!mBB4IR)5IgoC@D_en%ERJiV*R3RSmb_1D{7^h^5@Z1SxKSXFl(9WWG0kW z<+UFGc_c~}s|CWSuREs>J&@~1ZJx2L3Sy`qsBjPYAK7OtYCx?F@SP~QsX_27j5OxaaTLlLUKP4E(iFclZc9EoF^X?sQ zDH-HvaRvWub(jyCCc?&tvF-+-M--v&0EWjLQt0m!4)_4xu4Zy}>LB$W+GhZdjbS4~ zQlD@hLTd3AG~mSkMC&l_K;L3I8s!(QSb;O>B(>a*t#mYEYa)tmvY2vxgYW^q=v)M> z;&9DVH}Pvql_86C!-#O7t)BSt}6{<(Cc ztb<~PLeMw2htmoEM^jdGhhh5QnM#p0fthIm=S3;FXQ(DD-&0qt-4js?v#oyQjZv!} z?X>9Tp0QOvV!FEnK{%zHTbmJWIlHUu;SxZ`5>JJl!ZQ+D4HFRiykLyc<+U}qfZOm+ z)Xp=OVi}(wDAyM$b^MS;bP^FQi`Ku{W`(bUZBStTuH-KQ&&a|k&zfbBJ4*E6jz)fl zp(QXcD z3cPzYb#G_F2C(*e;8%2-MVJcqtvV-mC5HJk;C@em_N3D?F=c`N8TchB(Z$Sg8Imh- z4kg1JcoLcoTbJzdKdX8v`%qkd@@Xg7wC;W|2bjOAliG6#t`%BE?%@YHFt z<}~()rZhK^N~1ErHQ6I^rhs3E93k^ebH^J;WztV+t%v~L#gufIjwInmthp1E^4m7K z61Zli0mF*l5~ws1it^s46-}U9r(!wA=aMH|x3NR_=u%99>0FcX)9GlZ10Io>IA9WZ zecYJ_$dxS~HDuuRAoq`s3rcply37r)3J+POF%zqTB_v(H3tDHH`FqYDwqM*Qn=zrS zU~VdQ+>Hmm2Is_j$xtp^^ZLd^$|64ZKc$F=8nf>om0q`%EcovQdW374OQR28m#Xo7 z2dr{LbPed6!~qJiZd60h&=G6{mVQ~Sht>JG(-TAVjUE`OlBY}ae#uJ@P%5^bq#`FJ zeAVQd*pT!nJ{(L@&htxQ5%($c18rQ!_gA7wll`4vQvwV@hmCgNi^H1MrL6 zW&2w^Qq~b&EMfm}G-zK+9v; z&Wc<-eqt6F${0b42oRzpn`@(BU&s^?f6$rJCo&}bp z1j@h4dn&*p1_M>kL>E(_gS8yFc-#8(lQ>oNA~*$2-=ujuc&Jo_kWzk-B}Q+4VV(Ih z>B87Bj^Icm7$Q>QBsd`M4z_w7VQ$QGG1F+-Nwcq`Zz8dsHl=lEJ&t{|8D<0;?Szxb zt|E~B-s)gkVZ`xm$*+NK2m_*7cHrGfcJ3Jw9HjxA#+Be`CtIn4w%}lHC)`;pXvEr8 zi+JGRozg4)g6hx+mmul9TP@z*0>P@72%W%nyyIb92=IW zG11)fIo9FIE@1=uI4{%y8?=x6VJ1dC9b7L))b9nN;ouI!>_8vlVrO&{u^KI0%hadS zlp@oM${30DXK-=5vDihGto1Bi3V)ws_ewWiYvsfRx`$&fG@! zG$g2-n%>-Eqv51406M2^^jfi>NmjpZb)p!S{cBa@CmkSf@HthDCLxSr@TIN0uvqF};#p%sf z5F^1444i@*tuuRW;-dUHoG_Wl>g^1hCsdP|Q6H9CkIc}=S65vwCaITDjZ~D=J+pb~ zEb9GyaO8{uDlm1!$LifIe3F>G&`6#wGi)Bv_E?9#uSj;~GK2(rKpwZ-5N}ML+gz|YG&DNf_u=Oc5aOI^h6A9Gq7 zAoPyQG{0@<7`>|T00C2z?tZW|{S4Vl+=2T^lirs5Zv54UUS+J&YZ{^aG3rWV4Z><;Nl#P#f_A@yOlNiXMc$q6)gW&T6kM@9R|_D=>XeFC zNVFi2Xnay|CO2af!!H?@;t2-Kf-`k%zTkj#>_je~X(;3}V+{*vQFi`KZL6wvf2|$I z0io06klIjfh2}wHtWGg-5!G|QawCcNB9b4&b(bVqIK*gI+R$>Qt4M8|fQ($-nM1|^ zm?R;!?`U*3O0M|%N01s1%`Typ0z5WJ34#9_D#I*FOt||ZO|b@TB|JwC8u}dc(8x^Y z>vuUHulwi733yIem9>6UtS3n(_>Ops6c1%9>}Ywp%k%yCcppXGtGR`vaCydT-tTfA z+)qC;@MJGNm>D^TPA+BxlrTLFu8Xw{-HtL1OSGe-2=SvRuQ10-_`pv#Ycw1}Z}1kG zm2vmaE(@Bu&K2|(9hlPxg~wbteVlUaMwo0?_`0$3!(rF!zV%FesbTQ*>>KjE`n-8S z^?r>R;lMwh_1lT8^?YaK){zc#s86PBFp&6Gk2zBDcZLC}hlxA(U~$+_0^?svf$agi zuwRgU2Jg>Ja(Ydya*nwpb@Vw&VjA|2$y1|L*WMJSA;aO7l?J;*8HQt}$a zN0l&Q#4U+Jiui_AcR3&EE2Mm@Kc%TRWb*FMV9)%anmmv78c&Gzs$Y`7p6IoID#*t_ zd;tCQN|fE@{xr5EcD+PQyPeu<6X_OgKGH;&MSmm6!ykG8eYFa5WmAbcITXFU8_p4N z6Oa_RBR;6V)%W@FHffo`GttGXxS=L7YI+=-oJr~E!CVA>8hR5*o6V`tZ{Yi_+DBw* zPJXPbf>mTKF~siGZ%Lhhi?*^RF$+ucNu(metkBf;uA{j@sF4|;mi{JDoKx4*>`qp_ zRV!8bDmgKFQA*#b0lzM607#~E!&rVZtp2=1%QC9I;fr3MNL=YSb30F8=YH zxItbud*6pq+3?e=T`&&`Q5Yo+ID2w^VN93f2;;mNmPPJ6|Ff@FXTAxaWP2=+55%;$ zs8D`|zMIExpRvv9}(a%(sm1?Z9;)XiMw~*)qnVW)%Z(s7kDY^?RfjyfH;*NVTAuRDs z*7E)zpLb-sRy0)KuYRB50!v#%ac#Hq{K$f#23j!^U&BnfZg||=f5Vnb?6W0>&AO<* zNe#dnj>VTP<)1+N+e@n!TF`ioQ86U7)PC_i%NKkJHR27`HiPjN77wIo`d<6s^vCj# zhx!xT!1Iqp2lIBd(kCt$ndB0+Xo8Y$%MnKoPN-6Nw^BzKgEdpCA9DzS=Icz)e2n+u zkfkb`TB7V-*V33bgU61d4h7dnFR!`wH!2vr?&jTVpw(>auu5i4H3G_6z-O8@%^Ah&rEe=@bW`dmzaHhwnQ9wgUpm)p;Ipnz6-d9uiCQ$pFf7wIP{1zjL^SL0NMmPGi7@mb!HkL&|4FlI!C#Ss-#Tmrt-jDlC1*m z8cw)*jf|Y=K}~O6l23iOLAY0m_dBP_!sZ{aIJ=NLxj^w7AQ~Jnu52{s#WdADgrjaU z7qN=`X}+goZqF)8JPjZ-Pfo&^V#pYEg$ikv<&{{C^~!W7r`uI^Yq0|nXF4i{o+@pt z*thzPoaBEUe8RfJq}n5yyrt?)T@MC)daJu*LU;4l(e8(o;V?GR$8LhO8{Pj{u73C7yh zV-ejFiM0~gL<9s%-kG&8eGTg7W<=!1Rb z;EJD(FzVyKU%2+KUFLTM)<^Pkoi7`{ulpU5yMmUU3d;lujnIW>|JxL!wZ$L#<@Nvn z`d<(0#_9Q3nWO*e{a?*DD%-Z{BtExahWn+uMDO1(xK|YmhZHkC@21^!5(>Pm_m`fyuKoaClSlnpk}u*n7&fAPJwBFkyF%r%`E*lL{?@v%P|y|BpPE|`hLXxa>U-EY zXzWe?r?#c{_)nL)qoaFVrn`Kq@e1EVsS}+X*yL@&*>Wc?7YP~@G%>3GK}r1QEir%d9bUgnA^8mg-ZmyC%@<3)`lpUrdE0QMAqhs5$L}9jeBa@wewKO! z^(0|mckfL@@ChaGNvN7V0>}7-y0p!R@g?|x#-0UkTtZlaL)l`Mx zdTMR#WH>>EtteRY6Pqc=dR zb{9yzi%dB2c;5g{eyCnq?!gcieq*>0cOj$eg}fvCR%q@>oiQ9OH0cmTC=#KSHB>=V zfi*Y>YJ&o@`gQgOoi#3A`KWkCJHWTFfzX~xlV8#h8DPl6`%OBa$?b_3BXq1{d4#t2hb5dd+;5JT77bN%A&N725q>F&T1sf>= z$_xS>chE;h43j!UKfLe^g%5;2)Mp9Jj3~{XJ(juUwJ=qe9|AjZ(Gq{Y_p1Zyp&ml5 z&NS`XU}qxb=nW(}8b$?q=F~n$ldzn(zDEMA8Op|eD!!^kG!l&3mrH`LUu;&$0wj_l zd7KIis`eykyu-tpcBk2NrG6b%>%eX(KyaWo7uVk^`D9gQz95KJS>+=|zWJShY!%f# z0P1HS&sxu2Av$)rRQxK?SM!i-ON~Wy2>m(l3uQ@hxW?`{zg5Z!e(-d}jQn8iAiwqT z*z!C=%I8ce(8Q7D7K*XfGSWJo8yd6^iQs^upFis#(q>06@K& zGsY(CuO-GA>}5yhER4@iHN3;aM0{kOVLK`>=E84h+=OZs9f;-13dutd)PF}3ru9oa z|8@X+y~i&b%)_jE)Ca|g<*!Q+)H;q;$col`3zOe#oJ3o}WCgpMZe^bTs55w!%vov- z-ZfywU5!51sYsecTWan>;@GnpSzxVbDdu03@};@W@Y+cy-GM^jcZ;F53@@HJTBJrrc93@y4oB4Ar1vNZ8lsvQ-y^wc zpwprV>T(k8*VRg=VYJK!i!6%dP(3GnNkSAUY@~p_cGix+`@YC3c?Ee5fTp~@-Yt%z z*jcTth{F%8*0J2}I-0bCznZ_HUfpNk3usSmugC zM9|19(9rv{>g_D*mLn}4<{C$vg}V~z{1|(jy5YdCQu<#D7|}asy>}P!x8WST(SVb4 zPWjs0M;=)p&VfmG(GUC>KNa?Oy{U^-ST7EUNUU4Ejo&0SuDU(E8#{`-sFLIdStlqH zWr$hYV-U{_m)eegb_IjIz6Y{`wbv0{#JVmN3pblpQFz@T%!PZ7+KV|_p1Dlqg<1Wm zKKgKkWPSOMZ_;HXt(jp^>K2aWXo?}^p{hY!Y*=_?;L@oB-3;S|pf^X5Sd%O?9oyqO(l{aX( z{#Wx3StPI?OF6oTxdt#xxSB`Kqm&XtNWI`Ao~B5P z7|2ukL-_+MXMvWIZ#KNNW)ly?RY`W6$7E`G&4C&_arqL_4DaFf*ev*5+MgO+$V zQp!d*4$!t?qXzi~H6_2a#uA$4b8?=H+Bd@9FSxeTcC1r0iVF*oObDkp%aK(mIM$qa zX`ErDU@;w~u4z47Y%^ZJ9ySuJS*jgGyk1;OXJCmhb;=J1f;2i|^ae!fECO3>l|rBm zhEoMW)G{t0!7z9&rFHtu{Nx2_;X6k5Y`vRrTm)>8>ODF$aEEHjHHt!8`Vqi}*lj~N z8;#v(YX$$@(1rXWV)|(+S~OE>oH1{HEZS&_B1JcIl`@Kjq>G)_NefZCbr2)Ucc*FM zGWwJIIyko7Hb5^fb3C|MFEcRxktfSzP+YVLpuFyAh;-14NW!#%*JIKzJW9Np8{jxn z5sEx`mI$v&W3T_*A*a+J)5~1=eD-IhI`L@cVt3(aLb(`UKPU9Y|4uep6a@oR5c?pV zd6(kUIo7m=B&L548FYZdsr!%|FVcY90F+v(I2pBPBfKd)TemZEWS|Xm_j---TPFD(c|i#)axGl^<(^LZzZ{!aRZN+6cD91qp#Qb{QQ2S{rRqke~j4q_;`GMf`lKL zai(61xcEk{tzYHZMt6!3juJuHeU?9k>n?fvo%X?BmHBeols{cX(SXuNV)>D=c^Mb`;J-dtepq`dPCTYhWV=HFm^o=Q;Uhmhg{ zPS^YJ;$tIeo~C;eUF|f$66uZ0WAtbf;|%}$TddM!2axo~)99%BDeqLbZ{F>2zwrxT z0zdhI#g@iIB4`qy;^#OOCg?ZaQD)t9u?W#B`P8FlP*|CTJF@!BDU@FL%x=u89o(MX zf@#^inaRn^=(l;}{P1G27~PkYr-_eiuJU^^MG^tXb>E-e+AaB2u3VUsbWxbWe33Ge zsXLO#O7P{n4u&&+-)3rs31H6C6Ma_2j~g$e>)mJ`b~@C~d{f9zv^Q#^(>-X?MtG-b zVp=eZ1kLbp-suQ}kCXLSrNJOwU4g4zus;Slt&W+DB$XW=?b&f7 zc^dF8udK5>_s2j#W`l0)IdF9E2b0?`Zsh{~l)5pkmvEwmvEA#pZo6}!q14$X%=KSR zPo~l{fRy84%M?!1IcH^_=ha^up)BUza`3h2lRwLZz;LMAz*zx2Q&lBOjMeu#kgN)g zgcgfwa}7jNZZl%Swn|#2*m`FzU{(b-ytCOmMckqd1>MM{If{K|SvV zMTPmb3t@AgC7ESr5=$Qaeg9Lko-35bhHJBRN#6qnwDTXwL28sOC&8%tJQ2G>Bg82c zLc1Xn<7GbDrPwUB;Nn_?v{zNi{|B%@Pro}F^?vdt4r8x>I-WI$0bJs^!gi{T#IN>z zg#il6jxqpi69c&R780PK>?i~97BPTD#g-BPDci>YtWT>MV8PQW3h>QFgYVjuIaLuf zS9@LoUsHAveAXuL!8+R#e4y+g_`FTvgZ{Q9_(0h~@Xa=XZ*Ox1Q&4sge9=|KOE_|eHJ@&n9T1n1ET^H~@3Xt}=mW853Zk-|d(HDDRI6uKD&~*L8GV zj>45Em+D@u3C$}aFk5irorLk7!D|`m#JKR1sWDBB)A)-t zBKQX5r3~nIdM1q*ufI2bi2)Avo1ab%oNIW`%(#fpW28@K@Lv?gfi%wGXPlwrT{@Km zJz?0i8qY;M92#MqjqzFyFEQj~0@Sl)8V^h3EO~yLek&?#k;(X$4>^w$)|T8 zkDfh$^UCNMN6>wg4u|PQoSYgUgh3X>LulTgO<{gA*j3Ku@X`op5??18dgpN>v&%hoT=;soa3AP8~U;<(^n9dZh z7Hb>EqXbuQhF7H3;7Aj*ML8dc>#{{}WQMcZWq4Lm9(ZGBeEaI{^H;aeJq6fp2mTWk zO5KjwDsTU`Z1qFl7j6^bB#1XhOy%F3d_bWT6IjQw5 z5UdK!5L6RT%s{RhDhC!+77z56$FGi_8Kh?#KcU9<1azx6(wl(BlA<8?)0yT46#21M z1X0_A$NqupH=3^B8$af%;`heSLy6iI@CfMqNUJ2b1pW0$nE@~5m;W}7q&IpNe(_*9 zA$;P`POEN zp*yHRu$D7b`@7*xtv??pAAmE`WMI6|Rs$FrQ=|97I6O;761^Xwl^!DES4WE6I+2*7%c0zVK|7=YzUfC;CwNdhdEq)G1&IY|xq zHAYP(lm-Pe0kq;?!P?U_9cG8fXW8LwJWRdtME!CyjD1jKDxXGuNd0w^Q;g(v5(k-j z{>A%uKfZbW;>dU$q{$TK^~u1{^8vp{!#J3#<#cmH$~GDR)oh4wGBv#M7&N1WS?Krs z##`@P8WJ?Hg4T#f<8-2I17-W*S_D%HM-N<_a_AjiaHxefuKPW|e zQi?QB+5_RVOr9X;a98Z0jO=>O#P;{26jAiGz#FBLOU;OnvnTmOeM37X<4HC(B6)!x z66HL60~?fah%UjYR-g1Qm?4+?YF~;^jZn%ld;|Lrx`c~;J<`6?484_Svv8X&>^S)` z_XKIaoXmz2l{;7pmdCg+4!q}uL`F;)BbyU|(bvoYJW(Yatl zn0vGSi}oKjy8Uf*7M?|r>*}&LQ=Yh{5LLw*cE_36{x@R%Poz4pOD8cvaySARag$;H z0V$AUAJaqnq^_Y~$8fLjFM!{)rJu>6)R7<8S^1XdgpyI$W#%odeB<$-HQY}M%l3n^ z^mCC=*!if2sH2xfgWtNve7NT^PY-L5&WeNDVZKoLsoJ&MS0?Y0g z1KYo&GRZ%^|IM{Nv(u+{T;h0N;e7viHhYs@tO@5s;Ygbl&G+K9@gpw1kGV~8JYz6r z$FoS1$jHBhwK@a+1+<5r@r#^b2=Ku8LKi>&F-B)c1z!edv*dE-CB2I|Yx={~KeSC3 z;K1@cL=fsy;ffIUl#p@Xac$YF+H{H$kL-{I`xGajo^TC`3fn7Zmc&^S^w(Y#< zdwzUESshyJ(3<;A#rfW z11dPx?s*yV)~(nTT8O7H*xlJ}2a_hr*INz-o7j37z_X%n?qc=U>vYX3-D{VBI|H)(2S0Ls>b@oc zh4g&ogc=t^0rlZrPBh@v3+cut;xy|~))PJ2^Cly!dWi5)DfPV|t&p>uw%-fVi5w00 z-;G`i>jca$ySjFX?Xsg0So8awxv!4Z+baj?4))Iic|ZFCW*GYw0$Pb5(r2}F-=e8) zCBQX9sjG(Hekj@$OT5V?y3<2%-lzI7mgN)|9~|b-ZP@DkZG*30DXa=V-k^(usD3Zw ziPSyqqZdqNxj2(@xIyFjb2qoF48bKvbT>wSNu!3p>LRZ$4;s(1vzohH7M#g&rktBM zb#aZMcbzi1{ylGeH6KJail&W1MqpQ%!Kj$MbCK85lU>rYM zHbaje>F6kLc;Q`QiZ=F+`Pqw^(#OeSmx~RqcA-ZEjl_69<_&YVJapU}8vjvD4t{om z=O`*R;^Hh0&Wt!Sjr8-?^iiy@v|YG zb{r=^rV}0U6(>Atgr`f52{; zBv^3rc#?h-8RB)uv3!%PWx9kp)((!u)j8bq!RNg7+a8El7Qh(h1u&LNU>wf~@yP88 zVW?L*PPVB_QB{7TER5lJ$hY&z=e&?jKE4rvt{5Y`Qa{4xjzdjqJXY%?ukKTHomWr5 zv*ITMEzOvF$lS6QjLlBN8HPHl_{pO9#^T6ATjU3jK$*ykl6nGYo+j}hI&NU>O+iz| z5iS?0SbVE6Ey9$fH^C_b(UzC0n2qs*VQ*yk@b@_k8AYXm>@$xWuQ=MOthKjtuL_3R zF;|@?$}t9peNr)TnpB;5WFS;C5WA1jO8~;;)cC&!O!;Y&W-uqky#8OKoE1W~49lxp z=nJ5Ek5NqPab4UB)jnoN?^ zVE;1SrGP$&UJO%1_6lQ|Xgrr**AFLzTFm*~_-oOV(Hj}#cw7W>>9DS+PeEmH`;=9; zBBnWbx3*&{i1($0KIAUm(I46ZjO*aM$$3 zF}axkZs5UsjQkBN$QuIk;iZuw9#JgA!Ig0%kLrj_8L%D>?65z?;v}dUALg&#w-1*` z^f;$}vSuhI<)&a84Fz3Q|3UkV_sPM3!l)SAepIU5&GP#ppXxT(j#W6eY2P^3j87+C zl0|azAEa-o!IHk{lbsF~TUk49@fyAP2=Hz@O;AA?WcRw z_h5ZzJ;PYG;_?>6gw)2*m_5H5q!;zsnh zpPrrk`u@Y;e|+=)*OUKx|NP|j%cs>uHoPdI2-49ME9> zGq9)pZph@zHIKl;{T+1K*Y8eVydGd)z`6%*;;IKfp1_=5`=sNjCja*O=;Yf^~9}v)FQ5Mu1xTsjTjUq&_CtZF>E`iBY=i_{74^2 z8sp69{TpT`9K}ZLGkcytE76Y zCUdECv-*6GnqEFmle+BYIsJC^r2I83hdDPdr*BeW)o)-?+p;#WVVa=nY2goM8ArqU zwwj*n_>#GfknlUBoYdZH?nKke2LTqI#W&XUaw0HabMi z5kM$U(&R)P;OLulfCli(ijwKDGWlJqdwf9YK-hgDIEC%xcjLw2_&8^-p3V|JKHiDK zr+Z?35S92bg8dR07hkSO(?r?TD?}C^l_hB+dtP$czc?Ejj~&>s-ada))hq`JPyt=V zmhh*y;U5FqJk;$PATllZ|6yjp6dGhe@ZU9^S*B@ipw6t_vx16QR`Lw~hq=qg$66E~ zA7kbHK0iMGYMIB!(Ae?uUiPx$%^0H&*u208=}_fM12os{Icpl>jtgSsD#W^-lcJrT z(h2}))?dY%nYMAh4r^wE&~-&EtPN-p_Ehz^@%QJ34M1!LHn%O<(r(rV7q|Caz>O1y z+T}vJwj%8nM$}a)G4HvjRpqNg(>%5?mY%Xw*2qBSS`HPr*LuJ{49(CdE*G{LnU#EX zgUDTGmY=JbrBD>*S6QJjVWX?f&Kp+bDxPoo$NYWq$NZJ?$K7dL<;#P*h3@oCw5^iJ zv2Dt3qixNbvR%R{-B#OzUAt1(T5R6otWwunc!pA9R_R)ck8p8~uC@3eY^&GkT8mG* z+^wT)Es!#kDY2`0?+q9&Qsw6N!nXJ0Q=9ig`!In778nBcN!4WuJd8z|r_ZWtA$>iFvoymC9#kwq zF5nMs|IiH%BSCFjNMG7vBs|;mOlb$UD?`hsekfg2QepZb$V@NroX`@!&rHwvNZI=5!zoC|tI{is%=>gqjf`m-EHOtJM$$>5OiLt-%r8lf3T)Au4Oy@1-DE# zk?UT;avWPMP2600+o`^_rhnlQ__OsChVG&84q-8JU@Tw4t3=F$sdHgjI%W{DFc2&R zG3pDCP&^26Xz0V|B5-L43`!5-b$ePBtTS&G81#6ed4qabYJ)j%AK>&Jtg4NhDT@mu z5Db_=xOb~88s%ge%Zqi5lqZ}ysHNOWjl^$Z1h{4OyiwSr)U3U!?Adv?JP22qn>f#w zh(Hx#=GN-mVpH6~0vtPbo-KFaDEGouh3?8Tf+^8>;lZYMZXwQ@8LSI%lr&)Ga<)^Hg+^)a2L3zm<JVaY ze6YHs8FL+i9@Qp#>&jPf80GxVy>*2cAjGnPWg6qHE643P7KYE?;H{hc=t{TUExmQc z2!umj6&Lp{dh6y=-nLDf+B^2vS)}%!vVUjY!XO+=u+5xx1(8h{CpN#0vu@rL1>45i zgR^e21DNJYXWim67U4HK>lUAP@Pu{MS-0>ACv;s7#FYmbGp}*hEk0?ATF$yfQX6%^ z8=ZAvYSwbrUCCbf98XV%^3)41izj z&!$ta#g+`x2PD?D=ir;`_C6KkP}}CB ztOWYO+6G+9F2bvprT6^3V7XsyZcS?@*gIOit)8EX;GhLs<_f;S)aY9tl-%d$n#RB; zVlJzgm-$A7#Oo$SxUQ*V+OD0HrMA{^CgsG%bAs*m#p{Ea=A8&Zgu8a#V5JR>NqPFNd>z$gnRy|&56m#-klZe3y)U@7*%X}&fVth5kBA` zaIo5eMfRApr~4tqwX`F2yPQ4U3Imf`;Ar=nvzPZ}TENH7{pRfDLlKlS>*jo?d(7D@ z25M5VK_!9v%Y!S%DNJQkeZ$9ns=wvs~_BVRIR=$RVX-q9R zotg`Ohk+qSZ?EuqrYB%Oc8Rn=^$U6Jz2+`*_KHDpu#Y#+*(->^V6(ROWGpphyKJR+ zz)o`Z$_`9at1EN%%4aN58ArbOh>EsfpR-p!=<490wK;pmlMbQ6u9ve{l5&e&pR-py z>SBD`J?89P2_!Gl^&G9~^^-3#sBQT%m3x=RRAgY%0Tq40VS>(5{mT>DP?%8f%9O7S?%tciO-`ZrZh2CK46UMDA^4pdg0rN?p z;qETKZLxiuF-mWj--fR64X4C0xu7=(YBVI7apBE?tt?7z*7_Toxwr79j?Fp8Jz(FN zdkcbx+omvacJC>3Z(%5=>)6z~ZOAp&D+s(Ga7@SZ10IHaO?TId+*W4pExo+W!TR6q z^k`0G5)d-CIz5^t#z)-%)RC2DRnC{_pUr6c;#8p;vsk#bctJ9>S0hbZ?_NPg@H4D@?cFp`$pC%QwoH_2T zPm^?GSYwV?&C3m^NjWujO~P+^nxr|CnI^Nh-iB{HP0HDy64Ts4=xZP>$F!)&Y%4T9 zd-XKorqy;Qr%8Gc4rPpQSo*OfVw{LurKp!p375H74?A{qnl$e~VE$3{Y0~@|%c^uP zEI#5;``Xi_`GYD+>$=mV{7H*Y)2w-#G$-Zk+R)eXQI}gLyW7%_3%7z+`WSe{j(5kU zzA=7;^g*G3*Z*IhE`>`yGI4nO{|1%{9`6@&dI3AatLzJ@^26{3|L*sBd8Xu}_ zG7cSL3)73-BD@f^f1AT{afxd(HK;?ax8x5%6aPPZ@80INjpPgezn_BZIcMYST51Bk z>m2V+$95*p#Fw$1naPav=zt_Bu|<&zFWs%)XW!5M7C=(BrqmKashd_2T0k#AEn5jX}wOACOTZ-XCyJw|0jWt17#L zSG!}%_DXkrrd1m46tugwQ;KHDWv7(YPFb?aIyIY5i60IbV8!(y2lnS(#IDRpE*Klu z7ksZ|+_Ep0hReRFyZe&2C)O2E#oJ{^RTWhBMcdvN_Lo>M6jk3oBH)uO=8?YI3&-Ii zGyFIof~j)*u7Pe?TZDx9;Rq>ZHw@^8;eoB=%CcuF9`D95@NH)Zsw+;5A<6qOVxRXu z8{enem*N-P$I6A1WEY@v*gyP3PRSwJGJRMlG-WXs=5j17qG^27DsNtTei6;m{FXj; z9o_OI=Rf3Ig8h^MkT{64l9ld$_Gu|E%o5-+S*r3@`us_ujutzurF|C|O%r!3FCT+OM$Oe2u<-9ew>e`}#ElXZGvY z_g}w$;6Ge`{rcChU;n;?t15FDuGwL{v&X8eqnbVb`W5eDj1Y^et=VY}Jt@|< zYhS;j2RhFx>K)$s)VA|ZizJv=m)%Aut-Fl}vObWNKCC2$svIwdr)&pQKIWBnYG#<{ zT(HUZLh+F^%=5MsuJ3*%cFM=EFe{kC(Mw_LJpHJ>mM`i;l8mIrXWh?3aB7 z+gyHF4xeZQ%C8Dj=1~%2#~tj{ukOt%Wp8CeRWUy99eRtlSP@fmKP%u!7T&Tex{CFu zZIG0GMY7E9Onb7{=(3;4*e7)O(X7#BM^z;6QStlfnMR8@K&GW?4~I8DQ5$vf8Z1wE zpZkrTyoI+={s>vB+CYtQSH>>UWzR*XgRyr%E$^YV=0pdnyc5XaC8zg3NaaNEsb&cu zq;giO$PY8ikf7gI$@HI7?Ea?zHqAoXz#z-3n~OVgk5*4d(?hGNLd!zpEtW&JJt zsi~^EUxCoavZxn*HIb_9ad+<%HJ_HR!Orp8c(q>n29~U`!g?Jr02j(%VHd?WP?y10 zRZlfMR?u2$Sk~5q;)+h4Feq-(7u_<}gW|UKp^jft(Yvp`~*R93*ChcV?8L2 zkAOruOvpNE5OYUj*&W0P~R{?ud~$1quoi&TLNcS)39c+$_6V{0jl>)w$ZbT zL+=#?*hq2SS_pYAndmGf(7JcyvRJu-RsP`6^4fnw<8as?}IW@1)};o;rSmX<46c~gt0xuTO?G+b&ET{keV{PB1f4KEZM%MD;# z>cU}Ou52NVI+SWIR`IORG;*N7}w4rIX#OgJkOS(=U;zi4=c zy^dJ%;Ogs37SAXwKBV&clEowXt^!f351Pi-as}&88c09rqTy9iJx}gxx?I6}D<;3r zQm9NlRZNcBc!{*uKAJu$Zl*)d9`h(Gb#QU9MdRH0NSq1-xVZ$aO%Awp7_N z_6UeSY8}wBE1Ibq4;LqTyz-i5KaITJ%2PcEU3L_l$R}S1wCJm*n2O%oTk4jtq3dXm zL-MD0z_@q=&<4DkQR@#F7s?-GE5C|Y)VRd9VymueV^#M&Ob%XaPEiN6?1Yxrx7@D| zXh~$qsHoOnyAEi<%0h;$Y+j<)NFC5+1K6mz(PZNpRjE8+Tzy0Bo*PE81dSh4QEgq-9rBSvC2F_Y*b#mi>egd*W9qFZznGDH{*N z`-l;HWk+=t8!e75D_#RW)KNR5%ZfJuBn+Nnkr-B zZPDI4)|yj9mzAABmWlV=;ad_ZdU*$3?dYGqg~Nu%LlEz#^ToKvc;2D86viKblH-$h;`d1x@<|L8#|-R_O9oSAFk()PFK*K z?@MaxKvw%0rXkOKIpVRWrmzEM*)*-9^!n<-;O5y!ML~CyAJ^SG-O1VNEPNUrb{^R4 zNx<_W=PMxz@4q$;XbH&55ElsQTy@bu2kQO>nlQpx6g6@i~ zdN$P+o0`;9?Mk_NK;?G#gtUBl3(4l*C!_@t7+sbuctV=9QWVfp@)1r*7wk=q-DX;O zLb`m$+E#vJ@rZ?uZI}CQJgCdNtf%FZrYvLSuoKcHscxa8osceCLD;x=Gvobjvt1V@ zh7ke$wiVpKC#~j=2eLI#z_^eNmAkITE3yXdD1BYgHBCYFi>z@juzqVgvW{J3EpH1t zjI@UjrjA==E$@mUtBUpTV*iLm){1^2c8UCx-XmM~l@-?B!+?6kf^bDgO-z2jh;oiRg{=L3mfURNVhCzJN-0Y(%>p_Mqi~oo&BV+`P()jL zU>OScW{&al$j4)EI&-!+&GX6xUcPLh-phlPj};aGa*UTB-Ks7zHuo-8npR|1-ALf& z%NFXrJc|$1jdkSNjD{BaLq7!d7LuUPQeWAYu&H5EAMTh#=Alck#kZ#7zMw#WN1M5Y>LrUi}uo!eSkUSEA+xqtXmp7l@U;ZRr ze*XOC(--Oc%eOzhdCg|R*JnF2KXco`VY#0xFMSciT#jSv4V?&v<~C*h`}%C4n8phU zj28k8#uTOCNpd-)Q}C1LO4R6ZLXK$9JVa1(6N!UT4wRGPa7& za@ybAI&{Xp_?Y+!G|eqyumqP~U@D=JVPJzGJ)cH^0jL`$5)4kTCF3Ohw(1J$yrU}5 zl=WgVrP3c{uw#)t7`$N^7ymrxpL~w8*rDZy)C*YG-N9fubZ&1E8Vqcb_$voi zXH4Sa0ruk%oK8PY_{aZQd=4FBoT51YRQk_v!G@J@bqoB*Tf;-9?0U$eh*u+6$y?sx zZ4^Vro?w~N;Mylq)urnBTWg4I;1q9-inq=UkwHEQUyZ3V`5u0M&bkOj6sX#Ne<7{? zl7aO#3ISG_r;QIu-paq_Bje)Lk;Hr^%TPUT=1pL7CpdniO>X_rD z($(=Yh(V7vUb@aai(DqfMD_e4SQ zqPAMTDMF@g*t%NxP5CcFXO``JUxQz+R=lse>wWVD;Z?DwypQr7YCtYwA|dknhf{DV z1`&+21QxO3ECsRQA6=6mql3ZsS(>p!wKZv9KA$pACmj}H2E!2Cm?2x}s+e$X#@pop z8S)A~V2*k6aq%&AR$->VN82Btx8VE_2E~G&L-mQ9jqHK#7=4;Kuk!LI-@)ShZ^h)i z=JswrI-gTmE2sBo?ee>g5p&GNh=B{BU5{Q%RhLIPkJ?zJdj~14e&RSrkV0TfXeec_n(UXf6!Ets56eB4{ttw>i7HL zO;%;DFY@im|Hm94 zu;Zw8ya`1DqA)YofNL!zxm(dmUz5`)ChZg;kk8R7h2A6OWId z-u(6Lho3(WUw?S@=Qp2*UoQXj@_#))44`lKXxL(VrR#$F)`Ov!{3+v)1f5)c`8d_Da_w+#9mMuMIq>nID%hU4ZF1g$cN46nxn1;+}2HN zb6ghPEUr&yTN}j5p(^%pU@ah1wgVHxn!siQ`!0(8kq_=@-rHoK#&x_Y*Za>2U6tj- z=LEYBd7X;C>|l`4vk>K#}mHMD}y+NIm!|Vu32)#`uBs-}3)#XKb?K(^XaN7mt{-yd&pu^+BZVDr&m= zB)_uhkgE1eQhsxvBr)wVWb^9t+Yj#tlCG8;EP z2;n0}M6I-iuxx01r0O3lB5JKGhJn=Om@N%ATP5QOIvfHs* zcPY{88|boy(Zf>l6l1#9UPCd|JznF9M?|eEe^ZeWs~apD5#(U zG2b>GRJWIm*mzPm(cb3UI;kR?>K<8x8$`H|XBT{>_1j{#76bt-d>m*r%rQ%UF zLpStCt)sfofx3z;wqqb_w-ArGx@9W-#KV*C6%X7)<#shoaPjh%tm$aaJA9Tz;IU$% zOB_p9I#w*6O4V$Z;F3KThAjoQme26}YpV)uEg#XeZPwDpgJynh$kqZ|izg9?i(`)x0egv6$ceYHc2^j%wKzUcCS5?%}NU6r$Wgrr4_TS4_i@8z`k^nx;)%gjCCM2xzN= z*P2t*Q7t=R8azFCKO?UsvJBM}Dh07%rKw0kN2#N_Y(P^hHpMoc(b)~1Rc7Au5mVby zN^A8&E5DI#tC_cW5+mcFGtE^}MZZ@StNKyRFz?-jdHgiNit?$|1@rMqZyG`aMIC5H zU)I?zV8^?WJ8uUT0Z<8MzUOtBvaBy|q-9}Yo~qIV5`}oK=9;o?A4!i$W{ zzOV}j)-N*N+Lye&BV1&>=*yzQWqFZtWnYeTk@2DzmZ`5_WL)&(G#43{YD2^7#zn?5 zVTS|ODK0Wzbiy$3`bEZkEHq2}BIAX6l?;c)?q1_#)$iu^}V*2p1VI zxbtLy#xF81*<)Rm5AS^CMaGBU)Pf(nsRf_(rWQOI)GCxkR}8it_;CHZ$E!kF_7mLMJ;Kx) zQByXJJ+Xbfa+zgcbyMHt#-Jy9IJ0;SOIFN1&X}IwDOvdj7$L>jC-dRSA1^JGKh_PE zImV)UZpfp^FdZ<~RAeZIcJNwriYk;vCsg>5IWhZHp)84zsTe|gZcA2R=HSgOp+>4u zE*mh7>f@!AXOL`dxs7@Gh_W^2dgDP&*Uj&NU@xs5j=cT%PNtA_XrBw zPKrue4XNdBVEvP#f|;gSCN>W~DJrQA6)!@*Yn~Jpl8Izi zq3u{FMMWblXPljl3)fCI49n3h7aNAFV0!<$?uTaENiLTyI-!B_wEyL@1(Bg(RWK-` zUj)AN4T50<8s->gNCK-zFcr(GC3=m&<+)zp&~jY_pshtTzN!%q)$jpk5qa2xY|TaPlg!_FQ@$_oY{F zrRTiCyewz2N_`8W@o}wYrJaqEizFM3Sp84Ho4Mkw#KH5j4(PWfyPp$Q^M$_`aCsH@ z&V(1FkQm97gc&HD<{ebV$X6;B&Et#%VfB1{W83JFJN7#pVoWX48dk zB82|nCrL&NhPMf1hD2GFk~=3OkF;XB$V7mB>~h1tNW3~ z;1m;36;2k%&;aNOoAMkZw!z1BRsue@#5<8NpPPaFiiv?2lNY>73?nB&5=oK6xG-wI z?nUruEPecx>?;kJs~BydS>RuYNJGL3&#?N}DGdbV;njzqe!Beb!>7wHA9ib)z2w4z zr4uQ#bsV`Uxi>8rSs2rRS2sQcL@Wvn4jGAIv*wT&!W5Hb%#vl9_Y(WkuydGRxY3t*sfBhQ1j`(L&AM`Iv?qW)p zNPOcYmdF%5-jvz0yrVV+1h}ToElc)nm`!bFH2#339cNWHZK!mvWYJ#|*3ib{@`c&8sZQYjs`w+^jArZ%P9i))sJH z`_do(97oqAod6TQg1Ek$uQ$BUQxZ>DXW>PN1iq^S&3#AXq9amnK{MYTS|Hb8+KyPM z!~YNI|659Cl)Bk0Ux>Ip^YPPOq7SYENWb|7nFEXmiYg-dlm>whqj_BOb{~O9GSYO@ zQmo@xB}?|=nhV4|VWJ%6l_^>EIeG(9@Iv}yIiAwLtq2ZGhQHBV;BH9BBodq+@V@!{ zU=>hvoA!o&17gdz4Y)+*cFdZT{Tm2N5U_gu5-U5)*f17MG$ohJU$A|I^d*tTJ zFTcEzp8p8b2BU8W(%<;~2jx8E0?bR)eaB~BF8nW8-9Xn*99E%|7GlnJ3ZjJAe&DC` z!pK@0Rb_9XCp1hlUg8d3Wz8hdr&PYBS*8ewO&V^AOb zk_bB8d@Yr|7X9X$xuZ!Gu)7}Eh!u8VF;>~kUAU7UKS^wX0d5hDR$t1I{MWtb9F|!* z4+RZm)TZU8>E3KkTXx+LI^MJ`*bA>zFf*g&1bxBhBzGXVY^+R%yeK4}+DvFk%%1NT zKht3lj7UEj^A*YFQ)=_kui^Lb!P?B7Mj+NSr0MWl8S*tS*ZX4p8bNIE6o&MzDBCv&}K)%@;Pz0op+Sf|Y@W28#K+Hdtr_j1Ae+ka8q#P_h@* z2E{%mtT1$CLkSj$qHL~Rl#7K`&SNXa9}J|a`U_TaH~*1c4HAxMz5{iC?v*8zpO<{O zWWt0hN=&al^nzFT_0w-wDk!zl+Ddzk*>MT1cB7cMbZNzM z(;%VHGqwtT{PbK!KP3gV(}3~;L!zWvNj%!2Y|^Ei&ylD+W0=<4VdUKK01?H7fR z{cqU$GW+Sj_QX{IXn9Rq1EB4S6QOA)QY`a0o7_wG!e}lW1yw9e7ls=A$9sB$ah<$xZ5C{AP8EjWf z6%vP&pPuX~8SH>jJ<^8Cl6~XuWlAQ+)(MR1ii`QLV6(z%aGL|qk9WBPm+O8S^4jO$ zu?sBnn=3f8HM3`b!7!Nqmz9C4wMqXU+0@YyW}5vt8?8(B-BdN!7i0~vIzUPXRgJD- zL)YbFxx^)VVab@1z@zC?k_(KOTxixn6|yv+s9YGTqG^V4ob$CMdr`T#1S4|9R)}Ft zd5%QPa=+La4G&@j&;0d=&)}`!1b#@9e*YkN!Qj+jA(2Jk&E@uKAgaQD8a^7LIeYQV zuetO`XPkxejD)?L{@3uAC}TUPxd*zK_a8r%&M3cuDc^x*TP`qx#xVK6DR*a>^va)mw=Z~<6Z3~(NNbnE)71HPad%8b$E%y2FprBh`-zPXwDz10 z4o^G+ZNqo}VMa(ENIk!qk=^DC*4YbbMq`hu*EqW7+uwNy@^@#KATy@dB=Ft5^`zX_ z%g>D#$6M*Pp(6L5T79KrGjh3G0fWU9dH&Eq+IWd?7>E2kHjLnfXl^zVX|3e~;De?; z9w0o>gtU60%!}Z{$Bv@mv(rF^W*%?ybM}pS6ls)iRr9Nl*aj?k6NL%zZ>xUA;9I_* z3aZ`@ohW9v!lyK+S`GGN$Eul~Fe`x_>U?^he0!)!sAzzHcPZ8+%EF3%EFZOGe@~*; ze*bm7^=C6OqHjYls(AfpJj4A}7AJgkp8cZyb?#tm{vrHj{$cjRhY!EJ8_Iuq`SRYF zJ_DQcV{vWQN8tagxPE=@vDI!Ze=dK@|7Cykfqe#I7|(~vn7m~9DAbiTTQM;OzZ@x= zBdeCKtCnJU7<-nDuxi*~AlaFR`@gwF`}_|B&*6z}SvMp3<= zHHJ2`#@YnAfx)A!1v#he2;@twhh`47*#j`!IEfW3!*r~JXinLAMDqy@KUdz8*rv&= zU`=Q3GRn>)dVlrR+B|_3WXh-(+??{<;Kr?o z+fwA3a0BJJ!L76%Zj5B59^8!b+~8JQ4>vaI8Qz@o+~C$)4>vovs|R=Gf#L>XH{kZN z(BUK6`mGX$J54nK2g;5BUO3i8MNA{#ACt?PO0Dy_RIFoDt%GV#*?m-d$<)-_39YH) zT6>#0W%r@2bg0!GHEi32z9JiHt@SylJV*3#m(VxsoonQj=ZL=2CG@R&ws}r@j_9i$ zLLb3@_0VUO=Y;-h*d$uHILK_f1fY5wtvThnN&xPX1Yo7s(P&P2t`eYhNdgp9&#~u} zXCwhz_UKw~1t>;M8<0_+5&kY^$zB};Up4Bgf1vC>_~owNj{TpJ)`S#%fBE*OpFdS) zV{DKeLkd(@m!XyvsFejHHVC^a1-Nl3cvM}WBCW0z0Oi?AfzmNufEBZ*6aeMfN`cxj zDNt}tqmWUan-nBLbhCalP8XH~#i|uzn^T^r2;jy=;BmEpYS!EyWR&M70$cV6x}aoz zV|Zpg*lq1jZQJJ5wl%eF+qS2+ZQHiZscpCCdB1}{=U;L~dneh+&bn9b1;g5WOhi;@ z-Hp5TYH-8V&)ncSLu&*wH^Paad)B44XCk*Qi%DbG8bq;uN^W7)Ra}yhqG7c?Lp2Lx zvgMmhSG5e&Fh$iFecx@;w#`|oIlVqbNq}&IxBJ~unXjYD&Yn}0-`RX1rjd8R;Mcsf zUE5LLZ!VMg^2ole>1VY2XY*P1>7-Z9+w*69TAp1h`BQqCLq4B*Z1>B(c7m`6=Svss z9~@0k74bedCrB<)Ld;|*FWeqh@yPMTH zuZ~R^N^?M=>A5K$MFqguLhMU7i4^3VO^=doMr>k|E)TWT$@cm%$4MDJP z{h4d&;qxDUS_@dgIzPunLJKh4$vNVSEg6UB`dgk88=kD5OfAA2e904A+f}N|U;s+v zk&i(f4XMdyA#Aq1D^)G7yj3(&zIKSN&{oAy>K7CJXTs7GPJJ5@g`^SXt@uQ1a&>g= zHOP4TEo!mLk`?^-923u89l25=~F3g(gFWB-;lXZeRsyuD4 zYCK~%pqgWikM`d2r6=SBb%rq0Xy!N7QYXArp#3e%D;#b$Ruk`K!1J<~d0U|T^Kr9A znu~apfz8bv;^;tt{FYIHUMsyo{6ESxfn2r!&MDo}t+Srn3?3z$LHQ03LtlaKHdtc$ z{XW79)N_atvg&WG^R*w?NxX;auWeJw@7}Wt+jRxhD11v(qi)XdIS{Vf zKod@AAkA7rd542E-V9W;s!&aUT%NBn+R_uFD1mspIkZLxos+BMxuf?tSST;dPP6O*>!{50?1h=i&8Uln+!|U5UIQhewp_W{F!gD{rJfusD()36WtqR7 zp(gw=7my^sthgS9SKZ7V7x3E#smMxw&br1IVdgCdQKZaj_<0=`DY6PHW%(u;A3_Pt z(hP*Mdxpiaj90FWqVJNoa$0zQoX;(!1Rxh*jjyIDCFM8P5Q|JObr#oPBBb+(C%>cu~lN-NO zseUh1wwV2A?8yb)C&Oo=C7g4)$6ad4a<8Or;&F@$zsQ*}O;CH3@5s!iu>Q&zJ&+Wj zGY|U2q(@+Chw3yMpEGe`>9-A*;TzClN&4Ox1Bxp!k2#(Uqpej zm(;FX>e{}LbQNDTZ|w-^sPSYl;14?3Bucrna$~X>$+W40f5uROb#UY5C2r`RiuM)x zQ$jzjmnw%)h9EKXok`6c(qqo(a|RNU+OIGa+_Yu3-s-rK=)E4nf8)KH@ZTtc4mL_l zH<cQmt1i-uGeQ` z?poX;2(ln&hVd@}dCy?^X^+*lsUm;Kr{Vf@+5Wa#Bl53?=EJU4GBX6>|A zFM~BXg2?dsierE zti_ZVZm6*`#U~&kQSA>cpoSEm0BGnUld5hxjYJZohG&W@22T#k*7#Y@ZACn=$yzQ)d+5o32Wi6S0BgUaRxR#UV52FX_Ph%~^)u^5 zBjTpgzS&1|X!+C#oIsn|QWCIp9@6L6YP!rrPB@9^ogu&9ddWTp+~dyw=?v~EU(tTS zI{KZz`|X*(jGQ;-C<~jbm{SSX)Tn=o4lSUm0}c3$|N6VC!EF-*PwfBUp{1IY0BOPX zZ`p=XP$}Hu<0<7K`P|eIw1$1{P+R;mN{)m4j^F~ejgM`t&&v@*#Le^Hcp@Gqho z`%doTNiJ}Iw1-3H5g8!XwC07ODmL!G@?|Ziosgvcqzo71%bl^6`64Yxd2i8o;2=pr z1oh;g*p~_a!&FNJch z#Oykx%zRI(XCOEKGfx?ANIYVkHN9Pr8I?SvoY-zSD5XtZCT7>z$rM3bPef)Q1WJTH zvzA+|feRzo<))bR&gf-k$NnKuj*V zjRZpk4muo#75?)*9~2ZK=p=;@)=^kaDhI5lC?EN-x;b+PNh9+%gqG-YR*BJH##1}i z#O6)1P*cK+I2*ZQ`FltRYs(5L?|fDG_$>I@rbaArod^~1;@1N#|Z=ry403n@3yB= zass9cC5AvD-%AK|^&y9mBg)uv#7uvJ=067g)^htXw+>kc&L>9~ zC>}&i?{{``&*Rb8*+zyjTUHO{<)=pkg-1Je1KN> zi6v<%K8zbd5d{5-^wgC(NeZSr{{!GXz{m!35SNnPmaPR|&L@=dW=|XX(O+I0N2JlG z#~G|gY53UWgRdcoC3B{)#dDHkL*wVmd?|_rQB4W*v$7JYv0%qC^LHU>u_xZ)35BK! zwZnXVWUGIC#b-cYoMc#J%}>u<@(6#~zr8=@mK3Hkta_=mMhhvM#cq18w>XVi2pv#9 zc-M!lgTrp4Wi65t5`7{$(TwVQjR$_5`U-_7g`b6RySDYWm=9-3F#<_K@xi8z8BUIZ zy%LHdDZ{k_uCG)$4+=un(~0tj{CX=d0|rM16}u;uS`XEwWEUaqVU^?6n}e-4ATBi~ zJr;;gv!dkWtic;LWFNJ1|GUS1uqkrsPB7vkqg20rNf`vLH+TQSh_01i3fvwwl)Sa| z(tx9Q0@~!qMP?*eiE2!XvJ>zMc{eEkz1_U#3)4p6#6OPQUykwnFd}WByMRu$+y zxFi-2g(L5wKJCx*OLVXw63)7D*4Idu&Yf~3O`TODjo`&&H-T;dI5~R@2dN}D3D-#L zWs}03Tu*@mPcIIMj`ixxNW=gsnPEO2KTy}7z-vT~1A*4~HMFfUU(u59-A}^Xag{mp zu4f7v`SwFv^q^06GDB|l<6Nq0)e75T@+nmK2&~dp<+?*!VaB=vham{3N1RMPAzJzq z<=gj~Ay&Tn(Qi_8Uai9p2e1zG9gCt1!$f|FdVuh^B#_uYtNyQq1PswLRZ@rrZe(uX zMi)fsO8P#LlZ(SarOGSgf?kDZQqba-Jg!4zzQ}_-v z2nHXs?><1n=Z~p}&+9SA>+*YN>!8%9@h2wgQ=eMxn2xKj)}_twjSSuK3YH!}ob1He zF$|Xg!8ah5G$%nlA3Z~4W}plc!4DB)$6dW?TkAqTdgjCqN4s&relkAy*CKHzMhk2| z)kyN7ZJu{JhBKJP7boACMZnVnQxM*=tfNLN!ynI!SPPVKKipgo>5kS*?ui!)*vBy_ zfFlTxSSJPJ{eyOROW;h}$pyHO*$gHvCPmF0Nb(R!@*p*j^Vmwqh+dC5)<2K5?UrNF zzsuY!7Zf`7>ws;FAU_nP&NI?XTewg zQ>CoTHb2o=`E@+q^69xd8w@ ztsaiYb6wH9tv4-Bf~P(O>^t?*>#E)AvEHya!4g- z`PqFND6wFJRFozfR0WO~JI01YvHW6+iQkAOEqf&oUdSIrUyW%`LW4NxM*UZrl}Ih| zm{Au0-cjb4!U%po(?ieJFjz)F?4^dluH>2U0&s4rkz<-aI|+>Ucfn9F45=}{lSDN3 zO5F-59{2tZS(h4%h&UmOlZF>2N4{ARNbdC_Yx{iC9E7PSfE$t3Dl;O)@5RD4exfiV zZMjAuBTto?6Yp2c7dRZhU~`t4jV?rnBrqehIRA)!Y5ipHk>f>>xZy@ji`O}1J6uv2 zYWtpk#Ye5#n<;$Yh|bb^K4M2mz`Ck05LfAn1| zC9D-}@6CDabyb;a7gg#_%GOX}vp*9!6TaTALv_~!SKLk;neor2T)?EGqKzgTg~3S( zgdJtexsiFc+~tp#1_R@NfF?Kq>L{>GZ?U=Pz6>1{SB;=He3lJGsa9(t#~ z@BhW|MCrC%jV1l%2If$@O-mT5qST;lA(qv+v&O{gA12L0hU|)znES-1a7ZC7sQ2JY zQ{@$B#fZ{=qY7VCzXdusC2jF@bfU;0e&4*(p0jQzCd%H`G{pK<=%4)ptUN2`a z03i+L!9RKoKBAljX^J(Femj!tugB2$i3~VLmtL=B-qOfY1QYrcN93tHeOeYfnSmH+ zSTH@34@t^7O9m`HetmT{PC?}Kg&wo3Iu(1%-e+t%aunSSEt(%`R1DN>)g3F$P?c*P z(@2E;VEwIP4Cn;GWZ(_zt?!uOJS@QWd6fo*jtN#SZ(B!;X_ABdlOZY+apmZJKWhA* z4l$rAETqt?$b2`PVu2p(3@#uG?siaH;BXo*z=*brS9)m9g%XqThO7hmaMO#jwKSog zI46Nb%QVthEG|LIme^*DEgQjr*_#8lE+X^eSHeC?NU={w=}6{5~4 zfHVi}<{Vv5D5MWDKR$+n@{>-G7_pmaih`{NjsG>pqHw@Vfvn8(%mWVjsyVmN1$6_*G=Df-Y43<4I}}|ni5_(< zsJkA)@~c3E)yZDQJ4^$=05Jv^Qy1HasXW`PkQSNNasQFnz66|i)&cB0z0C7K6jAUL z%qei>6o@rBL%*`dlq_915aFW&8QUAa7GsaOuQ-OiNVkBLu~FA;kwT`BO+Z7 zAO3{LTijm$CG@?|k`*Od8P|+G{F;`Wc2082a17V)ny`O)Io)5FBn?_1F`m4BPlE!S zU4%A2$0ibdU0W8#!2G9)tfeL!|%)y zm&{(I&mbpf8pV?T-A$eBvKC!nj;i3QQ^(1W89I|DN*C3bF|3Hao=!0{9+#NILWV*; zbN=-amsw@xU3vkc zjk9sej9UG(`uGeJt^tc|#(mMTY>)wFaMQHyQ3-p1xL%{M^;aj@Rx&wxwO8$c0Umq) zA=qav#4&BQT+pd$%d244kzU3E1yPJpmUT*1_278p`UP7bMaR={dy(t3OX6SlDDM~I z+bgi0GfStSVmY(#9aPPorT~wW*|l_jDCF8L#*BAm*Xr{T^y!Mq`r|k27q1oGjF(!5TpM>8G~M|L&@fKciU8L0iHz#BG)oAv z$m6k!;){x6b^MVW6Ot;;7833X6QoG^#bvgbwh|+&6borNo~qd0PxtX}IDoSEYEo)q z(xvy*eKP`qpKtp$eh607gm;87>~}bqbR#VkpQ5{iJurJ08AXQNmkRAW%ZcKC|Gk}{ zGz%%bs1xrfIlZg))VVXBYSXSvf-g zJ_Fph`=)I_*Ni!H-Ao@#?Ow*BAcu+K zP+%}c?}xC77sE78{we}G^$gs`qJ)isMx5-ze!o@$Y^(Ig`vm-u4WF~JM25qNWs zm+TGCM2))f`PiYDf6WFcum}L_a6f>i{wz3|HsY@uUV$vCoCiWD;{iK;bame^8VHCF zrMu}LXf_zE3}z_0;#*~MCgC7=)=h{j*Tr=BbJoOgUc(){sJhhmF7nyfl$+}C1T>Ji zU)ME!lE(OOOP5cY(W7T~?dtV>ot?1p}mzQw~Ilw*Z9ByyOYeN7+dM)D~9(?(azodq0g<4W?(q@k*}r=-HU7 zoI~{5CcSr|*_}R1HH+w7EWI}v#?}ux!2+gth4_kjn85Soul)UZ9tF4O<4*pXVV4YnjMMvt z;wtPisEe#0*aGHSBrvGT>sO*`HW$7t&mX~@5cmp`KYvhVmu0~Q?KUi%&-{9zf6c;r_9J)8b!#7zBEqlAI0R2o5}_ zb<6q-Ax~S$6>PSqI9*l9W`en*H0-y@NxrRx=>mFHy`%l-5TLAjE8CI}Y*B@){o?~5 z8t=-wZVh|T;dtvYme}wR&XJS#*ruNuMnOSGH4A%*ZqlF<6!ULT7%wGR2b6o(NU%-Z zt~}|eQ=ja5GL9>6E+K!#sfE)-T>r*b!U_B}vdY4f4qQ(GFho~y7CRt44(>T`|3Ub6 zW)ctRw{+F?uL2&Sd#x-*?bzOU`9X!4DgVL$+`<{H{&CeoLF@n79(*$aQTA#hyv5d|q7W18 z`}ZhVNeJ={h7|7jstk|UeE_Y)+kd1Qa-*XDtRatwAW!SBsWy011+FjK zY+}QQKr3rfJ+dVzK~H8+-X+wM%@k!Vh_Y~4>YD~&vP}IdDONMAUWN+QR&WKoNKr#5 z8(LhlqBay;*WlSq1XNTWyFB<(iII#K(9~!#@;7|-cfPuUEB@Vc3D%!Ln}dvN#zG#3 zy8mgjCE$F_`1&8VfX>v^Rzq?Y^6>ZlT>-7B2s1;brJO~uZ+`N}F_i_7DFET>EB}i!vCzK2 z`%j=%Erd4B8+|T-oq2CN3b+TdxnL}%WI25k+@v(`+L!v(2Y60KwDu?gAIo<3UuWkK`=+h!na1wG{UT00;0$eul8 zI^5>Eakl0qcyp!*;0@5##yQ2f%hr4Mtb$JuK_~=!OuA3M=z#4aY~4yC%(!LR8W}H z&fO^BY$K*TD(DjqBdG`VU%_58v;clmE7Vi?)k~Y&?B9>1mDQ4xQ&^vB?9hNGu9IhQ zlM9S1coih^=4SN@HW)j4%Fa+STaFZ$2`YyFg4DFz(pPiG(*ZwEM3@}U?*w}sNH~-l}9{z3&WLN;gljKaKtNNU-f-9U#BS~5Y+w9NeQC3e070%#&n1&}HlH#tKU@+0p;1n{i~a>1?f7q3H*F8SpCYIC++lYEEn>$#`JR zM;U)u&adB@Cu;U1wBkx*r!u%DqZ^FeGq&g2JJc?o(*bXswiDHE$JSuFeI*F^B#ad= znrj25Joodm8ZzBojY_a9)SS%X61nGjU#rN=1lHLUCMEQ`|4_jFd}RV0(sz8G zA^Ec!w-_j#^b7|!bPB9wCdI)Nldm#Iz)aK#Mr?9Z$PQODHYzB%VvM2BNG`@-T>e8( z#tT}oyh?6r@^rI9P0k*!_hDIBSkhE{0t{X}>n0aTLP44%9j?>F%)wd*Rs^al`?>WO zYBHX%V?yUr4^_Fw6FWG}#5Krz&+qcdGPN8f!8dn>>(}RCld+P=+EtF;RHDv5p)%x= zQn)zeomR+U>NQcah8rSGva2N6&(7y<^lW8hG(-7gL_@B42*y)@+v({_irSaksgs!# z#D<}<@g76ihB*W($wyp8fosu4RnoItW0LfVlu>w|cUoctOT8r4YR^XIE>5=`;1=>d z5hx#xtnGCy{Ptm$)o_B#?%$MWv&Y0S^=RjTV1mvluaaDk1@4@amwj&Z{5@BCWO<_^ z=>U`4@ZVf?JJx7P&H~|QS%MymXzmXaa}+ZWr%sO%T>&NT2OeyWS+8>nxd-h=)Rm;~ zHl~twYgClfkwGe%WWsTr_7Wzo^D=JduM2~Nc>f_rHKc%Ad%qM2pINV9{MUV2pzViX z&x5eI?YL1J{*qY!C=H!83ZMxoGkh=YNI(2ySSok}BHi#)qi3QCeCW`EFVjtkbwe>R zZ{R@m;p@3cA=_$>f`5OmXFByEG%j^-d-p@`Q^)n3z8cA`gCj!$S8D@4BMkYMt1!0I4#J#?uO@Yxhhhu{g=nJH%7 zQ_Y2!x9MMc4LX3?9fuW)y_GPT7n+Xdser}6^|3Y>IG#j<@)r+oi4c@LXpk%*|7VTP z6U%-)fc3sL5HkY7ms^<4Bz2lIH$asSBg`@2_iIR3^8lSM=q}L>ZKTs~C;ABfK{MN% zkio%fP;EU86wy}%-TB1WdCXY-x6QsoOzF|Ff?PIW?wV+tvDjuM>j7@Qm$Y9)ss zuo^!tP7qw8Q6oD6-#UZ2GGHj=4P6vt3LMFL~_rD zzqg+97=L5>XkoI|2x=Kwi`yoWP~_@;y*~NlcYSgaDO>+|iIgoD(k7}Hubuq70q$=3 z%{Dg2RyO@R9Rj`qf8M@+fB4zQ+jBAeuw@5^n;O5SM-*l20VU50`D5spf2q%EUms_J zw1Eu$*i^6igl^C1S~@hh@QZr~U)k^yRun_EY^}d3%2238a>X86YPz78&IobY@=re+L1+# z!@&y?11Y-@_)fMQdY1w!1g^b14r@V!n&bCT_*N+Pe#8_ybJH~*vy$jg-$qO4aZcG# zL1$2DgF&G|t+D1s!@@(TLHxx81bj~(+SSbcVD`!m50>d!wbxn;UlZqHMoE^=^?K8- z8@kSEEYr&Ae@g3Pp_AA*{@D~96A8HhJ<0m54F&{3r>DOJ=9he=8XrC+Fm`{owmpls zJ1ZwMy^vgRn4If^Ys2QMTfQg*EdNv^sA#1Idbb8q$z3ba78NKeZd#N6bf1Zh7`GIB zV?^DJoJ_Z*2Rb^24xcTH;s?7;p|43-`z;O4ObpWD4=2wQE@bRszUprdV+@h8ykcnH z8bOLo5ghB`mc{^%L=2Ioq`amqBns%G|8g^Bqt51ym7!%D*X?D3u&hoyb9cWVJM5CU zH~Q3uxvUV$OPWvPB~$EIHAKE{mn#ykUVQ4uR(YiSxDGd|e@WgWMvHk$6~HRE?p%R; z!}Wy>nwqNidtQ0Bu)aC6x!s*jA6>ZFy(_<*(y<-iT+XWrXX1J_p{_&6)Z#SRzL>ML zx)?byxW~1`QiKy;iK9X9_1@86Zn<-k=Wq5IZ05+)7ZA#ATp}$)Hi6stw>sdSfj+|( zliZu4G|x4NFF+5_P~sOWMU|KBw`h?Ga_s>gCR`AkLva z(Zjc4s|d~}b*0}CwMAFhUpplm6xHixpxbgIoszJRBp5IuaYO_hc2M#-@#Pi<=)3V+Yw4HOEEl8PKjH z&?sV_ywDG0WO7&-*>F6g4L`=H-ARD{8kb)v%H4?s7pT3}gjlW+5N>Y6sXJM=JW)h0>SX8t_CVL5ZSIX{AkSBmB<<>Me@Fsn88^jg1> zCVv%9w)HQ_)f#wrK{M$!If-)DL;EISs+aT1 zHyAyRf>6|}{lg`agGQ5U2d5gX1#p=j;3U6roqV@hBB?|F;%5#^*=%7?tKNKyr+zy- z%orwe!p;%qPlKJMJ?8@})u7>eTynGkjt42!TG(HPhG(xYiHS~jF{K?B!bC@m&zFPw z%%p$bdfm5U5gbKVNBd z4_To^fbHD?9{)_$<9BUA6Vwq{V7QCxhgTzF{El%C}>N8D+uBT6)f-R^>Q)pQv zE^ISz+8H&L=yrBXCGaEMptLZ)I@wxNR(YepL7c6Jkij|h+D>ND(;+G&tkVbNSSdJE z?ew94>v+6e&YBUh^Xh~Bui)9L_*>ckZ67q{r`=TB^eOSjZh`Uj#^Nq)8Fn=HA0JDZ zpm_fYI%1JE5Hyhudnp}8k0l}t$*2SCkJfncS|$Cj)d8E*?P`24iFQ(zVa zhLed5EEjKjECi_i;F-m;OU&x*@tq>J?AceWTr0j9YaE}GUtSsbUBgA!Bl(tuF&z!2 zG~^{20SN`a8!XQ{EU8k$>dn@nDpRh;Vl6d-4F{cFWJzky)-p^s5>XJMx(6$yHR(;* zKH|6kH>ui*!}sBEGxe`loJ5KXJb}Ya0u=&3+NRHxhl>AXZ;ZFA%PfmA;(9^H(O837 zxx|Z%CDs@*j8ILn;p$+^@o&o~)s)%Sk#94#<0s!g1ef>EHMI{v$swEVH|rz9bbqZQ zVY5N_r)<60Pr-H#+_;{m@n;X3i(n(DsLy`$V}uoplFSI>LWN!3xq|;MzhdtH_Y+M4T!8Yxi z@s%q=vsn_YaIG{A+Bs+ePTp2=Xsi248x#5LfOGM_qp5P7vEW4iiged9`+&`OhZ2IV zn1`+X;mROOyFOzd*&p$o}z~%8-l`qe=e%hhIDni6dWAOl%!- z+UvC~t;@&gi$G3$`C+-)Q~v$A8YzFn-~Ki(c0AIn;`0EeMD=)gLGXQVPaX2|23hc( zBX#FREt)z|LuYJ=cNT&^+7MY6@HE)`8u%afCO%k0v)7{Xa9ei)4(R;tnz{?>HnsBHsFj4!cCAn(GF`blaPZ$S#eBdvB1097Nyqzm4+u8vwRZA?Z-zMU>e6^UBFjRFi z)hWU`|MgkZ?~(H!)U*a|^S~K?dvt80k?^*uSf*ZEV=N_T|4Yr-lf(_a` z+sl_DV2boMOSYn&k_206`bGc0v`B5`2bPIOq243OR@J@AG62B0E_mCSt)@r>T%vRn z%mF|?pv`SGMgM}8t?aXLp#L{~)v}dki9pj8ZiGVsYy;#!Y@-Rr=bp*^!AoO#9f{jFcyPe@(&`i&8pgM40@0U1P}MS&KXH>v6>?ZMZ)D2g3@?YVR>@wm&p681;ttA&ug5H^V(eU z3+nqgy-;d_dIymKb#VggeQo)_evZo?J!owCQ5Dap=iZ{;cc3^_lX_H-oErrTIDQP! zAb5?W9r1;iyk9Kt7H93WKEVmMgY&I=gOgwqMLZ;;1NdXfVMQm4lqE{}f|IVgqdw-+ zhSd9;{j0a|z2vb^)BypeV6eEa(pqp~fAxF~S3RW)DIJSR13mD5Dj!YY%z?yIVAd%_ zUyfO?q_y7a#Cq!TslnhzAPA=x%%zlei3Ax_#ZZGYbRa0#9~xF?AGEGW5yTqZYwbSt z2IGLiMd<)Gd}RTnWNp%Q^C5bKz+(SyT(5_dBQl}gluKVZb2tVR@pZViE*NkBpk6qV z^hCx(yau$_X zr`AaZ8F6d2Dq(hY8$0C(!PMX)xAIvg?=OZ-pR7<;(|Yg{UM2hba;X z