-
Notifications
You must be signed in to change notification settings - Fork 25
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
System Info
import os
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoProcessor, Gemma3nForConditionalGeneration
# Wrapper for the Vision Encoder
class VisionEncoderWrapper(nn.Module):
def __init__(self, vision_encoder, embed_vision, config):
super().__init__()
self.vision_encoder = vision_encoder
self.embed_vision = embed_vision
self.config = config
def forward(self, inputs):
image_features = self.vision_encoder.timm_model.forward_features(inputs)
image_features_reshaped = image_features.reshape(
image_features.shape[0],
self.config.vision_config.hidden_size,
self.config.vision_soft_tokens_per_image,
).permute(0, 2, 1)
return self.embed_vision(inputs_embeds=image_features_reshaped)
# Wrapper for the Audio Encoder
class AudioEncoderWrapper(nn.Module):
def __init__(self, audio_encoder, embed_audio, config):
super().__init__()
self.audio_encoder = audio_encoder
self.embed_audio = embed_audio
self.config = config
def forward(self, inputs_features, inputs_features_mask):
audio_features, audio_mask = model.model.audio_tower(
inputs_features.to(torch.float32), ~inputs_features_mask
)
audio_embeds = self.embed_audio(inputs_embeds=audio_features)
padding_token_id = 262400 - 1
audio_padding_toks = torch.tensor([[padding_token_id]], dtype=torch.long)
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
audio_features = torch.where(
audio_mask.unsqueeze(-1), audio_padding_embs, audio_embeds
)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim
)
return torch.cat((audio_features, extra_padding_features), dim=1)
# Wrapper for the Token Embedder
class TokenEmbedderWrapper(nn.Module):
def __init__(self, text_embedder, per_layer_inputs_model, config):
super().__init__()
self.text_embedder = text_embedder
self.per_layer_inputs_model = per_layer_inputs_model
self.config = config
def forward(self, input_ids):
# This layer converts token IDs to dense vectors
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.config.text_config.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask,
input_ids,
torch.zeros_like(input_ids),
)
per_layer_inputs = self.per_layer_inputs_model(per_layer_inputs_tokens)
return self.text_embedder(input_ids), per_layer_inputs
class LanguageWrapper(nn.Module):
def __init__(self, lauguage_model):
super().__init__()
self.lauguage_model = lauguage_model
def forward(self, inputs_embeds, position_ids, per_layer_inputs, **kwargs):
return self.lauguage_model(inputs_embeds, position_ids, per_layer_inputs, **kwargs)
if __name__ == "__main__":
model_id = "google/gemma-3n-e4b-it"
output_dir = "onnx_exported_model_float32"
os.makedirs(output_dir, exist_ok=True)
print("π Loading processor and model...")
# Using float32 for higher precision
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float32
)
processor = AutoProcessor.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)
# Move model to CPU for export
model.to("cpu")
# === Dummy input ===
# Using the same multimodal inputs as before
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
},
{"type": "text", "text": "Describe this image."},
{
"type": "audio",
"audio": "https://raw.githubusercontent.com/google-gemini/gemma-cookbook/refs/heads/main/Demos/sample-data/shopping3.wav",
},
],
},
]
# Process inputs to get all necessary tensors
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
inputs = {k: v.to("cpu") for k, v in inputs.items()}
# Ensure pixel_values are float32 to match the model
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
# === 1. Export Vision Encoder ===
print("\nπ¦ [1/4] Exporting Vision Encoder...")
vision_wrapper = VisionEncoderWrapper(
model.model.vision_tower, model.model.embed_vision, config
).eval()
torch.onnx.export(
model=vision_wrapper,
args=(inputs["pixel_values"],),
f=os.path.join(output_dir, "vision_encoder.onnx"),
input_names=["pixel_values"],
output_names=["image_features"],
dynamic_axes={
"pixel_values": {0: "batch_size"},
"image_features": {0: "batch_size"},
},
opset_version=21,
dynamo=True,
)
print("β
Vision encoder exported.")
# === 2. Export Audio Encoder ===
print("\nπ [2/4] Exporting Audio Encoder...")
audio_wrapper = AudioEncoderWrapper(
model.model.audio_tower, model.model.embed_audio, config
).eval()
torch.onnx.export(
model=audio_wrapper,
args=(
inputs["input_features"].to(torch.float32),
inputs["input_features_mask"],
),
f=os.path.join(output_dir, "audio_encoder.onnx"),
input_names=["input_features", "input_features_mask"],
output_names=["audio_features"],
dynamic_axes={
"input_features": {0: "batch_size", 1: "sequence_length"},
"input_features_mask": {0: "batch_size", 1: "sequence_length"},
"audio_features": {0: "batch_size"},
},
opset_version=21,
dynamo=True,
)
print(f" - Audio features shape: {inputs['input_features'].shape}")
print("β
Audio encoder exported.")
# === 3. Export Token Embedder ===
print("\nπ [3/4] Exporting Token Embedder...")
token_embedder_wrapper = TokenEmbedderWrapper(
model.model.language_model.embed_tokens,
model.language_model.get_per_layer_inputs,
config,
).eval()
torch.onnx.export(
model=token_embedder_wrapper,
args=(inputs["input_ids"],),
f=os.path.join(output_dir, "token_embedder.onnx"),
input_names=["input_ids"],
output_names=["text_embeds"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"text_embeds": {0: "batch_size", 1: "sequence_length"},
},
opset_version=21,
dynamo=True,
)
#################################################################################
model.language_model.save_pretrained('language-model')`
I am working on exporting the "google/gemma-3n-e4b-it" model to the ONNX format and am encountering issues with the language model (decoder) component.
I have been following the approach outlined in a Colab Notebook on how to export Llava, which was referenced in a related GitHub issue (#38924):
Notebook: https://colab.research.google.com/drive/1mtcxKWHAR7D9LbjzRcTB_b5CURm2jw-z
GitHub Issue: huggingface/transformers#38924
Could you please suggest how to export this model? I have found that a 2B version is available in the ONNX format here: https://huggingface.co/onnx-community/gemma-3n-E2B-it-ONNX/tree/main. I am curious about how to export it myself.
Who can help?
@xenova @zucchini-nlp
need some tutorial about how to export the decoder model π
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working