|
14 | 14 |
|
15 | 15 | import re
|
16 | 16 |
|
| 17 | +import torch |
| 18 | + |
17 | 19 | from ..utils import is_peft_version, logging
|
18 | 20 |
|
19 | 21 |
|
@@ -326,3 +328,294 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
326 | 328 | prefix = "text_encoder_2."
|
327 | 329 | new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
328 | 330 | return {new_name: alpha}
|
| 331 | + |
| 332 | + |
| 333 | +# The utilities under `_convert_kohya_flux_lora_to_diffusers()` |
| 334 | +# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py |
| 335 | +# All credits go to `kohya-ss`. |
| 336 | +def _convert_kohya_flux_lora_to_diffusers(state_dict): |
| 337 | + def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): |
| 338 | + if sds_key + ".lora_down.weight" not in sds_sd: |
| 339 | + return |
| 340 | + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") |
| 341 | + |
| 342 | + # scale weight by alpha and dim |
| 343 | + rank = down_weight.shape[0] |
| 344 | + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar |
| 345 | + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here |
| 346 | + |
| 347 | + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 |
| 348 | + scale_down = scale |
| 349 | + scale_up = 1.0 |
| 350 | + while scale_down * 2 < scale_up: |
| 351 | + scale_down *= 2 |
| 352 | + scale_up /= 2 |
| 353 | + |
| 354 | + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down |
| 355 | + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up |
| 356 | + |
| 357 | + def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): |
| 358 | + if sds_key + ".lora_down.weight" not in sds_sd: |
| 359 | + return |
| 360 | + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") |
| 361 | + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") |
| 362 | + sd_lora_rank = down_weight.shape[0] |
| 363 | + |
| 364 | + # scale weight by alpha and dim |
| 365 | + alpha = sds_sd.pop(sds_key + ".alpha") |
| 366 | + scale = alpha / sd_lora_rank |
| 367 | + |
| 368 | + # calculate scale_down and scale_up |
| 369 | + scale_down = scale |
| 370 | + scale_up = 1.0 |
| 371 | + while scale_down * 2 < scale_up: |
| 372 | + scale_down *= 2 |
| 373 | + scale_up /= 2 |
| 374 | + |
| 375 | + down_weight = down_weight * scale_down |
| 376 | + up_weight = up_weight * scale_up |
| 377 | + |
| 378 | + # calculate dims if not provided |
| 379 | + num_splits = len(ait_keys) |
| 380 | + if dims is None: |
| 381 | + dims = [up_weight.shape[0] // num_splits] * num_splits |
| 382 | + else: |
| 383 | + assert sum(dims) == up_weight.shape[0] |
| 384 | + |
| 385 | + # check upweight is sparse or not |
| 386 | + is_sparse = False |
| 387 | + if sd_lora_rank % num_splits == 0: |
| 388 | + ait_rank = sd_lora_rank // num_splits |
| 389 | + is_sparse = True |
| 390 | + i = 0 |
| 391 | + for j in range(len(dims)): |
| 392 | + for k in range(len(dims)): |
| 393 | + if j == k: |
| 394 | + continue |
| 395 | + is_sparse = is_sparse and torch.all( |
| 396 | + up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 |
| 397 | + ) |
| 398 | + i += dims[j] |
| 399 | + if is_sparse: |
| 400 | + logger.info(f"weight is sparse: {sds_key}") |
| 401 | + |
| 402 | + # make ai-toolkit weight |
| 403 | + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] |
| 404 | + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] |
| 405 | + if not is_sparse: |
| 406 | + # down_weight is copied to each split |
| 407 | + ait_sd.update({k: down_weight for k in ait_down_keys}) |
| 408 | + |
| 409 | + # up_weight is split to each split |
| 410 | + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 |
| 411 | + else: |
| 412 | + # down_weight is chunked to each split |
| 413 | + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 |
| 414 | + |
| 415 | + # up_weight is sparse: only non-zero values are copied to each split |
| 416 | + i = 0 |
| 417 | + for j in range(len(dims)): |
| 418 | + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() |
| 419 | + i += dims[j] |
| 420 | + |
| 421 | + def _convert_sd_scripts_to_ai_toolkit(sds_sd): |
| 422 | + ait_sd = {} |
| 423 | + for i in range(19): |
| 424 | + _convert_to_ai_toolkit( |
| 425 | + sds_sd, |
| 426 | + ait_sd, |
| 427 | + f"lora_unet_double_blocks_{i}_img_attn_proj", |
| 428 | + f"transformer.transformer_blocks.{i}.attn.to_out.0", |
| 429 | + ) |
| 430 | + _convert_to_ai_toolkit_cat( |
| 431 | + sds_sd, |
| 432 | + ait_sd, |
| 433 | + f"lora_unet_double_blocks_{i}_img_attn_qkv", |
| 434 | + [ |
| 435 | + f"transformer.transformer_blocks.{i}.attn.to_q", |
| 436 | + f"transformer.transformer_blocks.{i}.attn.to_k", |
| 437 | + f"transformer.transformer_blocks.{i}.attn.to_v", |
| 438 | + ], |
| 439 | + ) |
| 440 | + _convert_to_ai_toolkit( |
| 441 | + sds_sd, |
| 442 | + ait_sd, |
| 443 | + f"lora_unet_double_blocks_{i}_img_mlp_0", |
| 444 | + f"transformer.transformer_blocks.{i}.ff.net.0.proj", |
| 445 | + ) |
| 446 | + _convert_to_ai_toolkit( |
| 447 | + sds_sd, |
| 448 | + ait_sd, |
| 449 | + f"lora_unet_double_blocks_{i}_img_mlp_2", |
| 450 | + f"transformer.transformer_blocks.{i}.ff.net.2", |
| 451 | + ) |
| 452 | + _convert_to_ai_toolkit( |
| 453 | + sds_sd, |
| 454 | + ait_sd, |
| 455 | + f"lora_unet_double_blocks_{i}_img_mod_lin", |
| 456 | + f"transformer.transformer_blocks.{i}.norm1.linear", |
| 457 | + ) |
| 458 | + _convert_to_ai_toolkit( |
| 459 | + sds_sd, |
| 460 | + ait_sd, |
| 461 | + f"lora_unet_double_blocks_{i}_txt_attn_proj", |
| 462 | + f"transformer.transformer_blocks.{i}.attn.to_add_out", |
| 463 | + ) |
| 464 | + _convert_to_ai_toolkit_cat( |
| 465 | + sds_sd, |
| 466 | + ait_sd, |
| 467 | + f"lora_unet_double_blocks_{i}_txt_attn_qkv", |
| 468 | + [ |
| 469 | + f"transformer.transformer_blocks.{i}.attn.add_q_proj", |
| 470 | + f"transformer.transformer_blocks.{i}.attn.add_k_proj", |
| 471 | + f"transformer.transformer_blocks.{i}.attn.add_v_proj", |
| 472 | + ], |
| 473 | + ) |
| 474 | + _convert_to_ai_toolkit( |
| 475 | + sds_sd, |
| 476 | + ait_sd, |
| 477 | + f"lora_unet_double_blocks_{i}_txt_mlp_0", |
| 478 | + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", |
| 479 | + ) |
| 480 | + _convert_to_ai_toolkit( |
| 481 | + sds_sd, |
| 482 | + ait_sd, |
| 483 | + f"lora_unet_double_blocks_{i}_txt_mlp_2", |
| 484 | + f"transformer.transformer_blocks.{i}.ff_context.net.2", |
| 485 | + ) |
| 486 | + _convert_to_ai_toolkit( |
| 487 | + sds_sd, |
| 488 | + ait_sd, |
| 489 | + f"lora_unet_double_blocks_{i}_txt_mod_lin", |
| 490 | + f"transformer.transformer_blocks.{i}.norm1_context.linear", |
| 491 | + ) |
| 492 | + |
| 493 | + for i in range(38): |
| 494 | + _convert_to_ai_toolkit_cat( |
| 495 | + sds_sd, |
| 496 | + ait_sd, |
| 497 | + f"lora_unet_single_blocks_{i}_linear1", |
| 498 | + [ |
| 499 | + f"transformer.single_transformer_blocks.{i}.attn.to_q", |
| 500 | + f"transformer.single_transformer_blocks.{i}.attn.to_k", |
| 501 | + f"transformer.single_transformer_blocks.{i}.attn.to_v", |
| 502 | + f"transformer.single_transformer_blocks.{i}.proj_mlp", |
| 503 | + ], |
| 504 | + dims=[3072, 3072, 3072, 12288], |
| 505 | + ) |
| 506 | + _convert_to_ai_toolkit( |
| 507 | + sds_sd, |
| 508 | + ait_sd, |
| 509 | + f"lora_unet_single_blocks_{i}_linear2", |
| 510 | + f"transformer.single_transformer_blocks.{i}.proj_out", |
| 511 | + ) |
| 512 | + _convert_to_ai_toolkit( |
| 513 | + sds_sd, |
| 514 | + ait_sd, |
| 515 | + f"lora_unet_single_blocks_{i}_modulation_lin", |
| 516 | + f"transformer.single_transformer_blocks.{i}.norm.linear", |
| 517 | + ) |
| 518 | + |
| 519 | + if len(sds_sd) > 0: |
| 520 | + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") |
| 521 | + |
| 522 | + return ait_sd |
| 523 | + |
| 524 | + return _convert_sd_scripts_to_ai_toolkit(state_dict) |
| 525 | + |
| 526 | + |
| 527 | +# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6 |
| 528 | +# Some utilities were reused from |
| 529 | +# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py |
| 530 | +def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): |
| 531 | + new_state_dict = {} |
| 532 | + orig_keys = list(old_state_dict.keys()) |
| 533 | + |
| 534 | + def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): |
| 535 | + down_weight = sds_sd.pop(sds_key) |
| 536 | + up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) |
| 537 | + |
| 538 | + # calculate dims if not provided |
| 539 | + num_splits = len(ait_keys) |
| 540 | + if dims is None: |
| 541 | + dims = [up_weight.shape[0] // num_splits] * num_splits |
| 542 | + else: |
| 543 | + assert sum(dims) == up_weight.shape[0] |
| 544 | + |
| 545 | + # make ai-toolkit weight |
| 546 | + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] |
| 547 | + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] |
| 548 | + |
| 549 | + # down_weight is copied to each split |
| 550 | + ait_sd.update({k: down_weight for k in ait_down_keys}) |
| 551 | + |
| 552 | + # up_weight is split to each split |
| 553 | + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 |
| 554 | + |
| 555 | + for old_key in orig_keys: |
| 556 | + # Handle double_blocks |
| 557 | + if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")): |
| 558 | + block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1) |
| 559 | + new_key = f"transformer.transformer_blocks.{block_num}" |
| 560 | + |
| 561 | + if "processor.proj_lora1" in old_key: |
| 562 | + new_key += ".attn.to_out.0" |
| 563 | + elif "processor.proj_lora2" in old_key: |
| 564 | + new_key += ".attn.to_add_out" |
| 565 | + elif "processor.qkv_lora1" in old_key and "up" not in old_key: |
| 566 | + handle_qkv( |
| 567 | + old_state_dict, |
| 568 | + new_state_dict, |
| 569 | + old_key, |
| 570 | + [ |
| 571 | + f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", |
| 572 | + f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", |
| 573 | + f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", |
| 574 | + ], |
| 575 | + ) |
| 576 | + # continue |
| 577 | + elif "processor.qkv_lora2" in old_key and "up" not in old_key: |
| 578 | + handle_qkv( |
| 579 | + old_state_dict, |
| 580 | + new_state_dict, |
| 581 | + old_key, |
| 582 | + [ |
| 583 | + f"transformer.transformer_blocks.{block_num}.attn.to_q", |
| 584 | + f"transformer.transformer_blocks.{block_num}.attn.to_k", |
| 585 | + f"transformer.transformer_blocks.{block_num}.attn.to_v", |
| 586 | + ], |
| 587 | + ) |
| 588 | + # continue |
| 589 | + |
| 590 | + if "down" in old_key: |
| 591 | + new_key += ".lora_A.weight" |
| 592 | + elif "up" in old_key: |
| 593 | + new_key += ".lora_B.weight" |
| 594 | + |
| 595 | + # Handle single_blocks |
| 596 | + elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"): |
| 597 | + block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) |
| 598 | + new_key = f"transformer.single_transformer_blocks.{block_num}" |
| 599 | + |
| 600 | + if "proj_lora1" in old_key or "proj_lora2" in old_key: |
| 601 | + new_key += ".proj_out" |
| 602 | + elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: |
| 603 | + new_key += ".norm.linear" |
| 604 | + |
| 605 | + if "down" in old_key: |
| 606 | + new_key += ".lora_A.weight" |
| 607 | + elif "up" in old_key: |
| 608 | + new_key += ".lora_B.weight" |
| 609 | + |
| 610 | + else: |
| 611 | + # Handle other potential key patterns here |
| 612 | + new_key = old_key |
| 613 | + |
| 614 | + # Since we already handle qkv above. |
| 615 | + if "qkv" not in old_key: |
| 616 | + new_state_dict[new_key] = old_state_dict.pop(old_key) |
| 617 | + |
| 618 | + if len(old_state_dict) > 0: |
| 619 | + raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") |
| 620 | + |
| 621 | + return new_state_dict |
0 commit comments