From 0a17475c1448c9b01962b3bee27972618ace6895 Mon Sep 17 00:00:00 2001 From: Jun Jiang Date: Wed, 3 Sep 2025 11:34:18 -0700 Subject: [PATCH] Add litert dialects for tfl ops. PiperOrigin-RevId: 802639397 --- .../odml_torch/experimental/torch_tfl/_decomps.py | 12 ++++++++++++ .../odml_torch/experimental/torch_tfl/_lowerings.py | 3 ++- .../torch_tfl/test/test_torch_tfl_impls.py | 2 ++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py index 3e249c77..582f2c16 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py @@ -243,6 +243,18 @@ def _aten_cat_decomp(tensors, dim=0): return torch.ops.tfl.concatenation(processed_tensors, dim) +@register_decomp(torch.ops.aten.full.default) +def _aten_full_decomp( + size, + fill_value, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + return torch.ops.tfl.fill(tuple(size), fill_value) + + @register_decomp(torch.ops.aten.full_like.default) def _aten_full_like_decomp( x, diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py index a2861b8b..1b6a6471 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py @@ -72,10 +72,11 @@ def _tfl_batch_matmul_lowering( @lower(torch.ops.tfl.add.default) def _tfl_add_lowering( lctx: LoweringContext, - lhs: ir.Value, + lhs: ir.Value | int | float, rhs: ir.Value | int | float, fused_activation_function: str = "NONE", ) -> ir.Value: + lhs = lowering_utils.convert_to_ir_value(lhs) rhs = lowering_utils.convert_to_ir_value(rhs) return _ir_operation( "tfl.add", diff --git a/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py b/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py index 28103bb7..5ce72a21 100644 --- a/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py +++ b/ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py @@ -179,6 +179,8 @@ def _assert_export_and_close( ("aten_cat_2", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()), ("aten_cat_3", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()), ("aten_cat_4", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))],), dict()), + ("aten_full_0", torch.ops.aten.full.default, ([10, 10], 0.123,), dict()), + ("aten_full_1", torch.ops.aten.full.default, ([10, 10], 123,), dict()), ("aten_full_like_0", torch.ops.aten.full_like.default, (rnd(torch.float32, (10, 10)), 0.123,), dict()), ("aten_full_like_1", torch.ops.aten.full_like.default, (rnd(torch.int64, (10, 10)), 123,), dict()), ("aten_view_0", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [1, 100],), dict()),