@@ -1652,6 +1652,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
16521652    _lora_loadable_modules  =  ["transformer" , "text_encoder" ]
16531653    transformer_name  =  TRANSFORMER_NAME 
16541654    text_encoder_name  =  TEXT_ENCODER_NAME 
1655+     _control_lora_supported_norm_keys  =  ["norm_q" , "norm_k" , "norm_added_q" , "norm_added_k" ]
16551656
16561657    @classmethod  
16571658    @validate_hf_hub_args  
@@ -1835,8 +1836,9 @@ def load_lora_weights(
18351836        has_lora_keys  =  any ("lora"  in  key  for  key  in  state_dict .keys ())
18361837
18371838        # Flux Control LoRAs also have norm keys 
1838-         supported_norm_keys  =  ["norm_q" , "norm_k" , "norm_added_q" , "norm_added_k" ]
1839-         has_norm_keys  =  any (norm_key  in  key  for  key  in  state_dict .keys () for  norm_key  in  supported_norm_keys )
1839+         has_norm_keys  =  any (
1840+             norm_key  in  key  for  key  in  state_dict .keys () for  norm_key  in  self ._control_lora_supported_norm_keys 
1841+         )
18401842
18411843        if  not  (has_lora_keys  or  has_norm_keys ):
18421844            raise  ValueError ("Invalid LoRA checkpoint." )
@@ -1847,7 +1849,7 @@ def load_lora_weights(
18471849        transformer_norm_state_dict  =  {
18481850            k : state_dict .pop (k )
18491851            for  k  in  list (state_dict .keys ())
1850-             if  "transformer."  in  k  and  any (norm_key  in  k  for  norm_key  in  supported_norm_keys )
1852+             if  "transformer."  in  k  and  any (norm_key  in  k  for  norm_key  in  self . _control_lora_supported_norm_keys )
18511853        }
18521854
18531855        transformer  =  getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer 
@@ -1977,7 +1979,15 @@ def _load_norm_into_transformer(
19771979        )
19781980
19791981        # We can't load with strict=True because the current state_dict does not contain all the transformer keys 
1980-         transformer .load_state_dict (state_dict , strict = False )
1982+         incompatible_keys  =  transformer .load_state_dict (state_dict , strict = False )
1983+         unexpected_keys  =  getattr (incompatible_keys , "unexpected_keys" , None )
1984+ 
1985+         # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. 
1986+         if  unexpected_keys :
1987+             if  any (norm_key  in  k  for  k  in  unexpected_keys  for  norm_key  in  cls ._control_lora_supported_norm_keys ):
1988+                 raise  ValueError (
1989+                     f"Found { unexpected_keys }  
1990+                 )
19811991
19821992        return  overwritten_layers_state_dict 
19831993
0 commit comments