Skip to content

Commit bc2ba00

Browse files
sayakpaulpcuencapatrickvonplaten
authored
[LCM] add: locm docs. (#5723)
* add: locm docs. * correct path * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * up * add --------- Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 3d7eaf8 commit bc2ba00

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
title: Overview
7373
- local: using-diffusers/sdxl
7474
title: Stable Diffusion XL
75+
- local: using-diffusers/lcm
76+
title: Latent Consistency Models
7577
- local: using-diffusers/kandinsky
7678
title: Kandinsky
7779
- local: using-diffusers/controlnet

docs/source/en/using-diffusers/lcm.md

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Performing inference with LCM
14+
15+
Latent Consistency Models (LCM) enable quality image generation in typically 2-4 steps making it possible to use diffusion models in almost real-time settings.
16+
17+
From the [official website](https://latent-consistency-models.github.io/):
18+
19+
> LCMs can be distilled from any pre-trained Stable Diffusion (SD) in only 4,000 training steps (~32 A100 GPU Hours) for generating high quality 768 x 768 resolution images in 2~4 steps or even one step, significantly accelerating text-to-image generation. We employ LCM to distill the Dreamshaper-V7 version of SD in just 4,000 training iterations.
20+
21+
For a more technical overview of LCMs, refer to [the paper](https://huggingface.co/papers/2310.04378).
22+
23+
This guide shows how to perform inference with LCMs for text-to-image and image-to-image generation tasks. It will also cover performing inference with LoRA checkpoints.
24+
25+
## Text-to-image
26+
27+
You'll use the [`StableDiffusionXLPipeline`] here changing the `unet`. The UNet was distilled from the SDXL UNet using the framework introduced in LCM. Another important component is the scheduler: [`LCMScheduler`]. Together with the distilled UNet and the scheduler, LCM enables a fast inference workflow overcoming the slow iterative nature of diffusion models.
28+
29+
```python
30+
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
31+
import torch
32+
33+
unet = UNet2DConditionModel.from_pretrained(
34+
"latent-consistency/lcm-sdxl",
35+
torch_dtype=torch.float16,
36+
variant="fp16",
37+
)
38+
pipe = DiffusionPipeline.from_pretrained(
39+
"stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
40+
).to("cuda")
41+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
42+
43+
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
44+
45+
generator = torch.manual_seed(0)
46+
image = pipe(
47+
prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=8.0
48+
).images[0]
49+
```
50+
51+
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_intro.png)
52+
53+
Notice that we use only 4 steps for generation which is way less than what's typically used for standard SDXL.
54+
55+
Some details to keep in mind:
56+
57+
* To perform classifier-free guidance, batch size is usually doubled inside the pipeline. LCM, however, applies guidance using guidance embeddings, so the batch size does not have to be doubled in this case. This leads to a faster inference time, with the drawback that negative prompts don't have any effect on the denoising process.
58+
* The UNet was trained using the [3., 13.] guidance scale range. So, that is the ideal range for `guidance_scale`. However, disabling `guidance_scale` using a value of 1.0 is also effective in most cases.
59+
60+
## Image-to-image
61+
62+
The findings above apply to image-to-image tasks too. Let's look at how we can perform image-to-image generation with LCMs:
63+
64+
```python
65+
from diffusers import AutoPipelineForImage2Image, UNet2DConditionModel, LCMScheduler
66+
from diffusers.utils import load_image
67+
import torch
68+
69+
unet = UNet2DConditionModel.from_pretrained(
70+
"latent-consistency/lcm-sdxl",
71+
torch_dtype=torch.float16,
72+
variant="fp16",
73+
)
74+
pipe = AutoPipelineForImage2Image.from_pretrained(
75+
"stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
76+
).to("cuda")
77+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
78+
79+
prompt = "High altitude snowy mountains"
80+
image = load_image(
81+
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/snowy_mountains.jpeg"
82+
)
83+
84+
generator = torch.manual_seed(0)
85+
image = pipe(
86+
prompt=prompt,
87+
image=image,
88+
num_inference_steps=4,
89+
generator=generator,
90+
guidance_scale=8.0,
91+
).images[0]
92+
```
93+
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_i2i.png)
94+
95+
## LoRA
96+
97+
It is possible to generalize the LCM framework to use with [LoRA](../training/lora.md). It effectively eliminates the need to conduct expensive fine-tuning runs as LoRA training concerns just a few number of parameters compared to full fine-tuning. During inference, the [`LCMScheduler`] comes to the advantage as it enables very few-steps inference without compromising the quality.
98+
99+
We recommend to disable `guidance_scale` by setting it 0. The model is trained to follow prompts accurately
100+
even without using guidance scale. You can however, still use guidance scale in which case we recommend
101+
using values between 1.0 and 2.0.
102+
103+
### Text-to-image
104+
105+
```python
106+
from diffusers import DiffusionPipeline, LCMScheduler
107+
import torch
108+
109+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
110+
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
111+
112+
pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16", torch_dtype=torch.float16).to("cuda")
113+
114+
pipe.load_lora_weights(lcm_lora_id)
115+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
116+
117+
prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
118+
image = pipe(
119+
prompt=prompt,
120+
num_inference_steps=4,
121+
guidance_scale=0, # set guidance scale to 0 to disable it
122+
).images[0]
123+
```
124+
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lora_lcm.png)
125+
126+
### Image-to-image
127+
128+
Extending LCM LoRA to image-to-image is possible:
129+
130+
```python
131+
from diffusers import StableDiffusionXLImg2ImgPipeline, LCMScheduler
132+
from diffusers.utils import load_image
133+
import torch
134+
135+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
136+
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
137+
138+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, variant="fp16", torch_dtype=torch.float16).to("cuda")
139+
140+
pipe.load_lora_weights(lcm_lora_id)
141+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
142+
143+
prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
144+
145+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lora_lcm.png")
146+
147+
image = pipe(
148+
prompt=prompt,
149+
image=image,
150+
num_inference_steps=4,
151+
guidance_scale=0, # set guidance scale to 0 to disable it
152+
).images[0]
153+
```
154+
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_lora_i2i.png)

0 commit comments

Comments
 (0)