@@ -698,30 +698,17 @@ def __call__(
698698 # 6. Denoising loop
699699 with self .progress_bar (total = num_inference_steps ) as progress_bar :
700700 for i , t in enumerate (timesteps ):
701+ # compute whether apply classifier-free truncation on this timestep
702+ do_classifier_free_truncation = (i + 1 ) / num_inference_steps > cfg_trunc_ratio
703+
701704 # expand the latents if we are doing classifier free guidance
702705 latent_model_input = (
703706 torch .cat ([latents ] * 2 )
704- if do_classifier_free_guidance
705- and 1 - t .item () / self .scheduler .config .num_train_timesteps < cfg_trunc_ratio
707+ if do_classifier_free_guidance and not do_classifier_free_truncation
706708 else latents
707709 )
710+
708711 current_timestep = t
709- if not torch .is_tensor (current_timestep ):
710- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
711- # This would be a good case for the `match` statement (Python 3.10+)
712- is_mps = latent_model_input .device .type == "mps"
713- if isinstance (current_timestep , float ):
714- dtype = torch .float32 if is_mps else torch .float64
715- else :
716- dtype = torch .int32 if is_mps else torch .int64
717- current_timestep = torch .tensor (
718- [current_timestep ],
719- dtype = dtype ,
720- device = latent_model_input .device ,
721- )
722- elif len (current_timestep .shape ) == 0 :
723- current_timestep = current_timestep [None ].to (latent_model_input .device )
724-
725712 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
726713 current_timestep = current_timestep .expand (latent_model_input .shape [0 ])
727714 # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
@@ -736,7 +723,7 @@ def __call__(
736723 )[0 ]
737724
738725 # perform normalization-based guidance scale on a truncated timestep interval
739- if do_classifier_free_guidance and current_timestep [ 0 ] < cfg_trunc_ratio :
726+ if self . do_classifier_free_guidance and not do_classifier_free_truncation :
740727 noise_pred_cond , noise_pred_uncond = torch .split (noise_pred , len (noise_pred ) // 2 , dim = 0 )
741728 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond )
742729 # apply normalization after classifier-free guidance
0 commit comments