Skip to content

Conversation

@sambhavnoobcoder
Copy link
Contributor

Implement ensure_weight_tying for trainable_token_indices

Summary

This PR implements consistent weight tying behavior for trainable_token_indices as specified in issue #2864. It extends the ensure_weight_tying parameter (introduced in PR #2803) to work with trainable_token_indices, providing users explicit control over weight tying between embeddings and LM head.

Fixes #2864 (trainable_token_indices portion)


Problem Statement

Background

PEFT models sometimes need to handle tied weights between embedding layers and LM head layers (when tie_word_embeddings=True). The ensure_weight_tying parameter was introduced in PR #2803 to give users explicit control over this behavior for modules_to_save. However, the same control was missing for trainable_token_indices.

The Issue

Issue identified that the weight tying behavior for trainable_token_indices was not consistent across different scenarios. Specifically, there were four cases that needed to be implemented:

  1. Untied model with ensure_weight_tying=True: Should warn users that weight tying cannot be applied
  2. Tied model with ensure_weight_tying=True and different indices: Should error, as it's impossible to tie adapters with different token indices
  3. Tied model with ensure_weight_tying=False and different indices: Should treat layers as separate (backwards compatibility behavior)
  4. Tied model with ensure_weight_tying=True and same indices: Should apply weight tying correctly

Solution Approach

Implementation Strategy:

  1. Check weight tying configuration early (before creating wrappers)
  2. Detect if user specified both embedding and lm_head layers in dict format
  3. Check if their token indices match or differ
  4. Apply appropriate logic based on the configuration matrix from the issue
  5. Skip creating wrappers for layers that will be tied later

Changes Made

1. Updated Configuration Documentation

File: src/peft/tuners/lora/config.py

Updated the ensure_weight_tying parameter docstring to clarify that it now applies to both modules_to_save and trainable_token_indices, making the documentation consistent with the implementation.

2. Implemented Weight Tying Logic

File: src/peft/utils/other.py

Added comprehensive logic within the existing trainable_token_indices handling block:

Key Components:

  • Early Detection: Check weight tying configuration before creating any wrappers
  • Layer Detection: Identify if both embedding and lm_head layers are specified
  • Index Comparison: Determine if token indices match between the layers
  • Skip Logic: Prevent double-wrapping by skipping layers that will be tied
  • Warning System: Inform users when their configuration cannot be applied
  • Error Handling: Raise clear errors for contradictory configurations
  • Backwards Compatibility: Preserve existing behavior when ensure_weight_tying=False

Four Cases Implemented:

  1. Case 1 - Warning for Untied Models:

    • When: weights_tied=False + ensure_weight_tying=True
    • Action: Issue warning that weight tying cannot be applied
    • Rationale: Model doesn't have tied weights, so user's request cannot be fulfilled
  2. Case 2 - Error for Contradictory Configuration:

    • When: weights_tied=True + ensure_weight_tying=True + different indices
    • Action: Raise ValueError with clear explanation
    • Rationale: Cannot tie adapters that operate on different token indices
  3. Case 3 - Backwards Compatibility:

    • When: weights_tied=True + ensure_weight_tying=False + different indices
    • Action: Treat layers as separate (no tying)
    • Rationale: User explicitly opted out, respect their choice even if model supports tying
  4. Case 4 - Apply Tying:

    • When: Other combinations where tying is appropriate
    • Action: Create tied adapters that share parameters
    • Rationale: Normal weight tying behavior

3. Comprehensive Test Suite

File: tests/test_trainable_tokens.py

Added 7 new test methods covering all scenarios:

Test Coverage:

  • test_ensure_weight_tying_warns_when_model_not_tied_list_format: Verifies warning for list format
  • test_ensure_weight_tying_warns_when_model_not_tied_dict_format: Verifies warning for dict format
  • test_weight_tying_bc_different_indices_treated_separately: Verifies backwards compatibility
  • test_ensure_weight_tying_errors_with_different_indices: Verifies error for contradictory config
  • test_ensure_weight_tying_applied_with_same_indices: Verifies tying with same indices
  • test_weight_tying_bc_same_indices_applied: Verifies BC for same indices
  • test_ensure_weight_tying_with_single_layer: Verifies list format tying

Testing Results

New Tests

All 7 new tests pass successfully:

  • test_ensure_weight_tying_warns_when_model_not_tied_list_format
  • test_ensure_weight_tying_warns_when_model_not_tied_dict_format
  • test_weight_tying_bc_different_indices_treated_separately
  • test_ensure_weight_tying_errors_with_different_indices
  • test_ensure_weight_tying_applied_with_same_indices
  • test_weight_tying_bc_same_indices_applied
  • test_ensure_weight_tying_with_single_layer

Backwards Compatibility

This implementation maintains full backwards compatibility:

