Skip to content

Commit e116774

Browse files
committed
Export voxtral encoder
1 parent a624083 commit e116774

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

examples/models/voxtral/export.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
2+
import torch
3+
from torch import nn
4+
5+
device = "cuda" if torch.cuda.is_available() else "cpu"
6+
repo_id = "mistralai/Voxtral-Mini-3B-2507"
7+
8+
processor = AutoProcessor.from_pretrained(repo_id)
9+
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device)
10+
11+
conversation = [
12+
{
13+
"role": "user",
14+
"content": [
15+
{
16+
"type": "audio",
17+
"url": "https://huggingface.co/datasets/eustlb/audio-samples/resolve/main/dude_where_is_my_car.wav",
18+
},
19+
# {"type": "text", "text": "What can you tell me about this audio?"},
20+
],
21+
}
22+
]
23+
24+
inputs = processor.apply_chat_template(conversation)
25+
inputs = inputs.to(device, dtype=torch.bfloat16)
26+
27+
class VoxtralEncoderForExecuTorch(nn.Module):
28+
def __init__(self, audio_encoder: nn.Module, mm_projector: nn.Module, intermediate_size: int):
29+
super().__init__()
30+
self.audio_encoder = audio_encoder
31+
self.mm_projector = mm_projector
32+
self.intermediate_size = intermediate_size
33+
34+
def forward(self, input_features: torch.FloatTensor):
35+
audio_outputs = self.audio_encoder(input_features)
36+
audio_hidden_states = audio_outputs.last_hidden_state
37+
audio_hidden_states = audio_hidden_states.reshape(-1, self.intermediate_size)
38+
audio_embeds = self.mm_projector(audio_hidden_states)
39+
40+
# TODO: add the below two lines after confirming equality with eager.
41+
# audio_token_mask = input_ids == self.config.audio_token_id
42+
# inputs_embeds[audio_token_mask] = audio_embeds
43+
44+
return audio_embeds
45+
46+
voxtral_encoder = VoxtralEncoderForExecuTorch(
47+
model.audio_tower,
48+
model.multi_modal_projector,
49+
model.config.audio_config.intermediate_size,
50+
)
51+
expected_seq_length = model.audio_tower.config.max_source_positions * model.audio_tower.conv1.stride[0] * model.audio_tower.conv2.stride[0] # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/modeling_voxtral.py#L342, should be equal to 3000.
52+
sample_encoder_inputs = (torch.rand(1, 128, expected_seq_length),) # Shape of input_features from sample Voxtral audio input from voxtral.md, but with batch size = 1 (representing < 30 seconds of audio). See https://github.com/huggingface/transformers/blob/fbeaf96f9e2291c21277ac658a33ea8752728bf3/src/transformers/models/voxtral/processing_voxtral.py#L91 for more info.
53+
dynamic_shapes = {
54+
"input_features": {0: torch.export.Dim.STATIC, 1: torch.export.Dim.STATIC, 2: torch.export.Dim.STATIC}
55+
}
56+
57+
ep = torch.export.export(
58+
voxtral_encoder,
59+
args=sample_encoder_inputs,
60+
dynamic_shapes=dynamic_shapes,
61+
strict=True,
62+
)
63+
64+
eager_output = model.get_audio_embeds(sample_encoder_inputs[0])
65+
ep_output = ep.module()(*sample_encoder_inputs)
66+
torch.allclose(eager_output, ep_output)
67+
68+
# outputs = model.generate(**inputs, max_new_tokens=500)
69+
# decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
70+
71+
# print("\nGenerated response:")
72+
# print("=" * 80)
73+
# print(decoded_outputs[0])
74+
# print("=" * 80)

0 commit comments

Comments
 (0)