Skip to content

Commit 8bef102

Browse files
authored
Update conversion script versions (#1204)
* Update conversion script versions * Use custom float16 converter script * Add MIT license header * Use relative import * Fix transformers to stable 4.48.x branch * Add strictness check when saving model * Create new output nodes after casts * Use onnxslim after fp16 conversion * Prevent in-place modification while iterating (infinite loops) * Finalize fp16 quantization script * Only warn if alignment heads can't be found
1 parent c2ab81a commit 8bef102

File tree

6 files changed

+1007
-14
lines changed

6 files changed

+1007
-14
lines changed

scripts/convert.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,9 @@ def main():
446446

447447
generation_config = GenerationConfig.from_pretrained(
448448
model_id, **from_pretrained_kwargs)
449-
generation_config.alignment_heads = get_alignment_heads(config)
449+
alignment_heads = get_alignment_heads(config)
450+
if alignment_heads:
451+
generation_config.alignment_heads = alignment_heads
450452
generation_config.save_pretrained(output_model_folder)
451453

452454

scripts/extra/whisper.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
from optimum.exporters.onnx.base import ConfigBehavior
44
from typing import Dict
5+
import logging
6+
7+
logger = logging.getLogger(__name__)
58

69
# List of [layer, head] pairs that select the cross-attention heads that are highly correlated to word-level timing.
710
# Source: https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a
@@ -57,12 +60,13 @@ def get_main_export_kwargs(config, task):
5760

5861
def get_alignment_heads(config):
5962
if getattr(config, '_name_or_path', None) is None:
60-
raise ValueError(
63+
logger.warning(
6164
"Unable to determine model type from config. Please specify `_name_or_path` in the config.")
65+
return None
6266

6367
for model_name, heads in ALIGNMENT_HEADS_MAPPING.items():
6468
if model_name in config._name_or_path:
6569
return heads
6670

67-
raise ValueError(
68-
f"Unknown model type: {config._name_or_path}. Please add one of the following model types to `_name_or_path` in the config file: {list(ALIGNMENT_HEADS_MAPPING.keys())}")
71+
logger.warning(f"Unknown model type: {config._name_or_path}. Please add one of the following model types to `_name_or_path` in the config file: {list(ALIGNMENT_HEADS_MAPPING.keys())}")
72+
return None

0 commit comments

Comments
 (0)