|
| 1 | +from transformers import VoxtralForConditionalGeneration, AutoProcessor |
| 2 | +import torch |
| 3 | +from torch import nn |
| 4 | + |
| 5 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 6 | +repo_id = "mistralai/Voxtral-Mini-3B-2507" |
| 7 | + |
| 8 | +processor = AutoProcessor.from_pretrained(repo_id) |
| 9 | +model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device) |
| 10 | + |
| 11 | +conversation = [ |
| 12 | + { |
| 13 | + "role": "user", |
| 14 | + "content": [ |
| 15 | + { |
| 16 | + "type": "audio", |
| 17 | + "url": "https://huggingface.co/datasets/eustlb/audio-samples/resolve/main/dude_where_is_my_car.wav", |
| 18 | + }, |
| 19 | + # {"type": "text", "text": "What can you tell me about this audio?"}, |
| 20 | + ], |
| 21 | + } |
| 22 | +] |
| 23 | + |
| 24 | +inputs = processor.apply_chat_template(conversation) |
| 25 | +inputs = inputs.to(device, dtype=torch.bfloat16) |
| 26 | + |
| 27 | +class VoxtralEncoderForExecuTorch(nn.Module): |
| 28 | + def __init__(self, audio_encoder: nn.Module, mm_projector: nn.Module, intermediate_size: int): |
| 29 | + super().__init__() |
| 30 | + self.audio_encoder = audio_encoder |
| 31 | + self.mm_projector = mm_projector |
| 32 | + self.intermediate_size = intermediate_size |
| 33 | + |
| 34 | + def forward(self, input_features: torch.FloatTensor): |
| 35 | + audio_outputs = self.audio_encoder(input_features) |
| 36 | + audio_hidden_states = audio_outputs.last_hidden_state |
| 37 | + audio_hidden_states = audio_hidden_states.reshape(-1, self.intermediate_size) |
| 38 | + audio_embeds = self.mm_projector(audio_hidden_states) |
| 39 | + |
| 40 | + # TODO: add the below two lines after confirming equality with eager. |
| 41 | + # audio_token_mask = input_ids == self.config.audio_token_id |
| 42 | + # inputs_embeds[audio_token_mask] = audio_embeds |
| 43 | + |
| 44 | + return audio_embeds |
| 45 | + |
| 46 | +voxtral_encoder = VoxtralEncoderForExecuTorch( |
| 47 | + model.audio_tower, |
| 48 | + model.multi_modal_projector, |
| 49 | + model.config.audio_config.intermediate_size, |
| 50 | +) |
| 51 | +expected_seq_length = model.audio_tower.config.max_source_positions * model.audio_tower.conv1.stride[0] * model.audio_tower.conv2.stride[0] # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/modeling_voxtral.py#L342, should be equal to 3000. |
| 52 | +sample_encoder_inputs = (torch.rand(1, 128, expected_seq_length),) # Shape of input_features from sample Voxtral audio input from voxtral.md, but with batch size = 1 (representing < 30 seconds of audio). See https://github.com/huggingface/transformers/blob/fbeaf96f9e2291c21277ac658a33ea8752728bf3/src/transformers/models/voxtral/processing_voxtral.py#L91 for more info. |
| 53 | +dynamic_shapes = { |
| 54 | + "input_features": {0: torch.export.Dim.STATIC, 1: torch.export.Dim.STATIC, 2: torch.export.Dim.STATIC} # Arbitrary max batch size. |
| 55 | +} |
| 56 | + |
| 57 | +ep = torch.export.export( |
| 58 | + voxtral_encoder, |
| 59 | + args=sample_encoder_inputs, |
| 60 | + dynamic_shapes=dynamic_shapes, |
| 61 | + strict=True, |
| 62 | +) |
| 63 | + |
| 64 | +eager_output = model.get_audio_embeds(sample_encoder_inputs[0]) |
| 65 | +ep_output = ep.module()(*sample_encoder_inputs) |
| 66 | +torch.allclose(eager_output, ep_output) |
| 67 | + |
| 68 | +# outputs = model.generate(**inputs, max_new_tokens=500) |
| 69 | +# decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) |
| 70 | + |
| 71 | +# print("\nGenerated response:") |
| 72 | +# print("=" * 80) |
| 73 | +# print(decoded_outputs[0]) |
| 74 | +# print("=" * 80) |
0 commit comments