Skip to content

Commit 132258e

Browse files
committed
fix: issues with lite-whisper models
1 parent 76c0bb4 commit 132258e

File tree

9 files changed

+141
-90
lines changed

9 files changed

+141
-90
lines changed

include/ctranslate2/layers/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ namespace ctranslate2 {
150150
const models::QUANTIZATION_TYPE _quant_method;
151151
const bool _quantized_gemm;
152152
const ops::Gemm _gemm_op;
153+
const ops::Gemm _gemm_op_low_rank;
153154
const ops::Quantize _quantize_op;
154155
const ops::Dequantize _dequantize_op;
155156
const ops::ActivationType* _activation_type;

python/ctranslate2/converters/transformers.py

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import gc
44
import itertools
55
import os
6-
import re
76

87
from typing import List, Optional
98

@@ -97,13 +96,6 @@ def __init__(
9796
trust_remote_code: Allow converting models using custom code.
9897
"""
9998
self._model_name_or_path = model_name_or_path
100-
self._model_processor_name = model_name_or_path
101-
if model_name_or_path.startswith('efficient-speech/lite-whisper'):
102-
# If this is a lite-whisper model, use openai's
103-
# corresponding preprocessor.
104-
regex = r'whisper-[a-z0-9-]+?(?=-(?:fast|acc)|$)'
105-
regex_result = re.search(regex, model_name_or_path)
106-
self._model_processor_name = f"openai/{regex_result.group()}"
10799
self._activation_scales = activation_scales
108100
self._copy_files = copy_files
109101
self._load_as_float16 = load_as_float16
@@ -127,6 +119,14 @@ def _load(self):
127119
% (config_name, ", ".join(sorted(_MODEL_LOADERS.keys())))
128120
)
129121

122+
# If lite whisper use corresponding openai tokenizer
123+
if config.model_type == "lite-whisper":
124+
base_name = self._model_name_or_path.split("/")[-1] # e.g., "lite-whisper-large-v3"
125+
base_name = base_name.replace("lite-", "") # e.g., "whisper-large-v3"
126+
tokenizer_path = f"openai/{base_name}"
127+
else:
128+
tokenizer_path = self._model_name_or_path
129+
130130
tokenizer_class = transformers.AutoTokenizer
131131

132132
kwargs = {
@@ -147,18 +147,15 @@ def _load(self):
147147
if hasattr(transformers, loader.architecture_name):
148148
model_class = getattr(transformers, loader.architecture_name)
149149
model = self.load_model(model_class, self._model_name_or_path, **kwargs)
150-
elif self._model_name_or_path.startswith('efficient-speech/lite-whisper'):
151-
model = transformers.AutoModel.from_pretrained(self._model_name_or_path, **kwargs)
152150
else:
153-
raise ValueError(
154-
"The model %s is not supported by the converter. " % self._model_name_or_path)
151+
model = transformers.AutoModel.from_pretrained(self._model_name_or_path, **kwargs)
155152

156153
tokenizer_kwargs = {}
157154
if self._trust_remote_code:
158155
tokenizer_kwargs["trust_remote_code"] = self._trust_remote_code
159156

160157
tokenizer = self.load_tokenizer(
161-
tokenizer_class, self._model_processor_name, **tokenizer_kwargs
158+
tokenizer_class, tokenizer_path, **tokenizer_kwargs
162159
)
163160

164161
spec = loader(model, tokenizer)
@@ -251,19 +248,6 @@ def set_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
251248
spec.weight = spec.weight.transpose(0, 1)
252249
if module.bias is not None:
253250
spec.bias = module.bias
254-
255-
def set_low_rank_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
256-
if quant_type == common_spec.Quantization.CT2:
257-
spec.low_rank_weight_1 = module.weight1
258-
spec.low_rank_weight_2 = module.weight2
259-
else:
260-
spec.low_rank_weight_1 = module.qweight1
261-
spec.low_rank_weight_2 = module.qweight2
262-
spec.weight_scale = module.scales
263-
spec.weight_zero = module.qzeros
264-
265-
if module.bias is not None:
266-
spec.bias = module.bias
267251

268252
def set_embeddings(self, spec, module):
269253
spec.weight = module.weight
@@ -1044,10 +1028,45 @@ def get_model_spec(self, model):
10441028

10451029
return spec
10461030

1031+
1032+
def set_config(self, config, model, tokenizer):
1033+
gen_config = getattr(model, "generation_config", None)
1034+
1035+
if gen_config is not None:
1036+
config.suppress_ids = gen_config.suppress_tokens
1037+
config.suppress_ids_begin = gen_config.begin_suppress_tokens
1038+
if hasattr(gen_config, "alignment_heads"):
1039+
config.alignment_heads = gen_config.alignment_heads
1040+
if hasattr(gen_config, "lang_to_id"):
1041+
config.lang_ids = sorted(gen_config.lang_to_id.values())
1042+
else:
1043+
config.suppress_ids = model.config.suppress_tokens
1044+
config.suppress_ids_begin = model.config.begin_suppress_tokens
1045+
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)
1046+
1047+
if getattr(config, "lang_ids", None) is None:
1048+
config.lang_ids = self._get_lang_ids_from_tokenizer(tokenizer)
1049+
1050+
if config.alignment_heads is None:
1051+
config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)
1052+
if config.alignment_heads is None:
1053+
# Use the last half layers for alignment by default.
1054+
num_layers = model.config.decoder_layers
1055+
num_heads = model.config.decoder_attention_heads
1056+
config.alignment_heads = list(
1057+
itertools.product(
1058+
range(num_layers // 2, num_layers),
1059+
range(num_heads),
1060+
)
1061+
)
1062+
10471063
def set_encoder(self, spec, encoder):
1064+
"""
1065+
Override encoder mapping for LiteWhisper.
1066+
"""
10481067
self.set_conv1d(spec.conv1, encoder.conv1)
10491068
self.set_conv1d(spec.conv2, encoder.conv2)
1050-
1069+
10511070
self.set_common_layers(spec, encoder)
10521071

10531072
for layer_spec, layer in zip(spec.layer, encoder.layers):
@@ -1060,29 +1079,42 @@ def set_encoder(self, spec, encoder):
10601079
layer.self_attn_layer_norm,
10611080
)
10621081

1063-
# Double check if these are low rank or not because of potential
1064-
# fall backs to full precision.
1065-
if hasattr(layer.fc1, 'weight1'):
1082+
if hasattr(layer.fc1, "weight1"):
1083+
# low rank
10661084
self.set_low_rank_linear(layer_spec.ffn.linear_0, layer.fc1)
10671085
else:
10681086
layer_spec.ffn.linear_0 = common_spec.LinearSpec()
10691087
self.set_linear(layer_spec.ffn.linear_0, layer.fc1)
1070-
1071-
if hasattr(layer.fc2, 'weight1'):
1088+
1089+
if hasattr(layer.fc2, "weight1"):
1090+
# low rank
10721091
self.set_low_rank_linear(layer_spec.ffn.linear_1, layer.fc2)
10731092
else:
10741093
layer_spec.ffn.linear_1 = common_spec.LinearSpec()
10751094
self.set_linear(layer_spec.ffn.linear_1, layer.fc2)
10761095

10771096
self.set_layer_norm(layer_spec.ffn.layer_norm, layer.final_layer_norm)
10781097

1098+
def set_low_rank_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
1099+
if quant_type == common_spec.Quantization.CT2:
1100+
spec.low_rank_weight_1 = module.weight1.transpose(0, 1).contiguous()
1101+
spec.low_rank_weight_2 = module.weight2.transpose(0, 1).contiguous()
1102+
else:
1103+
spec.low_rank_weight_1 = module.qweight1.transpose(0, 1).contiguous()
1104+
spec.low_rank_weight_2 = module.qweight2.transpose(0, 1).contiguous()
1105+
spec.weight_scale = module.scales
1106+
spec.weight_zero = module.qzeros
1107+
1108+
if module.bias is not None:
1109+
spec.bias = module.bias
1110+
10791111
def set_low_rank_or_linear_router(self, spec, module, i):
10801112
if hasattr(module, "weight1"):
10811113
self.set_low_rank_linear(spec.linear[i], module)
10821114
else:
10831115
spec.linear[i] = common_spec.LinearSpec()
10841116
self.set_linear(spec.linear[i], module)
1085-
1117+
10861118
def set_low_rank_attention(self, spec, attention):
10871119
self.set_low_rank_or_linear_router(spec, attention.q_proj, 0)
10881120
self.set_low_rank_or_linear_router(spec, attention.k_proj, 1)
@@ -3000,6 +3032,7 @@ def main():
30003032
(3, 4),
30013033
],
30023034
"openai/whisper-tiny": [(2, 2), (3, 0), (3, 2), (3, 3), (3, 4), (3, 5)],
3035+
"efficient-speech/whisper-tiny": [(2, 2), (3, 0), (3, 2), (3, 3), (3, 4), (3, 5)],
30033036
"openai/whisper-base.en": [(3, 3), (4, 7), (5, 1), (5, 5), (5, 7)],
30043037
"openai/whisper-base": [
30053038
(3, 1),
@@ -3113,4 +3146,16 @@ def main():
31133146
(24, 1),
31143147
(25, 6),
31153148
],
3149+
"efficient-speech/whisper-large-v3": [
3150+
(7, 0),
3151+
(10, 17),
3152+
(12, 18),
3153+
(13, 12),
3154+
(16, 1),
3155+
(17, 14),
3156+
(19, 11),
3157+
(21, 4),
3158+
(24, 1),
3159+
(25, 6),
3160+
],
31163161
}

python/ctranslate2/specs/attention_spec.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,9 @@ def __init__(
3737
self.queries_scale = model_spec.OPTIONAL
3838

3939
self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
40-
if low_rank:
41-
self.linear = [common_spec.LowRankLinearSpec() for _ in range(4)]
42-
else:
43-
self.linear = [
44-
common_spec.LinearSpec() for _ in range(2 if self_attention else 3)
45-
]
40+
linear_cls = common_spec.LinearLowRankSpec if low_rank else common_spec.LinearSpec
41+
count = 4 if low_rank else (2 if self_attention else 3)
42+
self.linear = [linear_cls() for _ in range(count)]
4643

4744
if relative_position:
4845
self.relative_position_keys = None

python/ctranslate2/specs/common_spec.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,6 @@ def __init__(self):
5151
def has_bias(self):
5252
return not isinstance(self.bias, str)
5353

54-
class LowRankLinearSpec(model_spec.LayerSpec):
55-
def __init__(self):
56-
super().__init__()
57-
self.low_rank_weight_1 = None
58-
self.low_rank_weight_2 = None
59-
self.weight_scale = model_spec.OPTIONAL
60-
self.weight_zero = model_spec.OPTIONAL
61-
self.bias = model_spec.OPTIONAL
62-
63-
def has_bias(self):
64-
return not isinstance(self.bias, str)
65-
6654

6755
class Conv1DSpec(model_spec.LayerSpec):
6856
def __init__(self):
@@ -76,3 +64,15 @@ def __init__(self):
7664
self.weight = None
7765
self.weight_scale = model_spec.OPTIONAL
7866
self.multiply_by_sqrt_depth = model_spec.OPTIONAL
67+
68+
69+
class LinearLowRankSpec(model_spec.LayerSpec):
70+
def __init__(self):
71+
self.low_rank_weight_1 = None
72+
self.low_rank_weight_2 = None
73+
self.weight_scale = model_spec.OPTIONAL
74+
self.weight_zero = model_spec.OPTIONAL
75+
self.bias = model_spec.OPTIONAL
76+
77+
def has_bias(self):
78+
return not isinstance(self.bias, str)

python/ctranslate2/specs/transformer_spec.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def __init__(
253253
rms_norm=False,
254254
num_heads_kv=None,
255255
sliding_window=None,
256-
low_rank=False
256+
low_rank=False,
257257
):
258258
self.self_attention = attention_spec.MultiHeadAttentionSpec(
259259
self_attention=True,
@@ -344,8 +344,9 @@ def __init__(
344344
class FeedForwardSpec(model_spec.LayerSpec):
345345
def __init__(self, glu=False, rms_norm=False, low_rank=False):
346346
self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
347-
self.linear_0 = common_spec.LinearSpec() if not low_rank else common_spec.LowRankLinearSpec()
348-
self.linear_1 = common_spec.LinearSpec() if not low_rank else common_spec.LowRankLinearSpec()
347+
linear_cls = common_spec.LinearLowRankSpec if low_rank else common_spec.LinearSpec
348+
self.linear_0 = linear_cls()
349+
self.linear_1 = linear_cls()
349350
if glu:
350351
self.linear_0_noact = common_spec.LinearSpec()
351352

python/ctranslate2/specs/whisper_spec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def __init__(
4141
num_encoder_heads: The number of encoder attention heads.
4242
num_decoder_layers: The number of decoder layers.
4343
num_decoder_heads: The number of decoder attention heads.
44+
low_rank: Whether to use lite whisper model or not.
4445
"""
4546
super().__init__()
46-
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads, low_rank)
47+
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads, low_rank=low_rank)
4748
self.decoder = transformer_spec.TransformerDecoderSpec(
4849
num_decoder_layers,
4950
num_decoder_heads,

src/layers/attention.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,12 @@ namespace ctranslate2 {
360360
q = &queries_proj;
361361
}
362362

363-
if (!_is_low_rank) {
364-
_linear[0](*q, fused_proj);
365-
} else {
366-
// Low-rank attention does not fuse qkv.
367-
_linear[0](*q, queries_proj);
363+
_linear[0](*q, fused_proj);
364+
365+
if (_is_low_rank) { // support low-rank
368366
_linear[1](*q, keys_proj);
369367
_linear[2](*q, values_proj);
368+
queries_proj = std::move(fused_proj);
370369
}
371370

372371
dim_t beam_size = 1;
@@ -375,7 +374,7 @@ namespace ctranslate2 {
375374

376375
if (!_self_attention) {
377376
if (_is_low_rank)
378-
throw std::invalid_argument("MultiHeadAttention does not support low-rank attention with cross-attention");
377+
throw std::invalid_argument("lite whisper doesn't use low-rank for cross-attention");
379378
queries_proj = std::move(fused_proj);
380379

381380
if (cached_keys == nullptr || cached_keys->empty()) {
@@ -411,7 +410,7 @@ namespace ctranslate2 {
411410

412411
if (_num_heads_kv < _num_heads) {
413412
if (_is_low_rank)
414-
throw std::invalid_argument("MutliHeadAttention does not support low-rank attention with multi-query or GQA");
413+
throw std::invalid_argument("lite whisper doesn't use low-rank for multi-query or GQA");
415414
if (queries_padder)
416415
queries_padder->add_padding(fused_proj);
417416

@@ -430,10 +429,11 @@ namespace ctranslate2 {
430429
}
431430

432431
} else {
433-
if (!_is_low_rank) {
432+
if (!_is_low_rank){
434433
split_heads(fused_proj, 3 * _num_heads, queries_padder);
435434
ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj);
436-
} else {
435+
}
436+
else{
437437
split_heads(queries_proj, _num_heads, queries_padder);
438438
split_heads(keys_proj, _num_heads_kv, queries_padder);
439439
split_heads(values_proj, _num_heads_kv, queries_padder);

src/layers/attention_layer.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,24 @@ namespace ctranslate2 {
5252
}
5353

5454
static bool set_low_rank(const models::Model& model, const std::string& scope) {
55-
const StorageView* low_rank_weight = model.get_variable_if_exists(scope + "/linear_0/low_rank_weight_1");
56-
if (low_rank_weight) {
57-
return true;
55+
const dim_t max_layers = 4;
56+
for (int i = 0; i < max_layers; ++i) {
57+
std::string prefix = scope + "/linear_" + std::to_string(i);
58+
const StorageView* w1 = model.get_variable_if_exists(prefix + "/low_rank_weight_1");
59+
const StorageView* w2 = model.get_variable_if_exists(prefix + "/low_rank_weight_2");
60+
if (w1 && w2) {
61+
return true;
62+
}
5863
}
64+
// If no low-rank pair is found, then it is not low-rank
5965
return false;
6066
}
6167

6268
static std::vector<Dense> make_linear_layers(const models::Model& model,
6369
const std::string& scope,
6470
bool self_attention,
6571
bool _is_low_rank) {
66-
dim_t num_linear_layers;
67-
if (!_is_low_rank) {
68-
num_linear_layers = self_attention ? 2 : 3;
69-
} else {
70-
num_linear_layers = 4;
71-
}
72-
72+
const dim_t num_linear_layers = !_is_low_rank ? (self_attention ? 2 : 3) : 4;
7373
std::vector<Dense> layers;
7474
layers.reserve(num_linear_layers);
7575
for (dim_t i = 0; i < num_linear_layers; ++i)

0 commit comments

Comments
 (0)