Skip to content

Commit baeea19

Browse files
committed
Init pipeline
1 parent 37204c4 commit baeea19

File tree

7 files changed

+233
-9
lines changed

7 files changed

+233
-9
lines changed

fastvideo/configs/models/vaes/cosmosvae.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ class CosmosVAEArchConfig(VAEArchConfig):
1818
attn_scales: tuple[float, ...] = ()
1919
temperal_downsample: tuple[bool, ...] = (False, True, True)
2020
dropout: float = 0.0
21+
decoder_base_dim: int | None = None
22+
is_residual: bool = False
23+
in_channels: int = 3
24+
out_channels: int = 3
25+
patch_size: int | None = None
26+
scale_factor_temporal: int = 4
27+
scale_factor_spatial: int = 8
28+
clip_output: bool = True
2129
latents_mean: tuple[float, ...] = (
2230
-0.7571,
2331
-0.7089,
@@ -62,6 +70,8 @@ def __post_init__(self):
6270
self.latents_std).view(1, self.z_dim, 1, 1, 1)
6371
self.shift_factor: torch.Tensor = torch.tensor(self.latents_mean).view(
6472
1, self.z_dim, 1, 1, 1)
73+
self.temporal_compression_ratio = self.scale_factor_temporal
74+
self.spatial_compression_ratio = self.scale_factor_spatial
6575

6676

6777
@dataclass

fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin,
8888
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
8989
stochastic_sampling (`bool`, defaults to False):
9090
Whether to use stochastic sampling.
91+
final_sigmas_type (`str`, defaults to "sigma_min"):
92+
The type of final sigmas to use. Either "sigma_min" or "zero".
93+
sigma_max (`float`, *optional*):
94+
The maximum sigma value for the noise schedule.
95+
sigma_min (`float`, *optional*):
96+
The minimum sigma value for the noise schedule.
97+
sigma_data (`float`, *optional*):
98+
The sigma data value for scaling.
9199
"""
92100

93101
_compatibles: list[Any] = []
@@ -110,6 +118,10 @@ def __init__(
110118
use_beta_sigmas: bool | None = False,
111119
time_shift_type: str = "exponential",
112120
stochastic_sampling: bool = False,
121+
final_sigmas_type: str = "sigma_min",
122+
sigma_max: float | None = None,
123+
sigma_min: float | None = None,
124+
sigma_data: float | None = None,
113125
):
114126
if sum([
115127
self.config.use_beta_sigmas, self.config.use_exponential_sigmas,
@@ -403,9 +415,14 @@ def set_timesteps(
403415
[sigmas_tensor,
404416
torch.ones(1, device=sigmas_tensor.device)])
405417
else:
406-
sigmas_tensor = torch.cat(
407-
[sigmas_tensor,
408-
torch.zeros(1, device=sigmas_tensor.device)])
418+
# Handle final_sigmas_type parameter
419+
if self.config.final_sigmas_type == "sigma_min":
420+
# Use sigma_min instead of zero for final sigma
421+
final_sigma = torch.tensor([self.sigma_min], device=sigmas_tensor.device)
422+
else: # "zero" or default
423+
final_sigma = torch.zeros(1, device=sigmas_tensor.device)
424+
425+
sigmas_tensor = torch.cat([sigmas_tensor, final_sigma])
409426

410427
self.timesteps = timesteps_tensor
411428
self.sigmas = sigmas_tensor

fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,59 @@
1414
from fastvideo.logger import init_logger
1515
from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase
1616
from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage,
17-
DenoisingStage, InputValidationStage,
17+
CosmosDenoisingStage, InputValidationStage,
1818
LatentPreparationStage,
1919
TextEncodingStage,
2020
TimestepPreparationStage)
2121
from fastvideo.pipelines.stages.base import PipelineStage
22+
from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (
23+
FlowMatchEulerDiscreteScheduler)
2224

2325
logger = init_logger(__name__)
2426

2527

2628
class Cosmos2VideoToWorldPipeline(ComposedPipelineBase):
2729

2830
_required_config_modules = [
29-
"text_encoder", "tokenizer", "vae", "transformer", "scheduler"
31+
"text_encoder", "tokenizer", "vae", "transformer", "scheduler", "safety_checker"
3032
]
31-
32-
def __init__(self, *args, **kwargs):
33-
super().__init__(*args, **kwargs)
33+
34+
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
35+
36+
self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
37+
shift=fastvideo_args.pipeline_config.flow_shift)
3438

3539
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
3640
"""Set up pipeline stages with proper dependency injection."""
41+
42+
self.add_stage(stage_name="input_validation_stage",
43+
stage=InputValidationStage())
44+
45+
self.add_stage(stage_name="prompt_encoding_stage",
46+
stage=TextEncodingStage(
47+
text_encoders=[self.get_module("text_encoder")],
48+
tokenizers=[self.get_module("tokenizer")],
49+
))
50+
51+
self.add_stage(stage_name="conditioning_stage",
52+
stage=ConditioningStage())
53+
54+
self.add_stage(stage_name="timestep_preparation_stage",
55+
stage=TimestepPreparationStage(
56+
scheduler=self.get_module("scheduler")))
57+
58+
self.add_stage(stage_name="latent_preparation_stage",
59+
stage=LatentPreparationStage(
60+
scheduler=self.get_module("scheduler"),
61+
transformer=self.get_module("transformer", None)))
62+
63+
self.add_stage(stage_name="denoising_stage",
64+
stage=CosmosDenoisingStage(
65+
transformer=self.get_module("transformer"),
66+
scheduler=self.get_module("scheduler")))
67+
68+
self.add_stage(stage_name="decoding_stage",
69+
stage=DecodingStage(vae=self.get_module("vae")))
3770

3871

3972

fastvideo/pipelines/composed_pipeline_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def load_modules(
282282
for module_name, (transformers_or_diffusers,
283283
architecture) in model_index.items():
284284
if transformers_or_diffusers is None:
285+
print("REQURED", self.required_config_modules, module_name)
285286
self.required_config_modules.remove(module_name)
286287
continue
287288
if module_name not in required_modules:

fastvideo/pipelines/pipeline_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"WanImageToVideoPipeline": "wan",
2424
"StepVideoPipeline": "stepvideo",
2525
"HunyuanVideoPipeline": "hunyuan",
26+
"Cosmos2VideoToWorldPipeline": "cosmos"
2627
}
2728

2829
_PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME: dict[WorkloadType, str] = {

fastvideo/pipelines/stages/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from fastvideo.pipelines.stages.conditioning import ConditioningStage
1111
from fastvideo.pipelines.stages.decoding import DecodingStage
1212
from fastvideo.pipelines.stages.denoising import (DenoisingStage,
13-
DmdDenoisingStage)
13+
DmdDenoisingStage,
14+
CosmosDenoisingStage)
1415
from fastvideo.pipelines.stages.encoding import EncodingStage
1516
from fastvideo.pipelines.stages.image_encoding import (ImageEncodingStage,
1617
ImageVAEEncodingStage)
@@ -30,6 +31,7 @@
3031
"ConditioningStage",
3132
"DenoisingStage",
3233
"DmdDenoisingStage",
34+
"CosmosDenoisingStage",
3335
"EncodingStage",
3436
"DecodingStage",
3537
"ImageEncodingStage",

fastvideo/pipelines/stages/denoising.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,166 @@ def verify_output(self, batch: ForwardBatch,
600600
return result
601601

602602

603+
class CosmosDenoisingStage(PipelineStage):
604+
"""
605+
Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
606+
607+
This stage implements the diffusers-compatible Cosmos denoising process with velocity prediction,
608+
classifier-free guidance, and conditional video generation support.
609+
Compatible with Hugging Face Cosmos models.
610+
"""
611+
612+
def __init__(self,
613+
transformer,
614+
scheduler,
615+
pipeline=None) -> None:
616+
super().__init__()
617+
self.transformer = transformer
618+
self.scheduler = scheduler # FlowMatchEulerDiscreteScheduler
619+
self.pipeline = weakref.ref(pipeline) if pipeline else None
620+
621+
def forward(
622+
self,
623+
batch: ForwardBatch,
624+
fastvideo_args: FastVideoArgs,
625+
) -> ForwardBatch:
626+
"""
627+
Run the diffusers-style Cosmos denoising loop.
628+
629+
Args:
630+
batch: The current batch information.
631+
fastvideo_args: The inference arguments.
632+
633+
Returns:
634+
The batch with denoised latents.
635+
"""
636+
pipeline = self.pipeline() if self.pipeline else None
637+
if not fastvideo_args.model_loaded["transformer"]:
638+
loader = TransformerLoader()
639+
self.transformer = loader.load(
640+
fastvideo_args.model_paths["transformer"], fastvideo_args)
641+
if pipeline:
642+
pipeline.add_module("transformer", self.transformer)
643+
fastvideo_args.model_loaded["transformer"] = True
644+
645+
# Setup precision and autocast settings
646+
target_dtype = torch.bfloat16
647+
autocast_enabled = (target_dtype != torch.float32
648+
) and not fastvideo_args.disable_autocast
649+
650+
# Get latents and setup
651+
latents = batch.latents
652+
num_inference_steps = batch.num_inference_steps
653+
guidance_scale = batch.guidance_scale
654+
655+
# Setup scheduler with sigma schedule
656+
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
657+
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
658+
timesteps = torch.arange(num_inference_steps, device=latents.device, dtype=torch.long)
659+
self.scheduler.set_timesteps(device=latents.device, sigmas=sigmas)
660+
661+
# Initialize with maximum noise
662+
latents = torch.randn_like(latents, dtype=torch.float32) * self.scheduler.config.sigma_max
663+
664+
# Prepare conditional frame handling (if needed)
665+
# This would be implemented based on batch.conditioning_latents or similar
666+
667+
# Sampling loop
668+
with self.progress_bar(total=num_inference_steps) as progress_bar:
669+
for i, t in enumerate(timesteps):
670+
# Skip if interrupted
671+
if hasattr(self, 'interrupt') and self.interrupt:
672+
continue
673+
674+
# Get current sigma and preconditioning coefficients
675+
current_sigma = self.scheduler.sigmas[i]
676+
current_t = current_sigma / (current_sigma + 1)
677+
c_in = 1 - current_t
678+
c_skip = 1 - current_t
679+
c_out = -current_t
680+
681+
# Prepare timestep tensor
682+
timestep = current_t.view(1, 1, 1, 1, 1).expand(
683+
latents.size(0), -1, latents.size(2), -1, -1
684+
)
685+
686+
with torch.autocast(device_type="cuda",
687+
dtype=target_dtype,
688+
enabled=autocast_enabled):
689+
690+
# Conditional forward pass
691+
cond_latent = latents * c_in
692+
# Add conditional frame handling here if needed:
693+
# cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
694+
695+
cond_velocity = self.transformer(
696+
hidden_states=cond_latent.to(target_dtype),
697+
timestep=timestep.to(target_dtype),
698+
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
699+
return_dict=False,
700+
)[0]
701+
702+
# Apply preconditioning
703+
cond_pred = (c_skip * latents + c_out * cond_velocity.float()).to(target_dtype)
704+
705+
# Classifier-free guidance
706+
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
707+
uncond_latent = latents * c_in
708+
709+
uncond_velocity = self.transformer(
710+
hidden_states=uncond_latent.to(target_dtype),
711+
timestep=timestep.to(target_dtype),
712+
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
713+
return_dict=False,
714+
)[0]
715+
716+
uncond_pred = (c_skip * latents + c_out * uncond_velocity.float()).to(target_dtype)
717+
718+
# Apply guidance
719+
velocity_pred = cond_pred + guidance_scale * (cond_pred - uncond_pred)
720+
else:
721+
velocity_pred = cond_pred
722+
723+
# Convert velocity to noise for scheduler
724+
noise_pred = (latents - velocity_pred) / current_sigma
725+
726+
# Standard scheduler step
727+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
728+
729+
progress_bar.update()
730+
731+
# Update batch with final latents
732+
batch.latents = latents
733+
734+
return batch
735+
736+
def verify_input(self, batch: ForwardBatch,
737+
fastvideo_args: FastVideoArgs) -> VerificationResult:
738+
"""Verify Cosmos denoising stage inputs."""
739+
result = VerificationResult()
740+
result.add_check("latents", batch.latents,
741+
[V.is_tensor, V.with_dims(5)])
742+
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
743+
result.add_check("num_inference_steps", batch.num_inference_steps,
744+
V.positive_int)
745+
result.add_check("guidance_scale", batch.guidance_scale,
746+
V.positive_float)
747+
result.add_check("do_classifier_free_guidance",
748+
batch.do_classifier_free_guidance, V.bool_value)
749+
result.add_check(
750+
"negative_prompt_embeds", batch.negative_prompt_embeds, lambda x:
751+
not batch.do_classifier_free_guidance or V.list_not_empty(x))
752+
return result
753+
754+
def verify_output(self, batch: ForwardBatch,
755+
fastvideo_args: FastVideoArgs) -> VerificationResult:
756+
"""Verify Cosmos denoising stage outputs."""
757+
result = VerificationResult()
758+
result.add_check("latents", batch.latents,
759+
[V.is_tensor, V.with_dims(5)])
760+
return result
761+
762+
603763
class DmdDenoisingStage(DenoisingStage):
604764
"""
605765
Denoising stage for DMD.

0 commit comments

Comments
 (0)