Skip to content

Commit 3852dc4

Browse files
Export just a single API
1 parent bba3921 commit 3852dc4

File tree

3 files changed

+450
-250
lines changed

3 files changed

+450
-250
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 3 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -111,123 +111,13 @@ std::optional<SmallVector<ReassociationIndices>>
111111
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
112112

113113
//===----------------------------------------------------------------------===//
114-
// Convolution matcher utilities
114+
// Convolution matcher utility
115115
//===----------------------------------------------------------------------===//
116116

117-
bool isaConv1DOp(LinalgOp op);
118-
bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
119-
SmallVector<int64_t> *strides = nullptr);
120-
bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
121-
SmallVector<int64_t> *strides = nullptr);
122-
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op,
123-
SmallVector<int64_t> *dilations = nullptr,
124-
SmallVector<int64_t> *strides = nullptr);
125-
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
126-
SmallVector<int64_t> *dilations = nullptr,
127-
SmallVector<int64_t> *strides = nullptr);
128-
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op,
129-
SmallVector<int64_t> *dilations = nullptr,
130-
SmallVector<int64_t> *strides = nullptr);
131-
bool isaConv2DOp(LinalgOp op);
132-
bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
133-
SmallVector<int64_t> *strides = nullptr);
134-
bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
135-
SmallVector<int64_t> *strides = nullptr);
136-
bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
137-
SmallVector<int64_t> *strides = nullptr);
138-
bool isaConv2DNhwcFhwcQOp(LinalgOp op,
139-
SmallVector<int64_t> *dilations = nullptr,
140-
SmallVector<int64_t> *strides = nullptr);
141-
bool isaConv2DNchwFchwQOp(LinalgOp op,
142-
SmallVector<int64_t> *dilations = nullptr,
143-
SmallVector<int64_t> *strides = nullptr);
144-
bool isaConv2DNgchwFgchwOp(LinalgOp op,
145-
SmallVector<int64_t> *dilations = nullptr,
146-
SmallVector<int64_t> *strides = nullptr);
147-
bool isaConv2DNgchwGfchwOp(LinalgOp op,
148-
SmallVector<int64_t> *dilations = nullptr,
149-
SmallVector<int64_t> *strides = nullptr);
150-
bool isaConv2DNhwcHwcfQOp(LinalgOp op,
151-
SmallVector<int64_t> *dilations = nullptr,
152-
SmallVector<int64_t> *strides = nullptr);
153-
bool isaConv2DNhwgcGfhwcQOp(LinalgOp op,
117+
template <typename ConvOpTy>
118+
bool isaConvolutionOpOfType(LinalgOp op,
154119
SmallVector<int64_t> *dilations = nullptr,
155120
SmallVector<int64_t> *strides = nullptr);
156-
bool isaConv2DNgchwGfchwQOp(LinalgOp op,
157-
SmallVector<int64_t> *dilations = nullptr,
158-
SmallVector<int64_t> *strides = nullptr);
159-
bool isaConv2DNhwgcGfhwcOp(LinalgOp op,
160-
SmallVector<int64_t> *dilations = nullptr,
161-
SmallVector<int64_t> *strides = nullptr);
162-
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
163-
SmallVector<int64_t> *dilations = nullptr,
164-
SmallVector<int64_t> *strides = nullptr);
165-
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op,
166-
SmallVector<int64_t> *dilations = nullptr,
167-
SmallVector<int64_t> *strides = nullptr);
168-
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op,
169-
SmallVector<int64_t> *dilations = nullptr,
170-
SmallVector<int64_t> *strides = nullptr);
171-
bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op,
172-
SmallVector<int64_t> *dilations = nullptr,
173-
SmallVector<int64_t> *strides = nullptr);
174-
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op,
175-
SmallVector<int64_t> *dilations = nullptr,
176-
SmallVector<int64_t> *strides = nullptr);
177-
bool isaConv3DOp(LinalgOp op);
178-
bool isaConv3DNcdhwFcdhwOp(LinalgOp op,
179-
SmallVector<int64_t> *dilations = nullptr,
180-
SmallVector<int64_t> *strides = nullptr);
181-
bool isaConv3DNdhwcDhwcfOp(LinalgOp op,
182-
SmallVector<int64_t> *dilations = nullptr,
183-
SmallVector<int64_t> *strides = nullptr);
184-
bool isaConv3DNdhwcDhwcfQOp(LinalgOp op,
185-
SmallVector<int64_t> *dilations = nullptr,
186-
SmallVector<int64_t> *strides = nullptr);
187-
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
188-
SmallVector<int64_t> *dilations = nullptr,
189-
SmallVector<int64_t> *strides = nullptr);
190-
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op,
191-
SmallVector<int64_t> *dilations = nullptr,
192-
SmallVector<int64_t> *strides = nullptr);
193-
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op,
194-
SmallVector<int64_t> *dilations = nullptr,
195-
SmallVector<int64_t> *strides = nullptr);
196-
bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
197-
SmallVector<int64_t> *strides = nullptr);
198-
bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
199-
SmallVector<int64_t> *strides = nullptr);
200-
bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
201-
SmallVector<int64_t> *strides = nullptr);
202-
bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
203-
SmallVector<int64_t> *strides = nullptr);
204-
bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
205-
SmallVector<int64_t> *strides = nullptr);
206-
bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
207-
SmallVector<int64_t> *dilations = nullptr,
208-
SmallVector<int64_t> *strides = nullptr);
209-
bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
210-
SmallVector<int64_t> *dilations = nullptr,
211-
SmallVector<int64_t> *strides = nullptr);
212-
bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
213-
SmallVector<int64_t> *strides = nullptr);
214-
bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
215-
SmallVector<int64_t> *strides = nullptr);
216-
bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
217-
SmallVector<int64_t> *strides = nullptr);
218-
bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
219-
SmallVector<int64_t> *strides = nullptr);
220-
bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
221-
SmallVector<int64_t> *strides = nullptr);
222-
bool isaPoolingNdhwcMaxOp(LinalgOp op,
223-
SmallVector<int64_t> *dilations = nullptr,
224-
SmallVector<int64_t> *strides = nullptr);
225-
bool isaPoolingNdhwcMinOp(LinalgOp op,
226-
SmallVector<int64_t> *dilations = nullptr,
227-
SmallVector<int64_t> *strides = nullptr);
228-
bool isaPoolingNdhwcSumOp(LinalgOp op,
229-
SmallVector<int64_t> *dilations = nullptr,
230-
SmallVector<int64_t> *strides = nullptr);
231121

232122
//===----------------------------------------------------------------------===//
233123
// Fusion / Tiling utilities

0 commit comments

Comments
 (0)