Skip to content

Commit 5795843

Browse files
hominkhkwon
andauthored
Backward compatibility for the Wav2Vec2 ASR model (OpenNMT#1810)
* update description for the wav2vec2 model * the backward compatiability support for the wav2vec2 ASR model * dummy * dummy push * dummy push * dummy push * header update * dummpy push --------- Co-authored-by: hkwon <homin.kwon@sri.com>
1 parent 383d063 commit 5795843

File tree

3 files changed

+58
-41
lines changed

3 files changed

+58
-41
lines changed

include/ctranslate2/layers/wav2vec2.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <optional>
34
#include "ctranslate2/layers/transformer.h"
45

56
namespace ctranslate2 {
@@ -81,17 +82,18 @@ namespace ctranslate2 {
8182
}
8283

8384
private:
84-
const Wav2Vec2LayerNormConvLayer _feat_layer0;
85-
const std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>> _feat_layers;
86-
const LayerNorm _fp_norm;
87-
const Dense _fp_ff;
88-
const Wav2Vec2PosConvLayer _pos_conv_embed;
85+
const StorageView* _upgraded_model;
86+
std::optional<Wav2Vec2LayerNormConvLayer> _feat_layer0;
87+
std::optional<std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>>> _feat_layers;
88+
std::optional<LayerNorm> _fp_norm;
89+
std::optional<Dense> _fp_ff;
90+
std::optional<Wav2Vec2PosConvLayer> _pos_conv_embed;
8991
const ops::Transpose _transpose;
9092
const ops::GELU _gelu;
9193
const dim_t _num_heads;
9294
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
9395
const LayerNorm _output_norm;
94-
const Dense _lm_head;
96+
std::optional<Dense> _lm_head;
9597
};
9698

9799
}

python/cpp/wav2vec2.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ namespace ctranslate2 {
8686
Encodes the input features.
8787
8888
Arguments:
89-
features: Mel spectogram of the audio, as a float array with shape
90-
``[batch_size, 80, 3000]``.
89+
features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or
90+
raw audio, as a float array with shape (followed by VAD)
91+
``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]``
9192
to_cpu: Copy the encoder output to the CPU before returning the value.
9293
9394
Returns:

src/layers/wav2vec2.cc

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,7 @@ namespace ctranslate2 {
4646
}
4747

