@@ -21,6 +21,7 @@ class OptFlags:
21
21
split_k : int
22
22
fused_scatter : bool
23
23
is_persistent : bool
24
+ idle_sms : int
24
25
epilogue_subtile : int | None
25
26
arch : str
26
27
target_kernel_kwargs : dict
@@ -116,6 +117,7 @@ def make_default_opt_flags_amd(
116
117
split_k = split_k ,
117
118
fused_scatter = constraints .get ('fused_scatter' , False ),
118
119
is_persistent = is_persistent ,
120
+ idle_sms = 0 ,
119
121
epilogue_subtile = constraints .get ('epilogue_subtile' , None ),
120
122
arch = None ,
121
123
target_kernel_kwargs = target_kernel_kwargs ,
@@ -140,7 +142,7 @@ def make_default_opt_flags_nvidia(
140
142
epilogue_effective_itemsize ,
141
143
constraints ,
142
144
):
143
- constraints_supported = ["block_m" , "block_k" , "split_k" , "fused_scatter" , "is_persistent" , "epilogue_subtile" , "num_stages" ]
145
+ constraints_supported = ["block_m" , "block_k" , "split_k" , "fused_scatter" , "is_persistent" , "epilogue_subtile" , "num_stages" , "idle_sms" ]
144
146
assert not any ([c not in constraints_supported for c in constraints ]), constraints .keys ()
145
147
# tokens per expert
146
148
if routing_data is None :
@@ -236,6 +238,7 @@ def make_default_opt_flags_nvidia(
236
238
epilogue_subtile = epilogue_subtile ,
237
239
arch = arch ,
238
240
target_kernel_kwargs = dict (),
241
+ idle_sms = constraints .get ("idle_sms" , 0 ),
239
242
)
240
243
# check constraints
241
244
assert all (getattr (ret , ck ) == cv for ck , cv in constraints .items () if cv is not None ), f"{ ret } != { constraints } "
@@ -283,7 +286,8 @@ def make_opt_flags(
283
286
return _opt_flags
284
287
args = [out_dtype , lhs_dtype , rhs_dtype , precision_config , microscaling_ctx , m , n , k ,
285
288
routing_data , can_use_persistent_tma , can_use_fused_scatter ,
286
- enforce_bitwise_invariance , epilogue_effective_itemsize , _opt_flags_constraints ]
289
+ enforce_bitwise_invariance , epilogue_effective_itemsize ,
290
+ _opt_flags_constraints ]
287
291
backend = triton .runtime .driver .active .get_current_target ().backend
288
292
if backend == "hip" :
289
293
return make_default_opt_flags_amd (* args )
0 commit comments