@@ -1630,6 +1630,9 @@ def triton_config(
16301630 num_stages = 1 ,
16311631 num_elements_per_warp = 256 ,
16321632 min_elem_per_thread = 0 ,
1633+ num_warps = None ,
1634+ matrix_instr = None ,
1635+ waves_per_eu = None ,
16331636) -> Config :
16341637 """
16351638 Construct a pointwise triton config with some adjustment heuristics
@@ -1686,9 +1689,11 @@ def triton_config(
16861689 ):
16871690 z *= 2
16881691
1689- num_warps = _num_warps (
1690- conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
1691- )
1692+ # Calculate num_warps if they are not hard passed to config
1693+ if num_warps is None :
1694+ num_warps = _num_warps (
1695+ conditional_product (x , y , z ) // num_elements_per_warp , min_num_warps = 1
1696+ )
16921697 # we are going to arrive at 2 warps only if bs was too small due to
16931698 # numel being too small. However to workaround some ptx bugs we still
16941699 # want at least 4 warps if there's enough elements per thread
@@ -1718,7 +1723,15 @@ def triton_config(
17181723 cfg ["ZBLOCK" ] = z
17191724 check_max_block (cfg )
17201725 check_config (cfg , xnumel = xnumel , ynumel = ynumel , znumel = znumel )
1721- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
1726+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
1727+
1728+ if torch .version .hip :
1729+ if matrix_instr is not None :
1730+ config .kwargs ["matrix_instr_nonkdim" ] = matrix_instr
1731+ if waves_per_eu is not None :
1732+ config .kwargs ["waves_per_eu" ] = waves_per_eu
1733+
1734+ return config
17221735
17231736
17241737def _get_nd_reduction_numels (r : int , size_hints : dict [str , int ]) -> dict [str , int ]:
0 commit comments