Skip to content

Commit b449ade

Browse files
[HIGH] Patch pytorch for CVE-2025-55552 (microsoft#15166)
Co-authored-by: jslobodzian <[email protected]>
1 parent b9cf434 commit b449ade

File tree

2 files changed

+392
-1
lines changed

2 files changed

+392
-1
lines changed

SPECS/pytorch/CVE-2025-55552.patch

Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
From c849ccbd342b6067d19d5805c6614a21a4f0b49f Mon Sep 17 00:00:00 2001
2+
From: Sam Larsen <[email protected]>
3+
Date: Fri, 25 Jul 2025 09:31:15 -0700
4+
Subject: [PATCH] Fix full_like decomposition to preserve strides (#158898)
5+
6+
Summary:
7+
See original PR at: https://github.com/pytorch/pytorch/pull/144765, which landed internally but was reverted due to test failures. Addressing reviewer comments and trying again.
8+
9+
Upstream Patch Reference: https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/159294.patch & https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/158898.patch
10+
---
11+
test/inductor/test_torchinductor.py | 51 ++++++-
12+
...st_torchinductor_codegen_dynamic_shapes.py | 1 +
13+
test/test_decomp.py | 11 +-
14+
torch/_inductor/decomposition.py | 139 +++++++++++-------
15+
torch/_inductor/lowering.py | 1 -
16+
5 files changed, 145 insertions(+), 58 deletions(-)
17+
18+
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
19+
index f1cfb90c..e2daa183 100644
20+
--- a/test/inductor/test_torchinductor.py
21+
+++ b/test/inductor/test_torchinductor.py
22+
@@ -264,6 +264,8 @@ def check_model(
23+
check_gradient=False,
24+
check_has_compiled=True,
25+
output_process_fn_grad=lambda x: x,
26+
+ # TODO: enable this for all tests
27+
+ exact_stride=False,
28+
):
29+
kwargs = kwargs or {}
30+
torch._dynamo.reset()
31+
@@ -282,6 +284,11 @@ def check_model(
32+
):
33+
has_lowp_args = True
34+
- return x.float()
35+
+ # Preserve strides when casting
36+
+ result = torch.empty_strided(
37+
+ x.size(), x.stride(), device=x.device, dtype=torch.float
38+
+ )
39+
+ result.copy_(x)
40+
+ return result
41+
else:
42+
return x
43+
44+
@@ -353,6 +361,7 @@ def check_model(
45+
rtol=rtol,
46+
equal_nan=True,
47+
exact_dtype=exact_dtype,
48+
+ exact_stride=exact_stride,
49+
)
50+
# In case of input mutations, check that inputs are the same
51+
self.assertEqual(
52+
@@ -363,6 +372,7 @@ def check_model(
53+
equal_nan=True,
54+
# our testing sometimes uses higher precision inputs for the reference
55+
exact_dtype=False,
56+
+ exact_stride=exact_stride,
57+
)
58+
else:
59+
for correct_val, actual_val in zip(correct_flat, actual_flat):
60+
@@ -376,6 +386,8 @@ def check_model(
61+
assert correct_val.layout == actual_val.layout
62+
if exact_dtype:
63+
assert correct_val.dtype == actual_val.dtype
64+
+ if exact_stride:
65+
+ assert correct_val.stride() == actual_val.stride()
66+
67+
if check_gradient:
68+
actual = output_process_fn_grad(actual)
69+
@@ -423,6 +435,7 @@ def check_model(
70+
rtol=rtol,
71+
equal_nan=True,
72+
exact_dtype=exact_dtype,
73+
+ exact_stride=exact_stride,
74+
)
75+
76+
torch._dynamo.reset()
77+
@@ -446,6 +459,8 @@ def check_model_cuda(
78+
check_gradient=False,
79+
check_has_compiled=True,
80+
output_process_fn_grad=lambda x: x,
81+
+ # TODO: enable this for all tests
82+
+ exact_stride=False,
83+
):
84+
kwargs = kwargs or {}
85+
if hasattr(model, "to"):
86+
@@ -470,6 +485,7 @@ def check_model_cuda(
87+
check_gradient=check_gradient,
88+
check_has_compiled=check_has_compiled,
89+
output_process_fn_grad=output_process_fn_grad,
90+
+ exact_stride=exact_stride,
91+
)
92+
93+
if check_lowp:
94+
@@ -500,6 +516,7 @@ def check_model_cuda(
95+
check_gradient=check_gradient,
96+
check_has_compiled=check_has_compiled,
97+
output_process_fn_grad=output_process_fn_grad,
98+
+ exact_stride=exact_stride,
99+
)
100+
101+
102+
@@ -4194,6 +4211,18 @@ class CommonTemplate:
103+
104+
self.common(fn, (torch.randn(8),))
105+
106+
+ def test_full_like_transposed(self):
107+
+ def fn(a):
108+
+ return torch.full_like(a, 3)
109+
+
110+
+ self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True)
111+
+
112+
+ def test_full_like_sliced(self):
113+
+ def fn(a):
114+
+ return torch.full_like(a, 3)
115+
+
116+
+ self.common(fn, (torch.rand(3, 4)[:, ::2],), exact_stride=True)
117+
+
118+
def test_full_truncation(self):
119+
def fn(a):
120+
return a + torch.full_like(a, 7.777)
121+
@@ -5872,7 +5901,7 @@ class CommonTemplate:
122+
model = Model()
123+
x = torch.rand(10, 3, 0)
124+
125+
- self.common(model, (x,))
126+
+ self.common(model, (x,), exact_stride=True)
127+
128+
def test_randint(self):
129+
@torch.compile(fullgraph=True)
130+
@@ -5907,9 +5936,21 @@ class CommonTemplate:
131+
@config.patch(fallback_random=True)
132+
def test_like_rands(self):
133+
def fn(x):
134+
- return torch.rand_like(x), torch.randn_like(x)
135+
+ return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11)
136+
+
137+
+ self.common(fn, [torch.zeros([20, 20])], exact_stride=True)
138+
+
139+
+ @config.patch(fallback_random=True)
140+
+ @xfail_if_mps # 100% are not close
141+
+ def test_like_rands_sliced(self):
142+
+ def fn(x):
143+
+ return (
144+
+ torch.randn_like(x),
145+
+ torch.randn_like(x),
146+
+ torch.randint_like(x, 1, 11),
147+
+ )
148+
149+
- self.common(fn, [torch.zeros([20, 20])])
150+
+ self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True)
151+
152+
def test_like_rands2(self):
153+
# rand_like with kwargs `device` of str type
154+
@@ -5924,6 +5965,8 @@ class CommonTemplate:
155+
a0 = fn(x).clone()
156+
a1 = fn(x).clone()
157+
self.assertFalse(torch.allclose(a0, a1))
158+
+ self.assertEqual(a0.shape, a1.shape)
159+
+ self.assertEqual(a0.stride(), a1.stride())
160+
161+
@requires_cuda()
162+
def test_like_rands3(self):
163+
@@ -5940,6 +5983,8 @@ class CommonTemplate:
164+
a1 = test_like_rands_on_different_device("cuda", "cpu")
165+
self.assertTrue(a0.device.type == "cuda")
166+
self.assertTrue(a1.device.type == "cpu")
167+
+ self.assertEqual(a0.shape, a1.shape)
168+
+ self.assertEqual(a0.stride(), a1.stride())
169+
170+
def test_max_pool2d_with_indices_backward(self):
171+
def fn(a, b, c):
172+
diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
173+
index fa4b8040..ae52a802 100644
174+
--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
175+
+++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
176+
@@ -162,6 +162,7 @@ test_failures = {
177+
"test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"),
178+
"test_bucketize_int_dynamic_shapes": TestFailure("cpu"),
179+
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda")),
180+
+ "test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
181+
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda")),
182+
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda")),
183+
"test_max_pool2d6_dynamic_shapes": TestFailure(("cpu", "cuda")),
184+
diff --git a/test/test_decomp.py b/test/test_decomp.py
185+
index 10df8b8b..9ad20995 100644
186+
--- a/test/test_decomp.py
187+
+++ b/test/test_decomp.py
188+
@@ -693,7 +693,16 @@ class TestDecomp(TestCase):
189+
assert len(real_out) == len(decomp_out)
190+
191+
if do_relative_check:
192+
- upcast = partial(upcast_tensor, dtype=torch.float64)
193+
+ device_arg = kwargs.get("device", None)
194+
+
195+
+ def upcast(x):
196+
+ if (isinstance(x, Tensor) and x.device.type == "mps") or (
197+
+ device_arg and torch.device(device_arg).type == "mps"
198+
+ ):
199+
+ return upcast_tensor(x, dtype=torch.float32)
200+
+ else:
201+
+ return upcast_tensor(x, dtype=torch.float64)
202+
+
203+
real_out_double, _ = tree_flatten(
204+
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
205+
)
206+
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
207+
index 88a56dea..6f396f50 100644
208+
--- a/torch/_inductor/decomposition.py
209+
+++ b/torch/_inductor/decomposition.py
210+
@@ -343,35 +343,19 @@ def view_copy_default(self, size):
211+
def view_copy_dtype(self, dtype):
212+
return self.to(dtype).clone()
213+
214+
+def _get_shape_permutation_like(
215+
+ self: torch.Tensor, layout: torch.layout
216+
+) -> tuple[utils.ShapeType, utils.StrideType]:
217+
+ assert layout == torch.strided
218+
219+
-def get_like_layout(
220+
- tensor: torch.Tensor, memory_format: Optional[torch.memory_format]
221+
-) -> torch.memory_format:
222+
- # TODO: _to_copy tensor to stride permutation
223+
- if memory_format is torch.preserve_format or memory_format is None:
224+
- return utils.suggest_memory_format(tensor)
225+
- else:
226+
- return memory_format
227+
-
228+
-
229+
-@register_decomposition(aten.rand_like)
230+
-def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
231+
- return torch.rand(
232+
- [*self.size()],
233+
- dtype=dtype or self.dtype,
234+
- device=device or self.device,
235+
- **kwargs,
236+
- ).to(memory_format=get_like_layout(self, memory_format))
237+
+ physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
238+
+ shape = [self.shape[l] for l in physical_layout]
239+
240+
+ permutation = [0] * len(shape)
241+
+ for p, l in enumerate(physical_layout):
242+
+ permutation[l] = p
243+
244+
-@register_decomposition(aten.randn_like)
245+
-def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
246+
- return torch.randn(
247+
- [*self.size()],
248+
- dtype=dtype or self.dtype,
249+
- device=device or self.device,
250+
- **kwargs,
251+
- ).to(memory_format=get_like_layout(self, memory_format))
252+
+ return (shape, permutation)
253+
254+
255+
@register_decomposition(aten.full_like)
256+
@@ -386,40 +370,89 @@ def full_like(
257+
requires_grad=False,
258+
memory_format=torch.preserve_format,
259+
):
260+
- return torch.full(
261+
- [*self.size()],
262+
- fill_value,
263+
- dtype=dtype or self.dtype,
264+
- layout=layout or self.layout,
265+
- device=device or self.device,
266+
- requires_grad=requires_grad,
267+
- ).to(memory_format=get_like_layout(self, memory_format))
268+
+ dtype = self.dtype if dtype is None else dtype
269+
+ layout = self.layout if layout is None else layout
270+
+ device = self.device if device is None else device
271+
+
272+
+ if memory_format != torch.preserve_format:
273+
+ result = torch.full(
274+
+ self.shape,
275+
+ fill_value,
276+
+ dtype=dtype,
277+
+ layout=layout,
278+
+ device=device,
279+
+ pin_memory=pin_memory,
280+
+ requires_grad=requires_grad,
281+
+ )
282+
+ return result.to(memory_format=memory_format)
283+
284+
+ else:
285+
+ shape, permutation = _get_shape_permutation_like(self, layout)
286+
+ result = torch.full(
287+
+ shape,
288+
+ fill_value,
289+
+ dtype=dtype,
290+
+ layout=layout,
291+
+ device=device,
292+
+ pin_memory=pin_memory,
293+
+ requires_grad=requires_grad,
294+
+ )
295+
+ if permutation == list(range(len(permutation))):
296+
+ return result
297+
+ return result.permute(permutation).clone()
298+
299+
-@register_decomposition(aten.randint_like.default)
300+
-def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs):
301+
- return aten.randint.low(
302+
- 0,
303+
- high,
304+
- [*self.size()],
305+
- dtype=dtype or self.dtype,
306+
- device=device or self.device,
307+
+
308+
+def _rand_like(
309+
+ rand_fn: Callable[..., torch.Tensor],
310+
+ self: torch.Tensor,
311+
+ *,
312+
+ dtype: Optional[torch.dtype] = None,
313+
+ device: Optional[torch.device] = None,
314+
+ memory_format: torch.memory_format = torch.preserve_format,
315+
+ **kwargs: Any,
316+
+) -> torch.Tensor:
317+
+ dtype = self.dtype if dtype is None else dtype
318+
+ device = self.device if device is None else device
319+
+
320+
+ if memory_format != torch.preserve_format:
321+
+ return rand_fn(
322+
+ self.shape,
323+
+ dtype=dtype,
324+
+ device=device,
325+
+ **kwargs,
326+
+ ).to(memory_format=memory_format)
327+
+
328+
+ shape, permutation = _get_shape_permutation_like(self)
329+
+ result = rand_fn(
330+
+ shape,
331+
+ dtype=dtype,
332+
+ device=device,
333+
**kwargs,
334+
- ).to(memory_format=get_like_layout(self, memory_format))
335+
+ )
336+
+ if permutation == list(range(len(permutation))):
337+
+ return result
338+
+ return result.permute(permutation).clone()
339+
340+
341+
+@register_decomposition(aten.rand_like)
342+
+def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
343+
+ return _rand_like(torch.rand, self, **kwargs)
344+
+
345+
+
346+
+@register_decomposition(aten.randn_like)
347+
+def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
348+
+ return _rand_like(torch.randn, self, **kwargs)
349+
+
350+
+
351+
+@register_decomposition(aten.randint_like.default)
352+
+def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor:
353+
+ return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs)
354+
+
355+
@register_decomposition(aten.randint_like.low_dtype)
356+
def randint_like_low(
357+
- self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs
358+
-):
359+
- return aten.randint.low(
360+
- low,
361+
- high,
362+
- [*self.size()],
363+
- dtype=dtype or self.dtype,
364+
- device=device or self.device,
365+
- **kwargs,
366+
- ).to(memory_format=get_like_layout(self, memory_format))
367+
+ self: torch.Tensor, low: int, high: int, **kwargs: Any
368+
+) -> torch.Tensor:
369+
+ return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs)
370+
371+
372+
@register_decomposition(aten.randint.default)
373+
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
374+
index e6f2e8d0..dbd9aa28 100644
375+
--- a/torch/_inductor/lowering.py
376+
+++ b/torch/_inductor/lowering.py
377+
@@ -2550,7 +2550,6 @@ def _full(fill_value, device, dtype, size):
378+
)
379+
380+
381+
-@register_lowering(aten.full_like, type_promotion_kind=None)
382+
def full_like(x, fill_value, **kwargs):
383+
return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
384+
385+
--
386+
2.45.4
387+

0 commit comments

Comments
 (0)