Skip to content

Commit 878920e

Browse files
hominkhkwon
andauthored
Support for returning the hidden vector in the Wav2Vec2 and Wav2Vec2Bert models (#1867)
* adding the hidden vector return feature * format update * format update * dummy * dummy --------- Co-authored-by: hkwon <[email protected]>
1 parent d3e37ed commit 878920e

File tree

9 files changed

+108
-30
lines changed

9 files changed

+108
-30
lines changed

include/ctranslate2/layers/wav2vec2.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,21 @@ namespace ctranslate2 {
5959
void operator()(const StorageView& features, StorageView& output);
6060

6161
DataType output_type() const override {
62-
return _output_norm.output_type();
62+
if (_lm_head) {
63+
return (*_lm_head).output_type();
64+
}
65+
else {
66+
return _output_norm.output_type();
67+
}
6368
}
6469

6570
dim_t output_size() const override {
66-
return _output_norm.output_size();
71+
if (_lm_head) {
72+
return (*_lm_head).output_size();
73+
}
74+
else {
75+
return _output_norm.output_size();
76+
}
6777
}
6878

6979
dim_t input_size() const {
@@ -81,8 +91,10 @@ namespace ctranslate2 {
8191
&& features.dim(1) != input_size());
8292
}
8393

84-
private:
8594
const StorageView* _upgraded_model;
95+
96+
private:
97+
const StorageView* _return_logits;
8698
std::optional<Wav2Vec2LayerNormConvLayer> _feat_layer0;
8799
std::optional<std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>>> _feat_layers;
88100
std::optional<LayerNorm> _fp_norm;

include/ctranslate2/layers/wav2vec2bert.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <optional>
34
#include "ctranslate2/layers/attention.h"
45
#include "ctranslate2/layers/flash_attention.h"
56
#include "ctranslate2/layers/common.h"
@@ -92,11 +93,21 @@ namespace ctranslate2 {
9293
void operator()(const StorageView& features, StorageView& output);
9394

9495
DataType output_type() const override {
95-
return _lm_head.output_type();
96+
if (_lm_head) {
97+
return (*_lm_head).output_type();
98+
}
99+
else {
100+
return DataType::FLOAT32;
101+
}
96102
}
97103

98104
dim_t output_size() const override {
99-
return _lm_head.output_size();
105+
if (_lm_head) {
106+
return (*_lm_head).output_size();
107+
}
108+
else {
109+
return 1024;
110+
}
100111
}
101112

102113
dim_t input_size() const {
@@ -115,11 +126,12 @@ namespace ctranslate2 {
115126
}
116127

117128
private:
129+
const StorageView* _return_logits;
118130
const LayerNorm _fp_layer_norm;
119131
const Dense _fp_projection;
120132
const std::vector<std::unique_ptr<const EncoderLayer>> _encoder_layers;
121133
const std::vector<std::unique_ptr<const AdapterLayer>> _adapt_layers;
122-
const Dense _lm_head;
134+
std::optional<Dense> _lm_head;
123135
};
124136

125137
}

python/ctranslate2/converters/transformers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,10 +1004,13 @@ def architecture_name(self):
10041004
return "Wav2Vec2ForCTC"
10051005

10061006
def get_model_spec(self, model):
1007+
return_hidden = getattr(model.wav2vec2.config, "return_hidden", False)
10071008
spec = wav2vec2_spec.Wav2Vec2Spec(
10081009
model.wav2vec2.config.num_feat_extract_layers,
10091010
model.wav2vec2.encoder.config.num_hidden_layers,
10101011
model.wav2vec2.encoder.config.num_attention_heads,
1012+
model.lm_head.weight.shape[0],
1013+
return_hidden,
10111014
)
10121015

10131016
# layer component name matching (no duplications saving)
@@ -1065,7 +1068,9 @@ def set_encoder(self, spec, model, config):
10651068
self.set_feature_projection(spec, model.wav2vec2.feature_projection)
10661069
self.set_pos_conv_embed(spec, model.wav2vec2.encoder, config)
10671070
super().set_encoder(spec, model.wav2vec2.encoder)
1068-
self.set_linear(spec.lm_head, model.lm_head)
1071+
return_hidden = getattr(model.wav2vec2.config, "return_hidden", False)
1072+
if not return_hidden:
1073+
self.set_linear(spec.lm_head, model.lm_head)
10691074

10701075
def set_common_layers(self, spec, module):
10711076
self.set_layer_norm(spec.layer_norm, module.layer_norm)
@@ -1078,9 +1083,12 @@ def architecture_name(self):
10781083
return "Wav2Vec2BertForCTC"
10791084

10801085
def get_model_spec(self, model):
1086+
return_hidden = getattr(model.wav2vec2_bert.config, "return_hidden", False)
10811087
spec = wav2vec2bert_spec.Wav2Vec2BertSpec(
10821088
model.wav2vec2_bert.config.num_adapter_layers,
10831089
model.wav2vec2_bert.config.num_hidden_layers,
1090+
model.lm_head.weight.shape[0],
1091+
return_hidden,
10841092
)
10851093
self.set_encoder(spec.encoder, model)
10861094
return spec
@@ -1170,7 +1178,9 @@ def set_encoder(self, spec, model):
11701178
self.set_wav2vec2bert_adapter(
11711179
spec.adapter_layers, model.wav2vec2_bert.adapter.layers
11721180
)
1173-
self.set_linear(spec.lm_head, model.lm_head)
1181+
return_hidden = getattr(model.wav2vec2_bert.config, "return_hidden", False)
1182+
if not return_hidden:
1183+
self.set_linear(spec.lm_head, model.lm_head)
11741184

11751185
def set_conv1d(self, spec, module):
11761186
spec.weight = module.weight

python/ctranslate2/specs/wav2vec2_spec.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,22 @@ def __init__(self):
1313

1414

1515
class Wav2Vec2Spec(model_spec.LanguageModelSpec):
16-
def __init__(self, feat_layers, num_layers, num_heads):
16+
def __init__(
17+
self,
18+
feat_layers,
19+
num_layers,
20+
num_heads,
21+
vocab_size,
22+
return_hidden,
23+
):
1724
super().__init__()
18-
self.encoder = Wav2Vec2EncoderSpec(feat_layers, num_layers, num_heads)
25+
self.vocab_size = np.dtype("int16").type(vocab_size)
26+
self.encoder = Wav2Vec2EncoderSpec(
27+
feat_layers,
28+
num_layers,
29+
num_heads,
30+
return_hidden,
31+
)
1932

2033
@property
2134
def name(self):
@@ -29,7 +42,7 @@ def get_default_config(self):
2942
return Wav2Vec2Config()
3043

3144
def get_vocabulary_size(self):
32-
return self.encoder.lm_head.weight.shape[0]
45+
return int(self.vocab_size.numpy())
3346

3447

3548
class Wav2Vec2LayerNormConvLayer(model_spec.LayerSpec):
@@ -44,7 +57,7 @@ def __init__(self):
4457

4558

4659
class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
47-
def __init__(self, feat_layers, num_layers, num_heads):
60+
def __init__(self, feat_layers, num_layers, num_heads, return_hidden):
4861
self.num_heads = np.dtype("int16").type(num_heads)
4962
self.feat_layer0 = Wav2Vec2LayerNormConvLayer()
5063
self.feat_layer = [Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)]
@@ -55,4 +68,5 @@ def __init__(self, feat_layers, num_layers, num_heads):
5568
self.layer = [
5669
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
5770
]
58-
self.lm_head = common_spec.LinearSpec()
71+
if not return_hidden:
72+
self.lm_head = common_spec.LinearSpec()

python/ctranslate2/specs/wav2vec2bert_spec.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from ctranslate2.specs import attention_spec, common_spec, model_spec
24

35

@@ -9,11 +11,19 @@ def __init__(self):
911

1012

1113
class Wav2Vec2BertSpec(model_spec.LanguageModelSpec):
12-
def __init__(self, num_hidden_layers, num_adapter_layers):
14+
def __init__(
15+
self,
16+
num_hidden_layers,
17+
num_adapter_layers,
18+
vocab_size,
19+
return_hidden,
20+
):
1321
super().__init__()
22+
self.vocab_size = np.dtype("int16").type(vocab_size)
1423
self.encoder = Wav2Vec2BertEncoderSpec(
1524
num_adapter_layers,
1625
num_hidden_layers,
26+
return_hidden,
1727
)
1828

1929
@property
@@ -28,7 +38,7 @@ def get_default_config(self):
2838
return Wav2Vec2BertConfig()
2939

3040
def get_vocabulary_size(self):
31-
return self.encoder.lm_head.weight.shape[0]
41+
return int(self.vocab_size.numpy())
3242

3343

3444
class Wav2Vec2BertFeedForwardSpec(model_spec.LayerSpec):
@@ -78,9 +88,10 @@ def __init__(self):
7888

7989

8090
class Wav2Vec2BertEncoderSpec(model_spec.LayerSpec):
81-
def __init__(self, num_hidden_layers, num_adapter_layers):
91+
def __init__(self, num_hidden_layers, num_adapter_layers, return_hidden):
8292
self.fp_layer_norm = common_spec.LayerNormSpec()
8393
self.fp_projection = common_spec.LinearSpec()
8494
self.encoder_layers = [EncoderSpec() for _ in range(num_hidden_layers)]
8595
self.adapter_layers = [AdapterSpec() for _ in range(num_adapter_layers)]
86-
self.lm_head = common_spec.LinearSpec()
96+
if not return_hidden:
97+
self.lm_head = common_spec.LinearSpec()

src/layers/wav2vec2.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ namespace ctranslate2 {
4646
}
4747

4848
Wav2Vec2Encoder::Wav2Vec2Encoder(const models::Model& model, const std::string& scope)
49-
: _upgraded_model(model.get_variable_if_exists(scope + "/lm_head/weight"))
49+
: _return_logits(model.get_variable_if_exists(scope + "/lm_head/weight"))
50+
, _upgraded_model(model.get_variable_if_exists(scope + "/fp_projection/weight"))
5051
, _num_heads(model.get_attribute_with_default<int32_t>(scope + "/num_heads", 8))
5152
, _transpose({0, 2, 1})
5253
, _layers(build_layers_list<const TransformerEncoderLayer>(model,
@@ -65,7 +66,9 @@ namespace ctranslate2 {
6566
_fp_norm.emplace(model, scope + "/fp_layer_norm");
6667
_fp_ff.emplace(model, scope + "/fp_projection", nullptr, true);
6768
_pos_conv_embed.emplace(model, scope + "/pos_conv_embed");
68-
_lm_head.emplace(model, scope + "/lm_head", nullptr, true);
69+
if (_return_logits) {
70+
_lm_head.emplace(model, scope + "/lm_head", nullptr, true);
71+
}
6972
}
7073
}
7174

@@ -101,12 +104,16 @@ namespace ctranslate2 {
101104
(*layer)(feat_buffer2, nullptr, feat_buffer);
102105
feat_buffer2 = std::move(feat_buffer);
103106
}
104-
_output_norm(feat_buffer2, feat_buffer);
105-
106-
(*_lm_head)(feat_buffer, output); //_lm_head(feat_buffer, output);
107+
if (_return_logits) {
108+
_output_norm(feat_buffer2, feat_buffer);
109+
(*_lm_head)(feat_buffer, output);
110+
}
111+
else {
112+
_output_norm(feat_buffer2, output);
113+
}
107114
}
108115
else { // backward compatibility for the previous converted model
109-
StorageView input(output_type(), features.device());
116+
StorageView input(features.dtype(), features.device());
110117
input = features;
111118
for (const auto& layer : _layers) {
112119
(*layer)(input, nullptr, output);

src/layers/wav2vec2bert.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ namespace ctranslate2 {
164164
}
165165

166166
Wav2Vec2BertEncoder::Wav2Vec2BertEncoder(const models::Model& model, const std::string& scope)
167-
: _fp_layer_norm(model, scope + "/fp_layer_norm")
167+
: _return_logits(model.get_variable_if_exists(scope + "/lm_head/weight"))
168+
, _fp_layer_norm(model, scope + "/fp_layer_norm")
168169
, _fp_projection(model, scope + "/fp_projection", nullptr, true)
169170
, _encoder_layers(build_layers_list<const EncoderLayer>(model,
170171
scope + "/encoder_layers",
@@ -175,8 +176,10 @@ namespace ctranslate2 {
175176
scope + "/adapter_layers",
176177
/*pre_norm=*/true,
177178
ops::ActivationType::ReLU,
178-
/*use_flash_attention=*/false))
179-
, _lm_head(model, scope + "/lm_head", nullptr, true) {
179+
/*use_flash_attention=*/false)) {
180+
if (_return_logits) {
181+
_lm_head.emplace(model, scope + "/lm_head", nullptr, true);
182+
}
180183
}
181184

182185
void Wav2Vec2BertEncoder::operator()(const StorageView& features, StorageView& output) {
@@ -203,7 +206,12 @@ namespace ctranslate2 {
203206
buffer2 = std::move(buffer1);
204207
}
205208

206-
_lm_head(buffer2, output);
209+
if (_return_logits) {
210+
(*_lm_head)(buffer2, output);
211+
}
212+
else {
213+
output = std::move(buffer2);
214+
}
207215
}
208216

209217
}

src/models/wav2vec2.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ namespace ctranslate2 {
7878
features.move_to(device, dtype);
7979

8080
StorageView encoder_output(dtype, device);
81-
(*_encoder)(features, encoder_output);
81+
if (_encoder->_upgraded_model) {
82+
encoder_output = maybe_encode(std::move(features));
83+
}
84+
else {
85+
(*_encoder)(features, encoder_output);
86+
}
8287

8388
if (to_cpu) {
8489
if (device != Device::CPU)

src/models/wav2vec2bert.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ namespace ctranslate2 {
7777
const DataType dtype = _encoder->output_type();
7878
features.move_to(device, dtype);
7979

80-
StorageView encoder_output(dtype, device);
81-
(*_encoder)(features, encoder_output);
80+
StorageView encoder_output = maybe_encode(std::move(features));
8281

8382
if (to_cpu) {
8483
if (device != Device::CPU)

0 commit comments

Comments
 (0)