Skip to content

Commit 283c473

Browse files
committed
Remove redundant decoding fallbacks
1 parent a4d2086 commit 283c473

File tree

9 files changed

+219
-156
lines changed

9 files changed

+219
-156
lines changed

fastvideo/image_processor.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Minimal image processing utilities for FastVideo.
4+
This module provides lightweight image preprocessing without external dependencies beyond PyTorch/NumPy/PIL.
5+
"""
6+
7+
from typing import Optional, Union
8+
9+
import numpy as np
10+
import PIL.Image
11+
import torch
12+
13+
14+
class ImageProcessor:
15+
"""
16+
Minimal image processor for video frame preprocessing.
17+
18+
This is a lightweight alternative to diffusers.VideoProcessor that handles:
19+
- PIL image to tensor conversion
20+
- Resizing to specified dimensions
21+
- Normalization to [-1, 1] range
22+
23+
Args:
24+
vae_scale_factor: The VAE scale factor used to ensure dimensions are multiples of this value.
25+
"""
26+
27+
def __init__(self, vae_scale_factor: int = 8) -> None:
28+
self.vae_scale_factor = vae_scale_factor
29+
30+
def preprocess(
31+
self,
32+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
33+
height: Optional[int] = None,
34+
width: Optional[int] = None,
35+
) -> torch.Tensor:
36+
"""
37+
Preprocess an image to a normalized torch tensor.
38+
39+
Args:
40+
image: Input image (PIL Image, NumPy array, or torch tensor)
41+
height: Target height. If None, uses image's original height.
42+
width: Target width. If None, uses image's original width.
43+
44+
Returns:
45+
torch.Tensor: Normalized tensor of shape (1, 3, height, width) or (1, 1, height, width) for grayscale,
46+
with values in range [-1, 1].
47+
"""
48+
# Handle different input types
49+
if isinstance(image, PIL.Image.Image):
50+
return self._preprocess_pil(image, height, width)
51+
elif isinstance(image, np.ndarray):
52+
return self._preprocess_numpy(image, height, width)
53+
elif isinstance(image, torch.Tensor):
54+
return self._preprocess_tensor(image, height, width)
55+
else:
56+
raise ValueError(
57+
f"Unsupported image type: {type(image)}. "
58+
"Supported types: PIL.Image.Image, np.ndarray, torch.Tensor"
59+
)
60+
61+
def _preprocess_pil(
62+
self,
63+
image: PIL.Image.Image,
64+
height: Optional[int] = None,
65+
width: Optional[int] = None,
66+
) -> torch.Tensor:
67+
"""Preprocess a PIL image."""
68+
if height is None:
69+
height = image.height
70+
if width is None:
71+
width = image.width
72+
73+
height = height - (height % self.vae_scale_factor)
74+
width = width - (width % self.vae_scale_factor)
75+
76+
image = image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS)
77+
78+
image_np = np.array(image, dtype=np.float32) / 255.0
79+
80+
if image_np.ndim == 2: # Grayscale
81+
image_np = np.expand_dims(image_np, axis=-1)
82+
83+
return self._normalize_to_tensor(image_np)
84+
85+
def _preprocess_numpy(
86+
self,
87+
image: np.ndarray,
88+
height: Optional[int] = None,
89+
width: Optional[int] = None,
90+
) -> torch.Tensor:
91+
"""Preprocess a numpy array."""
92+
# Determine target dimensions if not provided
93+
if image.ndim == 3:
94+
img_height, img_width = image.shape[:2]
95+
elif image.ndim == 2:
96+
img_height, img_width = image.shape
97+
else:
98+
raise ValueError(f"Expected 2D or 3D array, got {image.ndim}D")
99+
100+
if height is None:
101+
height = img_height
102+
if width is None:
103+
width = img_width
104+
105+
height = height - (height % self.vae_scale_factor)
106+
width = width - (width % self.vae_scale_factor)
107+
108+
if image.dtype == np.uint8:
109+
pil_image = PIL.Image.fromarray(image)
110+
else:
111+
# Assume normalized [0, 1] or similar
112+
if image.max() <= 1.0:
113+
image_uint8 = (image * 255).astype(np.uint8)
114+
else:
115+
image_uint8 = image.astype(np.uint8)
116+
pil_image = PIL.Image.fromarray(image_uint8)
117+
118+
pil_image = pil_image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS)
119+
image_np = np.array(pil_image, dtype=np.float32) / 255.0
120+
121+
# Ensure 3D shape
122+
if image_np.ndim == 2:
123+
image_np = np.expand_dims(image_np, axis=-1)
124+
125+
return self._normalize_to_tensor(image_np)
126+
127+
def _preprocess_tensor(
128+
self,
129+
image: torch.Tensor,
130+
height: Optional[int] = None,
131+
width: Optional[int] = None,
132+
) -> torch.Tensor:
133+
"""Preprocess a torch tensor."""
134+
# Determine target dimensions
135+
if image.ndim == 3: # (H, W, C) or (C, H, W)
136+
if image.shape[0] in (1, 3, 4): # Likely (C, H, W)
137+
img_height, img_width = image.shape[1], image.shape[2]
138+
else: # Likely (H, W, C)
139+
img_height, img_width = image.shape[0], image.shape[1]
140+
elif image.ndim == 2: # (H, W)
141+
img_height, img_width = image.shape
142+
else:
143+
raise ValueError(f"Expected 2D or 3D tensor, got {image.ndim}D")
144+
145+
if height is None:
146+
height = img_height
147+
if width is None:
148+
width = img_width
149+
150+
height = height - (height % self.vae_scale_factor)
151+
width = width - (width % self.vae_scale_factor)
152+
153+
if image.ndim == 2:
154+
image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
155+
elif image.ndim == 3:
156+
if image.shape[0] in (1, 3, 4): # (C, H, W)
157+
image = image.unsqueeze(0) # (1, C, H, W)
158+
else: # (H, W, C) - need to rearrange
159+
image = image.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W)
160+
161+
image = torch.nn.functional.interpolate(
162+
image, size=(height, width), mode="bilinear", align_corners=False
163+
)
164+
165+
if image.max() > 1.0: # Assume [0, 255] range
166+
image = image / 255.0
167+
168+
image = 2.0 * image - 1.0
169+
170+
return image
171+
172+
def _normalize_to_tensor(self, image_np: np.ndarray) -> torch.Tensor:
173+
"""
174+
Convert normalized numpy array [0, 1] to torch tensor [-1, 1].
175+
176+
Args:
177+
image_np: NumPy array with shape (H, W) or (H, W, C) with values in [0, 1]
178+
179+
Returns:
180+
torch.Tensor: Shape (1, C, H, W) or (1, 1, H, W) with values in [-1, 1]
181+
"""
182+
# Convert to tensor
183+
if image_np.ndim == 2: # (H, W) - grayscale
184+
tensor = torch.from_numpy(image_np).unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
185+
elif image_np.ndim == 3: # (H, W, C)
186+
tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0) # (1, C, H, W)
187+
else:
188+
raise ValueError(f"Expected 2D or 3D array, got {image_np.ndim}D")
189+
190+
# Normalize to [-1, 1]
191+
tensor = 2.0 * tensor - 1.0
192+
193+
return tensor

fastvideo/models/dits/cosmos.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,14 +471,10 @@ def forward(self,
471471
w_spatial_freqs)[None, None, :, :].repeat(
472472
pe_size[0], pe_size[1], 1, 1)
473473

474-
# Apply sequence scaling in temporal dimension
475474
if fps is None:
476475
emb_t = torch.outer(seq[:pe_size[0]], temporal_freqs)
477476
else:
478-
# Videos
479-
print(f"[FASTVIDEO ROPE FORWARD] Using video mode (fps={fps})")
480477
temporal_scale = seq[:pe_size[0]] / fps * self.base_fps
481-
print(f"[FASTVIDEO ROPE FORWARD] temporal_scale range: {temporal_scale.min().item():.6f} to {temporal_scale.max().item():.6f}")
482478
emb_t = torch.outer(temporal_scale,
483479
temporal_freqs)
484480

fastvideo/models/registry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ def register_model(
236236

237237
def _raise_for_unsupported(self, architectures: list[str]) -> NoReturn:
238238
all_supported_archs = self.get_supported_archs()
239-
print('all_supported1', all_supported_archs)
240239
if any(arch in all_supported_archs for arch in architectures):
241240
raise ValueError(
242241
f"Model architectures {architectures} failed "

fastvideo/pipelines/composed_pipeline_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def load_modules(
282282
for module_name, (transformers_or_diffusers,
283283
architecture) in model_index.items():
284284
if transformers_or_diffusers is None:
285-
print("REQURED", self.required_config_modules, module_name)
286285
self.required_config_modules.remove(module_name)
287286
continue
288287
if module_name not in required_modules:

fastvideo/pipelines/stages/decoding.py

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -92,49 +92,21 @@ def forward(
9292
vae_autocast_enabled = (vae_dtype != torch.float32
9393
) and not fastvideo_args.disable_autocast
9494

95-
# Apply latents normalization for Cosmos VAE
96-
# Source: /diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:1000-1010
97-
if hasattr(self.vae, 'config') and hasattr(self.vae.config, 'latents_mean') and hasattr(self.vae.config, 'latents_std'):
98-
# Get scheduler for sigma_data
99-
pipeline = self.pipeline() if self.pipeline else None
100-
sigma_data = 1.0 # default
101-
if pipeline and hasattr(pipeline, 'modules') and 'scheduler' in pipeline.modules:
102-
scheduler = pipeline.modules['scheduler']
103-
if hasattr(scheduler, 'config') and hasattr(scheduler.config, 'sigma_data'):
104-
sigma_data = scheduler.config.sigma_data
105-
106-
latents_mean = (
107-
torch.tensor(self.vae.config.latents_mean)
108-
.view(1, self.vae.config.z_dim, 1, 1, 1)
109-
.to(latents.device, latents.dtype)
110-
)
111-
latents_std = (
112-
torch.tensor(self.vae.config.latents_std)
113-
.view(1, self.vae.config.z_dim, 1, 1, 1)
114-
.to(latents.device, latents.dtype)
115-
)
116-
117-
latents_after_mul = latents * latents_std / sigma_data
118-
latents = latents_after_mul + latents_mean
119-
120-
# Fallback to scaling_factor for other VAE types
121-
elif hasattr(self.vae, 'scaling_factor'):
95+
if hasattr(self.vae, 'scaling_factor'):
12296
if isinstance(self.vae.scaling_factor, torch.Tensor):
12397
latents = latents / self.vae.scaling_factor.to(
12498
latents.device, latents.dtype)
12599
else:
126100
latents = latents / self.vae.scaling_factor
127-
elif hasattr(self.vae, 'config') and hasattr(self.vae.config, 'scaling_factor'):
128-
latents = latents / self.vae.config.scaling_factor
129-
130-
# NOTE: Skip this if we already applied latents_mean (for Cosmos VAE)
131-
elif (hasattr(self.vae, "shift_factor")
132-
and self.vae.shift_factor is not None):
133-
if isinstance(self.vae.shift_factor, torch.Tensor):
134-
latents += self.vae.shift_factor.to(latents.device,
135-
latents.dtype)
136-
else:
137-
latents += self.vae.shift_factor
101+
102+
# Apply shifting if needed
103+
if (hasattr(self.vae, "shift_factor")
104+
and self.vae.shift_factor is not None):
105+
if isinstance(self.vae.shift_factor, torch.Tensor):
106+
latents += self.vae.shift_factor.to(latents.device,
107+
latents.dtype)
108+
else:
109+
latents += self.vae.shift_factor
138110

139111
# Decode latents
140112
with torch.autocast(device_type="cuda",
@@ -146,15 +118,8 @@ def forward(
146118
# self.vae.enable_parallel()
147119
if not vae_autocast_enabled:
148120
latents = latents.to(vae_dtype)
149-
decode_output = self.vae.decode(latents)
150121

151-
# TEMPORARY: Handle diffusers VAE decode output compatibility
152-
if hasattr(decode_output, 'sample'):
153-
# Diffusers VAE returns DecoderOutput with .sample attribute
154-
image = decode_output.sample
155-
else:
156-
# FastVideo VAE returns tensor directly
157-
image = decode_output
122+
image = self.vae.decode(latents)
158123

159124
# Normalize image to [0, 1] range
160125
image = (image / 2 + 0.5).clamp(0, 1)

0 commit comments

Comments
 (0)