Skip to content

Commit d52388f

Browse files
Deprecate predict_epsilon (#1393)
* Adapt ddpm, ddpmsolver to prediction_type. * Deprecate predict_epsilon in __init__. * Bring FlaxDDIMScheduler up to date with DDIMScheduler. * Set prediction_type as an ivar for consistency. * Convert pipeline_ddpm * Adapt tests. * Adapt unconditional training script. * Adapt BitDiffusion example. * Add missing kwargs in dpmsolver_multistep * Ugly workaround to accept deprecated predict_epsilon when loading schedulers using from_pretrained. * make style * Remove import no longer in use. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Use config.prediction_type everywhere * Add a couple of Flax prediction type tests. * make style * fix register deprecated arg Co-authored-by: Patrick von Platen <[email protected]>
1 parent babfb8a commit d52388f

17 files changed

+260
-87
lines changed

examples/community/bit_diffusion.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step(
138138
model_output: torch.FloatTensor,
139139
timestep: int,
140140
sample: torch.FloatTensor,
141-
predict_epsilon=True,
141+
prediction_type="epsilon",
142142
generator=None,
143143
return_dict: bool = True,
144144
) -> Union[DDPMSchedulerOutput, Tuple]:
@@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step(
150150
timestep (`int`): current discrete timestep in the diffusion chain.
151151
sample (`torch.FloatTensor`):
152152
current instance of sample being created by diffusion process.
153-
predict_epsilon (`bool`):
154-
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
153+
prediction_type (`str`, default `epsilon`):
154+
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
155155
generator: random number generator.
156156
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
157157
Returns:
@@ -174,10 +174,12 @@ def ddpm_bit_scheduler_step(
174174

175175
# 2. compute predicted original sample from predicted noise also called
176176
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
177-
if predict_epsilon:
177+
if prediction_type == "epsilon":
178178
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
179-
else:
179+
elif prediction_type == "sample":
180180
pred_original_sample = model_output
181+
else:
182+
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
181183

182184
# 3. Clip "predicted x_0"
183185
scale = self.bit_scale

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,10 @@ def parse_args():
194194
)
195195

196196
parser.add_argument(
197-
"--predict_epsilon",
198-
action="store_true",
199-
default=True,
197+
"--prediction_type",
198+
type=str,
199+
default="epsilon",
200+
choices=["epsilon", "sample"],
200201
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
201202
)
202203

@@ -256,13 +257,13 @@ def main(args):
256257
"UpBlock2D",
257258
),
258259
)
259-
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
260+
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
260261

261-
if accepts_predict_epsilon:
262+
if accepts_prediction_type:
262263
noise_scheduler = DDPMScheduler(
263264
num_train_timesteps=args.ddpm_num_steps,
264265
beta_schedule=args.ddpm_beta_schedule,
265-
predict_epsilon=args.predict_epsilon,
266+
prediction_type=args.prediction_type,
266267
)
267268
else:
268269
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
@@ -365,9 +366,9 @@ def transforms(examples):
365366
# Predict the noise residual
366367
model_output = model(noisy_images, timesteps).sample
367368

368-
if args.predict_epsilon:
369+
if args.prediction_type == "epsilon":
369370
loss = F.mse_loss(model_output, noise) # this could have different weights!
370-
else:
371+
elif args.prediction_type == "sample":
371372
alpha_t = _extract_into_tensor(
372373
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
373374
)
@@ -376,6 +377,8 @@ def transforms(examples):
376377
model_output, clean_images, reduction="none"
377378
) # use SNR weighting from distillation paper
378379
loss = loss.mean()
380+
else:
381+
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
379382

380383
accelerator.backward(loss)
381384

src/diffusers/configuration_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
195195
if "dtype" in unused_kwargs:
196196
init_dict["dtype"] = unused_kwargs.pop("dtype")
197197

198+
if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
199+
deprecate("remove this", "0.10.0", "remove")
200+
predict_epsilon = unused_kwargs.pop("predict_epsilon")
201+
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample"
202+
198203
# Return model and optionally state and/or unused_kwargs
199204
model = cls(**init_dict)
200205

src/diffusers/experimental/rl/value_guided_sampling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale):
8989
x = x + scale * grad
9090
x = self.reset_x0(x, conditions, self.action_dim)
9191
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
92+
# TODO: set prediction_type when instantiating the model
9293
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
9394

