Skip to content

Commit b4f1cbf

Browse files
committed
control lora updates
1 parent 217e90c commit b4f1cbf

File tree

1 file changed

+125
-11
lines changed

1 file changed

+125
-11
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 125 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,24 +1819,27 @@ def prune_state_dict_(state_dict):
18191819
if "transformer." in k and any(norm_key in k for norm_key in supported_norm_keys)
18201820
}
18211821

1822+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
1823+
self._maybe_expand_transformer_param_shape_(
1824+
transformer, transformer_lora_state_dict, transformer_norm_state_dict
1825+
)
1826+
print(transformer)
1827+
18221828
if len(transformer_lora_state_dict) > 0:
18231829
self.load_lora_into_transformer(
18241830
transformer_lora_state_dict,
18251831
network_alphas=network_alphas,
1826-
transformer=getattr(self, self.transformer_name)
1827-
if not hasattr(self, "transformer")
1828-
else self.transformer,
1832+
transformer=transformer,
18291833
adapter_name=adapter_name,
18301834
_pipeline=self,
18311835
low_cpu_mem_usage=low_cpu_mem_usage,
18321836
)
18331837

18341838
if len(transformer_norm_state_dict) > 0:
1835-
self.load_norm_into_transformer(
1839+
self._transformer_norm_layers = self.load_norm_into_transformer(
18361840
transformer_norm_state_dict,
1837-
transformer=getattr(self, self.transformer_name)
1838-
if not hasattr(self, "transformer")
1839-
else self.transformer,
1841+
transformer=transformer,
1842+
discard_original_layers=False,
18401843
)
18411844

18421845
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
@@ -1899,10 +1902,41 @@ def load_lora_into_transformer(
18991902
def load_norm_into_transformer(
19001903
cls,
19011904
state_dict,
1902-
transformer: torch.nn.Module,
1903-
):
1904-
print(state_dict.keys())
1905-
transformer.load_state_dict(state_dict, strict=True)
1905+
transformer,
1906+
prefix=None,
1907+
discard_original_layers=False,
1908+
) -> Dict[str, torch.Tensor]:
1909+
# Remove prefix if present
1910+
prefix = prefix or cls.transformer_name
1911+
for key in list(state_dict.keys()):
1912+
if key.split(".")[0] == prefix:
1913+
state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key)
1914+
1915+
# Find invalid keys
1916+
transformer_state_dict = transformer.state_dict()
1917+
transformer_keys = set(transformer_state_dict.keys())
1918+
state_dict_keys = set(state_dict.keys())
1919+
extra_keys = list(state_dict_keys - transformer_keys)
1920+
logger.warning(
1921+
f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
1922+
)
1923+
1924+
for key in extra_keys:
1925+
state_dict.pop(key)
1926+
1927+
# Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
1928+
overwritten_layers = {}
1929+
if not discard_original_layers:
1930+
for key in state_dict.keys():
1931+
overwritten_layers[key] = transformer_state_dict[key]
1932+
1933+
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
1934+
logger.info(
1935+
"Normalization layers in LoRA state dict can only be loaded if fused directly in the transformer. Calls to `.fuse_lora()` will only affect the LoRA layers and not the normalization layers."
1936+
)
1937+
transformer.load_state_dict(state_dict, strict=False)
1938+
1939+
return overwritten_layers
19061940

19071941
@classmethod
19081942
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2139,6 +2173,11 @@ def fuse_lora(
21392173
pipeline.fuse_lora(lora_scale=0.7)
21402174
```
21412175
"""
2176+
if len(self._transformer_norm_layers.keys()) > 0:
2177+
logger.info(
2178+
"Normalization layers cannot be loaded without fusing. Calls to `.fuse_lora()` will only affect the actual LoRA layers."
2179+
)
2180+
21422181
super().fuse_lora(
21432182
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
21442183
)
@@ -2157,8 +2196,83 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
21572196
Args:
21582197
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
21592198
"""
2199+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
2200+
transformer.load_state_dict(self._transformer_norm_layers)
2201+
21602202
super().unfuse_lora(components=components)
21612203

2204+
@classmethod
2205+
def _maybe_expand_transformer_param_shape_(
2206+
cls,
2207+
transformer: torch.nn.Module,
2208+
lora_state_dict=None,
2209+
norm_state_dict=None,
2210+
prefix=None,
2211+
):
2212+
state_dict = {}
2213+
if lora_state_dict is not None:
2214+
state_dict.update(lora_state_dict)
2215+
if norm_state_dict is not None:
2216+
state_dict.update(norm_state_dict)
2217+
2218+
# Remove prefix if present
2219+
prefix = prefix or cls.transformer_name
2220+
for key in list(state_dict.keys()):
2221+
if key.split(".")[0] == prefix:
2222+
state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key)
2223+
2224+
def get_submodule(module, name):
2225+
for part in name.split("."):
2226+
if len(name) == 0:
2227+
break
2228+
if not hasattr(module, part):
2229+
raise AttributeError(f"Submodule '{part}' not found in '{module}'.")
2230+
module = getattr(module, part)
2231+
return module
2232+
2233+
# Expand transformer parameter shapes if they don't match lora
2234+
for name, module in transformer.named_modules():
2235+
if isinstance(module, torch.nn.Linear):
2236+
module_weight = module.weight.data
2237+
module_bias = module.bias.data if hasattr(module, "bias") else None
2238+
bias = module_bias is not None
2239+
name_split = name.split(".")
2240+
2241+
lora_A_name = f"{name}.lora_A"
2242+
lora_B_name = f"{name}.lora_B"
2243+
lora_A_weight_name = f"{lora_A_name}.weight"
2244+
lora_B_weight_name = f"{lora_B_name}.weight"
2245+
2246+
if lora_A_weight_name not in state_dict.keys():
2247+
continue
2248+
2249+
in_features = state_dict[lora_A_weight_name].shape[1]
2250+
out_features = state_dict[lora_B_weight_name].shape[0]
2251+
2252+
if tuple(module_weight.shape) == (out_features, in_features):
2253+
continue
2254+
2255+
parent_module_name = ".".join(name_split[:-1])
2256+
current_module_name = name_split[-1]
2257+
parent_module = get_submodule(transformer, parent_module_name)
2258+
2259+
expanded_module = torch.nn.Linear(
2260+
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
2261+
)
2262+
2263+
new_weight = module_weight.new_zeros(expanded_module.weight.data.shape)
2264+
slices = tuple(slice(0, dim) for dim in module_weight.shape)
2265+
new_weight[slices] = module_weight
2266+
expanded_module.weight.data.copy_(new_weight)
2267+
2268+
if bias:
2269+
new_bias = module_bias.new_zeros(expanded_module.bias.data.shape)
2270+
slices = tuple(slice(0, dim) for dim in module_bias.shape)
2271+
new_bias[slices] = module_bias
2272+
expanded_module.bias.data.copy_(new_bias)
2273+
2274+
setattr(parent_module, current_module_name, expanded_module)
2275+
21622276

21632277
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
21642278
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

0 commit comments

Comments
 (0)