Skip to content

Commit f1b8e80

Browse files
Make use of dilations/strides info
1 parent 5aeb371 commit f1b8e80

File tree

3 files changed

+548
-437
lines changed

3 files changed

+548
-437
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -115,50 +115,50 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
115115
//===----------------------------------------------------------------------===//
116116

117117
bool isaConv1DOp(LinalgOp op);
118-
bool isaConv1DNwcWcfOp(LinalgOp op);
119-
bool isaConv1DNcwFcwOp(LinalgOp op);
120-
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
121-
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
122-
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
118+
bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
119+
bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
120+
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
121+
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
122+
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
123123
bool isaConv2DOp(LinalgOp op);
124-
bool isaConv2DNhwcFhwcOp(LinalgOp op);
125-
bool isaConv2DNhwcHwcfOp(LinalgOp op);
126-
bool isaConv2DNchwFchwOp(LinalgOp op);
127-
bool isaConv2DNhwcFhwcQOp(LinalgOp op);
128-
bool isaConv2DNchwFchwQOp(LinalgOp op);
129-
bool isaConv2DNgchwFgchwOp(LinalgOp op);
130-
bool isaConv2DNgchwGfchwOp(LinalgOp op);
131-
bool isaConv2DNhwcHwcfQOp(LinalgOp op);
132-
bool isaConv2DNhwgcGfhwcQOp(LinalgOp op);
133-
bool isaConv2DNgchwGfchwQOp(LinalgOp op);
134-
bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
135-
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
136-
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
137-
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
138-
bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op);
139-
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
124+
bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
125+
bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
126+
bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
127+
bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
128+
bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
129+
bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
130+
bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
131+
bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
132+
bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
133+
bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
134+
bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
135+
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
136+
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
137+
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
138+
bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
139+
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
140140
bool isaConv3DOp(LinalgOp op);
141-
bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
142-
bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
143-
bool isaConv3DNdhwcDhwcfQOp(LinalgOp op);
144-
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
145-
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
146-
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
147-
bool isaPoolingNchwMaxOp(LinalgOp op);
148-
bool isaPoolingNchwSumOp(LinalgOp op);
149-
bool isaPoolingNhwcMaxOp(LinalgOp op);
150-
bool isaPoolingNhwcMinOp(LinalgOp op);
151-
bool isaPoolingNhwcSumOp(LinalgOp op);
152-
bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op);
153-
bool isaPoolingNhwcMinUnsignedOp(LinalgOp op);
154-
bool isaPoolingNcwMaxOp(LinalgOp op);
155-
bool isaPoolingNcwSumOp(LinalgOp op);
156-
bool isaPoolingNwcMaxOp(LinalgOp op);
157-
bool isaPoolingNwcMinOp(LinalgOp op);
158-
bool isaPoolingNwcSumOp(LinalgOp op);
159-
bool isaPoolingNdhwcMaxOp(LinalgOp op);
160-
bool isaPoolingNdhwcMinOp(LinalgOp op);
161-
bool isaPoolingNdhwcSumOp(LinalgOp op);
141+
bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
142+
bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
143+
bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
144+
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
145+
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
146+
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
147+
bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
148+
bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
149+
bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
150+
bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
151+
bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
152+
bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
153+
bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
154+
bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
155+
bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
156+
bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
157+
bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
158+
bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
159+
bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
160+
bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
161+
bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
162162

163163
//===----------------------------------------------------------------------===//
164164
// Fusion / Tiling utilities

0 commit comments

Comments
 (0)