Skip to content

Commit c4c568a

Browse files
Arm backend: Add logit decomposition pass and test (#13366)
Decomposes logit into other operators. Signed-off-by: Teo Bergkvist <[email protected]> Co-authored-by: Sebastian Larsson <[email protected]>
1 parent f2a7f46 commit c4c568a

File tree

6 files changed

+222
-0
lines changed

6 files changed

+222
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
4545
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
4646
from .decompose_linear_pass import DecomposeLinearPass # noqa
47+
from .decompose_logit_pass import DecomposeLogitPass # noqa
4748
from .decompose_masked_fill import DecomposeMaskedFill # noqa
4849
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
4950
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
DecomposeLeakyReLUPass,
5050
DecomposeLinearPass,
5151
DecomposeLinearVectorNormPass,
52+
DecomposeLogitPass,
5253
DecomposeMaskedFill,
5354
DecomposeMaxPool2DPass,
5455
DecomposeMeanDimPass,
@@ -166,6 +167,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
166167

167168
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
168169
self.add_pass(DecomposeExpm1Pass())
170+
self.add_pass(DecomposeLogitPass())
169171
self.add_pass(DecomposeMaskedFill())
170172
self.add_pass(DecomposeRoundPass())
171173
self.add_pass(DecomposeAcoshPass())
@@ -257,6 +259,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
257259
self.add_pass(DecomposeEmbeddingPass())
258260
self.add_pass(DecomposeScaledDotProductAttention())
259261
self.add_pass(DecomposeRoundPass())
262+
self.add_pass(DecomposeLogitPass())
260263
self.add_pass(CastBoolToInt8Pass())
261264
self.add_pass(DecomposeSignPass())
262265
self.add_pass(DecomposeAddmmPass())
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
12+
# For FP case
13+
edge_logit = exir_ops.edge.aten.logit.default
14+
# For INT case
15+
aten_logit = torch.ops.aten.logit.default
16+
17+
18+
def get_ops(op):
19+
"""Returns the appropriate operator functions based on the input operator."""
20+
if op == edge_logit:
21+
return (
22+
exir_ops.edge.aten.log.default,
23+
exir_ops.edge.aten.add.Scalar,
24+
exir_ops.edge.aten.reciprocal.default,
25+
exir_ops.edge.aten.mul.Tensor,
26+
exir_ops.edge.aten.mul.Scalar,
27+
exir_ops.edge.aten.clamp.default,
28+
)
29+
elif op == aten_logit:
30+
return (
31+
torch.ops.aten.log.default,
32+
torch.ops.aten.add.Scalar,
33+
torch.ops.aten.reciprocal.default,
34+
torch.ops.aten.mul.Tensor,
35+
torch.ops.aten.mul.Scalar,
36+
torch.ops.aten.clamp.default,
37+
)
38+
else:
39+
raise ValueError(f"Unsupported operator: {op}")
40+
41+
42+
class DecomposeLogitPass(ArmPass):
43+
"""
44+
Decomposes the `logit` operator into a sequence of primitive operations.
45+
46+
If `eps` is provided, the input tensor `x` is first clamped to the range
47+
[eps, 1 - eps].
48+
49+
The decomposition follows the identity:
50+
51+
logit(x) = log(x / (1 - x))
52+
53+
Examples:
54+
55+
logit(x) becomes:
56+
log(x * reciprocal((-1) * x + 1))
57+
58+
logit(x, eps) becomes:
59+
y = clamp(x, eps, 1 - eps)
60+
log(y * reciprocal((-1) * y + 1))
61+
"""
62+
63+
def call_operator(self, op, args, kwargs, meta):
64+
if op not in [edge_logit, aten_logit]:
65+
return super().call_operator(op, args, kwargs, meta)
66+
67+
X = args[0]
68+
eps = args[1] if len(args) > 1 else kwargs.get("eps", None)
69+
70+
(
71+
log_op,
72+
add_scalar_op,
73+
recip_op,
74+
mul_tensor_op,
75+
mul_scalar_op,
76+
clamp_op,
77+
) = get_ops(op)
78+
79+
if eps is not None:
80+
X = super().call_operator(
81+
clamp_op, (X, eps, 1.0 - eps), {}, meta, updated=True
82+
)
83+
84+
neg_X = super().call_operator(mul_scalar_op, (X, -1.0), {}, meta, updated=True)
85+
86+
denom = super().call_operator(
87+
add_scalar_op, (neg_X, 1.0), {}, meta, updated=True
88+
)
89+
90+
frac = super().call_operator(recip_op, (denom,), {}, meta, updated=True)
91+
92+
log_input = super().call_operator(
93+
mul_tensor_op, (X, frac), {}, meta, updated=True
94+
)
95+
96+
return super().call_operator(log_op, (log_input,), {}, meta, updated=True)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def is_node_supported(
260260
exir_ops.edge.aten.asinh.default,
261261
exir_ops.edge.aten.cosh.default,
262262
exir_ops.edge.aten.glu.default,
263+
exir_ops.edge.aten.logit.default,
263264
]
264265

265266
return supported
@@ -302,6 +303,7 @@ def is_node_supported(
302303
exir_ops.edge.aten.round.default: None,
303304
exir_ops.edge.aten.addmm.default: None,
304305
exir_ops.edge.aten.glu.default: None,
306+
exir_ops.edge.aten.logit.default: None,
305307
}
306308

307309
if node.target in needs_decomp_dict:

backends/arm/test/ops/test_logit.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
aten_op = "torch.ops.aten.logit.default"
20+
exir_op = "executorch_exir_dialects_edge__ops_aten__logit_default"
21+
22+
input_t1 = Tuple[torch.Tensor]
23+
24+
test_data_suite = {
25+
"zeros": [torch.zeros((10, 10, 10)), None],
26+
"ones": [torch.ones((10, 10, 10)), None],
27+
"uniform_valid": [torch.rand((10, 10, 10)), None],
28+
"near_zero": [torch.full((10, 10), 1e-8), None],
29+
"near_one": [torch.full((10, 10), 1 - 1e-8), None],
30+
"mixed": [torch.tensor([0.0, 1e-5, 0.5, 1 - 1e-5, 1.0]), None],
31+
"multi_dim": [torch.rand((2, 3, 4)), None],
32+
"eps": [torch.zeros((10, 10, 10)), 1e-6],
33+
"invalid_neg": [torch.full((5,), -0.1), 1e-6],
34+
"invalid_gt1": [torch.full((5,), 1.1), 1e-6],
35+
}
36+
37+
38+
class Logit(torch.nn.Module):
39+
40+
def forward(self, x: torch.Tensor, eps: torch.float32):
41+
return torch.logit(x, eps=eps)
42+
43+
44+
@common.parametrize("test_data", test_data_suite)
45+
def test_logit_tosa_FP(test_data: Tuple):
46+
pipeline = TosaPipelineFP[input_t1](
47+
Logit(),
48+
(*test_data,),
49+
aten_op=aten_op,
50+
exir_op=exir_op,
51+
)
52+
pipeline.run()
53+
54+
55+
@common.parametrize("test_data", test_data_suite)
56+
def test_logit_tosa_INT(test_data: Tuple):
57+
pipeline = TosaPipelineINT[input_t1](
58+
Logit(),
59+
(*test_data,),
60+
aten_op=[],
61+
exir_op=exir_op,
62+
)
63+
pipeline.run()
64+
65+
66+
@common.XfailIfNoCorstone300
67+
@common.parametrize("test_data", test_data_suite)
68+
def test_logit_u55_INT(test_data: Tuple):
69+
pipeline = EthosU55PipelineINT[input_t1](
70+
Logit(),
71+
(*test_data,),
72+
aten_ops=[],
73+
exir_ops=exir_op,
74+
)
75+
pipeline.run()
76+
77+
78+
@common.XfailIfNoCorstone320
79+
@common.parametrize("test_data", test_data_suite)
80+
def test_logit_u85_INT(test_data: Tuple):
81+
pipeline = EthosU85PipelineINT[input_t1](
82+
Logit(),
83+
(*test_data,),
84+
aten_ops=[],
85+
exir_ops=exir_op,
86+
)
87+
pipeline.run()
88+
89+
90+
@common.parametrize(
91+
"test_data",
92+
test_data_suite,
93+
)
94+
@common.SkipIfNoModelConverter
95+
def test_logit_vgf_FP(test_data: input_t1):
96+
pipeline = VgfPipeline[input_t1](
97+
Logit(),
98+
(*test_data,),
99+
[],
100+
[],
101+
tosa_version="TOSA-1.0+FP",
102+
)
103+
pipeline.run()
104+
105+
106+
@common.parametrize(
107+
"test_data",
108+
test_data_suite,
109+
)
110+
@common.SkipIfNoModelConverter
111+
def test_logit_vgf_INT(test_data: input_t1):
112+
pipeline = VgfPipeline[input_t1](
113+
Logit(),
114+
(*test_data,),
115+
[],
116+
[],
117+
tosa_version="TOSA-1.0+INT",
118+
)
119+
pipeline.run()

backends/arm/tosa_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
160160
torch.ops.aten.linear.default,
161161
torch.ops.aten.eye.default,
162162
torch.ops.aten.linspace.default,
163+
torch.ops.aten.logit.default,
163164
] + ops_to_not_decompose_if_quant_op
164165

165166
tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs)

0 commit comments

Comments
 (0)