@@ -2163,6 +2163,9 @@ def triton_config(
21632163 num_stages = 1 ,
21642164 num_elements_per_warp = 256 ,
21652165 min_elem_per_thread = 0 ,
2166+ num_warps = None ,
2167+ matrix_instr = None ,
2168+ waves_per_eu = None ,
21662169) -> Config :
21672170 """
21682171 Construct a pointwise triton config with some adjustment heuristics
@@ -2219,9 +2222,11 @@ def triton_config(
22192222 ):
22202223 z *= 2
22212224
2222- num_warps = _num_warps (
2223- conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
2224- )
2225+ # Calculate num_warps if they are not hard passed to config
2226+ if num_warps is None :
2227+ num_warps = _num_warps (
2228+ conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
2229+ )
22252230 # we are going to arrive at 2 warps only if bs was too small due to
22262231 # numel being too small. However to workaround some ptx bugs we still
22272232 # want at least 4 warps if there's enough elements per thread
@@ -2251,7 +2256,15 @@ def triton_config(
22512256 cfg ["ZBLOCK" ] = z
22522257 check_max_block (cfg )
22532258 check_config (cfg , xnumel = xnumel , ynumel = ynumel , znumel = znumel )
2254- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
2259+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
2260+
2261+ if torch .version .hip :
2262+ if matrix_instr is not None :
2263+ config .kwargs ["matrix_instr_nonkdim" ] = matrix_instr
2264+ if waves_per_eu is not None :
2265+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2266+
2267+ return config
22552268
22562269
22572270def _get_nd_reduction_numels (r : int , size_hints : dict [str , int ]) -> dict [str , int ]:
0 commit comments