Skip to content

Commit 9515178

Browse files
committed
Use convert_model_output
1 parent 3b8d017 commit 9515178

File tree

2 files changed

+10
-68
lines changed

2 files changed

+10
-68
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 6 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -527,72 +527,13 @@ def _convert_to_beta(
527527
)
528528
return sigmas
529529

530-
def convert_noise_to_x0(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int) -> torch.Tensor:
531-
"""
532-
Convert to original sample x0 from noise prediction.
533-
534-
Args:
535-
model_output (`torch.Tensor`): The model output.
536-
sample (`torch.Tensor`): A current instance of a sample created by the diffusion process.
537-
timestep (`int`): The current discrete timestep in the diffusion chain.
538-
539-
Returns:
540-
`torch.Tensor`: The predicted original sample (x0).
541-
"""
542-
if self.step_index is None:
543-
self._init_step_index(timestep)
544-
sigma = self.sigmas[self.step_index]
545-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
546-
547-
if self.config.prediction_type == "epsilon":
548-
return (sample - sigma_t * model_output) / alpha_t
549-
elif self.config.prediction_type == "sample":
550-
return model_output
551-
elif self.config.prediction_type == "v_prediction":
552-
return alpha_t * sample - sigma_t * model_output
553-
else:
554-
raise ValueError(
555-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
556-
" `v_prediction`."
557-
)
558-
559-
def convert_x0_to_noise(self, pred_x0: torch.Tensor, sample: torch.Tensor, timestep: int) -> torch.Tensor:
560-
"""
561-
Convert to noise prediction from original sample x0.
562-
563-
Args:
564-
pred_x0 (`torch.Tensor`): The predicted original sample (x0).
565-
sample (`torch.Tensor`): A current instance of a sample created by the diffusion process.
566-
timestep (`int`): The current discrete timestep in the diffusion chain.
567-
568-
Returns:
569-
`torch.Tensor`: The converted noise prediction.
570-
"""
571-
if self.step_index is None:
572-
self._init_step_index(timestep)
573-
sigma = self.sigmas[self.step_index]
574-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
575-
576-
if self.config.prediction_type == "epsilon":
577-
x0_pred = (sample - alpha_t * pred_x0) / sigma_t
578-
elif self.config.prediction_type == "sample":
579-
x0_pred = pred_x0
580-
elif self.config.prediction_type == "v_prediction":
581-
x0_pred = alpha_t * pred_x0 + sigma_t * sample
582-
else:
583-
raise ValueError(
584-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
585-
" `v_prediction`."
586-
)
587-
if self.config.thresholding:
588-
x0_pred = self._threshold_sample(x0_pred)
589-
return x0_pred
590-
591530
def convert_model_output(
592531
self,
593532
model_output: torch.Tensor,
594533
*args,
595534
sample: torch.Tensor = None,
535+
predict_x0: bool = True,
536+
step_index: Optional[int] = None,
596537
**kwargs,
597538
) -> torch.Tensor:
598539
r"""
@@ -622,11 +563,12 @@ def convert_model_output(
622563
"1.0.0",
623564
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
624565
)
566+
step_index = step_index if step_index is not None else self.step_index
625567

626-
sigma = self.sigmas[self.step_index]
568+
sigma = self.sigmas[step_index]
627569
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
628570

629-
if self.predict_x0:
571+
if predict_x0:
630572
if self.config.prediction_type == "epsilon":
631573
x0_pred = (sample - sigma_t * model_output) / alpha_t
632574
elif self.config.prediction_type == "sample":
@@ -996,7 +938,7 @@ def step(
996938
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
997939
)
998940

999-
model_output_convert = self.convert_model_output(model_output, sample=sample)
941+
model_output_convert = self.convert_model_output(model_output, sample=sample, predict_x0=self.predict_x0)
1000942
if use_corrector:
1001943
sample = self.multistep_uni_c_bh_update(
1002944
this_model_output=model_output_convert,

tests/schedulers/test_scheduler_unipc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ def test_convert_model_output(self):
297297

298298
for i, t in enumerate(scheduler.timesteps):
299299
residual = model(sample, t)
300-
pred_x0 = scheduler.convert_noise_to_x0(residual, sample, timestep=t)
301-
pred_noise = scheduler.convert_x0_to_noise(pred_x0, sample, timestep=t)
300+
pred_x0 = scheduler.convert_model_output(residual, sample=sample, predict_x0=True, step_index=i)
301+
pred_noise = scheduler.convert_model_output(pred_x0, sample=sample, predict_x0=False, step_index=i)
302302
assert (
303303
abs(torch.mean(torch.abs(pred_noise)).item() - torch.mean(torch.abs(residual)).item()) < 1e-4
304304
), prediction_type
@@ -314,8 +314,8 @@ def test_convert_model_output(self):
314314
scheduler.set_timesteps(num_inference_steps)
315315
for i, t in enumerate(scheduler.timesteps):
316316
residual = model(sample, t)
317-
pred_x0 = scheduler.convert_noise_to_x0(residual, sample, timestep=t)
318-
pred_noise = scheduler.convert_x0_to_noise(pred_x0, sample, timestep=t)
317+
pred_x0 = scheduler.convert_model_output(residual, sample=sample, predict_x0=True, step_index=i)
318+
pred_noise = scheduler.convert_model_output(pred_x0, sample=sample, predict_x0=False, step_index=i)
319319
sample = scheduler.step(residual, t, sample).prev_sample
320320
assert (
321321
abs(torch.mean(torch.abs(pred_noise)).item() - torch.mean(torch.abs(residual)).item()) < 2e-2

0 commit comments

Comments
 (0)