Skip to content

Commit 62159eb

Browse files
committed
Fix CustomWhisperOnnxConfig
1 parent 0c2dcc7 commit 62159eb

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

scripts/extra/whisper.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,16 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
3030
elif self._behavior is ConfigBehavior.DECODER:
3131
for i in range(self._config.decoder_layers):
3232
common_outputs[f"decoder_attentions.{i}"] = {
33-
0: "batch_size", 3: "decoder_sequence_length"}
33+
0: "batch_size",
34+
2: "decoder_sequence_length",
35+
3: "past_decoder_sequence_length + 1"
36+
}
3437
for i in range(self._config.decoder_layers):
3538
common_outputs[f"cross_attentions.{i}"] = {
36-
0: "batch_size", 3: "cross_attention_length"}
39+
0: "batch_size",
40+
2: "decoder_sequence_length",
41+
3: "encoder_sequence_length_out"
42+
}
3743

3844
return common_outputs
3945

@@ -48,6 +54,7 @@ def torch_to_onnx_output_map(self):
4854

4955
def get_main_export_kwargs(config, task):
5056

57+
# See https://github.com/huggingface/optimum/blob/a39b1f5637af9725c0c788b86ca1fdf71ad3dcc2/docs/source/exporters/onnx/usage_guides/export_a_model.mdx#L264
5158
custom_config = CustomWhisperOnnxConfig(config=config, task=task)
5259

5360
custom_onnx_configs = dict(

0 commit comments

Comments
 (0)