Skip to content

Commit 03165b9

Browse files
committed
Merge branch 'chroma-fixes' into chroma-img2img
2 parents 6ac443d + 544dad4 commit 03165b9

File tree

9 files changed

+197
-28
lines changed

9 files changed

+197
-28
lines changed

examples/dreambooth/test_dreambooth_lora_flux.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
2326

2427
sys.path.append("..")
2528
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -234,3 +237,45 @@ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_mult
234237
run_command(self._launch_args + resume_run_args)
235238

236239
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
240+
241+
def test_dreambooth_lora_with_metadata(self):
242+
# Use a `lora_alpha` that is different from `rank`.
243+
lora_alpha = 8
244+
rank = 4
245+
with tempfile.TemporaryDirectory() as tmpdir:
246+
test_args = f"""
247+
{self.script_path}
248+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
249+
--instance_data_dir {self.instance_data_dir}
250+
--instance_prompt {self.instance_prompt}
251+
--resolution 64
252+
--train_batch_size 1
253+
--gradient_accumulation_steps 1
254+
--max_train_steps 2
255+
--lora_alpha={lora_alpha}
256+
--rank={rank}
257+
--learning_rate 5.0e-04
258+
--scale_lr
259+
--lr_scheduler constant
260+
--lr_warmup_steps 0
261+
--output_dir {tmpdir}
262+
""".split()
263+
264+
run_command(self._launch_args + test_args)
265+
# save_pretrained smoke test
266+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
267+
self.assertTrue(os.path.isfile(state_dict_file))
268+
269+
# Check if the metadata was properly serialized.
270+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
271+
metadata = f.metadata() or {}
272+
273+
metadata.pop("format", None)
274+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
275+
if raw:
276+
raw = json.loads(raw)
277+
278+
loaded_lora_alpha = raw["transformer.lora_alpha"]
279+
self.assertTrue(loaded_lora_alpha == lora_alpha)
280+
loaded_lora_rank = raw["transformer.r"]
281+
self.assertTrue(loaded_lora_rank == rank)

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import numpy as np
2929
import torch
30-
import torch.utils.checkpoint
3130
import transformers
3231
from accelerate import Accelerator
3332
from accelerate.logging import get_logger
@@ -53,6 +52,7 @@
5352
)
5453
from diffusers.optimization import get_scheduler
5554
from diffusers.training_utils import (
55+
_collate_lora_metadata,
5656
_set_state_dict_into_text_encoder,
5757
cast_training_params,
5858
compute_density_for_timestep_sampling,
@@ -358,7 +358,12 @@ def parse_args(input_args=None):
358358
default=4,
359359
help=("The dimension of the LoRA update matrices."),
360360
)
361-
361+
parser.add_argument(
362+
"--lora_alpha",
363+
type=int,
364+
default=4,
365+
help="LoRA alpha to be used for additional scaling.",
366+
)
362367
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
363368

