Skip to content

Commit bcba858

Browse files
authored
Merge branch 'main' into allegro-impl
2 parents 412cd7c + 1b64772 commit bcba858

File tree

12 files changed

+338
-67
lines changed

12 files changed

+338
-67
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ image = pipe(
5454
image.save("sd3_hello_world.png")
5555
```
5656

57+
**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family:
58+
- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)
59+
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
60+
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
61+
5762
## Memory Optimisations for SD3
5863

5964
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.

examples/community/README.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4336,19 +4336,19 @@ The Abstract of the paper:
43364336

43374337
**64x64**
43384338
:-------------------------:
4339-
| <img src="https://github.com/user-attachments/assets/9e7bb2cd-45a0-4bd1-adb8-23e283baed39" width="222" height="222" alt="bird_64"> |
4339+
| <img src="https://github.com/user-attachments/assets/032738eb-c6cd-4fd9-b4d7-a7317b4b6528" width="222" height="222" alt="bird_64_64"> |
43404340

43414341
- `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps:
43424342

43434343
**64x64** | **256x256**
43444344
:-------------------------:|:-------------------------:
4345-
| <img src="https://github.com/user-attachments/assets/6b724c2e-5e6a-4b63-9b65-c1182cbb67e0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/7dbab2ad-bf40-4a73-ab04-f178347cb7d5" width="222" height="222" alt="256x256"> |
4345+
| <img src="https://github.com/user-attachments/assets/21b9ad8b-eea6-4603-80a2-31180f391589" width="222" height="222" alt="bird_256_64"> | <img src="https://github.com/user-attachments/assets/fc411682-8a36-422c-9488-395b77d4406e" width="222" height="222" alt="bird_256_256"> |
43464346

4347-
- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible. With `250` DDIM inference steps:
4347+
- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible in this context! With `250` DDIM inference steps:
43484348

43494349
**64x64** | **256x256** | **1024x1024**
43504350
:-------------------------:|:-------------------------:|:-------------------------:
4351-
| <img src="https://github.com/user-attachments/assets/4a9454e4-e20a-4736-a196-270e2ae796c0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/4a96555d-0fda-4303-82b1-a4d886f770b9" width="222" height="222" alt="256x256"> | <img src="https://github.com/user-attachments/assets/e0239b7a-ab73-4d45-8f3e-b4e6b4b50abe" width="222" height="222" alt="1024x1024"> |
4351+
| <img src="https://github.com/user-attachments/assets/febf4b98-3dee-4a8e-9946-fd42e1f232e6" width="222" height="222" alt="bird_1024_64"> | <img src="https://github.com/user-attachments/assets/c5f85b40-5d6d-4267-a92a-c89dff015b9b" width="222" height="222" alt="bird_1024_256"> | <img src="https://github.com/user-attachments/assets/ad66b913-4367-4cb9-889e-bc06f4d96148" width="222" height="222" alt="bird_1024_1024"> |
43524352

43534353
```py
43544354
from diffusers import DiffusionPipeline
@@ -4362,8 +4362,7 @@ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-model
43624362

43634363
prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
43644364
prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
4365-
negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
4366-
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
4365+
image = pipe(prompt, num_inference_steps=50).images
43674366
make_image_grid(image, rows=1, cols=len(image))
43684367

43694368
# pipe.change_nesting_level(<int>) # 0, 1, or 2

examples/community/matryoshka.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,16 @@
107107
108108
>>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
109109
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
110-
>>> custom_pipeline="matryoshka").to("cuda")
110+
... nesting_level=0,
111+
... trust_remote_code=False, # One needs to give permission for this code to run
112+
... ).to("cuda")
111113
112114
>>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
113115
>>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
114-
>>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
115-
>>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
116+
>>> image = pipe(prompt, num_inference_steps=50).images
116117
>>> make_image_grid(image, rows=1, cols=len(image))
117118
118-
>>> pipe.change_nesting_level(<int>) # 0, 1, or 2
119+
>>> # pipe.change_nesting_level(<int>) # 0, 1, or 2
119120
>>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
120121
```
121122
"""
@@ -420,6 +421,7 @@ def __init__(
420421
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
421422

422423
self.scales = None
424+
self.schedule_shifted_power = 1.0
423425

424426
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
425427
"""
@@ -532,6 +534,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
532534

533535
def get_schedule_shifted(self, alpha_prod, scale_factor=None):
534536
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
537+
scale_factor = scale_factor**self.schedule_shifted_power
535538
snr = alpha_prod / (1 - alpha_prod)
536539
scaled_snr = snr / scale_factor
537540
alpha_prod = 1 / (1 + 1 / scaled_snr)
@@ -639,17 +642,14 @@ def step(
639642
# 4. Clip or threshold "predicted x_0"
640643
if self.config.thresholding:
641644
if len(model_output) > 1:
642-
pred_original_sample = [
643-
self._threshold_sample(p_o_s * scale) / scale
644-
for p_o_s, scale in zip(pred_original_sample, self.scales)
645-
]
645+
pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample]
646646
else:
647647
pred_original_sample = self._threshold_sample(pred_original_sample)
648648
elif self.config.clip_sample:
649649
if len(model_output) > 1:
650650
pred_original_sample = [
651-
(p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale
652-
for p_o_s, scale in zip(pred_original_sample, self.scales)
651+
p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
652+
for p_o_s in pred_original_sample
653653
]
654654
else:
655655
pred_original_sample = pred_original_sample.clamp(
@@ -3816,6 +3816,8 @@ def __init__(
38163816

38173817
if hasattr(unet, "nest_ratio"):
38183818
scheduler.scales = unet.nest_ratio + [1]
3819+
if nesting_level == 2:
3820+
scheduler.schedule_shifted_power = 2.0
38193821

38203822
self.register_modules(
38213823
text_encoder=text_encoder,
@@ -3842,12 +3844,14 @@ def change_nesting_level(self, nesting_level: int):
38423844
).to(self.device)
38433845
self.config.nesting_level = 1
38443846
self.scheduler.scales = self.unet.nest_ratio + [1]
3847+
self.scheduler.schedule_shifted_power = 1.0
38453848
elif nesting_level == 2:
38463849
self.unet = NestedUNet2DConditionModel.from_pretrained(
38473850
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
38483851
).to(self.device)
38493852
self.config.nesting_level = 2
38503853
self.scheduler.scales = self.unet.nest_ratio + [1]
3854+
self.scheduler.schedule_shifted_power = 2.0
38513855
else:
38523856
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
38533857

@@ -4627,8 +4631,8 @@ def __call__(
46274631
image = latents
46284632

46294633
if self.scheduler.scales is not None:
4630-
for i, (img, scale) in enumerate(zip(image, self.scheduler.scales)):
4631-
image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
4634+
for i, img in enumerate(image):
4635+
image[i] = self.image_processor.postprocess(img, output_type=output_type)[0]
46324636
else:
46334637
image = self.image_processor.postprocess(image, output_type=output_type)
46344638

examples/controlnet/train_controlnet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,9 @@ def load_model_hook(models, input_dir):
10481048

10491049
# Add noise to the latents according to the noise magnitude at each timestep
10501050
# (this is the forward diffusion process)
1051-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1051+
noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
1052+
dtype=weight_dtype
1053+
)
10521054

10531055
# Get the text embedding for conditioning
10541056
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,9 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
12101210

12111211
# Add noise to the latents according to the noise magnitude at each timestep
12121212
# (this is the forward diffusion process)
1213-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1213+
noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
1214+
dtype=weight_dtype
1215+
)
12141216

12151217
# ControlNet conditioning.
12161218
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)

scripts/convert_sd3_to_diffusers.py

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
parser = argparse.ArgumentParser()
1717
parser.add_argument("--checkpoint_path", type=str)
1818
parser.add_argument("--output_path", type=str)
19-
parser.add_argument("--dtype", type=str, default="fp16")
19+
parser.add_argument("--dtype", type=str)
2020

2121
args = parser.parse_args()
22-
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
2322

2423

2524
def load_original_checkpoint(ckpt_path):
@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
4039
return new_weight
4140

4241

43-
def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
42+
def convert_sd3_transformer_checkpoint_to_diffusers(
43+
original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm
44+
):
4445
converted_state_dict = {}
4546

4647
# Positional and patch embeddings.
@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
110111
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
111112
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
112113

114+
# qk norm
115+
if has_qk_norm:
116+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop(
117+
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
118+
)
119+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop(
120+
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
121+
)
122+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop(
123+
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
124+
)
125+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop(
126+
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
127+
)
128+
113129
# output projections.
114130
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
115131
f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
125141
f"joint_blocks.{i}.context_block.attn.proj.bias"
126142
)
127143

144+
# attn2
145+
if i in dual_attention_layers:
146+
# Q, K, V
147+
sample_q2, sample_k2, sample_v2 = torch.chunk(
148+
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
149+
)
150+
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
151+
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
152+
)
153+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
154+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
155+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
156+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
157+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
158+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
159+
160+
# qk norm
161+
if has_qk_norm:
162+
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop(
163+
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
164+
)
165+
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop(
166+
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
167+
)
168+
169+
# output projections.
170+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop(
171+
f"joint_blocks.{i}.x_block.attn2.proj.weight"
172+
)
173+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop(
174+
f"joint_blocks.{i}.x_block.attn2.proj.bias"
175+
)
176+
128177
# norms.
129178
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
130179
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
195244
)
196245

