Skip to content

Commit f06e66a

Browse files
committed
input dimensions seem correct, but getting a value error
1 parent 9c119ea commit f06e66a

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

examples/models/whisper_tiny/model.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
class WhisperTinyModel(EagerModelBase):
1919
def __init__(self):
20+
#self.max_cache_length=1024
21+
#self.batch_size=1
2022
pass
2123

2224
def get_eager_model(self) -> torch.nn.Module:
@@ -28,13 +30,27 @@ def get_eager_model(self) -> torch.nn.Module:
2830
return model
2931

3032
def get_example_inputs(self):
33+
#input_ids = torch.tensor([[0]], dtype=torch.long)
34+
#encoder_hidden_states = torch.rand(1, 1500, 384)
35+
#cache_position = torch.tensor([0], dtype=torch.long)
36+
#atten_mask = torch.full((1, self.max_cache_length), torch.tensor(-255.0))
37+
#atten_mask *= torch.arange(self.max_cache_length) > cache_position.reshape(
38+
# -1, 1
39+
#)
40+
#atten_mask = atten_mask[None, None, :, :].expand(self.batch_size, 1, -1, -1)
41+
#return (input_ids, atten_mask, encoder_hidden_states, cache_position)
42+
3143
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
3244
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", return_dict=False)
3345
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
3446
inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
3547
input_features = inputs.input_features
48+
#expected_shape = (1, processor.feature_extractor.feature_size, processor.feature_extractor.nb_max_frames)
49+
#print("Expected shape: " + str(expected_shape))
50+
print("Input features has shape: " + str(input_features.shape))
3651
#generated_ids = model.generate(inputs=input_features)
37-
return (input_features[0],) #(generated_ids,)
52+
#return (torch.rand(expected_shape),) #(input_features,) #(generated_ids,)
53+
return (input_features,) #(generated_ids,)
3854

3955
#feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
4056
#ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

0 commit comments

Comments
 (0)