Skip to content

Commit 50f0ab8

Browse files
committed
use aoti decomposition on lowable graph
1 parent 23afe52 commit 50f0ab8

File tree

3 files changed

+186
-167
lines changed

3 files changed

+186
-167
lines changed

backends/aoti/aoti_partitioner.py

Lines changed: 182 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-unsafe
88

99
import operator
10-
from typing import cast, final, List
10+
from typing import Callable, cast, Dict, final, List, Optional, Set, Tuple
1111

1212
import torch
1313
from executorch.backends.aoti.aoti_backend import AotiBackend # usort: skip
@@ -24,167 +24,168 @@
2424

2525
from torch.fx.passes.operator_support import OperatorSupportBase
2626

27-
supported_fallback_operators = []
28-
29-
inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
30-
"aten._adaptive_avg_pool2d_backward.default": {},
31-
"aten._adaptive_avg_pool2d.default": {},
32-
"aten._adaptive_avg_pool3d_backward.default": {},
33-
"aten._adaptive_avg_pool3d.default": {},
34-
"aten._addmm_activation.default": {},
35-
"aten._cdist_backward.default": {},
36-
"aten._cdist_forward.default": {},
37-
"aten._cudnn_rnn.default": {},
38-
"aten._dyn_quant_matmul_4bit.default": {},
39-
"aten._dyn_quant_pack_4bit_weight.default": {},
40-
"aten._efficient_attention_backward.default": {},
41-
"aten._efficient_attention_forward.default": {},
42-
"aten._efficientzerotensor.default": {},
43-
"aten._embedding_bag_dense_backward.default": {},
44-
"aten._embedding_bag_forward_only.default": {},
45-
"aten._embedding_bag_per_sample_weights_backward.default": {},
46-
"aten._embedding_bag.default": {},
47-
"aten._fft_c2c.default": {},
48-
"aten._fft_r2c.default": {},
49-
"aten._flash_attention_backward.default": {},
50-
"aten._flash_attention_forward.default": {},
51-
"aten._fused_moving_avg_obs_fq_helper_functional.default": {},
52-
"aten._fused_moving_avg_obs_fq_helper.default": {},
53-
"aten._fused_rms_norm.default": {},
54-
"aten._histogramdd_from_bin_cts.default": {},
55-
"aten._int_mm.out": {},
56-
"aten._pdist_backward.default": {},
57-
"aten._pdist_forward.default": {},
58-
"aten._scaled_dot_product_attention_math_for_mps.default": {},
59-
"aten._scaled_dot_product_cudnn_attention_backward.default": {},
60-
"aten._scaled_dot_product_cudnn_attention.default": {},
61-
"aten._scaled_dot_product_efficient_attention_backward.default": {},
62-
"aten._scaled_dot_product_efficient_attention.default": {},
63-
"aten._scaled_dot_product_flash_attention_backward.default": {},
64-
"aten._scaled_dot_product_flash_attention_for_cpu_backward.default": {},
65-
"aten._scaled_dot_product_flash_attention_for_cpu.default": {},
66-
"aten._scaled_dot_product_flash_attention.default": {},
67-
"aten._scaled_dot_product_fused_attention_overrideable_backward.default": {},
68-
"aten._scaled_dot_product_fused_attention_overrideable.default": {},
69-
"aten._scaled_mm.default": {},
70-
"aten._scaled_mm.out": {},
71-
"aten._segment_reduce_backward.default": {},
72-
"aten._thnn_fused_lstm_cell.default": {},
73-
"aten._to_sparse.default": {},
74-
"aten._trilinear.default": {},
75-
"aten._weight_int4pack_mm.default": {},
76-
"aten._weight_int8pack_mm.default": {},
77-
"aten.abs.default": {},
78-
"aten.adaptive_max_pool2d_backward.default": {},
79-
"aten.adaptive_max_pool2d.default": {},
80-
"aten.adaptive_max_pool3d_backward.default": {},
81-
"aten.adaptive_max_pool3d.default": {},
82-
"aten.add.Scalar": {},
83-
"aten.add.Tensor": {},
84-
"aten.addbmm.default": {},
85-
"aten.addmm.out": {},
86-
"aten.addmv.default": {},
87-
"aten.angle.default": {},
88-
"aten.avg_pool2d_backward.default": {},
89-
"aten.avg_pool2d.default": {},
90-
"aten.avg_pool3d_backward.default": {},
91-
"aten.avg_pool3d.default": {},
92-
"aten.baddbmm.out": {},
93-
"aten.bernoulli_.float": {},
94-
"aten.bernoulli_.Tensor": {},
95-
"aten.bmm.out": {},
96-
"aten.bucketize.Tensor": {},
97-
"aten.cat.default": {},
98-
"aten.cholesky_inverse.default": {},
99-
"aten.cholesky_solve.default": {},
100-
"aten.convolution_backward.default": {},
101-
"aten.convolution.default": {},
102-
"aten.cummax.default": {},
103-
"aten.cummin.default": {},
104-
"aten.cumprod.default": {},
105-
"aten.cumsum.default": {},
106-
"aten.exponential.default": {},
107-
"aten.fill_.Scalar": {},
108-
"aten.fractional_max_pool2d_backward.default": {},
109-
"aten.fractional_max_pool2d.default": {},
110-
"aten.fractional_max_pool3d_backward.default": {},
111-
"aten.fractional_max_pool3d.default": {},
112-
"aten.gcd.default": {},
113-
"aten.geqrf.default": {},
114-
"aten.grid_sampler_2d_backward.default": {},
115-
"aten.hann_window.default": {},
116-
"aten.histc.default": {},
117-
"aten.histogram.bin_ct": {},
118-
"aten.index_put.default": {},
119-
"aten.index_reduce.default": {},
120-
"aten.index.Tensor": {},
121-
"aten.kthvalue.default": {},
122-
"aten.logcumsumexp.default": {},
123-
"aten.lu_unpack.default": {},
124-
"aten.masked_scatter_backward.default": {},
125-
"aten.masked_scatter.default": {},
126-
"aten.masked_select.default": {},
127-
"aten.max_pool2d_with_indices_backward.default": {},
128-
"aten.max_pool2d_with_indices.default": {},
129-
"aten.max_pool3d_with_indices_backward.default": {},
130-
"aten.max_pool3d_with_indices.default": {},
131-
"aten.max_unpool2d.default": {},
132-
"aten.max_unpool3d.default": {},
133-
"aten.median.default": {},
134-
"aten.mm.out": {},
135-
"aten.mode.default": {},
136-
"aten.mul.Scalar": {},
137-
"aten.mul.Tensor": {},
138-
"aten.nanmedian.default": {},
139-
"aten.narrow.default": {},
140-
"aten.native_dropout.default": {},
141-
"aten.nonzero.default": {},
142-
"aten.normal_functional.default": {},
143-
"aten.ormqr.default": {},
144-
"aten.pad.default": {},
145-
"aten.permute.default": {},
146-
"aten.polar.default": {},
147-
"aten.pow.Scalar": {},
148-
"aten.pow.Tensor_Scalar": {},
149-
"aten.pow.Tensor_Tensor": {},
150-
"aten.rand.default": {},
151-
"aten.rand.generator": {},
152-
"aten.randint.default": {},
153-
"aten.randint.generator": {},
154-
"aten.randint.low_out": {},
155-
"aten.randint.low": {},
156-
"aten.randn.default": {},
157-
"aten.randn.generator": {},
158-
"aten.randperm.default": {},
159-
"aten.repeat_interleave.Tensor": {},
160-
"aten.replication_pad1d_backward.default": {},
161-
"aten.replication_pad2d_backward.default": {},
162-
"aten.reshape.default": {},
163-
"aten.resize_.default": {},
164-
"aten.resize_as_.default": {},
165-
"aten.scatter_reduce.two_out": {},
166-
"aten.scatter.src_out": {},
167-
"aten.scatter.value_out": {},
168-
"aten.searchsorted.Scalar": {},
169-
"aten.searchsorted.Tensor": {},
170-
"aten.segment_reduce.default": {},
171-
"aten.set_.source_Tensor": {},
172-
"aten.slice.Tensor": {},
173-
"aten.soft_margin_loss_backward.default": {},
174-
"aten.sort.default": {},
175-
"aten.sort.stable": {},
176-
"aten.squeeze.dim": {},
177-
"aten.to_sparse.default": {},
178-
"aten.topk.default": {},
179-
"aten.triangular_solve.default": {},
180-
"aten.uniform.default": {},
181-
"aten.upsample_bicubic2d_backward.default": {},
182-
"aten.upsample_linear1d_backward.default": {},
183-
"aten.upsample_trilinear3d_backward.default": {},
184-
"aten.view_as_complex.default": {},
185-
"aten.view_as_real.default": {},
186-
"aten.view.dtype": {},
187-
"aten._weight_int4pack_mm_with_scales_and_zeros.default": {},
27+
# exist fallback operators in et namespace; should map to inductor_fallback_ops
28+
supported_fallback_operators: Dict[str, Dict[str, List[str]]] = {}
29+
30+
inductor_fallback_ops: Set[str] = {
31+
"aten._adaptive_avg_pool2d_backward.default",
32+
"aten._adaptive_avg_pool2d.default",
33+
"aten._adaptive_avg_pool3d_backward.default",
34+
"aten._adaptive_avg_pool3d.default",
35+
"aten._addmm_activation.default",
36+
"aten._cdist_backward.default",
37+
"aten._cdist_forward.default",
38+
"aten._cudnn_rnn.default",
39+
"aten._dyn_quant_matmul_4bit.default",
40+
"aten._dyn_quant_pack_4bit_weight.default",
41+
"aten._efficient_attention_backward.default",
42+
"aten._efficient_attention_forward.default",
43+
"aten._efficientzerotensor.default",
44+
"aten._embedding_bag_dense_backward.default",
45+
"aten._embedding_bag_forward_only.default",
46+
"aten._embedding_bag_per_sample_weights_backward.default",
47+
"aten._embedding_bag.default",
48+
"aten._fft_c2c.default",
49+
"aten._fft_r2c.default",
50+
"aten._flash_attention_backward.default",
51+
"aten._flash_attention_forward.default",
52+
"aten._fused_moving_avg_obs_fq_helper_functional.default",
53+
"aten._fused_moving_avg_obs_fq_helper.default",
54+
"aten._fused_rms_norm.default",
55+
"aten._histogramdd_from_bin_cts.default",
56+
"aten._int_mm.out",
57+
"aten._pdist_backward.default",
58+
"aten._pdist_forward.default",
59+
"aten._scaled_dot_product_attention_math_for_mps.default",
60+
"aten._scaled_dot_product_cudnn_attention_backward.default",
61+
"aten._scaled_dot_product_cudnn_attention.default",
62+
"aten._scaled_dot_product_efficient_attention_backward.default",
63+
"aten._scaled_dot_product_efficient_attention.default",
64+
"aten._scaled_dot_product_flash_attention_backward.default",
65+
"aten._scaled_dot_product_flash_attention_for_cpu_backward.default",
66+
"aten._scaled_dot_product_flash_attention_for_cpu.default",
67+
"aten._scaled_dot_product_flash_attention.default",
68+
"aten._scaled_dot_product_fused_attention_overrideable_backward.default",
69+
"aten._scaled_dot_product_fused_attention_overrideable.default",
70+
"aten._scaled_mm.default",
71+
"aten._scaled_mm.out",
72+
"aten._segment_reduce_backward.default",
73+
"aten._thnn_fused_lstm_cell.default",
74+
"aten._to_sparse.default",
75+
"aten._trilinear.default",
76+
"aten._weight_int4pack_mm.default",
77+
"aten._weight_int8pack_mm.default",
78+
"aten.abs.default",
79+
"aten.adaptive_max_pool2d_backward.default",
80+
"aten.adaptive_max_pool2d.default",
81+
"aten.adaptive_max_pool3d_backward.default",
82+
"aten.adaptive_max_pool3d.default",
83+
"aten.add.Scalar",
84+
"aten.add.Tensor",
85+
"aten.addbmm.default",
86+
"aten.addmm.out",
87+
"aten.addmv.default",
88+
"aten.angle.default",
89+
"aten.avg_pool2d_backward.default",
90+
"aten.avg_pool2d.default",
91+
"aten.avg_pool3d_backward.default",
92+
"aten.avg_pool3d.default",
93+
"aten.baddbmm.out",
94+
"aten.bernoulli_.float",
95+
"aten.bernoulli_.Tensor",
96+
"aten.bmm.out",
97+
"aten.bucketize.Tensor",
98+
"aten.cat.default",
99+
"aten.cholesky_inverse.default",
100+
"aten.cholesky_solve.default",
101+
"aten.convolution_backward.default",
102+
"aten.convolution.default",
103+
"aten.cummax.default",
104+
"aten.cummin.default",
105+
"aten.cumprod.default",
106+
"aten.cumsum.default",
107+
"aten.exponential.default",
108+
"aten.fill_.Scalar",
109+
"aten.fractional_max_pool2d_backward.default",
110+
"aten.fractional_max_pool2d.default",
111+
"aten.fractional_max_pool3d_backward.default",
112+
"aten.fractional_max_pool3d.default",
113+
"aten.gcd.default",
114+
"aten.geqrf.default",
115+
"aten.grid_sampler_2d_backward.default",
116+
"aten.hann_window.default",
117+
"aten.histc.default",
118+
"aten.histogram.bin_ct",
119+
"aten.index_put.default",
120+
"aten.index_reduce.default",
121+
"aten.index.Tensor",
122+
"aten.kthvalue.default",
123+
"aten.logcumsumexp.default",
124+
"aten.lu_unpack.default",
125+
"aten.masked_scatter_backward.default",
126+
"aten.masked_scatter.default",
127+
"aten.masked_select.default",
128+
"aten.max_pool2d_with_indices_backward.default",
129+
"aten.max_pool2d_with_indices.default",
130+
"aten.max_pool3d_with_indices_backward.default",
131+
"aten.max_pool3d_with_indices.default",
132+
"aten.max_unpool2d.default",
133+
"aten.max_unpool3d.default",
134+
"aten.median.default",
135+
"aten.mm.out",
136+
"aten.mode.default",
137+
"aten.mul.Scalar",
138+
"aten.mul.Tensor",
139+
"aten.nanmedian.default",
140+
"aten.narrow.default",
141+
"aten.native_dropout.default",
142+
"aten.nonzero.default",
143+
"aten.normal_functional.default",
144+
"aten.ormqr.default",
145+
"aten.pad.default",
146+
"aten.permute.default",
147+
"aten.polar.default",
148+
"aten.pow.Scalar",
149+
"aten.pow.Tensor_Scalar",
150+
"aten.pow.Tensor_Tensor",
151+
"aten.rand.default",
152+
"aten.rand.generator",
153+
"aten.randint.default",
154+
"aten.randint.generator",
155+
"aten.randint.low_out",
156+
"aten.randint.low",
157+
"aten.randn.default",
158+
"aten.randn.generator",
159+
"aten.randperm.default",
160+
"aten.repeat_interleave.Tensor",
161+
"aten.replication_pad1d_backward.default",
162+
"aten.replication_pad2d_backward.default",
163+
"aten.reshape.default",
164+
"aten.resize_.default",
165+
"aten.resize_as_.default",
166+
"aten.scatter_reduce.two_out",
167+
"aten.scatter.src_out",
168+
"aten.scatter.value_out",
169+
"aten.searchsorted.Scalar",
170+
"aten.searchsorted.Tensor",
171+
"aten.segment_reduce.default",
172+
"aten.set_.source_Tensor",
173+
"aten.slice.Tensor",
174+
"aten.soft_margin_loss_backward.default",
175+
"aten.sort.default",
176+
"aten.sort.stable",
177+
"aten.squeeze.dim",
178+
"aten.to_sparse.default",
179+
"aten.topk.default",
180+
"aten.triangular_solve.default",
181+
"aten.uniform.default",
182+
"aten.upsample_bicubic2d_backward.default",
183+
"aten.upsample_linear1d_backward.default",
184+
"aten.upsample_trilinear3d_backward.default",
185+
"aten.view_as_complex.default",
186+
"aten.view_as_real.default",
187+
"aten.view.dtype",
188+
"aten._weight_int4pack_mm_with_scales_and_zeros.default",
188189
}
189190

