@@ -244,6 +244,11 @@ def has_latent_keyframes(self):
244244 def has_mask_hint (self ):
245245 return self .mask_hint_orig is not None
246246
247+ def get_effective_guarantee_steps (self , max_sigma : torch .Tensor ):
248+ '''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
249+ if self .start_t > max_sigma :
250+ return 0
251+ return self .guarantee_steps
247252
248253 @staticmethod
249254 def default () -> 'TimestepKeyframe' :
@@ -549,16 +554,17 @@ def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroup):
549554 self .weights = None
550555 self .latent_keyframes = None
551556
552- def prepare_current_timestep (self , t : Tensor , batched_number : int = 1 ):
557+ def prepare_current_timestep (self , t : Tensor , transformer_options : dict [ str , torch . Tensor ] ):
553558 self .t = float (t [0 ])
554559 # check if t has changed (otherwise do nothing, as step already accounted for)
555560 if self .t == self .prev_t :
556561 return
557562 # get current step percent
558563 curr_t : float = self .t
559564 prev_index = self ._current_timestep_index
565+ max_sigma = torch .max (transformer_options .get ("sigmas" , BIGMAX ))
560566 # if met guaranteed steps (or no current keyframe), look for next keyframe in case need to switch
561- if self ._current_timestep_keyframe is None or self ._current_used_steps >= self ._current_timestep_keyframe .guarantee_steps :
567+ if self ._current_timestep_keyframe is None or self ._current_used_steps >= self ._current_timestep_keyframe .get_effective_guarantee_steps ( max_sigma ) :
562568 # if has next index, loop through and see if need to switch
563569 if self .timestep_keyframes .has_index (self ._current_timestep_index + 1 ):
564570 for i in range (self ._current_timestep_index + 1 , len (self .timestep_keyframes )):
@@ -584,7 +590,7 @@ def prepare_current_timestep(self, t: Tensor, batched_number: int=1):
584590 del self .tk_mask_cond_hint_original
585591 self .tk_mask_cond_hint_original = None
586592 # if guarantee_steps greater than zero, stop searching for other keyframes
587- if self ._current_timestep_keyframe .guarantee_steps > 0 :
593+ if self ._current_timestep_keyframe .get_effective_guarantee_steps ( max_sigma ) > 0 :
588594 break
589595 # if eval_tk is outside of percent range, stop looking further
590596 else :
@@ -673,7 +679,7 @@ def get_control_inject(self, x_noisy, t, cond, batched_number, transformer_optio
673679 self .batch_size = len (t )
674680 self .cond_or_uncond = transformer_options .get ("cond_or_uncond" , None )
675681 # prepare timestep and everything related
676- self .prepare_current_timestep (t = t , batched_number = batched_number )
682+ self .prepare_current_timestep (t = t , transformer_options = transformer_options )
677683 # if should not perform any actions for the controlnet, exit without doing any work
678684 if self .strength == 0.0 or self ._current_timestep_keyframe .strength == 0.0 :
679685 return self .default_control_actions (x_noisy , t , cond , batched_number , transformer_options )
0 commit comments