33import gc
44import itertools
55import os
6- import re
76
87from typing import List , Optional
98
@@ -97,13 +96,6 @@ def __init__(
9796 trust_remote_code: Allow converting models using custom code.
9897 """
9998 self ._model_name_or_path = model_name_or_path
100- self ._model_processor_name = model_name_or_path
101- if model_name_or_path .startswith ('efficient-speech/lite-whisper' ):
102- # If this is a lite-whisper model, use openai's
103- # corresponding preprocessor.
104- regex = r'whisper-[a-z0-9-]+?(?=-(?:fast|acc)|$)'
105- regex_result = re .search (regex , model_name_or_path )
106- self ._model_processor_name = f"openai/{ regex_result .group ()} "
10799 self ._activation_scales = activation_scales
108100 self ._copy_files = copy_files
109101 self ._load_as_float16 = load_as_float16
@@ -127,6 +119,14 @@ def _load(self):
127119 % (config_name , ", " .join (sorted (_MODEL_LOADERS .keys ())))
128120 )
129121
122+ # If lite whisper use corresponding openai tokenizer
123+ if config .model_type == "lite-whisper" :
124+ base_name = self ._model_name_or_path .split ("/" )[- 1 ] # e.g., "lite-whisper-large-v3"
125+ base_name = base_name .replace ("lite-" , "" ) # e.g., "whisper-large-v3"
126+ tokenizer_path = f"openai/{ base_name } "
127+ else :
128+ tokenizer_path = self ._model_name_or_path
129+
130130 tokenizer_class = transformers .AutoTokenizer
131131
132132 kwargs = {
@@ -147,18 +147,15 @@ def _load(self):
147147 if hasattr (transformers , loader .architecture_name ):
148148 model_class = getattr (transformers , loader .architecture_name )
149149 model = self .load_model (model_class , self ._model_name_or_path , ** kwargs )
150- elif self ._model_name_or_path .startswith ('efficient-speech/lite-whisper' ):
151- model = transformers .AutoModel .from_pretrained (self ._model_name_or_path , ** kwargs )
152150 else :
153- raise ValueError (
154- "The model %s is not supported by the converter. " % self ._model_name_or_path )
151+ model = transformers .AutoModel .from_pretrained (self ._model_name_or_path , ** kwargs )
155152
156153 tokenizer_kwargs = {}
157154 if self ._trust_remote_code :
158155 tokenizer_kwargs ["trust_remote_code" ] = self ._trust_remote_code
159156
160157 tokenizer = self .load_tokenizer (
161- tokenizer_class , self . _model_processor_name , ** tokenizer_kwargs
158+ tokenizer_class , tokenizer_path , ** tokenizer_kwargs
162159 )
163160
164161 spec = loader (model , tokenizer )
@@ -251,19 +248,6 @@ def set_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
251248 spec .weight = spec .weight .transpose (0 , 1 )
252249 if module .bias is not None :
253250 spec .bias = module .bias
254-
255- def set_low_rank_linear (self , spec , module , quant_type = common_spec .Quantization .CT2 ):
256- if quant_type == common_spec .Quantization .CT2 :
257- spec .low_rank_weight_1 = module .weight1
258- spec .low_rank_weight_2 = module .weight2
259- else :
260- spec .low_rank_weight_1 = module .qweight1
261- spec .low_rank_weight_2 = module .qweight2
262- spec .weight_scale = module .scales
263- spec .weight_zero = module .qzeros
264-
265- if module .bias is not None :
266- spec .bias = module .bias
267251
268252 def set_embeddings (self , spec , module ):
269253 spec .weight = module .weight
@@ -1044,10 +1028,45 @@ def get_model_spec(self, model):
10441028
10451029 return spec
10461030
1031+
1032+ def set_config (self , config , model , tokenizer ):
1033+ gen_config = getattr (model , "generation_config" , None )
1034+
1035+ if gen_config is not None :
1036+ config .suppress_ids = gen_config .suppress_tokens
1037+ config .suppress_ids_begin = gen_config .begin_suppress_tokens
1038+ if hasattr (gen_config , "alignment_heads" ):
1039+ config .alignment_heads = gen_config .alignment_heads
1040+ if hasattr (gen_config , "lang_to_id" ):
1041+ config .lang_ids = sorted (gen_config .lang_to_id .values ())
1042+ else :
1043+ config .suppress_ids = model .config .suppress_tokens
1044+ config .suppress_ids_begin = model .config .begin_suppress_tokens
1045+ config .alignment_heads = _WHISPER_ALIGNMENT_HEADS .get (model .name_or_path )
1046+
1047+ if getattr (config , "lang_ids" , None ) is None :
1048+ config .lang_ids = self ._get_lang_ids_from_tokenizer (tokenizer )
1049+
1050+ if config .alignment_heads is None :
1051+ config .alignment_heads = _WHISPER_ALIGNMENT_HEADS .get (model .name_or_path )
1052+ if config .alignment_heads is None :
1053+ # Use the last half layers for alignment by default.
1054+ num_layers = model .config .decoder_layers
1055+ num_heads = model .config .decoder_attention_heads
1056+ config .alignment_heads = list (
1057+ itertools .product (
1058+ range (num_layers // 2 , num_layers ),
1059+ range (num_heads ),
1060+ )
1061+ )
1062+
10471063 def set_encoder (self , spec , encoder ):
1064+ """
1065+ Override encoder mapping for LiteWhisper.
1066+ """
10481067 self .set_conv1d (spec .conv1 , encoder .conv1 )
10491068 self .set_conv1d (spec .conv2 , encoder .conv2 )
1050-
1069+
10511070 self .set_common_layers (spec , encoder )
10521071
10531072 for layer_spec , layer in zip (spec .layer , encoder .layers ):
@@ -1060,29 +1079,42 @@ def set_encoder(self, spec, encoder):
10601079 layer .self_attn_layer_norm ,
10611080 )
10621081
1063- # Double check if these are low rank or not because of potential
1064- # fall backs to full precision.
1065- if hasattr (layer .fc1 , 'weight1' ):
1082+ if hasattr (layer .fc1 , "weight1" ):
1083+ # low rank
10661084 self .set_low_rank_linear (layer_spec .ffn .linear_0 , layer .fc1 )
10671085 else :
10681086 layer_spec .ffn .linear_0 = common_spec .LinearSpec ()
10691087 self .set_linear (layer_spec .ffn .linear_0 , layer .fc1 )
1070-
1071- if hasattr (layer .fc2 , 'weight1' ):
1088+
1089+ if hasattr (layer .fc2 , "weight1" ):
1090+ # low rank
10721091 self .set_low_rank_linear (layer_spec .ffn .linear_1 , layer .fc2 )
10731092 else :
10741093 layer_spec .ffn .linear_1 = common_spec .LinearSpec ()
10751094 self .set_linear (layer_spec .ffn .linear_1 , layer .fc2 )
10761095
10771096 self .set_layer_norm (layer_spec .ffn .layer_norm , layer .final_layer_norm )
10781097
1098+ def set_low_rank_linear (self , spec , module , quant_type = common_spec .Quantization .CT2 ):
1099+ if quant_type == common_spec .Quantization .CT2 :
1100+ spec .low_rank_weight_1 = module .weight1 .transpose (0 , 1 ).contiguous ()
1101+ spec .low_rank_weight_2 = module .weight2 .transpose (0 , 1 ).contiguous ()
1102+ else :
1103+ spec .low_rank_weight_1 = module .qweight1 .transpose (0 , 1 ).contiguous ()
1104+ spec .low_rank_weight_2 = module .qweight2 .transpose (0 , 1 ).contiguous ()
1105+ spec .weight_scale = module .scales
1106+ spec .weight_zero = module .qzeros
1107+
1108+ if module .bias is not None :
1109+ spec .bias = module .bias
1110+
10791111 def set_low_rank_or_linear_router (self , spec , module , i ):
10801112 if hasattr (module , "weight1" ):
10811113 self .set_low_rank_linear (spec .linear [i ], module )
10821114 else :
10831115 spec .linear [i ] = common_spec .LinearSpec ()
10841116 self .set_linear (spec .linear [i ], module )
1085-
1117+
10861118 def set_low_rank_attention (self , spec , attention ):
10871119 self .set_low_rank_or_linear_router (spec , attention .q_proj , 0 )
10881120 self .set_low_rank_or_linear_router (spec , attention .k_proj , 1 )
@@ -3000,6 +3032,7 @@ def main():
30003032 (3 , 4 ),
30013033 ],
30023034 "openai/whisper-tiny" : [(2 , 2 ), (3 , 0 ), (3 , 2 ), (3 , 3 ), (3 , 4 ), (3 , 5 )],
3035+ "efficient-speech/whisper-tiny" : [(2 , 2 ), (3 , 0 ), (3 , 2 ), (3 , 3 ), (3 , 4 ), (3 , 5 )],
30033036 "openai/whisper-base.en" : [(3 , 3 ), (4 , 7 ), (5 , 1 ), (5 , 5 ), (5 , 7 )],
30043037 "openai/whisper-base" : [
30053038 (3 , 1 ),
@@ -3113,4 +3146,16 @@ def main():
31133146 (24 , 1 ),
31143147 (25 , 6 ),
31153148 ],
3149+ "efficient-speech/whisper-large-v3" : [
3150+ (7 , 0 ),
3151+ (10 , 17 ),
3152+ (12 , 18 ),
3153+ (13 , 12 ),
3154+ (16 , 1 ),
3155+ (17 , 14 ),
3156+ (19 , 11 ),
3157+ (21 , 4 ),
3158+ (24 , 1 ),
3159+ (25 , 6 ),
3160+ ],
31163161}
0 commit comments