Skip to content

Deal with weight tying consistently #2864

@BenjaminBossan

Description

@BenjaminBossan

Currently, the way PEFT deals with tied embedding and LM head weights is not always clear. In #2803, a new argument, ensure_weight_tying, was introduced to make it easier for users to automatically tie the PEFT weights while keeping backwards compatibility. However, this makes it even more important to clarify what happens when.

The table below shows the intended behavior in different circumstances. Notably, weigh tying can effect modules_to_save, target_modules, and trainable_token_indices. The table lists the expected results for all combinations of these factors.

weights tied ensure_weight_tying LoraConfig result
False False modules_to_save=[embed_tokens] / modules_to_save=[lm_head] ModulesToSaveWrapper on embedding/lm head
False True modules_to_save=[embed_tokens] / modules_to_save=[lm_head] warn & ModulesToSaveWrapper on embedding/lm head
True False modules_to_save=[embed_tokens] / modules_to_save=[lm_head] ModulesToSaveWrapper on embedding/lm head (BC)
True True modules_to_save=[embed_tokens] / modules_to_save=[lm_head] ModulesToSaveWrappers share weights
False False modules_to_save=[embed_tokens, lm_head] treat as separate
False True modules_to_save=[embed_tokens, lm_head] warn & treat as separate
True False modules_to_save=[embed_tokens, lm_head] treat as separate (BC)
True True modules_to_save=[embed_tokens, lm_head] ModulesToSaveWrappers share weights
False False target_modules=[embed_tokens] / target_modules=[lm_head] LoRA on embedding/lm head
False True target_modules=[embed_tokens] / target_modules=[lm_head] *warn & LoRA on embedding/lm head
True False target_modules=[embed_tokens] / target_modules=[lm_head] LoRA on embedding/lm head (BC)
True True target_modules=[embed_tokens] / target_modules=[lm_head] *LoRA share weights
False False target_modules=[embed_tokens, lm_head] treat as separate
False True target_modules=[embed_tokens, lm_head] *warn & treat as separate
True False target_modules=[embed_tokens, lm_head] treat as separate (BC)
True True target_modules=[embed_tokens, lm_head] *LoRA share weights
False False trainable_token_indices=[1, 2, 3] trainable tokens on embeddings only
False True trainable_token_indices=[1, 2, 3] warn & trainable tokens on embeddings only
True False trainable_token_indices=[1, 2, 3] tied trainable tokens
True True trainable_token_indices=[1, 2, 3] tied trainable tokens
False False trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} treat as separate
False True trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} warn & treat as separate
True False trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} tied trainable tokens
True True trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} tied trainable tokens
False False trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} treat as separate
False True trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} warn & treat as separate
True False trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} *treat as separate
True True trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} *error

Explanation:

  • BC means that we keep this behavior for backwards compatibility, even if it might not be the most intuitive behavior.
  • * marks behavior that is not yet implemented as such but should be added.
  • For trainable_token_indices, we distinguish between cases where embedding and LM head define the same indices, which would allow weight sharing, and where they define distinct indices, which precludes weight sharing.

Ping @romitjain

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions