Skip to content

Commit 0aeea18

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.full_like and lowering.
PiperOrigin-RevId: 768220783
1 parent 92c6eb6 commit 0aeea18

File tree

5 files changed

+105
-31
lines changed

5 files changed

+105
-31
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,19 @@ def _aten_cat_decomp(tensors, dim=0):
203203
return torch.ops.tfl.concatenation(processed_tensors, dim)
204204

205205

206+
@register_decomp(torch.ops.aten.full_like.default)
207+
def _aten_full_like_decomp(
208+
x,
209+
fill_value,
210+
dtype=None,
211+
layout=None,
212+
device=None,
213+
pin_memory=None,
214+
memory_format=None,
215+
):
216+
return torch.ops.tfl.fill(tuple(x.shape), fill_value)
217+
218+
206219
@register_decomp(torch.ops.aten.view.default)
207220
def _aten_view_decomp(x, shape: Sequence[int]):
208221
return torch.ops.tfl.reshape(x, shape)

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -317,45 +317,69 @@ def _tfl_concatenation_lowering(
317317
)
318318

319319

320+
@lower(torch.ops.tfl.fill.default)
321+
def _tfl_fill_lowering(
322+
lctx: LoweringContext,
323+
dims: Sequence[int | ir.Value],
324+
fill_value: ir.Value,
325+
) -> ir.Value:
326+
dims_ir_value = lowering_utils.convert_shape_to_ir_value(dims)
327+
fill_value_ir_value = lowering_utils.convert_to_ir_value(fill_value)
328+
329+
# Ensure fill_value_ir_value is a scalar (0-D tensor) for TFLite Fill op.
330+
# The TFLite Fill kernel expects the value to be a 0-D tensor.
331+
if isinstance(fill_value_ir_value.type, ir.RankedTensorType):
332+
tensor_type = fill_value_ir_value.type
333+
# If it's a 1-D tensor with a single element, reshape to 0-D.
334+
if list(tensor_type.shape) == [1]:
335+
scalar_type = ir.RankedTensorType.get([], tensor_type.element_type)
336+
fill_value_ir_value = stablehlo.reshape(scalar_type, fill_value_ir_value)
337+
338+
# Determine the target element type from the node's output definition.
339+
result_types = lowering_utils.node_meta_to_ir_types(lctx.node)
340+
if not result_types or not isinstance(result_types[0], ir.RankedTensorType):
341+
raise ValueError(
342+
"tfl.fill: Unable to determine result tensor type or result is not a"
343+
" ranked tensor."
344+
)
345+
target_element_type = result_types[0].element_type
346+
347+
# Ensure fill_value_ir_value is a RankedTensorType to access its properties.
348+
if not isinstance(fill_value_ir_value.type, ir.RankedTensorType):
349+
raise TypeError(
350+
"tfl.fill: fill_value_ir_value expected to be RankedTensorType, got"
351+
f" {fill_value_ir_value.type}"
352+
)
353+
354+
current_fill_tensor_type = fill_value_ir_value.type
355+
current_element_type = current_fill_tensor_type.element_type
356+
357+
# If the element type of the (scalar) fill_value doesn't match the target
358+
# output element type, cast fill_value_ir_value to the target_element_type
359+
# while maintaining its current shape (which should be scalar).
360+
if current_element_type != target_element_type:
361+
cast_to_type = ir.RankedTensorType.get(
362+
current_fill_tensor_type.shape, target_element_type
363+
)
364+
fill_value_ir_value = stablehlo.convert(cast_to_type, fill_value_ir_value)
365+
366+
return _ir_operation(
367+
"tfl.fill",
368+
results=result_types,
369+
operands=[dims_ir_value, fill_value_ir_value],
370+
)
371+
372+
320373
@lower(torch.ops.tfl.reshape.default)
321374
def _tfl_reshape_lowering(
322375
lctx: LoweringContext,
323376
x: ir.Value,
324377
shape: Sequence[int | ir.Value],
325378
) -> ir.Value:
326-
# Check if all elements in the shape sequence are integers.
327-
if not shape or all(isinstance(dim, int) for dim in shape):
328-
# If all are integers, create a constant numpy array.
329-
# Assuming int32 is the required type for TFLite shape tensors.
330-
shape_ir_value = lowering_utils.numpy_array_constant(
331-
np.array(shape, dtype=np.int32)
332-
)
333-
else:
334-
# Handle mixed int and ir.Value shape sequence
335-
processed_dims = []
336-
for dim in shape:
337-
if isinstance(dim, int):
338-
# Convert int to a constant 1D tensor
339-
shape_ir_value = lowering_utils.numpy_array_constant(
340-
np.array([dim], dtype=np.int32)
341-
)
342-
processed_dims.append(shape_ir_value)
343-
else:
344-
assert isinstance(dim, ir.Value)
345-
# Convert ir.Value to a constant 1D tensor
346-
new_type = ir.RankedTensorType.get([1], dim.type.element_type)
347-
reshape_dim = stablehlo.reshape(new_type, dim)
348-
processed_dims.append(reshape_dim)
349-
350-
shape_ir_value = stablehlo.concatenate(
351-
processed_dims,
352-
dimension=0,
353-
)
354-
355379
return _ir_operation(
356380
"tfl.reshape",
357381
results=lowering_utils.node_meta_to_ir_types(lctx.node),
358-
operands=[x, shape_ir_value],
382+
operands=[x, lowering_utils.convert_shape_to_ir_value(shape)],
359383
)
360384

