99import torch
1010
1111from transformers import AutoFeatureExtractor , WhisperModel # @manual
12+ from transformers import AutoProcessor , WhisperForConditionalGeneration # @manual
1213from datasets import load_dataset
1314
1415from ..model_base import EagerModelBase
@@ -21,18 +22,29 @@ def __init__(self):
2122 def get_eager_model (self ) -> torch .nn .Module :
2223 logging .info ("Loading whipser-tiny model" )
2324 # pyre-ignore
24- model = WhisperModel .from_pretrained ("openai/whisper-tiny" , return_dict = False )
25+ model = WhisperForConditionalGeneration .from_pretrained ("openai/whisper-tiny.en " , return_dict = False )
2526 model .eval ()
2627 logging .info ("Loaded whisper-tiny model" )
2728 return model
2829
2930 def get_example_inputs (self ):
30- feature_extractor = AutoFeatureExtractor .from_pretrained ("openai/whisper-tiny" )
31+ processor = AutoProcessor .from_pretrained ("openai/whisper-tiny.en" )
32+ model = WhisperForConditionalGeneration .from_pretrained ("openai/whisper-tiny.en" , return_dict = False )
3133 ds = load_dataset ("hf-internal-testing/librispeech_asr_dummy" , "clean" , split = "validation" )
32- inputs = feature_extractor (ds [0 ]["audio" ]["array" ], return_tensors = "pt" )
33- print (inputs )
34- print (inputs .input_features )
35- return (inputs .input_features ,)
34+ inputs = processor (ds [0 ]["audio" ]["array" ], return_tensors = "pt" )
35+ input_features = inputs .input_features
36+ #generated_ids = model.generate(inputs=input_features)
37+ return (input_features [0 ],) #(generated_ids,)
38+
39+ #feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
40+ #ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
41+ #inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
42+ #print(inputs)
43+ #print(inputs.input_features)
44+ #print(inputs.input_features.shape)
45+ #decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
46+
47+ #return (inputs.input_features,decoder_input_ids)
3648 # Raw audio input: 1 second of 16kHz audio
3749 #input_values = torch.randn(1, 16000)
3850 #print(input_values)
0 commit comments