File tree Expand file tree Collapse file tree 1 file changed +11
-10
lines changed
src/transformers/models/bark Expand file tree Collapse file tree 1 file changed +11
-10
lines changed Original file line number Diff line number Diff line change @@ -1067,16 +1067,6 @@ def tie_weights(self):
1067
1067
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
1068
1068
weights instead.
1069
1069
"""
1070
- if getattr (self .config , "tie_word_embeddings" , True ):
1071
- self ._tied_weights_keys = []
1072
- output_embeddings = self .get_output_embeddings ()
1073
- input_embeddings = self .get_input_embeddings ()
1074
-
1075
- for i in range (self .config .n_codes_total - self .config .n_codes_given ):
1076
- # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
1077
- self ._tie_or_clone_weights (output_embeddings [i ], input_embeddings [i + 1 ])
1078
- self ._tied_weights_keys .append (f"lm_heads.{ i } .weight" )
1079
-
1080
1070
for module in self .modules ():
1081
1071
if hasattr (module , "_tie_weights" ):
1082
1072
module ._tie_weights ()
@@ -1621,6 +1611,17 @@ def generate(
1621
1611
1622
1612
return audio
1623
1613
1614
+ def tie_weights (self ):
1615
+ """
1616
+ Tie the weights between the input embeddings list and the output embeddings list.
1617
+
1618
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
1619
+ weights instead.
1620
+ """
1621
+ for module in self .modules ():
1622
+ if hasattr (module , "_tie_weights" ):
1623
+ module ._tie_weights ()
1624
+
1624
1625
1625
1626
__all__ = [
1626
1627
"BarkFineModel" ,
You can’t perform that action at this time.
0 commit comments