Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ai_edge_torch/odml_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from . import export_utils
from . import lowerings
from . import passes
from . import experimental
16 changes: 16 additions & 0 deletions ai_edge_torch/odml_torch/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Torch-TFL ops definitions, decompositions, and lowerings."""
from . import torch_tfl
52 changes: 52 additions & 0 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Torch ops to Torch-TFL decompositions."""
from typing import Sequence
from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
import numpy as np
import torch

decomps = {}
Expand Down Expand Up @@ -424,3 +425,54 @@ def _aten_topk_decomp(self, k, dim=-1, largest=True, sorted=True):
# torch.topk returns int64 indices, but tfl.topk_v2 returns indices in int32.
indices = indices.to(torch.int64)
return out, indices


@register_decomp(torch.ops.aten.scatter.src)
def _aten_scatter_src_decomp(self, dim, index, src):
index = index.to(torch.int32)

# --- 1. PREPARE THE `UPDATES` TENSOR ---
# The number of updates is determined by the shape of the `index` tensor.
# The `src` tensor must be reshaped to match `index`.
if src.dim() == 0:
# If `src` is a scalar, expand it to the shape of `index`.
updates = src.expand(index.shape)
else:
# If `src` is a tensor, slice its top-left corner to match `index`.
slicing_tuple = tuple(slice(s) for s in index.shape)
updates = src[slicing_tuple]

# --- 2. CREATE FULL COORDINATE INDICES FOR TENSORFLOW ---
# The coordinate grid must match the shape of the `index` tensor, as this
# defines the number of updates to perform.
grid_ranges = [torch.arange(s, dtype=torch.int32) for s in index.shape]
grid = list(torch.meshgrid(*grid_ranges, indexing="ij"))

# Handle negative dimension indexing
if dim < 0:
dim = self.dim() + dim

# Replace the coordinates for the scatter dimension with the `index` values.
grid[dim] = index

# Stack the coordinates along the *last* dimension to create the final
# indices tensor. This is the format required by TensorFlow's scatter ops.
# Final shape: (*index.shape, rank_of_input)
indices = torch.stack(grid, dim=-1)

# 3. PERFORM THE SCATTER OPERATION
# `tf.tensor_scatter_nd_update` replaces values at specified indices.
# This matches the behavior of `torch.scatter_`. If you need the behavior
# of `torch.scatter_add_`, you would use `tf.tensor_scatter_nd_add`.
update_tensor = torch.ops.tfl.scatter_nd(indices, updates, self.shape)

# Step 2b: Create a boolean mask of the updated locations. We do this by
# scattering `True` values to the same indices.
mask_updates = torch.full_like(updates, True, dtype=torch.bool)
scatter_mask = torch.ops.tfl.scatter_nd(indices, mask_updates, self.shape)

# --- 3. COMBINE ORIGINAL TENSOR WITH SCATTERED UPDATES ---
# Use the mask to take values from `update_tensor` where updates occurred,
# and from the original `input_tensor` everywhere else.
result = torch.where(scatter_mask, update_tensor, self)
return result
23 changes: 21 additions & 2 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from typing import Sequence

from ai_edge_torch import odml_torch
from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
from ai_edge_torch.odml_torch.lowerings import context
from ai_edge_torch.odml_torch.lowerings import registry
from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
from jax._src.lib.mlir import ir
Expand All @@ -27,7 +27,7 @@


lower = registry.lower
LoweringContext = odml_torch.lowerings.context.LoweringContext
LoweringContext = context.LoweringContext


def _ir_operation(
Expand Down Expand Up @@ -704,3 +704,22 @@ def _tfl_topk_v2_lowering(
],
attributes={},
)


@lower(torch.ops.tfl.scatter_nd.default)
def _tfl_scatter_nd_lowering(
lctx: LoweringContext,
indices: ir.Value,
updates: ir.Value,
shape: Sequence[int],
) -> ir.Value:
return _ir_operation(
"tfl.scatter_nd",
results=lowering_utils.node_meta_to_ir_types(lctx.node),
operands=[
indices,
updates,
lowering_utils.numpy_array_constant(np.array(shape, dtype=np.int32)),
],
attributes={},
)
23 changes: 23 additions & 0 deletions ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ai_edge_torch.odml_torch.experimental.torch_tfl import torch_library_utils
import numpy as np
import tensorflow as tf
import torch


Expand Down Expand Up @@ -318,6 +319,28 @@ def tfl_slice(
return input[tuple(slices)].clone()


@torch.library.custom_op("tfl::scatter_nd", mutates_args=())
def tfl_scatter_nd(
indices: torch.Tensor,
updates: torch.Tensor,
shape: Sequence[int],
) -> torch.Tensor:
# TODO: Implement this in native torch.
indices = indices.detach().numpy()
updates = updates.detach().numpy()
out = tf.scatter_nd(indices, updates, shape)
return torch.tensor(out)


@torch.library.register_fake("tfl::scatter_nd")
def tfl_scatter_nd(
indices: torch.Tensor,
updates: torch.Tensor,
shape: Sequence[int],
) -> torch.Tensor:
return torch.empty(shape, dtype=updates.dtype)


@torch.library.custom_op("tfl::slice.tensor", mutates_args=())
def tfl_slice_tensor(
input: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _assert_export_and_close(
("aten__softmax_5", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), 1, False), dict()),
("aten_topk_0", torch.ops.aten.topk.default, (rnd(torch.float32, (4, 10)), 3), dict()),
("aten_topk_1", torch.ops.aten.topk.default, (rnd(torch.float32, (4, 10)), 3), dict(dim=0)),
("aten_scatter_src_0", torch.ops.aten.scatter.src, (rnd(torch.float32, (10, 10)), 1, rnd(torch.int64, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
# fmt: on
# pyformat: enable
)
Expand Down
15 changes: 15 additions & 0 deletions ai_edge_torch/odml_torch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any, Callable, Optional

from ai_edge_torch import fx_infra
from ai_edge_torch.odml_torch.experimental import torch_tfl
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import hlo as stablehlo
Expand Down Expand Up @@ -370,6 +371,20 @@ def exported_program_to_mlir(
if _pre_lower_pass:
_pre_lower_pass(exported_program)

# DO NOT SUBMIT: Lower via torch_tfl for specific aten ops.
exported_program = fx_infra.safe_run_decompositions(
exported_program,
fx_infra.decomp.pre_convert_decomp()
| fx_infra.decomp.pre_lower_decomp()
| {
op: torch_tfl.decomps[op]
for op in [
torch.ops.aten.topk.default,
torch.ops.aten.scatter.src,
]
},
)

if not ir_context:
ir_context = export_utils.create_ir_context()

Expand Down
Loading