diff --git a/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py b/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py index 3e92d9cd..fcc45513 100644 --- a/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +++ b/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py @@ -16,6 +16,7 @@ from typing import Any, Callable from ai_edge_torch import fx_infra from ai_edge_torch import lowertools +from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib import torch import torch.utils._pytree as pytree @@ -276,6 +277,7 @@ def embedding(*args, **kwargs): # Explicitly reshape back to the original shape. This places the ReshapeOp # outside of the HLFB. output = torch.reshape(output, (*(original_idx_shape), embedding_dim)) + output, _ = optimization_barrier_lib.optimization_barrier(output, idx) return output node.target = embedding diff --git a/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py b/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py new file mode 100644 index 00000000..7811821b --- /dev/null +++ b/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py @@ -0,0 +1,95 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Verifies the reauthored SmolVLM2 Image Encoder model.""" + +import logging + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2 +from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder +from PIL import Image +import requests +import torch +import transformers + +_IMAGE_URL = flags.DEFINE_string( + "image_url", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true", + "The image URI to encode.", +) + +_CHECKPOINT = flags.DEFINE_string( + "checkpoint", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "The checkpoint to verify.", +) + +_REAUTHORTHED_CHECKPOINT = flags.DEFINE_string( + "pretrained_weights_path", + None, + "The path to the model's pretrained weights.", +) + + +def main(_): + checkpoint = _CHECKPOINT.value + logging.info("Loading the original model from: %s", checkpoint) + original_model = transformers.AutoModelForImageTextToText.from_pretrained( + checkpoint + ) + original_model = original_model.eval().model + + logging.info("Building the reauthored checkpoint from: %s", checkpoint) + reauthored_checkpoint = _REAUTHORTHED_CHECKPOINT.value + if reauthored_checkpoint is None: + raise ValueError("reauthored_checkpoint is required.") + + logging.info("Building the reauthored model from: %s", reauthored_checkpoint) + reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint) + + logging.info("Loading the tokenizer from: %s", checkpoint) + processor = transformers.AutoProcessor.from_pretrained(checkpoint) + + logging.info("Loading the image from: %s", _IMAGE_URL.value) + image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw) + pixel_values = processor(images=image, return_tensors="pt")["pixel_values"] + + logging.info("Forwarding the original model...") + outputs_original = original_model.get_image_features(pixel_values) + logging.info("outputs_original's shape: %s", outputs_original.shape) + + pixel_values = pixel_values.reshape( + pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:] + ) + logging.info("Forwarding the reauthored model...") + outputs_reauthored = reauthored_model.forward( + pixel_values=pixel_values + ) + logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape) + + try: + assert torch.allclose( + outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04 + ) + except AssertionError as e: + logging.error("*** FAILED *** verify with an image") + raise e + else: + logging.info("*** PASSED *** verify with an image") + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py b/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py index 676c5c32..3b5d15da 100644 --- a/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py +++ b/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py @@ -129,7 +129,18 @@ def forward( pixel_values: torch.Tensor, export_config: export_cfg.ExportConfig = None, ) -> torch.Tensor: - x = self.siglip_encoder(pixel_values) + # Embed the image according to SiplipVisionEmbeddings. + x = self.siglip_encoder.tok_embedding(pixel_values) + x = x.flatten(2).transpose(1, 2) + x = x + self.siglip_encoder.tok_embedding_position + + # Pass a dummy mask because SDPA attention impl expects non-None mask. + mask = torch.zeros(x.shape[0], 1, x.shape[1], x.shape[1]) + for _, block in enumerate(self.siglip_encoder.transformer_blocks): + x = block(x, mask=mask) + x = self.siglip_encoder.final_norm(x) + + # Project the image embeddings to text hidden size. x = self.connector(x) return x diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py index 250118ef..7b93eb93 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py @@ -180,6 +180,11 @@ def _aten_rsqrt_decomp(x): return torch.ops.tfl.rsqrt(x) +@register_decomp(torch.ops.aten.neg.default) +def _aten_neg_decomp(x): + return torch.ops.tfl.neg(x) + + @register_decomp(torch.ops.aten.gelu.default) def _aten_gelu_decomp(x, approximate="none"): return torch.ops.tfl.gelu(x, approximate != "none") @@ -317,6 +322,38 @@ def _aten_select_int_decomp(x, dim, index): return torch.ops.tfl.squeeze(sliced, [dim]) +@register_decomp(torch.ops.aten.slice.Tensor) +def _aten_slice_tensor_decomp(x, dim=0, start=None, end=None, step=1): + rank = x.dim() + dim_size = x.shape[dim] + + # Initialize begin, end, strides for tfl.strided_slice + begin = [0] * rank + end_vec = list(x.shape) + strides = [1] * rank + + # The logic below is to match PyTorch's `slice` behavior. + # `start` and `end` can be negative, which means they count from the end. + # `start=None` defaults to 0. + # `end=None` or a large number defaults to `dim_size` after clamping. + + start_val = 0 if start is None else start + if start_val < 0: + start_val += dim_size + + end_val = dim_size if end is None else end + if end_val < 0: + end_val += dim_size + + # Clamp start and end to be within the dimension size, following PyTorch's + # logic. + start_val = max(0, min(start_val, dim_size)) + end_val = max(start_val, min(end_val, dim_size)) + + begin[dim], end_vec[dim], strides[dim] = start_val, end_val, step + return torch.ops.tfl.strided_slice(x, begin, end_vec, strides) + + @register_decomp(torch.ops.aten.where.self) def _aten_where_self_decomp(condition, x, y): x, y = _promote_types_for_binary_op(x, y) diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py index 5d3425a7..8f7d69ae 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py @@ -286,6 +286,18 @@ def _tfl_rsqrt_lowering( ) +@lower(torch.ops.tfl.neg.default) +def _tfl_neg_lowering( + lctx: LoweringContext, + x: ir.Value, +) -> ir.Value: + return _ir_operation( + "tfl.neg", + results=lowering_utils.node_meta_to_ir_types(lctx.node), + operands=[x], + ) + + @lower(torch.ops.tfl.gelu.default) def _tfl_gelu_lowering( lctx: LoweringContext, diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py index 9364a487..e13d7c9e 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py @@ -110,6 +110,11 @@ def tfl_rsqrt(x: torch.Tensor) -> torch.Tensor: return torch.rsqrt(x) +@custom_op_with_fake("tfl::neg") +def tfl_neg(x: torch.Tensor) -> torch.Tensor: + return torch.neg(x) + + @custom_op_with_fake("tfl::gelu") def tfl_gelu(x: torch.Tensor, approximate: bool = False) -> torch.Tensor: gelu_approximate = "tanh" if approximate else "none" diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py index 6f2250aa..0f8f351b 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py @@ -152,6 +152,7 @@ def _assert_export_and_close( ("aten_cos_1", torch.ops.aten.cos.default, (rnd(torch.float32, (1, 10)),), dict()), ("aten_rsqrt_0", torch.ops.aten.rsqrt.default, (rnd(torch.float32, (10, 10)),), dict()), ("aten_rsqrt_1", torch.ops.aten.rsqrt.default, (rnd(torch.float32, (1, 10)),), dict()), + ("aten_neg_0", torch.ops.aten.neg.default, (rnd(torch.float32, (10, 10)),), dict()), ("aten_gelu_0", torch.ops.aten.gelu.default, (rnd(torch.float32, (10, 10)),), dict()), ("aten_gelu_1", torch.ops.aten.gelu.default, (rnd(torch.float32, (10, 10)),), dict(approximate="tanh")), ("aten_gelu_2", torch.ops.aten.gelu.default, (rnd(torch.float32, (1, 10)),), dict()), @@ -186,6 +187,14 @@ def _assert_export_and_close( ("aten_squeeze_dims_0", torch.ops.aten.squeeze.dims, (rnd(torch.float32, (2, 1, 2, 1, 2)), [1, 2, 3],), dict()), ("aten_select_int_0", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 0, 1,), dict()), ("aten_select_int_1", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 1, 1,), dict()), + ("aten_slice_tensor_0", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=1, end=3)), + ("aten_slice_tensor_1", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=1, start=2, end=5)), + ("aten_slice_tensor_2", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=None, end=5)), + ("aten_slice_tensor_3", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=2, end=None)), + ("aten_slice_tensor_4", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=-5, end=-2)), + ("aten_slice_tensor_5", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=1, end=8, step=2)), + ("aten_slice_tensor_6", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=1, start=2, end=100)), + ("aten_slice_tensor_7", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=None, end=None)), ("aten_where_self_0", torch.ops.aten.where.self, (rnd(torch.bool, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()), ("aten_embedding_0", torch.ops.aten.embedding.default, (rnd(torch.float32, (10, 10)), torch.tensor([[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]]),), dict()), ("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()), diff --git a/ai_edge_torch/odml_torch/optimization_barrier.py b/ai_edge_torch/odml_torch/optimization_barrier.py new file mode 100644 index 00000000..88778b37 --- /dev/null +++ b/ai_edge_torch/odml_torch/optimization_barrier.py @@ -0,0 +1,71 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Optimization barrier op definition and lowering.""" + +from ai_edge_torch.odml_torch import _torch_library +from ai_edge_torch.odml_torch.lowerings import registry +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo as stablehlo +import torch +import torch.utils._pytree as pytree + +_torch_library.ODML_TORCH_LIB.define( + "optimization_barrier(Tensor[] inputs) -> Tensor[]" +) + +optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default + + +def optimization_barrier(*inputs: pytree.PyTree): + """Apply optimization barrier to the tensors nested within arbitrary pytrees. + + Args: + *inputs: A list of tensors or tensor pytrees. + + Returns: + The tensors after optimization barrier in the same pytrees structures. + """ + if len(inputs) == 1: + inputs = inputs[0] + tensors, spec = pytree.tree_flatten(inputs) + tensors = optimization_barrier_op(tuple(tensors)) + outputs = pytree.tree_unflatten(tensors, spec) + return outputs + + +@torch.library.impl( + _torch_library.ODML_TORCH_LIB, + "optimization_barrier", + "CompositeExplicitAutograd", +) +def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]): + return tuple(inputs) + + +@torch.library.impl( + _torch_library.ODML_TORCH_LIB, + "optimization_barrier", + "Meta", +) +def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]): + return tuple([torch.empty_like(x) for x in inputs]) + + +@registry.lower(torch.ops.odml_torch.optimization_barrier.default) +def _optimization_barrier_lowering( + lctx, inputs: tuple[ir.Value, ...] +) -> ir.Value: + del lctx + return stablehlo.optimization_barrier(inputs) diff --git a/ai_edge_torch/odml_torch/test/test_optimization_barrier.py b/ai_edge_torch/odml_torch/test/test_optimization_barrier.py new file mode 100644 index 00000000..d25d8c8e --- /dev/null +++ b/ai_edge_torch/odml_torch/test/test_optimization_barrier.py @@ -0,0 +1,80 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from ai_edge_torch import odml_torch +from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op. +import torch + +from absl.testing import absltest as googletest + +optimization_barrier = optimization_barrier_lib.optimization_barrier + + +class TestOptimizationBarrier(googletest.TestCase): + """Test optimization barrier op implementation and lowering.""" + + def test_applied_optimization_barrier_op(self): + """Test optimization barrier op application and lowering.""" + + class TestModel(torch.nn.Module): + + def forward(self, x, y): + x, _ = optimization_barrier(x, y) + return x + + x = torch.randn(1, 5) + ep = torch.export.export(TestModel().eval(), (x, x)) + mlir = odml_torch.export.exported_program_to_mlir(ep) + mlir_text = mlir.get_text() + self.assertEqual( + mlir_text.count( + "stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>," + " tensor<1x5xf32>" + ), + 1, + ) + + def test_input_single_tensor(self): + """Test optimization barrier with single tensor input.""" + x = torch.randn(1, 5) + y = optimization_barrier(x) + self.assertIsInstance(y, torch.Tensor) + self.assertEqual(y.shape, (1, 5)) + + def test_input_multiple_tensors(self): + """Test optimization barrier with multiple tensors input.""" + x = torch.randn(1, 5) + y = torch.randn(1, 6) + z = optimization_barrier(x, y) + self.assertIsInstance(z, tuple) + self.assertLen(z, 2) + self.assertIsInstance(z[0], torch.Tensor) + self.assertIsInstance(z[1], torch.Tensor) + self.assertEqual(z[0].shape, (1, 5)) + self.assertEqual(z[1].shape, (1, 6)) + + def test_input_nested_tensors(self): + """Test optimization barrier with nested tensor inputs.""" + x = {"foo": torch.randn(1, 5), "bar": torch.randn(1, 6)} + z = optimization_barrier(x) + self.assertIsInstance(z, dict) + self.assertLen(z, 2) + self.assertIsInstance(z["foo"], torch.Tensor) + self.assertIsInstance(z["bar"], torch.Tensor) + self.assertEqual(z["foo"].shape, (1, 5)) + self.assertEqual(z["bar"].shape, (1, 6)) + + +if __name__ == "__main__": + googletest.main()