@@ -76,7 +76,7 @@ def test_debug_mode_mm(self):
7676 _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32]
7777 _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32]
7878 aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32]
79- <method 'sum' of 'torch._C.TensorBase' objects>(dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P
79+ <method 'sum' of 'torch._C.TensorBase' objects>(dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P(sum)
8080 aten::sum(dt$6: f32[8, 32]| S(0))
8181 aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""" ,
8282 )
@@ -179,8 +179,8 @@ def test_debug_mode_backward(self):
179179 <method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 8]| S(0))
180180 aten::sum(dt: f32[8, 8]| S(0))
181181 aten::sum(t: f32[1, 8])
182- torch._tensor.backward(dt: f32[]| P, gradient=None, retain_graph=None, create_graph=False, inputs=None)
183- aten::ones_like(dt: f32[]| P, pin_memory=False, memory_format=torch.preserve_format)
182+ torch._tensor.backward(dt: f32[]| P(sum) , gradient=None, retain_graph=None, create_graph=False, inputs=None)
183+ aten::ones_like(dt: f32[]| P(sum) , pin_memory=False, memory_format=torch.preserve_format)
184184 aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)
185185 aten::expand(dt: f32[]| R, [8, 8])
186186 aten::expand(t: f32[], [8, 8])
@@ -189,9 +189,9 @@ def test_debug_mode_backward(self):
189189 aten::clone(t: f32[8, 1])
190190 aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu)
191191 redistribute_input(t: f32[8, 8], trace: R->S(0))
192- aten::detach(t: f32[8, 1])
193192 aten::split.Tensor(t: f32[8, 8], 1)
194193 aten::clone(t: f32[1, 8])
194+ aten::detach(t: f32[8, 1])
195195 aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu)
196196 aten::detach(t: f32[1, 8])""" ,
197197 )
@@ -253,50 +253,50 @@ def test_debug_mode_einsum(self):
253253 self .assertExpectedInline (
254254 debug_mode .debug_string (),
255255 """\
256- torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8]| PR , dt: f32[8, 4, 4]| RP)
257- aten::unsqueeze(dt: f32[16, 6, 8]| PR , 3)
256+ torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8]| P(sum)R , dt: f32[8, 4, 4]| RP(sum) )
257+ aten::unsqueeze(dt: f32[16, 6, 8]| P(sum)R , 3)
258258 aten::unsqueeze(t: f32[16, 6, 8], 3)
259- aten::unsqueeze(dt: f32[16, 6, 8, 1]| PR , 4)
259+ aten::unsqueeze(dt: f32[16, 6, 8, 1]| P(sum)R , 4)
260260 aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
261- aten::permute(dt: f32[16, 6, 8, 1, 1]| PR , [0, 1, 3, 4, 2])
261+ aten::permute(dt: f32[16, 6, 8, 1, 1]| P(sum)R , [0, 1, 3, 4, 2])
262262 aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
263- aten::unsqueeze(dt: f32[8, 4, 4]| RP, 3)
263+ aten::unsqueeze(dt: f32[8, 4, 4]| RP(sum) , 3)
264264 aten::unsqueeze(t: f32[8, 4, 4], 3)
265- aten::unsqueeze(dt: f32[8, 4, 4, 1]| RP, 4)
265+ aten::unsqueeze(dt: f32[8, 4, 4, 1]| RP(sum) , 4)
266266 aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
267- aten::permute(dt: f32[8, 4, 4, 1, 1]| RP, [3, 4, 1, 2, 0])
267+ aten::permute(dt: f32[8, 4, 4, 1, 1]| RP(sum) , [3, 4, 1, 2, 0])
268268 aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
269- aten::permute(dt: f32[16, 6, 1, 1, 8]| PR , [0, 1, 4, 2, 3])
269+ aten::permute(dt: f32[16, 6, 1, 1, 8]| P(sum)R , [0, 1, 4, 2, 3])
270270 aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
271- aten::view(dt: f32[16, 6, 8, 1, 1]| PR , [1, 96, 8])
271+ aten::view(dt: f32[16, 6, 8, 1, 1]| P(sum)R , [1, 96, 8])
272272 aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
273- aten::permute(dt: f32[1, 1, 4, 4, 8]| RP, [4, 2, 3, 0, 1])
273+ aten::permute(dt: f32[1, 1, 4, 4, 8]| RP(sum) , [4, 2, 3, 0, 1])
274274 aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
275- aten::view(dt: f32[8, 4, 4, 1, 1]| RP, [1, 8, 16])
275+ aten::view(dt: f32[8, 4, 4, 1, 1]| RP(sum) , [1, 8, 16])
276276 aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
277- aten::bmm(dt: f32[1, 96, 8]| PR , dt: f32[1, 8, 16]| RP)
278- redistribute_input(0, PR -> S(2)[0]S(2)[1])
279- redistribute_input(t: f32[1, 96, 8], trace: PR ->S(2)R->S(2)[0]S(2)[1])
277+ aten::bmm(dt: f32[1, 96, 8]| P(sum)R , dt: f32[1, 8, 16]| RP(sum) )
278+ redistribute_input(0, P(sum)R -> S(2)[0]S(2)[1])
279+ redistribute_input(t: f32[1, 96, 8], trace: P(sum)R ->S(2)R->S(2)[0]S(2)[1])
280280 aten::chunk(t: f32[1, 96, 8], 4, 2)
281281 aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
282282 _c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
283283 _c10d_functional::wait_tensor(t: f32[1, 96, 2])
284284 aten::chunk(t: f32[1, 96, 2], 2, 2)
285285 aten::clone(t: f32[1, 96, 1])
286- redistribute_input(1, RP -> S(1)[0]S(1)[1])
287- redistribute_input(t: f32[1, 8, 16], trace: RP->S(1)P->S(1)[0]S(1)[1])
286+ redistribute_input(1, RP(sum) -> S(1)[0]S(1)[1])
287+ redistribute_input(t: f32[1, 8, 16], trace: RP(sum) ->S(1)P(sum) ->S(1)[0]S(1)[1])
288288 aten::chunk(t: f32[1, 8, 16], 4, 1)
289289 aten::clone(t: f32[1, 2, 16])
290290 aten::chunk(t: f32[1, 2, 16], 2, 1)
291291 aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
292292 _c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
293293 _c10d_functional::wait_tensor(t: f32[1, 1, 16])
294294 aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
295- aten::view(dt: f32[1, 96, 16]| PP , [16, 6, 1, 4, 4])
295+ aten::view(dt: f32[1, 96, 16]| P(sum)P(sum) , [16, 6, 1, 4, 4])
296296 aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
297- aten::permute(dt: f32[16, 6, 1, 4, 4]| PP , [0, 1, 3, 4, 2])
297+ aten::permute(dt: f32[16, 6, 1, 4, 4]| P(sum)P(sum) , [0, 1, 3, 4, 2])
298298 aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
299- aten::view(dt: f32[16, 6, 4, 4, 1]| PP , [16, 6, 4, 4])
299+ aten::view(dt: f32[16, 6, 4, 4, 1]| P(sum)P(sum) , [16, 6, 4, 4])
300300 aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])""" ,
301301 )
302302
0 commit comments