Skip to content

Commit e2d037b

Browse files
yiyixuxua-r-r-o-wsayakpaulapolinario
authored
minor doc/test update (#9734)
* update some docs and tests! --------- Co-authored-by: Aryan <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: apolinário <[email protected]>
1 parent bcd61fd commit e2d037b

File tree

7 files changed

+310
-47
lines changed

7 files changed

+310
-47
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ image = pipe(
5454
image.save("sd3_hello_world.png")
5555
```
5656

57+
**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family:
58+
- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)
59+
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
60+
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
61+
5762
## Memory Optimisations for SD3
5863

5964
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.

scripts/convert_sd3_to_diffusers.py

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
parser = argparse.ArgumentParser()
1717
parser.add_argument("--checkpoint_path", type=str)
1818
parser.add_argument("--output_path", type=str)
19-
parser.add_argument("--dtype", type=str, default="fp16")
19+
parser.add_argument("--dtype", type=str)
2020

2121
args = parser.parse_args()
22-
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
2322

2423

2524
def load_original_checkpoint(ckpt_path):
@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
4039
return new_weight
4140

4241

43-
def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
42+
def convert_sd3_transformer_checkpoint_to_diffusers(
43+
original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm
44+
):
4445
converted_state_dict = {}
4546

4647
# Positional and patch embeddings.
@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
110111
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
111112
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
112113

114+
# qk norm
115+
if has_qk_norm:
116+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop(
117+
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
118+
)
119+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop(
120+
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
121+
)
122+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop(
123+
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
124+
)
125+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop(
126+
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
127+
)
128+
113129
# output projections.
114130
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
115131
f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
125141
f"joint_blocks.{i}.context_block.attn.proj.bias"
126142
)
127143

144+
# attn2
145+
if i in dual_attention_layers:
146+
# Q, K, V
147+
sample_q2, sample_k2, sample_v2 = torch.chunk(
148+
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
149+
)
150+
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
151+
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
152+
)
153+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
154+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
155+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
156+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
157+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
158+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
159+
160+
# qk norm
161+
if has_qk_norm:
162+
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop(
163+
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
164+
)
165+
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop(
166+
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
167+
)
168+
169+
# output projections.
170+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop(
171+
f"joint_blocks.{i}.x_block.attn2.proj.weight"
172+
)
173+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop(
174+
f"joint_blocks.{i}.x_block.attn2.proj.bias"
175+
)
176+
128177
# norms.
129178
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
130179
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
195244
)
196245

197246

247+
def get_attn2_layers(state_dict):
248+
attn2_layers = []
249+
for key in state_dict.keys():
250+
if "attn2." in key:
251+
# Extract the layer number from the key
252+
layer_num = int(key.split(".")[1])
253+
attn2_layers.append(layer_num)
254+
return tuple(sorted(set(attn2_layers)))
255+
256+
257+
def get_pos_embed_max_size(state_dict):
258+
num_patches = state_dict["pos_embed"].shape[1]
259+
pos_embed_max_size = int(num_patches**0.5)
260+
return pos_embed_max_size
261+
262+
263+
def get_caption_projection_dim(state_dict):
264+
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
265+
return caption_projection_dim
266+
267+
198268
def main(args):
199269
original_ckpt = load_original_checkpoint(args.checkpoint_path)
270+
original_dtype = next(iter(original_ckpt.values())).dtype
271+
272+
# Initialize dtype with a default value
273+
dtype = None
274+
275+
if args.dtype is None:
276+
dtype = original_dtype
277+
elif args.dtype == "fp16":
278+
dtype = torch.float16
279+
elif args.dtype == "bf16":
280+
dtype = torch.bfloat16
281+
elif args.dtype == "fp32":
282+
dtype = torch.float32
283+
else:
284+
raise ValueError(f"Unsupported dtype: {args.dtype}")
285+
286+
if dtype != original_dtype:
287+
print(
288+
f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
289+
)
290+
200291
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
201-
caption_projection_dim = 1536
292+
293+
caption_projection_dim = get_caption_projection_dim(original_ckpt)
294+
295+
# () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
296+
attn2_layers = get_attn2_layers(original_ckpt)
297+
298+
# sd3.5 use qk norm("rms_norm")
299+
has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
300+
301+
# sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
302+
pos_embed_max_size = get_pos_embed_max_size(original_ckpt)
202303

203304
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
204-
original_ckpt, num_layers, caption_projection_dim
305+
original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
205306
)
206307

207308
with CTX():
208309
transformer = SD3Transformer2DModel(
209-
sample_size=64,
310+
sample_size=128,
210311
patch_size=2,
211312
in_channels=16,
212313
joint_attention_dim=4096,
213314
num_layers=num_layers,
214315
caption_projection_dim=caption_projection_dim,
215-
num_attention_heads=24,
216-
pos_embed_max_size=192,
316+
num_attention_heads=num_layers,
317+
pos_embed_max_size=pos_embed_max_size,
318+
qk_norm="rms_norm" if has_qk_norm else None,
319+
dual_attention_layers=attn2_layers,
217320
)
218321
if is_accelerate_available():
219322
load_model_dict_into_meta(transformer, converted_transformer_state_dict)

src/diffusers/models/attention.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
2323
from .attention_processor import Attention, JointAttnProcessor2_0
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
25+
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
2626

2727

2828
logger = logging.get_logger(__name__)
@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
100100
processing of `context` conditions.
101101
"""
102102

103-
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
103+
def __init__(
104+
self,
105+
dim: int,
106+
num_attention_heads: int,
107+
attention_head_dim: int,
108+
context_pre_only: bool = False,
109+
qk_norm: Optional[str] = None,
110+
use_dual_attention: bool = False,
111+
):
104112
super().__init__()
105113

114+
self.use_dual_attention = use_dual_attention
106115
self.context_pre_only = context_pre_only
107116
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
108117

109-
self.norm1 = AdaLayerNormZero(dim)
118+
if use_dual_attention:
119+
self.norm1 = SD35AdaLayerNormZeroX(dim)
120+
else:
121+
self.norm1 = AdaLayerNormZero(dim)
110122

111123
if context_norm_type == "ada_norm_continous":
112124
self.norm1_context = AdaLayerNormContinuous(
@@ -118,12 +130,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
118130
raise ValueError(
119131
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
120132
)
133+
121134
if hasattr(F, "scaled_dot_product_attention"):
122135
processor = JointAttnProcessor2_0()
123136
else:
124137
raise ValueError(
125138
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
126139
)
140+
127141
self.attn = Attention(
128142
query_dim=dim,
129143
cross_attention_dim=None,
@@ -134,8 +148,25 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
134148
context_pre_only=context_pre_only,
135149
bias=True,
136150
processor=processor,
151+
qk_norm=qk_norm,
152+
eps=1e-6,
137153
)
138154

155+
if use_dual_attention:
156+
self.attn2 = Attention(
157+
query_dim=dim,
158+
cross_attention_dim=None,
159+
dim_head=attention_head_dim,
160+
heads=num_attention_heads,
161+
out_dim=dim,
162+
bias=True,
163+
processor=processor,
164+
qk_norm=qk_norm,
165+
eps=1e-6,
166+
)
167+
else:
168+
self.attn2 = None
169+
139170
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
140171
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
141172

@@ -159,7 +190,12 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
159190
def forward(
160191
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
161192
):
162-
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
193+
if self.use_dual_attention:
194+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
195+
hidden_states, emb=temb
196+
)
197+
else:
198+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
163199

164200
if self.context_pre_only:
165201
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
@@ -177,6 +213,11 @@ def forward(
177213
attn_output = gate_msa.unsqueeze(1) * attn_output
178214
hidden_states = hidden_states + attn_output
179215

216+
if self.use_dual_attention:
217+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
218+
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
219+
hidden_states = hidden_states + attn_output2
220+
180221
norm_hidden_states = self.norm2(hidden_states)
181222
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
182223
if self._chunk_size is not None:

0 commit comments

Comments
 (0)