Skip to content

Exporting google/gemma-3n-e4b-it language_model (decoder) into ONNX formatΒ #56

@arkadaz

Description

@arkadaz

System Info

import os

import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoProcessor, Gemma3nForConditionalGeneration


# Wrapper for the Vision Encoder
class VisionEncoderWrapper(nn.Module):
    def __init__(self, vision_encoder, embed_vision, config):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.embed_vision = embed_vision
        self.config = config

    def forward(self, inputs):
        image_features = self.vision_encoder.timm_model.forward_features(inputs)
        image_features_reshaped = image_features.reshape(
            image_features.shape[0],
            self.config.vision_config.hidden_size,
            self.config.vision_soft_tokens_per_image,
        ).permute(0, 2, 1)
        return self.embed_vision(inputs_embeds=image_features_reshaped)


# Wrapper for the Audio Encoder
class AudioEncoderWrapper(nn.Module):
    def __init__(self, audio_encoder, embed_audio, config):
        super().__init__()
        self.audio_encoder = audio_encoder
        self.embed_audio = embed_audio
        self.config = config

    def forward(self, inputs_features, inputs_features_mask):
        audio_features, audio_mask = model.model.audio_tower(
            inputs_features.to(torch.float32), ~inputs_features_mask
        )
        audio_embeds = self.embed_audio(inputs_embeds=audio_features)
        padding_token_id = 262400 - 1
        audio_padding_toks = torch.tensor([[padding_token_id]], dtype=torch.long)
        audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
        audio_features = torch.where(
            audio_mask.unsqueeze(-1), audio_padding_embs, audio_embeds
        )

        audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
        extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
        extra_padding_features = audio_padding_embs.expand(
            audio_batch_size, extra_padding_tokens, audio_embed_dim
        )
        return torch.cat((audio_features, extra_padding_features), dim=1)


# Wrapper for the Token Embedder
class TokenEmbedderWrapper(nn.Module):
    def __init__(self, text_embedder, per_layer_inputs_model, config):
        super().__init__()
        self.text_embedder = text_embedder
        self.per_layer_inputs_model = per_layer_inputs_model
        self.config = config

    def forward(self, input_ids):
        # This layer converts token IDs to dense vectors
        per_layer_inputs_mask = torch.logical_and(
            input_ids >= 0,
            input_ids < self.config.text_config.vocab_size_per_layer_input,
        )
        per_layer_inputs_tokens = torch.where(
            per_layer_inputs_mask,
            input_ids,
            torch.zeros_like(input_ids),
        )
        per_layer_inputs = self.per_layer_inputs_model(per_layer_inputs_tokens)
        return self.text_embedder(input_ids), per_layer_inputs

class LanguageWrapper(nn.Module):
    def __init__(self, lauguage_model):
        super().__init__()
        self.lauguage_model = lauguage_model

    def forward(self, inputs_embeds, position_ids, per_layer_inputs, **kwargs):
        return  self.lauguage_model(inputs_embeds, position_ids, per_layer_inputs, **kwargs)

if __name__ == "__main__":

    model_id = "google/gemma-3n-e4b-it"
    output_dir = "onnx_exported_model_float32"
    os.makedirs(output_dir, exist_ok=True)

    print("πŸ”„ Loading processor and model...")
    # Using float32 for higher precision
    model = Gemma3nForConditionalGeneration.from_pretrained(
        model_id, torch_dtype=torch.float32
    )
    processor = AutoProcessor.from_pretrained(model_id)
    config = AutoConfig.from_pretrained(model_id)

    # Move model to CPU for export
    model.to("cpu")

    # === Dummy input ===
    # Using the same multimodal inputs as before
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
                },
                {"type": "text", "text": "Describe this image."},
                {
                    "type": "audio",
                    "audio": "https://raw.githubusercontent.com/google-gemini/gemma-cookbook/refs/heads/main/Demos/sample-data/shopping3.wav",
                },
            ],
        },
    ]

    # Process inputs to get all necessary tensors
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    inputs = {k: v.to("cpu") for k, v in inputs.items()}

    # Ensure pixel_values are float32 to match the model
    if "pixel_values" in inputs:
        inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)

    # === 1. Export Vision Encoder ===
    print("\nπŸ“¦ [1/4] Exporting Vision Encoder...")
    vision_wrapper = VisionEncoderWrapper(
        model.model.vision_tower, model.model.embed_vision, config
    ).eval()
    torch.onnx.export(
        model=vision_wrapper,
        args=(inputs["pixel_values"],),
        f=os.path.join(output_dir, "vision_encoder.onnx"),
        input_names=["pixel_values"],
        output_names=["image_features"],
        dynamic_axes={
            "pixel_values": {0: "batch_size"},
            "image_features": {0: "batch_size"},
        },
        opset_version=21,
        dynamo=True,
    )
    print("βœ… Vision encoder exported.")

    # === 2. Export Audio Encoder ===
    print("\nπŸ”Š [2/4] Exporting Audio Encoder...")
    audio_wrapper = AudioEncoderWrapper(
        model.model.audio_tower, model.model.embed_audio, config
    ).eval()
    torch.onnx.export(
        model=audio_wrapper,
        args=(
            inputs["input_features"].to(torch.float32),
            inputs["input_features_mask"],
        ),
        f=os.path.join(output_dir, "audio_encoder.onnx"),
        input_names=["input_features", "input_features_mask"],
        output_names=["audio_features"],
        dynamic_axes={
            "input_features": {0: "batch_size", 1: "sequence_length"},
            "input_features_mask": {0: "batch_size", 1: "sequence_length"},
            "audio_features": {0: "batch_size"},
        },
        opset_version=21,
        dynamo=True,
    )
    print(f"   - Audio features shape: {inputs['input_features'].shape}")
    print("βœ… Audio encoder exported.")

    # === 3. Export Token Embedder ===
    print("\nπŸ“ [3/4] Exporting Token Embedder...")
    token_embedder_wrapper = TokenEmbedderWrapper(
        model.model.language_model.embed_tokens,
        model.language_model.get_per_layer_inputs,
        config,
    ).eval()
    torch.onnx.export(
        model=token_embedder_wrapper,
        args=(inputs["input_ids"],),
        f=os.path.join(output_dir, "token_embedder.onnx"),
        input_names=["input_ids"],
        output_names=["text_embeds"],
        dynamic_axes={
            "input_ids": {0: "batch_size", 1: "sequence_length"},
            "text_embeds": {0: "batch_size", 1: "sequence_length"},
        },
        opset_version=21,
        dynamo=True,
    )

    #################################################################################
    model.language_model.save_pretrained('language-model')`

I am working on exporting the "google/gemma-3n-e4b-it" model to the ONNX format and am encountering issues with the language model (decoder) component.

I have been following the approach outlined in a Colab Notebook on how to export Llava, which was referenced in a related GitHub issue (#38924):

Notebook: https://colab.research.google.com/drive/1mtcxKWHAR7D9LbjzRcTB_b5CURm2jw-z

GitHub Issue: huggingface/transformers#38924

Could you please suggest how to export this model? I have found that a 2B version is available in the ONNX format here: https://huggingface.co/onnx-community/gemma-3n-E2B-it-ONNX/tree/main. I am curious about how to export it myself.

Who can help?

@xenova @zucchini-nlp
need some tutorial about how to export the decoder model πŸ™

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions