Skip to content

Commit 04f2ff0

Browse files
committed
use state dict when setting up LoRA.
1 parent 3d2f8ae commit 04f2ff0

File tree

4 files changed

+5
-41
lines changed

4 files changed

+5
-41
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
"librosa",
117117
"numpy",
118118
"parameterized",
119-
"peft>=0.15.0",
119+
"peft>=0.16.1",
120120
"protobuf>=3.20.3,<4",
121121
"pytest",
122122
"pytest-timeout",

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"librosa": "librosa",
2424
"numpy": "numpy",
2525
"parameterized": "parameterized",
26-
"peft": "peft>=0.15.0",
26+
"peft": "peft>=0.16.1",
2727
"protobuf": "protobuf>=3.20.3,<4",
2828
"pytest": "pytest",
2929
"pytest-timeout": "pytest-timeout",

src/diffusers/loaders/peft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,9 @@ def map_state_dict_for_hotswap(sd):
319319
# it to None
320320
incompatible_keys = None
321321
else:
322-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
322+
inject_adapter_in_model(
323+
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
324+
)
323325
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
324326

325327
if self._prepare_lora_hotswap_kwargs is not None:

src/diffusers/utils/peft_utils.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,6 @@ def get_peft_kwargs(
197197
"lora_bias": lora_bias,
198198
}
199199

200-
# Example: try load FusionX LoRA into Wan VACE
201-
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
202-
if exclude_modules:
203-
if not is_peft_version(">=", "0.14.0"):
204-
msg = """
205-
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
206-
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
207-
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
208-
https://github.com/huggingface/diffusers/issues/new
209-
"""
210-
logger.debug(msg)
211-
else:
212-
lora_config_kwargs.update({"exclude_modules": exclude_modules})
213-
214200
return lora_config_kwargs
215201

216202

@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
388374

389375
if warn_msg:
390376
logger.warning(warn_msg)
391-
392-
393-
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
394-
"""
395-
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
396-
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
397-
doesn't exist in `peft_state_dict`.
398-
"""
399-
if model_state_dict is None:
400-
return
401-
all_modules = set()
402-
string_to_replace = f"{adapter_name}." if adapter_name else ""
403-
404-
for name in model_state_dict.keys():
405-
if string_to_replace:
406-
name = name.replace(string_to_replace, "")
407-
if "." in name:
408-
module_name = name.rsplit(".", 1)[0]
409-
all_modules.add(module_name)
410-
411-
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
412-
exclude_modules = list(all_modules - target_modules_set)
413-
414-
return exclude_modules

0 commit comments

Comments
 (0)