diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 76469846608..329580594aa 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -37,6 +37,7 @@ class Model(str, Enum): EfficientSam = "efficient_sam" Qwen25 = "qwen2_5" Phi4Mini = "phi_4_mini" + WhisperTiny = "whisper_tiny" def __str__(self) -> str: return self.value @@ -82,6 +83,7 @@ def __str__(self) -> str: str(Model.EfficientSam): ("efficient_sam", "EfficientSAM"), str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"), str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"), + str(Model.WhisperTiny): ("whisper_tiny", "WhisperTinyModel"), } __all__ = [ diff --git a/examples/models/whisper_tiny/__init__.py b/examples/models/whisper_tiny/__init__.py new file mode 100644 index 00000000000..ca800c7cad4 --- /dev/null +++ b/examples/models/whisper_tiny/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import WhisperTinyModel + +__all__ = [ + "WhisperTinyModel", +] diff --git a/examples/models/whisper_tiny/model.py b/examples/models/whisper_tiny/model.py new file mode 100644 index 00000000000..e7efc27fa2f --- /dev/null +++ b/examples/models/whisper_tiny/model.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +from transformers import AutoFeatureExtractor, WhisperModel # @manual +from transformers import AutoProcessor, WhisperForConditionalGeneration # @manual +from datasets import load_dataset + +from ..model_base import EagerModelBase + + +class WhisperTinyModel(EagerModelBase): + def __init__(self): + #self.max_cache_length=1024 + #self.batch_size=1 + pass + + def get_eager_model(self) -> torch.nn.Module: + logging.info("Loading whipser-tiny model") + # pyre-ignore + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", return_dict=False) + model.eval() + logging.info("Loaded whisper-tiny model") + return model + + def get_example_inputs(self): + #input_ids = torch.tensor([[0]], dtype=torch.long) + #encoder_hidden_states = torch.rand(1, 1500, 384) + #cache_position = torch.tensor([0], dtype=torch.long) + #atten_mask = torch.full((1, self.max_cache_length), torch.tensor(-255.0)) + #atten_mask *= torch.arange(self.max_cache_length) > cache_position.reshape( + # -1, 1 + #) + #atten_mask = atten_mask[None, None, :, :].expand(self.batch_size, 1, -1, -1) + #return (input_ids, atten_mask, encoder_hidden_states, cache_position) + + processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", return_dict=False) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + input_features = inputs.input_features + #expected_shape = (1, processor.feature_extractor.feature_size, processor.feature_extractor.nb_max_frames) + #print("Expected shape: " + str(expected_shape)) + print("Input features has shape: " + str(input_features.shape)) + #generated_ids = model.generate(inputs=input_features) + #return (torch.rand(expected_shape),) #(input_features,) #(generated_ids,) + return (input_features,) #(generated_ids,) + + #feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny") + #ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + #inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + #print(inputs) + #print(inputs.input_features) + #print(inputs.input_features.shape) + #decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + + #return (inputs.input_features,decoder_input_ids) + # Raw audio input: 1 second of 16kHz audio + #input_values = torch.randn(1, 16000) + #print(input_values) + #return (input_values,) diff --git a/export.log b/export.log new file mode 100644 index 00000000000..f3fa834ece2 --- /dev/null +++ b/export.log @@ -0,0 +1 @@ +Input features has shape: torch.Size([1, 80, 3000]) diff --git a/requirements-examples.txt b/requirements-examples.txt index 7426df861a2..3cab53469c3 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -5,3 +5,4 @@ timm == 1.0.7 torchsr == 1.0.4 torchtune >= 0.6.1 transformers >= 4.53.1 +librosa >= 0.11.0