Skip to content

Commit 8ba828c

Browse files
hominkhkwon
andauthored
Wav2Vec2 upgrade with Conv1D options (#1758)
* Wav2Vec2 upgrade with Conv1D options * refining scripts * refining script again * fix the formats * fix the isort format * refining the library * update based on the suggestions * update the variable name * adding unk_token removal for the Python testing * adding whitespace * update Python format * update variables * update variables * update variables * update variables --------- Co-authored-by: hkwon <[email protected]>
1 parent d202032 commit 8ba828c

File tree

6 files changed

+207
-117
lines changed

6 files changed

+207
-117
lines changed

include/ctranslate2/layers/wav2vec2.h

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,52 @@
55
namespace ctranslate2 {
66
namespace layers {
77

8+
class Wav2Vec2LayerNormConvLayer : public Layer {
9+
public:
10+
Wav2Vec2LayerNormConvLayer(const models::Model& model,
11+
const std::string& scope,
12+
dim_t stride,
13+
dim_t padding);
14+
15+
void operator()(const StorageView& input, StorageView& output) const;
16+
17+
DataType output_type() const override {
18+
return _conv.output_type();
19+
}
20+
21+
dim_t output_size() const override {
22+
return _conv.output_size();
23+
}
24+
25+
private:
26+
dim_t _stride;
27+
dim_t _padding;
28+
const Conv1D _conv;
29+
const LayerNorm _output_norm;
30+
const ops::Transpose _transpose;
31+
const ops::GELU _gelu;
32+
};
33+
34+
class Wav2Vec2PosConvLayer : public Layer {
35+
public:
36+
Wav2Vec2PosConvLayer(const models::Model& model, const std::string& scope);
37+
38+
void operator()(const StorageView& input, StorageView& output) const;
39+
40+
DataType output_type() const override {
41+
return _conv.output_type();
42+
}
43+
44+
dim_t output_size() const override {
45+
return _conv.output_size();
46+
}
47+
48+
private:
49+
const Conv1D _conv;
50+
const ops::Transpose _transpose;
51+
const ops::GELU _gelu;
52+
};
53+
854
class Wav2Vec2Encoder : public Layer {
955
public:
1056
Wav2Vec2Encoder(const models::Model& model, const std::string& scope);
@@ -35,12 +81,17 @@ namespace ctranslate2 {
3581
}
3682

3783
private:
84+
const Wav2Vec2LayerNormConvLayer _feat_layer0;
85+
const std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>> _feat_layers;
86+
const LayerNorm _fp_norm;
87+
const Dense _fp_ff;
88+
const Wav2Vec2PosConvLayer _pos_conv_embed;
89+
const ops::Transpose _transpose;
3890
const ops::GELU _gelu;
39-
// wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
40-
//const ops::Transpose _transpose;
4191
const dim_t _num_heads;
4292
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
4393
const LayerNorm _output_norm;
94+
const Dense _lm_head;
4495
};
4596

4697
}

python/ctranslate2/converters/transformers.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -992,9 +992,8 @@ def architecture_name(self):
992992
return "Wav2Vec2ForCTC"
993993

994994
def get_model_spec(self, model):
995-
# Wav2Vec2 encoder Wav2Vec2PositionalConvEmbedding conv1d has groups 16
996-
# that doesn't look available here so we make Wav2Vec2 encoder layers only
997995
spec = wav2vec2_spec.Wav2Vec2Spec(
996+
model.wav2vec2.config.num_feat_extract_layers,
998997
model.wav2vec2.encoder.config.num_hidden_layers,
999998
model.wav2vec2.encoder.config.num_attention_heads,
1000999
)
@@ -1007,9 +1006,7 @@ def get_model_spec(self, model):
10071006
layer.fc1 = layer.feed_forward.intermediate_dense
10081007
layer.fc2 = layer.feed_forward.output_dense
10091008

1010-
self.set_encoder(spec.encoder, model.wav2vec2.encoder)
1011-
self.set_linear(spec.lm_head, model.lm_head)
1012-
# only for Wav2Vec2Spec.get_vocabulary_size()
1009+
self.set_encoder(spec.encoder, model, model.wav2vec2.config)
10131010
return spec
10141011

10151012
def set_config(self, config, model, tokenizer):
@@ -1021,8 +1018,42 @@ def get_vocabulary(self, model, tokenizer):
10211018
def set_vocabulary(self, spec, tokens):
10221019
spec.register_vocabulary(tokens)
10231020

1024-
def set_encoder(self, spec, encoder):
1025-
super().set_encoder(spec, encoder)
1021+
def set_feature_extractor(self, spec, feature_extractor):
1022+
spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight
1023+
spec.feat_layer0.conv.bias = feature_extractor.conv_layers[0].conv.bias
1024+
self.set_layer_norm(
1025+
spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm
1026+
)
1027+
for spec_layer, module_layer in zip(
1028+
spec.feat_layer, feature_extractor.conv_layers[1:]
1029+
):
1030+
spec_layer.conv.weight = module_layer.conv.weight
1031+
spec_layer.conv.bias = module_layer.conv.bias
1032+
self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm)
1033+
1034+
def set_feature_projection(self, spec, feature_projection):
1035+
self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm)
1036+
self.set_linear(spec.fp_projection, feature_projection.projection)
1037+
1038+
def set_pos_conv_embed(self, spec, encoder, config):
1039+
# forcing parameters to be set because some transformers version initializes garbage numbers
1040+
# conv parameters are float16 so force float32 for the loading
1041+
encoder.pos_conv_embed.conv.weight.data = (
1042+
encoder.pos_conv_embed.conv.weight.data.float()
1043+
)
1044+
encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float()
1045+
for param in encoder.pos_conv_embed.parameters():
1046+
param.data = param.data.float()
1047+
encoder.pos_conv_embed(torch.randn((1, 1, config.hidden_size)))
1048+
spec.pos_conv_embed.conv.weight = encoder.pos_conv_embed.conv.weight
1049+
spec.pos_conv_embed.conv.bias = encoder.pos_conv_embed.conv.bias
1050+
1051+
def set_encoder(self, spec, model, config):
1052+
self.set_feature_extractor(spec, model.wav2vec2.feature_extractor)
1053+
self.set_feature_projection(spec, model.wav2vec2.feature_projection)
1054+
self.set_pos_conv_embed(spec, model.wav2vec2.encoder, config)
1055+
super().set_encoder(spec, model.wav2vec2.encoder)
1056+
self.set_linear(spec.lm_head, model.lm_head)
10261057

