Skip to content

Commit f227e15

Browse files
committed
Merge branch 'main' into groupwise-offloading
2 parents af62c93 + 3e35f56 commit f227e15

File tree

66 files changed

+374
-1823
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+374
-1823
lines changed

docs/source/en/using-diffusers/img2img.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,12 @@ Chain it to an upscaler pipeline to increase the image resolution:
461461
from diffusers import StableDiffusionLatentUpscalePipeline
462462

463463
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
464-
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
464+
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, use_safetensors=True
465465
)
466466
upscaler.enable_model_cpu_offload()
467467
upscaler.enable_xformers_memory_efficient_attention()
468468

469-
image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
469+
image_2 = upscaler(prompt, image=image_1).images[0]
470470
```
471471

472472
Finally, chain it to a super-resolution pipeline to further enhance the resolution:

docs/source/en/using-diffusers/write_own_pipeline.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ Let's try it out!
106106

107107
## Deconstruct the Stable Diffusion pipeline
108108

109-
Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder to convert the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
109+
Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder converts the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
110110

111111
As you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.
112112

examples/community/matryoshka.py

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
USE_PEFT_BACKEND,
8181
BaseOutput,
8282
deprecate,
83-
is_torch_version,
8483
is_torch_xla_available,
8584
logging,
8685
replace_example_docstring,
@@ -869,23 +868,7 @@ def forward(
869868

870869
for i, (resnet, attn) in enumerate(blocks):
871870
if torch.is_grad_enabled() and self.gradient_checkpointing:
872-
873-
def create_custom_forward(module, return_dict=None):
874-
def custom_forward(*inputs):
875-
if return_dict is not None:
876-
return module(*inputs, return_dict=return_dict)
877-
else:
878-
return module(*inputs)
879-
880-
return custom_forward
881-
882-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
883-
hidden_states = torch.utils.checkpoint.checkpoint(
884-
create_custom_forward(resnet),
885-
hidden_states,
886-
temb,
887-
**ckpt_kwargs,
888-
)
871+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
889872
hidden_states = attn(
890873
hidden_states,
891874
encoder_hidden_states=encoder_hidden_states,
@@ -1030,17 +1013,6 @@ def forward(
10301013
hidden_states = self.resnets[0](hidden_states, temb)
10311014
for attn, resnet in zip(self.attentions, self.resnets[1:]):
10321015
if torch.is_grad_enabled() and self.gradient_checkpointing:
1033-
1034-
def create_custom_forward(module, return_dict=None):
1035-
def custom_forward(*inputs):
1036-
if return_dict is not None:
1037-
return module(*inputs, return_dict=return_dict)
1038-
else:
1039-
return module(*inputs)
1040-
1041-
return custom_forward
1042-
1043-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
10441016
hidden_states = attn(
10451017
hidden_states,
10461018
encoder_hidden_states=encoder_hidden_states,
@@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
10491021
encoder_attention_mask=encoder_attention_mask,
10501022
return_dict=False,
10511023
)[0]
1052-
hidden_states = torch.utils.checkpoint.checkpoint(
1053-
create_custom_forward(resnet),
1054-
hidden_states,
1055-
temb,
1056-
**ckpt_kwargs,
1057-
)
1024+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
10581025
else:
10591026
hidden_states = attn(
10601027
hidden_states,
@@ -1192,23 +1159,7 @@ def forward(
11921159
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
11931160

11941161
if torch.is_grad_enabled() and self.gradient_checkpointing:
1195-
1196-
def create_custom_forward(module, return_dict=None):
1197-
def custom_forward(*inputs):
1198-
if return_dict is not None:
1199-
return module(*inputs, return_dict=return_dict)
1200-
else:
1201-
return module(*inputs)
1202-
1203-
return custom_forward
1204-
1205-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1206-
hidden_states = torch.utils.checkpoint.checkpoint(
1207-
create_custom_forward(resnet),
1208-
hidden_states,
1209-
temb,
1210-
**ckpt_kwargs,
1211-
)
1162+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
12121163
hidden_states = attn(
12131164
hidden_states,
12141165
encoder_hidden_states=encoder_hidden_states,
@@ -1282,10 +1233,6 @@ def __init__(
12821233
]
12831234
)
12841235

1285-
def _set_gradient_checkpointing(self, module, value=False):
1286-
if hasattr(module, "gradient_checkpointing"):
1287-
module.gradient_checkpointing = value
1288-
12891236
def forward(
12901237
self,
12911238
hidden_states: torch.Tensor,
@@ -1365,27 +1312,15 @@ def forward(
13651312
# Blocks
13661313
for block in self.transformer_blocks:
13671314
if torch.is_grad_enabled() and self.gradient_checkpointing:
1368-
1369-
def create_custom_forward(module, return_dict=None):
1370-
def custom_forward(*inputs):
1371-
if return_dict is not None:
1372-
return module(*inputs, return_dict=return_dict)
1373-
else:
1374-
return module(*inputs)
1375-
1376-
return custom_forward
1377-
1378-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1379-
hidden_states = torch.utils.checkpoint.checkpoint(
1380-
create_custom_forward(block),
1315+
hidden_states = self._gradient_checkpointing_func(
1316+
block,
13811317
hidden_states,
13821318
attention_mask,
13831319
encoder_hidden_states,
13841320
encoder_attention_mask,
13851321
timestep,
13861322
cross_attention_kwargs,
13871323
class_labels,
1388-
**ckpt_kwargs,
13891324
)
13901325
else:
13911326
hidden_states = block(
@@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
27242659
for module in self.children():
27252660
fn_recursive_set_attention_slice(module, reversed_slice_size)
27262661

2727-
def _set_gradient_checkpointing(self, module, value=False):
2728-
if hasattr(module, "gradient_checkpointing"):
2729-
module.gradient_checkpointing = value
2730-
27312662
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
27322663
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
27332664

examples/community/stable_diffusion_xl_controlnet_reference.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
193193

194194
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
195195
refimage = refimage.to(device=device)
196-
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
196+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
197+
if needs_upcasting:
197198
self.upcast_vae()
198199
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
199200
if refimage.dtype != self.vae.dtype:
@@ -223,6 +224,11 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do
223224

224225
# aligning device to prevent device errors when concating it with the latent model input
225226
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
227+
228+
# cast back to fp16 if needed
229+
if needs_upcasting:
230+
self.vae.to(dtype=torch.float16)
231+
226232
return ref_image_latents
227233

228234
def prepare_ref_image(

examples/community/stable_diffusion_xl_reference.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def retrieve_timesteps(
139139
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
140140
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
141141
refimage = refimage.to(device=device)
142-
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
142+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
143+
if needs_upcasting:
143144
self.upcast_vae()
144145
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
145146
if refimage.dtype != self.vae.dtype:
@@ -169,6 +170,11 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do
169170

170171
# aligning device to prevent device errors when concating it with the latent model input
171172
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
173+
174+
# cast back to fp16 if needed
175+
if needs_upcasting:
176+
self.vae.to(dtype=torch.float16)
177+
172178
return ref_image_latents
173179

174180
def prepare_ref_image(

examples/instruct_pix2pix/train_instruct_pix2pix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def preprocess_images(examples):
695695
)
696696
# We need to ensure that the original and the edited images undergo the same
697697
# augmentation transforms.
698-
images = np.concatenate([original_images, edited_images])
698+
images = np.stack([original_images, edited_images])
699699
images = torch.tensor(images)
700700
images = 2 * (images / 255) - 1
701701
return train_transforms(images)
@@ -706,7 +706,7 @@ def preprocess_train(examples):
706706
# Since the original and edited images were concatenated before
707707
# applying the transformations, we need to separate them and reshape
708708
# them accordingly.
709-
original_images, edited_images = preprocessed_images.chunk(2)
709+
original_images, edited_images = preprocessed_images
710710
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
711711
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
712712

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def preprocess_images(examples):
766766
)
767767
# We need to ensure that the original and the edited images undergo the same
768768
# augmentation transforms.
769-
images = np.concatenate([original_images, edited_images])
769+
images = np.stack([original_images, edited_images])
770770
images = torch.tensor(images)
771771
images = 2 * (images / 255) - 1
772772
return train_transforms(images)
@@ -906,7 +906,7 @@ def preprocess_train(examples):
906906
# Since the original and edited images were concatenated before
907907
# applying the transformations, we need to separate them and reshape
908908
# them accordingly.
909-
original_images, edited_images = preprocessed_images.chunk(2)
909+
original_images, edited_images = preprocessed_images
910910
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
911911
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
912912

examples/research_projects/pixart/controlnet_pixart_alpha.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from diffusers.models.attention import BasicTransformerBlock
99
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1010
from diffusers.models.modeling_utils import ModelMixin
11-
from diffusers.utils.torch_utils import is_torch_version
1211

1312

1413
class PixArtControlNetAdapterBlock(nn.Module):
@@ -151,10 +150,6 @@ def __init__(
151150
self.transformer = transformer
152151
self.controlnet = controlnet
153152

154-
def _set_gradient_checkpointing(self, module, value=False):
155-
if hasattr(module, "gradient_checkpointing"):
156-
module.gradient_checkpointing = value
157-
158153
def forward(
159154
self,
160155
hidden_states: torch.Tensor,
@@ -220,26 +215,15 @@ def forward(
220215
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
221216
exit(1)
222217

223-
def create_custom_forward(module, return_dict=None):
224-
def custom_forward(*inputs):
225-
if return_dict is not None:
226-
return module(*inputs, return_dict=return_dict)
227-
else:
228-
return module(*inputs)
229-
230-
return custom_forward
231-
232-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
233-
hidden_states = torch.utils.checkpoint.checkpoint(
234-
create_custom_forward(block),
218+
hidden_states = self._gradient_checkpointing_func(
219+
block,
235220
hidden_states,
236221
attention_mask,
237222
encoder_hidden_states,
238223
encoder_attention_mask,
239224
timestep,
240225
cross_attention_kwargs,
241226
None,
242-
**ckpt_kwargs,
243227
)
244228
else:
245229
# the control nets are only used for the blocks 1 to self.blocks_num

src/diffusers/models/attention_processor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,12 @@ def set_use_memory_efficient_attention_xformers(
405405
else:
406406
try:
407407
# Make sure we can run the memory efficient attention
408-
_ = xformers.ops.memory_efficient_attention(
409-
torch.randn((1, 2, 40), device="cuda"),
410-
torch.randn((1, 2, 40), device="cuda"),
411-
torch.randn((1, 2, 40), device="cuda"),
412-
)
408+
dtype = None
409+
if attention_op is not None:
410+
op_fw, op_bw = attention_op
411+
dtype, *_ = op_fw.SUPPORTED_DTYPES
412+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
413+
_ = xformers.ops.memory_efficient_attention(q, q, q)
413414
except Exception as e:
414415
raise e
415416

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ def __init__(
138138
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
139139
self.tile_overlap_factor = 0.25
140140

141-
def _set_gradient_checkpointing(self, module, value=False):
142-
if isinstance(module, (Encoder, Decoder)):
143-
module.gradient_checkpointing = value
144-
145141
def enable_tiling(self, use_tiling: bool = True):
146142
r"""
147143
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to

0 commit comments

Comments
 (0)