Skip to content

Commit 194ed79

Browse files
[PNDM] Stable diffusion (#186)
* [PNDM] Stable diffusino * finish
1 parent 051b346 commit 194ed79

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
beta_end=0.02,
5757
beta_schedule="linear",
5858
tensor_format="pt",
59+
skip_prk_steps=False,
5960
):
6061

6162
if beta_schedule == "linear":
@@ -88,24 +89,35 @@ def __init__(
8889
# setable values
8990
self.num_inference_steps = None
9091
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
92+
self._offset = 0
9193
self.prk_timesteps = None
9294
self.plms_timesteps = None
9395
self.timesteps = None
9496

9597
self.tensor_format = tensor_format
9698
self.set_format(tensor_format=tensor_format)
9799

98-
def set_timesteps(self, num_inference_steps):
100+
def set_timesteps(self, num_inference_steps, offset=0):
99101
self.num_inference_steps = num_inference_steps
100102
self._timesteps = list(
101103
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
102104
)
105+
self._offset = offset
106+
self._timesteps = [t + self._offset for t in self._timesteps]
107+
108+
if self.config.skip_prk_steps:
109+
# for some models like stable diffusion the prk steps can/should be skipped to
110+
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
111+
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
112+
self.prk_timesteps = []
113+
self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:]))
114+
else:
115+
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
116+
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
117+
)
118+
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
119+
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
103120

104-
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
105-
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
106-
)
107-
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
108-
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
109121
self.timesteps = self.prk_timesteps + self.plms_timesteps
110122

111123
self.counter = 0
@@ -117,7 +129,7 @@ def step(
117129
timestep: int,
118130
sample: Union[torch.FloatTensor, np.ndarray],
119131
):
120-
if self.counter < len(self.prk_timesteps):
132+
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
121133
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
122134
else:
123135
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
@@ -166,7 +178,7 @@ def step_plms(
166178
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
167179
times to approximate the solution.
168180
"""
169-
if len(self.ets) < 3:
181+
if not self.config.skip_prk_steps and len(self.ets) < 3:
170182
raise ValueError(
171183
f"{self.__class__} can only be run AFTER scheduler has been run "
172184
"in 'prk' mode for at least 12 iterations "
@@ -175,9 +187,26 @@ def step_plms(
175187
)
176188

177189
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
178-
self.ets.append(model_output)
179190

180-
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
191+
if self.counter != 1:
192+
self.ets.append(model_output)
193+
else:
194+
prev_timestep = timestep
195+
timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
196+
197+
if len(self.ets) == 1 and self.counter == 0:
198+
model_output = model_output
199+
self.cur_sample = sample
200+
elif len(self.ets) == 1 and self.counter == 1:
201+
model_output = (model_output + self.ets[-1]) / 2
202+
sample = self.cur_sample
203+
self.cur_sample = None
204+
elif len(self.ets) == 2:
205+
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
206+
elif len(self.ets) == 3:
207+
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
208+
else:
209+
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
181210

182211
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
183212
self.counter += 1
@@ -197,8 +226,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
197226
# sample -> x_t
198227
# model_output -> e_θ(x_t, t)
199228
# prev_sample -> x_(t−δ)
200-
alpha_prod_t = self.alphas_cumprod[timestep + 1]
201-
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1]
229+
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
230+
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
202231
beta_prod_t = 1 - alpha_prod_t
203232
beta_prod_t_prev = 1 - alpha_prod_t_prev
204233

tests/test_modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,7 @@ def test_ldm_text2img_fast(self):
843843
@slow
844844
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
845845
def test_stable_diffusion(self):
846+
# make sure here that pndm scheduler skips prk
846847
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
847848

848849
prompt = "A painting of a squirrel eating a burger"
@@ -857,7 +858,7 @@ def test_stable_diffusion(self):
857858
image_slice = image[0, -3:, -3:, -1]
858859

859860
assert image.shape == (1, 512, 512, 3)
860-
expected_slice = np.array([0.898, 0.9194, 0.91, 0.8955, 0.915, 0.919, 0.9233, 0.9307, 0.8887])
861+
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
861862
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
862863

863864
@slow

0 commit comments

Comments
 (0)