Skip to content

Commit 1a0c464

Browse files
committed
Using WhisperForConditionalGeneration instead of WhisperModel; seems to be the more correct thing
1 parent 3e5aa6b commit 1a0c464

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

examples/models/whisper_tiny/model.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
from transformers import AutoFeatureExtractor, WhisperModel # @manual
12+
from transformers import AutoProcessor, WhisperForConditionalGeneration # @manual
1213
from datasets import load_dataset
1314

1415
from ..model_base import EagerModelBase
@@ -21,18 +22,29 @@ def __init__(self):
2122
def get_eager_model(self) -> torch.nn.Module:
2223
logging.info("Loading whipser-tiny model")
2324
# pyre-ignore
24-
model = WhisperModel.from_pretrained("openai/whisper-tiny", return_dict=False)
25+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", return_dict=False)
2526
model.eval()
2627
logging.info("Loaded whisper-tiny model")
2728
return model
2829

2930
def get_example_inputs(self):
30-
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
31+
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
32+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", return_dict=False)
3133
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
32-
inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
33-
print(inputs)
34-
print(inputs.input_features)
35-
return (inputs.input_features,)
34+
inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
35+
input_features = inputs.input_features
36+
#generated_ids = model.generate(inputs=input_features)
37+
return (input_features[0],) #(generated_ids,)
38+
39+
#feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
40+
#ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
41+
#inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
42+
#print(inputs)
43+
#print(inputs.input_features)
44+
#print(inputs.input_features.shape)
45+
#decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
46+
47+
#return (inputs.input_features,decoder_input_ids)
3648
# Raw audio input: 1 second of 16kHz audio
3749
#input_values = torch.randn(1, 16000)
3850
#print(input_values)

0 commit comments

Comments
 (0)