@@ -47,17 +47,16 @@ def __contains__(self, op):
4747 operator .getitem ,
4848]
4949
50- BINARY_OPS = [
50+ SUPPORTS_DYNAMIC_SHAPE = [
51+ # Binary broadcasting
5152 exir_ops .edge .aten .add .Tensor ,
5253 exir_ops .edge .aten .sub .Tensor ,
5354 exir_ops .edge .aten .minimum .default ,
5455 exir_ops .edge .aten .mul .Tensor ,
5556 exir_ops .edge .aten .div .Tensor ,
5657 exir_ops .edge .aten .div .Tensor_mode ,
5758 exir_ops .edge .aten .pow .Tensor_Tensor ,
58- ]
59-
60- UNARY_OPS = [
59+ # Unary elementwise
6160 exir_ops .edge .aten .abs .default ,
6261 exir_ops .edge .aten .clamp .default ,
6362 exir_ops .edge .aten .cos .default ,
@@ -71,60 +70,46 @@ def __contains__(self, op):
7170 exir_ops .edge .aten .sin .default ,
7271 exir_ops .edge .aten .sqrt .default ,
7372 exir_ops .edge .aten .tanh .default ,
74- ]
75-
76- MATMUL_OPS = [
73+ # Matrix Multiplication
7774 exir_ops .edge .aten .bmm .default ,
7875 exir_ops .edge .aten .mm .default ,
7976 exir_ops .edge .aten .addmm .default ,
8077 exir_ops .edge .aten .linear .default ,
81- ]
82-
83- POOLING_OPS = [
78+ # Reduction
79+ exir_ops .edge .aten ._log_softmax .default ,
80+ exir_ops .edge .aten ._softmax .default ,
81+ # 2D Pooling
8482 exir_ops .edge .aten .avg_pool2d .default ,
8583 exir_ops .edge .aten .max_pool2d_with_indices .default ,
86- ]
87-
88- CONVOLUTION_OPS = [
84+ # Convolution
8985 exir_ops .edge .aten .convolution .default ,
9086 exir_ops .edge .et_vk .conv_with_clamp .default ,
9187]
9288
93- REDUCTION_OPS = [
89+ NO_DYNAMIC_SHAPE = [
90+ # Reduction
9491 exir_ops .edge .aten .mean .dim ,
9592 exir_ops .edge .aten .sum .dim_IntList ,
96- exir_ops .edge .aten ._log_softmax .default ,
97- exir_ops .edge .aten ._softmax .default ,
98- ]
99-
100- NORMALIZATION_OPS = [
93+ # Normalization
10194 exir_ops .edge .aten ._native_batch_norm_legit_no_training .default ,
10295 exir_ops .edge .aten .native_layer_norm .default ,
103- ]
104-
105- SHAPE_MANIPULATION_OPS = [
96+ # Shape Manipulation
10697 exir_ops .edge .aten .squeeze_copy .dims ,
10798 exir_ops .edge .aten .unsqueeze_copy .default ,
10899 exir_ops .edge .aten .view_copy .default ,
109100 exir_ops .edge .aten .permute_copy .default ,
110101 exir_ops .edge .aten .t_copy .default ,
111- ]
112-
113- INDEXING_OPS = [
102+ # Indexing and lookup
114103 exir_ops .edge .aten .embedding .default ,
115104 exir_ops .edge .aten .index_select .default ,
116105 exir_ops .edge .aten .select_copy .int ,
117106 exir_ops .edge .aten .slice_copy .Tensor ,
118- ]
119-
120- ORCHESTRATION_OPS = [
107+ # Tensor combination
121108 exir_ops .edge .aten .cat .default ,
122109 exir_ops .edge .aten .split_with_sizes_copy .default ,
123110 exir_ops .edge .aten .split .Tensor ,
124111 exir_ops .edge .aten .repeat .default ,
125- ]
126-
127- CREATION_OPS = [
112+ # Tensor creation
128113 exir_ops .edge .aten .arange .start_step ,
129114 exir_ops .edge .aten .clone .default ,
130115 exir_ops .edge .aten .constant_pad_nd .default ,
@@ -139,39 +124,20 @@ def __contains__(self, op):
139124]
140125
141126
142- def register_prim_ops (ops : OpList ):
143- for op in PRIM_OPS :
144- ops [op ].supports_texture = True
145- ops [op ].supports_buffer = True
146- ops [op ].supports_dynamic_shape = True
127+ def enumerate_supported_ops ():
128+ ops = OpList ()
147129
130+ # Register in order of least to most capabilities
148131
149- def register_no_dynamic_shape_ops (ops : OpList ):
150- for op in [
151- * REDUCTION_OPS ,
152- * NORMALIZATION_OPS ,
153- * SHAPE_MANIPULATION_OPS ,
154- * INDEXING_OPS ,
155- * ORCHESTRATION_OPS ,
156- * CREATION_OPS ,
157- ]:
132+ for op in NO_DYNAMIC_SHAPE :
158133 ops [op ].supports_dynamic_shape = False
159134
160-
161- def register_dynamic_shape_ops (ops : OpList ):
162- for op in [
163- * BINARY_OPS ,
164- * UNARY_OPS ,
165- * MATMUL_OPS ,
166- * POOLING_OPS ,
167- * CONVOLUTION_OPS ,
168- ]:
135+ for op in SUPPORTS_DYNAMIC_SHAPE :
169136 ops [op ].supports_dynamic_shape = True
170137
138+ for op in PRIM_OPS :
139+ ops [op ].supports_texture = True
140+ ops [op ].supports_buffer = True
141+ ops [op ].supports_dynamic_shape = True
171142
172- def enumerate_supported_ops ():
173- ops = OpList ()
174- register_prim_ops (ops )
175- register_no_dynamic_shape_ops (ops )
176- register_dynamic_shape_ops (ops )
177143 return ops
0 commit comments