Skip to content

Commit 9d83b50

Browse files
junjiang-labcopybara-github
authored andcommitted
Support torch.ops.aten.sym_size.int and torch Tensor.reshape.
PiperOrigin-RevId: 750232210
1 parent ee4a41e commit 9d83b50

File tree

5 files changed

+107
-11
lines changed

5 files changed

+107
-11
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,41 @@ def _tfl_transpose_lowering(
269269
def _tfl_reshape_lowering(
270270
lctx: LoweringContext,
271271
x: ir.Value,
272-
shape: Sequence[int],
272+
shape: Sequence[int | ir.Value],
273273
) -> ir.Value:
274-
constant_shape = lowering_utils.numpy_array_constant(
275-
np.array(shape, dtype=np.int32)
276-
)
274+
# Check if all elements in the shape sequence are integers.
275+
if not shape or all(isinstance(dim, int) for dim in shape):
276+
# If all are integers, create a constant numpy array.
277+
# Assuming int32 is the required type for TFLite shape tensors.
278+
shape_ir_value = lowering_utils.numpy_array_constant(
279+
np.array(shape, dtype=np.int32)
280+
)
281+
else:
282+
# Handle mixed int and ir.Value shape sequence
283+
processed_dims = []
284+
for dim in shape:
285+
if isinstance(dim, int):
286+
# Convert int to a constant 1D tensor
287+
shape_ir_value = lowering_utils.numpy_array_constant(
288+
np.array([dim], dtype=np.int32)
289+
)
290+
processed_dims.append(shape_ir_value)
291+
else:
292+
assert isinstance(dim, ir.Value)
293+
# Convert ir.Value to a constant 1D tensor
294+
new_type = ir.RankedTensorType.get([1], dim.type.element_type)
295+
reshape_dim = stablehlo.reshape(new_type, dim)
296+
processed_dims.append(reshape_dim)
297+
298+
shape_ir_value = stablehlo.concatenate(
299+
processed_dims,
300+
dimension=0,
301+
)
302+
277303
return _ir_operation(
278304
"tfl.reshape",
279305
results=lowering_utils.node_meta_to_ir_types(lctx.node),
280-
operands=[x, constant_shape],
306+
operands=[x, shape_ir_value],
281307
)
282308

283309

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Numerical validation tests for torch ops and Torch-TFL ops."""
1616

17+
from typing import Any, Dict, Sequence
18+
1719
import ai_edge_torch
1820
from ai_edge_torch import testing
1921
from ai_edge_torch.odml_torch.experimental import torch_tfl
@@ -53,14 +55,23 @@ def setUp(self):
5355
torch.manual_seed(0)
5456

5557
def _assert_export_and_close(
56-
self, func, args, kwargs, atol=1e-3, rtol=1e-5, equal_nan=True
58+
self,
59+
func,
60+
args,
61+
kwargs,
62+
dynamic_shapes: Dict[str, Any] | Sequence[Any] | None = None,
63+
atol=1e-3,
64+
rtol=1e-5,
65+
equal_nan=True,
5766
):
5867
"""Assert func, args, and kwargs can be lowered and pass numerical validation."""
5968
with self.subTest("torch_eval"):
6069
expected = func(*args, **kwargs)
6170

6271
with self.subTest("export_and_decompse"):
63-
exported_program = export_with_tensor_inputs_only(func, args, kwargs)
72+
exported_program = export_with_tensor_inputs_only(
73+
func, args, kwargs, dynamic_shapes
74+
)
6475
exported_program = exported_program.run_decompositions(
6576
torch_tfl.decomps
6677
)
@@ -81,7 +92,9 @@ def _assert_export_and_close(
8192

8293
with self.subTest("convert_eval"):
8394
args, kwargs = exported_program.example_inputs
84-
edge_model = ai_edge_torch.convert(exported_program.module(), args)
95+
edge_model = ai_edge_torch.convert(
96+
exported_program.module(), args, dynamic_shapes=dynamic_shapes
97+
)
8598
actual = edge_model(*args, **kwargs)
8699

87100
with self.subTest("torch_convert_eval_diff:" + str(atol)):
@@ -142,9 +155,41 @@ def _assert_export_and_close(
142155
# fmt: on
143156
# pyformat: enable
144157
)
145-
def test_op(self, op, args, kwargs):
158+
def test_op(
159+
self,
160+
op,
161+
args,
162+
kwargs,
163+
):
146164
self._assert_export_and_close(op, args, kwargs)
147165

166+
@parameterized.named_parameters(
167+
# fmt: off
168+
# pyformat: disabledef
169+
("reshape_without_dynamic_shape_0", (rnd(torch.float32, (10, 2, 3)),), dict(), None),
170+
("reshape_with_dynamic_shape_1", (rnd(torch.float32, (10, 2, 3)),), dict(), ((torch.export.Dim("batch"), None, None),)),
171+
("reshape_with_dynamic_shape_2", (rnd(torch.float32, (10, 2, 3)),), dict(), ({0: torch.export.Dim("batch")},)),
172+
("reshape_with_dynamic_shape_3", (rnd(torch.float32, (10, 2, 3)),), dict(), ((torch.export.Dim("batch"), torch.export.Dim("height"), torch.export.Dim("width")),)),
173+
("reshape_with_dynamic_shape_4", (rnd(torch.float32, (10, 2, 3)),), dict(), ({0: torch.export.Dim("batch"), 1: torch.export.Dim("height"), 2: torch.export.Dim("width")},)),
174+
# fmt: on
175+
# pyformat: enable
176+
)
177+
def test_reshape_op(
178+
self,
179+
args,
180+
kwargs,
181+
dynamic_shapes: Dict[str, Any] | Sequence[Any] | None = None,
182+
):
183+
184+
class ReshapeModel(torch.nn.Module):
185+
186+
def forward(self, x):
187+
x = x + x
188+
x = x.reshape([x.shape[0], x.shape[1] * x.shape[2]])
189+
return x
190+
191+
self._assert_export_and_close(ReshapeModel(), args, kwargs, dynamic_shapes)
192+
148193

149194
if __name__ == "__main__":
150195
googletest.main()

ai_edge_torch/odml_torch/lowerings/_basic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
import math
16+
import operator
1617
from typing import Optional, Union
1718

1819
from ai_edge_torch.odml_torch import export_utils
@@ -24,6 +25,7 @@
2425
import numpy as np
2526
import torch
2627

28+
2729
LoweringContext = context.LoweringContext
2830
lower = registry.lower
2931

@@ -320,3 +322,22 @@ def _aten_to_copy(
320322
),
321323
x,
322324
)
325+
326+
327+
# Schema:
328+
# - aten::sym_size.int(Tensor self, int dim) -> SymInt
329+
@lower(torch.ops.aten.sym_size.int)
330+
def _aten_sym_size_int(lctx, x: ir.Value, dim: int):
331+
return stablehlo.get_dimension_size(x, dim)
332+
333+
334+
# Lowering for the multiplication operator (`*`).
335+
# Handles cases where one operand is an integer (scalar) and the other is a
336+
# tensor, broadcasting the scalar to the tensor's shape before multiplication.
337+
@lower(operator.mul)
338+
def _operator_mul(lctx, self: int | ir.Value, other: int | ir.Value):
339+
if isinstance(self, int) and isinstance(other, ir.Value):
340+
self = utils.splat(self, other.type.element_type, other.type.shape)
341+
if isinstance(other, int) and isinstance(self, ir.Value):
342+
other = utils.splat(other, self.type.element_type, self.type.shape)
343+
return stablehlo.multiply(self, other)

ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ def lower_by_torch_xla2(op):
218218
lower_by_torch_xla2(torch.ops.aten.sub.Scalar)
219219
lower_by_torch_xla2(torch.ops.aten.sub.Tensor)
220220
lower_by_torch_xla2(torch.ops.aten.sum)
221-
lower_by_torch_xla2(torch.ops.aten.sym_size)
222221
lower_by_torch_xla2(torch.ops.aten.t)
223222
lower_by_torch_xla2(torch.ops.aten.tan)
224223
lower_by_torch_xla2(torch.ops.aten.tanh)

ai_edge_torch/testing/export.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Torch export utilities for testing."""
1616

1717
from collections.abc import Callable
18-
from typing import Any
18+
from typing import Any, Dict, Sequence
1919

2020
import torch
2121
from torch.utils import _pytree as pytree
@@ -25,6 +25,7 @@ def export_with_tensor_inputs_only(
2525
model: Callable[..., Any],
2626
args: tuple[Any, ...],
2727
kwargs: dict[str, Any],
28+
dynamic_shapes: Dict[str, Any] | Sequence[Any] | None = None,
2829
) -> torch.export.ExportedProgram:
2930
"""Exports a PyTorch model, treating only tensor inputs as export inputs.
3031
@@ -76,8 +77,12 @@ def forward(self, *export_args):
7677

7778
export_args = tuple(export_args)
7879
export_kwargs = {}
80+
# Need to wrap dynamic_shapes in a tuple to match the inputs structure of
81+
# ModuleWrapper.
82+
dynamic_shapes = (dynamic_shapes,) if dynamic_shapes else None
7983
return torch.export.export(
8084
ModuleWrapper(model, args, kwargs).eval(),
8185
export_args,
8286
export_kwargs,
87+
dynamic_shapes=dynamic_shapes,
8388
)

0 commit comments

Comments
 (0)