Skip to content

Commit 86de50d

Browse files
authored
Whisper word-level timestamps (#184)
* Support outputting attentions in generate function * Add unit tests for concatenating tensors * Implement `cat` for `dim>0` * Add `cat` unit tests for > 2 tensors * Allow for negative indexing + bounds checking * Add test case for `cat` with negative indexing * Clean up `safeIndex` helper function * Allow indexing error message to include dimension * Reuse `safeIndex` helper function for `normalize_` * Optimize `cat` indexing * Implement `stack` tensor operation + add unit tests * Add TODOs * Implement `mean` tensor operation * Implement `std_mean` tensor ops * Fix order of `std_mean` returns * Implement median filter * Implement dynamic time warping * Implement `neg` tensor op * Throw error if audio sent to processor is not a `Float32Array` * Add `round` helper function * [WIP] Implement basic version of word-level-timestamps Known issues: - timestamps not correct for index > 0 - punctuation not same as python version * Fix typo * Fix timestamps * Round to 2 decimals * Fix punctuation * Fix typing * Remove debug statements * Cleanup code * Cleanup * Remove debug statements * Update JSDoc for extract token timestamps function * Add return type for `std_mean` tensor function * Improve typing of private whisper tokenizer functions * Indicate method is private * Allow whisper feature extractor to be called with Float64Array input * Fix typo * Throw error if `cross_attentions` are not present in model output when extracting token timestamps * Throw error during generate function * Allow whisper models to be exported with `output_attentions=True` * Add alignment heads to generation config * Remove print statement * Update versions * Override protobufjs version * Update package-lock.json * Require onnx==1.13.1 for conversion Will update once onnxruntime-web supports onnx IR version 9 * Add unit test for word-level timestamps * Extract add attentions function out of `generate` * Fix `findLongestCommonSequence` return types * Downgrade back to onnxruntime 1.14.0 1.15.1 is a little to unstable right now. * Cleanup - use `.map` - rename variables * Update comments * Add examples for how to transcribe w/ word-level timestamps * Add example for transcribing/translating audio longer than 30 seconds * Make example more compact
1 parent aceab9b commit 86de50d

File tree

14 files changed

+1893
-672
lines changed

14 files changed

+1893
-672
lines changed

package-lock.json

Lines changed: 479 additions & 505 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@xenova/transformers",
3-
"version": "2.3.1",
3+
"version": "2.4.0",
44
"description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!",
55
"main": "./src/transformers.js",
66
"types": "./types/transformers.d.ts",
@@ -57,6 +57,10 @@
5757
"webpack-cli": "^5.0.2",
5858
"webpack-dev-server": "^4.13.3"
5959
},
60+
"overrides": {
61+
"semver": "^7.5.4",
62+
"protobufjs": "^7.2.4"
63+
},
6064
"files": [
6165
"src",
6266
"dist",

scripts/convert.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
AutoTokenizer,
1212
HfArgumentParser
1313
)
14-
from transformers.utils import cached_file
1514

1615
import onnx
1716
from optimum.exporters.onnx import main_export
@@ -21,6 +20,18 @@
2120
QuantType
2221
)
2322

23+
DEFAULT_QUANTIZE_PARAMS = {
24+
'per_channel': True,
25+
'reduce_range': True,
26+
}
27+
28+
MODEL_SPECIFIC_QUANTIZE_PARAMS = {
29+
'whisper': {
30+
'per_channel': False,
31+
'reduce_range': False,
32+
}
33+
}
34+
2435

2536
@dataclass
2637
class ConversionArguments:
@@ -79,18 +90,25 @@ class ConversionArguments:
7990
)
8091

8192
per_channel: bool = field(
82-
default=True,
93+
default=None,
8394
metadata={
8495
"help": "Whether to quantize weights per channel"
8596
}
8697
)
8798
reduce_range: bool = field(
88-
default=True,
99+
default=None,
89100
metadata={
90101
"help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode"
91102
}
92103
)
93104

105+
output_attentions: bool = field(
106+
default=False,
107+
metadata={
108+
"help": "Whether to output attentions from the model. NOTE: This is only supported for whisper models right now."
109+
}
110+
)
111+
94112

95113
def get_operators(model: onnx.ModelProto) -> Set[str]:
96114
operators = set()
@@ -107,7 +125,7 @@ def traverse_graph(graph):
107125
return operators
108126

109127

110-
def quantize(model_names_or_paths, conv_args: ConversionArguments):
128+
def quantize(model_names_or_paths, **quantize_kwargs):
111129
"""
112130
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
113131
@@ -119,9 +137,8 @@ def quantize(model_names_or_paths, conv_args: ConversionArguments):
119137
Returns: The Path generated for the quantized
120138
"""
121139

