@@ -35,6 +35,7 @@ def make_default_opt_flags_amd(
3535 lhs_dtype ,
3636 rhs_dtype ,
3737 precision_config ,
38+ batch_size ,
3839 m ,
3940 n ,
4041 k ,
@@ -133,6 +134,7 @@ def make_default_opt_flags_nvidia(
133134 lhs_dtype ,
134135 rhs_dtype ,
135136 precision_config ,
137+ batch_size ,
136138 m ,
137139 n ,
138140 k ,
@@ -146,7 +148,7 @@ def make_default_opt_flags_nvidia(
146148 constraints_supported = ["block_m" , "block_k" , "split_k" , "is_persistent" , "fused_scatter" , "epilogue_subtile" , "num_stages" , "idle_sms" ]
147149 assert not any ([c not in constraints_supported for c in constraints ]), constraints .keys ()
148150 # tokens per expert
149- if routing_data is None :
151+ if routing_data is None or batch_size > 1 :
150152 tokens_per_expt = m
151153 elif routing_data .expected_tokens_per_expt is None :
152154 tokens_per_expt = max (1 , m // routing_data .n_expts_tot )
@@ -164,11 +166,11 @@ def make_default_opt_flags_nvidia(
164166 block_m = max (16 , min (triton .next_power_of_2 (tokens_per_expt ), 128 ))
165167 # block n
166168 arch = None
167- block_n = opt_flags_nvidia .compute_block_n (n , arch , precision_config )
169+ block_n , block_n_tma = opt_flags_nvidia .compute_block_n (n , arch , precision_config )
168170 # is_persistent
169- grid_size = opt_flags_nvidia .compute_grid_size (routing_data , m , n , block_m , block_n )
171+ grid_size_tma = opt_flags_nvidia .compute_grid_size (routing_data , batch_size , m , n , block_m , block_n_tma )
170172 n_sms = torch .cuda .get_device_properties (0 ).multi_processor_count
171- tiles_per_sm = grid_size / n_sms
173+ tiles_per_sm = grid_size_tma / n_sms
172174 supports_persistent = can_use_persistent_tma and (arch is None or int (arch [2 :- 1 ]) >= 9 )
173175 if constraints .get ("is_persistent" , None ) is not None :
174176 is_persistent = constraints ["is_persistent" ]
@@ -178,6 +180,10 @@ def make_default_opt_flags_nvidia(
178180 # TEMP CHANGE
179181 if precision_config .act_scale is not None or precision_config .out_scale is not None :
180182 is_persistent = False
183+ # TMA is slower for batched matmuls with small m/n/k.
184+ if m * n * k < 131072 :
185+ is_persistent = False
186+ block_n = block_n_tma if is_persistent else block_n
181187 # block k
182188 if constraints .get ("block_k" , None ) is not None :
183189 block_k = constraints ["block_k" ]
@@ -189,7 +195,7 @@ def make_default_opt_flags_nvidia(
189195 elif is_persistent or enforce_bitwise_invariance or precision_config .act_scale is not None or precision_config .out_scale is not None :
190196 split_k = 1
191197 else :
192- estimated_actual_grid_size = opt_flags_nvidia .compute_grid_size (None , m , n , block_m , block_n )
198+ estimated_actual_grid_size = opt_flags_nvidia .compute_grid_size (None , batch_size , m , n , block_m , block_n )
193199 split_k = opt_flags_nvidia .compute_split_k (block_k , k , estimated_actual_grid_size )
194200 if split_k > 1 :
195201 # With split_k, results are written in f32. Use that for the following computations.
@@ -224,7 +230,7 @@ def make_default_opt_flags_nvidia(
224230 else :
225231 fused_scatter = can_use_fused_scatter and split_k == 1
226232 # Handshake with the HBM swizzling
227- num_warps = opt_flags_nvidia .compute_num_warps (block_m , block_n , precision_config )
233+ num_warps = opt_flags_nvidia .compute_num_warps (block_m , block_n , is_persistent , precision_config )
228234 ret = OptFlags (
229235 block_m = block_m ,
230236 block_n = block_n ,
@@ -275,6 +281,7 @@ def make_opt_flags(
275281 lhs_dtype ,
276282 rhs_dtype ,
277283 precision_config ,
284+ batch_size ,
278285 m ,
279286 n ,
280287 k ,
@@ -291,7 +298,7 @@ def make_opt_flags(
291298 if _opt_flags is not None :
292299 assert not _opt_flags_constraints
293300 return _opt_flags
294- args = [out_dtype , lhs_dtype , rhs_dtype , precision_config , m , n , k ,
301+ args = [out_dtype , lhs_dtype , rhs_dtype , precision_config , batch_size , m , n , k ,
295302 routing_data , can_use_persistent_tma , can_use_fused_scatter ,
296303 enforce_bitwise_invariance , epilogue_effective_itemsize ,
297304 _opt_flags_constraints ]
0 commit comments