Skip to content

Commit 26c2f0f

Browse files
committed
Remove diffusers vae dependency
1 parent e45dab4 commit 26c2f0f

File tree

3 files changed

+55
-49
lines changed

3 files changed

+55
-49
lines changed

fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
# TEMPORARY: Import diffusers VAE for comparison
1414
import sys
15-
sys.path.insert(0, '/workspace/diffusers/src')
16-
from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan as DiffusersAutoencoderKLWan
15+
# sys.path.insert(0, '/mnt/fast-disks/nfs/hao_lab/kevin/diffusers/src')
16+
# from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan as DiffusersAutoencoderKLWan
1717

1818
from fastvideo.fastvideo_args import FastVideoArgs
1919
from fastvideo.logger import init_logger
@@ -27,6 +27,8 @@
2727
from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (
2828
FlowMatchEulerDiscreteScheduler)
2929

30+
from fastvideo.models.vaes.wanvae import AutoencoderKLWan
31+
3032
logger = init_logger(__name__)
3133

3234

@@ -38,93 +40,70 @@ class Cosmos2VideoToWorldPipeline(ComposedPipelineBase):
3840

3941
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
4042

41-
# TEMPORARY: Replace FastVideo VAE with diffusers VAE for testing
42-
print("[TEMPORARY] Replacing FastVideo VAE with diffusers VAE...")
43-
original_vae = self.modules["vae"]
44-
print(f"[TEMPORARY] Original VAE type: {type(original_vae)}")
43+
# original_vae = self.modules["vae"]
44+
45+
# diffusers_vae = DiffusersAutoencoderKLWan.from_pretrained(
46+
# self.model_path,
47+
# subfolder="vae",
48+
# torch_dtype=torch.bfloat16,
49+
# )
4550

46-
# Load diffusers VAE with same config
47-
diffusers_vae = DiffusersAutoencoderKLWan.from_pretrained(
48-
self.model_path,
49-
subfolder="vae",
50-
torch_dtype=torch.bfloat16,
51-
)
52-
print(f"[TEMPORARY] Diffusers VAE type: {type(diffusers_vae)}")
51+
# with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
52+
# f.write(f"[TEMPORARY] Diffusers VAE type: {type(diffusers_vae)}\n")
5353

5454
# Replace the VAE module
55-
self.modules["vae"] = diffusers_vae
56-
print("[TEMPORARY] VAE replacement complete!")
55+
# self.modules["vae"] = diffusers_vae
56+
5757

5858
self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
5959
shift=fastvideo_args.pipeline_config.flow_shift,
6060
use_karras_sigmas=True)
6161

62-
# Configure Cosmos-specific scheduler parameters (matching diffusers)
63-
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:209-219
6462
sigma_max = 80.0
6563
sigma_min = 0.002
6664
sigma_data = 1.0
6765
final_sigmas_type = "sigma_min"
6866

6967
if self.modules["scheduler"] is not None:
70-
# Update scheduler config and attributes directly
7168
scheduler = self.modules["scheduler"]
7269
scheduler.config.sigma_max = sigma_max
7370
scheduler.config.sigma_min = sigma_min
7471
scheduler.config.sigma_data = sigma_data
7572
scheduler.config.final_sigmas_type = final_sigmas_type
76-
# Also set the direct attributes used by the scheduler
7773
scheduler.sigma_max = sigma_max
7874
scheduler.sigma_min = sigma_min
7975
scheduler.sigma_data = sigma_data
8076

8177
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
8278
"""Set up pipeline stages with proper dependency injection."""
8379

84-
# Input validation - corresponds to diffusers check_inputs method
85-
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:427-456
8680
self.add_stage(stage_name="input_validation_stage",
8781
stage=InputValidationStage())
8882

89-
# Text encoding - corresponds to diffusers encode_prompt method
90-
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:265-346
91-
# Also uses _get_t5_prompt_embeds method: lines 222-262
9283
self.add_stage(stage_name="prompt_encoding_stage",
9384
stage=TextEncodingStage(
9485
text_encoders=[self.get_module("text_encoder")],
9586
tokenizers=[self.get_module("tokenizer")],
9687
))
9788

98-
# Conditioning preparation - part of main __call__ method setup
99-
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:607-628
10089
self.add_stage(stage_name="conditioning_stage",
10190
stage=ConditioningStage())
10291

103-
# Timestep preparation - corresponds to timestep setup in __call__
104-
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:630-637
105-
# Uses retrieve_timesteps function: lines 81-137
10692
self.add_stage(stage_name="timestep_preparation_stage",
10793
stage=TimestepPreparationStage(
10894
scheduler=self.get_module("scheduler")))
10995

