1010 for env_var in [True , False ]\
1111])
1212@pytest .mark .forked
13- def test_device_assert (cond , opt_flag , env_var , device = "cuda" ):
13+ def test_device_assert (cond , opt_flag , env_var , device ):
1414 os .environ ['TRITON_DEBUG' ] = str (int (env_var ))
1515 torch .zeros ([1 ], dtype = torch .int32 , device = device )
1616
@@ -21,11 +21,11 @@ def _kernel(COND: tl.constexpr):
2121 if not cond and (opt_flag or env_var ):
2222 with pytest .raises (RuntimeError ):
2323 _kernel [(1 , )](cond , debug = opt_flag )
24- torch . cuda .synchronize ()
24+ getattr ( torch , device ) .synchronize ()
2525 return
2626
2727 _kernel [(1 , )](cond , debug = opt_flag )
28- torch . cuda .synchronize ()
28+ getattr ( torch , device ) .synchronize ()
2929
3030
3131@pytest .mark .parametrize ("cond" , [False , True ])
@@ -43,19 +43,18 @@ def _kernel(COND: tl.constexpr):
4343 _kernel [(1 , )](cond )
4444
4545
46- def _test_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , tri_func , ref_func ):
47- device = "cuda"
46+ def _test_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , tri_func , ref_func , device ):
4847 x = torch .tensor ([x ], dtype = getattr (torch , x_dtype ), device = device )
4948 y = torch .tensor ([y ], dtype = getattr (torch , y_dtype ), device = device )
5049 z = torch .empty_like (x )
5150 if should_overflow and debug :
5251 with pytest .raises (RuntimeError ) as exc_info :
5352 tri_func [(1 , )](x , y , z , debug = debug )
54- torch . cuda .synchronize ()
53+ getattr ( torch , device ) .synchronize ()
5554 assert "device-side assert" in str (exc_info .value )
5655 else :
5756 tri_func [(1 , )](x , y , z , debug = debug )
58- torch . cuda .synchronize ()
57+ getattr ( torch , device ) .synchronize ()
5958 assert int (z ) == int (ref_func (x , y ))
6059
6160
@@ -74,13 +73,13 @@ def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref
7473 (2 ** 15 - 1 , 1 , 'int16' , 'int16' , True , True ),
7574])
7675@pytest .mark .forked
77- def test_sanitize_int_add_overflow (x , y , x_dtype , y_dtype , debug , should_overflow ):
76+ def test_sanitize_int_add_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , device ):
7877
7978 @triton .jit
8079 def _kernel_add (X , Y , Z ):
8180 tl .store (Z , tl .load (X ) + tl .load (Y ))
8281
83- _test_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , _kernel_add , lambda x , y : x + y )
82+ _test_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , _kernel_add , lambda x , y : x + y , device )
8483
8584
8685# mul overflow
@@ -95,13 +94,13 @@ def _kernel_add(X, Y, Z):
9594 (- 2 ** 30 , 2 , 'int32' , 'int32' , True , False ),
9695])
9796@pytest .mark .forked
98- def test_sanitize_int_mul_overflow (x , y , x_dtype , y_dtype , debug , should_overflow ):
97+ def test_sanitize_int_mul_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , device ):
9998
10099 @triton .jit
101100 def _kernel_mul (X , Y , Z ):
102101 tl .store (Z , tl .load (X ) * tl .load (Y ))
103102
104- _test_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , _kernel_mul , lambda x , y : x * y )
103+ _test_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , _kernel_mul , lambda x , y : x * y , device )
105104
106105
107106# sub overflow
@@ -115,10 +114,10 @@ def _kernel_mul(X, Y, Z):
115114 (- 2 ** 31 , - 1 , 'int32' , 'int32' , True , False ),
116115])
117116@pytest .mark .forked
118- def test_sanitize_int_sub_overflow (x , y , x_dtype , y_dtype , debug , should_overflow ):
117+ def test_sanitize_int_sub_overflow (x , y , x_dtype , y_dtype , debug , should_overflow , device ):
119118
120119 @triton .jit
121120 def _kernel_sub (X , Y , Z ):
122121 tl .store (Z , tl .load (X ) - tl .load (Y ))
123122
124- _test_overflow (x , y , x_dtype , y_dtype , should_overflow , debug , _kernel_sub , lambda x , y : x - y )
123+ _test_overflow (x , y , x_dtype , y_dtype , should_overflow , debug , _kernel_sub , lambda x , y : x - y , device )
0 commit comments