@@ -122,6 +122,7 @@ def __init__(
122
122
clip_sample : bool = True ,
123
123
set_alpha_to_one : bool = True ,
124
124
steps_offset : int = 0 ,
125
+ prediction_type : str = "epsilon" ,
125
126
):
126
127
if trained_betas is not None :
127
128
self .betas = torch .from_numpy (trained_betas )
@@ -138,6 +139,8 @@ def __init__(
138
139
else :
139
140
raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
140
141
142
+ self .prediction_type = prediction_type
143
+
141
144
self .alphas = 1.0 - self .betas
142
145
self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
143
146
@@ -258,7 +261,19 @@ def step(
258
261
259
262
# 3. compute predicted original sample from predicted noise also called
260
263
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
261
- pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
264
+ if self .prediction_type == "epsilon" :
265
+ pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
266
+ elif self .prediction_type == "sample" :
267
+ pred_original_sample = model_output
268
+ elif self .prediction_type == "v_prediction" :
269
+ pred_original_sample = (alpha_prod_t ** 0.5 ) * sample - (beta_prod_t ** 0.5 ) * model_output
270
+ # predict V
271
+ model_output = (alpha_prod_t ** 0.5 ) * model_output + (beta_prod_t ** 0.5 ) * sample
272
+ else :
273
+ raise ValueError (
274
+ f"prediction_type given as { self .prediction_type } must be one of `epsilon`, `sample`, or"
275
+ " `v_prediction`"
276
+ )
262
277
263
278
# 4. Clip "predicted x_0"
264
279
if self .config .clip_sample :
0 commit comments