@@ -2031,18 +2031,36 @@ def lora_state_dict(
20312031 if is_kohya :
20322032 state_dict = _convert_kohya_flux_lora_to_diffusers (state_dict )
20332033 # Kohya already takes care of scaling the LoRA parameters with alpha.
2034- return (state_dict , None ) if return_alphas else state_dict
2034+ return cls ._prepare_outputs (
2035+ state_dict ,
2036+ metadata = metadata ,
2037+ alphas = None ,
2038+ return_alphas = return_alphas ,
2039+ return_metadata = return_lora_metadata ,
2040+ )
20352041
20362042 is_xlabs = any ("processor" in k for k in state_dict )
20372043 if is_xlabs :
20382044 state_dict = _convert_xlabs_flux_lora_to_diffusers (state_dict )
20392045 # xlabs doesn't use `alpha`.
2040- return (state_dict , None ) if return_alphas else state_dict
2046+ return cls ._prepare_outputs (
2047+ state_dict ,
2048+ metadata = metadata ,
2049+ alphas = None ,
2050+ return_alphas = return_alphas ,
2051+ return_metadata = return_lora_metadata ,
2052+ )
20412053
20422054 is_bfl_control = any ("query_norm.scale" in k for k in state_dict )
20432055 if is_bfl_control :
20442056 state_dict = _convert_bfl_flux_control_lora_to_diffusers (state_dict )
2045- return (state_dict , None ) if return_alphas else state_dict
2057+ return cls ._prepare_outputs (
2058+ state_dict ,
2059+ metadata = metadata ,
2060+ alphas = None ,
2061+ return_alphas = return_alphas ,
2062+ return_metadata = return_lora_metadata ,
2063+ )
20462064
20472065 # For state dicts like
20482066 # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
@@ -2061,12 +2079,13 @@ def lora_state_dict(
20612079 )
20622080
20632081 if return_alphas or return_lora_metadata :
2064- outputs = [state_dict ]
2065- if return_alphas :
2066- outputs .append (network_alphas )
2067- if return_lora_metadata :
2068- outputs .append (metadata )
2069- return tuple (outputs )
2082+ return cls ._prepare_outputs (
2083+ state_dict ,
2084+ metadata = metadata ,
2085+ alphas = network_alphas ,
2086+ return_alphas = return_alphas ,
2087+ return_metadata = return_lora_metadata ,
2088+ )
20702089 else :
20712090 return state_dict
20722091
@@ -2785,6 +2804,15 @@ def _get_weight_shape(weight: torch.Tensor):
27852804
27862805 raise ValueError ("Either `base_module` or `base_weight_param_name` must be provided." )
27872806
2807+ @staticmethod
2808+ def _prepare_outputs (state_dict , metadata , alphas = None , return_alphas = False , return_metadata = False ):
2809+ outputs = [state_dict ]
2810+ if return_alphas :
2811+ outputs .append (alphas )
2812+ if return_metadata :
2813+ outputs .append (metadata )
2814+ return tuple (outputs ) if (return_alphas or return_metadata ) else state_dict
2815+
27882816
27892817# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
27902818# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
0 commit comments