Skip to content

Commit 1bc1554

Browse files
committed
fix convert_model_output in schedulers
1 parent 884d29e commit 1bc1554

File tree

6 files changed

+24
-12
lines changed

6 files changed

+24
-12
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,13 @@ def convert_model_output(
501501
x0_pred = model_output
502502
elif self.config.prediction_type == "v_prediction":
503503
x0_pred = alpha_t * sample - sigma_t * model_output
504+
elif self.config.prediction_type == "flow_prediction":
505+
sigma_t = self.sigmas[self.step_index]
506+
x0_pred = sample - sigma_t * model_output
504507
else:
505508
raise ValueError(
506-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
507-
" `v_prediction` for the DEISMultistepScheduler."
509+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
510+
"`v_prediction`, or `flow_prediction` for the DEISMultistepScheduler."
508511
)
509512

510513
if self.config.thresholding:

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,8 +666,8 @@ def convert_model_output(
666666
x0_pred = sample - sigma_t * model_output
667667
else:
668668
raise ValueError(
669-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
670-
" `v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
669+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
670+
"`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
671671
)
672672

673673
if self.config.thresholding:

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,8 @@ def convert_model_output(
538538
x0_pred = sample - sigma_t * model_output
539539
else:
540540
raise ValueError(
541-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
542-
" `v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
541+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
542+
"`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
543543
)
544544

545545
if self.config.thresholding:

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,10 +606,13 @@ def convert_model_output(
606606
sigma = self.sigmas[self.step_index]
607607
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
608608
x0_pred = alpha_t * sample - sigma_t * model_output
609+
elif self.config.prediction_type == "flow_prediction":
610+
sigma_t = self.sigmas[self.step_index]
611+
x0_pred = sample - sigma_t * model_output
609612
else:
610613
raise ValueError(
611-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
612-
" `v_prediction` for the DPMSolverSinglestepScheduler."
614+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
615+
"`v_prediction`, or `flow_prediction` for the DPMSolverSinglestepScheduler."
613616
)
614617

615618
if self.config.thresholding:

src/diffusers/schedulers/scheduling_sasolver.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,10 +542,13 @@ def convert_model_output(
542542
x0_pred = model_output
543543
elif self.config.prediction_type == "v_prediction":
544544
x0_pred = alpha_t * sample - sigma_t * model_output
545+
elif self.config.prediction_type == "flow_prediction":
546+
sigma_t = self.sigmas[self.step_index]
547+
x0_pred = sample - sigma_t * model_output
545548
else:
546549
raise ValueError(
547-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
548-
" `v_prediction` for the SASolverScheduler."
550+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
551+
"`v_prediction`, or `flow_prediction` for the SASolverScheduler."
549552
)
550553

551554
if self.config.thresholding:

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,10 +605,13 @@ def convert_model_output(
605605
x0_pred = model_output
606606
elif self.config.prediction_type == "v_prediction":
607607
x0_pred = alpha_t * sample - sigma_t * model_output
608+
elif self.config.prediction_type == "flow_prediction":
609+
sigma_t = self.sigmas[self.step_index]
610+
x0_pred = sample - sigma_t * model_output
608611
else:
609612
raise ValueError(
610-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
611-
" `v_prediction` for the UniPCMultistepScheduler."
613+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
614+
"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
612615
)
613616

614617
if self.config.thresholding:

0 commit comments

Comments
 (0)