diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 51994b456f..405277b6b5 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -461,7 +461,9 @@ def inject_adapter( and len(peft_config.target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION ): names_no_target = [ - name for name in key_list if not any(name.endswith(suffix) for suffix in peft_config.target_modules) + name + for name in key_list + if not any((name == suffix) or name.endswith("." + suffix) for suffix in peft_config.target_modules) ] new_target_modules = _find_minimal_target_modules(peft_config.target_modules, names_no_target) if len(new_target_modules) < len(peft_config.target_modules): diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 5e742f3c88..06a47deb26 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -1400,3 +1400,45 @@ def test_suffix_is_substring_of_other_suffix(self): expected = {"time_emb_proj", "proj", "proj_out"} result = find_minimal_target_modules(target_modules, other_module_names) assert result == expected + + def test_get_peft_modules_module_name_is_suffix_of_another_module(self): + # Solves the following bug: + # https://github.com/huggingface/diffusers/pull/9622#issuecomment-2404789721 + + # The cause for the bug is as follows: When we have, say, a module called "bar.0.query" that we want to target + # and another module called "foo_bar.0.query" that we don't want to target, there was potential for an error. + # This is not caused by _find_minimal_target_modules directly, but rather the bug was inside of + # BaseTuner.inject_adapter and how the names_no_target were chosen. Those used to be chosen based on suffix. In + # our example, however, "bar.0.query" is a suffix of "foo_bar.0.query", therefore "foo_bar.0.query" was *not* + # added to names_no_target when it should have. As a consequence, during the optimization, it looks like "query" + # is safe to use as target_modules because we don't see that it wrongly matches "foo_bar.0.query". + + # ensure that we have sufficiently many modules to trigger the optimization + n_layers = MIN_TARGET_MODULES_FOR_OPTIMIZATION + 1 + + class InnerModule(nn.Module): + def __init__(self): + super().__init__() + self.query = nn.Linear(10, 10) + + class OuterModule(nn.Module): + def __init__(self): + super().__init__() + # note that "transformer_blocks" is a suffix of "single_transformer_blocks" + self.transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)]) + self.single_transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)]) + + # we want to match all "transformer_blocks" layers but not "single_transformer_blocks" + target_modules = [f"transformer_blocks.{i}.query" for i in range(n_layers)] + model = get_peft_model(OuterModule(), LoraConfig(target_modules=target_modules)) + + # sanity check: we should have n_layers PEFT layers in model.transformer_blocks + transformer_blocks = model.base_model.model.transformer_blocks + assert sum(isinstance(module, BaseTunerLayer) for module in transformer_blocks.modules()) == n_layers + + # we should not have any PEFT layers in model.single_transformer_blocks + single_transformer_blocks = model.base_model.model.single_transformer_blocks + assert not any(isinstance(module, BaseTunerLayer) for module in single_transformer_blocks.modules()) + + # target modules should *not* be simplified to "query" as that would match "single_transformers_blocks" too + assert model.peft_config["default"].target_modules != {"query"}