4848
Wav2Vec2Encoder::Wav2Vec2Encoder(const models::Model& model, const std::string& scope)
49-
: _feat_layer0(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0)
50-
, _feat_layers(build_layers_list<const Wav2Vec2LayerNormConvLayer>(model,
51-
scope + "/feat_layer",
52-
/*stride=*/2,
53-
/*padding=*/0))
54-
, _fp_norm(model, scope + "/fp_layer_norm")
55-
, _fp_ff(model, scope + "/fp_projection", nullptr, true)
56-
, _pos_conv_embed(model, scope + "/pos_conv_embed")
49+
: _upgraded_model(model.get_variable_if_exists(scope + "/lm_head/weight"))
5750
, _num_heads(model.get_attribute_with_default<int32_t>(scope + "/num_heads", 8))
5851
, _transpose({0, 2, 1})
5952
, _layers(build_layers_list<const TransformerEncoderLayer>(model,
@@ -62,8 +55,18 @@ namespace ctranslate2 {
6255
/*pre_norm=*/true,
6356
ops::ActivationType::GELU))
6457
, _output_norm(model, scope + "/layer_norm")
65-
, _lm_head(model, scope + "/lm_head", nullptr, true)
6658
{
59+
if (_upgraded_model) {
60+
_feat_layer0.emplace(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0);
61+
_feat_layers.emplace(build_layers_list<const Wav2Vec2LayerNormConvLayer>(model,
62+
scope + "/feat_layer",
63+
/*stride=*/2,
64+
/*padding=*/0));
65+
_fp_norm.emplace(model, scope + "/fp_layer_norm");
66+
_fp_ff.emplace(model, scope + "/fp_projection", nullptr, true);
67+
_pos_conv_embed.emplace(model, scope + "/pos_conv_embed");
68+
_lm_head.emplace(model, scope + "/lm_head", nullptr, true);
69+
}
6770
}
6871

6972
void Wav2Vec2Encoder::operator()(const StorageView& features, StorageView& output) {
@@ -74,33 +77,44 @@ namespace ctranslate2 {
7477
throw std::invalid_argument("Expected input features to have 3 dimensions, but got "
7578
+ std::to_string(features.rank())
7679
+ " dimension(s) instead");
77-
78-
// Wav2Vec2FeatureExtractor------------------------------------
79-
StorageView feat_buffer(features.dtype(), features.device());
80-
StorageView feat_buffer2(features.dtype(), features.device());
81-
feat_buffer = std::move(features);
82-
_feat_layer0(feat_buffer, output);
83-
feat_buffer = std::move(output);
84-
for (dim_t l = 0; l < _feat_layers.size(); l++) {
85-
(*_feat_layers[l])(feat_buffer, output);
86-
if (l < _feat_layers.size() - 1 ) {
87-
feat_buffer = std::move(output);
80+
if (_upgraded_model) {
81+
// Wav2Vec2FeatureExtractor------------------------------------
82+
StorageView feat_buffer(features.dtype(), features.device());
83+
StorageView feat_buffer2(features.dtype(), features.device());
84+
feat_buffer = std::move(features);
85+
(*_feat_layer0)(feat_buffer, output); //_feat_layer0(feat_buffer, output);
86+
feat_buffer = std::move(output);
87+
for (dim_t l = 0; l < _feat_layers->size(); l++) {
88+
(*_feat_layers.value()[l])(feat_buffer, output);
89+
if (l < _feat_layers->size() - 1 ) {
90+
feat_buffer = std::move(output);
91+
}
8892
}
93+
_transpose(output, feat_buffer);
94+
// Wav2Vec2FeatureProjection-----------------------------------
95+
(*_fp_norm)(feat_buffer, output); //_fp_norm(feat_buffer, output);
96+
(*_fp_ff)(output, feat_buffer); //_fp_ff(output, feat_buffer);
97+
// Wav2Vec2PositionalConvEmbedding-----------------------------
98+
(*_pos_conv_embed)(feat_buffer, feat_buffer2); //_pos_conv_embed(feat_buffer, feat_buffer2);
99+
// Wav2Vec2EncoderLayerStableLayerNorm-------------------------
100+
for (const auto& layer : _layers) {
101+
(*layer)(feat_buffer2, nullptr, feat_buffer);
102+
feat_buffer2 = std::move(feat_buffer);
103+
}
104+
_output_norm(feat_buffer2, feat_buffer);
105+
106+
(*_lm_head)(feat_buffer, output); //_lm_head(feat_buffer, output);
89107
}
90-
_transpose(output, feat_buffer);
91-
// Wav2Vec2FeatureProjection-----------------------------------
92-
_fp_norm(feat_buffer, output);
93-
_fp_ff(output, feat_buffer);
94-
// Wav2Vec2PositionalConvEmbedding-----------------------------
95-
_pos_conv_embed(feat_buffer, feat_buffer2);
96-
// Wav2Vec2EncoderLayerStableLayerNorm-------------------------
97-
for (const auto& layer : _layers) {
98-
(*layer)(feat_buffer2, nullptr, feat_buffer);
99-
feat_buffer2 = std::move(feat_buffer);
100-
}
101-
_output_norm(feat_buffer2, feat_buffer);
108+
else { // backward compatibility for the previous converted model
109+
StorageView input(output_type(), features.device());
110+
input = features;
111+
for (const auto& layer : _layers) {
112+
(*layer)(input, nullptr, output);
113+
input = std::move(output);
114+
}
102115

103-
_lm_head(feat_buffer, output);
116+
_output_norm(input, output);
117+
}
104118
}
105119

106120
}

0 commit comments

Comments
 (0)