@@ -56,6 +56,7 @@ def __init__(
5656 beta_end = 0.02 ,
5757 beta_schedule = "linear" ,
5858 tensor_format = "pt" ,
59+ skip_prk_steps = False ,
5960 ):
6061
6162 if beta_schedule == "linear" :
@@ -88,24 +89,35 @@ def __init__(
8889 # setable values
8990 self .num_inference_steps = None
9091 self ._timesteps = np .arange (0 , num_train_timesteps )[::- 1 ].copy ()
92+ self ._offset = 0
9193 self .prk_timesteps = None
9294 self .plms_timesteps = None
9395 self .timesteps = None
9496
9597 self .tensor_format = tensor_format
9698 self .set_format (tensor_format = tensor_format )
9799
98- def set_timesteps (self , num_inference_steps ):
100+ def set_timesteps (self , num_inference_steps , offset = 0 ):
99101 self .num_inference_steps = num_inference_steps
100102 self ._timesteps = list (
101103 range (0 , self .config .num_train_timesteps , self .config .num_train_timesteps // num_inference_steps )
102104 )
105+ self ._offset = offset
106+ self ._timesteps = [t + self ._offset for t in self ._timesteps ]
107+
108+ if self .config .skip_prk_steps :
109+ # for some models like stable diffusion the prk steps can/should be skipped to
110+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
111+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
112+ self .prk_timesteps = []
113+ self .plms_timesteps = list (reversed (self ._timesteps [:- 1 ] + self ._timesteps [- 2 :- 1 ] + self ._timesteps [- 1 :]))
114+ else :
115+ prk_timesteps = np .array (self ._timesteps [- self .pndm_order :]).repeat (2 ) + np .tile (
116+ np .array ([0 , self .config .num_train_timesteps // num_inference_steps // 2 ]), self .pndm_order
117+ )
118+ self .prk_timesteps = list (reversed (prk_timesteps [:- 1 ].repeat (2 )[1 :- 1 ]))
119+ self .plms_timesteps = list (reversed (self ._timesteps [:- 3 ]))
103120
104- prk_timesteps = np .array (self ._timesteps [- self .pndm_order :]).repeat (2 ) + np .tile (
105- np .array ([0 , self .config .num_train_timesteps // num_inference_steps // 2 ]), self .pndm_order
106- )
107- self .prk_timesteps = list (reversed (prk_timesteps [:- 1 ].repeat (2 )[1 :- 1 ]))
108- self .plms_timesteps = list (reversed (self ._timesteps [:- 3 ]))
109121 self .timesteps = self .prk_timesteps + self .plms_timesteps
110122
111123 self .counter = 0
@@ -117,7 +129,7 @@ def step(
117129 timestep : int ,
118130 sample : Union [torch .FloatTensor , np .ndarray ],
119131 ):
120- if self .counter < len (self .prk_timesteps ):
132+ if self .counter < len (self .prk_timesteps ) and not self . config . skip_prk_steps :
121133 return self .step_prk (model_output = model_output , timestep = timestep , sample = sample )
122134 else :
123135 return self .step_plms (model_output = model_output , timestep = timestep , sample = sample )
@@ -166,7 +178,7 @@ def step_plms(
166178 Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
167179 times to approximate the solution.
168180 """
169- if len (self .ets ) < 3 :
181+ if not self . config . skip_prk_steps and len (self .ets ) < 3 :
170182 raise ValueError (
171183 f"{ self .__class__ } can only be run AFTER scheduler has been run "
172184 "in 'prk' mode for at least 12 iterations "
@@ -175,9 +187,26 @@ def step_plms(
175187 )
176188
177189 prev_timestep = max (timestep - self .config .num_train_timesteps // self .num_inference_steps , 0 )
178- self .ets .append (model_output )
179190
180- model_output = (1 / 24 ) * (55 * self .ets [- 1 ] - 59 * self .ets [- 2 ] + 37 * self .ets [- 3 ] - 9 * self .ets [- 4 ])
191+ if self .counter != 1 :
192+ self .ets .append (model_output )
193+ else :
194+ prev_timestep = timestep
195+ timestep = timestep + self .config .num_train_timesteps // self .num_inference_steps
196+
197+ if len (self .ets ) == 1 and self .counter == 0 :
198+ model_output = model_output
199+ self .cur_sample = sample
200+ elif len (self .ets ) == 1 and self .counter == 1 :
201+ model_output = (model_output + self .ets [- 1 ]) / 2
202+ sample = self .cur_sample
203+ self .cur_sample = None
204+ elif len (self .ets ) == 2 :
205+ model_output = (3 * self .ets [- 1 ] - self .ets [- 2 ]) / 2
206+ elif len (self .ets ) == 3 :
207+ model_output = (23 * self .ets [- 1 ] - 16 * self .ets [- 2 ] + 5 * self .ets [- 3 ]) / 12
208+ else :
209+ model_output = (1 / 24 ) * (55 * self .ets [- 1 ] - 59 * self .ets [- 2 ] + 37 * self .ets [- 3 ] - 9 * self .ets [- 4 ])
181210
182211 prev_sample = self ._get_prev_sample (sample , timestep , prev_timestep , model_output )
183212 self .counter += 1
@@ -197,8 +226,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
197226 # sample -> x_t
198227 # model_output -> e_θ(x_t, t)
199228 # prev_sample -> x_(t−δ)
200- alpha_prod_t = self .alphas_cumprod [timestep + 1 ]
201- alpha_prod_t_prev = self .alphas_cumprod [timestep_prev + 1 ]
229+ alpha_prod_t = self .alphas_cumprod [timestep + 1 - self . _offset ]
230+ alpha_prod_t_prev = self .alphas_cumprod [timestep_prev + 1 - self . _offset ]
202231 beta_prod_t = 1 - alpha_prod_t
203232 beta_prod_t_prev = 1 - alpha_prod_t_prev
204233
0 commit comments