77# pyre-unsafe
88
99import operator
10- from typing import cast , final , List
10+ from typing import Callable , cast , Dict , final , List , Optional , Set , Tuple
1111
1212import torch
1313from executorch .backends .aoti .aoti_backend import AotiBackend # usort: skip
2424
2525from torch .fx .passes .operator_support import OperatorSupportBase
2626
27- supported_fallback_operators = []
28-
29- inductor_fallback_ops : dict [str , dict [str , list [str ]]] = {
30- "aten._adaptive_avg_pool2d_backward.default" : {},
31- "aten._adaptive_avg_pool2d.default" : {},
32- "aten._adaptive_avg_pool3d_backward.default" : {},
33- "aten._adaptive_avg_pool3d.default" : {},
34- "aten._addmm_activation.default" : {},
35- "aten._cdist_backward.default" : {},
36- "aten._cdist_forward.default" : {},
37- "aten._cudnn_rnn.default" : {},
38- "aten._dyn_quant_matmul_4bit.default" : {},
39- "aten._dyn_quant_pack_4bit_weight.default" : {},
40- "aten._efficient_attention_backward.default" : {},
41- "aten._efficient_attention_forward.default" : {},
42- "aten._efficientzerotensor.default" : {},
43- "aten._embedding_bag_dense_backward.default" : {},
44- "aten._embedding_bag_forward_only.default" : {},
45- "aten._embedding_bag_per_sample_weights_backward.default" : {},
46- "aten._embedding_bag.default" : {},
47- "aten._fft_c2c.default" : {},
48- "aten._fft_r2c.default" : {},
49- "aten._flash_attention_backward.default" : {},
50- "aten._flash_attention_forward.default" : {},
51- "aten._fused_moving_avg_obs_fq_helper_functional.default" : {},
52- "aten._fused_moving_avg_obs_fq_helper.default" : {},
53- "aten._fused_rms_norm.default" : {},
54- "aten._histogramdd_from_bin_cts.default" : {},
55- "aten._int_mm.out" : {},
56- "aten._pdist_backward.default" : {},
57- "aten._pdist_forward.default" : {},
58- "aten._scaled_dot_product_attention_math_for_mps.default" : {},
59- "aten._scaled_dot_product_cudnn_attention_backward.default" : {},
60- "aten._scaled_dot_product_cudnn_attention.default" : {},
61- "aten._scaled_dot_product_efficient_attention_backward.default" : {},
62- "aten._scaled_dot_product_efficient_attention.default" : {},
63- "aten._scaled_dot_product_flash_attention_backward.default" : {},
64- "aten._scaled_dot_product_flash_attention_for_cpu_backward.default" : {},
65- "aten._scaled_dot_product_flash_attention_for_cpu.default" : {},
66- "aten._scaled_dot_product_flash_attention.default" : {},
67- "aten._scaled_dot_product_fused_attention_overrideable_backward.default" : {},
68- "aten._scaled_dot_product_fused_attention_overrideable.default" : {},
69- "aten._scaled_mm.default" : {},
70- "aten._scaled_mm.out" : {},
71- "aten._segment_reduce_backward.default" : {},
72- "aten._thnn_fused_lstm_cell.default" : {},
73- "aten._to_sparse.default" : {},
74- "aten._trilinear.default" : {},
75- "aten._weight_int4pack_mm.default" : {},
76- "aten._weight_int8pack_mm.default" : {},
77- "aten.abs.default" : {},
78- "aten.adaptive_max_pool2d_backward.default" : {},
79- "aten.adaptive_max_pool2d.default" : {},
80- "aten.adaptive_max_pool3d_backward.default" : {},
81- "aten.adaptive_max_pool3d.default" : {},
82- "aten.add.Scalar" : {},
83- "aten.add.Tensor" : {},
84- "aten.addbmm.default" : {},
85- "aten.addmm.out" : {},
86- "aten.addmv.default" : {},
87- "aten.angle.default" : {},
88- "aten.avg_pool2d_backward.default" : {},
89- "aten.avg_pool2d.default" : {},
90- "aten.avg_pool3d_backward.default" : {},
91- "aten.avg_pool3d.default" : {},
92- "aten.baddbmm.out" : {},
93- "aten.bernoulli_.float" : {},
94- "aten.bernoulli_.Tensor" : {},
95- "aten.bmm.out" : {},
96- "aten.bucketize.Tensor" : {},
97- "aten.cat.default" : {},
98- "aten.cholesky_inverse.default" : {},
99- "aten.cholesky_solve.default" : {},
100- "aten.convolution_backward.default" : {},
101- "aten.convolution.default" : {},
102- "aten.cummax.default" : {},
103- "aten.cummin.default" : {},
104- "aten.cumprod.default" : {},
105- "aten.cumsum.default" : {},
106- "aten.exponential.default" : {},
107- "aten.fill_.Scalar" : {},
108- "aten.fractional_max_pool2d_backward.default" : {},
109- "aten.fractional_max_pool2d.default" : {},
110- "aten.fractional_max_pool3d_backward.default" : {},
111- "aten.fractional_max_pool3d.default" : {},
112- "aten.gcd.default" : {},
113- "aten.geqrf.default" : {},
114- "aten.grid_sampler_2d_backward.default" : {},
115- "aten.hann_window.default" : {},
116- "aten.histc.default" : {},
117- "aten.histogram.bin_ct" : {},
118- "aten.index_put.default" : {},
119- "aten.index_reduce.default" : {},
120- "aten.index.Tensor" : {},
121- "aten.kthvalue.default" : {},
122- "aten.logcumsumexp.default" : {},
123- "aten.lu_unpack.default" : {},
124- "aten.masked_scatter_backward.default" : {},
125- "aten.masked_scatter.default" : {},
126- "aten.masked_select.default" : {},
127- "aten.max_pool2d_with_indices_backward.default" : {},
128- "aten.max_pool2d_with_indices.default" : {},
129- "aten.max_pool3d_with_indices_backward.default" : {},
130- "aten.max_pool3d_with_indices.default" : {},
131- "aten.max_unpool2d.default" : {},
132- "aten.max_unpool3d.default" : {},
133- "aten.median.default" : {},
134- "aten.mm.out" : {},
135- "aten.mode.default" : {},
136- "aten.mul.Scalar" : {},
137- "aten.mul.Tensor" : {},
138- "aten.nanmedian.default" : {},
139- "aten.narrow.default" : {},
140- "aten.native_dropout.default" : {},
141- "aten.nonzero.default" : {},
142- "aten.normal_functional.default" : {},
143- "aten.ormqr.default" : {},
144- "aten.pad.default" : {},
145- "aten.permute.default" : {},
146- "aten.polar.default" : {},
147- "aten.pow.Scalar" : {},
148- "aten.pow.Tensor_Scalar" : {},
149- "aten.pow.Tensor_Tensor" : {},
150- "aten.rand.default" : {},
151- "aten.rand.generator" : {},
152- "aten.randint.default" : {},
153- "aten.randint.generator" : {},
154- "aten.randint.low_out" : {},
155- "aten.randint.low" : {},
156- "aten.randn.default" : {},
157- "aten.randn.generator" : {},
158- "aten.randperm.default" : {},
159- "aten.repeat_interleave.Tensor" : {},
160- "aten.replication_pad1d_backward.default" : {},
161- "aten.replication_pad2d_backward.default" : {},
162- "aten.reshape.default" : {},
163- "aten.resize_.default" : {},
164- "aten.resize_as_.default" : {},
165- "aten.scatter_reduce.two_out" : {},
166- "aten.scatter.src_out" : {},
167- "aten.scatter.value_out" : {},
168- "aten.searchsorted.Scalar" : {},
169- "aten.searchsorted.Tensor" : {},
170- "aten.segment_reduce.default" : {},
171- "aten.set_.source_Tensor" : {},
172- "aten.slice.Tensor" : {},
173- "aten.soft_margin_loss_backward.default" : {},
174- "aten.sort.default" : {},
175- "aten.sort.stable" : {},
176- "aten.squeeze.dim" : {},
177- "aten.to_sparse.default" : {},
178- "aten.topk.default" : {},
179- "aten.triangular_solve.default" : {},
180- "aten.uniform.default" : {},
181- "aten.upsample_bicubic2d_backward.default" : {},
182- "aten.upsample_linear1d_backward.default" : {},
183- "aten.upsample_trilinear3d_backward.default" : {},
184- "aten.view_as_complex.default" : {},
185- "aten.view_as_real.default" : {},
186- "aten.view.dtype" : {},
187- "aten._weight_int4pack_mm_with_scales_and_zeros.default" : {},
27+ # exist fallback operators in et namespace; should map to inductor_fallback_ops
28+ supported_fallback_operators : Dict [str , Dict [str , List [str ]]] = {}
29+
30+ inductor_fallback_ops : Set [str ] = {
31+ "aten._adaptive_avg_pool2d_backward.default" ,
32+ "aten._adaptive_avg_pool2d.default" ,
33+ "aten._adaptive_avg_pool3d_backward.default" ,
34+ "aten._adaptive_avg_pool3d.default" ,
35+ "aten._addmm_activation.default" ,
36+ "aten._cdist_backward.default" ,
37+ "aten._cdist_forward.default" ,
38+ "aten._cudnn_rnn.default" ,
39+ "aten._dyn_quant_matmul_4bit.default" ,
40+ "aten._dyn_quant_pack_4bit_weight.default" ,
41+ "aten._efficient_attention_backward.default" ,
42+ "aten._efficient_attention_forward.default" ,
43+ "aten._efficientzerotensor.default" ,
44+ "aten._embedding_bag_dense_backward.default" ,
45+ "aten._embedding_bag_forward_only.default" ,
46+ "aten._embedding_bag_per_sample_weights_backward.default" ,
47+ "aten._embedding_bag.default" ,
48+ "aten._fft_c2c.default" ,
49+ "aten._fft_r2c.default" ,
50+ "aten._flash_attention_backward.default" ,
51+ "aten._flash_attention_forward.default" ,
52+ "aten._fused_moving_avg_obs_fq_helper_functional.default" ,
53+ "aten._fused_moving_avg_obs_fq_helper.default" ,
54+ "aten._fused_rms_norm.default" ,
55+ "aten._histogramdd_from_bin_cts.default" ,
56+ "aten._int_mm.out" ,
57+ "aten._pdist_backward.default" ,
58+ "aten._pdist_forward.default" ,
59+ "aten._scaled_dot_product_attention_math_for_mps.default" ,
60+ "aten._scaled_dot_product_cudnn_attention_backward.default" ,
61+ "aten._scaled_dot_product_cudnn_attention.default" ,
62+ "aten._scaled_dot_product_efficient_attention_backward.default" ,
63+ "aten._scaled_dot_product_efficient_attention.default" ,
64+ "aten._scaled_dot_product_flash_attention_backward.default" ,
65+ "aten._scaled_dot_product_flash_attention_for_cpu_backward.default" ,
66+ "aten._scaled_dot_product_flash_attention_for_cpu.default" ,
67+ "aten._scaled_dot_product_flash_attention.default" ,
68+ "aten._scaled_dot_product_fused_attention_overrideable_backward.default" ,
69+ "aten._scaled_dot_product_fused_attention_overrideable.default" ,
70+ "aten._scaled_mm.default" ,
71+ "aten._scaled_mm.out" ,
72+ "aten._segment_reduce_backward.default" ,
73+ "aten._thnn_fused_lstm_cell.default" ,
74+ "aten._to_sparse.default" ,
75+ "aten._trilinear.default" ,
76+ "aten._weight_int4pack_mm.default" ,
77+ "aten._weight_int8pack_mm.default" ,
78+ "aten.abs.default" ,
79+ "aten.adaptive_max_pool2d_backward.default" ,
80+ "aten.adaptive_max_pool2d.default" ,
81+ "aten.adaptive_max_pool3d_backward.default" ,
82+ "aten.adaptive_max_pool3d.default" ,
83+ "aten.add.Scalar" ,
84+ "aten.add.Tensor" ,
85+ "aten.addbmm.default" ,
86+ "aten.addmm.out" ,
87+ "aten.addmv.default" ,
88+ "aten.angle.default" ,
89+ "aten.avg_pool2d_backward.default" ,
90+ "aten.avg_pool2d.default" ,
91+ "aten.avg_pool3d_backward.default" ,
92+ "aten.avg_pool3d.default" ,
93+ "aten.baddbmm.out" ,
94+ "aten.bernoulli_.float" ,
95+ "aten.bernoulli_.Tensor" ,
96+ "aten.bmm.out" ,
97+ "aten.bucketize.Tensor" ,
98+ "aten.cat.default" ,
99+ "aten.cholesky_inverse.default" ,
100+ "aten.cholesky_solve.default" ,
101+ "aten.convolution_backward.default" ,
102+ "aten.convolution.default" ,
103+ "aten.cummax.default" ,
104+ "aten.cummin.default" ,
105+ "aten.cumprod.default" ,
106+ "aten.cumsum.default" ,
107+ "aten.exponential.default" ,
108+ "aten.fill_.Scalar" ,
109+ "aten.fractional_max_pool2d_backward.default" ,
110+ "aten.fractional_max_pool2d.default" ,
111+ "aten.fractional_max_pool3d_backward.default" ,
112+ "aten.fractional_max_pool3d.default" ,
113+ "aten.gcd.default" ,
114+ "aten.geqrf.default" ,
115+ "aten.grid_sampler_2d_backward.default" ,
116+ "aten.hann_window.default" ,
117+ "aten.histc.default" ,
118+ "aten.histogram.bin_ct" ,
119+ "aten.index_put.default" ,
120+ "aten.index_reduce.default" ,
121+ "aten.index.Tensor" ,
122+ "aten.kthvalue.default" ,
123+ "aten.logcumsumexp.default" ,
124+ "aten.lu_unpack.default" ,
125+ "aten.masked_scatter_backward.default" ,
126+ "aten.masked_scatter.default" ,
127+ "aten.masked_select.default" ,
128+ "aten.max_pool2d_with_indices_backward.default" ,
129+ "aten.max_pool2d_with_indices.default" ,
130+ "aten.max_pool3d_with_indices_backward.default" ,
131+ "aten.max_pool3d_with_indices.default" ,
132+ "aten.max_unpool2d.default" ,
133+ "aten.max_unpool3d.default" ,
134+ "aten.median.default" ,
135+ "aten.mm.out" ,
136+ "aten.mode.default" ,
137+ "aten.mul.Scalar" ,
138+ "aten.mul.Tensor" ,
139+ "aten.nanmedian.default" ,
140+ "aten.narrow.default" ,
141+ "aten.native_dropout.default" ,
142+ "aten.nonzero.default" ,
143+ "aten.normal_functional.default" ,
144+ "aten.ormqr.default" ,
145+ "aten.pad.default" ,
146+ "aten.permute.default" ,
147+ "aten.polar.default" ,
148+ "aten.pow.Scalar" ,
149+ "aten.pow.Tensor_Scalar" ,
150+ "aten.pow.Tensor_Tensor" ,
151+ "aten.rand.default" ,
152+ "aten.rand.generator" ,
153+ "aten.randint.default" ,
154+ "aten.randint.generator" ,
155+ "aten.randint.low_out" ,
156+ "aten.randint.low" ,
157+ "aten.randn.default" ,
158+ "aten.randn.generator" ,
159+ "aten.randperm.default" ,
160+ "aten.repeat_interleave.Tensor" ,
161+ "aten.replication_pad1d_backward.default" ,
162+ "aten.replication_pad2d_backward.default" ,
163+ "aten.reshape.default" ,
164+ "aten.resize_.default" ,
165+ "aten.resize_as_.default" ,
166+ "aten.scatter_reduce.two_out" ,
167+ "aten.scatter.src_out" ,
168+ "aten.scatter.value_out" ,
169+ "aten.searchsorted.Scalar" ,
170+ "aten.searchsorted.Tensor" ,
171+ "aten.segment_reduce.default" ,
172+ "aten.set_.source_Tensor" ,
173+ "aten.slice.Tensor" ,
174+ "aten.soft_margin_loss_backward.default" ,
175+ "aten.sort.default" ,
176+ "aten.sort.stable" ,
177+ "aten.squeeze.dim" ,
178+ "aten.to_sparse.default" ,
179+ "aten.topk.default" ,
180+ "aten.triangular_solve.default" ,
181+ "aten.uniform.default" ,
182+ "aten.upsample_bicubic2d_backward.default" ,
183+ "aten.upsample_linear1d_backward.default" ,
184+ "aten.upsample_trilinear3d_backward.default" ,
185+ "aten.view_as_complex.default" ,
186+ "aten.view_as_real.default" ,
187+ "aten.view.dtype" ,
188+ "aten._weight_int4pack_mm_with_scales_and_zeros.default" ,
188189}
189190
190191
@@ -193,13 +194,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
193194 supported = node .op == "call_function" and (
194195 node .target == operator .getitem
195196 or node .target ._op not in inductor_fallback_ops
197+ or node .target ._op in supported_fallback_operators
196198 )
197199
198- # if node.op == "call_function" and node.target != operator.getitem:
199- # print(node.target._op)
200- # print(supported)
201- # print('------------------')
202-
203200 return supported
204201
205202 def is_node_supported_custom (self , node : torch .fx .Node ) -> bool :
@@ -248,3 +245,21 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
248245 return PartitionResult (
249246 tagged_exported_program = exported_program , partition_tags = partition_tags
250247 )
248+
249+ def ops_to_not_decompose (
250+ self , ep : ExportedProgram
251+ ) -> Tuple [List [torch ._ops .OpOverload ], Optional [Callable [[torch .fx .Node ], bool ]]]:
252+ """
253+ Return a list of operations that should not be decomposed and let the AOT compiler handle them.
254+ """
255+ do_not_decompose = set ()
256+ op_support = AOTISupportedOperators ()
257+
258+ for node in ep .graph .nodes :
259+ if (
260+ node .op == "call_function"
261+ and isinstance (node .target , torch ._ops .OpOverload )
262+ and op_support .is_node_supported (None , node )
263+ ):
264+ do_not_decompose .add (node .target )
265+ return list (do_not_decompose ), None
0 commit comments