Skip to content

Commit 27b7077

Browse files
chunnienccopybara-github
authored andcommitted
Improve _safe_softmax lowering
PiperOrigin-RevId: 704838624
1 parent cf0e73f commit 27b7077

File tree

3 files changed

+60
-33
lines changed

3 files changed

+60
-33
lines changed

ai_edge_torch/odml_torch/lowerings/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@
2121
from . import context
2222
from . import registry
2323
from . import utils
24-
from .registry import decompositions
24+
from .decomp import decompositions
2525
from .registry import lookup
2626
from .registry import lower
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Torch export decompositions to run before lowering."""
16+
17+
import functools
18+
19+
import torch
20+
21+
22+
@functools.cache
23+
def decompositions():
24+
# Base: Core ATen decompositions
25+
decompositions = torch._decomp.core_aten_decompositions()
26+
27+
decompositions.update(
28+
torch._decomp.get_decompositions([
29+
torch.ops.aten.upsample_nearest2d,
30+
torch.ops.aten._native_batch_norm_legit.no_stats,
31+
torch.ops.aten._native_batch_norm_legit_functional,
32+
torch.ops.aten._adaptive_avg_pool2d,
33+
torch.ops.aten._adaptive_avg_pool3d,
34+
torch.ops.aten.grid_sampler_2d,
35+
torch.ops.aten.native_group_norm,
36+
torch.ops.aten.native_dropout,
37+
torch.ops.aten.reflection_pad1d,
38+
torch.ops.aten.reflection_pad2d,
39+
torch.ops.aten.reflection_pad3d,
40+
torch.ops.aten.replication_pad1d,
41+
torch.ops.aten.replication_pad2d,
42+
torch.ops.aten.replication_pad3d,
43+
torch.ops.aten.addmm,
44+
])
45+
)
46+
47+
torch._decomp.remove_decompositions(
48+
decompositions,
49+
[torch.ops.aten.roll],
50+
)
51+
52+
# Override _safe_softmax decompositions with regular softmax.
53+
# _safe_softmax introduces additional check-select ops to guard extreme
54+
# input values to softmax, which could make the converted model inefficient
55+
# on-device.
56+
if hasattr(torch.ops.aten, "_safe_softmax"):
57+
decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
58+
59+
return decompositions

ai_edge_torch/odml_torch/lowerings/registry.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ class LoweringRegistry:
2626

2727
def __init__(self):
2828
self.registered_ops = {}
29-
self.decompositions = {}
3029

3130
def lookup(self, op_or_name):
3231
candidate = self._get_lowering(op_or_name)
@@ -52,33 +51,6 @@ def register(self, op, lowering):
5251

5352

5453
global_registry = LoweringRegistry()
55-
global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
56-
global_registry.decompositions.update(
57-
torch._decomp.get_decompositions([
58-
torch.ops.aten.upsample_nearest2d,
59-
torch.ops.aten._native_batch_norm_legit.no_stats,
60-
torch.ops.aten._native_batch_norm_legit_functional,
61-
torch.ops.aten._adaptive_avg_pool2d,
62-
torch.ops.aten._adaptive_avg_pool3d,
63-
torch.ops.aten.grid_sampler_2d,
64-
torch.ops.aten.native_group_norm,
65-
torch.ops.aten.native_dropout,
66-
torch.ops.aten.reflection_pad1d,
67-
torch.ops.aten.reflection_pad2d,
68-
torch.ops.aten.reflection_pad3d,
69-
torch.ops.aten.replication_pad1d,
70-
torch.ops.aten.replication_pad2d,
71-
torch.ops.aten.replication_pad3d,
72-
torch.ops.aten.addmm,
73-
])
74-
)
75-
76-
torch._decomp.remove_decompositions(
77-
global_registry.decompositions,
78-
[
79-
torch.ops.aten.roll,
80-
],
81-
)
8254

8355

8456
def lookup(op):
@@ -91,7 +63,3 @@ def inner(lowering: Callable[[context.LoweringContext, ...], Any]):
9163
return lowering
9264

9365
return inner
94-
95-
96-
def decompositions():
97-
return global_registry.decompositions

0 commit comments

Comments
 (0)