Skip to content

Commit a10d0a6

Browse files
committed
Fix scheduling
1 parent 56be69d commit a10d0a6

File tree

3 files changed

+128
-20
lines changed

3 files changed

+128
-20
lines changed

fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,20 @@ def set_timesteps(
348348
sigmas_array: np.ndarray
349349
if sigmas is None:
350350
if timesteps_array is None:
351-
timesteps_array = np.linspace(self._sigma_to_t(self.sigma_max),
352-
self._sigma_to_t(self.sigma_min),
353-
num_inference_steps)
351+
t_max = self._sigma_to_t(self.sigma_max)
352+
t_min = self._sigma_to_t(self.sigma_min)
353+
print(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] sigma_max={self.sigma_max}, sigma_min={self.sigma_min}")
354+
print(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] t_max={t_max}, t_min={t_min}, num_inference_steps={num_inference_steps}")
355+
timesteps_array = np.linspace(t_max, t_min, num_inference_steps)
356+
print(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] timesteps_array first few: {timesteps_array[:3]}")
357+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
358+
f.write(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] sigma_max={self.sigma_max}, sigma_min={self.sigma_min}\n")
359+
f.write(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] t_max={t_max}, t_min={t_min}, num_inference_steps={num_inference_steps}\n")
360+
f.write(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] timesteps_array first few: {timesteps_array[:3]}\n")
354361
sigmas_array = timesteps_array / self.config.num_train_timesteps
362+
print(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] sigmas_array before shifting first few: {sigmas_array[:3]}")
363+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
364+
f.write(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] sigmas_array before shifting first few: {sigmas_array[:3]}\n")
355365
else:
356366
sigmas_array = np.array(sigmas).astype(np.float32)
357367
num_inference_steps = len(sigmas_array)
@@ -362,8 +372,14 @@ def set_timesteps(
362372
assert mu is not None, "mu cannot be None when use_dynamic_shifting is True"
363373
sigmas_array = self.time_shift(mu, 1.0, sigmas_array)
364374
else:
375+
print(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] Before shifting - self.shift={self.shift}, sigmas_array first few: {sigmas_array[:3]}")
376+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
377+
f.write(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] Before shifting - self.shift={self.shift}, sigmas_array first few: {sigmas_array[:3]}\n")
365378
sigmas_array = self.shift * sigmas_array / (
366379
1 + (self.shift - 1) * sigmas_array)
380+
print(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] After shifting - sigmas_array first few: {sigmas_array[:3]}")
381+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
382+
f.write(f"[FASTVIDEO SCHEDULER SIGMA DEBUG] After shifting - sigmas_array first few: {sigmas_array[:3]}\n")
367383

368384
# 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
369385
if self.config.shift_terminal:
@@ -415,14 +431,7 @@ def set_timesteps(
415431
[sigmas_tensor,
416432
torch.ones(1, device=sigmas_tensor.device)])
417433
else:
418-
# Handle final_sigmas_type parameter
419-
if self.config.final_sigmas_type == "sigma_min":
420-
# Use sigma_min instead of zero for final sigma
421-
final_sigma = torch.tensor([self.sigma_min], device=sigmas_tensor.device)
422-
else: # "zero" or default
423-
final_sigma = torch.zeros(1, device=sigmas_tensor.device)
424-
425-
sigmas_tensor = torch.cat([sigmas_tensor, final_sigma])
434+
sigmas_tensor = torch.cat([sigmas_tensor, torch.zeros(1, device=sigmas_tensor.device)])
426435

427436
self.timesteps = timesteps_tensor
428437
self.sigmas = sigmas_tensor
@@ -522,24 +531,43 @@ def step(
522531
next_sigma = lower_sigmas[..., None]
523532
dt = current_sigma - next_sigma
524533
else:
525-
assert self.step_index is not None, "step_index should not be None"
534+
if self.step_index is None:
535+
self._init_step_index(timestep)
536+
526537
sigma_idx = self.step_index
527538
sigma = self.sigmas[sigma_idx]
528539
sigma_next = self.sigmas[sigma_idx + 1]
529540

541+
# DETAILED SCHEDULER DEBUG LOGGING
542+
print(f"[FASTVIDEO SCHEDULER DEBUG] step_index: {self.step_index}, sigma_idx: {sigma_idx}")
543+
print(f"[FASTVIDEO SCHEDULER DEBUG] sigma: {sigma:.10f}, sigma_next: {sigma_next:.10f}")
544+
print(f"[FASTVIDEO SCHEDULER DEBUG] sigmas array length: {len(self.sigmas)}, first few: {self.sigmas[:3]}")
545+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
546+
f.write(f"[FASTVIDEO SCHEDULER DEBUG] step_index: {self.step_index}, sigma_idx: {sigma_idx}\n")
547+
f.write(f"[FASTVIDEO SCHEDULER DEBUG] sigma: {sigma:.10f}, sigma_next: {sigma_next:.10f}\n")
548+
f.write(f"[FASTVIDEO SCHEDULER DEBUG] sigmas array length: {len(self.sigmas)}, first few: {self.sigmas[:3]}\n")
549+
530550
current_sigma = sigma
531551
next_sigma = sigma_next
532552
dt = sigma_next - sigma
533553

554+
print(f"[FASTVIDEO SCHEDULER DEBUG] dt: {dt:.10f}, current_sigma: {current_sigma:.10f}, next_sigma: {next_sigma:.10f}")
555+
print(f"[FASTVIDEO SCHEDULER DEBUG] sample sum before step: {sample.float().sum().item():.6f}, model_output sum: {model_output.float().sum().item():.6f}")
556+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
557+
f.write(f"[FASTVIDEO SCHEDULER DEBUG] dt: {dt:.10f}, current_sigma: {current_sigma:.10f}, next_sigma: {next_sigma:.10f}\n")
558+
f.write(f"[FASTVIDEO SCHEDULER DEBUG] sample sum before step: {sample.float().sum().item():.6f}, model_output sum: {model_output.float().sum().item():.6f}\n")
559+
534560
if self.config.stochastic_sampling:
535561
x0 = sample - current_sigma * model_output
536562
noise = torch.randn_like(sample)
537563
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
538564
else:
539565
prev_sample = sample + dt * model_output
566+
print(f"[FASTVIDEO SCHEDULER DEBUG] final prev_sample sum: {prev_sample.float().sum().item():.6f}")
567+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
568+
f.write(f"[FASTVIDEO SCHEDULER DEBUG] final prev_sample sum: {prev_sample.float().sum().item():.6f}\n")
540569

541570
# upon completion increase step index by one
542-
assert self._step_index is not None, "_step_index should not be None"
543571
self._step_index += 1
544572
if per_token_timesteps is None:
545573
# Cast sample back to model compatible dtype
@@ -575,7 +603,7 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor,
575603
min_inv_rho = sigma_min**(1 / rho)
576604
max_inv_rho = sigma_max**(1 / rho)
577605
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
578-
return sigmas
606+
return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device)
579607

580608
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
581609
def _convert_to_exponential(self, in_sigmas: torch.Tensor,
@@ -600,7 +628,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor,
600628
sigmas = np.exp(
601629
np.linspace(math.log(sigma_max), math.log(sigma_min),
602630
num_inference_steps))
603-
return sigmas
631+
return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device)
604632

605633
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
606634
def _convert_to_beta(self,
@@ -631,7 +659,7 @@ def _convert_to_beta(self,
631659
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
632660
]
633661
])
634-
return sigmas
662+
return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device)
635663

