Skip to content

Commit ca83c7e

Browse files
committed
pre-commit
1 parent 283c473 commit ca83c7e

File tree

21 files changed

+316
-470
lines changed

21 files changed

+316
-470
lines changed
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig
12
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
23
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
34
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig
4-
print("WOW")
5-
from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig
65

7-
__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig", "CosmosVideoConfig"]
6+
__all__ = [
7+
"HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig",
8+
"CosmosVideoConfig"
9+
]

fastvideo/configs/models/dits/cosmos.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class CosmosArchConfig(DiTArchConfig):
9090
qk_norm: str = "rms_norm"
9191
eps: float = 1e-6
9292
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])
93-
9493

9594
def __post_init__(self):
9695
super().__post_init__()

fastvideo/configs/models/encoders/t5.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class T5Config(TextEncoderConfig):
9494
@dataclass
9595
class T5LargeConfig(TextEncoderConfig):
9696
"""T5 Large configuration for your specific model."""
97-
arch_config: TextEncoderArchConfig = field(default_factory=T5LargeArchConfig)
97+
arch_config: TextEncoderArchConfig = field(
98+
default_factory=T5LargeArchConfig)
9899

99100
prefix: str = "t5"

fastvideo/configs/models/vaes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from fastvideo.configs.models.vaes.cosmosvae import CosmosVAEConfig
12
from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig
23
from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig
34
from fastvideo.configs.models.vaes.wanvae import WanVAEConfig
4-
from fastvideo.configs.models.vaes.cosmosvae import CosmosVAEConfig
55

