Skip to content

Commit d57c79e

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass] Fix regression from f7ad69f (pytorch#161398)
Pull Request resolved: pytorch#161398 Approved by: https://github.com/henrylhtsang
1 parent 1a566c4 commit d57c79e

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

test/inductor/test_cutlass_evt.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -345,29 +345,31 @@ def test_example_tensor_creation(self):
345345
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
346346
create_example_tensors,
347347
)
348+
from torch._inductor.virtualized import V
348349

349-
row_major_buf0 = MockComputedBuffer(
350-
"buf0", None, torch.float32, (3, 4, 1), (4, 1, 0)
351-
)
352-
col_major_buf1 = MockComputedBuffer(
353-
"buf1", None, torch.float32, (3, 2, 1), (1, 3, 0)
354-
)
355-
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
356-
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
357-
result = create_example_tensors(
358-
buffer_renames, name_to_buffer, lambda x: int(x)
359-
)
360-
self.assertEqual(result["acc"].shape, (3, 4, 1))
361-
self.assertEqual(result["acc"].stride, (4, 1, 0))
362-
self.assertEqual(
363-
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
364-
)
350+
with V.set_graph_handler(MockGraphHandler({})):
351+
row_major_buf0 = MockComputedBuffer(
352+
"buf0", None, torch.float32, (3, 4, 1), (4, 1, 0)
353+
)
354+
col_major_buf1 = MockComputedBuffer(
355+
"buf1", None, torch.float32, (3, 2, 1), (1, 3, 0)
356+
)
357+
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
358+
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
359+
result = create_example_tensors(
360+
buffer_renames, name_to_buffer, lambda x: int(x)
361+
)
362+
self.assertEqual(result["acc"].shape, (3, 4, 1))
363+
self.assertEqual(result["acc"].stride, (4, 1, 0))
364+
self.assertEqual(
365+
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
366+
)
365367

366-
self.assertEqual(result["buf1"].shape, (3, 2, 1))
367-
self.assertEqual(result["buf1"].stride, (1, 3, 0))
368-
self.assertEqual(
369-
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
370-
)
368+
self.assertEqual(result["buf1"].shape, (3, 2, 1))
369+
self.assertEqual(result["buf1"].stride, (1, 3, 0))
370+
self.assertEqual(
371+
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
372+
)
371373

372374
@unittest.skipIf(not SM90OrLater, "need sm_90")
373375
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")

0 commit comments

Comments
 (0)