Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable
from ai_edge_torch import fx_infra
from ai_edge_torch import lowertools
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
import torch
import torch.utils._pytree as pytree

Expand Down Expand Up @@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
# Explicitly reshape back to the original shape. This places the ReshapeOp
# outside of the HLFB.
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
return output

node.target = embedding
Expand Down
95 changes: 95 additions & 0 deletions ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Verifies the reauthored SmolVLM2 Image Encoder model."""

import logging

from absl import app
from absl import flags
from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2
from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder
from PIL import Image
import requests
import torch
import transformers

_IMAGE_URL = flags.DEFINE_string(
"image_url",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
"The image URI to encode.",
)

_CHECKPOINT = flags.DEFINE_string(
"checkpoint",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"The checkpoint to verify.",
)

_REAUTHORTHED_CHECKPOINT = flags.DEFINE_string(
"pretrained_weights_path",
None,
"The path to the model's pretrained weights.",
)


def main(_):
checkpoint = _CHECKPOINT.value
logging.info("Loading the original model from: %s", checkpoint)
original_model = transformers.AutoModelForImageTextToText.from_pretrained(
checkpoint
)
original_model = original_model.eval().model

logging.info("Building the reauthored checkpoint from: %s", checkpoint)
reauthored_checkpoint = _REAUTHORTHED_CHECKPOINT.value
if reauthored_checkpoint is None:
raise ValueError("reauthored_checkpoint is required.")

logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint)

logging.info("Loading the tokenizer from: %s", checkpoint)
processor = transformers.AutoProcessor.from_pretrained(checkpoint)

logging.info("Loading the image from: %s", _IMAGE_URL.value)
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]

logging.info("Forwarding the original model...")
outputs_original = original_model.get_image_features(pixel_values)
logging.info("outputs_original's shape: %s", outputs_original.shape)

pixel_values = pixel_values.reshape(
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
)
logging.info("Forwarding the reauthored model...")
outputs_reauthored = reauthored_model.forward(
pixel_values=pixel_values
)
logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape)

try:
assert torch.allclose(
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
)
except AssertionError as e:
logging.error("*** FAILED *** verify with an image")
raise e
else:
logging.info("*** PASSED *** verify with an image")


if __name__ == "__main__":
app.run(main)
13 changes: 12 additions & 1 deletion ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,18 @@ def forward(
pixel_values: torch.Tensor,
export_config: export_cfg.ExportConfig = None,
) -> torch.Tensor:
x = self.siglip_encoder(pixel_values)
# Embed the image according to SiplipVisionEmbeddings.
x = self.siglip_encoder.tok_embedding(pixel_values)
x = x.flatten(2).transpose(1, 2)
x = x + self.siglip_encoder.tok_embedding_position

# Pass a dummy mask because SDPA attention impl expects non-None mask.
mask = torch.zeros(x.shape[0], 1, x.shape[1], x.shape[1])
for _, block in enumerate(self.siglip_encoder.transformer_blocks):
x = block(x, mask=mask)
x = self.siglip_encoder.final_norm(x)

# Project the image embeddings to text hidden size.
x = self.connector(x)
return x

Expand Down
37 changes: 37 additions & 0 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ def _aten_rsqrt_decomp(x):
return torch.ops.tfl.rsqrt(x)


@register_decomp(torch.ops.aten.neg.default)
def _aten_neg_decomp(x):
return torch.ops.tfl.neg(x)


@register_decomp(torch.ops.aten.gelu.default)
def _aten_gelu_decomp(x, approximate="none"):
return torch.ops.tfl.gelu(x, approximate != "none")
Expand Down Expand Up @@ -317,6 +322,38 @@ def _aten_select_int_decomp(x, dim, index):
return torch.ops.tfl.squeeze(sliced, [dim])


@register_decomp(torch.ops.aten.slice.Tensor)
def _aten_slice_tensor_decomp(x, dim=0, start=None, end=None, step=1):
rank = x.dim()
dim_size = x.shape[dim]

# Initialize begin, end, strides for tfl.strided_slice
begin = [0] * rank
end_vec = list(x.shape)
strides = [1] * rank

# The logic below is to match PyTorch's `slice` behavior.
# `start` and `end` can be negative, which means they count from the end.
# `start=None` defaults to 0.
# `end=None` or a large number defaults to `dim_size` after clamping.

start_val = 0 if start is None else start
if start_val < 0:
start_val += dim_size

end_val = dim_size if end is None else end
if end_val < 0:
end_val += dim_size

# Clamp start and end to be within the dimension size, following PyTorch's
# logic.
start_val = max(0, min(start_val, dim_size))
end_val = max(start_val, min(end_val, dim_size))

begin[dim], end_vec[dim], strides[dim] = start_val, end_val, step
return torch.ops.tfl.strided_slice(x, begin, end_vec, strides)


@register_decomp(torch.ops.aten.where.self)
def _aten_where_self_decomp(condition, x, y):
x, y = _promote_types_for_binary_op(x, y)
Expand Down
12 changes: 12 additions & 0 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,18 @@ def _tfl_rsqrt_lowering(
)


@lower(torch.ops.tfl.neg.default)
def _tfl_neg_lowering(
lctx: LoweringContext,
x: ir.Value,
) -> ir.Value:
return _ir_operation(
"tfl.neg",
results=lowering_utils.node_meta_to_ir_types(lctx.node),
operands=[x],
)


@lower(torch.ops.tfl.gelu.default)
def _tfl_gelu_lowering(
lctx: LoweringContext,
Expand Down
5 changes: 5 additions & 0 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def tfl_rsqrt(x: torch.Tensor) -> torch.Tensor:
return torch.rsqrt(x)


@custom_op_with_fake("tfl::neg")
def tfl_neg(x: torch.Tensor) -> torch.Tensor:
return torch.neg(x)


@custom_op_with_fake("tfl::gelu")
def tfl_gelu(x: torch.Tensor, approximate: bool = False) -> torch.Tensor:
gelu_approximate = "tanh" if approximate else "none"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _assert_export_and_close(
("aten_cos_1", torch.ops.aten.cos.default, (rnd(torch.float32, (1, 10)),), dict()),
("aten_rsqrt_0", torch.ops.aten.rsqrt.default, (rnd(torch.float32, (10, 10)),), dict()),
("aten_rsqrt_1", torch.ops.aten.rsqrt.default, (rnd(torch.float32, (1, 10)),), dict()),
("aten_neg_0", torch.ops.aten.neg.default, (rnd(torch.float32, (10, 10)),), dict()),
("aten_gelu_0", torch.ops.aten.gelu.default, (rnd(torch.float32, (10, 10)),), dict()),
("aten_gelu_1", torch.ops.aten.gelu.default, (rnd(torch.float32, (10, 10)),), dict(approximate="tanh")),
("aten_gelu_2", torch.ops.aten.gelu.default, (rnd(torch.float32, (1, 10)),), dict()),
Expand Down Expand Up @@ -186,6 +187,14 @@ def _assert_export_and_close(
("aten_squeeze_dims_0", torch.ops.aten.squeeze.dims, (rnd(torch.float32, (2, 1, 2, 1, 2)), [1, 2, 3],), dict()),
("aten_select_int_0", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 0, 1,), dict()),
("aten_select_int_1", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 1, 1,), dict()),
("aten_slice_tensor_0", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=1, end=3)),
("aten_slice_tensor_1", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=1, start=2, end=5)),
("aten_slice_tensor_2", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=None, end=5)),
("aten_slice_tensor_3", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=2, end=None)),
("aten_slice_tensor_4", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=-5, end=-2)),
("aten_slice_tensor_5", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=1, end=8, step=2)),
("aten_slice_tensor_6", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=1, start=2, end=100)),
("aten_slice_tensor_7", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=None, end=None)),
("aten_where_self_0", torch.ops.aten.where.self, (rnd(torch.bool, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
("aten_embedding_0", torch.ops.aten.embedding.default, (rnd(torch.float32, (10, 10)), torch.tensor([[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]]),), dict()),
("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()),
Expand Down
71 changes: 71 additions & 0 deletions ai_edge_torch/odml_torch/optimization_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimization barrier op definition and lowering."""

