From 425035d810f8a9bff8642820d89d79ff5cd4119d Mon Sep 17 00:00:00 2001 From: cmeka Date: Tue, 21 Oct 2025 12:46:58 -0400 Subject: [PATCH] Fix custom sigmas for supported schedulers Previously, only unipc, dpm++, and dpm++_sde schedulers preserved custom input sigmas exactly. Other schedulers such as (euler, lcm, deis, etc.) would transform or modify the sigmas through their set_timesteps() methods, causing inconsistent behavior. --- wanvideo/schedulers/__init__.py | 55 +++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/wanvideo/schedulers/__init__.py b/wanvideo/schedulers/__init__.py index 8df1749f..f83c3ac4 100644 --- a/wanvideo/schedulers/__init__.py +++ b/wanvideo/schedulers/__init__.py @@ -31,6 +31,11 @@ "rcm" ] +def _apply_custom_sigmas(sample_scheduler, sigmas, device): + sample_scheduler.sigmas = sigmas.to(device) + sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device) + sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps) + def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, flowedit_args=None, denoise_strength=1.0, sigmas=None, log_timesteps=False, **kwargs): timesteps = None if 'unipc' in scheduler: @@ -38,16 +43,17 @@ def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transfo if sigmas is None: sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler)) else: - sample_scheduler.sigmas = sigmas.to(device) - sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device) - sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps) + _apply_custom_sigmas(sample_scheduler, sigmas, device) elif scheduler in ['euler/beta', 'euler']: sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta')) - if flowedit_args: #seems to work better - timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=get_sampling_sigmas(steps, shift)) + if sigmas is None: + if flowedit_args: #seems to work better + timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=get_sampling_sigmas(steps, shift)) + else: + sample_scheduler.set_timesteps(steps, device=device) else: - sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None) + _apply_custom_sigmas(sample_scheduler, sigmas, device) elif 'dpm' in scheduler: if 'sde' in scheduler: algorithm_type = "sde-dpmsolver++" @@ -57,16 +63,20 @@ def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transfo if sigmas is None: sample_scheduler.set_timesteps(steps, device=device, use_beta_sigmas=('beta' in scheduler)) else: - sample_scheduler.sigmas = sigmas.to(device) - sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device) - sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps) + _apply_custom_sigmas(sample_scheduler, sigmas, device) elif scheduler == 'deis': sample_scheduler = DEISMultistepScheduler(use_flow_sigmas=True, prediction_type="flow_prediction", flow_shift=shift) - sample_scheduler.set_timesteps(steps, device=device) - sample_scheduler.sigmas[-1] = 1e-6 + if sigmas is None: + sample_scheduler.set_timesteps(steps, device=device) + sample_scheduler.sigmas[-1] = 1e-6 + else: + _apply_custom_sigmas(sample_scheduler, sigmas, device) elif 'lcm' in scheduler: sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta')) - sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None) + if sigmas is None: + sample_scheduler.set_timesteps(steps, device=device) + else: + _apply_custom_sigmas(sample_scheduler, sigmas, device) elif 'flowmatch_causvid' in scheduler: if sigmas is not None: raise NotImplementedError("This scheduler does not support custom sigmas") @@ -99,17 +109,28 @@ def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transfo sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)]) elif 'flowmatch_pusa' in scheduler: sample_scheduler = FlowMatchSchedulerPusa(shift=shift, sigma_min=0.0, extra_one_step=True) - sample_scheduler.set_timesteps(steps+1, denoising_strength=denoise_strength, shift=shift, - sigmas=sigmas[:-1].tolist() if sigmas is not None else None) + if sigmas is None: + sample_scheduler.set_timesteps(steps+1, denoising_strength=denoise_strength, shift=shift) + else: + _apply_custom_sigmas(sample_scheduler, sigmas, device) elif scheduler == 'res_multistep': sample_scheduler = FlowMatchSchedulerResMultistep(shift=shift) - sample_scheduler.set_timesteps(steps, denoising_strength=denoise_strength, sigmas=sigmas[:-1].tolist() if sigmas is not None else None) + if sigmas is None: + sample_scheduler.set_timesteps(steps, denoising_strength=denoise_strength) + else: + _apply_custom_sigmas(sample_scheduler, sigmas, device) elif "sa_ode_stable" in scheduler: sample_scheduler = FlowMatchSAODEStableScheduler(shift=shift, **kwargs) - sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None) + if sigmas is None: + sample_scheduler.set_timesteps(steps, device=device) + else: + _apply_custom_sigmas(sample_scheduler, sigmas, device) elif 'rcm' in scheduler: sample_scheduler = rCMFlowMatchScheduler() - sample_scheduler.set_timesteps(steps, sigma_max=120) + if sigmas is None: + sample_scheduler.set_timesteps(steps, sigma_max=120) + else: + _apply_custom_sigmas(sample_scheduler, sigmas, device) if timesteps is None: timesteps = sample_scheduler.timesteps