Skip to content

Commit f9fd511

Browse files
authored
[LoRA] support Kohya Flux LoRAs that have text encoders as well (#9542)
* support kohya flux loras that have tes.
1 parent 8e7d6c0 commit f9fd511

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,47 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
516516
f"transformer.single_transformer_blocks.{i}.norm.linear",
517517
)
518518

519+
remaining_keys = list(sds_sd.keys())
520+
te_state_dict = {}
521+
if remaining_keys:
522+
if not all(k.startswith("lora_te1") for k in remaining_keys):
523+
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524+
for key in remaining_keys:
525+
if not key.endswith("lora_down.weight"):
526+
continue
527+
528+
lora_name = key.split(".")[0]
529+
lora_name_up = f"{lora_name}.lora_up.weight"
530+
lora_name_alpha = f"{lora_name}.alpha"
531+
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
532+
533+
if lora_name.startswith(("lora_te_", "lora_te1_")):
534+
down_weight = sds_sd.pop(key)
535+
sd_lora_rank = down_weight.shape[0]
536+
te_state_dict[diffusers_name] = down_weight
537+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
538+
539+
if lora_name_alpha in sds_sd:
540+
alpha = sds_sd.pop(lora_name_alpha).item()
541+
scale = alpha / sd_lora_rank
542+
543+
scale_down = scale
544+
scale_up = 1.0
545+
while scale_down * 2 < scale_up:
546+
scale_down *= 2
547+
scale_up /= 2
548+
549+
te_state_dict[diffusers_name] *= scale_down
550+
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
551+
519552
if len(sds_sd) > 0:
520-
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
553+
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
554+
555+
if te_state_dict:
556+
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
521557

522-
return ait_sd
558+
new_state_dict = {**ait_sd, **te_state_dict}
559+
return new_state_dict
523560

524561
return _convert_sd_scripts_to_ai_toolkit(state_dict)
525562

tests/lora/test_lora_layers_flux.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,26 @@ def test_flux_kohya(self):
228228

229229
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
230230

231+
def test_flux_kohya_with_text_encoder(self):
232+
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
233+
self.pipeline.fuse_lora()
234+
self.pipeline.unload_lora_weights()
235+
self.pipeline.enable_model_cpu_offload()
236+
237+
prompt = "optimus is cleaning the house with broomstick"
238+
out = self.pipeline(
239+
prompt,
240+
num_inference_steps=self.num_inference_steps,
241+
guidance_scale=4.5,
242+
output_type="np",
243+
generator=torch.manual_seed(self.seed),
244+
).images
245+
246+
out_slice = out[0, -3:, -3:, -1].flatten()
247+
expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219])
248+
249+
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
250+
231251
def test_flux_xlabs(self):
232252
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
233253
self.pipeline.fuse_lora()

0 commit comments

Comments
 (0)