Skip to content

Commit 337bdee

Browse files
junjiang-labcopybara-github
authored andcommitted
Add direct lowering for aten.abs
PiperOrigin-RevId: 729271527
1 parent fc6c888 commit 337bdee

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

ai_edge_torch/odml_torch/lowerings/_basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,15 @@ def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
215215
return stablehlo.floor(x)
216216

217217

218+
# Schema:
219+
# - aten::abs(Tensor input) -> Tensor
220+
# Torch Reference:
221+
# - https://pytorch.org/docs/main/generated/torch.abs.html
222+
@lower(torch.ops.aten.abs.default)
223+
def _aten_abs(lctx, input: ir.Value, *, out=None) -> ir.Value:
224+
return stablehlo.abs(input)
225+
226+
218227
# Schema:
219228
# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor
220229
# Torch Reference:

ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def lower_by_torch_xla2(op):
7777
lower_by_torch_xla2(torch.ops.aten._to_copy)
7878
lower_by_torch_xla2(torch.ops.aten._unsafe_index)
7979
lower_by_torch_xla2(torch.ops.aten._unsafe_view)
80-
lower_by_torch_xla2(torch.ops.aten.abs)
8180
lower_by_torch_xla2(torch.ops.aten.acos)
8281
lower_by_torch_xla2(torch.ops.aten.acosh)
8382
lower_by_torch_xla2(torch.ops.aten.add.Scalar)

ai_edge_torch/odml_torch/test/test_core_aten_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,11 @@ def _run_export_and_compare(
136136

137137
@parameterized.named_parameters(
138138
# fmt: off
139-
# pyformat: disabledef
139+
# pyformat: disabledef
140140
("aten_abs_0", torch.ops.aten.abs, (rnd(torch.float32, (10, 10)),), dict()),
141+
("aten_abs_1", torch.ops.aten.abs, (rnd(torch.float32, (10, 10), -10, 0),), dict()),
142+
("aten_abs_2", torch.ops.aten.abs, (rnd(torch.float32, (10, 10), 0, 10),), dict()),
143+
("aten_abs_3", torch.ops.aten.abs, (rnd(torch.int64, (10, 10), -100, 100),), dict()),
141144
("aten_acos_0", torch.ops.aten.acos, (rnd(torch.float32, (10, 10)),), dict()),
142145
("aten_acosh_0", torch.ops.aten.acosh, (rnd(torch.float32, (10, 10)),), dict()),
143146
("aten_unsqueeze_0", torch.ops.aten.unsqueeze, (rnd(torch.float32, (1, 3, 10)), -2,), dict()),

0 commit comments

Comments
 (0)