Skip to content

Commit a96cccd

Browse files
authored
Tie weights recursively on all submodels (#39996)
* recursive call * add missing keys * remove bad keys
1 parent a78263d commit a96cccd

File tree

5 files changed

+18
-5
lines changed

5 files changed

+18
-5
lines changed

src/transformers/modeling_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,9 +2992,10 @@ def smart_apply(self, fn):
29922992
# Let the magic happen with this simple call
29932993
self.smart_apply(self._initialize_weights)
29942994

2995-
def tie_weights(self):
2995+
def tie_embeddings_and_encoder_decoder(self):
29962996
"""
2997-
Tie the weights between the input embeddings and the output embeddings.
2997+
If set in the config, tie the weights between the input embeddings and the output embeddings,
2998+
and the encoder and decoder.
29982999
29993000
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
30003001
weights instead.
@@ -3015,7 +3016,16 @@ def tie_weights(self):
30153016
# Leading to issues on subsequent calls by different tests or subsequent calls.
30163017
self._dynamic_tied_weights_keys = tied_weights
30173018

3019+
def tie_weights(self):
3020+
"""
3021+
Recursively (for all submodels) tie all the weights of the model.
3022+
"""
3023+
# Note that `self` is included in `self.modules` so we also apply to current PreTrainedModel with this call
30183024
for module in self.modules():
3025+
# If it's a PreTrainedModel, may need to tie the embeddings and/or encoder/decoder weights
3026+
if isinstance(module, PreTrainedModel):
3027+
module.tie_embeddings_and_encoder_decoder()
3028+
# Additionally, if it has a custom `_tie_weights`, honor it
30193029
if hasattr(module, "_tie_weights"):
30203030
module._tie_weights()
30213031

src/transformers/models/blip/modeling_blip_text.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,8 @@ def forward(
842842

843843
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
844844
class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
845+
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
846+
845847
def __init__(self, config):
846848
super().__init__(config)
847849

src/transformers/models/seamless_m4t/modeling_seamless_m4t.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1989,7 +1989,7 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel,
19891989
"text_encoder",
19901990
"text_decoder",
19911991
]
1992-
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
1992+
_tied_weights_keys = ["lm_head.weight"]
19931993

19941994
def __init__(
19951995
self,

src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2191,7 +2191,7 @@ class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedMod
21912191
"text_encoder",
21922192
"text_decoder",
21932193
]
2194-
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
2194+
_tied_weights_keys = ["lm_head.weight"]
21952195

21962196
# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.__init__ with SeamlessM4T->SeamlessM4Tv2
21972197
def __init__(

tests/test_modeling_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
is_deepspeed_zero3_enabled,
4949
unset_hf_deepspeed_config,
5050
)
51+
from transformers.modeling_utils import _get_tied_weight_keys
5152
from transformers.models.auto import get_values
5253
from transformers.models.auto.modeling_auto import (
5354
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
@@ -2572,7 +2573,7 @@ def test_tied_weights_keys(self):
25722573
copied_config.get_text_config().tie_word_embeddings = True
25732574
model_tied = model_class(copied_config)
25742575

2575-
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
2576+
tied_weight_keys = _get_tied_weight_keys(model_tied)
25762577
# If we don't find any tied weights keys, and by default we don't tie the embeddings, it's because the model
25772578
# does not tie them
25782579
if len(tied_weight_keys) == 0 and not original_config.tie_word_embeddings:

0 commit comments

Comments
 (0)