Skip to content

Commit 5f22c45

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.mean.dim and lowering.
PiperOrigin-RevId: 768227681
1 parent 0aeea18 commit 5f22c45

File tree

4 files changed

+36
-0
lines changed

4 files changed

+36
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ def _aten_bitwise_and_tensor_decomp(x, y):
110110
return torch.ops.tfl.logical_and(x, y)
111111

112112

113+
@register_decomp(torch.ops.aten.mean.dim)
114+
def _aten_mean_dim_decomp(x, dim, keepdim=False):
115+
return torch.ops.tfl.mean(x, dim, keepdim)
116+
117+
113118
@register_decomp(torch.ops.aten.gt.Tensor)
114119
def _aten_gt_tensor_decomp(x, y):
115120
return torch.ops.tfl.greater(x, y)

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,27 @@ def _tfl_logical_and_lowering(
177177
)
178178

179179

180+
@lower(torch.ops.tfl.mean.default)
181+
def _tfl_mean_lowering(
182+
lctx: LoweringContext,
183+
x: ir.Value,
184+
dims: int | ir.Value | Sequence[int | ir.Value],
185+
keepdim: bool = False,
186+
) -> ir.Value:
187+
if isinstance(dims, int) or isinstance(dims, ir.Value):
188+
dims_ir_value = lowering_utils.convert_to_ir_value(dims)
189+
else:
190+
dims_ir_value = lowering_utils.convert_shape_to_ir_value(dims)
191+
return _ir_operation(
192+
"tfl.mean",
193+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
194+
operands=[x, dims_ir_value],
195+
attributes={
196+
"keep_dims": ir.BoolAttr.get(keepdim),
197+
},
198+
)
199+
200+
180201
@lower(torch.ops.tfl.greater.default)
181202
def _tfl_greater_lowering(
182203
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ def tfl_logical_and(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6868
return torch.logical_and(x, y)
6969

7070

71+
@custom_op_with_fake(
72+
"tfl::mean", schema="(Tensor x, Any dims, bool keepdim) -> Tensor"
73+
)
74+
def tfl_mean(x: torch.Tensor, dims: Any, keepdim: bool = False) -> torch.Tensor:
75+
return torch.mean(x, dims, keepdim)
76+
77+
7178
@custom_op_with_fake("tfl::greater")
7279
def tfl_greater(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
7380
return torch.gt(x, y)

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def _assert_export_and_close(
134134
("aten_pow_Tensor_Scalar_0", torch.ops.aten.pow.Tensor_Scalar, (rnd(torch.float32, (10, 10)), np.random.rand(),), dict()),
135135
("aten_pow_Tensor_Tensor_0", torch.ops.aten.pow.Tensor_Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
136136
("aten_bitwise_and_Tensor_0", torch.ops.aten.bitwise_and.Tensor, (rnd(torch.bool, (10, 10)), rnd(torch.bool, (10, 10)),), dict()),
137+
("aten_mean_dim_0", torch.ops.aten.mean.dim, (rnd(torch.float32, (10, 10)), 0), dict()),
138+
("aten_mean_dim_1", torch.ops.aten.mean.dim, (rnd(torch.float32, (10, 10)), 0, True), dict()),
139+
("aten_mean_dim_2", torch.ops.aten.mean.dim, (rnd(torch.float32, (10, 10)), 1), dict()),
137140
("aten_gt_Tensor_0", torch.ops.aten.gt.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
138141
("aten_gt_Tensor_1", torch.ops.aten.gt.Tensor, (rnd(torch.float32, (1, 10)), rnd(torch.float32, (10, 1)),), dict()),
139142
("aten_lt_Tensor_0", torch.ops.aten.lt.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),

0 commit comments

Comments
 (0)