@@ -1004,10 +1004,13 @@ def architecture_name(self):
10041004 return "Wav2Vec2ForCTC"
10051005
10061006 def get_model_spec (self , model ):
1007+ return_hidden = getattr (model .wav2vec2 .config , "return_hidden" , False )
10071008 spec = wav2vec2_spec .Wav2Vec2Spec (
10081009 model .wav2vec2 .config .num_feat_extract_layers ,
10091010 model .wav2vec2 .encoder .config .num_hidden_layers ,
10101011 model .wav2vec2 .encoder .config .num_attention_heads ,
1012+ model .lm_head .weight .shape [0 ],
1013+ return_hidden ,
10111014 )
10121015
10131016 # layer component name matching (no duplications saving)
@@ -1065,7 +1068,9 @@ def set_encoder(self, spec, model, config):
10651068 self .set_feature_projection (spec , model .wav2vec2 .feature_projection )
10661069 self .set_pos_conv_embed (spec , model .wav2vec2 .encoder , config )
10671070 super ().set_encoder (spec , model .wav2vec2 .encoder )
1068- self .set_linear (spec .lm_head , model .lm_head )
1071+ return_hidden = getattr (model .wav2vec2 .config , "return_hidden" , False )
1072+ if not return_hidden :
1073+ self .set_linear (spec .lm_head , model .lm_head )
10691074
10701075 def set_common_layers (self , spec , module ):
10711076 self .set_layer_norm (spec .layer_norm , module .layer_norm )
@@ -1078,9 +1083,12 @@ def architecture_name(self):
10781083 return "Wav2Vec2BertForCTC"
10791084
10801085 def get_model_spec (self , model ):
1086+ return_hidden = getattr (model .wav2vec2_bert .config , "return_hidden" , False )
10811087 spec = wav2vec2bert_spec .Wav2Vec2BertSpec (
10821088 model .wav2vec2_bert .config .num_adapter_layers ,
10831089 model .wav2vec2_bert .config .num_hidden_layers ,
1090+ model .lm_head .weight .shape [0 ],
1091+ return_hidden ,
10841092 )
10851093 self .set_encoder (spec .encoder , model )
10861094 return spec
@@ -1170,7 +1178,9 @@ def set_encoder(self, spec, model):
11701178 self .set_wav2vec2bert_adapter (
11711179 spec .adapter_layers , model .wav2vec2_bert .adapter .layers
11721180 )
1173- self .set_linear (spec .lm_head , model .lm_head )
1181+ return_hidden = getattr (model .wav2vec2_bert .config , "return_hidden" , False )
1182+ if not return_hidden :
1183+ self .set_linear (spec .lm_head , model .lm_head )
11741184
11751185 def set_conv1d (self , spec , module ):
11761186 spec .weight = module .weight
0 commit comments