122-
quant_config = dict(
123-
per_channel=conv_args.per_channel,
124-
reduce_range=conv_args.reduce_range,
140+
quantize_config = dict(
141+
**quantize_kwargs,
125142
per_model_config={}
126143
)
127144

@@ -148,34 +165,25 @@ def quantize(model_names_or_paths, conv_args: ConversionArguments):
148165
model_output=os.path.join(
149166
directory_path, f'{file_name_without_extension}_quantized.onnx'),
150167

151-
per_channel=conv_args.per_channel,
152-
reduce_range=conv_args.reduce_range,
153-
154168
weight_type=weight_type,
155169
optimize_model=False,
156170

157171
# TODO allow user to specify these
158172
# op_types_to_quantize=['MatMul', 'Add', 'Conv'],
159173
extra_options=dict(
160174
EnableSubgraph=True
161-
)
175+
),
176+
**quantize_kwargs
162177
)
163178

164-
quant_config['per_model_config'][file_name_without_extension] = dict(
179+
quantize_config['per_model_config'][file_name_without_extension] = dict(
165180
op_types=list(op_types),
166181
weight_type=str(weight_type),
167182
)
168183

169184
# Save quantization config
170-
with open(os.path.join(directory_path, 'quant_config.json'), 'w') as fp:
171-
json.dump(quant_config, fp, indent=4)
172-
173-
174-
def copy_if_exists(model_path, file_name, destination):
175-
file = cached_file(model_path, file_name,
176-
_raise_exceptions_for_missing_entries=False)
177-
if file is not None:
178-
shutil.copy(file, destination)
185+
with open(os.path.join(directory_path, 'quantize_config.json'), 'w') as fp:
186+
json.dump(quantize_config, fp, indent=4)
179187

180188

181189
def main():
@@ -192,35 +200,18 @@ def main():
192200
# Create output folder
193201
os.makedirs(output_model_folder, exist_ok=True)
194202

195-
# Copy certain JSON files, which save_pretrained doesn't handle
196-
# copy_if_exists(model_id, 'tokenizer.json', output_model_folder)
197-
198-
# copy_if_exists(model_id, 'preprocessor_config.json', output_model_folder)
199-
# copy_if_exists(model_id, 'generation_config.json', output_model_folder)
200-
201-
# # Saving the model config
203+
# Saving the model config
202204
config = AutoConfig.from_pretrained(model_id)
203-
# config.save_pretrained(output_model_folder)
204205

206+
tokenizer = None
205207
try:
206208
# Save tokenizer
207209
tokenizer = AutoTokenizer.from_pretrained(model_id)
208-
# tokenizer.save_pretrained(output_model_folder)
209-
210-
# Handle special cases
211-
if config.model_type == 'marian':
212-
import json
213-
from .extra.marian import generate_tokenizer_json
214-
tokenizer_json = generate_tokenizer_json(model_id, tokenizer)
215-
216-
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
217-
json.dump(tokenizer_json, fp)
218210

219211
except KeyError:
220212
pass # No Tokenizer
221213

222-
# Step 1. convert huggingface model to onnx
223-
main_export(
214+
export_kwargs = dict(
224215
model_name_or_path=model_id,
225216
output=output_model_folder,
226217
task=conv_args.task,
@@ -229,21 +220,54 @@ def main():
229220
do_validation=not conv_args.skip_validation,
230221
)
231222

223+
# Handle special cases
224+
if config.model_type == 'marian':
225+
from .extra.marian import generate_tokenizer_json
226+
tokenizer_json = generate_tokenizer_json(model_id, tokenizer)
227+
228+
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
229+
json.dump(tokenizer_json, fp)
230+
231+
elif config.model_type == 'whisper':
232+
if conv_args.output_attentions:
233+
from .extra.whisper import get_main_export_kwargs
234+
235+
export_kwargs.update(
236+
**get_main_export_kwargs(config, "automatic-speech-recognition")
237+
)
238+
else:
239+
pass # TODO
240+
241+
# Step 1. convert huggingface model to onnx
242+
main_export(**export_kwargs)
243+
232244
# Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size.
233245
if conv_args.quantize:
246+
# Update quantize config with model specific defaults
247+
quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS.get(
248+
config.model_type, DEFAULT_QUANTIZE_PARAMS)
249+
234250
quantize([
235251
os.path.join(output_model_folder, x)
236252
for x in os.listdir(output_model_folder)
237253
if x.endswith('.onnx') and not x.endswith('_quantized.onnx')
238-
], conv_args)
254+
], **quantize_config)
239255

240256
# Step 3. Move .onnx files to the 'onnx' subfolder
241257
os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True)
242258
for file in os.listdir(output_model_folder):
243-
if file.endswith('.onnx') or file.endswith('.onnx_data'):
259+
if file.endswith(('.onnx', '.onnx_data')):
244260
shutil.move(os.path.join(output_model_folder, file),
245261
os.path.join(output_model_folder, 'onnx', file))
246262

263+
# Step 4. Update the generation config if necessary
264+
if config.model_type == 'whisper':
265+
from transformers import GenerationConfig
266+
from .extra.whisper import get_alignment_heads
267+
268+
generation_config = GenerationConfig.from_pretrained(model_id)
269+
generation_config.alignment_heads = get_alignment_heads(config)
270+
generation_config.save_pretrained(output_model_folder)
247271

248272
if __name__ == '__main__':
249273
main()

