Skip to content

Commit 9898978

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten.view.default, tfl.reshape and lowering.
PiperOrigin-RevId: 743746772
1 parent 12aca92 commit 9898978

File tree

4 files changed

+31
-0
lines changed

4 files changed

+31
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def _aten_permute_decomp(x, dims: Sequence[int]):
116116
return torch.ops.tfl.transpose(x, dims)
117117

118118

119+
@register_decomp(torch.ops.aten.view.default)
120+
def _aten_view_decomp(x, shape: Sequence[int]):
121+
return torch.ops.tfl.reshape(x, shape)
122+
123+
119124
@register_decomp(torch.ops.aten._softmax.default)
120125
def _aten__softmax_decomp(
121126
x, dim: int, half_to_float: bool # pylint: disable=unused-argument

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,22 @@ def _tfl_transpose_lowering(
265265
)
266266

267267

268+
@lower(torch.ops.tfl.reshape.default)
269+
def _tfl_reshape_lowering(
270+
lctx: LoweringContext,
271+
x: ir.Value,
272+
shape: Sequence[int],
273+
) -> ir.Value:
274+
constant_shape = lowering_utils.numpy_array_constant(
275+
np.array(shape, dtype=np.int32)
276+
)
277+
return _ir_operation(
278+
"tfl.reshape",
279+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
280+
operands=[x, constant_shape],
281+
)
282+
283+
268284
@lower(torch.ops.tfl.softmax.default)
269285
def _tfl_softmax_lowering(
270286
lctx: LoweringContext,

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ def tfl_transpose(input: torch.Tensor, perm: Sequence[int]) -> torch.Tensor:
102102
return torch.permute(input, perm).clone()
103103

104104

105+
@custom_op_with_fake("tfl::reshape")
106+
def tfl_reshape(input: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
107+
assert torch.Size(shape).numel() == input.numel()
108+
109+
return input.view(shape).clone()
110+
111+
105112
@custom_op_with_fake("tfl::softmax")
106113
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
107114
return torch.nn.functional.softmax(x)

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ def _assert_export_and_close(
130130
("aten_gelu_3", torch.ops.aten.gelu.default, (rnd(torch.float32, (1, 10)),), dict(approximate="tanh")),
131131
("aten_permute_0", torch.ops.aten.permute.default, (rnd(torch.float32, (10, 10)), [0, 1],), dict()),
132132
("aten_permute_1", torch.ops.aten.permute.default, (rnd(torch.float32, (1, 10)), [0, 1],), dict()),
133+
("aten_view_0", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [1, 100],), dict()),
134+
("aten_view_1", torch.ops.aten.view.default, (rnd(torch.float32, (1, 10)), [10, 1],), dict()),
135+
("aten_view_2", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [2, 5, 10],), dict()),
133136
("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()),
134137
("aten__softmax_1", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), -1, False), dict()),
135138
("aten__softmax_2", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), 0, False), dict()),

0 commit comments

Comments
 (0)