@@ -2441,7 +2441,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
24412441negative_config = [('cumsum' , 'float32' , (32 , 32 ), - 1 , False , 4 )]
24422442
24432443
2444- def test_sum_dtype ():
2444+ def test_sum_dtype (device ):
24452445
24462446 @triton .jit
24472447 def kernel_dtype (out_ptr , init , in_dtype : tl .constexpr , out_dtype : tl .constexpr ):
@@ -2461,7 +2461,7 @@ def kernel_default_float(out_ptr):
24612461 x = tl .sum (x )
24622462 tl .store (out_ptr , x )
24632463
2464- out = torch .empty (1 , dtype = torch .int32 , device = 'cuda' )
2464+ out = torch .empty (1 , dtype = torch .int32 , device = device )
24652465 kernel_dtype [(1 , )](out , init = 1 , in_dtype = tl .int1 , out_dtype = None )
24662466 assert out [0 ] == 32 * 32
24672467
@@ -2477,9 +2477,9 @@ def kernel_default_float(out_ptr):
24772477 kernel_default_int [(1 , )](out )
24782478 assert out [0 ] == 32 * 32
24792479
2480- out = torch .empty (1 , dtype = torch .bfloat16 , device = 'cuda' )
2480+ out = torch .empty (1 , dtype = torch .bfloat16 , device = device )
24812481 kernel_default_float [(1 , )](out )
2482- torch .testing .assert_close (out [0 ], torch .tensor (32 * 32 , dtype = torch .bfloat16 , device = 'cuda' ))
2482+ torch .testing .assert_close (out [0 ], torch .tensor (32 * 32 , dtype = torch .bfloat16 , device = device ))
24832483
24842484
24852485@triton .jit
@@ -2675,16 +2675,16 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
26752675
26762676
26772677@pytest .mark .parametrize ("M, N" , [(1 , 64 ), (2 , 32 ), (4 , 16 ), (8 , 8 ), (16 , 4 ), (32 , 2 ), (64 , 1 )])
2678- def test_scan_1d (M , N ):
2678+ def test_scan_1d (M , N , device ):
26792679
26802680 @triton .jit
26812681 def scan_kernel (out_ptr , in_ptr , M : tl .constexpr , N : tl .constexpr ):
26822682 input = tl .load (in_ptr + tl .arange (0 , M ))
26832683 output = tl .cumsum (input ).reshape ([1 , M ]).broadcast_to ([N , M ])
26842684 tl .store (out_ptr + tl .arange (0 , M * N ), output .reshape ([M * N ]))
26852685
2686- x = torch .randint (- 100 , 100 , (M , ), dtype = torch .int32 , device = 'cuda' )
2687- output = torch .empty (M * N , dtype = torch .int32 , device = 'cuda' )
2686+ x = torch .randint (- 100 , 100 , (M , ), dtype = torch .int32 , device = device )
2687+ output = torch .empty (M * N , dtype = torch .int32 , device = device )
26882688
26892689 scan_kernel [(1 , )](output , x , M , N )
26902690
@@ -4813,14 +4813,14 @@ def kernel():
48134813
48144814
48154815@pytest .mark .interpreter
4816- def test_tma_load_block_shape_err ():
4816+ def test_tma_load_block_shape_err (device ):
48174817
48184818 @triton .jit
48194819 def kernel (ptr ):
48204820 desc = tl ._experimental_make_tensor_descriptor (ptr , [128 , 128 ], [128 , 1 ], [1 , 32 ])
48214821 desc .load ([0 , 0 ])
48224822
4823- input = torch .empty ((128 , 128 ), dtype = torch .int32 , device = 'cuda' )
4823+ input = torch .empty ((128 , 128 ), dtype = torch .int32 , device = device )
48244824 errc = triton .CompilationError if not is_interpreter () else InterpreterError
48254825 with pytest .raises (errc ) as e :
48264826 kernel [(1 , )](input )
@@ -4829,14 +4829,14 @@ def kernel(ptr):
48294829
48304830
48314831@pytest .mark .interpreter
4832- def test_tma_store_block_shape_err ():
4832+ def test_tma_store_block_shape_err (device ):
48334833
48344834 @triton .jit
48354835 def kernel (ptr ):
48364836 desc = tl ._experimental_make_tensor_descriptor (ptr , [128 , 128 ], [128 , 1 ], [8 , 8 ])
48374837 desc .store ([0 , 0 ], tl .zeros ((1 , 32 ), dtype = tl .int16 ))
48384838
4839- input = torch .empty ((128 , 128 ), dtype = torch .int16 , device = 'cuda' )
4839+ input = torch .empty ((128 , 128 ), dtype = torch .int16 , device = device )
48404840 errc = triton .CompilationError if not is_interpreter () else InterpreterError
48414841 with pytest .raises (errc ) as e :
48424842 kernel [(1 , )](input )
0 commit comments