@@ -345,29 +345,31 @@ def test_example_tensor_creation(self):
345
345
from torch ._inductor .codegen .cuda .cutlass_lib_extensions .evt_extensions import (
346
346
create_example_tensors ,
347
347
)
348
+ from torch ._inductor .virtualized import V
348
349
349
- row_major_buf0 = MockComputedBuffer (
350
- "buf0" , None , torch .float32 , (3 , 4 , 1 ), (4 , 1 , 0 )
351
- )
352
- col_major_buf1 = MockComputedBuffer (
353
- "buf1" , None , torch .float32 , (3 , 2 , 1 ), (1 , 3 , 0 )
354
- )
355
- buffer_renames = {"buf0" : "buf0" , "buf1" : "buf1" , "acc" : "buf0" }
356
- name_to_buffer = {"buf0" : row_major_buf0 , "buf1" : col_major_buf1 }
357
- result = create_example_tensors (
358
- buffer_renames , name_to_buffer , lambda x : int (x )
359
- )
360
- self .assertEqual (result ["acc" ].shape , (3 , 4 , 1 ))
361
- self .assertEqual (result ["acc" ].stride , (4 , 1 , 0 ))
362
- self .assertEqual (
363
- result ["acc" ].element , torch_dtype_to_cutlass_type (torch .float32 )
364
- )
350
+ with V .set_graph_handler (MockGraphHandler ({})):
351
+ row_major_buf0 = MockComputedBuffer (
352
+ "buf0" , None , torch .float32 , (3 , 4 , 1 ), (4 , 1 , 0 )
353
+ )
354
+ col_major_buf1 = MockComputedBuffer (
355
+ "buf1" , None , torch .float32 , (3 , 2 , 1 ), (1 , 3 , 0 )
356
+ )
357
+ buffer_renames = {"buf0" : "buf0" , "buf1" : "buf1" , "acc" : "buf0" }
358
+ name_to_buffer = {"buf0" : row_major_buf0 , "buf1" : col_major_buf1 }
359
+ result = create_example_tensors (
360
+ buffer_renames , name_to_buffer , lambda x : int (x )
361
+ )
362
+ self .assertEqual (result ["acc" ].shape , (3 , 4 , 1 ))
363
+ self .assertEqual (result ["acc" ].stride , (4 , 1 , 0 ))
364
+ self .assertEqual (
365
+ result ["acc" ].element , torch_dtype_to_cutlass_type (torch .float32 )
366
+ )
365
367
366
- self .assertEqual (result ["buf1" ].shape , (3 , 2 , 1 ))
367
- self .assertEqual (result ["buf1" ].stride , (1 , 3 , 0 ))
368
- self .assertEqual (
369
- result ["buf1" ].element , torch_dtype_to_cutlass_type (torch .float32 )
370
- )
368
+ self .assertEqual (result ["buf1" ].shape , (3 , 2 , 1 ))
369
+ self .assertEqual (result ["buf1" ].stride , (1 , 3 , 0 ))
370
+ self .assertEqual (
371
+ result ["buf1" ].element , torch_dtype_to_cutlass_type (torch .float32 )
372
+ )
371
373
372
374
@unittest .skipIf (not SM90OrLater , "need sm_90" )
373
375
@unittest .skipIf (not try_import_cutlass (), "requires cutlass" )
0 commit comments