|
37 | 37 | IA3Config, |
38 | 38 | LoHaConfig, |
39 | 39 | LoraConfig, |
| 40 | + PeftModel, |
40 | 41 | PromptTuningConfig, |
41 | 42 | VeraConfig, |
42 | 43 | get_layer_status, |
@@ -1502,6 +1503,36 @@ def __init__(self): |
1502 | 1503 | # target modules should *not* be simplified to "query" as that would match "single_transformers_blocks" too |
1503 | 1504 | assert model.peft_config["default"].target_modules != {"query"} |
1504 | 1505 |
|
| 1506 | + def test_find_minimal_target_modules_does_not_error_with_ia3(self, tmp_path): |
| 1507 | + # See #2429 |
| 1508 | + # There is an issue with the compression of the target_modules attribute when using IA³. There, we additionally |
| 1509 | + # have the feedforward_modules attribute, which must be subset of target_modules. When target_modules is shrunk, |
| 1510 | + # the subset check will fail. This test ensures that this doesn't happen. |
| 1511 | + n_layers = MIN_TARGET_MODULES_FOR_OPTIMIZATION + 1 |
| 1512 | + |
| 1513 | + class InnerModule(nn.Module): |
| 1514 | + def __init__(self): |
| 1515 | + super().__init__() |
| 1516 | + self.query = nn.Linear(10, 10) |
| 1517 | + |
| 1518 | + class OuterModule(nn.Module): |
| 1519 | + def __init__(self): |
| 1520 | + super().__init__() |
| 1521 | + self.blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)]) |
| 1522 | + |
| 1523 | + target_modules = [f"blocks.{i}.query" for i in range(n_layers)] |
| 1524 | + feedforward_modules = [f"blocks.{i}.query" for i in range(n_layers)] |
| 1525 | + # the subset check happens here |
| 1526 | + config = IA3Config(target_modules=target_modules, feedforward_modules=feedforward_modules) |
| 1527 | + # the optimization step happens here, after the subset check, so at first we're fine, but we will run into an |
| 1528 | + # issue after a save/load roundtrip |
| 1529 | + model = get_peft_model(OuterModule(), config) |
| 1530 | + model.save_pretrained(tmp_path) |
| 1531 | + del model |
| 1532 | + |
| 1533 | + # does not raise |
| 1534 | + PeftModel.from_pretrained(OuterModule(), tmp_path) |
| 1535 | + |
1505 | 1536 |
|
1506 | 1537 | class TestRankAndAlphaPattern: |
1507 | 1538 | @pytest.fixture |
|
0 commit comments