Skip to content

Commit a78263d

Browse files
committed
fix
1 parent dc11a3c commit a78263d

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

src/transformers/models/bark/modeling_bark.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,16 +1067,6 @@ def tie_weights(self):
10671067
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
10681068
weights instead.
10691069
"""
1070-
if getattr(self.config, "tie_word_embeddings", True):
1071-
self._tied_weights_keys = []
1072-
output_embeddings = self.get_output_embeddings()
1073-
input_embeddings = self.get_input_embeddings()
1074-
1075-
for i in range(self.config.n_codes_total - self.config.n_codes_given):
1076-
# self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
1077-
self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1])
1078-
self._tied_weights_keys.append(f"lm_heads.{i}.weight")
1079-
10801070
for module in self.modules():
10811071
if hasattr(module, "_tie_weights"):
10821072
module._tie_weights()
@@ -1621,6 +1611,17 @@ def generate(
16211611

16221612
return audio
16231613

1614+
def tie_weights(self):
1615+
"""
1616+
Tie the weights between the input embeddings list and the output embeddings list.
1617+
1618+
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
1619+
weights instead.
1620+
"""
1621+
for module in self.modules():
1622+
if hasattr(module, "_tie_weights"):
1623+
module._tie_weights()
1624+
16241625

16251626
__all__ = [
16261627
"BarkFineModel",

0 commit comments

Comments
 (0)