190191

@@ -193,13 +194,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
193194
supported = node.op == "call_function" and (
194195
node.target == operator.getitem
195196
or node.target._op not in inductor_fallback_ops
197+
or node.target._op in supported_fallback_operators
196198
)
197199

198-
# if node.op == "call_function" and node.target != operator.getitem:
199-
# print(node.target._op)
200-
# print(supported)
201-
# print('------------------')
202-
203200
return supported
204201

205202
def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
@@ -248,3 +245,21 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
248245
return PartitionResult(
249246
tagged_exported_program=exported_program, partition_tags=partition_tags
250247
)
248+
249+
def ops_to_not_decompose(
250+
self, ep: ExportedProgram
251+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
252+
"""
253+
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
254+
"""
255+
do_not_decompose = set()
256+
op_support = AOTISupportedOperators()
257+
258+
for node in ep.graph.nodes:
259+
if (
260+
node.op == "call_function"
261+
and isinstance(node.target, torch._ops.OpOverload)
262+
and op_support.is_node_supported(None, node)
263+
):
264+
do_not_decompose.add(node.target)
265+
return list(do_not_decompose), None
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 187af0d41fe75d08d2a7ec84c1b4d24b9b641ed2

export_aoti.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def export_model(model, example_inputs, output_filename="aoti_model.pte"):
143143
# edge_program = edge_program.to_backend(AotiPartitioner([]))
144144
# print("To backend done.")
145145

146+
# aoti part should be decomposed by the internal torch._inductor.aot_compile
147+
# we should preserve the lowerable part and waiting for aoti backend handle that
148+
# Q: maybe need to turn on fallback_random?
146149
edge_program = to_edge_transform_and_lower(
147150
aten_dialect, partitioner=[AotiPartitioner([])]
148151
)

0 commit comments

Comments
 (0)