scripts/extra/whisper.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from optimum.exporters.onnx.model_configs import WhisperOnnxConfig
2+
3+
from optimum.exporters.onnx.base import ConfigBehavior
4+
from typing import Dict
5+
6+
# List of [layer, head] pairs that select the cross-attention heads that are highly correlated to word-level timing.
7+
# Source: https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a
8+
ALIGNMENT_HEADS_MAPPING = {
9+
'whisper-tiny.en': [[1, 0], [2, 0], [2, 5], [3, 0], [3, 1], [3, 2], [3, 3], [3, 4]],
10+
'whisper-tiny': [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]],
11+
'whisper-base.en': [[3, 3], [4, 7], [5, 1], [5, 5], [5, 7]],
12+
'whisper-base': [[3, 1], [4, 2], [4, 3], [4, 7], [5, 1], [5, 2], [5, 4], [5, 6]],
13+
'whisper-small.en': [[6, 6], [7, 0], [7, 3], [7, 8], [8, 2], [8, 5], [8, 7], [9, 0], [9, 4], [9, 8], [9, 10], [10, 0], [10, 1], [10, 2], [10, 3], [10, 6], [10, 11], [11, 2], [11, 4]],
14+
'whisper-small': [[5, 3], [5, 9], [8, 0], [8, 4], [8, 7], [8, 8], [9, 0], [9, 7], [9, 9], [10, 5]],
15+
'whisper-medium.en': [[11, 4], [14, 1], [14, 12], [14, 14], [15, 4], [16, 0], [16, 4], [16, 9], [17, 12], [17, 14], [18, 7], [18, 10], [18, 15], [20, 0], [20, 3], [20, 9], [20, 14], [21, 12]],
16+
'whisper-medium': [[13, 15], [15, 4], [15, 15], [16, 1], [20, 0], [23, 4]],
17+
'whisper-large-v2': [[10, 12], [13, 17], [16, 11], [16, 12], [16, 13], [17, 15], [17, 16], [18, 4], [18, 11], [18, 19], [19, 11], [21, 2], [21, 3], [22, 3], [22, 9], [22, 12], [23, 5], [23, 7], [23, 13], [25, 5], [26, 1], [26, 12], [27, 15]],
18+
'whisper-large': [[9, 19], [11, 2], [11, 4], [11, 17], [22, 7], [22, 11], [22, 17], [23, 2], [23, 15]],
19+
}
20+
21+
22+
class CustomWhisperOnnxConfig(WhisperOnnxConfig):
23+
@property
24+
def outputs(self) -> Dict[str, Dict[int, str]]:
25+
common_outputs = super().outputs
26+
27+
if self._behavior is ConfigBehavior.ENCODER:
28+
for i in range(self._config.encoder_layers):
29+
common_outputs[f"encoder_attentions.{i}"] = {0: "batch_size"}
30+
elif self._behavior is ConfigBehavior.DECODER:
31+
for i in range(self._config.decoder_layers):
32+
common_outputs[f"decoder_attentions.{i}"] = {
33+
0: "batch_size", 3: "decoder_sequence_length"}
34+
for i in range(self._config.decoder_layers):
35+
common_outputs[f"cross_attentions.{i}"] = {
36+
0: "batch_size", 3: "cross_attention_length"}
37+
38+
return common_outputs
39+
40+
@property
41+
def torch_to_onnx_output_map(self):
42+
if self._behavior is ConfigBehavior.ENCODER:
43+
# The encoder export uses WhisperEncoder that returns the key "attentions"
44+
return {"attentions": "encoder_attentions"}
45+
else:
46+
return {}
47+
48+
49+
def get_main_export_kwargs(config, task):
50+
51+
custom_config = CustomWhisperOnnxConfig(config=config, task=task)
52+
53+
custom_onnx_configs = dict(
54+
encoder_model=custom_config.with_behavior("encoder"),
55+
decoder_model=custom_config.with_behavior("decoder", use_past=False),
56+
decoder_with_past_model=custom_config.with_behavior(
57+
"decoder", use_past=True),
58+
)
59+
60+
return dict(
61+
model_kwargs={"output_attentions": True},
62+
custom_onnx_configs=custom_onnx_configs,
63+
)
64+
65+
66+
def get_alignment_heads(config):
67+
if getattr(config, '_name_or_path', None) is None:
68+
raise ValueError(
69+
"Unable to determine model type from config. Please specify `_name_or_path` in the config.")
70+
71+
for model_name, heads in ALIGNMENT_HEADS_MAPPING.items():
72+
if model_name in config._name_or_path:
73+
return heads
74+
75+
raise ValueError(
76+
f"Unknown model type: {config._name_or_path}. Please add one of the following model types to `_name_or_path` in the config file: {list(ALIGNMENT_HEADS_MAPPING.keys())}")

scripts/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
transformers[torch]@git+https://github.com/huggingface/transformers
22
optimum[onnxruntime]@git+https://github.com/huggingface/optimum
33
tqdm
4+
onnx==1.13.1

0 commit comments

Comments
 (0)