@@ -1863,6 +1863,9 @@ def load_lora_weights(
18631863 "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
18641864 "To get a comprehensive list of parameter names that were modified, enable debug logging."
18651865 )
1866+ transformer_lora_state_dict = self ._maybe_expand_lora_state_dict (
1867+ transformer = transformer , lora_state_dict = transformer_lora_state_dict
1868+ )
18661869
18671870 if len (transformer_lora_state_dict ) > 0 :
18681871 self .load_lora_into_transformer (
@@ -2309,16 +2312,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
23092312
23102313 # Expand transformer parameter shapes if they don't match lora
23112314 has_param_with_shape_update = False
2312-
2315+ is_peft_loaded = getattr ( transformer , "peft_config" , None ) is not None
23132316 for name , module in transformer .named_modules ():
23142317 if isinstance (module , torch .nn .Linear ):
23152318 module_weight = module .weight .data
23162319 module_bias = module .bias .data if module .bias is not None else None
23172320 bias = module_bias is not None
23182321
2319- lora_A_weight_name = f"{ name } .lora_A.weight"
2320- lora_B_weight_name = f"{ name } .lora_B.weight"
2321- if lora_A_weight_name not in state_dict .keys ():
2322+ lora_base_name = name .replace (".base_layer" , "" ) if is_peft_loaded else name
2323+ lora_A_weight_name = f"{ lora_base_name } .lora_A.weight"
2324+ lora_B_weight_name = f"{ lora_base_name } .lora_B.weight"
2325+ if lora_A_weight_name not in state_dict :
23222326 continue
23232327
23242328 in_features = state_dict [lora_A_weight_name ].shape [1 ]
@@ -2329,56 +2333,105 @@ def _maybe_expand_transformer_param_shape_or_error_(
23292333 continue
23302334
23312335 module_out_features , module_in_features = module_weight .shape
2332- if out_features < module_out_features or in_features < module_in_features :
2333- raise NotImplementedError (
2334- f"Only LoRAs with input/output features higher than the current module's input/output features "
2335- f"are currently supported. The provided LoRA contains { in_features = } and { out_features = } , which "
2336- f"are lower than { module_in_features = } and { module_out_features = } . If you require support for "
2337- f"this please open an issue at https://github.com/huggingface/diffusers/issues. "
2336+ debug_message = ""
2337+ if in_features > module_in_features :
2338+ debug_message += (
2339+ f'Expanding the nn.Linear input/output features for module=" { name } " because the provided LoRA '
2340+ f"checkpoint contains higher number of features than expected. The number of input_features will be "
2341+ f"expanded from { module_in_features } to { in_features } "
23382342 )
2339-
2340- debug_message = (
2341- f'Expanding the nn.Linear input/output features for module="{ name } " because the provided LoRA '
2342- f"checkpoint contains higher number of features than expected. The number of input_features will be "
2343- f"expanded from { module_in_features } to { in_features } "
2344- )
2345- if module_out_features != out_features :
2343+ if out_features > module_out_features :
23462344 debug_message += (
23472345 ", and the number of output features will be "
23482346 f"expanded from { module_out_features } to { out_features } ."
23492347 )
23502348 else :
23512349 debug_message += "."
2352- logger .debug (debug_message )
2350+ if debug_message :
2351+ logger .debug (debug_message )
2352+
2353+ if out_features > module_out_features or in_features > module_in_features :
2354+ has_param_with_shape_update = True
2355+ parent_module_name , _ , current_module_name = name .rpartition ("." )
2356+ parent_module = transformer .get_submodule (parent_module_name )
2357+
2358+ with torch .device ("meta" ):
2359+ expanded_module = torch .nn .Linear (
2360+ in_features , out_features , bias = bias , dtype = module_weight .dtype
2361+ )
2362+ # Only weights are expanded and biases are not. This is because only the input dimensions
2363+ # are changed while the output dimensions remain the same. The shape of the weight tensor
2364+ # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
2365+ # explains the reason why only weights are expanded.
2366+ new_weight = torch .zeros_like (
2367+ expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
2368+ )
2369+ slices = tuple (slice (0 , dim ) for dim in module_weight .shape )
2370+ new_weight [slices ] = module_weight
2371+ tmp_state_dict = {"weight" : new_weight }
2372+ if module_bias is not None :
2373+ tmp_state_dict ["bias" ] = module_bias
2374+ expanded_module .load_state_dict (tmp_state_dict , strict = True , assign = True )
2375+
2376+ setattr (parent_module , current_module_name , expanded_module )
2377+
2378+ del tmp_state_dict
2379+
2380+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX :
2381+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX [current_module_name ]
2382+ new_value = int (expanded_module .weight .data .shape [1 ])
2383+ old_value = getattr (transformer .config , attribute_name )
2384+ setattr (transformer .config , attribute_name , new_value )
2385+ logger .info (
2386+ f"Set the { attribute_name } attribute of the model to { new_value } from { old_value } ."
2387+ )
23532388
2354- has_param_with_shape_update = True
2355- parent_module_name , _ , current_module_name = name .rpartition ("." )
2356- parent_module = transformer .get_submodule (parent_module_name )
2389+ return has_param_with_shape_update
23572390
2358- # TODO: consider initializing this under meta device for optims.
2359- expanded_module = torch .nn .Linear (
2360- in_features , out_features , bias = bias , device = module_weight .device , dtype = module_weight .dtype
2361- )
2362- # Only weights are expanded and biases are not.
2363- new_weight = torch .zeros_like (
2364- expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
2391+ @classmethod
2392+ def _maybe_expand_lora_state_dict (cls , transformer , lora_state_dict ):
2393+ expanded_module_names = set ()
2394+ transformer_state_dict = transformer .state_dict ()
2395+ prefix = f"{ cls .transformer_name } ."
2396+
2397+ lora_module_names = [
2398+ key [: - len (".lora_A.weight" )] for key in lora_state_dict if key .endswith (".lora_A.weight" )
2399+ ]
2400+ lora_module_names = [name [len (prefix ) :] for name in lora_module_names if name .startswith (prefix )]
2401+ lora_module_names = sorted (set (lora_module_names ))
2402+ transformer_module_names = sorted ({name for name , _ in transformer .named_modules ()})
2403+ unexpected_modules = set (lora_module_names ) - set (transformer_module_names )
2404+ if unexpected_modules :
2405+ logger .debug (f"Found unexpected modules: { unexpected_modules } . These will be ignored." )
2406+
2407+ is_peft_loaded = getattr (transformer , "peft_config" , None ) is not None
2408+ for k in lora_module_names :
2409+ if k in unexpected_modules :
2410+ continue
2411+
2412+ base_param_name = (
2413+ f"{ k .replace (prefix , '' )} .base_layer.weight" if is_peft_loaded else f"{ k .replace (prefix , '' )} .weight"
2414+ )
2415+ base_weight_param = transformer_state_dict [base_param_name ]
2416+ lora_A_param = lora_state_dict [f"{ prefix } { k } .lora_A.weight" ]
2417+
2418+ if base_weight_param .shape [1 ] > lora_A_param .shape [1 ]:
2419+ shape = (lora_A_param .shape [0 ], base_weight_param .shape [1 ])
2420+ expanded_state_dict_weight = torch .zeros (shape , device = base_weight_param .device )
2421+ expanded_state_dict_weight [:, : lora_A_param .shape [1 ]].copy_ (lora_A_param )
2422+ lora_state_dict [f"{ prefix } { k } .lora_A.weight" ] = expanded_state_dict_weight
2423+ expanded_module_names .add (k )
2424+ elif base_weight_param .shape [1 ] < lora_A_param .shape [1 ]:
2425+ raise NotImplementedError (
2426+ f"This LoRA param ({ k } .lora_A.weight) has an incompatible shape { lora_A_param .shape } . Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
23652427 )
2366- slices = tuple (slice (0 , dim ) for dim in module_weight .shape )
2367- new_weight [slices ] = module_weight
2368- expanded_module .weight .data .copy_ (new_weight )
2369- if module_bias is not None :
2370- expanded_module .bias .data .copy_ (module_bias )
2371-
2372- setattr (parent_module , current_module_name , expanded_module )
23732428
2374- if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX :
2375- attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX [current_module_name ]
2376- new_value = int (expanded_module .weight .data .shape [1 ])
2377- old_value = getattr (transformer .config , attribute_name )
2378- setattr (transformer .config , attribute_name , new_value )
2379- logger .info (f"Set the { attribute_name } attribute of the model to { new_value } from { old_value } ." )
2429+ if expanded_module_names :
2430+ logger .info (
2431+ f"The following LoRA modules were zero padded to match the state dict of { cls .transformer_name } : { expanded_module_names } . Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
2432+ )
23802433
2381- return has_param_with_shape_update
2434+ return lora_state_dict
23822435
23832436
23842437# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
0 commit comments