@@ -115,50 +115,119 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
115115// ===----------------------------------------------------------------------===//
116116
117117bool isaConv1DOp (LinalgOp op);
118- bool isaConv1DNwcWcfOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
119- bool isaConv1DNcwFcwOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
120- bool isaDepthwiseConv1DNcwCwOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
121- bool isaDepthwiseConv1DNwcWcOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
122- bool isaDepthwiseConv1DNwcWcmOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
118+ bool isaConv1DNwcWcfOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
119+ SmallVector<int64_t > *strides = nullptr );
120+ bool isaConv1DNcwFcwOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
121+ SmallVector<int64_t > *strides = nullptr );
122+ bool isaDepthwiseConv1DNcwCwOp (LinalgOp op,
123+ SmallVector<int64_t > *dilations = nullptr ,
124+ SmallVector<int64_t > *strides = nullptr );
125+ bool isaDepthwiseConv1DNwcWcOp (LinalgOp op,
126+ SmallVector<int64_t > *dilations = nullptr ,
127+ SmallVector<int64_t > *strides = nullptr );
128+ bool isaDepthwiseConv1DNwcWcmOp (LinalgOp op,
129+ SmallVector<int64_t > *dilations = nullptr ,
130+ SmallVector<int64_t > *strides = nullptr );
123131bool isaConv2DOp (LinalgOp op);
124- bool isaConv2DNhwcFhwcOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
125- bool isaConv2DNhwcHwcfOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
126- bool isaConv2DNchwFchwOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
127- bool isaConv2DNhwcFhwcQOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
128- bool isaConv2DNchwFchwQOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
129- bool isaConv2DNgchwFgchwOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
130- bool isaConv2DNgchwGfchwOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
131- bool isaConv2DNhwcHwcfQOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
132- bool isaConv2DNhwgcGfhwcQOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
133- bool isaConv2DNgchwGfchwQOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
134- bool isaConv2DNhwgcGfhwcOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
135- bool isaDepthwiseConv2DNchwChwOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
136- bool isaDepthwiseConv2DNhwcHwcOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
137- bool isaDepthwiseConv2DNhwcHwcmOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
138- bool isaDepthwiseConv2DNhwcHwcQOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
139- bool isaDepthwiseConv2DNhwcHwcmQOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
132+ bool isaConv2DNhwcFhwcOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
133+ SmallVector<int64_t > *strides = nullptr );
134+ bool isaConv2DNhwcHwcfOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
135+ SmallVector<int64_t > *strides = nullptr );
136+ bool isaConv2DNchwFchwOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
137+ SmallVector<int64_t > *strides = nullptr );
138+ bool isaConv2DNhwcFhwcQOp (LinalgOp op,
139+ SmallVector<int64_t > *dilations = nullptr ,
140+ SmallVector<int64_t > *strides = nullptr );
141+ bool isaConv2DNchwFchwQOp (LinalgOp op,
142+ SmallVector<int64_t > *dilations = nullptr ,
143+ SmallVector<int64_t > *strides = nullptr );
144+ bool isaConv2DNgchwFgchwOp (LinalgOp op,
145+ SmallVector<int64_t > *dilations = nullptr ,
146+ SmallVector<int64_t > *strides = nullptr );
147+ bool isaConv2DNgchwGfchwOp (LinalgOp op,
148+ SmallVector<int64_t > *dilations = nullptr ,
149+ SmallVector<int64_t > *strides = nullptr );
150+ bool isaConv2DNhwcHwcfQOp (LinalgOp op,
151+ SmallVector<int64_t > *dilations = nullptr ,
152+ SmallVector<int64_t > *strides = nullptr );
153+ bool isaConv2DNhwgcGfhwcQOp (LinalgOp op,
154+ SmallVector<int64_t > *dilations = nullptr ,
155+ SmallVector<int64_t > *strides = nullptr );
156+ bool isaConv2DNgchwGfchwQOp (LinalgOp op,
157+ SmallVector<int64_t > *dilations = nullptr ,
158+ SmallVector<int64_t > *strides = nullptr );
159+ bool isaConv2DNhwgcGfhwcOp (LinalgOp op,
160+ SmallVector<int64_t > *dilations = nullptr ,
161+ SmallVector<int64_t > *strides = nullptr );
162+ bool isaDepthwiseConv2DNchwChwOp (LinalgOp op,
163+ SmallVector<int64_t > *dilations = nullptr ,
164+ SmallVector<int64_t > *strides = nullptr );
165+ bool isaDepthwiseConv2DNhwcHwcOp (LinalgOp op,
166+ SmallVector<int64_t > *dilations = nullptr ,
167+ SmallVector<int64_t > *strides = nullptr );
168+ bool isaDepthwiseConv2DNhwcHwcmOp (LinalgOp op,
169+ SmallVector<int64_t > *dilations = nullptr ,
170+ SmallVector<int64_t > *strides = nullptr );
171+ bool isaDepthwiseConv2DNhwcHwcQOp (LinalgOp op,
172+ SmallVector<int64_t > *dilations = nullptr ,
173+ SmallVector<int64_t > *strides = nullptr );
174+ bool isaDepthwiseConv2DNhwcHwcmQOp (LinalgOp op,
175+ SmallVector<int64_t > *dilations = nullptr ,
176+ SmallVector<int64_t > *strides = nullptr );
140177bool isaConv3DOp (LinalgOp op);
141- bool isaConv3DNcdhwFcdhwOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
142- bool isaConv3DNdhwcDhwcfOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
143- bool isaConv3DNdhwcDhwcfQOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
144- bool isaDepthwiseConv3DNdhwcDhwcmOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
145- bool isaDepthwiseConv3DNcdhwCdhwOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
146- bool isaDepthwiseConv3DNdhwcDhwcOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
147- bool isaPoolingNchwMaxOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
148- bool isaPoolingNchwSumOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
149- bool isaPoolingNhwcMaxOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
150- bool isaPoolingNhwcMinOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
151- bool isaPoolingNhwcSumOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
152- bool isaPoolingNhwcMaxUnsignedOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
153- bool isaPoolingNhwcMinUnsignedOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
154- bool isaPoolingNcwMaxOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
155- bool isaPoolingNcwSumOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
156- bool isaPoolingNwcMaxOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
157- bool isaPoolingNwcMinOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
158- bool isaPoolingNwcSumOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
159- bool isaPoolingNdhwcMaxOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
160- bool isaPoolingNdhwcMinOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
161- bool isaPoolingNdhwcSumOp (LinalgOp op, SmallVector<int64_t >* dilations = nullptr , SmallVector<int64_t >* strides = nullptr );
178+ bool isaConv3DNcdhwFcdhwOp (LinalgOp op,
179+ SmallVector<int64_t > *dilations = nullptr ,
180+ SmallVector<int64_t > *strides = nullptr );
181+ bool isaConv3DNdhwcDhwcfOp (LinalgOp op,
182+ SmallVector<int64_t > *dilations = nullptr ,
183+ SmallVector<int64_t > *strides = nullptr );
184+ bool isaConv3DNdhwcDhwcfQOp (LinalgOp op,
185+ SmallVector<int64_t > *dilations = nullptr ,
186+ SmallVector<int64_t > *strides = nullptr );
187+ bool isaDepthwiseConv3DNdhwcDhwcmOp (LinalgOp op,
188+ SmallVector<int64_t > *dilations = nullptr ,
189+ SmallVector<int64_t > *strides = nullptr );
190+ bool isaDepthwiseConv3DNcdhwCdhwOp (LinalgOp op,
191+ SmallVector<int64_t > *dilations = nullptr ,
192+ SmallVector<int64_t > *strides = nullptr );
193+ bool isaDepthwiseConv3DNdhwcDhwcOp (LinalgOp op,
194+ SmallVector<int64_t > *dilations = nullptr ,
195+ SmallVector<int64_t > *strides = nullptr );
196+ bool isaPoolingNchwMaxOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
197+ SmallVector<int64_t > *strides = nullptr );
198+ bool isaPoolingNchwSumOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
199+ SmallVector<int64_t > *strides = nullptr );
200+ bool isaPoolingNhwcMaxOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
201+ SmallVector<int64_t > *strides = nullptr );
202+ bool isaPoolingNhwcMinOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
203+ SmallVector<int64_t > *strides = nullptr );
204+ bool isaPoolingNhwcSumOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
205+ SmallVector<int64_t > *strides = nullptr );
206+ bool isaPoolingNhwcMaxUnsignedOp (LinalgOp op,
207+ SmallVector<int64_t > *dilations = nullptr ,
208+ SmallVector<int64_t > *strides = nullptr );
209+ bool isaPoolingNhwcMinUnsignedOp (LinalgOp op,
210+ SmallVector<int64_t > *dilations = nullptr ,
211+ SmallVector<int64_t > *strides = nullptr );
212+ bool isaPoolingNcwMaxOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
213+ SmallVector<int64_t > *strides = nullptr );
214+ bool isaPoolingNcwSumOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
215+ SmallVector<int64_t > *strides = nullptr );
216+ bool isaPoolingNwcMaxOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
217+ SmallVector<int64_t > *strides = nullptr );
218+ bool isaPoolingNwcMinOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
219+ SmallVector<int64_t > *strides = nullptr );
220+ bool isaPoolingNwcSumOp (LinalgOp op, SmallVector<int64_t > *dilations = nullptr ,
221+ SmallVector<int64_t > *strides = nullptr );
222+ bool isaPoolingNdhwcMaxOp (LinalgOp op,
223+ SmallVector<int64_t > *dilations = nullptr ,
224+ SmallVector<int64_t > *strides = nullptr );
225+ bool isaPoolingNdhwcMinOp (LinalgOp op,
226+ SmallVector<int64_t > *dilations = nullptr ,
227+ SmallVector<int64_t > *strides = nullptr );
228+ bool isaPoolingNdhwcSumOp (LinalgOp op,
229+ SmallVector<int64_t > *dilations = nullptr ,
230+ SmallVector<int64_t > *strides = nullptr );
162231
163232// ===----------------------------------------------------------------------===//
164233// Fusion / Tiling utilities
0 commit comments