55from triton_kernels .target_info import get_cdna_version
66import torch
77from .opt_flags_details import opt_flags_amd , opt_flags_nvidia
8+ from triton_kernels .tensor import bitwidth
89
910
1011@dataclass
@@ -80,15 +81,10 @@ def make_default_opt_flags_amd(
8081 num_xcds = 8
8182 xcd_swizzle = num_xcds
8283 # block_nk:
84+ # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
8385 block_n , block_k = opt_flags_amd .compute_block_nk (
8486 n , block_m , grid_m , num_xcds , lhs_dtype , rhs_dtype , precision_config
8587 )
86- # Replace block_k if provided in constraints.
87- # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
88- if constraints .get ("block_k" , None ) is not None :
89- block_k = constraints ["block_k" ]
90- if constraints .get ("block_n" , None ) is not None :
91- block_n = constraints ["block_n" ]
9288 is_persistent = constraints .get ("is_persistent" , False )
9389 # split_k:
9490 if constraints .get ("split_k" , None ) is not None :
@@ -109,10 +105,33 @@ def make_default_opt_flags_amd(
109105 epilogue_subtile = constraints .get ('epilogue_subtile' , None )
110106 if epilogue_subtile is None :
111107 epilogue_subtile = 1
108+
109+ # specific configs for F16 x MXFP4 on CDNA4
110+ # Note that these configs will exceed LDS usage with async copy enabled
111+ if is_cdna4 and bitwidth (lhs_dtype ) == 16 and bitwidth (rhs_dtype ) == 4 and precision_config .weight_scale is not None :
112+ split_k = 1
113+ if m <= 1024 :
114+ target_kernel_kwargs ["waves_per_eu" ] = 3
115+ block_n = 128
116+ block_k = 256
117+ num_warps = 4
118+ else :
119+ target_kernel_kwargs ["waves_per_eu" ] = 0
120+ block_m = 64
121+ block_n = 512
122+ block_k = 256
123+ num_warps = 8
124+
125+ def replace_with_valid_constraint (k : str , v ):
126+ if constraints .get (k , None ) is not None :
127+ return constraints [k ]
128+ else :
129+ return v
130+
112131 ret = OptFlags (
113- block_m = block_m ,
114- block_n = block_n ,
115- block_k = block_k ,
132+ block_m = replace_with_valid_constraint ( ' block_m' , block_m ) ,
133+ block_n = replace_with_valid_constraint ( ' block_n' , block_n ) ,
134+ block_k = replace_with_valid_constraint ( ' block_k' , block_k ) ,
116135 num_warps = num_warps ,
117136 num_stages = num_stages ,
118137 group_m = group_m ,
0 commit comments