@@ -120,8 +120,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
120
120
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
121
121
sampling, and `solver_order=3` for unconditional sampling.
122
122
prediction_type (`str`, default `epsilon`):
123
- indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
124
- `v-prediction` is not supported for this scheduler .
123
+ indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
124
+ or `v-prediction`.
125
125
thresholding (`bool`, default `False`):
126
126
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
127
127
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@@ -252,7 +252,7 @@ def convert_model_output(
252
252
"""
253
253
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
254
254
255
- DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
255
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
256
256
discretize an integral of the data prediction model. So we need to first convert the model output to the
257
257
corresponding type to match the algorithm.
258
258
@@ -275,10 +275,13 @@ def convert_model_output(
275
275
x0_pred = (sample - sigma_t * model_output ) / alpha_t
276
276
elif self .config .prediction_type == "sample" :
277
277
x0_pred = model_output
278
+ elif self .config .prediction_type == "v_prediction" :
279
+ alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
280
+ x0_pred = alpha_t * sample - sigma_t * model_output
278
281
else :
279
282
raise ValueError (
280
- f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample` "
281
- " for the FlaxDPMSolverMultistepScheduler."
283
+ f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, "
284
+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
282
285
)
283
286
284
287
if self .config .thresholding :
@@ -299,10 +302,14 @@ def convert_model_output(
299
302
alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
300
303
epsilon = (sample - alpha_t * model_output ) / sigma_t
301
304
return epsilon
305
+ elif self .config .prediction_type == "v_prediction" :
306
+ alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
307
+ epsilon = alpha_t * model_output + sigma_t * sample
308
+ return epsilon
302
309
else :
303
310
raise ValueError (
304
- f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample` "
305
- " for the FlaxDPMSolverMultistepScheduler."
311
+ f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, "
312
+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
306
313
)
307
314
308
315
def dpm_solver_first_order_update (
0 commit comments