from ai_edge_torch.odml_torch import _torch_library
from ai_edge_torch.odml_torch.lowerings import registry
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
import torch
import torch.utils._pytree as pytree

_torch_library.ODML_TORCH_LIB.define(
"optimization_barrier(Tensor[] inputs) -> Tensor[]"
)

optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default


def optimization_barrier(*inputs: pytree.PyTree):
"""Apply optimization barrier to the tensors nested within arbitrary pytrees.

Args:
*inputs: A list of tensors or tensor pytrees.

Returns:
The tensors after optimization barrier in the same pytrees structures.
"""
if len(inputs) == 1:
inputs = inputs[0]
tensors, spec = pytree.tree_flatten(inputs)
tensors = optimization_barrier_op(tuple(tensors))
outputs = pytree.tree_unflatten(tensors, spec)
return outputs


@torch.library.impl(
_torch_library.ODML_TORCH_LIB,
"optimization_barrier",
"CompositeExplicitAutograd",
)
def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]):
return tuple(inputs)


@torch.library.impl(
_torch_library.ODML_TORCH_LIB,
"optimization_barrier",
"Meta",
)
def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]):
return tuple([torch.empty_like(x) for x in inputs])


@registry.lower(torch.ops.odml_torch.optimization_barrier.default)
def _optimization_barrier_lowering(
lctx, inputs: tuple[ir.Value, ...]
) -> ir.Value:
del lctx
return stablehlo.optimization_barrier(inputs)
80 changes: 80 additions & 0 deletions ai_edge_torch/odml_torch/test/test_optimization_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from ai_edge_torch import odml_torch
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op.
import torch