197246

247+
def get_attn2_layers(state_dict):
248+
attn2_layers = []
249+
for key in state_dict.keys():
250+
if "attn2." in key:
251+
# Extract the layer number from the key
252+
layer_num = int(key.split(".")[1])
253+
attn2_layers.append(layer_num)
254+
return tuple(sorted(set(attn2_layers)))
255+
256+
257+
def get_pos_embed_max_size(state_dict):
258+
num_patches = state_dict["pos_embed"].shape[1]
259+
pos_embed_max_size = int(num_patches**0.5)
260+
return pos_embed_max_size
261+
262+
263+
def get_caption_projection_dim(state_dict):
264+
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
265+
return caption_projection_dim
266+
267+
198268
def main(args):
199269
original_ckpt = load_original_checkpoint(args.checkpoint_path)
270+
original_dtype = next(iter(original_ckpt.values())).dtype
271+
272+
# Initialize dtype with a default value
273+
dtype = None
274+
275+
if args.dtype is None:
276+
dtype = original_dtype
277+
elif args.dtype == "fp16":
278+
dtype = torch.float16
279+
elif args.dtype == "bf16":
280+
dtype = torch.bfloat16
281+
elif args.dtype == "fp32":
282+
dtype = torch.float32
283+
else:
284+
raise ValueError(f"Unsupported dtype: {args.dtype}")
285+
286+
if dtype != original_dtype:
287+
print(
288+
f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
289+
)
290+
200291
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
201-
caption_projection_dim = 1536
292+
293+
caption_projection_dim = get_caption_projection_dim(original_ckpt)
294+
295+
# () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
296+
attn2_layers = get_attn2_layers(original_ckpt)
297+
298+
# sd3.5 use qk norm("rms_norm")
299+
has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
300+
301+
# sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
302+
pos_embed_max_size = get_pos_embed_max_size(original_ckpt)
202303

203304
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
204-
original_ckpt, num_layers, caption_projection_dim
305+
original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
205306
)
206307

207308
with CTX():
208309
transformer = SD3Transformer2DModel(
209-
sample_size=64,
310+
sample_size=128,
210311
patch_size=2,
211312
in_channels=16,
212313
joint_attention_dim=4096,
213314
num_layers=num_layers,
214315
caption_projection_dim=caption_projection_dim,
215-
num_attention_heads=24,
216-
pos_embed_max_size=192,
316+
num_attention_heads=num_layers,
317+
pos_embed_max_size=pos_embed_max_size,
318+
qk_norm="rms_norm" if has_qk_norm else None,
319+
dual_attention_layers=attn2_layers,
217320
)
218321
if is_accelerate_available():
219322
load_model_dict_into_meta(transformer, converted_transformer_state_dict)

0 commit comments

Comments
 (0)