364369
parser.add_argument(
@@ -1238,7 +1243,7 @@ def main(args):
12381243
# now we will add new LoRA weights the transformer layers
12391244
transformer_lora_config = LoraConfig(
12401245
r=args.rank,
1241-
lora_alpha=args.rank,
1246+
lora_alpha=args.lora_alpha,
12421247
lora_dropout=args.lora_dropout,
12431248
init_lora_weights="gaussian",
12441249
target_modules=target_modules,
@@ -1247,7 +1252,7 @@ def main(args):
12471252
if args.train_text_encoder:
12481253
text_lora_config = LoraConfig(
12491254
r=args.rank,
1250-
lora_alpha=args.rank,
1255+
lora_alpha=args.lora_alpha,
12511256
lora_dropout=args.lora_dropout,
12521257
init_lora_weights="gaussian",
12531258
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
@@ -1264,12 +1269,14 @@ def save_model_hook(models, weights, output_dir):
12641269
if accelerator.is_main_process:
12651270
transformer_lora_layers_to_save = None
12661271
text_encoder_one_lora_layers_to_save = None
1267-
1272+
modules_to_save = {}
12681273
for model in models:
12691274
if isinstance(model, type(unwrap_model(transformer))):
12701275
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1276+
modules_to_save["transformer"] = model
12711277
elif isinstance(model, type(unwrap_model(text_encoder_one))):
12721278
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
1279+
modules_to_save["text_encoder"] = model
12731280
else:
12741281
raise ValueError(f"unexpected save model: {model.__class__}")
12751282

@@ -1280,6 +1287,7 @@ def save_model_hook(models, weights, output_dir):
12801287
output_dir,
12811288
transformer_lora_layers=transformer_lora_layers_to_save,
12821289
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
1290+
**_collate_lora_metadata(modules_to_save),
12831291
)
12841292

12851293
def load_model_hook(models, input_dir):
@@ -1889,23 +1897,27 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18891897
# Save the lora layers
18901898
accelerator.wait_for_everyone()
18911899
if accelerator.is_main_process:
1900+
modules_to_save = {}
18921901
transformer = unwrap_model(transformer)
18931902
if args.upcast_before_saving:
18941903
transformer.to(torch.float32)
18951904
else:
18961905
transformer = transformer.to(weight_dtype)
18971906
transformer_lora_layers = get_peft_model_state_dict(transformer)
1907+
modules_to_save["transformer"] = transformer
18981908

18991909
if args.train_text_encoder:
19001910
text_encoder_one = unwrap_model(text_encoder_one)
19011911
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
1912+
modules_to_save["text_encoder"] = text_encoder_one
19021913
else:
19031914
text_encoder_lora_layers = None
19041915

19051916
FluxPipeline.save_lora_weights(
19061917
save_directory=args.output_dir,
19071918
transformer_lora_layers=transformer_lora_layers,
19081919
text_encoder_lora_layers=text_encoder_lora_layers,
1920+
**_collate_lora_metadata(modules_to_save),
19091921
)
19101922

19111923
# Final inference

src/diffusers/loaders/lora_pipeline.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,18 +2031,36 @@ def lora_state_dict(
20312031
if is_kohya:
20322032
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
20332033
# Kohya already takes care of scaling the LoRA parameters with alpha.
2034-
return (state_dict, None) if return_alphas else state_dict
2034+
return cls._prepare_outputs(
2035+
state_dict,
2036+
metadata=metadata,
2037+
alphas=None,
2038+
return_alphas=return_alphas,
2039+
return_metadata=return_lora_metadata,
2040+
)
20352041

20362042
is_xlabs = any("processor" in k for k in state_dict)
20372043
if is_xlabs:
20382044
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
20392045
# xlabs doesn't use `alpha`.
2040-
return (state_dict, None) if return_alphas else state_dict
2046+
return cls._prepare_outputs(
2047+
state_dict,
2048+
metadata=metadata,
2049+
alphas=None,
2050+
return_alphas=return_alphas,
2051+
return_metadata=return_lora_metadata,
2052+
)
20412053

20422054
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
20432055
if is_bfl_control:
20442056
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
2045-
return (state_dict, None) if return_alphas else state_dict
2057+
return cls._prepare_outputs(
2058+
state_dict,
2059+
metadata=metadata,
2060+
alphas=None,
2061+
return_alphas=return_alphas,
2062+
return_metadata=return_lora_metadata,
2063+
)
20462064

20472065
# For state dicts like
20482066
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
@@ -2061,12 +2079,13 @@ def lora_state_dict(
20612079
)
20622080

20632081
if return_alphas or return_lora_metadata:
2064-
outputs = [state_dict]
2065-
if return_alphas:
2066-
outputs.append(network_alphas)
2067-
if return_lora_metadata:
2068-
outputs.append(metadata)
2069-
return tuple(outputs)
2082+
return cls._prepare_outputs(
2083+
state_dict,
2084+
metadata=metadata,
2085+
alphas=network_alphas,
2086+
return_alphas=return_alphas,
2087+
return_metadata=return_lora_metadata,
2088+
)
20702089
else:
20712090
return state_dict
20722091

@@ -2785,6 +2804,15 @@ def _get_weight_shape(weight: torch.Tensor):
27852804

27862805
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
27872806

2807+
@staticmethod
2808+
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
2809+
outputs = [state_dict]
2810+
if return_alphas:
2811+
outputs.append(alphas)
2812+
if return_metadata:
2813+
outputs.append(metadata)
2814+
return tuple(outputs) if (return_alphas or return_metadata) else state_dict
2815+
27882816

27892817
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
27902818
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

src/diffusers/loaders/peft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ def load_lora_adapter(
187187
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
188188
limitations to this technique, which are documented here:
189189
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
190-
metadata: TODO
190+
metadata:
191+
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
192+
initialize `LoraConfig`.
191193
"""
192194
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
193195
from peft.tuners.tuners_utils import BaseTunerLayer

src/diffusers/models/attention_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2543,7 +2543,9 @@ def __call__(
25432543
query = apply_rotary_emb(query, image_rotary_emb)
25442544
key = apply_rotary_emb(key, image_rotary_emb)
25452545

2546-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2546+
hidden_states = F.scaled_dot_product_attention(
2547+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2548+
)
25472549

25482550
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
25492551
hidden_states = hidden_states.to(query.dtype)
@@ -2776,7 +2778,9 @@ def __call__(
27762778
query = apply_rotary_emb(query, image_rotary_emb)
27772779
key = apply_rotary_emb(key, image_rotary_emb)
27782780

2779-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2781+
hidden_states = F.scaled_dot_product_attention(
2782+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2783+
)
27802784
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
27812785
hidden_states = hidden_states.to(query.dtype)
27822786

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,15 +250,21 @@ def forward(
250250
hidden_states: torch.Tensor,
251251
temb: torch.Tensor,
252252
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253+
attention_mask: Optional[torch.Tensor] = None,
253254
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
254255
) -> torch.Tensor:
255256
residual = hidden_states
256257
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
257258
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
258259
joint_attention_kwargs = joint_attention_kwargs or {}
260+
261+
if attention_mask is not None:
262+
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
263+
259264
attn_output = self.attn(
260265
hidden_states=norm_hidden_states,
261266
image_rotary_emb=image_rotary_emb,
267+
attention_mask=attention_mask,
262268
**joint_attention_kwargs,
263269
)
264270

@@ -312,6 +318,7 @@ def forward(
312318
encoder_hidden_states: torch.Tensor,
313319
temb: torch.Tensor,
314320
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
321+
attention_mask: Optional[torch.Tensor] = None,
315322
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
316323
) -> Tuple[torch.Tensor, torch.Tensor]:
317324
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
@@ -321,11 +328,15 @@ def forward(
321328
encoder_hidden_states, emb=temb_txt
322329
)
323330
joint_attention_kwargs = joint_attention_kwargs or {}
331+
if attention_mask is not None:
332+
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
333+
324334
# Attention.
325335
attention_outputs = self.attn(
326336
hidden_states=norm_hidden_states,
327337
encoder_hidden_states=norm_encoder_hidden_states,
328338
image_rotary_emb=image_rotary_emb,
339+
attention_mask=attention_mask,
329340
**joint_attention_kwargs,
330341
)
331342

@@ -570,6 +581,7 @@ def forward(
570581
timestep: torch.LongTensor = None,
571582
img_ids: torch.Tensor = None,
572583
txt_ids: torch.Tensor = None,
584+
attention_mask: torch.Tensor = None,
573585
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
574586
controlnet_block_samples=None,
575587
controlnet_single_block_samples=None,
@@ -659,11 +671,7 @@ def forward(
659671
)
660672
if torch.is_grad_enabled() and self.gradient_checkpointing:
661673
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
662-
block,
663-
hidden_states,
664-
encoder_hidden_states,
665-
temb,
666-
image_rotary_emb,
674+
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
667675
)
668676

669677
else:
@@ -672,6 +680,7 @@ def forward(
672680
encoder_hidden_states=encoder_hidden_states,
673681
temb=temb,
674682
image_rotary_emb=image_rotary_emb,
683+
attention_mask=attention_mask,
675684
joint_attention_kwargs=joint_attention_kwargs,
676685
)
677686

@@ -704,6 +713,7 @@ def forward(
704713
hidden_states=hidden_states,
705714
temb=temb,
706715
image_rotary_emb=image_rotary_emb,
716+
attention_mask=attention_mask,
707717
joint_attention_kwargs=joint_attention_kwargs,
708718
)
709719

0 commit comments

Comments
 (0)