Skip to content

Commit 30f6f44

Browse files
patil-surajpcuenca
andauthored
add v prediction (#1386)
* add v prediction * adat euler for v pred * velocity -> v_prediction * simplify * fix naming * Update src/diffusers/schedulers/scheduling_euler_discrete.py Co-authored-by: Pedro Cuenca <[email protected]> * style Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 9f47638 commit 30f6f44

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
clip_sample: bool = True,
123123
set_alpha_to_one: bool = True,
124124
steps_offset: int = 0,
125+
prediction_type: str = "epsilon",
125126
):
126127
if trained_betas is not None:
127128
self.betas = torch.from_numpy(trained_betas)
@@ -138,6 +139,8 @@ def __init__(
138139
else:
139140
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
140141

142+
self.prediction_type = prediction_type
143+
141144
self.alphas = 1.0 - self.betas
142145
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
143146

@@ -258,7 +261,19 @@ def step(
258261

259262
# 3. compute predicted original sample from predicted noise also called
260263
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
261-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
264+
if self.prediction_type == "epsilon":
265+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
266+
elif self.prediction_type == "sample":
267+
pred_original_sample = model_output
268+
elif self.prediction_type == "v_prediction":
269+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
270+
# predict V
271+
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
272+
else:
273+
raise ValueError(
274+
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
275+
" `v_prediction`"
276+
)
262277

263278
# 4. Clip "predicted x_0"
264279
if self.config.clip_sample:

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
beta_end: float = 0.02,
7979
beta_schedule: str = "linear",
8080
trained_betas: Optional[np.ndarray] = None,
81+
prediction_type: str = "epsilon",
8182
):
8283
if trained_betas is not None:
8384
self.betas = torch.from_numpy(trained_betas)
@@ -91,6 +92,8 @@ def __init__(
9192
else:
9293
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
9394

95+
self.prediction_type = prediction_type
96+
9497
self.alphas = 1.0 - self.betas
9598
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
9699

@@ -229,7 +232,15 @@ def step(
229232
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
230233

231234
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
232-
pred_original_sample = sample - sigma_hat * model_output
235+
if self.prediction_type == "epsilon":
236+
pred_original_sample = sample - sigma_hat * model_output
237+
elif self.prediction_type == "v_prediction":
238+
# * c_out + input * c_skip
239+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
240+
else:
241+
raise ValueError(
242+
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
243+
)
233244

234245
# 2. Convert to an ODE derivative
235246
derivative = (sample - pred_original_sample) / sigma_hat

0 commit comments

Comments
 (0)