Skip to content

Commit d7b6920

Browse files
authored
Add K-LMS scheduler from k-diffusion (#185)
* test LMS with LDM * test LMS with LDM * Interchangeable sigma and timestep. Added dummy objects * Debug * cuda generator * Fix derivatives * Update tests * Rename Lms->LMS
1 parent 9070c39 commit d7b6920

File tree

7 files changed

+225
-7
lines changed

7 files changed

+225
-7
lines changed

src/diffusers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# flake8: noqa
22
# There's no way to ignore "F401 '...' imported but unused" warnings in this
33
# module, but to preserve other warnings. So, don't check this module at all.
4-
from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
4+
from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available
55

66

77
__version__ = "0.1.3"
@@ -27,11 +27,17 @@
2727
SchedulerMixin,
2828
ScoreSdeVeScheduler,
2929
)
30+
31+
32+
if is_scipy_available():
33+
from .schedulers import LMSDiscreteScheduler
34+
3035
from .training_utils import EMAModel
3136

3237

3338
if is_transformers_available():
3439
from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline
3540

41+
3642
else:
3743
from .utils.dummy_transformers_objects import *

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ...models import AutoencoderKL, UNet2DConditionModel
1010
from ...pipeline_utils import DiffusionPipeline
11-
from ...schedulers import DDIMScheduler, PNDMScheduler
11+
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1212

1313

