@@ -2312,7 +2312,6 @@ def _maybe_expand_transformer_param_shape_or_error_(
23122312
23132313 # Expand transformer parameter shapes if they don't match lora
23142314 has_param_with_shape_update = False
2315-
23162315 for name , module in transformer .named_modules ():
23172316 if isinstance (module , torch .nn .Linear ):
23182317 module_weight = module .weight .data
@@ -2332,54 +2331,52 @@ def _maybe_expand_transformer_param_shape_or_error_(
23322331 continue
23332332
23342333 module_out_features , module_in_features = module_weight .shape
2335- if out_features < module_out_features or in_features < module_in_features :
2336- raise NotImplementedError (
2337- f"Only LoRAs with input/output features higher than the current module's input/output features "
2338- f"are currently supported. The provided LoRA contains { in_features = } and { out_features = } , which "
2339- f"are lower than { module_in_features = } and { module_out_features = } . If you require support for "
2340- f"this please open an issue at https://github.com/huggingface/diffusers/issues. "
2334+ debug_message = ""
2335+ if in_features > module_in_features :
2336+ debug_message += (
2337+ f'Expanding the nn.Linear input/output features for module=" { name } " because the provided LoRA '
2338+ f"checkpoint contains higher number of features than expected. The number of input_features will be "
2339+ f"expanded from { module_in_features } to { in_features } "
23412340 )
2342-
2343- debug_message = (
2344- f'Expanding the nn.Linear input/output features for module="{ name } " because the provided LoRA '
2345- f"checkpoint contains higher number of features than expected. The number of input_features will be "
2346- f"expanded from { module_in_features } to { in_features } "
2347- )
2348- if module_out_features != out_features :
2341+ if out_features > module_out_features :
23492342 debug_message += (
23502343 ", and the number of output features will be "
23512344 f"expanded from { module_out_features } to { out_features } ."
23522345 )
23532346 else :
23542347 debug_message += "."
2355- logger .debug (debug_message )
2348+ if debug_message :
2349+ logger .debug (debug_message )
23562350
2357- has_param_with_shape_update = True
2358- parent_module_name , _ , current_module_name = name .rpartition ("." )
2359- parent_module = transformer .get_submodule (parent_module_name )
2351+ if out_features > module_out_features or in_features > module_in_features :
2352+ has_param_with_shape_update = True
2353+ parent_module_name , _ , current_module_name = name .rpartition ("." )
2354+ parent_module = transformer .get_submodule (parent_module_name )
23602355
2361- # TODO: consider initializing this under meta device for optims.
2362- expanded_module = torch .nn .Linear (
2363- in_features , out_features , bias = bias , device = module_weight .device , dtype = module_weight .dtype
2364- )
2365- # Only weights are expanded and biases are not.
2366- new_weight = torch .zeros_like (
2367- expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
2368- )
2369- slices = tuple (slice (0 , dim ) for dim in module_weight .shape )
2370- new_weight [slices ] = module_weight
2371- expanded_module .weight .data .copy_ (new_weight )
2372- if module_bias is not None :
2373- expanded_module .bias .data .copy_ (module_bias )
2374-
2375- setattr (parent_module , current_module_name , expanded_module )
2376-
2377- if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX :
2378- attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX [current_module_name ]
2379- new_value = int (expanded_module .weight .data .shape [1 ])
2380- old_value = getattr (transformer .config , attribute_name )
2381- setattr (transformer .config , attribute_name , new_value )
2382- logger .info (f"Set the { attribute_name } attribute of the model to { new_value } from { old_value } ." )
2356+ # TODO: consider initializing this under meta device for optims.
2357+ expanded_module = torch .nn .Linear (
2358+ in_features , out_features , bias = bias , device = module_weight .device , dtype = module_weight .dtype
2359+ )
2360+ # Only weights are expanded and biases are not.
2361+ new_weight = torch .zeros_like (
2362+ expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
2363+ )
2364+ slices = tuple (slice (0 , dim ) for dim in module_weight .shape )
2365+ new_weight [slices ] = module_weight
2366+ expanded_module .weight .data .copy_ (new_weight )
2367+ if module_bias is not None :
2368+ expanded_module .bias .data .copy_ (module_bias )
2369+
2370+ setattr (parent_module , current_module_name , expanded_module )
2371+
2372+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX :
2373+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX [current_module_name ]
2374+ new_value = int (expanded_module .weight .data .shape [1 ])
2375+ old_value = getattr (transformer .config , attribute_name )
2376+ setattr (transformer .config , attribute_name , new_value )
2377+ logger .info (
2378+ f"Set the { attribute_name } attribute of the model to { new_value } from { old_value } ."
2379+ )
23832380
23842381 return has_param_with_shape_update
23852382
@@ -2405,10 +2402,14 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24052402 expanded_state_dict_weight [:, : lora_A_param .shape [1 ]].copy_ (lora_A_param )
24062403 lora_state_dict [f"{ k } .lora_A.weight" ] = expanded_state_dict_weight
24072404 expanded_module_names .add (k )
2405+ elif base_weight_param .shape [1 ] < lora_A_param .shape [1 ]:
2406+ raise NotImplementedError (
2407+ "We currently don't support loading LoRAs for this use case. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
2408+ )
24082409
24092410 if expanded_module_names :
24102411 logger .info (
2411- f"Found some LoRA modules for which the weights were expanded : { expanded_module_names } . Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2412+ f"Found some LoRA modules for which the weights were zero-padded : { expanded_module_names } . Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
24122413 )
24132414 return lora_state_dict
24142415
0 commit comments