@@ -21,6 +21,7 @@ class OptFlags:
2121 split_k : int
2222 fused_scatter : bool
2323 is_persistent : bool
24+ idle_sms : int
2425 epilogue_subtile : int | None
2526 arch : str
2627 target_kernel_kwargs : dict
@@ -116,6 +117,7 @@ def make_default_opt_flags_amd(
116117 split_k = split_k ,
117118 fused_scatter = constraints .get ('fused_scatter' , False ),
118119 is_persistent = is_persistent ,
120+ idle_sms = 0 ,
119121 epilogue_subtile = constraints .get ('epilogue_subtile' , None ),
120122 arch = None ,
121123 target_kernel_kwargs = target_kernel_kwargs ,
@@ -140,7 +142,7 @@ def make_default_opt_flags_nvidia(
140142 epilogue_effective_itemsize ,
141143 constraints ,
142144):
143- constraints_supported = ["block_m" , "block_k" , "split_k" , "fused_scatter" , "is_persistent" , "epilogue_subtile" , "num_stages" ]
145+ constraints_supported = ["block_m" , "block_k" , "split_k" , "fused_scatter" , "is_persistent" , "epilogue_subtile" , "num_stages" , "idle_sms" ]
144146 assert not any ([c not in constraints_supported for c in constraints ]), constraints .keys ()
145147 # tokens per expert
146148 if routing_data is None :
@@ -236,6 +238,7 @@ def make_default_opt_flags_nvidia(
236238 epilogue_subtile = epilogue_subtile ,
237239 arch = arch ,
238240 target_kernel_kwargs = dict (),
241+ idle_sms = constraints .get ("idle_sms" , 0 ),
239242 )
240243 # check constraints
241244 assert all (getattr (ret , ck ) == cv for ck , cv in constraints .items () if cv is not None ), f"{ ret } != { constraints } "
@@ -283,7 +286,8 @@ def make_opt_flags(
283286 return _opt_flags
284287 args = [out_dtype , lhs_dtype , rhs_dtype , precision_config , microscaling_ctx , m , n , k ,
285288 routing_data , can_use_persistent_tma , can_use_fused_scatter ,
286- enforce_bitwise_invariance , epilogue_effective_itemsize , _opt_flags_constraints ]
289+ enforce_bitwise_invariance , epilogue_effective_itemsize ,
290+ _opt_flags_constraints ]
287291 backend = triton .runtime .driver .active .get_current_target ().backend
288292 if backend == "hip" :
289293 return make_default_opt_flags_amd (* args )
0 commit comments