1414
class StableDiffusionPipeline(DiffusionPipeline):
@@ -18,7 +18,7 @@ def __init__(
1818
text_encoder: CLIPTextModel,
1919
tokenizer: CLIPTokenizer,
2020
unet: UNet2DConditionModel,
21-
scheduler: Union[DDIMScheduler, PNDMScheduler],
21+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
2222
):
2323
super().__init__()
2424
scheduler = scheduler.set_format("pt")
@@ -105,9 +105,16 @@ def __call__(
105105
if accepts_eta:
106106
extra_step_kwargs["eta"] = eta
107107

108-
for t in tqdm(self.scheduler.timesteps):
108+
self.scheduler.set_timesteps(num_inference_steps)
109+
if isinstance(self.scheduler, LMSDiscreteScheduler):
110+
latents = latents * self.scheduler.sigmas[0]
111+
112+
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
109113
# expand the latents if we are doing classifier free guidance
110114
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
115+
if isinstance(self.scheduler, LMSDiscreteScheduler):
116+
sigma = self.scheduler.sigmas[i]
117+
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
111118

112119
# predict the noise residual
113120
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
@@ -118,7 +125,10 @@ def __call__(
118125
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
119126

120127
# compute the previous noisy sample x_t -> x_t-1
121-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
128+
if isinstance(self.scheduler, LMSDiscreteScheduler):
129+
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
130+
else:
131+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
122132

123133
# scale and decode the image latents with vae
124134
latents = 1 / 0.18215 * latents

src/diffusers/schedulers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19+
from ..utils import is_scipy_available
1920
from .scheduling_ddim import DDIMScheduler
2021
from .scheduling_ddpm import DDPMScheduler
2122
from .scheduling_karras_ve import KarrasVeScheduler
2223
from .scheduling_pndm import PNDMScheduler
2324
from .scheduling_sde_ve import ScoreSdeVeScheduler
2425
from .scheduling_sde_vp import ScoreSdeVpScheduler
2526
from .scheduling_utils import SchedulerMixin
27+
28+
29+
if is_scipy_available():
30+
from .scheduling_lms_discrete import LMSDiscreteScheduler
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Union
16+
17+
import numpy as np
18+
import torch
19+
20+
from scipy import integrate
21+
22+
from ..configuration_utils import ConfigMixin, register_to_config
23+
from .scheduling_utils import SchedulerMixin
24+
25+
26+
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
27+
@register_to_config
28+
def __init__(
29+
self,
30+
num_train_timesteps=1000,
31+
beta_start=0.0001,
32+
beta_end=0.02,
33+
beta_schedule="linear",
34+
trained_betas=None,
35+
timestep_values=None,
36+
tensor_format="pt",
37+
):
38+
"""
39+
Linear Multistep Scheduler for discrete beta schedules.
40+
Based on the original k-diffusion implementation by Katherine Crowson:
41+
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
42+
"""
43+
44+
if beta_schedule == "linear":
45+
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
46+
elif beta_schedule == "scaled_linear":
47+
# this schedule is very specific to the latent diffusion model.
48+
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
49+
else:
50+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
51+
52+
self.alphas = 1.0 - self.betas
53+
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
54+
55+
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
56+
57+
# setable values
58+
self.num_inference_steps = None
59+
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
60+
self.derivatives = []
61+
62+
self.tensor_format = tensor_format
63+
self.set_format(tensor_format=tensor_format)
64+
65+
def get_lms_coefficient(self, order, t, current_order):
66+
"""
67+
Compute a linear multistep coefficient
68+
"""
69+
70+
def lms_derivative(tau):
71+
prod = 1.0
72+
for k in range(order):
73+
if current_order == k:
74+
continue
75+
prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
76+
return prod
77+
78+
integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
79+
80+
return integrated_coeff
81+
82+
def set_timesteps(self, num_inference_steps):
83+
self.num_inference_steps = num_inference_steps
84+
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
85+
86+
low_idx = np.floor(self.timesteps).astype(int)
87+
high_idx = np.ceil(self.timesteps).astype(int)
88+
frac = np.mod(self.timesteps, 1.0)
89+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
90+
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
91+
self.sigmas = np.concatenate([sigmas, [0.0]])
92+
93+
self.derivatives = []
94+
95+
self.set_format(tensor_format=self.tensor_format)
96+
97+
def step(
98+
self,
99+
model_output: Union[torch.FloatTensor, np.ndarray],
100+
timestep: int,
101+
sample: Union[torch.FloatTensor, np.ndarray],
102+
order: int = 4,
103+
):
104+
sigma = self.sigmas[timestep]
105+
106+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
107+
pred_original_sample = sample - sigma * model_output
108+
109+
# 2. Convert to an ODE derivative
110+
derivative = (sample - pred_original_sample) / sigma
111+
self.derivatives.append(derivative)
112+
if len(self.derivatives) > order:
113+
self.derivatives.pop(0)
114+
115+
# 3. Compute linear multistep coefficients
116+
order = min(timestep + 1, order)
117+
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
118+
119+
# 4. Compute previous sample based on the derivatives path
120+
prev_sample = sample + sum(
121+
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
122+
)
123+
124+
return {"prev_sample": prev_sample}
125+
126+
def add_noise(self, original_samples, noise, timesteps):
127+
alpha_prod = self.alphas_cumprod[timesteps]
128+
alpha_prod = self.match_shape(alpha_prod, original_samples)
129+
130+
noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise
131+
return noisy_samples
132+
133+
def __len__(self):
134+
return self.config.num_train_timesteps

src/diffusers/utils/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@
6969
_modelcards_available = False
7070

7171

72+
_scipy_available = importlib.util.find_spec("scipy") is not None
73+
try:
74+
_scipy_version = importlib_metadata.version("scipy")
75+
logger.debug(f"Successfully imported transformers version {_scipy_version}")
76+
except importlib_metadata.PackageNotFoundError:
77+
_scipy_available = False
78+
79+
7280
def is_transformers_available():
7381
return _transformers_available
7482

@@ -85,6 +93,10 @@ def is_modelcards_available():
8593
return _modelcards_available
8694

8795

96+
def is_scipy_available():
97+
return _scipy_available
98+
99+
88100
class RepositoryNotFoundError(HTTPError):
89101
"""
90102
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
@@ -118,11 +130,18 @@ class RevisionNotFoundError(HTTPError):
118130
"""
119131

120132

133+
SCIPY_IMPORT_ERROR = """
134+
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
135+
scipy`
136+
"""
137+
138+
121139
BACKENDS_MAPPING = OrderedDict(
122140
[
123141
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
124142
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
125143
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
144+
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
126145
]
127146
)
128147

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# This file is autogenerated by the command `make fix-copies`, do not edit.
2+
# flake8: noqa
3+
from ..utils import DummyObject, requires_backends
4+
5+
6+
class LmsDiscreteScheduler(metaclass=DummyObject):
7+
_backends = ["scipy"]
8+
9+
def __init__(self, *args, **kwargs):
10+
requires_backends(self, ["scipy"])
11+
12+
13+
class LDMTextToImagePipeline(metaclass=DummyObject):
14+
_backends = ["scipy"]
15+
16+
def __init__(self, *args, **kwargs):
17+
requires_backends(self, ["scipy"])
18+
19+
20+
class StableDiffusionPipeline(metaclass=DummyObject):
21+
_backends = ["scipy"]
22+
23+
def __init__(self, *args, **kwargs):
24+
requires_backends(self, ["scipy"])

tests/test_modeling_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
KarrasVeScheduler,
3434
LDMPipeline,
3535
LDMTextToImagePipeline,
36+
LMSDiscreteScheduler,
3637
PNDMPipeline,
3738
PNDMScheduler,
3839
ScoreSdeVePipeline,
@@ -841,7 +842,7 @@ def test_ldm_text2img_fast(self):
841842
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
842843

843844
@slow
844-
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
845+
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
845846
def test_stable_diffusion(self):
846847
# make sure here that pndm scheduler skips prk
847848
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
@@ -862,7 +863,7 @@ def test_stable_diffusion(self):
862863
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
863864

864865
@slow
865-
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
866+
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
866867
def test_stable_diffusion_fast_ddim(self):
867868
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
868869

@@ -977,3 +978,22 @@ def test_karras_ve_pipeline(self):
977978
assert image.shape == (1, 256, 256, 3)
978979
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
979980
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
981+
982+
@slow
983+
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
984+
def test_lms_stable_diffusion_pipeline(self):
985+
model_id = "CompVis/stable-diffusion-v1-1-diffusers"
986+
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
987+
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
988+
pipe.scheduler = scheduler
989+
990+
prompt = "a photograph of an astronaut riding a horse"
991+
generator = torch.Generator(device=torch_device).manual_seed(0)
992+
image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
993+
"sample"
994+
]
995+
996+
image_slice = image[0, -3:, -3:, -1]
997+
assert image.shape == (1, 512, 512, 3)
998+
expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
999+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

0 commit comments

Comments
 (0)