9495
# apply conditions to the trajectory

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ def __call__(
7070
generated images.
7171
"""
7272
message = (
73-
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
74-
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
73+
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
74+
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
7575
)
7676
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
7777

7878
if predict_epsilon is not None:
7979
new_config = dict(self.scheduler.config)
80-
new_config["predict_epsilon"] = predict_epsilon
80+
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
8181
self.scheduler._internal_dict = FrozenDict(new_config)
8282

8383
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
@@ -114,9 +114,7 @@ def __call__(
114114
model_output = self.unet(image, t).sample
115115

116116
# 2. compute previous image: x_t -> x_t-1
117-
image = self.scheduler.step(
118-
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
119-
).prev_sample
117+
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
120118

121119
image = (image / 2 + 0.5).clamp(0, 1)
122120
image = image.cpu().permute(0, 2, 3, 1).numpy()

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424

2525
from ..configuration_utils import ConfigMixin, register_to_config
26-
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
26+
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
2727
from .scheduling_utils import SchedulerMixin
2828

2929

@@ -106,6 +106,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
106106
an offset added to the inference steps. You can use a combination of `offset=1` and
107107
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
108108
stable diffusion.
109+
prediction_type (`str`, default `epsilon`):
110+
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
111+
`v-prediction` is not supported for this scheduler.
109112
110113
"""
111114

@@ -123,7 +126,16 @@ def __init__(
123126
set_alpha_to_one: bool = True,
124127
steps_offset: int = 0,
125128
prediction_type: str = "epsilon",
129+
**kwargs,
126130
):
131+
message = (
132+
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
133+
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
134+
)
135+
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
136+
if predict_epsilon is not None:
137+
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
138+
127139
if trained_betas is not None:
128140
self.betas = torch.from_numpy(trained_betas)
129141
elif beta_schedule == "linear":
@@ -139,8 +151,6 @@ def __init__(
139151
else:
140152
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
141153

142-
self.prediction_type = prediction_type
143-
144154
self.alphas = 1.0 - self.betas
145155
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
146156

@@ -261,17 +271,17 @@ def step(
261271

262272
# 3. compute predicted original sample from predicted noise also called
263273
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
264-
if self.prediction_type == "epsilon":
274+
if self.config.prediction_type == "epsilon":
265275
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
266-
elif self.prediction_type == "sample":
276+
elif self.config.prediction_type == "sample":
267277
pred_original_sample = model_output
268-
elif self.prediction_type == "v_prediction":
278+
elif self.config.prediction_type == "v_prediction":
269279
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
270280
# predict V
271281
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
272282
else:
273283
raise ValueError(
274-
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
284+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
275285
" `v_prediction`"
276286
)
277287

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import jax.numpy as jnp
2424

2525
from ..configuration_utils import ConfigMixin, register_to_config
26+
from ..utils import deprecate
2627
from .scheduling_utils_flax import (
2728
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
2829
FlaxSchedulerMixin,
@@ -108,6 +109,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
108109
an offset added to the inference steps. You can use a combination of `offset=1` and
109110
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
110111
stable diffusion.
112+
prediction_type (`str`, default `epsilon`):
113+
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
114+
`v-prediction` is not supported for this scheduler.
115+
111116
"""
112117

113118
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@@ -125,7 +130,17 @@ def __init__(
125130
beta_schedule: str = "linear",
126131
set_alpha_to_one: bool = True,
127132
steps_offset: int = 0,
133+
prediction_type: str = "epsilon",
134+
**kwargs,
128135
):
136+
message = (
137+
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
138+
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
139+
)
140+
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
141+
if predict_epsilon is not None:
142+
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
143+
129144
if beta_schedule == "linear":
130145
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
131146
elif beta_schedule == "scaled_linear":
@@ -259,7 +274,19 @@ def step(
259274

260275
# 3. compute predicted original sample from predicted noise also called
261276
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
262-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
277+
if self.config.prediction_type == "epsilon":
278+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
279+
elif self.config.prediction_type == "sample":
280+
pred_original_sample = model_output
281+
elif self.config.prediction_type == "v_prediction":
282+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
283+
# predict V
284+
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
285+
else:
286+
raise ValueError(
287+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
288+
" `v_prediction`"
289+
)
263290

264291
# 4. compute variance: "sigma_t(η)" -> see formula (16)
265292
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
9999
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
100100
clip_sample (`bool`, default `True`):
101101
option to clip predicted sample between -1 and 1 for numerical stability.
102-
predict_epsilon (`bool`):
103-
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
104-
102+
prediction_type (`str`, default `epsilon`):
103+
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
104+
`v-prediction` is not supported for this scheduler.
105105
"""
106106

107107
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@@ -116,8 +116,17 @@ def __init__(
116116
trained_betas: Optional[np.ndarray] = None,
117117
variance_type: str = "fixed_small",
118118
clip_sample: bool = True,
119-
predict_epsilon: bool = True,
119+
prediction_type: str = "epsilon",
120+
**kwargs,
120121
):
122+
message = (
123+
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
124+
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
125+
)
126+
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
127+
if predict_epsilon is not None:
128+
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
129+
121130
if trained_betas is not None:
122131
self.betas = torch.from_numpy(trained_betas)
123132
elif beta_schedule == "linear":
@@ -241,13 +250,13 @@ def step(
241250
242251
"""
243252
message = (
244-
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
245-
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
253+
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
254+
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
246255
)
247256
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
248-
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
257+
if predict_epsilon is not None:
249258
new_config = dict(self.config)
250-
new_config["predict_epsilon"] = predict_epsilon
259+
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
251260
self._internal_dict = FrozenDict(new_config)
252261

253262
t = timestep
@@ -265,10 +274,15 @@ def step(
265274

266275
# 2. compute predicted original sample from predicted noise also called
267276
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
268-
if self.config.predict_epsilon:
277+
if self.config.prediction_type == "epsilon":
269278
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
270-
else:
279+
elif self.config.prediction_type == "sample":
271280
pred_original_sample = model_output
281+
else:
282+
raise ValueError(
283+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
284+
" for the DDPMScheduler."
285+
)
272286

273287
# 3. Clip "predicted x_0"
274288
if self.config.clip_sample:

0 commit comments

Comments
 (0)