Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion ai_edge_torch/_convert/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

def _run_convert_passes(
exported_program: torch.export.ExportedProgram,
cast_i64_inputs_to_i32: bool,
) -> torch.export.ExportedProgram:
exported_program = generative_fx_passes.run_generative_passes(
exported_program
Expand All @@ -46,6 +47,10 @@ def _run_convert_passes(
fx_passes.CastInputsBf16ToF32Pass(),
]

if cast_i64_inputs_to_i32:
print("---------------> Casting i64 inputs to i32")
passes += [fx_passes.CastInputsI64ToI32Pass()]

# Debuginfo is not injected automatically by odml_torch. Only inject
# debuginfo via fx pass when using torch_xla.
if ai_edge_torch.config.use_torch_xla:
Expand Down Expand Up @@ -82,6 +87,7 @@ def convert_signatures(
signatures: list[signature.Signature],
*,
strict_export: Union[Literal["auto"], bool] = True,
cast_i64_inputs_to_i32: bool = False,
quant_config: Optional[qcfg.QuantConfig] = None,
_tfl_converter_flags: Optional[dict[str, Any]] = None,
_saved_model_dir: Optional[str] = None,
Expand All @@ -96,6 +102,8 @@ def convert_signatures(
and ensure the soundness of the exported graph. When
strict_export="auto", the function will try to export module in both
modes and use the first one succeeds for downstream conversion.
cast_i64_inputs_to_i32: If true, casts all inputs with torch.int64 type to
torch.int32.
quant_config: User-defined quantization method and scheme of the model.
_tfl_converter_flags: A nested dictionary allowing setting flags for the
underlying tflite converter.
Expand Down Expand Up @@ -147,7 +155,10 @@ def export(**kwargs):
]

# Apply default fx passes
exported_programs = list(map(_run_convert_passes, exported_programs))
exported_programs = [
_run_convert_passes(ep, cast_i64_inputs_to_i32)
for ep in exported_programs
]
tflite_model = lowertools.exported_programs_to_tflite(
exported_programs,
signatures,
Expand Down
8 changes: 8 additions & 0 deletions ai_edge_torch/_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def convert(
sample_kwargs=None,
*,
strict_export: Union[Literal["auto"], bool] = True,
cast_i64_inputs_to_i32: bool = False,
quant_config: Optional[qcfg.QuantConfig] = None,
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -159,6 +160,8 @@ def convert(
and ensure the soundness of the exported graph. When
strict_export="auto", the function will try to export module in both
modes and use the first one succeeds for downstream conversion.
cast_i64_inputs_to_i32: If true, casts all inputs with torch.int64 type to
torch.int32.
quant_config: User-defined quantization method and scheme of the model.
dynamic_shapes: Optional dict or tuple that specify dynamic shape
specifications for each input in original order. See
Expand Down Expand Up @@ -203,6 +206,7 @@ def convert(
converted_model = conversion.convert_signatures(
self._signatures,
strict_export=strict_export,
cast_i64_inputs_to_i32=cast_i64_inputs_to_i32,
quant_config=quant_config,
_tfl_converter_flags=_ai_edge_converter_flags,
_saved_model_dir=_saved_model_dir,
Expand Down Expand Up @@ -271,6 +275,7 @@ def convert(
sample_kwargs=None,
*,
strict_export: Union[Literal["auto"], bool] = True,
cast_i64_inputs_to_i32: bool = False,
quant_config: Optional[qcfg.QuantConfig] = None,
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
Expand All @@ -289,6 +294,8 @@ def convert(
and ensure the soundness of the exported graph. When strict_export="auto",
the function will try to export module in both modes and use the first one
succeeds for downstream conversion.
cast_i64_inputs_to_i32: If true, casts all inputs with torch.int64 type to
torch.int32.
quant_config: User-defined quantization method and scheme of the model.
dynamic_shapes: Optional dict or tuple that specify dynamic shape
specifications for each input in original order. See
Expand Down Expand Up @@ -317,6 +324,7 @@ def convert(
sample_args,
sample_kwargs,
strict_export=strict_export,
cast_i64_inputs_to_i32=cast_i64_inputs_to_i32,
quant_config=quant_config,
dynamic_shapes=dynamic_shapes,
_ai_edge_converter_flags=_ai_edge_converter_flags,
Expand Down
1 change: 1 addition & 0 deletions ai_edge_torch/_convert/fx_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
from ai_edge_torch._convert.fx_passes.cast_inputs_i64_to_i32_pass import CastInputsI64ToI32Pass
from ai_edge_torch._convert.fx_passes.eliminate_dead_code_pass import EliminateDeadCodePass
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
Expand Down
2 changes: 2 additions & 0 deletions ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable
from ai_edge_torch import fx_infra
from ai_edge_torch import lowertools
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
import torch
import torch.utils._pytree as pytree

Expand Down Expand Up @@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
# Explicitly reshape back to the original shape. This places the ReshapeOp
# outside of the HLFB.
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
return output

node.target = embedding
Expand Down
52 changes: 52 additions & 0 deletions ai_edge_torch/_convert/fx_passes/cast_inputs_i64_to_i32_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pass to cast all inputs with torch.int64 type to torch.int32."""


from ai_edge_torch import fx_infra
import torch


def cast_i32(x):
# return x.to(torch.int32)
return x.to(torch.float32)


class CastInputsI64ToI32Pass(fx_infra.ExportedProgramPassBase):
"""This pass casts all inputs with torch.int64 type to torch.int32."""

def call(self, exported_program: torch.export.ExportedProgram):
modified = False
for node in exported_program.graph.nodes:
if (
node.op in ("placeholder", "call_function")
and node.meta.get("val") is not None
and node.meta.get("val").dtype == torch.int64
):
if not node.users:
continue

modified = True
user = next(iter(node.users))
with exported_program.graph.inserting_before(user):
cast_node = exported_program.graph.call_function(
cast_i32,
(node,),
)
node.replace_all_uses_with(cast_node)
cast_node.replace_input_with(cast_node, node)

exported_program.graph_module.recompile()
return fx_infra.ExportedProgramPassResult(exported_program, modified)
35 changes: 35 additions & 0 deletions ai_edge_torch/_convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from typing import Tuple

import ai_edge_torch
from ai_edge_torch import fx_infra
from ai_edge_torch._convert import conversion_utils
from ai_edge_torch.odml_torch.experimental import torch_tfl
from ai_edge_torch.quantize import pt2e_quantizer
from ai_edge_torch.testing import model_coverage
import numpy as np
Expand Down Expand Up @@ -576,6 +578,39 @@ def forward(self, x: torch.Tensor):
self.fail(f"Conversion failed with bloat16 inputs: {err}")
# pylint: enable=broad-except

def test_convert_model_with_i64_inputs_legalization_error(self):
"""Test converting a simple model with torch.int64 input.

i64 inputs would remain in converted model signature but be casted to i32
right after the model inputs.
"""

class SampleModel(nn.Module):

def forward(self, x: torch.Tensor):
return torch.linspace(0.5, 10.5, steps=x.shape[0], dtype=torch.float64)

model = SampleModel().eval()
args = (torch.randint(0, 100, (10, 10), dtype=torch.int64),)

# pylint: disable=broad-except
try:
# Expect this to potentially raise an error during conversion
ai_edge_torch.convert(model, args, cast_i64_inputs_to_i32=False)
self.fail("Conversion succeeded unexpectedly")
except Exception as err:
print(f"Conversion failed as expected: {err}")
expected_error_message = "failed to legalize operation 'tfl.less'"
if expected_error_message not in str(err):
self.fail(f"Unexpected error message: {err}")

try:
# Expect this to fix the error during conversion
ai_edge_torch.convert(model, args, cast_i64_inputs_to_i32=True)
except Exception as err:
self.fail(f"Conversion failed with int64 inputs: {err}")
# pylint: enable=broad-except

def test_compile_model(self):
"""Tests AOT compilation of a simple Add module."""

Expand Down
51 changes: 43 additions & 8 deletions ai_edge_torch/generative/examples/gemma/gemma1.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
from torch import nn

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
Expand All @@ -36,6 +36,24 @@
lm_head=None,
)

TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
attn_query_proj="model.layers.{}.self_attn.q_proj",
attn_key_proj="model.layers.{}.self_attn.k_proj",
attn_value_proj="model.layers.{}.self_attn.v_proj",
attn_output_proj="model.layers.{}.self_attn.o_proj",
pre_attn_norm="model.layers.{}.input_layernorm",
post_attn_norm="model.layers.{}.post_attention_layernorm",
embedding="model.embed_tokens",
final_norm="model.norm",
)

TENSOR_NAMES_DICT = {
"safetensors": TENSOR_NAMES_SEP_QKV,
"kaggle": TENSOR_NAMES_FUSED_QKV,
}

class Gemma1(model_builder.DecoderOnlyModel):
"""A Gemma1 model built from the Edge Generative API layers."""
Expand Down Expand Up @@ -94,11 +112,28 @@ def build_2b_model(
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
mask_cache_size: int = 0,
) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(),
tensor_names=TENSOR_NAMES,
model_class=Gemma1,
custom_loader=custom_loader,
mask_cache_size=mask_cache_size,

# A list to store the reasons for each failure
key_errors = []

for tensor_names in TENSOR_NAMES_DICT.values():
try:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(),
tensor_names=tensor_names,
model_class=Gemma1,
custom_loader=custom_loader,
mask_cache_size=mask_cache_size,
)
except KeyError as ke:
# Store the specific key that was missing for later
key_errors.append(f"Missing key: {ke}")
continue

# If the loop finishes, raise an error with all the collected details
error_details = "\n".join(key_errors)
raise RuntimeError(
"Failed to build model after trying all configurations. "
f"Encountered the following errors:\n{error_details}"
)
95 changes: 95 additions & 0 deletions ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Verifies the reauthored SmolVLM2 Image Encoder model."""

import logging

from absl import app
from absl import flags
from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2
from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder
from PIL import Image
import requests
import torch
import transformers

_IMAGE_URL = flags.DEFINE_string(
"image_url",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
"The image URI to encode.",
)

_CHECKPOINT = flags.DEFINE_string(
"checkpoint",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"The checkpoint to verify.",
)

_REAUTHORTHED_CHECKPOINT = flags.DEFINE_string(
"pretrained_weights_path",
None,
"The path to the model's pretrained weights.",
)


def main(_):
checkpoint = _CHECKPOINT.value
logging.info("Loading the original model from: %s", checkpoint)
original_model = transformers.AutoModelForImageTextToText.from_pretrained(
checkpoint
)
original_model = original_model.eval().model

logging.info("Building the reauthored checkpoint from: %s", checkpoint)
reauthored_checkpoint = _REAUTHORTHED_CHECKPOINT.value
if reauthored_checkpoint is None:
raise ValueError("reauthored_checkpoint is required.")

logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint)

logging.info("Loading the tokenizer from: %s", checkpoint)
processor = transformers.AutoProcessor.from_pretrained(checkpoint)

logging.info("Loading the image from: %s", _IMAGE_URL.value)
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]

logging.info("Forwarding the original model...")
outputs_original = original_model.get_image_features(pixel_values)
logging.info("outputs_original's shape: %s", outputs_original.shape)

pixel_values = pixel_values.reshape(
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
)
logging.info("Forwarding the reauthored model...")
outputs_reauthored = reauthored_model.forward(
pixel_values=pixel_values
)
logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape)

try:
assert torch.allclose(
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
)
except AssertionError as e:
logging.error("*** FAILED *** verify with an image")
raise e
else:
logging.info("*** PASSED *** verify with an image")


if __name__ == "__main__":
app.run(main)
Loading
Loading