Default Behavior Unchanged: ensure_weight_tying defaults to False, preserving existing behavior
No Breaking Changes: Existing code continues to work without modification
Opt-in Enhancement: Users must explicitly set ensure_weight_tying=True to use new features
BC Mode Preserved: When ensure_weight_tying=False, existing automatic tying still works for compatible configurations


Screenshots

Screenshot 2025-10-26 at 7 20 09 PM

Checklist

  • Implementation follows the specification in issue Deal with weight tying consistently #2864
  • All 7 new tests pass
  • Backwards compatibility maintained
  • Documentation updated (docstring)
  • Code is scoped only to trainable_token_indices
  • Error messages are clear and actionable
  • Warning messages inform users appropriately

cc: @BenjaminBossan


Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for a lot for handling the update of weight tying of trainable tokens. What's there already looks quite good, but I wonder if we can simplify the implementation, please check my suggestions.

Regarding the tests, I wanted to map the tests you wrote onto the table from #2864, this is what I ended up with:

weights tied ensure_weight_tying LoraConfig trainable_token_indices result test
False False [1, 2, 3] trainable tokens on embeddings only
False True [1, 2, 3] warn & trainable tokens on embeddings only test_ensure_weight_tying_warns_when_model_not_tied_list_format
True False [1, 2, 3] tied trainable tokens
True True [1, 2, 3] tied trainable tokens test_ensure_weight_tying_with_single_layer
False False {"lm_head": [1,2], "embed_tokens": [1,2]} treat as separate
False True {"lm_head": [1,2], "embed_tokens": [1,2]} warn & treat as separate
True False {"lm_head": [1,2], "embed_tokens": [1,2]} tied trainable tokens test_weight_tying_bc_same_indices_applied
True True {"lm_head": [1,2], "embed_tokens": [1,2]} tied trainable tokens test_ensure_weight_tying_applied_with_same_indices
False False {"lm_head": [1,2], "embed_tokens": [3,4]} treat as separate
False True {"lm_head": [1,2], "embed_tokens": [3,4]} warn & treat as separate
True False {"lm_head": [1,2], "embed_tokens": [3,4]} *treat as separate test_weight_tying_bc_different_indices_treated_separately
True True {"lm_head": [1,2], "embed_tokens": [3,4]} *error test_ensure_weight_tying_errors_with_different_indices

Does this look right to you? I think it means there are still a few gaps in the tests, could you please provide the missing ones? Some tests could be combined via pytest.mark.parametrize if the expected outcomes are the same.

]
assert warnings_found

def test_ensure_weight_tying_warns_when_model_not_tied_dict_format(self, model_weight_untied, recwarn):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test can be merged with test_ensure_weight_tying_warns_when_model_not_tied_list_format by parametrizing the trainable_token_indices argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

Comment on lines 992 to 996
warnings_list = [w.message.args[0] for w in recwarn]
warnings_found = [
msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg
]
assert warnings_found
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit more elegant to do:

expected = ...
assert any(expected in msg for msg in warings_list)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

embedding_output = peft_embedding(x)
assert (embedding_output == 0.0).all()

# Tests for ensure_weight_tying parameter with trainable_token_indices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention #2864 here, I think it helps understanding the tests.

Copy link
Contributor Author

@sambhavnoobcoder sambhavnoobcoder Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added! I've included the comment # See #2864 for details on the expected behavior at the beginning of the ensure_weight_tying test section. This helps readers understand the context and refer back to the original issue for the full specification.

ensure_weight_tying=True,
)

with pytest.raises(ValueError) as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use:

msg = "Cannot ensure weight tying when different token indices are specified"
with pytest.raises(ValueError, match=msg):

ensure_weight_tying = getattr(peft_config, "ensure_weight_tying", False)

# Check if we're dealing with dict format that specifies both embed_tokens and lm_head
is_dict_format = isinstance(peft_config.trainable_token_indices, dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need is_dict_format. The check below, len(target_layers) > 1, is already enough, is it not?

Comment on lines 1487 to 1490
if "embed" in key_lower and not ("lm" in key_lower or "head" in key_lower):
embed_key = key
elif "lm_head" in key_lower or ("head" in key_lower and "lm" not in key_lower):
lm_head_key = key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we overcomplicate things here. If there are multiple target_layers, can we not just compare them to the tied weights? Is it important to identify here which one is for the embedding and which one is for the LM head?

Below, you're using the names for the error message, which is a nice touch, but if we can refrain from guessing here, it would be worth it to make the error message more generic IMO.

indices_mismatch = True
else:
# Same indices - if weights are tied and we're applying tying, skip lm_head (it'll be tied later)
if weights_tied and not (not ensure_weight_tying and False): # Will apply tying
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check makes no sense to me, why and False?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Deal with weight tying consistently

2 participants