1616from torch ._inductor .utils import fresh_cache , is_big_gpu , run_and_get_code
1717from torch .testing import FileCheck
1818from torch .testing ._internal .common_utils import skipIfRocm
19- from torch .testing ._internal .inductor_utils import HAS_CUDA_AND_TRITON
19+ from torch .testing ._internal .inductor_utils import GPU_TYPE , HAS_GPU_AND_TRITON
2020
2121
2222class PadMMTest (TestCase ):
@@ -38,15 +38,15 @@ class Model(torch.nn.Module):
3838 def __init__ (self ) -> None :
3939 super ().__init__ ()
4040 self .w = rand_strided (
41- (K2 , N ), (1 , K2 ), device = "cuda" , dtype = torch .float32
41+ (K2 , N ), (1 , K2 ), device = GPU_TYPE , dtype = torch .float32
4242 )
4343
4444 def forward (self , a ):
4545 a1 = torch .narrow (a , 1 , 0 , K2 )
4646 return torch .mm (a1 , self .w )
4747
48- fn = Model ().cuda ( )
49- a = rand_strided ((M , K1 ), (K1 , 1 ), device = "cuda" , dtype = torch .float32 )
48+ fn = Model ().to ( GPU_TYPE )
49+ a = rand_strided ((M , K1 ), (K1 , 1 ), device = GPU_TYPE , dtype = torch .float32 )
5050 aligned_k = get_padded_length (K2 , get_alignment_size (a )) + K2
5151 torch ._dynamo .mark_dynamic (a , 0 )
5252 with unittest .mock .patch (
@@ -72,17 +72,17 @@ class Model(torch.nn.Module):
7272 def __init__ (self ) -> None :
7373 super ().__init__ ()
7474 self .w = rand_strided (
75- (K2 , N ), (1 , K2 ), device = "cuda" , dtype = torch .float32
75+ (K2 , N ), (1 , K2 ), device = GPU_TYPE , dtype = torch .float32
7676 )
7777
7878 def forward (self , a , b ):
7979 c = torch .cat ([a , b ], dim = 0 )
8080 a1 = torch .narrow (c , 1 , 0 , K2 )
8181 return torch .mm (a1 , self .w )
8282
83- fn = Model ().cuda ( )
84- a = rand_strided ((M1 , K1 ), (K1 , 1 ), device = "cuda" , dtype = torch .float32 )
85- b = rand_strided ((M2 , K1 ), (K1 , 1 ), device = "cuda" , dtype = torch .float32 )
83+ fn = Model ().to ( GPU_TYPE )
84+ a = rand_strided ((M1 , K1 ), (K1 , 1 ), device = GPU_TYPE , dtype = torch .float32 )
85+ b = rand_strided ((M2 , K1 ), (K1 , 1 ), device = GPU_TYPE , dtype = torch .float32 )
8686 torch ._dynamo .mark_dynamic (a , 0 )
8787 torch ._dynamo .mark_dynamic (b , 0 )
8888 aligned_k = get_padded_length (K2 , get_alignment_size (a )) + K2
@@ -110,9 +110,9 @@ def __init__(self) -> None:
110110 def forward (self , a , b ):
111111 return torch .mm (a , b )
112112
113- fn = Model ().cuda ( )
114- a = rand_strided ((M , K ), (K , 1 ), device = "cuda" , dtype = torch .float32 )
115- b = rand_strided ((K , N ), (1 , K ), device = "cuda" , dtype = torch .float32 )
113+ fn = Model ().to ( GPU_TYPE )
114+ a = rand_strided ((M , K ), (K , 1 ), device = GPU_TYPE , dtype = torch .float32 )
115+ b = rand_strided ((K , N ), (1 , K ), device = GPU_TYPE , dtype = torch .float32 )
116116 aligned_k = get_padded_length (K , get_alignment_size (a )) + K
117117 torch ._dynamo .mark_dynamic (b , 1 )
118118 with unittest .mock .patch (
@@ -139,9 +139,9 @@ def __init__(self) -> None:
139139 def forward (self , a , b ):
140140 return torch .mm (a , b )
141141
142- fn = Model ().cuda ( )
143- a = rand_strided ((M , K ), (K , 1 ), device = "cuda" , dtype = torch .float32 )
144- b = rand_strided ((K , N ), (1 , K ), device = "cuda" , dtype = torch .float32 )
142+ fn = Model ().to ( GPU_TYPE )
143+ a = rand_strided ((M , K ), (K , 1 ), device = GPU_TYPE , dtype = torch .float32 )
144+ b = rand_strided ((K , N ), (1 , K ), device = GPU_TYPE , dtype = torch .float32 )
145145 # TODO: Getting the alignment right requires pattern matcher to
146146 # run on newly added nodes
147147 aligned_m = get_padded_length (M , get_alignment_size (a )) + M
@@ -168,9 +168,9 @@ def __init__(self) -> None:
168168 def forward (self , a , b ):
169169 return torch .mm (a , b )
170170
171- fn = Model ().cuda ( )
172- a = rand_strided ((M , K ), (K , 1 ), device = "cuda" , dtype = torch .float32 )
173- b = rand_strided ((K , N ), (1 , K ), device = "cuda" , dtype = torch .float32 )
171+ fn = Model ().to ( GPU_TYPE )
172+ a = rand_strided ((M , K ), (K , 1 ), device = GPU_TYPE , dtype = torch .float32 )
173+ b = rand_strided ((K , N ), (1 , K ), device = GPU_TYPE , dtype = torch .float32 )
174174 torch ._dynamo .mark_dynamic (a , 0 )
175175 torch ._dynamo .mark_dynamic (a , 1 )
176176 torch ._dynamo .mark_dynamic (b , 0 )
@@ -188,9 +188,9 @@ def test_zero_dim(self):
188188 def addmm (x , a , b ):
189189 return torch .addmm (x , a , b )
190190
191- x = torch .randn (100 ).cuda ( )
192- a = torch .randn (0 , 10 ).cuda ( )
193- b = torch .randn (10 , 100 ).cuda ( )
191+ x = torch .randn (100 ).to ( GPU_TYPE )
192+ a = torch .randn (0 , 10 ).to ( GPU_TYPE )
193+ b = torch .randn (10 , 100 ).to ( GPU_TYPE )
194194 self .assertEqual (torch .compile (addmm )(x , a , b ), addmm (x , a , b ))
195195
196196 @inductor_config .patch (
@@ -209,9 +209,9 @@ def __init__(self) -> None:
209209 def forward (self , a , b ):
210210 return torch .bmm (a , b )
211211
212- fn = Model ().cuda ( )
213- a = torch .randn (B , M , K , device = "cuda" , dtype = torch .float32 )
214- b = torch .randn (B , K , N , device = "cuda" , dtype = torch .float32 )
212+ fn = Model ().to ( GPU_TYPE )
213+ a = torch .randn (B , M , K , device = GPU_TYPE , dtype = torch .float32 )
214+ b = torch .randn (B , K , N , device = GPU_TYPE , dtype = torch .float32 )
215215 aligned_k = get_padded_length (K , get_alignment_size (a )) + K
216216 torch ._dynamo .mark_dynamic (a , 0 )
217217 torch ._dynamo .mark_dynamic (b , 0 )
@@ -240,9 +240,9 @@ def __init__(self) -> None:
240240 def forward (self , a , b ):
241241 return torch .bmm (a , b )
242242
243- fn = Model ().cuda ( )
244- a = torch .randn (B , M , K , device = "cuda" , dtype = torch .float32 )
245- b = torch .randn (B , K , N , device = "cuda" , dtype = torch .float32 )
243+ fn = Model ().to ( GPU_TYPE )
244+ a = torch .randn (B , M , K , device = GPU_TYPE , dtype = torch .float32 )
245+ b = torch .randn (B , K , N , device = GPU_TYPE , dtype = torch .float32 )
246246 aligned_n = get_padded_length (N , get_alignment_size (b )) + N
247247 torch ._dynamo .mark_dynamic (a , 2 )
248248 torch ._dynamo .mark_dynamic (b , 1 )
@@ -271,9 +271,9 @@ def __init__(self) -> None:
271271 def forward (self , a , b ):
272272 return torch .bmm (a , b )
273273
274- fn = Model ().cuda ( )
275- a = torch .randn (B , M , K , device = "cuda" , dtype = torch .float32 )
276- b = torch .randn (B , K , N , device = "cuda" , dtype = torch .float32 )
274+ fn = Model ().to ( GPU_TYPE )
275+ a = torch .randn (B , M , K , device = GPU_TYPE , dtype = torch .float32 )
276+ b = torch .randn (B , K , N , device = GPU_TYPE , dtype = torch .float32 )
277277 aligned_n = get_padded_length (N , get_alignment_size (b )) + N
278278 torch ._dynamo .mark_dynamic (a , 0 )
279279 torch ._dynamo .mark_dynamic (a , 1 )
@@ -302,10 +302,10 @@ def __init__(self) -> None:
302302 def forward (self , a , b , c ):
303303 return torch .addmm (a , b , c )
304304
305- fn = Model ().cuda ( )
306- a = torch .randn (M , N , device = "cuda" , dtype = torch .float32 )
307- b = torch .randn (M , K , device = "cuda" , dtype = torch .float32 )
308- c = torch .randn (K , N , device = "cuda" , dtype = torch .float32 )
305+ fn = Model ().to ( GPU_TYPE )
306+ a = torch .randn (M , N , device = GPU_TYPE , dtype = torch .float32 )
307+ b = torch .randn (M , K , device = GPU_TYPE , dtype = torch .float32 )
308+ c = torch .randn (K , N , device = GPU_TYPE , dtype = torch .float32 )
309309 aligned_k = get_padded_length (K , get_alignment_size (b )) + K
310310 torch ._dynamo .mark_dynamic (a , 0 )
311311 torch ._dynamo .mark_dynamic (b , 0 )
@@ -333,10 +333,10 @@ def __init__(self) -> None:
333333 def forward (self , a , b , c ):
334334 return torch .addmm (a , b , c )
335335
336- fn = Model ().cuda ( )
337- a = torch .randn (M , N , device = "cuda" , dtype = torch .float32 )
338- b = torch .randn (M , K , device = "cuda" , dtype = torch .float32 )
339- c = torch .randn (K , N , device = "cuda" , dtype = torch .float32 )
336+ fn = Model ().to ( GPU_TYPE )
337+ a = torch .randn (M , N , device = GPU_TYPE , dtype = torch .float32 )
338+ b = torch .randn (M , K , device = GPU_TYPE , dtype = torch .float32 )
339+ c = torch .randn (K , N , device = GPU_TYPE , dtype = torch .float32 )
340340 torch ._dynamo .mark_dynamic (a , 0 )
341341 torch ._dynamo .mark_dynamic (a , 1 )
342342 torch ._dynamo .mark_dynamic (b , 0 )
@@ -357,7 +357,7 @@ def test_pad_single_cat(self):
357357 def foo (x , y ):
358358 return x @ y
359359
360- inps = [torch .rand ([5 , 5 ], device = "cuda" ) for _ in range (2 )]
360+ inps = [torch .rand ([5 , 5 ], device = GPU_TYPE ) for _ in range (2 )]
361361 out = foo (* inps )
362362 self .assertEqual (out , inps [0 ] @ inps [1 ])
363363
@@ -371,19 +371,19 @@ def foo(input, x, y):
371371 for a in [1 , 4 ]:
372372 for b in [1 , 6 ]:
373373 inps = (
374- torch .rand ([a , b ], device = "cuda" ),
375- torch .rand ([4 , 5 ], device = "cuda" ),
376- torch .rand ([5 , 6 ], device = "cuda" ),
374+ torch .rand ([a , b ], device = GPU_TYPE ),
375+ torch .rand ([4 , 5 ], device = GPU_TYPE ),
376+ torch .rand ([5 , 6 ], device = GPU_TYPE ),
377377 )
378378 out = foo (* inps )
379379 out_eager = torch .ops .aten .addmm (* inps )
380380 self .assertEqual (out , out_eager )
381381
382382 for a in [1 , 6 ]:
383383 inps = (
384- torch .rand ([a ], device = "cuda" ),
385- torch .rand ([4 , 5 ], device = "cuda" ),
386- torch .rand ([5 , 6 ], device = "cuda" ),
384+ torch .rand ([a ], device = GPU_TYPE ),
385+ torch .rand ([4 , 5 ], device = GPU_TYPE ),
386+ torch .rand ([5 , 6 ], device = GPU_TYPE ),
387387 )
388388 out = foo (* inps )
389389 out_eager = torch .ops .aten .addmm (* inps )
@@ -395,8 +395,8 @@ def test_pad_batch(self):
395395 n = 9
396396 k = 11
397397 batch_size = 3
398- mat1 = torch .ones ((batch_size , m , k ), device = "cuda" , dtype = torch .float16 )
399- mat2 = torch .ones ((batch_size , k , n ), device = "cuda" , dtype = torch .float16 )
398+ mat1 = torch .ones ((batch_size , m , k ), device = GPU_TYPE , dtype = torch .float16 )
399+ mat2 = torch .ones ((batch_size , k , n ), device = GPU_TYPE , dtype = torch .float16 )
400400 expected_alignment = get_alignment_size (mat1 )
401401
402402 assert expected_alignment == 8 , "Alignment for float16 should be 8"
@@ -413,7 +413,7 @@ def bmm(mat1, mat2):
413413 # in call code, expect to see a single pad per input, and then we should see padded allocation for output
414414 FileCheck ().check ("del async_compile" ).check_count (
415415 ".run(" , 2 , exactly = True
416- ).check ("empty_strided_cuda ((3, 8, 16)" ).run (code )
416+ ).check (f"empty_strided_ { GPU_TYPE } ((3, 8, 16)" ).run (code )
417417
418418 assert torch .allclose (res2 , bmm_expected_result ), (
419419 "BMM results are not identical"
@@ -425,7 +425,7 @@ def test_exclude_padding(self):
425425 def mm (a , b ):
426426 return a @ b
427427
428- mm (torch .rand ([25 , 25 ], device = "cuda" ), torch .rand ([25 , 25 ], device = "cuda" ))
428+ mm (torch .rand ([25 , 25 ], device = GPU_TYPE ), torch .rand ([25 , 25 ], device = GPU_TYPE ))
429429 local_cache = get_pad_cache ().get_local_cache ()
430430 self .assertTrue (len (local_cache ) == 2 )
431431 FileCheck ().check_count ("exclude_pad:False" , 2 , exactly = True ).run (
@@ -436,7 +436,7 @@ def mm(a, b):
436436 def mm (a , b ):
437437 return (a + 1 ) @ b
438438
439- mm (torch .rand ([25 , 25 ], device = "cuda" ), torch .rand ([25 , 25 ], device = "cuda" ))
439+ mm (torch .rand ([25 , 25 ], device = GPU_TYPE ), torch .rand ([25 , 25 ], device = GPU_TYPE ))
440440 local_cache = get_pad_cache ().get_local_cache ()
441441 # reuse original base timing
442442 self .assertTrue (len (local_cache ) == 3 )
@@ -455,8 +455,8 @@ def test_exclude_cat_padding(self):
455455 def mm (inps , b ):
456456 return torch .cat (inps ) @ b
457457
458- inp = torch .rand ([2046 , 2046 ], device = "cuda" )
459- inp2 = torch .rand ([2046 , 2046 ], device = "cuda" )
458+ inp = torch .rand ([2046 , 2046 ], device = GPU_TYPE )
459+ inp2 = torch .rand ([2046 , 2046 ], device = GPU_TYPE )
460460
461461 inps = inp .chunk (3 )
462462 mm (inps , inp2 )
@@ -471,7 +471,8 @@ def mm(inps, b):
471471 )
472472
473473 @unittest .skipIf (
474- not torch .cuda .is_available () or torch .cuda .get_device_capability () >= (9 , 0 ),
474+ (not torch .cuda .is_available () or torch .cuda .get_device_capability () >= (9 , 0 ))
475+ and (not torch .xpu .is_available ()),
475476 "No perf regression on H100+ with BF16" ,
476477 )
477478 @skipIfRocm
@@ -483,8 +484,8 @@ def test_pad_mm_bf16(self):
483484 m = 2
484485 n = 13
485486 k = 15691904
486- mat1 = torch .ones ((m , k ), device = "cuda" , dtype = torch .bfloat16 )
487- mat2 = torch .ones ((k , n ), device = "cuda" , dtype = torch .bfloat16 )
487+ mat1 = torch .ones ((m , k ), device = GPU_TYPE , dtype = torch .bfloat16 )
488+ mat2 = torch .ones ((k , n ), device = GPU_TYPE , dtype = torch .bfloat16 )
488489 expected_alignment = get_alignment_size (mat1 )
489490
490491 assert expected_alignment == 8 , "Alignment for bfloat16 should be 8"
@@ -504,7 +505,7 @@ def mm(mat1, mat2):
504505 # in call code, expect to see a single pad per input, and then we should see padded allocation for output
505506 FileCheck ().check ("del async_compile" ).check_count (
506507 ".run(" , 2 , exactly = True
507- ).check ("empty_strided_cuda ((8, 16)" ).run (code )
508+ ).check (f"empty_strided_ { GPU_TYPE } ((8, 16)" ).run (code )
508509
509510 assert torch .allclose (res2 , mm_expected_result ), "MM results are not identical"
510511
@@ -521,8 +522,8 @@ def fn(x, y):
521522 return x @ y
522523
523524 args = [
524- torch .randn (2 ** 4 , 2 ** 8 - 1 , device = "cuda" , dtype = torch .float16 ),
525- torch .randn (2 ** 8 - 1 , 2 ** 4 , device = "cuda" , dtype = torch .float16 ),
525+ torch .randn (2 ** 4 , 2 ** 8 - 1 , device = GPU_TYPE , dtype = torch .float16 ),
526+ torch .randn (2 ** 8 - 1 , 2 ** 4 , device = GPU_TYPE , dtype = torch .float16 ),
526527 ]
527528
528529 counters .clear ()
@@ -615,7 +616,7 @@ def test_masked_mha(B, H, S, D, device, dtype):
615616 ):
616617 mha = torch .compile (mha , fullgraph = True , backend = "inductor" )
617618 with torch .autocast (
618- device_type = "cuda" , dtype = dtype , cache_enabled = False
619+ device_type = GPU_TYPE , dtype = dtype , cache_enabled = False
619620 ):
620621 out_vid = mha (x1 , x2 , attn_mask )
621622 target_vid = torch .randn_like (out_vid )
@@ -624,7 +625,7 @@ def test_masked_mha(B, H, S, D, device, dtype):
624625 loss = loss_vid
625626 loss .backward ()
626627
627- torch .cuda .synchronize ()
628+ torch .accelerator .synchronize ()
628629
629630 # Check if any bmm operations had dtype changes
630631 for node_name_pre , node_name_post in zip (
@@ -642,13 +643,13 @@ def test_masked_mha(B, H, S, D, device, dtype):
642643 self .assertFalse (torch .any (x2 .grad .isnan ()).item ())
643644
644645 B , H , S , D = 2 , 32 , 549 , 128
645- device = "cuda"
646+ device = GPU_TYPE
646647 dtype = torch .bfloat16
647648 torch .compiler .reset ()
648649 torch .manual_seed (42 )
649650 test_masked_mha (B , H , S , D , device , dtype )
650651
651652
652653if __name__ == "__main__" :
653- if HAS_CUDA_AND_TRITON :
654+ if HAS_GPU_AND_TRITON :
654655 run_tests ()
0 commit comments