10271058
def set_common_layers(self, spec, module):
10281059
self.set_layer_norm(spec.layer_norm, module.layer_norm)

python/ctranslate2/specs/wav2vec2_spec.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ def __init__(self):
1313

1414

1515
class Wav2Vec2Spec(model_spec.LanguageModelSpec):
16-
def __init__(self, num_layers, num_heads):
16+
def __init__(self, feat_layers, num_layers, num_heads):
1717
super().__init__()
18-
self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads)
19-
self.lm_head = common_spec.LinearSpec()
18+
self.encoder = Wav2Vec2EncoderSpec(feat_layers, num_layers, num_heads)
2019

2120
@property
2221
def name(self):
@@ -30,14 +29,30 @@ def get_default_config(self):
3029
return Wav2Vec2Config()
3130

3231
def get_vocabulary_size(self):
33-
return self.lm_head.weight.shape[0]
32+
return self.encoder.lm_head.weight.shape[0]
33+
34+
35+
class Wav2Vec2LayerNormConvLayer(model_spec.LayerSpec):
36+
def __init__(self):
37+
self.conv = common_spec.Conv1DSpec()
38+
self.layer_norm = common_spec.LayerNormSpec()
39+
40+
41+
class Wav2Vec2PosEmbedConvLayer(model_spec.LayerSpec):
42+
def __init__(self):
43+
self.conv = common_spec.Conv1DSpec()
3444

3545

3646
class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
37-
def __init__(self, num_layers, num_heads):
47+
def __init__(self, feat_layers, num_layers, num_heads):
3848
self.num_heads = np.dtype("int16").type(num_heads)
39-
# wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
49+
self.feat_layer0 = Wav2Vec2LayerNormConvLayer()
50+
self.feat_layer = [Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)]
51+
self.fp_layer_norm = common_spec.LayerNormSpec()
52+
self.fp_projection = common_spec.LinearSpec()
53+
self.pos_conv_embed = Wav2Vec2PosEmbedConvLayer()
4054
self.layer_norm = common_spec.LayerNormSpec()
4155
self.layer = [
4256
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
4357
]
58+
self.lm_head = common_spec.LinearSpec()

python/tests/test_transformers.py

Lines changed: 17 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -979,24 +979,16 @@ def test_transformers_wav2vec2(
979979
)
980980
output_dir = str(tmp_dir.join("ctranslate2_model"))
981981
output_dir = converter.convert(output_dir)
982-
# 24 x Wav2Vec2EncoderLayerStableLayerNorm converted & saved
983982

984-
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(model_name)
985-
del w2v2_model.wav2vec2.encoder.layers
986-
del w2v2_model.wav2vec2.encoder.layer_norm
987-
w2v2_model.save_pretrained(output_dir + "/wav2vec2_partial.bin")
988983
w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name)
989-
torch.save(w2v2_processor, output_dir + "/wav2vec2_processor.bin")
984+
w2v2_processor.save_pretrained(output_dir + "/wav2vec2_processor")
985+
processor = transformers.AutoProcessor.from_pretrained(
986+
output_dir + "/wav2vec2_processor"
987+
)
990988

