Skip to content

Commit 43437ef

Browse files
committed
Update denoising
1 parent baeea19 commit 43437ef

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

fastvideo/pipelines/stages/denoising.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def verify_output(self, batch: ForwardBatch,
600600
return result
601601

602602

603-
class CosmosDenoisingStage(PipelineStage):
603+
class CosmosDenoisingStage(DenoisingStage):
604604
"""
605605
Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
606606
@@ -613,10 +613,8 @@ def __init__(self,
613613
transformer,
614614
scheduler,
615615
pipeline=None) -> None:
616-
super().__init__()
617-
self.transformer = transformer
618-
self.scheduler = scheduler # FlowMatchEulerDiscreteScheduler
619-
self.pipeline = weakref.ref(pipeline) if pipeline else None
616+
super().__init__(transformer, scheduler, pipeline)
617+
# FlowMatchEulerDiscreteScheduler is already set by parent
620618

621619
def forward(
622620
self,
@@ -659,7 +657,7 @@ def forward(
659657
self.scheduler.set_timesteps(device=latents.device, sigmas=sigmas)
660658

661659
# Initialize with maximum noise
662-
latents = torch.randn_like(latents, dtype=torch.float32) * self.scheduler.config.sigma_max
660+
latents = torch.randn_like(latents, dtype=torch.float32) * self.scheduler.sigma_max
663661

664662
# Prepare conditional frame handling (if needed)
665663
# This would be implemented based on batch.conditioning_latents or similar
@@ -692,12 +690,17 @@ def forward(
692690
# Add conditional frame handling here if needed:
693691
# cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
694692

695-
cond_velocity = self.transformer(
696-
hidden_states=cond_latent.to(target_dtype),
697-
timestep=timestep.to(target_dtype),
698-
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
699-
return_dict=False,
700-
)[0]
693+
with set_forward_context(
694+
current_timestep=i,
695+
attn_metadata=None, # TODO: implement attention metadata if needed
696+
forward_batch=batch,
697+
):
698+
cond_velocity = self.transformer(
699+
hidden_states=cond_latent.to(target_dtype),
700+
timestep=timestep.to(target_dtype),
701+
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
702+
return_dict=False,
703+
)[0]
701704

702705
# Apply preconditioning
703706
cond_pred = (c_skip * latents + c_out * cond_velocity.float()).to(target_dtype)
@@ -706,12 +709,17 @@ def forward(
706709
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
707710
uncond_latent = latents * c_in
708711

709-
uncond_velocity = self.transformer(
710-
hidden_states=uncond_latent.to(target_dtype),
711-
timestep=timestep.to(target_dtype),
712-
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
713-
return_dict=False,
714-
)[0]
712+
with set_forward_context(
713+
current_timestep=i,
714+
attn_metadata=None, # TODO: implement attention metadata if needed
715+
forward_batch=batch,
716+
):
717+
uncond_velocity = self.transformer(
718+
hidden_states=uncond_latent.to(target_dtype),
719+
timestep=timestep.to(target_dtype),
720+
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
721+
return_dict=False,
722+
)[0]
715723

716724
uncond_pred = (c_skip * latents + c_out * uncond_velocity.float()).to(target_dtype)
717725

0 commit comments

Comments
 (0)