Skip to content

Commit b9e921f

Browse files
authored
added initial v-pred support to DPM-solver (#1421)
* added initial v-pred support to DPM-solver * fix sign * added v_prediction to flax * fixed typo
1 parent 7684518 commit b9e921f

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
8888
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
8989
sampling, and `solver_order=3` for unconditional sampling.
9090
prediction_type (`str`, default `epsilon`):
91-
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
92-
`v-prediction` is not supported for this scheduler.
91+
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
92+
or `v-prediction`.
9393
thresholding (`bool`, default `False`):
9494
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
9595
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@@ -212,7 +212,7 @@ def convert_model_output(
212212
"""
213213
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
214214
215-
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
215+
DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
216216
discretize an integral of the data prediction model. So we need to first convert the model output to the
217217
corresponding type to match the algorithm.
218218
@@ -235,10 +235,13 @@ def convert_model_output(
235235
x0_pred = (sample - sigma_t * model_output) / alpha_t
236236
elif self.config.prediction_type == "sample":
237237
x0_pred = model_output
238+
elif self.config.prediction_type == "v_prediction":
239+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
240+
x0_pred = alpha_t * sample - sigma_t * model_output
238241
else:
239242
raise ValueError(
240-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
241-
" for the DPMSolverMultistepScheduler."
243+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
244+
" `v_prediction` for the DPMSolverMultistepScheduler."
242245
)
243246

244247
if self.config.thresholding:
@@ -260,10 +263,14 @@ def convert_model_output(
260263
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
261264
epsilon = (sample - alpha_t * model_output) / sigma_t
262265
return epsilon
266+
elif self.config.prediction_type == "v_prediction":
267+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
268+
epsilon = alpha_t * model_output + sigma_t * sample
269+
return epsilon
263270
else:
264271
raise ValueError(
265-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
266-
" for the DPMSolverMultistepScheduler."
272+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
273+
" `v_prediction` for the DPMSolverMultistepScheduler."
267274
)
268275

269276
def dpm_solver_first_order_update(

src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
120120
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
121121
sampling, and `solver_order=3` for unconditional sampling.
122122
prediction_type (`str`, default `epsilon`):
123-
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
124-
`v-prediction` is not supported for this scheduler.
123+
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
124+
or `v-prediction`.
125125
thresholding (`bool`, default `False`):
126126
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
127127
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@@ -252,7 +252,7 @@ def convert_model_output(
252252
"""
253253
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
254254
255-
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
255+
DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
256256
discretize an integral of the data prediction model. So we need to first convert the model output to the
257257
corresponding type to match the algorithm.
258258
@@ -275,10 +275,13 @@ def convert_model_output(
275275
x0_pred = (sample - sigma_t * model_output) / alpha_t
276276
elif self.config.prediction_type == "sample":
277277
x0_pred = model_output
278+
elif self.config.prediction_type == "v_prediction":
279+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
280+
x0_pred = alpha_t * sample - sigma_t * model_output
278281
else:
279282
raise ValueError(
280-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
281-
" for the FlaxDPMSolverMultistepScheduler."
283+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
284+
" or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
282285
)
283286

284287
if self.config.thresholding:
@@ -299,10 +302,14 @@ def convert_model_output(
299302
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
300303
epsilon = (sample - alpha_t * model_output) / sigma_t
301304
return epsilon
305+
elif self.config.prediction_type == "v_prediction":
306+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
307+
epsilon = alpha_t * model_output + sigma_t * sample
308+
return epsilon
302309
else:
303310
raise ValueError(
304-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
305-
" for the FlaxDPMSolverMultistepScheduler."
311+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
312+
" or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
306313
)
307314

308315
def dpm_solver_first_order_update(

0 commit comments

Comments
 (0)