diff --git a/ai_edge_torch/odml_torch/__init__.py b/ai_edge_torch/odml_torch/__init__.py index 34830cd5..4fbaed18 100644 --- a/ai_edge_torch/odml_torch/__init__.py +++ b/ai_edge_torch/odml_torch/__init__.py @@ -18,3 +18,4 @@ from . import export_utils from . import lowerings from . import passes +from . import experimental diff --git a/ai_edge_torch/odml_torch/experimental/__init__.py b/ai_edge_torch/odml_torch/experimental/__init__.py new file mode 100644 index 00000000..9137fd9c --- /dev/null +++ b/ai_edge_torch/odml_torch/experimental/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Torch-TFL ops definitions, decompositions, and lowerings.""" +from . import torch_tfl diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py index 582f2c16..8be5dc2c 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py @@ -15,6 +15,7 @@ """Torch ops to Torch-TFL decompositions.""" from typing import Sequence from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops +import numpy as np import torch decomps = {} @@ -424,3 +425,54 @@ def _aten_topk_decomp(self, k, dim=-1, largest=True, sorted=True): # torch.topk returns int64 indices, but tfl.topk_v2 returns indices in int32. indices = indices.to(torch.int64) return out, indices + + +@register_decomp(torch.ops.aten.scatter.src) +def _aten_scatter_src_decomp(self, dim, index, src): + index = index.to(torch.int32) + + # --- 1. PREPARE THE `UPDATES` TENSOR --- + # The number of updates is determined by the shape of the `index` tensor. + # The `src` tensor must be reshaped to match `index`. + if src.dim() == 0: + # If `src` is a scalar, expand it to the shape of `index`. + updates = src.expand(index.shape) + else: + # If `src` is a tensor, slice its top-left corner to match `index`. + slicing_tuple = tuple(slice(s) for s in index.shape) + updates = src[slicing_tuple] + + # --- 2. CREATE FULL COORDINATE INDICES FOR TENSORFLOW --- + # The coordinate grid must match the shape of the `index` tensor, as this + # defines the number of updates to perform. + grid_ranges = [torch.arange(s, dtype=torch.int32) for s in index.shape] + grid = list(torch.meshgrid(*grid_ranges, indexing="ij")) + + # Handle negative dimension indexing + if dim < 0: + dim = self.dim() + dim + + # Replace the coordinates for the scatter dimension with the `index` values. + grid[dim] = index + + # Stack the coordinates along the *last* dimension to create the final + # indices tensor. This is the format required by TensorFlow's scatter ops. + # Final shape: (*index.shape, rank_of_input) + indices = torch.stack(grid, dim=-1) + + # 3. PERFORM THE SCATTER OPERATION + # `tf.tensor_scatter_nd_update` replaces values at specified indices. + # This matches the behavior of `torch.scatter_`. If you need the behavior + # of `torch.scatter_add_`, you would use `tf.tensor_scatter_nd_add`. + update_tensor = torch.ops.tfl.scatter_nd(indices, updates, self.shape) + + # Step 2b: Create a boolean mask of the updated locations. We do this by + # scattering `True` values to the same indices. + mask_updates = torch.full_like(updates, True, dtype=torch.bool) + scatter_mask = torch.ops.tfl.scatter_nd(indices, mask_updates, self.shape) + + # --- 3. COMBINE ORIGINAL TENSOR WITH SCATTERED UPDATES --- + # Use the mask to take values from `update_tensor` where updates occurred, + # and from the original `input_tensor` everywhere else. + result = torch.where(scatter_mask, update_tensor, self) + return result diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py index 1b6a6471..bf60116e 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py @@ -16,8 +16,8 @@ from typing import Sequence -from ai_edge_torch import odml_torch from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops +from ai_edge_torch.odml_torch.lowerings import context from ai_edge_torch.odml_torch.lowerings import registry from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils from jax._src.lib.mlir import ir @@ -27,7 +27,7 @@ lower = registry.lower -LoweringContext = odml_torch.lowerings.context.LoweringContext +LoweringContext = context.LoweringContext def _ir_operation( @@ -704,3 +704,22 @@ def _tfl_topk_v2_lowering( ], attributes={}, ) + + +@lower(torch.ops.tfl.scatter_nd.default) +def _tfl_scatter_nd_lowering( + lctx: LoweringContext, + indices: ir.Value, + updates: ir.Value, + shape: Sequence[int], +) -> ir.Value: + return _ir_operation( + "tfl.scatter_nd", + results=lowering_utils.node_meta_to_ir_types(lctx.node), + operands=[ + indices, + updates, + lowering_utils.numpy_array_constant(np.array(shape, dtype=np.int32)), + ], + attributes={}, + ) diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py index 95e1e534..6e38af4a 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py @@ -19,6 +19,7 @@ from ai_edge_torch.odml_torch.experimental.torch_tfl import torch_library_utils import numpy as np +import tensorflow as tf import torch @@ -318,6 +319,28 @@ def tfl_slice( return input[tuple(slices)].clone() +@torch.library.custom_op("tfl::scatter_nd", mutates_args=()) +def tfl_scatter_nd( + indices: torch.Tensor, + updates: torch.Tensor, + shape: Sequence[int], +) -> torch.Tensor: + # TODO: Implement this in native torch. + indices = indices.detach().numpy() + updates = updates.detach().numpy() + out = tf.scatter_nd(indices, updates, shape) + return torch.tensor(out) + + +@torch.library.register_fake("tfl::scatter_nd") +def tfl_scatter_nd( + indices: torch.Tensor, + updates: torch.Tensor, + shape: Sequence[int], +) -> torch.Tensor: + return torch.empty(shape, dtype=updates.dtype) + + @torch.library.custom_op("tfl::slice.tensor", mutates_args=()) def tfl_slice_tensor( input: torch.Tensor, diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py index 5ce72a21..88fcb83f 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py @@ -218,6 +218,7 @@ def _assert_export_and_close( ("aten__softmax_5", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), 1, False), dict()), ("aten_topk_0", torch.ops.aten.topk.default, (rnd(torch.float32, (4, 10)), 3), dict()), ("aten_topk_1", torch.ops.aten.topk.default, (rnd(torch.float32, (4, 10)), 3), dict(dim=0)), + ("aten_scatter_src_0", torch.ops.aten.scatter.src, (rnd(torch.float32, (10, 10)), 1, rnd(torch.int64, (10, 10)), rnd(torch.float32, (10, 10)),), dict()), # fmt: on # pyformat: enable ) diff --git a/ai_edge_torch/odml_torch/export.py b/ai_edge_torch/odml_torch/export.py index 0a883721..ab384dbe 100644 --- a/ai_edge_torch/odml_torch/export.py +++ b/ai_edge_torch/odml_torch/export.py @@ -21,6 +21,7 @@ from typing import Any, Callable, Optional from ai_edge_torch import fx_infra +from ai_edge_torch.odml_torch.experimental import torch_tfl from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import hlo as stablehlo @@ -370,6 +371,20 @@ def exported_program_to_mlir( if _pre_lower_pass: _pre_lower_pass(exported_program) + # DO NOT SUBMIT: Lower via torch_tfl for specific aten ops. + exported_program = fx_infra.safe_run_decompositions( + exported_program, + fx_infra.decomp.pre_convert_decomp() + | fx_infra.decomp.pre_lower_decomp() + | { + op: torch_tfl.decomps[op] + for op in [ + torch.ops.aten.topk.default, + torch.ops.aten.scatter.src, + ] + }, + ) + if not ir_context: ir_context = export_utils.create_ir_context()