@@ -750,15 +750,20 @@ def forward(
750750 print (f"[FASTVIDEO DEBUG] Step { i } : SKIPPING conditioning frame injection!" )
751751 logger .warning (f"Step { i } : Missing conditioning data - cond_indicator: { hasattr (batch , 'cond_indicator' )} , conditioning_latents: { conditioning_latents is not None } " )
752752
753- # cond_latent = cond_latent.to(target_dtype)
753+ # Convert cond_latent to target dtype BEFORE debug logging to match Diffusers
754+ cond_latent = cond_latent .to (target_dtype )
754755
755- # # Apply conditional timestep processing like diffusers (lines 720-721)
756- # cond_timestep = timestep
757- # if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
758- # cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
759- # cond_timestep = cond_timestep.to(target_dtype)
760- # if i < 3:
761- # logger.info(f"Step {i}: Applied conditional timestep - t_conditioning: {t_conditioning:.6f}, cond_timestep sum: {cond_timestep.float().sum().item():.6f}")
756+ # Apply conditional timestep processing like Diffusers (lines 792-793)
757+ cond_timestep = timestep
758+ if hasattr (batch , 'cond_indicator' ) and batch .cond_indicator is not None :
759+ # Exactly match Diffusers: cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep
760+ # First get t_conditioning (sigma_conditioning value from Diffusers)
761+ sigma_conditioning = 0.0001 # Same as Diffusers default
762+ t_conditioning = sigma_conditioning / (sigma_conditioning + 1 )
763+ cond_timestep = batch .cond_indicator * t_conditioning + (1 - batch .cond_indicator ) * timestep
764+ cond_timestep = cond_timestep .to (target_dtype )
765+ if i < 3 :
766+ logger .info (f"Step { i } : Applied conditional timestep - t_conditioning: { t_conditioning :.6f} , cond_timestep sum: { cond_timestep .float ().sum ().item ():.6f} " )
762767
763768 with set_forward_context (
764769 current_timestep = i ,
@@ -767,7 +772,8 @@ def forward(
767772 ):
768773 # Use conditioning masks from CosmosLatentPreparationStage
769774 condition_mask = batch .cond_mask .to (target_dtype ) if hasattr (batch , 'cond_mask' ) else None
770- padding_mask = torch .zeros (1 , 1 , cond_latent .shape [3 ], cond_latent .shape [4 ],
775+ # Padding mask should match original image dimensions like Diffusers (704, 1280)
776+ padding_mask = torch .zeros (1 , 1 , batch .height , batch .width ,
771777 device = cond_latent .device , dtype = target_dtype )
772778
773779 # Fallback if masks not available
@@ -786,10 +792,34 @@ def forward(
786792 logger .info (f" condition_mask shape: { condition_mask .shape if condition_mask is not None else None } " )
787793 logger .info (f" padding_mask shape: { padding_mask .shape } " )
788794
795+ # Log detailed transformer inputs for comparison with Diffusers
796+ if i < 3 :
797+ print (f"FASTVIDEO TRANSFORMER INPUTS (step { i } ):" )
798+ print (f" hidden_states: shape={ cond_latent .shape } , sum={ cond_latent .float ().sum ().item ():.6f} , mean={ cond_latent .float ().mean ().item ():.6f} " )
799+ print (f" timestep: shape={ cond_timestep .shape } , sum={ cond_timestep .float ().sum ().item ():.6f} , values={ cond_timestep .flatten ()[:5 ].float ()} " )
800+ print (f" encoder_hidden_states: shape={ batch .prompt_embeds [0 ].shape } , sum={ batch .prompt_embeds [0 ].float ().sum ().item ():.6f} " )
801+ print (f" condition_mask: shape={ condition_mask .shape if condition_mask is not None else None } , sum={ condition_mask .float ().sum ().item () if condition_mask is not None else None } " )
802+ print (f" padding_mask: shape={ padding_mask .shape } , sum={ padding_mask .float ().sum ().item ():.6f} " )
803+ print (f" fps: { 24 } , target_dtype: { target_dtype } " )
804+ print (f" DTYPES: hidden_states={ cond_latent .dtype } , timestep={ cond_timestep .dtype } , encoder_hidden_states={ batch .prompt_embeds [0 ].dtype } " )
805+ print (f" hidden_states first 5 values: { cond_latent .flatten ()[:5 ].float ()} " )
806+ print (f" encoder_hidden_states first 5 values: { batch .prompt_embeds [0 ].flatten ()[:5 ].float ()} " )
807+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
808+ f .write (f"FASTVIDEO TRANSFORMER INPUTS (step { i } ):\n " )
809+ f .write (f" hidden_states: shape={ cond_latent .shape } , sum={ cond_latent .float ().sum ().item ():.6f} , mean={ cond_latent .float ().mean ().item ():.6f} \n " )
810+ f .write (f" timestep: shape={ cond_timestep .shape } , sum={ cond_timestep .float ().sum ().item ():.6f} , values={ cond_timestep .flatten ()[:5 ].float ()} \n " )
811+ f .write (f" encoder_hidden_states: shape={ batch .prompt_embeds [0 ].shape } , sum={ batch .prompt_embeds [0 ].float ().sum ().item ():.6f} \n " )
812+ f .write (f" condition_mask: shape={ condition_mask .shape if condition_mask is not None else None } , sum={ condition_mask .float ().sum ().item () if condition_mask is not None else None } \n " )
813+ f .write (f" padding_mask: shape={ padding_mask .shape } , sum={ padding_mask .float ().sum ().item ():.6f} \n " )
814+ f .write (f" fps: { 24 } , target_dtype: { target_dtype } \n " )
815+ f .write (f" DTYPES: hidden_states={ cond_latent .dtype } , timestep={ cond_timestep .dtype } , encoder_hidden_states={ batch .prompt_embeds [0 ].dtype } \n " )
816+ f .write (f" hidden_states first 5 values: { cond_latent .flatten ()[:5 ].float ()} \n " )
817+ f .write (f" encoder_hidden_states first 5 values: { batch .prompt_embeds [0 ].flatten ()[:5 ].float ()} \n " )
818+
789819 print (f"[FASTVIDEO DENOISING] About to call transformer with hidden_states sum = { cond_latent .float ().sum ().item ()} " )
790820 noise_pred = self .transformer (
791- hidden_states = cond_latent . to ( target_dtype ),
792- timestep = timestep .to (target_dtype ),
821+ hidden_states = cond_latent , # Already converted to target_dtype above
822+ timestep = cond_timestep .to (target_dtype ),
793823 encoder_hidden_states = batch .prompt_embeds [0 ].to (target_dtype ),
794824 fps = 24 , # TODO: get fps from batch or config
795825 condition_mask = condition_mask ,
@@ -805,11 +835,7 @@ def forward(
805835 # Apply preconditioning exactly like diffusers
806836 cond_pred = (c_skip * latents + c_out * noise_pred .float ()).to (target_dtype )
807837
808- # Apply conditional indicator masking (from CosmosLatentPreparationStage)
809- if hasattr (batch , 'cond_indicator' ) and batch .cond_indicator is not None :
810- conditioning_latents = batch .conditioning_latents if batch .conditioning_latents is not None else torch .zeros_like (latents )
811- cond_pred = batch .cond_indicator * conditioning_latents + (1 - batch .cond_indicator ) * cond_pred
812-
838+ # NOTE: Conditioning frame injection is applied to cond_latent BEFORE transformer call (line 746), not after
813839 # Classifier-free guidance
814840 if batch .do_classifier_free_guidance and batch .negative_prompt_embeds is not None :
815841 # Unconditional pass - match diffusers logic (lines 755-759)
@@ -830,9 +856,17 @@ def forward(
830856 logger .info (f" negative_prompt_embeds shape: { batch .negative_prompt_embeds [0 ].shape } " )
831857 # sum: {uncond_timestep.float().sum().item():.6f}")
832858
859+ # Apply same conditional timestep processing for unconditional pass
860+ uncond_timestep = timestep
861+ if hasattr (batch , 'uncond_indicator' ) and batch .uncond_indicator is not None :
862+ sigma_conditioning = 0.0001 # Same as Diffusers default
863+ t_conditioning = sigma_conditioning / (sigma_conditioning + 1 )
864+ uncond_timestep = batch .uncond_indicator * t_conditioning + (1 - batch .uncond_indicator ) * timestep
865+ uncond_timestep = uncond_timestep .to (target_dtype )
866+
833867 noise_pred_uncond = self .transformer (
834868 hidden_states = uncond_latent .to (target_dtype ),
835- timestep = timestep .to (target_dtype ),
869+ timestep = uncond_timestep .to (target_dtype ),
836870 encoder_hidden_states = batch .negative_prompt_embeds [0 ].to (target_dtype ),
837871 fps = 24 , # TODO: get fps from batch or config
838872 condition_mask = uncond_condition_mask ,
0 commit comments