Skip to content

Commit a33b211

Browse files
committed
Blocking conv2d from vectorization pass
1 parent 6c2e170 commit a33b211

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1990,8 +1990,18 @@ static LogicalResult vectorizeLinalgOpPrecondition(
19901990
// TODO: isaConvolutionOpInterface that can also infer from generic
19911991
// features. But we will still need stride/dilation attributes that will be
19921992
// annoying to reverse-engineer...
1993-
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1993+
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
1994+
// Check if it is 2d+ convolution. If it is, return failure because we don't
1995+
// support it. To use this pass on a 2d+ convolution, it should have already
1996+
// been decomposed to 1d convolution via
1997+
// DecomposeConvolutionToLowerDimOpsPass.
1998+
if (linalgOp.getNumParallelLoops() >= 4) {
1999+
LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
2000+
return failure();
2001+
}
19942002
return success();
2003+
}
2004+
19952005
// TODO: the common vector shape is equal to the static loop sizes only when
19962006
// all indexing maps are projected permutations. For convs and stencils the
19972007
// logic will need to evolve.
@@ -3929,9 +3939,11 @@ static FailureOr<Operation *> vectorizeConvolution(
39293939
if (!inputVecSizes.empty()) {
39303940
// Only use the input vector size corresponding to the channel dim. Other
39313941
// vector dims will be inferred from the Ops.
3932-
assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3933-
isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3934-
"Not a 1D depthwise conv!");
3942+
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
3943+
!isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
3944+
return rewriter.notifyMatchFailure(
3945+
op, "Unexpected convolution: expected 1D depthwise conv");
3946+
}
39353947
size_t chDimIdx =
39363948
TypeSwitch<Operation *, size_t>(op)
39373949
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,25 @@ module attributes {transform.with_named_sequence} {
112112

113113
// -----
114114

115+
func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4: tensor<64x64x3x3xf32>) {
116+
%cst = arith.constant 0.000000e+00 : f32
117+
%5 = tensor.empty() : tensor<1x64x56x56xf32>
118+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
119+
// expected-error @+1 {{Attempted to vectorize, but failed}}
120+
%7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
121+
return
122+
}
123+
124+
module attributes {transform.with_named_sequence} {
125+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
126+
%0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
127+
transform.structured.vectorize %0 : !transform.any_op
128+
transform.yield
129+
}
130+
}
131+
132+
// -----
133+
115134
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
116135
%pad = arith.constant 0.000000e+00 : f32
117136
// expected-error @+1 {{Attempted to vectorize, but failed}}

0 commit comments

Comments
 (0)