from absl.testing import absltest as googletest

optimization_barrier = optimization_barrier_lib.optimization_barrier


class TestOptimizationBarrier(googletest.TestCase):
"""Test optimization barrier op implementation and lowering."""

def test_applied_optimization_barrier_op(self):
"""Test optimization barrier op application and lowering."""

class TestModel(torch.nn.Module):

def forward(self, x, y):
x, _ = optimization_barrier(x, y)
return x

x = torch.randn(1, 5)
ep = torch.export.export(TestModel().eval(), (x, x))
mlir = odml_torch.export.exported_program_to_mlir(ep)
mlir_text = mlir.get_text()
self.assertEqual(
mlir_text.count(
"stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>,"
" tensor<1x5xf32>"
),
1,
)

def test_input_single_tensor(self):
"""Test optimization barrier with single tensor input."""
x = torch.randn(1, 5)
y = optimization_barrier(x)
self.assertIsInstance(y, torch.Tensor)
self.assertEqual(y.shape, (1, 5))

def test_input_multiple_tensors(self):
"""Test optimization barrier with multiple tensors input."""
x = torch.randn(1, 5)
y = torch.randn(1, 6)
z = optimization_barrier(x, y)
self.assertIsInstance(z, tuple)
self.assertLen(z, 2)
self.assertIsInstance(z[0], torch.Tensor)
self.assertIsInstance(z[1], torch.Tensor)
self.assertEqual(z[0].shape, (1, 5))
self.assertEqual(z[1].shape, (1, 6))

def test_input_nested_tensors(self):
"""Test optimization barrier with nested tensor inputs."""
x = {"foo": torch.randn(1, 5), "bar": torch.randn(1, 6)}
z = optimization_barrier(x)
self.assertIsInstance(z, dict)
self.assertLen(z, 2)
self.assertIsInstance(z["foo"], torch.Tensor)
self.assertIsInstance(z["bar"], torch.Tensor)
self.assertEqual(z["foo"].shape, (1, 5))
self.assertEqual(z["bar"].shape, (1, 6))


if __name__ == "__main__":
googletest.main()
Loading