|  | 
|  | 1 | +# Copyright 2024 The AI Edge Torch Authors. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +# ============================================================================== | 
|  | 15 | +"""Torch export decompositions to run before lowering.""" | 
|  | 16 | + | 
|  | 17 | +import functools | 
|  | 18 | + | 
|  | 19 | +import torch | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +@functools.cache | 
|  | 23 | +def decompositions(): | 
|  | 24 | +  # Base: Core ATen decompositions | 
|  | 25 | +  decompositions = torch._decomp.core_aten_decompositions() | 
|  | 26 | + | 
|  | 27 | +  decompositions.update( | 
|  | 28 | +      torch._decomp.get_decompositions([ | 
|  | 29 | +          torch.ops.aten.upsample_nearest2d, | 
|  | 30 | +          torch.ops.aten._native_batch_norm_legit.no_stats, | 
|  | 31 | +          torch.ops.aten._native_batch_norm_legit_functional, | 
|  | 32 | +          torch.ops.aten._adaptive_avg_pool2d, | 
|  | 33 | +          torch.ops.aten._adaptive_avg_pool3d, | 
|  | 34 | +          torch.ops.aten.grid_sampler_2d, | 
|  | 35 | +          torch.ops.aten.native_group_norm, | 
|  | 36 | +          torch.ops.aten.native_dropout, | 
|  | 37 | +          torch.ops.aten.reflection_pad1d, | 
|  | 38 | +          torch.ops.aten.reflection_pad2d, | 
|  | 39 | +          torch.ops.aten.reflection_pad3d, | 
|  | 40 | +          torch.ops.aten.replication_pad1d, | 
|  | 41 | +          torch.ops.aten.replication_pad2d, | 
|  | 42 | +          torch.ops.aten.replication_pad3d, | 
|  | 43 | +          torch.ops.aten.addmm, | 
|  | 44 | +      ]) | 
|  | 45 | +  ) | 
|  | 46 | + | 
|  | 47 | +  torch._decomp.remove_decompositions( | 
|  | 48 | +      decompositions, | 
|  | 49 | +      [torch.ops.aten.roll], | 
|  | 50 | +  ) | 
|  | 51 | + | 
|  | 52 | +  # Override _safe_softmax decompositions with regular softmax. | 
|  | 53 | +  # _safe_softmax introduces additional check-select ops to guard extreme | 
|  | 54 | +  # input values to softmax, which could make the converted model inefficient | 
|  | 55 | +  # on-device. | 
|  | 56 | +  if hasattr(torch.ops.aten, "_safe_softmax"): | 
|  | 57 | +    decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax | 
|  | 58 | + | 
|  | 59 | +  return decompositions | 
0 commit comments