|
| 1 | +From c849ccbd342b6067d19d5805c6614a21a4f0b49f Mon Sep 17 00:00:00 2001 |
| 2 | +From: Sam Larsen < [email protected]> |
| 3 | +Date: Fri, 25 Jul 2025 09:31:15 -0700 |
| 4 | +Subject: [PATCH] Fix full_like decomposition to preserve strides (#158898) |
| 5 | + |
| 6 | +Summary: |
| 7 | +See original PR at: https://github.com/pytorch/pytorch/pull/144765, which landed internally but was reverted due to test failures. Addressing reviewer comments and trying again. |
| 8 | + |
| 9 | +Upstream Patch Reference: https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/159294.patch & https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/158898.patch |
| 10 | +--- |
| 11 | + test/inductor/test_torchinductor.py | 51 ++++++- |
| 12 | + ...st_torchinductor_codegen_dynamic_shapes.py | 1 + |
| 13 | + test/test_decomp.py | 11 +- |
| 14 | + torch/_inductor/decomposition.py | 139 +++++++++++------- |
| 15 | + torch/_inductor/lowering.py | 1 - |
| 16 | + 5 files changed, 145 insertions(+), 58 deletions(-) |
| 17 | + |
| 18 | +diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py |
| 19 | +index f1cfb90c..e2daa183 100644 |
| 20 | +--- a/test/inductor/test_torchinductor.py |
| 21 | ++++ b/test/inductor/test_torchinductor.py |
| 22 | +@@ -264,6 +264,8 @@ def check_model( |
| 23 | + check_gradient=False, |
| 24 | + check_has_compiled=True, |
| 25 | + output_process_fn_grad=lambda x: x, |
| 26 | ++ # TODO: enable this for all tests |
| 27 | ++ exact_stride=False, |
| 28 | + ): |
| 29 | + kwargs = kwargs or {} |
| 30 | + torch._dynamo.reset() |
| 31 | +@@ -282,6 +284,11 @@ def check_model( |
| 32 | + ): |
| 33 | + has_lowp_args = True |
| 34 | +- return x.float() |
| 35 | ++ # Preserve strides when casting |
| 36 | ++ result = torch.empty_strided( |
| 37 | ++ x.size(), x.stride(), device=x.device, dtype=torch.float |
| 38 | ++ ) |
| 39 | ++ result.copy_(x) |
| 40 | ++ return result |
| 41 | + else: |
| 42 | + return x |
| 43 | + |
| 44 | +@@ -353,6 +361,7 @@ def check_model( |
| 45 | + rtol=rtol, |
| 46 | + equal_nan=True, |
| 47 | + exact_dtype=exact_dtype, |
| 48 | ++ exact_stride=exact_stride, |
| 49 | + ) |
| 50 | + # In case of input mutations, check that inputs are the same |
| 51 | + self.assertEqual( |
| 52 | +@@ -363,6 +372,7 @@ def check_model( |
| 53 | + equal_nan=True, |
| 54 | + # our testing sometimes uses higher precision inputs for the reference |
| 55 | + exact_dtype=False, |
| 56 | ++ exact_stride=exact_stride, |
| 57 | + ) |
| 58 | + else: |
| 59 | + for correct_val, actual_val in zip(correct_flat, actual_flat): |
| 60 | +@@ -376,6 +386,8 @@ def check_model( |
| 61 | + assert correct_val.layout == actual_val.layout |
| 62 | + if exact_dtype: |
| 63 | + assert correct_val.dtype == actual_val.dtype |
| 64 | ++ if exact_stride: |
| 65 | ++ assert correct_val.stride() == actual_val.stride() |
| 66 | + |
| 67 | + if check_gradient: |
| 68 | + actual = output_process_fn_grad(actual) |
| 69 | +@@ -423,6 +435,7 @@ def check_model( |
| 70 | + rtol=rtol, |
| 71 | + equal_nan=True, |
| 72 | + exact_dtype=exact_dtype, |
| 73 | ++ exact_stride=exact_stride, |
| 74 | + ) |
| 75 | + |
| 76 | + torch._dynamo.reset() |
| 77 | +@@ -446,6 +459,8 @@ def check_model_cuda( |
| 78 | + check_gradient=False, |
| 79 | + check_has_compiled=True, |
| 80 | + output_process_fn_grad=lambda x: x, |
| 81 | ++ # TODO: enable this for all tests |
| 82 | ++ exact_stride=False, |
| 83 | + ): |
| 84 | + kwargs = kwargs or {} |
| 85 | + if hasattr(model, "to"): |
| 86 | +@@ -470,6 +485,7 @@ def check_model_cuda( |
| 87 | + check_gradient=check_gradient, |
| 88 | + check_has_compiled=check_has_compiled, |
| 89 | + output_process_fn_grad=output_process_fn_grad, |
| 90 | ++ exact_stride=exact_stride, |
| 91 | + ) |
| 92 | + |
| 93 | + if check_lowp: |
| 94 | +@@ -500,6 +516,7 @@ def check_model_cuda( |
| 95 | + check_gradient=check_gradient, |
| 96 | + check_has_compiled=check_has_compiled, |
| 97 | + output_process_fn_grad=output_process_fn_grad, |
| 98 | ++ exact_stride=exact_stride, |
| 99 | + ) |
| 100 | + |
| 101 | + |
| 102 | +@@ -4194,6 +4211,18 @@ class CommonTemplate: |
| 103 | + |
| 104 | + self.common(fn, (torch.randn(8),)) |
| 105 | + |
| 106 | ++ def test_full_like_transposed(self): |
| 107 | ++ def fn(a): |
| 108 | ++ return torch.full_like(a, 3) |
| 109 | ++ |
| 110 | ++ self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True) |
| 111 | ++ |
| 112 | ++ def test_full_like_sliced(self): |
| 113 | ++ def fn(a): |
| 114 | ++ return torch.full_like(a, 3) |
| 115 | ++ |
| 116 | ++ self.common(fn, (torch.rand(3, 4)[:, ::2],), exact_stride=True) |
| 117 | ++ |
| 118 | + def test_full_truncation(self): |
| 119 | + def fn(a): |
| 120 | + return a + torch.full_like(a, 7.777) |
| 121 | +@@ -5872,7 +5901,7 @@ class CommonTemplate: |
| 122 | + model = Model() |
| 123 | + x = torch.rand(10, 3, 0) |
| 124 | + |
| 125 | +- self.common(model, (x,)) |
| 126 | ++ self.common(model, (x,), exact_stride=True) |
| 127 | + |
| 128 | + def test_randint(self): |
| 129 | + @torch.compile(fullgraph=True) |
| 130 | +@@ -5907,9 +5936,21 @@ class CommonTemplate: |
| 131 | + @config.patch(fallback_random=True) |
| 132 | + def test_like_rands(self): |
| 133 | + def fn(x): |
| 134 | +- return torch.rand_like(x), torch.randn_like(x) |
| 135 | ++ return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11) |
| 136 | ++ |
| 137 | ++ self.common(fn, [torch.zeros([20, 20])], exact_stride=True) |
| 138 | ++ |
| 139 | ++ @config.patch(fallback_random=True) |
| 140 | ++ @xfail_if_mps # 100% are not close |
| 141 | ++ def test_like_rands_sliced(self): |
| 142 | ++ def fn(x): |
| 143 | ++ return ( |
| 144 | ++ torch.randn_like(x), |
| 145 | ++ torch.randn_like(x), |
| 146 | ++ torch.randint_like(x, 1, 11), |
| 147 | ++ ) |
| 148 | + |
| 149 | +- self.common(fn, [torch.zeros([20, 20])]) |
| 150 | ++ self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True) |
| 151 | + |
| 152 | + def test_like_rands2(self): |
| 153 | + # rand_like with kwargs `device` of str type |
| 154 | +@@ -5924,6 +5965,8 @@ class CommonTemplate: |
| 155 | + a0 = fn(x).clone() |
| 156 | + a1 = fn(x).clone() |
| 157 | + self.assertFalse(torch.allclose(a0, a1)) |
| 158 | ++ self.assertEqual(a0.shape, a1.shape) |
| 159 | ++ self.assertEqual(a0.stride(), a1.stride()) |
| 160 | + |
| 161 | + @requires_cuda() |
| 162 | + def test_like_rands3(self): |
| 163 | +@@ -5940,6 +5983,8 @@ class CommonTemplate: |
| 164 | + a1 = test_like_rands_on_different_device("cuda", "cpu") |
| 165 | + self.assertTrue(a0.device.type == "cuda") |
| 166 | + self.assertTrue(a1.device.type == "cpu") |
| 167 | ++ self.assertEqual(a0.shape, a1.shape) |
| 168 | ++ self.assertEqual(a0.stride(), a1.stride()) |
| 169 | + |
| 170 | + def test_max_pool2d_with_indices_backward(self): |
| 171 | + def fn(a, b, c): |
| 172 | +diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py |
| 173 | +index fa4b8040..ae52a802 100644 |
| 174 | +--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py |
| 175 | ++++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py |
| 176 | +@@ -162,6 +162,7 @@ test_failures = { |
| 177 | + "test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"), |
| 178 | + "test_bucketize_int_dynamic_shapes": TestFailure("cpu"), |
| 179 | + "test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda")), |
| 180 | ++ "test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), |
| 181 | + "test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda")), |
| 182 | + "test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda")), |
| 183 | + "test_max_pool2d6_dynamic_shapes": TestFailure(("cpu", "cuda")), |
| 184 | +diff --git a/test/test_decomp.py b/test/test_decomp.py |
| 185 | +index 10df8b8b..9ad20995 100644 |
| 186 | +--- a/test/test_decomp.py |
| 187 | ++++ b/test/test_decomp.py |
| 188 | +@@ -693,7 +693,16 @@ class TestDecomp(TestCase): |
| 189 | + assert len(real_out) == len(decomp_out) |
| 190 | + |
| 191 | + if do_relative_check: |
| 192 | +- upcast = partial(upcast_tensor, dtype=torch.float64) |
| 193 | ++ device_arg = kwargs.get("device", None) |
| 194 | ++ |
| 195 | ++ def upcast(x): |
| 196 | ++ if (isinstance(x, Tensor) and x.device.type == "mps") or ( |
| 197 | ++ device_arg and torch.device(device_arg).type == "mps" |
| 198 | ++ ): |
| 199 | ++ return upcast_tensor(x, dtype=torch.float32) |
| 200 | ++ else: |
| 201 | ++ return upcast_tensor(x, dtype=torch.float64) |
| 202 | ++ |
| 203 | + real_out_double, _ = tree_flatten( |
| 204 | + func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) |
| 205 | + ) |
| 206 | +diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py |
| 207 | +index 88a56dea..6f396f50 100644 |
| 208 | +--- a/torch/_inductor/decomposition.py |
| 209 | ++++ b/torch/_inductor/decomposition.py |
| 210 | +@@ -343,35 +343,19 @@ def view_copy_default(self, size): |
| 211 | + def view_copy_dtype(self, dtype): |
| 212 | + return self.to(dtype).clone() |
| 213 | + |
| 214 | ++def _get_shape_permutation_like( |
| 215 | ++ self: torch.Tensor, layout: torch.layout |
| 216 | ++) -> tuple[utils.ShapeType, utils.StrideType]: |
| 217 | ++ assert layout == torch.strided |
| 218 | + |
| 219 | +-def get_like_layout( |
| 220 | +- tensor: torch.Tensor, memory_format: Optional[torch.memory_format] |
| 221 | +-) -> torch.memory_format: |
| 222 | +- # TODO: _to_copy tensor to stride permutation |
| 223 | +- if memory_format is torch.preserve_format or memory_format is None: |
| 224 | +- return utils.suggest_memory_format(tensor) |
| 225 | +- else: |
| 226 | +- return memory_format |
| 227 | +- |
| 228 | +- |
| 229 | +-@register_decomposition(aten.rand_like) |
| 230 | +-def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): |
| 231 | +- return torch.rand( |
| 232 | +- [*self.size()], |
| 233 | +- dtype=dtype or self.dtype, |
| 234 | +- device=device or self.device, |
| 235 | +- **kwargs, |
| 236 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 237 | ++ physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self) |
| 238 | ++ shape = [self.shape[l] for l in physical_layout] |
| 239 | + |
| 240 | ++ permutation = [0] * len(shape) |
| 241 | ++ for p, l in enumerate(physical_layout): |
| 242 | ++ permutation[l] = p |
| 243 | + |
| 244 | +-@register_decomposition(aten.randn_like) |
| 245 | +-def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): |
| 246 | +- return torch.randn( |
| 247 | +- [*self.size()], |
| 248 | +- dtype=dtype or self.dtype, |
| 249 | +- device=device or self.device, |
| 250 | +- **kwargs, |
| 251 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 252 | ++ return (shape, permutation) |
| 253 | + |
| 254 | + |
| 255 | + @register_decomposition(aten.full_like) |
| 256 | +@@ -386,40 +370,89 @@ def full_like( |
| 257 | + requires_grad=False, |
| 258 | + memory_format=torch.preserve_format, |
| 259 | + ): |
| 260 | +- return torch.full( |
| 261 | +- [*self.size()], |
| 262 | +- fill_value, |
| 263 | +- dtype=dtype or self.dtype, |
| 264 | +- layout=layout or self.layout, |
| 265 | +- device=device or self.device, |
| 266 | +- requires_grad=requires_grad, |
| 267 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 268 | ++ dtype = self.dtype if dtype is None else dtype |
| 269 | ++ layout = self.layout if layout is None else layout |
| 270 | ++ device = self.device if device is None else device |
| 271 | ++ |
| 272 | ++ if memory_format != torch.preserve_format: |
| 273 | ++ result = torch.full( |
| 274 | ++ self.shape, |
| 275 | ++ fill_value, |
| 276 | ++ dtype=dtype, |
| 277 | ++ layout=layout, |
| 278 | ++ device=device, |
| 279 | ++ pin_memory=pin_memory, |
| 280 | ++ requires_grad=requires_grad, |
| 281 | ++ ) |
| 282 | ++ return result.to(memory_format=memory_format) |
| 283 | + |
| 284 | ++ else: |
| 285 | ++ shape, permutation = _get_shape_permutation_like(self, layout) |
| 286 | ++ result = torch.full( |
| 287 | ++ shape, |
| 288 | ++ fill_value, |
| 289 | ++ dtype=dtype, |
| 290 | ++ layout=layout, |
| 291 | ++ device=device, |
| 292 | ++ pin_memory=pin_memory, |
| 293 | ++ requires_grad=requires_grad, |
| 294 | ++ ) |
| 295 | ++ if permutation == list(range(len(permutation))): |
| 296 | ++ return result |
| 297 | ++ return result.permute(permutation).clone() |
| 298 | + |
| 299 | +-@register_decomposition(aten.randint_like.default) |
| 300 | +-def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs): |
| 301 | +- return aten.randint.low( |
| 302 | +- 0, |
| 303 | +- high, |
| 304 | +- [*self.size()], |
| 305 | +- dtype=dtype or self.dtype, |
| 306 | +- device=device or self.device, |
| 307 | ++ |
| 308 | ++def _rand_like( |
| 309 | ++ rand_fn: Callable[..., torch.Tensor], |
| 310 | ++ self: torch.Tensor, |
| 311 | ++ *, |
| 312 | ++ dtype: Optional[torch.dtype] = None, |
| 313 | ++ device: Optional[torch.device] = None, |
| 314 | ++ memory_format: torch.memory_format = torch.preserve_format, |
| 315 | ++ **kwargs: Any, |
| 316 | ++) -> torch.Tensor: |
| 317 | ++ dtype = self.dtype if dtype is None else dtype |
| 318 | ++ device = self.device if device is None else device |
| 319 | ++ |
| 320 | ++ if memory_format != torch.preserve_format: |
| 321 | ++ return rand_fn( |
| 322 | ++ self.shape, |
| 323 | ++ dtype=dtype, |
| 324 | ++ device=device, |
| 325 | ++ **kwargs, |
| 326 | ++ ).to(memory_format=memory_format) |
| 327 | ++ |
| 328 | ++ shape, permutation = _get_shape_permutation_like(self) |
| 329 | ++ result = rand_fn( |
| 330 | ++ shape, |
| 331 | ++ dtype=dtype, |
| 332 | ++ device=device, |
| 333 | + **kwargs, |
| 334 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 335 | ++ ) |
| 336 | ++ if permutation == list(range(len(permutation))): |
| 337 | ++ return result |
| 338 | ++ return result.permute(permutation).clone() |
| 339 | + |
| 340 | + |
| 341 | ++@register_decomposition(aten.rand_like) |
| 342 | ++def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
| 343 | ++ return _rand_like(torch.rand, self, **kwargs) |
| 344 | ++ |
| 345 | ++ |
| 346 | ++@register_decomposition(aten.randn_like) |
| 347 | ++def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
| 348 | ++ return _rand_like(torch.randn, self, **kwargs) |
| 349 | ++ |
| 350 | ++ |
| 351 | ++@register_decomposition(aten.randint_like.default) |
| 352 | ++def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor: |
| 353 | ++ return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs) |
| 354 | ++ |
| 355 | + @register_decomposition(aten.randint_like.low_dtype) |
| 356 | + def randint_like_low( |
| 357 | +- self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs |
| 358 | +-): |
| 359 | +- return aten.randint.low( |
| 360 | +- low, |
| 361 | +- high, |
| 362 | +- [*self.size()], |
| 363 | +- dtype=dtype or self.dtype, |
| 364 | +- device=device or self.device, |
| 365 | +- **kwargs, |
| 366 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 367 | ++ self: torch.Tensor, low: int, high: int, **kwargs: Any |
| 368 | ++) -> torch.Tensor: |
| 369 | ++ return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs) |
| 370 | + |
| 371 | + |
| 372 | + @register_decomposition(aten.randint.default) |
| 373 | +diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py |
| 374 | +index e6f2e8d0..dbd9aa28 100644 |
| 375 | +--- a/torch/_inductor/lowering.py |
| 376 | ++++ b/torch/_inductor/lowering.py |
| 377 | +@@ -2550,7 +2550,6 @@ def _full(fill_value, device, dtype, size): |
| 378 | + ) |
| 379 | + |
| 380 | + |
| 381 | +-@register_lowering(aten.full_like, type_promotion_kind=None) |
| 382 | + def full_like(x, fill_value, **kwargs): |
| 383 | + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) |
| 384 | + |
| 385 | +-- |
| 386 | +2.45.4 |
| 387 | + |
0 commit comments