@@ -7134,11 +7134,13 @@ def mul_add(data):
7134
7134
# -----------------------
7135
7135
7136
7136
7137
- @pytest .mark .parametrize ("arch" , ["sm70" , "sm80" , "sm90" ])
7137
+ @pytest .mark .parametrize ("arch" , ["sm70" , "sm80" , "sm90" , "gfx942" , "gfx950" , "gfx1200" ])
7138
7138
@pytest .mark .parametrize ("env_var_override" , [False , True ])
7139
7139
def test_override_arch (arch , env_var_override , device ):
7140
- if not is_cuda ():
7141
- pytest .skip ('arch only for CUDA' )
7140
+ if arch .startswith ("sm" ) and not is_cuda ():
7141
+ pytest .skip (f"{ arch } arch only for CUDA" )
7142
+ elif arch .startswith ("gfx" ) and not is_hip ():
7143
+ pytest .skip (f"{ arch } arch only for HIP" )
7142
7144
7143
7145
@triton .jit
7144
7146
def simple (data , out ):
@@ -7149,15 +7151,31 @@ def simple(data, out):
7149
7151
data = torch .randn ((128 , ), device = device , dtype = torch .float32 )
7150
7152
out = torch .empty_like (data )
7151
7153
7152
- if env_var_override :
7153
- os .environ ["TRITON_OVERRIDE_ARCH" ] = str (arch )
7154
- h = simple [(1 , )](data , out )
7155
- os .environ .pop ("TRITON_OVERRIDE_ARCH" )
7156
- else :
7157
- h = simple [(1 , )](data , out , arch = arch )
7158
- torch .testing .assert_close (data * 1.5 + 1.0 , out )
7159
- ttgir_cc = re .search (r'cuda:(\d+)' , h .asm ["ttgir" ])
7160
- assert ttgir_cc .group (1 ) == arch [2 :]
7154
+ if is_cuda ():
7155
+ if env_var_override :
7156
+ os .environ ["TRITON_OVERRIDE_ARCH" ] = str (arch )
7157
+ h = simple [(1 , )](data , out )
7158
+ os .environ .pop ("TRITON_OVERRIDE_ARCH" )
7159
+ else :
7160
+ h = simple [(1 , )](data , out , arch = arch )
7161
+ torch .testing .assert_close (data * 1.5 + 1.0 , out )
7162
+ ttgir_cc = re .search (r'cuda:(\d+)' , h .asm ["ttgir" ])
7163
+ assert ttgir_cc .group (1 ) == arch [2 :]
7164
+ elif is_hip ():
7165
+ # For HIP, the generated kernel is a binary containing the final ISA. So we cannot run
7166
+ # them like CUDA side if the chip doesn't match. Here we just check generated ISA.
7167
+ if env_var_override :
7168
+ os .environ ["TRITON_OVERRIDE_ARCH" ] = str (arch )
7169
+ h = simple .warmup (data , out , grid = (1 , ))
7170
+ os .environ .pop ("TRITON_OVERRIDE_ARCH" )
7171
+ else :
7172
+ h = simple .warmup (data , out , arch = arch , grid = (1 , ))
7173
+ ttgir_gfx = re .search (r'hip:(\w+)' , h .asm ["ttgir" ])
7174
+ ttgir_warp = re .search (r'"ttg.threads-per-warp" = (\d+)' , h .asm ["ttgir" ])
7175
+ amdgcn_gfx = re .search (r'.amdgcn_target "amdgcn-amd-amdhsa--(\w+)"' , h .asm ["amdgcn" ])
7176
+ assert ttgir_gfx .group (1 ) == arch
7177
+ assert int (ttgir_warp .group (1 )) == (32 if arch == "gfx1200" else 64 )
7178
+ assert amdgcn_gfx .group (1 ) == arch
7161
7179
7162
7180
7163
7181
# -----------------------
0 commit comments