Skip to content

Commit 3c23ade

Browse files
committed
Adding in whisper tiny export script in examples
1 parent 67154d0 commit 3c23ade

File tree

4 files changed

+53
-0
lines changed

4 files changed

+53
-0
lines changed

examples/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Model(str, Enum):
3737
EfficientSam = "efficient_sam"
3838
Qwen25 = "qwen2_5"
3939
Phi4Mini = "phi_4_mini"
40+
WhisperTiny = "whisper_tiny"
4041

4142
def __str__(self) -> str:
4243
return self.value
@@ -82,6 +83,7 @@ def __str__(self) -> str:
8283
str(Model.EfficientSam): ("efficient_sam", "EfficientSAM"),
8384
str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"),
8485
str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"),
86+
str(Model.WhisperTiny): ("whisper_tiny", "WhisperTinyModel"),
8587
}
8688

8789
__all__ = [
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import WhisperTinyModel
8+
9+
__all__ = [
10+
"WhisperTinyModel",
11+
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
11+
from transformers import AutoFeatureExtractor, WhisperModel # @manual
12+
from datasets import load_dataset
13+
14+
from ..model_base import EagerModelBase
15+
16+
17+
class WhisperTinyModel(EagerModelBase):
18+
def __init__(self):
19+
pass
20+
21+
def get_eager_model(self) -> torch.nn.Module:
22+
logging.info("Loading whipser-tiny model")
23+
# pyre-ignore
24+
model = WhisperModel.from_pretrained("openai/whisper-tiny", return_dict=False)
25+
model.eval()
26+
logging.info("Loaded whisper-tiny model")
27+
return model
28+
29+
def get_example_inputs(self):
30+
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
31+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
32+
inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
33+
print(inputs)
34+
print(inputs.input_features)
35+
return (inputs.input_features,)
36+
# Raw audio input: 1 second of 16kHz audio
37+
#input_values = torch.randn(1, 16000)
38+
#print(input_values)
39+
#return (input_values,)

requirements-examples.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ timm == 1.0.7
55
torchsr == 1.0.4
66
torchtune >= 0.6.1
77
transformers >= 4.53.1
8+
librosa >= 0.11.0

0 commit comments

Comments
 (0)