66
__all__ = [
77
"HunyuanVAEConfig",

fastvideo/configs/pipelines/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from fastvideo.configs.pipelines.base import (PipelineConfig,
22
SlidingTileAttnConfig)
3+
from fastvideo.configs.pipelines.cosmos import CosmosConfig
34
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
45
from fastvideo.configs.pipelines.registry import (
56
get_pipeline_config_cls_from_name)
67
from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig
78
from fastvideo.configs.pipelines.wan import (WanI2V480PConfig, WanI2V720PConfig,
89
WanT2V480PConfig, WanT2V720PConfig)
910

10-
from fastvideo.configs.pipelines.cosmos import CosmosConfig
11-
1211
__all__ = [
1312
"HunyuanConfig", "FastHunyuanConfig", "PipelineConfig",
1413
"SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig",

fastvideo/configs/pipelines/cosmos.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
import torch
66

77
from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
8-
98
from fastvideo.configs.models.dits import CosmosVideoConfig
10-
from fastvideo.configs.models.encoders import (BaseEncoderOutput,
11-
T5LargeConfig)
9+
from fastvideo.configs.models.encoders import BaseEncoderOutput, T5LargeConfig
1210
from fastvideo.configs.models.vaes import CosmosVAEConfig
1311
from fastvideo.configs.pipelines.base import PipelineConfig
1412

@@ -19,39 +17,35 @@ def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
1917
Return raw last_hidden_state without truncation/padding.
2018
"""
2119
hidden_state = outputs.last_hidden_state
22-
20+
2321
if hidden_state is None:
2422
raise ValueError("T5 Large outputs missing last_hidden_state")
25-
23+
2624
nan_count = torch.isnan(hidden_state).sum()
2725
if nan_count > 0:
2826
hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0)
29-
27+
3028
return hidden_state
3129

3230

3331
@dataclass
3432
class CosmosConfig(PipelineConfig):
3533
"""Configuration for Cosmos2 Video2World pipeline matching diffusers."""
3634

37-
3835
dit_config: DiTConfig = field(default_factory=CosmosVideoConfig)
39-
4036

4137
vae_config: VAEConfig = field(default_factory=CosmosVAEConfig)
42-
4338

4439
text_encoder_configs: tuple[EncoderConfig, ...] = field(
4540
default_factory=lambda: (T5LargeConfig(), ))
4641
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
4742
...] = field(default_factory=lambda:
4843
(t5_large_postprocess_text, ))
4944

50-
5145
dit_precision: str = "bf16"
5246
vae_precision: str = "fp16"
5347
text_encoder_precisions: tuple[str, ...] = field(
54-
default_factory=lambda: ("bf16",))
48+
default_factory=lambda: ("bf16", ))
5549

5650
conditioning_strategy: str = "frame_replace"
5751
min_num_conditional_frames: int = 1
@@ -61,13 +55,12 @@ class CosmosConfig(PipelineConfig):
6155
state_ch: int = 16
6256
state_t: int = 24
6357
text_encoder_class: str = "T5"
64-
6558

6659
embedded_cfg_scale: int = 6
67-
flow_shift: float = 1.0
60+
flow_shift: float = 1.0
6861

6962
def __post_init__(self):
7063
self.vae_config.load_encoder = True
7164
self.vae_config.load_decoder = True
72-
73-
self._vae_latent_dim = 16
65+
66+
self._vae_latent_dim = 16

fastvideo/configs/pipelines/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from collections.abc import Callable
66

77
from fastvideo.configs.pipelines.base import PipelineConfig
8+
from fastvideo.configs.pipelines.cosmos import CosmosConfig
89
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
910
from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig
1011
from fastvideo.configs.pipelines.wan import (FastWan2_1_T2V_480P_Config,
1112
FastWan2_2_TI2V_5B_Config,
1213
WanI2V480PConfig, WanI2V720PConfig,
1314
WanT2V480PConfig, WanT2V720PConfig)
14-
from fastvideo.configs.pipelines.cosmos import CosmosConfig
1515
from fastvideo.logger import init_logger
1616
from fastvideo.utils import (maybe_download_model_index,
1717
verify_model_config_and_directory)

fastvideo/image_processor.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
This module provides lightweight image preprocessing without external dependencies beyond PyTorch/NumPy/PIL.
55
"""
66

7-
from typing import Optional, Union
8-
97
import numpy as np
108
import PIL.Image
119
import torch
@@ -29,9 +27,9 @@ def __init__(self, vae_scale_factor: int = 8) -> None:
2927

3028
def preprocess(
3129
self,
32-
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
33-
height: Optional[int] = None,
34-
width: Optional[int] = None,
30+
image: PIL.Image.Image | np.ndarray | torch.Tensor,
31+
height: int | None = None,
32+
width: int | None = None,
3533
) -> torch.Tensor:
3634
"""
3735
Preprocess an image to a normalized torch tensor.
@@ -55,14 +53,13 @@ def preprocess(
5553
else:
5654
raise ValueError(
5755
f"Unsupported image type: {type(image)}. "
58-
"Supported types: PIL.Image.Image, np.ndarray, torch.Tensor"
59-
)
56+
"Supported types: PIL.Image.Image, np.ndarray, torch.Tensor")
6057

6158
def _preprocess_pil(
6259
self,
6360
image: PIL.Image.Image,
64-
height: Optional[int] = None,
65-
width: Optional[int] = None,
61+
height: int | None = None,
62+
width: int | None = None,
6663
) -> torch.Tensor:
6764
"""Preprocess a PIL image."""
6865
if height is None:
@@ -73,7 +70,8 @@ def _preprocess_pil(
7370
height = height - (height % self.vae_scale_factor)
7471
width = width - (width % self.vae_scale_factor)
7572

76-
image = image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS)
73+
image = image.resize((width, height),
74+
resample=PIL.Image.Resampling.LANCZOS)
7775

7876
image_np = np.array(image, dtype=np.float32) / 255.0
7977

@@ -85,8 +83,8 @@ def _preprocess_pil(
8583
def _preprocess_numpy(
8684
self,
8785
image: np.ndarray,
88-
height: Optional[int] = None,
89-
width: Optional[int] = None,
86+
height: int | None = None,
87+
width: int | None = None,
9088
) -> torch.Tensor:
9189
"""Preprocess a numpy array."""
9290
# Determine target dimensions if not provided
@@ -115,7 +113,8 @@ def _preprocess_numpy(
115113
image_uint8 = image.astype(np.uint8)
116114
pil_image = PIL.Image.fromarray(image_uint8)
117115

118-
pil_image = pil_image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS)
116+
pil_image = pil_image.resize((width, height),
117+
resample=PIL.Image.Resampling.LANCZOS)
119118
image_np = np.array(pil_image, dtype=np.float32) / 255.0
120119

121120
# Ensure 3D shape
@@ -127,8 +126,8 @@ def _preprocess_numpy(
127126
def _preprocess_tensor(
128127
self,
129128
image: torch.Tensor,
130-
height: Optional[int] = None,
131-
width: Optional[int] = None,
129+
height: int | None = None,
130+
width: int | None = None,
132131
) -> torch.Tensor:
133132
"""Preprocess a torch tensor."""
134133
# Determine target dimensions
@@ -158,9 +157,10 @@ def _preprocess_tensor(
158157
else: # (H, W, C) - need to rearrange
159158
image = image.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W)
160159

161-
image = torch.nn.functional.interpolate(
162-
image, size=(height, width), mode="bilinear", align_corners=False
163-
)
160+
image = torch.nn.functional.interpolate(image,
161+
size=(height, width),
162+
mode="bilinear",
163+
align_corners=False)
164164

165165
if image.max() > 1.0: # Assume [0, 255] range
166166
image = image / 255.0
@@ -181,9 +181,11 @@ def _normalize_to_tensor(self, image_np: np.ndarray) -> torch.Tensor:
181181
"""
182182
# Convert to tensor
183183
if image_np.ndim == 2: # (H, W) - grayscale
184-
tensor = torch.from_numpy(image_np).unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
184+
tensor = torch.from_numpy(image_np).unsqueeze(0).unsqueeze(
185+
0) # (1, 1, H, W)
185186
elif image_np.ndim == 3: # (H, W, C)
186-
tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0) # (1, C, H, W)
187+
tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(
188+
0) # (1, C, H, W)
187189
else:
188190
raise ValueError(f"Expected 2D or 3D array, got {image_np.ndim}D")
189191

fastvideo/layers/layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def forward_diffusers(self, hidden_states: torch.Tensor) -> torch.Tensor:
4343
"""Forward method that matches Diffusers RMSNorm implementation exactly."""
4444
input_dtype = hidden_states.dtype
4545
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
46-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
46+
hidden_states = hidden_states * torch.rsqrt(variance +
47+
self.variance_epsilon)
4748

4849
if self.has_weight and self.weight is not None:
4950
if self.weight.dtype in [torch.float16, torch.bfloat16]:

fastvideo/layers/rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
4646

4747
def apply_rotary_emb(
4848
x: torch.Tensor,
49-
freqs_cis: torch.Tensor | tuple[torch.Tensor],
49+
freqs_cis: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
5050
use_real: bool = True,
5151
use_real_unbind_dim: int = -1,
52-
) -> tuple[torch.Tensor, torch.Tensor]:
52+
) -> torch.Tensor:
5353
"""
5454
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
5555
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are

0 commit comments

Comments
 (0)