110-
# Latent preparation - corresponds to prepare_latents method
111-
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:348-424
112-
# Also includes video preprocessing: lines 642-661
11396
self.add_stage(stage_name="latent_preparation_stage",
11497
stage=CosmosLatentPreparationStage(
11598
scheduler=self.get_module("scheduler"),
11699
transformer=self.get_module("transformer"),
117100
vae=self.get_module("vae")))
118101

119-
# Denoising loop - corresponds to main denoising loop in __call__
120-
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:673-752
121102
self.add_stage(stage_name="denoising_stage",
122103
stage=CosmosDenoisingStage(
123104
transformer=self.get_module("transformer"),
124105
scheduler=self.get_module("scheduler")))
125106

126-
# VAE decoding - corresponds to final decoding section in __call__
127-
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:755-784
128107
self.add_stage(stage_name="decoding_stage",
129108
stage=DecodingStage(vae=self.get_module("vae")))
130109

fastvideo/pipelines/stages/decoding.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(
8383
raise ValueError("Latents must be provided")
8484

8585
print(f"[FASTVIDEO VAE DEBUG] Before scaling/shifting - latents sum: {latents.float().sum().item():.6f}, shape: {latents.shape}, dtype: {latents.dtype}")
86-
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
86+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
8787
f.write(f"[FASTVIDEO VAE DEBUG] Before scaling/shifting - latents sum: {latents.float().sum().item():.6f}, shape: {latents.shape}, dtype: {latents.dtype}\n")
8888

8989
# Skip decoding if output type is latent
@@ -107,6 +107,14 @@ def forward(
107107
if hasattr(scheduler, 'config') and hasattr(scheduler.config, 'sigma_data'):
108108
sigma_data = scheduler.config.sigma_data
109109

110+
print(f"[FASTVIDEO VAE DEBUG] sigma_data = {sigma_data}")
111+
print(f"[FASTVIDEO VAE DEBUG] latents_mean config = {self.vae.config.latents_mean}")
112+
print(f"[FASTVIDEO VAE DEBUG] latents_std config = {self.vae.config.latents_std}")
113+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
114+
f.write(f"[FASTVIDEO VAE DEBUG] sigma_data = {sigma_data}\n")
115+
f.write(f"[FASTVIDEO VAE DEBUG] latents_mean config = {self.vae.config.latents_mean}\n")
116+
f.write(f"[FASTVIDEO VAE DEBUG] latents_std config = {self.vae.config.latents_std}\n")
117+
110118
latents_mean = (
111119
torch.tensor(self.vae.config.latents_mean)
112120
.view(1, self.vae.config.z_dim, 1, 1, 1)
@@ -117,7 +125,29 @@ def forward(
117125
.view(1, self.vae.config.z_dim, 1, 1, 1)
118126
.to(latents.device, latents.dtype)
119127
)
120-
latents = latents * latents_std / sigma_data + latents_mean
128+
print(f"[FASTVIDEO VAE DEBUG] latents dtype = {latents.dtype}, latents_mean dtype = {latents_mean.dtype}, latents_std dtype = {latents_std.dtype}")
129+
print(f"[FASTVIDEO VAE DEBUG] latents_mean tensor sum = {latents_mean.sum().item():.6f}")
130+
print(f"[FASTVIDEO VAE DEBUG] latents_std tensor sum = {latents_std.sum().item():.6f}")
131+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
132+
f.write(f"[FASTVIDEO VAE DEBUG] latents dtype = {latents.dtype}, latents_mean dtype = {latents_mean.dtype}, latents_std dtype = {latents_std.dtype}\n")
133+
f.write(f"[FASTVIDEO VAE DEBUG] latents_mean tensor sum = {latents_mean.sum().item():.6f}\n")
134+
f.write(f"[FASTVIDEO VAE DEBUG] latents_std tensor sum = {latents_std.sum().item():.6f}\n")
135+
136+
print(f"[FASTVIDEO VAE DEBUG] latents shape = {latents.shape}, latents_mean shape = {latents_mean.shape}, latents_std shape = {latents_std.shape}")
137+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
138+
f.write(f"[FASTVIDEO VAE DEBUG] latents shape = {latents.shape}, latents_mean shape = {latents_mean.shape}, latents_std shape = {latents_std.shape}\n")
139+
140+
latents_after_mul = latents * latents_std / sigma_data
141+
print(f"[FASTVIDEO VAE DEBUG] After multiply (latents * latents_std / sigma_data) sum = {latents_after_mul.float().sum().item():.6f}")
142+
print(f"[FASTVIDEO VAE DEBUG] latents_after_mul shape = {latents_after_mul.shape}")
143+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
144+
f.write(f"[FASTVIDEO VAE DEBUG] After multiply sum = {latents_after_mul.float().sum().item():.6f}\n")
145+
f.write(f"[FASTVIDEO VAE DEBUG] latents_after_mul shape = {latents_after_mul.shape}\n")
146+
147+
latents = latents_after_mul + latents_mean
148+
print(f"[FASTVIDEO VAE DEBUG] After adding latents_mean, latents shape = {latents.shape}")
149+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
150+
f.write(f"[FASTVIDEO VAE DEBUG] After adding latents_mean shape = {latents.shape}\n")
121151
# Fallback to scaling_factor for other VAE types
122152
elif hasattr(self.vae, 'scaling_factor'):
123153
if isinstance(self.vae.scaling_factor, torch.Tensor):
@@ -126,11 +156,10 @@ def forward(
126156
else:
127157
latents = latents / self.vae.scaling_factor
128158
elif hasattr(self.vae, 'config') and hasattr(self.vae.config, 'scaling_factor'):
129-
# Fallback to config scaling factor for other diffusers VAEs
130159
latents = latents / self.vae.config.scaling_factor
131160

132-
# Apply shifting if needed (for other VAE types)
133-
if (hasattr(self.vae, "shift_factor")
161+
# NOTE: Skip this if we already applied latents_mean (for Cosmos VAE)
162+
elif (hasattr(self.vae, "shift_factor")
134163
and self.vae.shift_factor is not None):
135164
if isinstance(self.vae.shift_factor, torch.Tensor):
136165
latents += self.vae.shift_factor.to(latents.device,
@@ -139,7 +168,7 @@ def forward(
139168
latents += self.vae.shift_factor
140169

141170
print(f"[FASTVIDEO VAE DEBUG] After scaling/shifting - latents sum: {latents.float().sum().item():.6f}")
142-
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
171+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
143172
f.write(f"[FASTVIDEO VAE DEBUG] After scaling/shifting - latents sum: {latents.float().sum().item():.6f}\n")
144173

145174
# Decode latents
@@ -163,14 +192,14 @@ def forward(
163192
image = decode_output
164193

165194
print(f"[FASTVIDEO VAE DEBUG] After decode - image sum: {image.float().sum().item():.6f}, shape: {image.shape}")
166-
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
195+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
167196
f.write(f"[FASTVIDEO VAE DEBUG] After decode - image sum: {image.float().sum().item():.6f}, shape: {image.shape}\n")
168197

169198
# Normalize image to [0, 1] range
170199
image = (image / 2 + 0.5).clamp(0, 1)
171200

172201
print(f"[FASTVIDEO VAE DEBUG] After normalization - image sum: {image.float().sum().item():.6f}")
173-
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
202+
with open("/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/fastvideo_hidden_states.log", "a") as f:
174203
f.write(f"[FASTVIDEO VAE DEBUG] After normalization - image sum: {image.float().sum().item():.6f}\n")
175204

176205
# Convert to CPU float32 for compatibility

test_fastvideo_pipeline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import os
77
import sys
88

9-
# Add FastVideo to path
10-
sys.path.insert(0, "/workspace/FastVideo")
119

1210
from fastvideo.entrypoints.video_generator import VideoGenerator
1311

@@ -18,10 +16,10 @@ def generate_video():
1816
# Configuration
1917
#input_image_path = "/workspace/FastVideo/tennis.jpg"
2018
#prompt = "A tennis ball bouncing on a racquet, the ball moves in a smooth arc as it hits the strings and rebounds with natural physics. The racquet strings vibrate slightly from the impact, and the ball continues its trajectory with realistic motion."
21-
input_image_path = "/workspace/FastVideo/yellow-scrubber.png"
19+
input_image_path = "/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/yellow-scrubber.png"
2220
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
2321
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
24-
output_path = "/workspace/FastVideo/cosmos2_fastvideo_output.mp4"
22+
output_path = "/mnt/fast-disks/nfs/hao_lab/kevin/FastVideo/cosmos2_fastvideo_output.mp4"
2523

2624
# Check if input image exists
2725
if not os.path.exists(input_image_path):
@@ -51,7 +49,7 @@ def generate_video():
5149
guidance_scale=7.0,
5250
seed=1,
5351
save_video=True,
54-
output_path=output_path
52+
output_path=output_path,
5553
)
5654

5755
if result:

0 commit comments

Comments
 (0)