@@ -2071,6 +2071,9 @@ def triton_config(
20712071 num_stages = 1 ,
20722072 num_elements_per_warp = 256 ,
20732073 min_elem_per_thread = 0 ,
2074+ num_warps = None ,
2075+ matrix_instr = None ,
2076+ waves_per_eu = None
20742077) -> Config :
20752078 """
20762079 Construct a pointwise triton config with some adjustment heuristics
@@ -2127,9 +2130,11 @@ def triton_config(
21272130 ):
21282131 z *= 2
21292132
2130- num_warps = _num_warps (
2131- conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
2132- )
2133+ # Calculate num_waprs if they are not hard passed to config
2134+ if num_warps == None :
2135+ num_warps = _num_warps (
2136+ conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
2137+ )
21332138 # we are going to arrive at 2 warps only if bs was too small due to
21342139 # numel being too small. However to workaround some ptx bugs we still
21352140 # want at least 4 warps if there's enough elements per thread
@@ -2159,7 +2164,15 @@ def triton_config(
21592164 cfg ["ZBLOCK" ] = z
21602165 check_max_block (cfg )
21612166 check_config (cfg , xnumel = xnumel , ynumel = ynumel , znumel = znumel )
2162- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
2167+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
2168+
2169+ if torch .version .hip :
2170+ if matrix_instr is not None :
2171+ config .kwargs ["matrix_instr_nonkdim" ] = matrix_instr
2172+ if waves_per_eu is not None :
2173+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2174+
2175+ return config
21632176
21642177
21652178def _get_nd_reduction_numels (r : int , size_hints : dict [str , int ]) -> dict [str , int ]:
@@ -2207,6 +2220,7 @@ def triton_config_reduction(
22072220 num_stages = 1 ,
22082221 num_warps = None ,
22092222 register_intensive = False ,
2223+ waves_per_eu = None
22102224) -> Config :
22112225 """
22122226 Construct a reduction triton config with some adjustment heuristics
@@ -2250,7 +2264,13 @@ def total_numel() -> int:
22502264 cfg = _get_config ({"x" : x , ** rnumels })
22512265 check_max_block (cfg )
22522266 check_config (cfg , xnumel = size_hints ["x" ])
2253- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
2267+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
2268+
2269+ if torch .version .hip :
2270+ if waves_per_eu is not None :
2271+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2272+
2273+ return config
22542274
22552275
22562276def _get_config (numels : dict [str , int ]) -> dict [str , int ]:
@@ -2388,6 +2408,12 @@ def pointwise(
23882408 triton_config_with_settings (
23892409 size_hints , bs // 2 , num_elements_per_warp = 64
23902410 ),
2411+ # triton_config_with_settings(
2412+ # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2413+ # ),
2414+ triton_config_with_settings (
2415+ size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
2416+ ),
23912417 * hinted_configs ,
23922418 ]
23932419 if len (size_hints ) == 2 :
@@ -2624,6 +2650,7 @@ def reduction(
26242650
26252651 configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta )
26262652 configs = _maybe_filter_configs_for_tma_restrictions (inductor_meta , configs )
2653+
26272654 return cached_autotune (
26282655 size_hints ,
26292656 configs = configs ,
0 commit comments