Skip to content

Commit bba3921

Browse files
Format code
1 parent 1a9417d commit bba3921

File tree

3 files changed

+1527
-913
lines changed

3 files changed

+1527
-913
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 111 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -115,50 +115,119 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
115115
//===----------------------------------------------------------------------===//
116116

117117
bool isaConv1DOp(LinalgOp op);
118-
bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
119-
bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
120-
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
121-
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
122-
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
118+
bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
119+
SmallVector<int64_t> *strides = nullptr);
120+
bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
121+
SmallVector<int64_t> *strides = nullptr);
122+
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op,
123+
SmallVector<int64_t> *dilations = nullptr,
124+
SmallVector<int64_t> *strides = nullptr);
125+
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
126+
SmallVector<int64_t> *dilations = nullptr,
127+
SmallVector<int64_t> *strides = nullptr);
128+
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op,
129+
SmallVector<int64_t> *dilations = nullptr,
130+
SmallVector<int64_t> *strides = nullptr);
123131
bool isaConv2DOp(LinalgOp op);
124-
bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
125-
bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
126-
bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
127-
bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
128-
bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
129-
bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
130-
bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
131-
bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
132-
bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
133-
bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
134-
bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
135-
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
136-
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
137-
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
138-
bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
139-
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
132+
bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
133+
SmallVector<int64_t> *strides = nullptr);
134+
bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
135+
SmallVector<int64_t> *strides = nullptr);
136+
bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
137+
SmallVector<int64_t> *strides = nullptr);
138+
bool isaConv2DNhwcFhwcQOp(LinalgOp op,
139+
SmallVector<int64_t> *dilations = nullptr,
140+
SmallVector<int64_t> *strides = nullptr);
141+
bool isaConv2DNchwFchwQOp(LinalgOp op,
142+
SmallVector<int64_t> *dilations = nullptr,
143+
SmallVector<int64_t> *strides = nullptr);
144+
bool isaConv2DNgchwFgchwOp(LinalgOp op,
145+
SmallVector<int64_t> *dilations = nullptr,
146+
SmallVector<int64_t> *strides = nullptr);
147+
bool isaConv2DNgchwGfchwOp(LinalgOp op,
148+
SmallVector<int64_t> *dilations = nullptr,
149+
SmallVector<int64_t> *strides = nullptr);
150+
bool isaConv2DNhwcHwcfQOp(LinalgOp op,
151+
SmallVector<int64_t> *dilations = nullptr,
152+
SmallVector<int64_t> *strides = nullptr);
153+
bool isaConv2DNhwgcGfhwcQOp(LinalgOp op,
154+
SmallVector<int64_t> *dilations = nullptr,
155+
SmallVector<int64_t> *strides = nullptr);
156+
bool isaConv2DNgchwGfchwQOp(LinalgOp op,
157+
SmallVector<int64_t> *dilations = nullptr,
158+
SmallVector<int64_t> *strides = nullptr);
159+
bool isaConv2DNhwgcGfhwcOp(LinalgOp op,
160+
SmallVector<int64_t> *dilations = nullptr,
161+
SmallVector<int64_t> *strides = nullptr);
162+
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
163+
SmallVector<int64_t> *dilations = nullptr,
164+
SmallVector<int64_t> *strides = nullptr);
165+
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op,
166+
SmallVector<int64_t> *dilations = nullptr,
167+
SmallVector<int64_t> *strides = nullptr);
168+
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op,
169+
SmallVector<int64_t> *dilations = nullptr,
170+
SmallVector<int64_t> *strides = nullptr);
171+
bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op,
172+
SmallVector<int64_t> *dilations = nullptr,
173+
SmallVector<int64_t> *strides = nullptr);
174+
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op,
175+
SmallVector<int64_t> *dilations = nullptr,
176+
SmallVector<int64_t> *strides = nullptr);
140177
bool isaConv3DOp(LinalgOp op);
141-
bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
142-
bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
143-
bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
144-
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
145-
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
146-
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
147-
bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
148-
bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
149-
bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
150-
bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
151-
bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
152-
bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
153-
bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
154-
bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
155-
bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
156-
bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
157-
bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
158-
bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
159-
bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
160-
bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
161-
bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
178+
bool isaConv3DNcdhwFcdhwOp(LinalgOp op,
179+
SmallVector<int64_t> *dilations = nullptr,
180+
SmallVector<int64_t> *strides = nullptr);
181+
bool isaConv3DNdhwcDhwcfOp(LinalgOp op,
182+
SmallVector<int64_t> *dilations = nullptr,
183+
SmallVector<int64_t> *strides = nullptr);
184+
bool isaConv3DNdhwcDhwcfQOp(LinalgOp op,
185+
SmallVector<int64_t> *dilations = nullptr,
186+
SmallVector<int64_t> *strides = nullptr);
187+
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
188+
SmallVector<int64_t> *dilations = nullptr,
189+
SmallVector<int64_t> *strides = nullptr);
190+
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op,
191+
SmallVector<int64_t> *dilations = nullptr,
192+
SmallVector<int64_t> *strides = nullptr);
193+
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op,
194+
SmallVector<int64_t> *dilations = nullptr,
195+
SmallVector<int64_t> *strides = nullptr);
196+
bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
197+
SmallVector<int64_t> *strides = nullptr);
198+
bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
199+
SmallVector<int64_t> *strides = nullptr);
200+
bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
201+
SmallVector<int64_t> *strides = nullptr);
202+
bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
203+
SmallVector<int64_t> *strides = nullptr);
204+
bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
205+
SmallVector<int64_t> *strides = nullptr);
206+
bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
207+
SmallVector<int64_t> *dilations = nullptr,
208+
SmallVector<int64_t> *strides = nullptr);
209+
bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
210+
SmallVector<int64_t> *dilations = nullptr,
211+
SmallVector<int64_t> *strides = nullptr);
212+
bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
213+
SmallVector<int64_t> *strides = nullptr);
214+
bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
215+
SmallVector<int64_t> *strides = nullptr);
216+
bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
217+
SmallVector<int64_t> *strides = nullptr);
218+
bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
219+
SmallVector<int64_t> *strides = nullptr);
220+
bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
221+
SmallVector<int64_t> *strides = nullptr);
222+
bool isaPoolingNdhwcMaxOp(LinalgOp op,
223+
SmallVector<int64_t> *dilations = nullptr,
224+
SmallVector<int64_t> *strides = nullptr);
225+
bool isaPoolingNdhwcMinOp(LinalgOp op,
226+
SmallVector<int64_t> *dilations = nullptr,
227+
SmallVector<int64_t> *strides = nullptr);
228+
bool isaPoolingNdhwcSumOp(LinalgOp op,
229+
SmallVector<int64_t> *dilations = nullptr,
230+
SmallVector<int64_t> *strides = nullptr);
162231

163232
//===----------------------------------------------------------------------===//
164233
// Fusion / Tiling utilities

0 commit comments

Comments
 (0)