361385

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ def tfl_concatenation(
123123
return torch.cat(tensors, dim=dim)
124124

125125

126+
@custom_op_with_fake("tfl::fill", schema="(int[] x, Any y) -> Tensor")
127+
def tfl_fill(dims: Sequence[int], fill_value: Any) -> torch.Tensor:
128+
return torch.full(dims, fill_value)
129+
130+
126131
def _normalize_shape(
127132
tensor_input: torch.Tensor, shape: Sequence[int]
128133
) -> Sequence[int]:

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def _assert_export_and_close(
163163
("aten_cat_2", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()),
164164
("aten_cat_3", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()),
165165
("aten_cat_4", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))],), dict()),
166+
("aten_full_like_0", torch.ops.aten.full_like.default, (rnd(torch.float32, (10, 10)), 0.123,), dict()),
167+
("aten_full_like_1", torch.ops.aten.full_like.default, (rnd(torch.int64, (10, 10)), 123,), dict()),
166168
("aten_view_0", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [1, 100],), dict()),
167169
("aten_view_1", torch.ops.aten.view.default, (rnd(torch.float32, (1, 10)), [10, 1],), dict()),
168170
("aten_view_2", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [2, 5, 10],), dict()),

ai_edge_torch/odml_torch/lowerings/utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections.abc import Callable
1818
import functools
1919
import numbers
20-
from typing import Any, Optional, Union
20+
from typing import Any, Optional, Sequence, Union
2121
from ai_edge_torch.odml_torch import export_utils
2222
from jax._src.lib.mlir import ir
2323
from jax._src.lib.mlir.dialects import hlo as stablehlo
@@ -281,3 +281,33 @@ def convert_to_ir_value(
281281
if isinstance(value, ir.Value):
282282
return value
283283
raise TypeError(f"Unsupported type for conversion to ir.Value: {type(value)}")
284+
285+
286+
def convert_shape_to_ir_value(
287+
shape: Sequence[int],
288+
) -> ir.Value:
289+
# Check if all elements in the shape sequence are integers.
290+
if not shape or all(isinstance(dim, int) for dim in shape):
291+
# If all are integers, create a constant numpy array.
292+
# Assuming int32 is the required type for TFLite shape tensors.
293+
shape_ir_value = numpy_array_constant(np.array(shape, dtype=np.int32))
294+
else:
295+
# Handle mixed int and ir.Value shape sequence
296+
processed_dims = []
297+
for dim in shape:
298+
if isinstance(dim, int):
299+
# Convert int to a constant 1D tensor
300+
shape_ir_value = numpy_array_constant(np.array([dim], dtype=np.int32))
301+
processed_dims.append(shape_ir_value)
302+
else:
303+
assert isinstance(dim, ir.Value)
304+
# Convert ir.Value to a constant 1D tensor
305+
new_type = ir.RankedTensorType.get([1], dim.type.element_type)
306+
reshape_dim = stablehlo.reshape(new_type, dim)
307+
processed_dims.append(reshape_dim)
308+
309+
shape_ir_value = stablehlo.concatenate(
310+
processed_dims,
311+
dimension=0,
312+
)
313+
return shape_ir_value

0 commit comments

Comments
 (0)