@@ -1987,6 +1987,9 @@ def triton_config(
19871987 num_stages = 1 ,
19881988 num_elements_per_warp = 256 ,
19891989 min_elem_per_thread = 0 ,
1990+ num_warps = None ,
1991+ matrix_instr = None ,
1992+ waves_per_eu = None
19901993) -> Config :
19911994 """
19921995 Construct a pointwise triton config with some adjustment heuristics
@@ -2043,9 +2046,11 @@ def triton_config(
20432046 ):
20442047 z *= 2
20452048
2046- num_warps = _num_warps (
2047- conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
2048- )
2049+ # Calculate num_waprs if they are not hard passed to config
2050+ if num_warps == None :
2051+ num_warps = _num_warps (
2052+ conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
2053+ )
20492054 # we are going to arrive at 2 warps only if bs was too small due to
20502055 # numel being too small. However to workaround some ptx bugs we still
20512056 # want at least 4 warps if there's enough elements per thread
@@ -2075,7 +2080,15 @@ def triton_config(
20752080 cfg ["ZBLOCK" ] = z
20762081 check_max_block (cfg )
20772082 check_config (cfg , xnumel = xnumel , ynumel = ynumel , znumel = znumel )
2078- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
2083+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
2084+
2085+ if torch .version .hip :
2086+ if matrix_instr is not None :
2087+ config .kwargs ["matrix_instr_nonkdim" ] = matrix_instr
2088+ if waves_per_eu is not None :
2089+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2090+
2091+ return config
20792092
20802093
20812094def _get_nd_reduction_numels (r : int , size_hints : dict [str , int ]) -> dict [str , int ]:
@@ -2123,6 +2136,7 @@ def triton_config_reduction(
21232136 num_stages = 1 ,
21242137 num_warps = None ,
21252138 register_intensive = False ,
2139+ waves_per_eu = None
21262140) -> Config :
21272141 """
21282142 Construct a reduction triton config with some adjustment heuristics
@@ -2166,7 +2180,13 @@ def total_numel() -> int:
21662180 cfg = _get_config ({"x" : x , ** rnumels })
21672181 check_max_block (cfg )
21682182 check_config (cfg , xnumel = size_hints ["x" ])
2169- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
2183+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
2184+
2185+ if torch .version .hip :
2186+ if waves_per_eu is not None :
2187+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2188+
2189+ return config
21702190
21712191
21722192def _get_config (numels : dict [str , int ]) -> dict [str , int ]:
@@ -2259,6 +2279,12 @@ def pointwise(
22592279 triton_config_with_settings (
22602280 size_hints , bs // 2 , num_elements_per_warp = 64
22612281 ),
2282+ # triton_config_with_settings(
2283+ # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2284+ # ),
2285+ triton_config_with_settings (
2286+ size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
2287+ ),
22622288 * hinted_configs ,
22632289 ]
22642290 if len (size_hints ) == 2 :
@@ -2491,6 +2517,24 @@ def reduction(
24912517 assert triton_meta is not None
24922518
24932519 configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta )
2520+
2521+ # Additional tuning confirgs for ROCm builds
2522+ # Add checks for reduction autotuning bools
2523+ # if torch.version.hip and inductor_meta.get("max_autotune"):
2524+ # configs = [
2525+ # triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
2526+ # triton_config_with_settings(
2527+ # size_hints, bs // 2, num_elements_per_warp=64
2528+ # ),
2529+ # # triton_config_with_settings(
2530+ # # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2531+ # # ),
2532+ # triton_config_with_settings(
2533+ # size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
2534+ # ),
2535+ # *hinted_configs,
2536+ # ]
2537+
24942538 return cached_autotune (
24952539 size_hints ,
24962540 configs = configs ,
0 commit comments