991989
device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
992990
cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0))
993-
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(
994-
output_dir + "/wav2vec2_partial.bin"
995-
).to(device)
996-
del w2v2_model.wav2vec2.encoder.layers
997-
del w2v2_model.wav2vec2.encoder.layer_norm
998-
w2v2_processor = torch.load(output_dir + "/wav2vec2_processor.bin")
999-
ct2_w2v2_model = ctranslate2.models.Wav2Vec2(
991+
model = ctranslate2.models.Wav2Vec2(
1000992
output_dir,
1001993
device=device,
1002994
device_index=[0],
@@ -1008,73 +1000,26 @@ def test_transformers_wav2vec2(
10081000
speech_array = np.load(
10091001
os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy")
10101002
)
1011-
input_values = w2v2_processor(
1003+
input_values = processor(
10121004
speech_array,
10131005
padding=True,
10141006
return_tensors="pt",
10151007
sampling_rate=16000,
10161008
).input_values
10171009

1018-
with torch.no_grad():
1019-
extract_features = w2v2_model.wav2vec2.feature_extractor(
1020-
input_values.to(w2v2_model.device)
1021-
).transpose(1, 2)
1022-
hidden_states, extract_features = w2v2_model.wav2vec2.feature_projection(
1023-
extract_features
1024-
)
1025-
position_embeddings = w2v2_model.wav2vec2.encoder.pos_conv_embed(
1026-
hidden_states
1027-
)
1028-
hidden_states = position_embeddings + hidden_states
1029-
# hidden_states = w2v2_model.encoder.dropout(hidden_states)
1030-
# Dropout(p=0.0, inplace=False) bypassed
1031-
1032-
if ct2_w2v2_model.device == "cuda":
1033-
hidden_states = hidden_states.cpu()
1034-
else:
1035-
hidden_states.numpy()
1036-
1037-
hidden_states = np.ascontiguousarray(hidden_states)
1010+
hidden_states = np.ascontiguousarray(input_values.unsqueeze(0))
10381011
hidden_states = ctranslate2.StorageView.from_array(hidden_states)
1039-
to_cpu = (
1040-
ct2_w2v2_model.device == "cuda" and len(ct2_w2v2_model.device_index) > 1
1041-
)
1042-
ct2_output = ct2_w2v2_model.encode(
1043-
hidden_states,
1044-
to_cpu=to_cpu,
1045-
) # 24 x Wav2Vec2EncoderLayerStableLayerNorm processed
1046-
if ct2_w2v2_model.device == "cuda":
1047-
hidden_states = torch.as_tensor(
1048-
ct2_output,
1049-
device=ct2_w2v2_model.device,
1050-
)
1012+
to_cpu = model.device == "cuda" and len(model.device_index) > 1
1013+
output = model.encode(hidden_states, to_cpu=to_cpu)
1014+
if model.device == "cuda":
1015+
logits = torch.as_tensor(output, device=model.device)[0]
10511016
else:
1052-
hidden_states = torch.as_tensor(
1053-
np.array(ct2_output),
1054-
dtype=torch.float32,
1055-
device=ct2_w2v2_model.device,
1056-
)
1057-
1058-
encoder_outputs = transformers.modeling_outputs.BaseModelOutput(
1059-
last_hidden_state=hidden_states,
1060-
hidden_states=None,
1061-
attentions=None,
1062-
)
1063-
hidden_states = encoder_outputs[0]
1064-
outputs = transformers.modeling_outputs.Wav2Vec2BaseModelOutput(
1065-
last_hidden_state=hidden_states,
1066-
extract_features=extract_features,
1067-
hidden_states=encoder_outputs.hidden_states,
1068-
attentions=encoder_outputs.attentions,
1069-
)
1070-
hidden_states = outputs[0]
1071-
# hidden_states = w2v2_model.dropout(hidden_states)
1072-
# Dropout(p=0.0, inplace=False) bypassed
1073-
1074-
with torch.no_grad():
1075-
logits = w2v2_model.lm_head(hidden_states.to(torch.float32))[0]
1017+
logits = torch.as_tensor(
1018+
np.array(output), dtype=torch.float32, device=model.device
1019+
)[0]
10761020

10771021
predicted_ids = torch.argmax(logits, dim=-1)
1078-
transcription = w2v2_processor.decode(predicted_ids, output_word_offsets=True)
1022+
transcription = processor.decode(predicted_ids, output_word_offsets=True)
1023+
transcription = transcription[0].replace(processor.tokenizer.unk_token, "")
10791024

1080-
assert transcription[0] == expected_transcription[0]
1025+
assert transcription == expected_transcription[0]

0 commit comments

Comments
 (0)