-
Couldn't load subscription status.
- Fork 2.1k
Implement ensure_weight_tying for trainable_token_indices (#2864) #2870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement ensure_weight_tying for trainable_token_indices (#2864) #2870
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for a lot for handling the update of weight tying of trainable tokens. What's there already looks quite good, but I wonder if we can simplify the implementation, please check my suggestions.
Regarding the tests, I wanted to map the tests you wrote onto the table from #2864, this is what I ended up with:
| weights tied | ensure_weight_tying | LoraConfig trainable_token_indices | result | test |
|---|---|---|---|---|
| False | False | [1, 2, 3] |
trainable tokens on embeddings only | |
| False | True | [1, 2, 3] |
warn & trainable tokens on embeddings only | test_ensure_weight_tying_warns_when_model_not_tied_list_format |
| True | False | [1, 2, 3] |
tied trainable tokens | |
| True | True | [1, 2, 3] |
tied trainable tokens | test_ensure_weight_tying_with_single_layer |
| False | False | {"lm_head": [1,2], "embed_tokens": [1,2]} |
treat as separate | |
| False | True | {"lm_head": [1,2], "embed_tokens": [1,2]} |
warn & treat as separate | |
| True | False | {"lm_head": [1,2], "embed_tokens": [1,2]} |
tied trainable tokens | test_weight_tying_bc_same_indices_applied |
| True | True | {"lm_head": [1,2], "embed_tokens": [1,2]} |
tied trainable tokens | test_ensure_weight_tying_applied_with_same_indices |
| False | False | {"lm_head": [1,2], "embed_tokens": [3,4]} |
treat as separate | |
| False | True | {"lm_head": [1,2], "embed_tokens": [3,4]} |
warn & treat as separate | |
| True | False | {"lm_head": [1,2], "embed_tokens": [3,4]} |
*treat as separate | test_weight_tying_bc_different_indices_treated_separately |
| True | True | {"lm_head": [1,2], "embed_tokens": [3,4]} |
*error | test_ensure_weight_tying_errors_with_different_indices |
Does this look right to you? I think it means there are still a few gaps in the tests, could you please provide the missing ones? Some tests could be combined via pytest.mark.parametrize if the expected outcomes are the same.
tests/test_trainable_tokens.py
Outdated
| ] | ||
| assert warnings_found | ||
|
|
||
| def test_ensure_weight_tying_warns_when_model_not_tied_dict_format(self, model_weight_untied, recwarn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test can be merged with test_ensure_weight_tying_warns_when_model_not_tied_list_format by parametrizing the trainable_token_indices argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
tests/test_trainable_tokens.py
Outdated
| warnings_list = [w.message.args[0] for w in recwarn] | ||
| warnings_found = [ | ||
| msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg | ||
| ] | ||
| assert warnings_found |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a bit more elegant to do:
expected = ...
assert any(expected in msg for msg in warings_list)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
| embedding_output = peft_embedding(x) | ||
| assert (embedding_output == 0.0).all() | ||
|
|
||
| # Tests for ensure_weight_tying parameter with trainable_token_indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's mention #2864 here, I think it helps understanding the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added! I've included the comment # See #2864 for details on the expected behavior at the beginning of the ensure_weight_tying test section. This helps readers understand the context and refer back to the original issue for the full specification.
tests/test_trainable_tokens.py
Outdated
| ensure_weight_tying=True, | ||
| ) | ||
|
|
||
| with pytest.raises(ValueError) as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use:
msg = "Cannot ensure weight tying when different token indices are specified"
with pytest.raises(ValueError, match=msg):
src/peft/utils/other.py
Outdated
| ensure_weight_tying = getattr(peft_config, "ensure_weight_tying", False) | ||
|
|
||
| # Check if we're dealing with dict format that specifies both embed_tokens and lm_head | ||
| is_dict_format = isinstance(peft_config.trainable_token_indices, dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need is_dict_format. The check below, len(target_layers) > 1, is already enough, is it not?
src/peft/utils/other.py
Outdated
| if "embed" in key_lower and not ("lm" in key_lower or "head" in key_lower): | ||
| embed_key = key | ||
| elif "lm_head" in key_lower or ("head" in key_lower and "lm" not in key_lower): | ||
| lm_head_key = key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we overcomplicate things here. If there are multiple target_layers, can we not just compare them to the tied weights? Is it important to identify here which one is for the embedding and which one is for the LM head?
Below, you're using the names for the error message, which is a nice touch, but if we can refrain from guessing here, it would be worth it to make the error message more generic IMO.
src/peft/utils/other.py
Outdated
| indices_mismatch = True | ||
| else: | ||
| # Same indices - if weights are tied and we're applying tying, skip lm_head (it'll be tied later) | ||
| if weights_tied and not (not ensure_weight_tying and False): # Will apply tying |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check makes no sense to me, why and False?
Implement ensure_weight_tying for trainable_token_indices
Summary
This PR implements consistent weight tying behavior for
trainable_token_indicesas specified in issue #2864. It extends theensure_weight_tyingparameter (introduced in PR #2803) to work withtrainable_token_indices, providing users explicit control over weight tying between embeddings and LM head.Fixes #2864 (trainable_token_indices portion)
Problem Statement
Background
PEFT models sometimes need to handle tied weights between embedding layers and LM head layers (when
tie_word_embeddings=True). Theensure_weight_tyingparameter was introduced in PR #2803 to give users explicit control over this behavior formodules_to_save. However, the same control was missing fortrainable_token_indices.The Issue
Issue identified that the weight tying behavior for
trainable_token_indiceswas not consistent across different scenarios. Specifically, there were four cases that needed to be implemented:Solution Approach
Implementation Strategy:
Changes Made
1. Updated Configuration Documentation
File:
src/peft/tuners/lora/config.pyUpdated the
ensure_weight_tyingparameter docstring to clarify that it now applies to bothmodules_to_saveandtrainable_token_indices, making the documentation consistent with the implementation.2. Implemented Weight Tying Logic
File:
src/peft/utils/other.pyAdded comprehensive logic within the existing
trainable_token_indiceshandling block:Key Components:
ensure_weight_tying=FalseFour Cases Implemented:
Case 1 - Warning for Untied Models:
weights_tied=False+ensure_weight_tying=TrueCase 2 - Error for Contradictory Configuration:
weights_tied=True+ensure_weight_tying=True+ different indicesCase 3 - Backwards Compatibility:
weights_tied=True+ensure_weight_tying=False+ different indicesCase 4 - Apply Tying:
3. Comprehensive Test Suite
File:
tests/test_trainable_tokens.pyAdded 7 new test methods covering all scenarios:
Test Coverage:
test_ensure_weight_tying_warns_when_model_not_tied_list_format: Verifies warning for list formattest_ensure_weight_tying_warns_when_model_not_tied_dict_format: Verifies warning for dict formattest_weight_tying_bc_different_indices_treated_separately: Verifies backwards compatibilitytest_ensure_weight_tying_errors_with_different_indices: Verifies error for contradictory configtest_ensure_weight_tying_applied_with_same_indices: Verifies tying with same indicestest_weight_tying_bc_same_indices_applied: Verifies BC for same indicestest_ensure_weight_tying_with_single_layer: Verifies list format tyingTesting Results
New Tests
All 7 new tests pass successfully:
test_ensure_weight_tying_warns_when_model_not_tied_list_formattest_ensure_weight_tying_warns_when_model_not_tied_dict_formattest_weight_tying_bc_different_indices_treated_separatelytest_ensure_weight_tying_errors_with_different_indicestest_ensure_weight_tying_applied_with_same_indicestest_weight_tying_bc_same_indices_appliedtest_ensure_weight_tying_with_single_layerBackwards Compatibility
This implementation maintains full backwards compatibility:
✅ Default Behavior Unchanged:
ensure_weight_tyingdefaults toFalse, preserving existing behavior✅ No Breaking Changes: Existing code continues to work without modification
✅ Opt-in Enhancement: Users must explicitly set
ensure_weight_tying=Trueto use new features✅ BC Mode Preserved: When
ensure_weight_tying=False, existing automatic tying still works for compatible configurationsScreenshots
Checklist
cc: @BenjaminBossan