636664
def _time_shift_exponential(
637665
self, mu: float, sigma: float,

fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
5656
print("[TEMPORARY] VAE replacement complete!")
5757

5858
self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
59-
shift=fastvideo_args.pipeline_config.flow_shift)
59+
shift=fastvideo_args.pipeline_config.flow_shift,
60+
use_karras_sigmas=True)
6061

6162
# Configure Cosmos-specific scheduler parameters (matching diffusers)
6263
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:209-219

fastvideo/pipelines/stages/denoising.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,21 @@ def forward(
660660
pipeline.add_module("transformer", self.transformer)
661661
fastvideo_args.model_loaded["transformer"] = True
662662

663+
# Prepare extra step kwargs for scheduler
664+
extra_step_kwargs = self.prepare_extra_func_kwargs(
665+
self.scheduler.step,
666+
{
667+
"generator": batch.generator,
668+
"eta": batch.eta
669+
},
670+
)
671+
672+
# Log the extra step kwargs
673+
print(f"[FASTVIDEO DEBUG] Extra step kwargs: {extra_step_kwargs}")
674+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
675+
f.write(f"[FASTVIDEO DEBUG] Extra step kwargs: {extra_step_kwargs}\n")
676+
677+
663678
# Setup precision to match diffusers exactly
664679
# Diffusers uses transformer.dtype (bfloat16) and converts inputs before transformer calls
665680
# For FSDP wrapped models, we need to access the underlying module
@@ -682,11 +697,43 @@ def forward(
682697
f.write(f"Denoising init: latents sum = {sum_value:.6f}, shape = {latents.shape}\n")
683698

684699

700+
# Configure scheduler to match Diffusers exactly (MUST be before set_timesteps)
701+
sigma_max = 80.0
702+
sigma_min = 0.002
703+
sigma_data = 1.0
704+
final_sigmas_type = "sigma_min"
705+
706+
print(f"[FASTVIDEO DEBUG] BEFORE config - scheduler.config: {self.scheduler.config}")
707+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
708+
f.write(f"[FASTVIDEO DEBUG] BEFORE config - scheduler.config: {self.scheduler.config}\n")
709+
710+
if self.scheduler is not None:
711+
self.scheduler.register_to_config(
712+
sigma_max=sigma_max,
713+
sigma_min=sigma_min,
714+
sigma_data=sigma_data,
715+
final_sigmas_type=final_sigmas_type,
716+
)
717+
print(f"[FASTVIDEO DEBUG] Applied scheduler config: sigma_max={sigma_max}, sigma_min={sigma_min}, sigma_data={sigma_data}, final_sigmas_type={final_sigmas_type}")
718+
print(f"[FASTVIDEO DEBUG] AFTER config - scheduler.config: {self.scheduler.config}")
719+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
720+
f.write(f"[FASTVIDEO DEBUG] Applied scheduler config: sigma_max={sigma_max}, sigma_min={sigma_min}, sigma_data={sigma_data}, final_sigmas_type={final_sigmas_type}\n")
721+
f.write(f"[FASTVIDEO DEBUG] AFTER config - scheduler.config: {self.scheduler.config}\n")
722+
685723
# Setup scheduler timesteps - use default scheduler sigma generation
686724
# The torch.linspace(0, 1, num_inference_steps) approach was incorrect for FlowMatchEulerDiscreteScheduler
687725
# Let the scheduler generate its own sigmas using the configured sigma_max, sigma_min, etc.
688726
self.scheduler.set_timesteps(num_inference_steps, device=latents.device)
689727
timesteps = self.scheduler.timesteps
728+
729+
# Debug what sigmas were actually generated
730+
print(f"[FASTVIDEO DEBUG] Generated sigmas - length: {len(self.scheduler.sigmas)}, first few: {self.scheduler.sigmas[:3]}")
731+
print(f"[FASTVIDEO DEBUG] Scheduler config after set_timesteps: sigma_max={getattr(self.scheduler.config, 'sigma_max', 'NOT_SET')}, sigma_min={getattr(self.scheduler.config, 'sigma_min', 'NOT_SET')}")
732+
print(f"[FASTVIDEO DEBUG] Scheduler properties: self.sigma_max={getattr(self.scheduler, 'sigma_max', 'NOT_SET')}, self.sigma_min={getattr(self.scheduler, 'sigma_min', 'NOT_SET')}")
733+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
734+
f.write(f"[FASTVIDEO DEBUG] Generated sigmas - length: {len(self.scheduler.sigmas)}, first few: {self.scheduler.sigmas[:3]}\n")
735+
f.write(f"[FASTVIDEO DEBUG] Scheduler config after set_timesteps: sigma_max={getattr(self.scheduler.config, 'sigma_max', 'NOT_SET')}, sigma_min={getattr(self.scheduler.config, 'sigma_min', 'NOT_SET')}\n")
736+
f.write(f"[FASTVIDEO DEBUG] Scheduler properties: self.sigma_max={getattr(self.scheduler, 'sigma_max', 'NOT_SET')}, self.sigma_min={getattr(self.scheduler, 'sigma_min', 'NOT_SET')}\n")
690737

691738
# Handle final sigmas like diffusers
692739
if hasattr(self.scheduler.config, 'final_sigmas_type') and self.scheduler.config.final_sigmas_type == "sigma_min":
@@ -844,6 +891,18 @@ def forward(
844891
print(f"[FASTVIDEO DEBUG] Step {i}: Preconditioning - c_skip={c_skip:.6f}, c_out={c_out:.6f}, latents_sum={latents.float().sum().item():.6f}")
845892
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
846893
f.write(f"[FASTVIDEO DEBUG] Step {i}: Preconditioning - c_skip={c_skip:.6f}, c_out={c_out:.6f}, latents_sum={latents.float().sum().item():.6f}\n")
894+
895+
# PRECONDITIONING DTYPE VERIFICATION
896+
print(f"[FASTVIDEO DTYPE DEBUG] Step {i}: Preconditioning dtypes")
897+
print(f"[FASTVIDEO DTYPE DEBUG] noise_pred dtype: {noise_pred.dtype}, latents dtype: {latents.dtype}")
898+
print(f"[FASTVIDEO DTYPE DEBUG] c_skip: {c_skip:.10f} (type: {type(c_skip)}), c_out: {c_out:.10f} (type: {type(c_out)})")
899+
print(f"[FASTVIDEO DTYPE DEBUG] target_dtype: {target_dtype}")
900+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
901+
f.write(f"[FASTVIDEO DTYPE DEBUG] Step {i}: Preconditioning dtypes\n")
902+
f.write(f"[FASTVIDEO DTYPE DEBUG] noise_pred dtype: {noise_pred.dtype}, latents dtype: {latents.dtype}\n")
903+
f.write(f"[FASTVIDEO DTYPE DEBUG] c_skip: {c_skip:.10f} (type: {type(c_skip)}), c_out: {c_out:.10f} (type: {type(c_out)})\n")
904+
f.write(f"[FASTVIDEO DTYPE DEBUG] target_dtype: {target_dtype}\n")
905+
847906
cond_pred = (c_skip * latents + c_out * noise_pred.float()).to(target_dtype)
848907

849908
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
@@ -954,14 +1013,34 @@ def forward(
9541013
else:
9551014
logger.warning(f"Step {i}: current_sigma too small ({current_sigma}), using final_pred directly")
9561015
noise_for_scheduler = final_pred
957-
1016+
9581017
# Debug: Check for NaN values before scheduler step
9591018
if torch.isnan(noise_for_scheduler).sum() > 0:
9601019
logger.error(f"Step {i}: NaN detected in noise_for_scheduler, sum: {noise_for_scheduler.float().sum().item()}")
9611020
logger.error(f"Step {i}: latents sum: {latents.float().sum().item()}, final_pred sum: {final_pred.float().sum().item()}, current_sigma: {current_sigma}")
962-
1021+
1022+
# DTYPE VERIFICATION LOGS
1023+
print(f"[FASTVIDEO DTYPE DEBUG] Step {i}: Before scheduler step")
1024+
print(f"[FASTVIDEO DTYPE DEBUG] latents dtype: {latents.dtype}, sum: {latents.float().sum().item():.6f}")
1025+
print(f"[FASTVIDEO DTYPE DEBUG] final_pred dtype: {final_pred.dtype}, sum: {final_pred.float().sum().item():.6f}")
1026+
print(f"[FASTVIDEO DTYPE DEBUG] noise_for_scheduler dtype: {noise_for_scheduler.dtype}, sum: {noise_for_scheduler.float().sum().item():.6f}")
1027+
print(f"[FASTVIDEO DTYPE DEBUG] current_sigma: {current_sigma:.10f} (type: {type(current_sigma)})")
1028+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
1029+
f.write(f"[FASTVIDEO DTYPE DEBUG] Step {i}: Before scheduler step\n")
1030+
f.write(f"[FASTVIDEO DTYPE DEBUG] latents dtype: {latents.dtype}, sum: {latents.float().sum().item():.6f}\n")
1031+
f.write(f"[FASTVIDEO DTYPE DEBUG] final_pred dtype: {final_pred.dtype}, sum: {final_pred.float().sum().item():.6f}\n")
1032+
f.write(f"[FASTVIDEO DTYPE DEBUG] noise_for_scheduler dtype: {noise_for_scheduler.dtype}, sum: {noise_for_scheduler.float().sum().item():.6f}\n")
1033+
f.write(f"[FASTVIDEO DTYPE DEBUG] current_sigma: {current_sigma:.10f} (type: {type(current_sigma)})\n")
1034+
9631035
# Standard scheduler step like diffusers
964-
latents = self.scheduler.step(noise_for_scheduler, t, latents, return_dict=False)[0]
1036+
latents = self.scheduler.step(noise_for_scheduler, t, latents, **extra_step_kwargs, return_dict=False)[0]
1037+
1038+
# DTYPE VERIFICATION LOGS AFTER SCHEDULER
1039+
print(f"[FASTVIDEO DTYPE DEBUG] Step {i}: After scheduler step")
1040+
print(f"[FASTVIDEO DTYPE DEBUG] latents dtype: {latents.dtype}, sum: {latents.float().sum().item():.6f}")
1041+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
1042+
f.write(f"[FASTVIDEO DTYPE DEBUG] Step {i}: After scheduler step\n")
1043+
f.write(f"[FASTVIDEO DTYPE DEBUG] latents dtype: {latents.dtype}, sum: {latents.float().sum().item():.6f}\n")
9651044
sum_value = latents.float().sum().item()
9661045
logger.info(f"CosmosDenoisingStage: step {i}, updated latents sum = {sum_value:.6f}")
9671046
# Write to output file

0 commit comments

Comments
 (0)