diff --git a/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py b/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py index 3e92d9cd..fcc45513 100644 --- a/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +++ b/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py @@ -16,6 +16,7 @@ from typing import Any, Callable from ai_edge_torch import fx_infra from ai_edge_torch import lowertools +from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib import torch import torch.utils._pytree as pytree @@ -276,6 +277,7 @@ def embedding(*args, **kwargs): # Explicitly reshape back to the original shape. This places the ReshapeOp # outside of the HLFB. output = torch.reshape(output, (*(original_idx_shape), embedding_dim)) + output, _ = optimization_barrier_lib.optimization_barrier(output, idx) return output node.target = embedding diff --git a/ai_edge_torch/_convert/test/test_convert.py b/ai_edge_torch/_convert/test/test_convert.py index b613489d..f8e50262 100644 --- a/ai_edge_torch/_convert/test/test_convert.py +++ b/ai_edge_torch/_convert/test/test_convert.py @@ -576,6 +576,24 @@ def forward(self, x: torch.Tensor): self.fail(f"Conversion failed with bloat16 inputs: {err}") # pylint: enable=broad-except + def test_convert_model_with_torch_linspace_operation(self): + """Test converting a simple model with torch.linspace operation.""" + + class SampleModel(nn.Module): + + def forward(self, x: torch.Tensor): + return torch.linspace(0.5, 10.5, steps=x.shape[0], dtype=torch.float64) + + model = SampleModel().eval() + args = (torch.randint(0, 100, (10, 10), dtype=torch.int64),) + + try: + # Expect this to fix the error during conversion + ai_edge_torch.convert(model, args) + except Exception as err: + self.fail(f"Conversion failed with int64 inputs: {err}") + # pylint: enable=broad-except + def test_compile_model(self): """Tests AOT compilation of a simple Add module.""" diff --git a/ai_edge_torch/generative/examples/gemma/gemma1.py b/ai_edge_torch/generative/examples/gemma/gemma1.py index 0a04ed3f..c2549ff7 100644 --- a/ai_edge_torch/generative/examples/gemma/gemma1.py +++ b/ai_edge_torch/generative/examples/gemma/gemma1.py @@ -23,7 +23,7 @@ import torch from torch import nn -TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( +TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames( ff_up_proj="model.layers.{}.mlp.up_proj", ff_down_proj="model.layers.{}.mlp.down_proj", ff_gate_proj="model.layers.{}.mlp.gate_proj", @@ -36,6 +36,24 @@ lm_head=None, ) +TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames( + ff_up_proj="model.layers.{}.mlp.up_proj", + ff_down_proj="model.layers.{}.mlp.down_proj", + ff_gate_proj="model.layers.{}.mlp.gate_proj", + attn_query_proj="model.layers.{}.self_attn.q_proj", + attn_key_proj="model.layers.{}.self_attn.k_proj", + attn_value_proj="model.layers.{}.self_attn.v_proj", + attn_output_proj="model.layers.{}.self_attn.o_proj", + pre_attn_norm="model.layers.{}.input_layernorm", + post_attn_norm="model.layers.{}.post_attention_layernorm", + embedding="model.embed_tokens", + final_norm="model.norm", +) + +TENSOR_NAMES_DICT = { + "safetensors": TENSOR_NAMES_SEP_QKV, + "kaggle": TENSOR_NAMES_FUSED_QKV, +} class Gemma1(model_builder.DecoderOnlyModel): """A Gemma1 model built from the Edge Generative API layers.""" @@ -94,11 +112,28 @@ def build_2b_model( custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None, mask_cache_size: int = 0, ) -> nn.Module: - return model_builder.build_decoder_only_model( - checkpoint_path=checkpoint_path, - config=get_model_config_2b(), - tensor_names=TENSOR_NAMES, - model_class=Gemma1, - custom_loader=custom_loader, - mask_cache_size=mask_cache_size, + + # A list to store the reasons for each failure + key_errors = [] + + for tensor_names in TENSOR_NAMES_DICT.values(): + try: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config_2b(), + tensor_names=tensor_names, + model_class=Gemma1, + custom_loader=custom_loader, + mask_cache_size=mask_cache_size, + ) + except KeyError as ke: + # Store the specific key that was missing for later + key_errors.append(f"Missing key: {ke}") + continue + + # If the loop finishes, raise an error with all the collected details + error_details = "\n".join(key_errors) + raise RuntimeError( + "Failed to build model after trying all configurations. " + f"Encountered the following errors:\n{error_details}" ) diff --git a/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py b/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py new file mode 100644 index 00000000..f2395b04 --- /dev/null +++ b/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py @@ -0,0 +1,95 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Verifies the reauthored SmolVLM2 Image Encoder model.""" + +import logging + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2 +from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder +from PIL import Image +import requests +import torch +import transformers + +_IMAGE_URL = flags.DEFINE_string( + "image_url", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true", + "The image URI to encode.", +) + +_CHECKPOINT = flags.DEFINE_string( + "checkpoint", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "The checkpoint to verify.", +) + +_REAUTHORTHED_CHECKPOINT = flags.DEFINE_string( + "pretrained_weights", + None, + "The path to the model's pretrained weights.", +) + + +def main(_): + checkpoint = _CHECKPOINT.value + logging.info("Loading the original model from: %s", checkpoint) + original_model = transformers.AutoModelForImageTextToText.from_pretrained( + checkpoint + ) + original_model = original_model.eval().model + + logging.info("Building the reauthored checkpoint from: %s", checkpoint) + reauthored_checkpoint = _REAUTHORTHED_CHECKPOINT.value + if reauthored_checkpoint is None: + raise ValueError("reauthored_checkpoint is required.") + + logging.info("Building the reauthored model from: %s", reauthored_checkpoint) + reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint) + + logging.info("Loading the tokenizer from: %s", checkpoint) + processor = transformers.AutoProcessor.from_pretrained(checkpoint) + + logging.info("Loading the image from: %s", _IMAGE_URL.value) + image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw) + pixel_values = processor(images=image, return_tensors="pt")["pixel_values"] + + logging.info("Forwarding the original model...") + outputs_original = original_model.get_image_features(pixel_values) + logging.info("outputs_original's shape: %s", outputs_original.shape) + + pixel_values = pixel_values.reshape( + pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:] + ) + logging.info("Forwarding the reauthored model...") + outputs_reauthored = reauthored_model.forward( + pixel_values=pixel_values + ) + logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape) + + try: + assert torch.allclose( + outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04 + ) + except AssertionError as e: + logging.error("*** FAILED *** verify with an image") + raise e + else: + logging.info("*** PASSED *** verify with an image") + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py b/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py index 676c5c32..ad6bb4c6 100644 --- a/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py +++ b/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py @@ -20,7 +20,7 @@ """ from dataclasses import dataclass -from typing import Callable, Dict +from typing import Callable, Dict, Optional from ai_edge_torch.generative.examples.paligemma import image_encoder import ai_edge_torch.generative.layers.model_config as cfg @@ -127,9 +127,20 @@ def __init__( def forward( self, pixel_values: torch.Tensor, - export_config: export_cfg.ExportConfig = None, + export_config: Optional[export_cfg.ExportConfig] = None, ) -> torch.Tensor: - x = self.siglip_encoder(pixel_values) + # Embed the image according to SiplipVisionEmbeddings. + x = self.siglip_encoder.tok_embedding(pixel_values) + x = x.flatten(2).transpose(1, 2) + x = x + self.siglip_encoder.tok_embedding_position + + # Pass a dummy mask because SDPA attention impl expects non-None mask. + mask = torch.zeros(x.shape[0], 1, x.shape[1], x.shape[1]) + for _, block in enumerate(self.siglip_encoder.transformer_blocks): + x = block(x, mask=mask) + x = self.siglip_encoder.final_norm(x) + + # Project the image embeddings to text hidden size. x = self.connector(x) return x @@ -166,7 +177,7 @@ def get_image_encoder_config() -> cfg.ModelConfig: output_proj_use_bias=True, ) norm_config = cfg.NormalizationConfig( - type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6 + type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6, ) ff_config = cfg.FeedForwardConfig( type=cfg.FeedForwardType.SEQUENTIAL, @@ -189,15 +200,13 @@ def get_image_encoder_config() -> cfg.ModelConfig: image_embedding=image_embedding_config, block_configs=block_config, final_norm_config=norm_config, - # num_mm_tokens_per_image=81, - # enable_hlfb=False ) return config def build_image_encoder( checkpoint_path: str, - custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None, + custom_loader: Optional[Callable[[str], Dict[str, torch.Tensor]]] = None, ) -> FullVisionEncoder: """Builds a FullVisionEncoder from the checkpoint path.""" encoder_config = get_image_encoder_config() @@ -208,7 +217,6 @@ def build_image_encoder( ) loader.load(encoder.siglip_encoder, strict=False) - loader = loading_utils.ModelLoader(checkpoint_path, None, custom_loader) state = loader.get_state() converted_state = dict() converted_state["modality_projection.weight"] = state.pop( diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py index 250118ef..3e249c77 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py @@ -180,6 +180,11 @@ def _aten_rsqrt_decomp(x): return torch.ops.tfl.rsqrt(x) +@register_decomp(torch.ops.aten.neg.default) +def _aten_neg_decomp(x): + return torch.ops.tfl.neg(x) + + @register_decomp(torch.ops.aten.gelu.default) def _aten_gelu_decomp(x, approximate="none"): return torch.ops.tfl.gelu(x, approximate != "none") @@ -317,6 +322,38 @@ def _aten_select_int_decomp(x, dim, index): return torch.ops.tfl.squeeze(sliced, [dim]) +@register_decomp(torch.ops.aten.slice.Tensor) +def _aten_slice_tensor_decomp(x, dim=0, start=None, end=None, step=1): + rank = x.dim() + dim_size = x.shape[dim] + + # Initialize begin, end, strides for tfl.strided_slice + begin = [0] * rank + end_vec = list(x.shape) + strides = [1] * rank + + # The logic below is to match PyTorch's `slice` behavior. + # `start` and `end` can be negative, which means they count from the end. + # `start=None` defaults to 0. + # `end=None` or a large number defaults to `dim_size` after clamping. + + start_val = 0 if start is None else start + if start_val < 0: + start_val += dim_size + + end_val = dim_size if end is None else end + if end_val < 0: + end_val += dim_size + + # Clamp start and end to be within the dimension size, following PyTorch's + # logic. + start_val = max(0, min(start_val, dim_size)) + end_val = max(start_val, min(end_val, dim_size)) + + begin[dim], end_vec[dim], strides[dim] = start_val, end_val, step + return torch.ops.tfl.strided_slice(x, begin, end_vec, strides) + + @register_decomp(torch.ops.aten.where.self) def _aten_where_self_decomp(condition, x, y): x, y = _promote_types_for_binary_op(x, y) @@ -351,3 +388,27 @@ def _aten__softmax_decomp( softmax_result = torch.ops.tfl.softmax(x_permuted) # Transpose the result back to the original dimensions. return torch.ops.tfl.transpose(softmax_result, dims) + + +@register_decomp(torch.ops.aten.topk.default) +def _aten_topk_decomp(self, k, dim=-1, largest=True, sorted=True): + if not largest: + raise ValueError("Only largest=True is supported for torch.topk.") + + if dim < 0: + dim = self.dim() + dim + + if dim != self.dim() - 1: + self = torch.transpose(self, dim, -1) + + # Ignores sorted value: tfl.topk_v2 only supports sorted=True, but it doesn't + # affect the correctness of the output. + out, indices = torch.ops.tfl.topk_v2(self, k) + + if dim != self.dim() - 1: + out = torch.transpose(out, dim, -1) + indices = torch.transpose(indices, dim, -1) + + # torch.topk returns int64 indices, but tfl.topk_v2 returns indices in int32. + indices = indices.to(torch.int64) + return out, indices diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py index 5d3425a7..a2861b8b 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py @@ -286,6 +286,18 @@ def _tfl_rsqrt_lowering( ) +@lower(torch.ops.tfl.neg.default) +def _tfl_neg_lowering( + lctx: LoweringContext, + x: ir.Value, +) -> ir.Value: + return _ir_operation( + "tfl.neg", + results=lowering_utils.node_meta_to_ir_types(lctx.node), + operands=[x], + ) + + @lower(torch.ops.tfl.gelu.default) def _tfl_gelu_lowering( lctx: LoweringContext, @@ -674,3 +686,20 @@ def _tfl_softmax_lowering( "beta": ir.FloatAttr.get(ir.F32Type.get(), beta), }, ) + + +@lower(torch.ops.tfl.topk_v2.default) +def _tfl_topk_v2_lowering( + lctx: LoweringContext, + x: ir.Value, + k: int, +) -> tuple[ir.Value, ir.Value]: + return _ir_operation( + "tfl.topk_v2", + results=lowering_utils.node_meta_to_ir_types(lctx.node), + operands=[ + x, + lowering_utils.numpy_array_constant(np.array(k, dtype=np.int32)), + ], + attributes={}, + ) diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py index 9364a487..95e1e534 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py @@ -110,6 +110,11 @@ def tfl_rsqrt(x: torch.Tensor) -> torch.Tensor: return torch.rsqrt(x) +@custom_op_with_fake("tfl::neg") +def tfl_neg(x: torch.Tensor) -> torch.Tensor: + return torch.neg(x) + + @custom_op_with_fake("tfl::gelu") def tfl_gelu(x: torch.Tensor, approximate: bool = False) -> torch.Tensor: gelu_approximate = "tanh" if approximate else "none" @@ -292,6 +297,13 @@ def tfl_softmax(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.softmax(x, dim=-1) +@custom_op_with_fake("tfl::topk_v2") +def tfl_topk_v2(x: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]: + out, indices = torch.topk(x, k, dim=-1, largest=True, sorted=True) + indices = indices.to(torch.int32) + return out, indices + + @custom_op_with_fake( "tfl::slice", schema="(Tensor x, SymInt[] begin, SymInt[] size) -> Tensor" ) diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py index 6f2250aa..28103bb7 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py @@ -30,6 +30,14 @@ export_with_tensor_inputs_only = testing.export_with_tensor_inputs_only +def tree_map_list_to_tuple(x): + if isinstance(x, (list, tuple)): + return tuple(tree_map_list_to_tuple(y) for y in x) + if isinstance(x, dict): + return {k: tree_map_list_to_tuple(v) for k, v in x.items()} + return x + + def rnd(dtype, shape, min_v=None, max_v=None): """Shortcut for creating a random torch tensor.""" if dtype in (torch.int32, torch.int64, torch.bool): @@ -98,6 +106,9 @@ def _assert_export_and_close( actual = edge_model(*args, **kwargs) with self.subTest("torch_convert_eval_diff:" + str(atol)): + expected = tree_map_list_to_tuple(expected) + actual = tree_map_list_to_tuple(actual) + expected_flat, expected_spec = pytree.tree_flatten(expected) actual_flat, actual_spec = pytree.tree_flatten(actual) @@ -152,6 +163,7 @@ def _assert_export_and_close( ("aten_cos_1", torch.ops.aten.cos.default, (rnd(torch.float32, (1, 10)),), dict()), ("aten_rsqrt_0", torch.ops.aten.rsqrt.default, (rnd(torch.float32, (10, 10)),), dict()), ("aten_rsqrt_1", torch.ops.aten.rsqrt.default, (rnd(torch.float32, (1, 10)),), dict()), + ("aten_neg_0", torch.ops.aten.neg.default, (rnd(torch.float32, (10, 10)),), dict()), ("aten_gelu_0", torch.ops.aten.gelu.default, (rnd(torch.float32, (10, 10)),), dict()), ("aten_gelu_1", torch.ops.aten.gelu.default, (rnd(torch.float32, (10, 10)),), dict(approximate="tanh")), ("aten_gelu_2", torch.ops.aten.gelu.default, (rnd(torch.float32, (1, 10)),), dict()), @@ -186,6 +198,14 @@ def _assert_export_and_close( ("aten_squeeze_dims_0", torch.ops.aten.squeeze.dims, (rnd(torch.float32, (2, 1, 2, 1, 2)), [1, 2, 3],), dict()), ("aten_select_int_0", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 0, 1,), dict()), ("aten_select_int_1", torch.ops.aten.select.int, (rnd(torch.float32, (2, 3, 4)), 1, 1,), dict()), + ("aten_slice_tensor_0", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=1, end=3)), + ("aten_slice_tensor_1", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=1, start=2, end=5)), + ("aten_slice_tensor_2", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=None, end=5)), + ("aten_slice_tensor_3", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=2, end=None)), + ("aten_slice_tensor_4", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=-5, end=-2)), + ("aten_slice_tensor_5", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=1, end=8, step=2)), + ("aten_slice_tensor_6", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=1, start=2, end=100)), + ("aten_slice_tensor_7", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)),), dict(dim=0, start=None, end=None)), ("aten_where_self_0", torch.ops.aten.where.self, (rnd(torch.bool, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()), ("aten_embedding_0", torch.ops.aten.embedding.default, (rnd(torch.float32, (10, 10)), torch.tensor([[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]]),), dict()), ("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()), @@ -194,6 +214,8 @@ def _assert_export_and_close( ("aten__softmax_3", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), 0, False), dict()), ("aten__softmax_4", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), 1, False), dict()), ("aten__softmax_5", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), 1, False), dict()), + ("aten_topk_0", torch.ops.aten.topk.default, (rnd(torch.float32, (4, 10)), 3), dict()), + ("aten_topk_1", torch.ops.aten.topk.default, (rnd(torch.float32, (4, 10)), 3), dict(dim=0)), # fmt: on # pyformat: enable ) diff --git a/ai_edge_torch/odml_torch/export.py b/ai_edge_torch/odml_torch/export.py index 2783af11..8e8c17e4 100644 --- a/ai_edge_torch/odml_torch/export.py +++ b/ai_edge_torch/odml_torch/export.py @@ -21,7 +21,7 @@ from typing import Any, Callable, Optional from ai_edge_torch import fx_infra -from jax.lib import xla_extension +import jax.extend from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import hlo as stablehlo @@ -233,7 +233,7 @@ def module_bytecode_vhlo(self) -> bytes: target_version = stablehlo.get_version_from_compatibility_requirement( stablehlo.StablehloCompatibilityRequirement.WEEK_12 ) - module_bytecode = xla_extension.mlir.serialize_portable_artifact( + module_bytecode = jax.extend.mlir.serialize_portable_artifact( self.module_bytecode, target_version ) return module_bytecode diff --git a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py index 41cdf033..50cae945 100644 --- a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +++ b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py @@ -78,8 +78,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten._unsafe_view) lower_by_torch_xla2(torch.ops.aten.acos) lower_by_torch_xla2(torch.ops.aten.acosh) -lower_by_torch_xla2(torch.ops.aten.add.Scalar) -lower_by_torch_xla2(torch.ops.aten.add.Tensor) lower_by_torch_xla2(torch.ops.aten.addbmm.default) lower_by_torch_xla2(torch.ops.aten.addmm) lower_by_torch_xla2(torch.ops.aten.addmv) @@ -159,7 +157,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.logical_not) lower_by_torch_xla2(torch.ops.aten.logical_or) lower_by_torch_xla2(torch.ops.aten.logical_xor) -lower_by_torch_xla2(torch.ops.aten.lt) lower_by_torch_xla2(torch.ops.aten.max) lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices) lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward) @@ -170,8 +167,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.min) lower_by_torch_xla2(torch.ops.aten.minimum) lower_by_torch_xla2(torch.ops.aten.mm) -lower_by_torch_xla2(torch.ops.aten.mul.Scalar) -lower_by_torch_xla2(torch.ops.aten.mul.Tensor) lower_by_torch_xla2(torch.ops.aten.native_batch_norm) lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward) lower_by_torch_xla2(torch.ops.aten.ne) @@ -215,8 +210,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.squeeze) lower_by_torch_xla2(torch.ops.aten.squeeze_copy) lower_by_torch_xla2(torch.ops.aten.stack) -lower_by_torch_xla2(torch.ops.aten.sub.Scalar) -lower_by_torch_xla2(torch.ops.aten.sub.Tensor) lower_by_torch_xla2(torch.ops.aten.sum) lower_by_torch_xla2(torch.ops.aten.t) lower_by_torch_xla2(torch.ops.aten.tan) @@ -244,7 +237,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.view_copy) lower_by_torch_xla2(torch.ops.aten.where.ScalarOther) lower_by_torch_xla2(torch.ops.aten.where.ScalarSelf) -lower_by_torch_xla2(torch.ops.aten.where.self) lower_by_torch_xla2(torch.ops.prims.broadcast_in_dim) lower_by_torch_xla2(torch.ops.prims.var) @@ -259,6 +251,149 @@ def _aten_copy(self, src, **kwargs): return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src) +@registry.lower(torch.ops.aten.add.Scalar) +def _aten_add_scalar(lctx: LoweringContext, self, other): + _log_usage(torch.ops.aten.add.Scalar) + + @jax_bridge.wrap + def jax_lowering(self, other): + other_dtype = jnp.result_type(other) + promoted_type = jnp.promote_types(self.dtype, other_dtype) + if promoted_type == jnp.float64: + promoted_type = jnp.float32 + return jnp.add( + self.astype(promoted_type), jnp.array(other, dtype=promoted_type) + ) + + return jax_lowering(lctx, self, other) + + +@registry.lower(torch.ops.aten.add.Tensor) +def _aten_add_tensor(lctx: LoweringContext, self, other): + _log_usage(torch.ops.aten.add.Tensor) + + @jax_bridge.wrap + def jax_lowering(self, other): + promoted_type = jnp.promote_types(self.dtype, other.dtype) + if promoted_type == jnp.float64: + promoted_type = jnp.float32 + return jnp.add(self.astype(promoted_type), other.astype(promoted_type)) + + return jax_lowering(lctx, self, other) + + +@registry.lower(torch.ops.aten.sub.Scalar) +def _aten_sub_scalar(lctx: LoweringContext, self, other): + _log_usage(torch.ops.aten.sub.Scalar) + + @jax_bridge.wrap + def jax_lowering(self, other): + other_dtype = jnp.result_type(other) + promoted_type = jnp.promote_types(self.dtype, other_dtype) + if promoted_type == jnp.float64: + promoted_type = jnp.float32 + return jnp.subtract( + self.astype(promoted_type), jnp.array(other, dtype=promoted_type) + ) + + return jax_lowering(lctx, self, other) + + +@registry.lower(torch.ops.aten.sub.Tensor) +def _aten_sub_tensor(lctx: LoweringContext, self, other): + _log_usage(torch.ops.aten.sub.Tensor) + + @jax_bridge.wrap + def jax_lowering(self, other): + promoted_type = jnp.promote_types(self.dtype, other.dtype) + if promoted_type == jnp.float64: + promoted_type = jnp.float32 + return jnp.subtract(self.astype(promoted_type), other.astype(promoted_type)) + + return jax_lowering(lctx, self, other) + + +@registry.lower(torch.ops.aten.lt.Scalar) +def _aten_lt_scalar(lctx: LoweringContext, self, other): + _log_usage(torch.ops.aten.lt.Scalar) + + @jax_bridge.wrap + def jax_lowering(self, other): + other_dtype = jnp.result_type(other) + promoted_type = jnp.promote_types(self.dtype, other_dtype) + if promoted_type == jnp.float64: + promoted_type = jnp.float32 + return jnp.less( + self.astype(promoted_type), jnp.array(other, dtype=promoted_type) + ) + + return jax_lowering(lctx, self, other) + + +@registry.lower(torch.ops.aten.lt.Tensor) +def _aten_lt_tensor(lctx: LoweringContext, self, other): + _log_usage(torch.ops.aten.lt.Tensor) + + @jax_bridge.wrap + def jax_lowering(self, other): + promoted_type = jnp.promote_types(self.dtype, other.dtype) + return jnp.less(self.astype(promoted_type), other.astype(promoted_type)) + + return jax_lowering(lctx, self, other) + + +@registry.lower(torch.ops.aten.mul.Scalar) +def _aten_mul_scalar(lctx: LoweringContext, self, other): + _log_usage(torch.ops.aten.mul.Scalar) + + @jax_bridge.wrap + def jax_lowering(self, other): + other_dtype = jnp.result_type(other) + promoted_type = jnp.promote_types(self.dtype, other_dtype) + if promoted_type == jnp.float64: + promoted_type = jnp.float32 + return jnp.multiply( + self.astype(promoted_type), jnp.array(other, dtype=promoted_type) + ) + + return jax_lowering(lctx, self, other) + + +@registry.lower(torch.ops.aten.mul.Tensor) +def _aten_mul_tensor(lctx: LoweringContext, self, other): + _log_usage(torch.ops.aten.mul.Tensor) + + @jax_bridge.wrap + def jax_lowering(self, other): + other_dtype = jnp.result_type(other) + promoted_type = jnp.promote_types(self.dtype, other_dtype) + if promoted_type == jnp.float64: + promoted_type = jnp.float32 + return jnp.multiply( + self.astype(promoted_type), jnp.array(other, dtype=promoted_type) + ) + + return jax_lowering(lctx, self, other) + + +@registry.lower(torch.ops.aten.where.self) +def _aten_where_self(lctx: LoweringContext, condition, self, other): + _log_usage(torch.ops.aten.where.self) + + @jax_bridge.wrap + def jax_lowering(condition, self, other): + promoted_type = jnp.promote_types(self.dtype, other.dtype) + if promoted_type == jnp.float64: + promoted_type = jnp.float32 + return jnp.where( + condition, + self.astype(promoted_type), + other.astype(promoted_type), + ) + + return jax_lowering(lctx, condition, self, other) + + # Schema: # - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) # -> Tensor diff --git a/ai_edge_torch/odml_torch/optimization_barrier.py b/ai_edge_torch/odml_torch/optimization_barrier.py new file mode 100644 index 00000000..88778b37 --- /dev/null +++ b/ai_edge_torch/odml_torch/optimization_barrier.py @@ -0,0 +1,71 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Optimization barrier op definition and lowering.""" + +from ai_edge_torch.odml_torch import _torch_library +from ai_edge_torch.odml_torch.lowerings import registry +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo as stablehlo +import torch +import torch.utils._pytree as pytree + +_torch_library.ODML_TORCH_LIB.define( + "optimization_barrier(Tensor[] inputs) -> Tensor[]" +) + +optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default + + +def optimization_barrier(*inputs: pytree.PyTree): + """Apply optimization barrier to the tensors nested within arbitrary pytrees. + + Args: + *inputs: A list of tensors or tensor pytrees. + + Returns: + The tensors after optimization barrier in the same pytrees structures. + """ + if len(inputs) == 1: + inputs = inputs[0] + tensors, spec = pytree.tree_flatten(inputs) + tensors = optimization_barrier_op(tuple(tensors)) + outputs = pytree.tree_unflatten(tensors, spec) + return outputs + + +@torch.library.impl( + _torch_library.ODML_TORCH_LIB, + "optimization_barrier", + "CompositeExplicitAutograd", +) +def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]): + return tuple(inputs) + + +@torch.library.impl( + _torch_library.ODML_TORCH_LIB, + "optimization_barrier", + "Meta", +) +def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]): + return tuple([torch.empty_like(x) for x in inputs]) + + +@registry.lower(torch.ops.odml_torch.optimization_barrier.default) +def _optimization_barrier_lowering( + lctx, inputs: tuple[ir.Value, ...] +) -> ir.Value: + del lctx + return stablehlo.optimization_barrier(inputs) diff --git a/ai_edge_torch/odml_torch/test/test_optimization_barrier.py b/ai_edge_torch/odml_torch/test/test_optimization_barrier.py new file mode 100644 index 00000000..d25d8c8e --- /dev/null +++ b/ai_edge_torch/odml_torch/test/test_optimization_barrier.py @@ -0,0 +1,80 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from ai_edge_torch import odml_torch +from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op. +import torch + +from absl.testing import absltest as googletest + +optimization_barrier = optimization_barrier_lib.optimization_barrier + + +class TestOptimizationBarrier(googletest.TestCase): + """Test optimization barrier op implementation and lowering.""" + + def test_applied_optimization_barrier_op(self): + """Test optimization barrier op application and lowering.""" + + class TestModel(torch.nn.Module): + + def forward(self, x, y): + x, _ = optimization_barrier(x, y) + return x + + x = torch.randn(1, 5) + ep = torch.export.export(TestModel().eval(), (x, x)) + mlir = odml_torch.export.exported_program_to_mlir(ep) + mlir_text = mlir.get_text() + self.assertEqual( + mlir_text.count( + "stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>," + " tensor<1x5xf32>" + ), + 1, + ) + + def test_input_single_tensor(self): + """Test optimization barrier with single tensor input.""" + x = torch.randn(1, 5) + y = optimization_barrier(x) + self.assertIsInstance(y, torch.Tensor) + self.assertEqual(y.shape, (1, 5)) + + def test_input_multiple_tensors(self): + """Test optimization barrier with multiple tensors input.""" + x = torch.randn(1, 5) + y = torch.randn(1, 6) + z = optimization_barrier(x, y) + self.assertIsInstance(z, tuple) + self.assertLen(z, 2) + self.assertIsInstance(z[0], torch.Tensor) + self.assertIsInstance(z[1], torch.Tensor) + self.assertEqual(z[0].shape, (1, 5)) + self.assertEqual(z[1].shape, (1, 6)) + + def test_input_nested_tensors(self): + """Test optimization barrier with nested tensor inputs.""" + x = {"foo": torch.randn(1, 5), "bar": torch.randn(1, 6)} + z = optimization_barrier(x) + self.assertIsInstance(z, dict) + self.assertLen(z, 2) + self.assertIsInstance(z["foo"], torch.Tensor) + self.assertIsInstance(z["bar"], torch.Tensor) + self.assertEqual(z["foo"].shape, (1, 5)) + self.assertEqual(z["bar"].shape, (1, 6)) + + +if __name__ == "__main__": + googletest.main()