Skip to content

Commit ecbc4cb

Browse files
committed
fixes
1 parent d6518b7 commit ecbc4cb

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,10 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
673673
inner_dim = 3072
674674
mlp_ratio = 4.0
675675

676+
for k in original_state_dict:
677+
if "bias" in k and "img_in" in k:
678+
print(f"{k=}")
679+
676680
def swap_scale_shift(weight):
677681
shift, scale = weight.chunk(2, dim=0)
678682
new_weight = torch.cat([scale, shift], dim=0)
@@ -750,7 +754,7 @@ def swap_scale_shift(weight):
750754
for i in range(num_layers):
751755
block_prefix = f"transformer_blocks.{i}."
752756

753-
for lora_key, lora_key in zip(["lora_A", "lora_B"], ["lora_A", "lora_B"]):
757+
for lora_key in ["lora_A", "lora_B"]:
754758
# norms
755759
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
756760
f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"

src/diffusers/loaders/lora_pipeline.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,7 +2318,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
23182318

23192319
lora_A_weight_name = f"{name}.lora_A.weight"
23202320
lora_B_weight_name = f"{name}.lora_B.weight"
2321-
lora_B_bias_name = f"{name}.lora_B.bias"
2321+
# lora_B_bias_name = f"{name}.lora_B.bias"
23222322

23232323
if lora_A_weight_name not in state_dict.keys():
23242324
continue
@@ -2352,24 +2352,15 @@ def _maybe_expand_transformer_param_shape_or_error_(
23522352
expanded_module = torch.nn.Linear(
23532353
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
23542354
)
2355-
2355+
# Only weights are expanded and biases are not.
23562356
new_weight = torch.zeros_like(
23572357
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
23582358
)
23592359
slices = tuple(slice(0, dim) for dim in module_weight.shape)
23602360
new_weight[slices] = module_weight
23612361
expanded_module.weight.data.copy_(new_weight)
2362-
2363-
bias_present_for_lora_B = lora_B_bias_name in state_dict
2364-
if bias_present_for_lora_B:
2365-
new_bias_shape = state_dict[lora_B_bias_name].shape
2366-
if bias and module_bias.shape < new_bias_shape:
2367-
new_bias = torch.zeros_like(
2368-
expanded_module.bias.data, device=module_bias.device, dtype=module_bias.dtype
2369-
)
2370-
slices = tuple(slice(0, dim) for dim in module_bias.shape)
2371-
new_bias[slices] = module_bias
2372-
expanded_module.bias.data.copy_(new_bias)
2362+
if module_bias is not None:
2363+
expanded_module.bias.data.copy_(module_bias)
23732364

23742365
setattr(parent_module, current_module_name, expanded_module)
23752366

0 commit comments

Comments
 (0)