@@ -690,6 +690,11 @@ def forward(
690690 conditioning_latents = getattr (batch , 'conditioning_latents' , None )
691691 unconditioning_latents = conditioning_latents # Same for cosmos
692692
693+ # Add sigma_conditioning logic like diffusers (line 694-695)
694+ # sigma_conditioning = 0.0001 # Default value from diffusers
695+ # sigma_conditioning_tensor = torch.tensor(sigma_conditioning, dtype=torch.float32, device=latents.device)
696+ # t_conditioning = sigma_conditioning_tensor / (sigma_conditioning_tensor + 1)
697+
693698 # Sampling loop
694699 with self .progress_bar (total = num_inference_steps ) as progress_bar :
695700 for i , t in enumerate (timesteps ):
@@ -709,20 +714,39 @@ def forward(
709714 logger .info (f"Step { i } : current_sigma={ current_sigma :.6f} , current_t={ current_t :.6f} " )
710715 logger .info (f"Step { i } : c_in={ c_in :.6f} , c_skip={ c_skip :.6f} , c_out={ c_out :.6f} " )
711716
712- # Prepare timestep tensor
717+ # Prepare timestep tensor like diffusers (lines 713-715)
713718 timestep = current_t .view (1 , 1 , 1 , 1 , 1 ).expand (
714719 latents .size (0 ), - 1 , latents .size (2 ), - 1 , - 1
715- )
720+ ) # [B, 1, T, 1, 1]
716721
717722 with torch .autocast (device_type = "cuda" ,
718723 dtype = target_dtype ,
719724 enabled = autocast_enabled ):
720725
721- # Conditional forward pass
726+ # Conditional forward pass - match diffusers exactly (lines 717-721)
722727 cond_latent = latents * c_in
723- # Add conditional frame handling like diffusers
728+ print (f"[FASTVIDEO DEBUG] Step { i } : After latents * c_in, cond_latent sum = { cond_latent .float ().sum ().item ()} " )
729+
730+ # CRITICAL: Apply conditioning frame injection like diffusers
731+ print (f"[FASTVIDEO DEBUG] Step { i } : Conditioning check - cond_indicator exists: { hasattr (batch , 'cond_indicator' )} , is not None: { batch .cond_indicator is not None if hasattr (batch , 'cond_indicator' ) else 'N/A' } , conditioning_latents is not None: { conditioning_latents is not None } " )
724732 if hasattr (batch , 'cond_indicator' ) and batch .cond_indicator is not None and conditioning_latents is not None :
733+ print (f"[FASTVIDEO DEBUG] Step { i } : Before conditioning - cond_latent sum: { cond_latent .float ().sum ().item ()} , conditioning_latents sum: { conditioning_latents .float ().sum ().item ()} " )
725734 cond_latent = batch .cond_indicator * conditioning_latents + (1 - batch .cond_indicator ) * cond_latent
735+ print (f"[FASTVIDEO DEBUG] Step { i } : After conditioning - cond_latent sum: { cond_latent .float ().sum ().item ()} " )
736+ logger .info (f"Step { i } : Applied conditioning frame injection - cond_latent sum: { cond_latent .float ().sum ().item ():.6f} " )
737+ else :
738+ print (f"[FASTVIDEO DEBUG] Step { i } : SKIPPING conditioning frame injection!" )
739+ logger .warning (f"Step { i } : Missing conditioning data - cond_indicator: { hasattr (batch , 'cond_indicator' )} , conditioning_latents: { conditioning_latents is not None } " )
740+
741+ # cond_latent = cond_latent.to(target_dtype)
742+
743+ # # Apply conditional timestep processing like diffusers (lines 720-721)
744+ # cond_timestep = timestep
745+ # if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
746+ # cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
747+ # cond_timestep = cond_timestep.to(target_dtype)
748+ # if i < 3:
749+ # logger.info(f"Step {i}: Applied conditional timestep - t_conditioning: {t_conditioning:.6f}, cond_timestep sum: {cond_timestep.float().sum().item():.6f}")
726750
727751 with set_forward_context (
728752 current_timestep = i ,
@@ -750,6 +774,7 @@ def forward(
750774 logger .info (f" condition_mask shape: { condition_mask .shape if condition_mask is not None else None } " )
751775 logger .info (f" padding_mask shape: { padding_mask .shape } " )
752776
777+ print (f"[FASTVIDEO DENOISING] About to call transformer with hidden_states sum = { cond_latent .float ().sum ().item ()} " )
753778 noise_pred = self .transformer (
754779 hidden_states = cond_latent .to (target_dtype ),
755780 timestep = timestep .to (target_dtype ),
@@ -775,6 +800,7 @@ def forward(
775800
776801 # Classifier-free guidance
777802 if batch .do_classifier_free_guidance and batch .negative_prompt_embeds is not None :
803+ # Unconditional pass - match diffusers logic (lines 755-759)
778804 uncond_latent = latents * c_in
779805
780806 with set_forward_context (
@@ -790,6 +816,7 @@ def forward(
790816 logger .info (f"Step { i } : Uncond transformer inputs:" )
791817 logger .info (f" uncond_latent sum: { uncond_latent .float ().sum ().item ():.6f} " )
792818 logger .info (f" negative_prompt_embeds shape: { batch .negative_prompt_embeds [0 ].shape } " )
819+ # sum: {uncond_timestep.float().sum().item():.6f}")
793820
794821 noise_pred_uncond = self .transformer (
795822 hidden_states = uncond_latent .to (target_dtype ),
0 commit comments