1717
1818class WhisperTinyModel (EagerModelBase ):
1919 def __init__ (self ):
20+ #self.max_cache_length=1024
21+ #self.batch_size=1
2022 pass
2123
2224 def get_eager_model (self ) -> torch .nn .Module :
@@ -28,13 +30,27 @@ def get_eager_model(self) -> torch.nn.Module:
2830 return model
2931
3032 def get_example_inputs (self ):
33+ #input_ids = torch.tensor([[0]], dtype=torch.long)
34+ #encoder_hidden_states = torch.rand(1, 1500, 384)
35+ #cache_position = torch.tensor([0], dtype=torch.long)
36+ #atten_mask = torch.full((1, self.max_cache_length), torch.tensor(-255.0))
37+ #atten_mask *= torch.arange(self.max_cache_length) > cache_position.reshape(
38+ # -1, 1
39+ #)
40+ #atten_mask = atten_mask[None, None, :, :].expand(self.batch_size, 1, -1, -1)
41+ #return (input_ids, atten_mask, encoder_hidden_states, cache_position)
42+
3143 processor = AutoProcessor .from_pretrained ("openai/whisper-tiny.en" )
3244 model = WhisperForConditionalGeneration .from_pretrained ("openai/whisper-tiny.en" , return_dict = False )
3345 ds = load_dataset ("hf-internal-testing/librispeech_asr_dummy" , "clean" , split = "validation" )
3446 inputs = processor (ds [0 ]["audio" ]["array" ], return_tensors = "pt" )
3547 input_features = inputs .input_features
48+ #expected_shape = (1, processor.feature_extractor.feature_size, processor.feature_extractor.nb_max_frames)
49+ #print("Expected shape: " + str(expected_shape))
50+ print ("Input features has shape: " + str (input_features .shape ))
3651 #generated_ids = model.generate(inputs=input_features)
37- return (input_features [0 ],) #(generated_ids,)
52+ #return (torch.rand(expected_shape),) #(input_features,) #(generated_ids,)
53+ return (input_features ,) #(generated_ids,)
3854
3955 #feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
4056 #ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
0 commit comments