@@ -1587,6 +1587,9 @@ def triton_config(
15871587 num_stages = 1 ,
15881588 num_elements_per_warp = 256 ,
15891589 min_elem_per_thread = 0 ,
1590+ num_warps = None ,
1591+ matrix_instr = None ,
1592+ waves_per_eu = None
15901593) -> Config :
15911594 """
15921595 Construct a pointwise triton config with some adjustment heuristics
@@ -1643,9 +1646,11 @@ def triton_config(
16431646 ):
16441647 z *= 2
16451648
1646- num_warps = _num_warps (
1647- conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
1648- )
1649+ # Calculate num_waprs if they are not hard passed to config
1650+ if num_warps == None :
1651+ num_warps = _num_warps (
1652+ conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
1653+ )
16491654 # we are going to arrive at 2 warps only if bs was too small due to
16501655 # numel being too small. However to workaround some ptx bugs we still
16511656 # want at least 4 warps if there's enough elements per thread
@@ -1675,7 +1680,15 @@ def triton_config(
16751680 cfg ["ZBLOCK" ] = z
16761681 check_max_block (cfg )
16771682 check_config (cfg , xnumel = xnumel , ynumel = ynumel , znumel = znumel )
1678- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
1683+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
1684+
1685+ if torch .version .hip :
1686+ if matrix_instr is not None :
1687+ config .kwargs ["matrix_instr_nonkdim" ] = matrix_instr
1688+ if waves_per_eu is not None :
1689+ config .kwargs ["waves_per_eu" ] = waves_per_eu
1690+
1691+ return config
16791692
16801693
16811694def _get_nd_reduction_numels (r : int , size_hints : dict [str , int ]) -> dict [str , int ]:
@@ -1723,6 +1736,7 @@ def triton_config_reduction(
17231736 num_stages = 1 ,
17241737 num_warps = None ,
17251738 register_intensive = False ,
1739+ waves_per_eu = None
17261740) -> Config :
17271741 """
17281742 Construct a reduction triton config with some adjustment heuristics
@@ -1766,7 +1780,13 @@ def total_numel() -> int:
17661780 cfg = _get_config ({"x" : x , ** rnumels })
17671781 check_max_block (cfg )
17681782 check_config (cfg , xnumel = size_hints ["x" ])
1769- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
1783+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
1784+
1785+ if torch .version .hip :
1786+ if waves_per_eu is not None :
1787+ config .kwargs ["waves_per_eu" ] = waves_per_eu
1788+
1789+ return config
17701790
17711791
17721792def _get_config (numels : dict [str , int ]) -> dict [str , int ]:
@@ -1854,6 +1874,12 @@ def pointwise(
18541874 triton_config_with_settings (
18551875 size_hints , bs // 2 , num_elements_per_warp = 64
18561876 ),
1877+ # triton_config_with_settings(
1878+ # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
1879+ # ),
1880+ triton_config_with_settings (
1881+ size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
1882+ ),
18571883 * hinted_configs ,
18581884 ]
18591885 if len (size_hints ) == 2 :
@@ -1985,6 +2011,24 @@ def reduction(
19852011 assert triton_meta is not None
19862012
19872013 configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta )
2014+
2015+ # Additional tuning confirgs for ROCm builds
2016+ # Add checks for reduction autotuning bools
2017+ # if torch.version.hip and inductor_meta.get("max_autotune"):
2018+ # configs = [
2019+ # triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
2020+ # triton_config_with_settings(
2021+ # size_hints, bs // 2, num_elements_per_warp=64
2022+ # ),
2023+ # # triton_config_with_settings(
2024+ # # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2025+ # # ),
2026+ # triton_config_with_settings(
2027+ # size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
2028+ # ),
2029+ # *hinted_configs,
2030+ # ]
2031+
19882032 return cached_autotune (
19892033 size_hints ,
19902034 configs = configs ,
0 commit comments