Skip to content

Commit 12aca92

Browse files
junjiang-labcopybara-github
authored andcommitted
Add impl for aten._softmax.default, tfl.softmax and lowering.
PiperOrigin-RevId: 743639346
1 parent 78bf19b commit 12aca92

File tree

4 files changed

+44
-0
lines changed

4 files changed

+44
-0
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,20 @@ def _aten_gelu_decomp(x, approximate="none"):
114114
@register_decomp(torch.ops.aten.permute.default)
115115
def _aten_permute_decomp(x, dims: Sequence[int]):
116116
return torch.ops.tfl.transpose(x, dims)
117+
118+
119+
@register_decomp(torch.ops.aten._softmax.default)
120+
def _aten__softmax_decomp(
121+
x, dim: int, half_to_float: bool # pylint: disable=unused-argument
122+
):
123+
if dim == -1 or dim == x.dim() - 1:
124+
return torch.ops.tfl.softmax(x)
125+
else:
126+
dims = list(range(x.dim()))
127+
# Transpose the input by swapping the dim with the last dimension.
128+
dims[dim], dims[-1] = dims[-1], dims[dim]
129+
x_permuted = torch.ops.tfl.transpose(x, dims)
130+
# Compute the softmax on the last dimension.
131+
softmax_result = torch.ops.tfl.softmax(x_permuted)
132+
# Transpose the result back to the original dimensions.
133+
return torch.ops.tfl.transpose(softmax_result, dims)

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,19 @@ def _tfl_transpose_lowering(
263263
results=lowering_utils.node_meta_to_ir_types(lctx.node),
264264
operands=[x, constant_perm],
265265
)
266+
267+
268+
@lower(torch.ops.tfl.softmax.default)
269+
def _tfl_softmax_lowering(
270+
lctx: LoweringContext,
271+
x: ir.Value,
272+
beta: float = 1.0,
273+
) -> ir.Value:
274+
return _ir_operation(
275+
"tfl.softmax",
276+
results=lowering_utils.node_meta_to_ir_types(lctx.node),
277+
operands=[x],
278+
attributes={
279+
"beta": ir.FloatAttr.get(ir.F32Type.get(), beta),
280+
},
281+
)

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def tfl_transpose(input: torch.Tensor, perm: Sequence[int]) -> torch.Tensor:
102102
return torch.permute(input, perm).clone()
103103

104104

105+
@custom_op_with_fake("tfl::softmax")
106+
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
107+
return torch.nn.functional.softmax(x)
108+
109+
105110
@custom_op_with_fake("tfl::slice")
106111
def tfl_slice(
107112
input: torch.Tensor, begin: Sequence[int], size: Sequence[int]

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ def _assert_export_and_close(
130130
("aten_gelu_3", torch.ops.aten.gelu.default, (rnd(torch.float32, (1, 10)),), dict(approximate="tanh")),
131131
("aten_permute_0", torch.ops.aten.permute.default, (rnd(torch.float32, (10, 10)), [0, 1],), dict()),
132132
("aten_permute_1", torch.ops.aten.permute.default, (rnd(torch.float32, (1, 10)), [0, 1],), dict()),
133+
("aten__softmax_0", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), -1, False), dict()),
134+
("aten__softmax_1", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), -1, False), dict()),
135+
("aten__softmax_2", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), 0, False), dict()),
136+
("aten__softmax_3", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), 0, False), dict()),
137+
("aten__softmax_4", torch.ops.aten._softmax.default, (rnd(torch.float32, (10, 10)), 1, False), dict()),
138+
("aten__softmax_5", torch.ops.aten._softmax.default, (rnd(torch.float32, (1, 10)), 1, False), dict()),
133139
# fmt: on
134140
# pyformat: enable
135141
)

0 commit comments

Comments
 (0)