Skip to content

Commit 258a398

Browse files
committed
working.
1 parent d3e177c commit 258a398

File tree

1 file changed

+42
-41
lines changed

1 file changed

+42
-41
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,7 +2312,6 @@ def _maybe_expand_transformer_param_shape_or_error_(
23122312

23132313
# Expand transformer parameter shapes if they don't match lora
23142314
has_param_with_shape_update = False
2315-
23162315
for name, module in transformer.named_modules():
23172316
if isinstance(module, torch.nn.Linear):
23182317
module_weight = module.weight.data
@@ -2332,54 +2331,52 @@ def _maybe_expand_transformer_param_shape_or_error_(
23322331
continue
23332332

23342333
module_out_features, module_in_features = module_weight.shape
2335-
if out_features < module_out_features or in_features < module_in_features:
2336-
raise NotImplementedError(
2337-
f"Only LoRAs with input/output features higher than the current module's input/output features "
2338-
f"are currently supported. The provided LoRA contains {in_features=} and {out_features=}, which "
2339-
f"are lower than {module_in_features=} and {module_out_features=}. If you require support for "
2340-
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
2334+
debug_message = ""
2335+
if in_features > module_in_features:
2336+
debug_message += (
2337+
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
2338+
f"checkpoint contains higher number of features than expected. The number of input_features will be "
2339+
f"expanded from {module_in_features} to {in_features}"
23412340
)
2342-
2343-
debug_message = (
2344-
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
2345-
f"checkpoint contains higher number of features than expected. The number of input_features will be "
2346-
f"expanded from {module_in_features} to {in_features}"
2347-
)
2348-
if module_out_features != out_features:
2341+
if out_features > module_out_features:
23492342
debug_message += (
23502343
", and the number of output features will be "
23512344
f"expanded from {module_out_features} to {out_features}."
23522345
)
23532346
else:
23542347
debug_message += "."
2355-
logger.debug(debug_message)
2348+
if debug_message:
2349+
logger.debug(debug_message)
23562350

2357-
has_param_with_shape_update = True
2358-
parent_module_name, _, current_module_name = name.rpartition(".")
2359-
parent_module = transformer.get_submodule(parent_module_name)
2351+
if out_features > module_out_features or in_features > module_in_features:
2352+
has_param_with_shape_update = True
2353+
parent_module_name, _, current_module_name = name.rpartition(".")
2354+
parent_module = transformer.get_submodule(parent_module_name)
23602355

2361-
# TODO: consider initializing this under meta device for optims.
2362-
expanded_module = torch.nn.Linear(
2363-
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
2364-
)
2365-
# Only weights are expanded and biases are not.
2366-
new_weight = torch.zeros_like(
2367-
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2368-
)
2369-
slices = tuple(slice(0, dim) for dim in module_weight.shape)
2370-
new_weight[slices] = module_weight
2371-
expanded_module.weight.data.copy_(new_weight)
2372-
if module_bias is not None:
2373-
expanded_module.bias.data.copy_(module_bias)
2374-
2375-
setattr(parent_module, current_module_name, expanded_module)
2376-
2377-
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2378-
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2379-
new_value = int(expanded_module.weight.data.shape[1])
2380-
old_value = getattr(transformer.config, attribute_name)
2381-
setattr(transformer.config, attribute_name, new_value)
2382-
logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.")
2356+
# TODO: consider initializing this under meta device for optims.
2357+
expanded_module = torch.nn.Linear(
2358+
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
2359+
)
2360+
# Only weights are expanded and biases are not.
2361+
new_weight = torch.zeros_like(
2362+
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2363+
)
2364+
slices = tuple(slice(0, dim) for dim in module_weight.shape)
2365+
new_weight[slices] = module_weight
2366+
expanded_module.weight.data.copy_(new_weight)
2367+
if module_bias is not None:
2368+
expanded_module.bias.data.copy_(module_bias)
2369+
2370+
setattr(parent_module, current_module_name, expanded_module)
2371+
2372+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
2373+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
2374+
new_value = int(expanded_module.weight.data.shape[1])
2375+
old_value = getattr(transformer.config, attribute_name)
2376+
setattr(transformer.config, attribute_name, new_value)
2377+
logger.info(
2378+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
2379+
)
23832380

23842381
return has_param_with_shape_update
23852382

@@ -2405,10 +2402,14 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24052402
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
24062403
lora_state_dict[f"{k}.lora_A.weight"] = expanded_state_dict_weight
24072404
expanded_module_names.add(k)
2405+
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
2406+
raise NotImplementedError(
2407+
"We currently don't support loading LoRAs for this use case. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
2408+
)
24082409

24092410
if expanded_module_names:
24102411
logger.info(
2411-
f"Found some LoRA modules for which the weights were expanded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2412+
f"Found some LoRA modules for which the weights were zero-padded: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
24122413
)
24132414
return lora_state_dict
24142415

0 commit comments

Comments
 (0)