@@ -293,25 +293,6 @@ def num_timesteps(self):
293293 def interrupt (self ):
294294 return self ._interrupt
295295
296- @staticmethod
297- def value_from_time_aware_config (config , t ):
298- if isinstance (config , (float , int , str )):
299- return config
300- elif isinstance (config , torch .Tensor ):
301- assert config .numel () == 1
302- return config .item ()
303- elif isinstance (config , (tuple , list )):
304- assert isinstance (config [0 ], (float , int , str ))
305- result = config [0 ]
306- for thresh , val in config [1 :]:
307- if t >= thresh :
308- result = val
309- else :
310- break
311- return result
312- else :
313- raise ValueError (f"invalid time-aware config { config } of type { type (config )} " )
314-
315296 @torch .no_grad ()
316297 @replace_example_docstring (EXAMPLE_DOC_STRING )
317298 def __call__ (
@@ -507,9 +488,8 @@ def __call__(
507488 and self ._cfg_truncation is not None
508489 and float (self ._cfg_truncation ) <= 1
509490 ):
510- current_guidance_scale = self .value_from_time_aware_config (
511- (self .guidance_scale , (self ._cfg_truncation , 0.0 )), t_norm
512- )
491+ if t_norm > self ._cfg_truncation :
492+ current_guidance_scale = 0.0
513493
514494 # Run CFG only if configured AND scale is non-zero
515495 apply_cfg = self .do_classifier_free_guidance and current_guidance_scale > 0
0 commit comments