Skip to content

Commit 00c3a33

Browse files
committed
Forward declare Conv1DGenerator for validaty
1 parent 80474f7 commit 00c3a33

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ using namespace mlir::linalg;
5252
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
5353
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
5454

55+
// Forward declaration of Conv1DGenerator and its validator
56+
namespace {
57+
struct Conv1DGenerator;
58+
bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp);
59+
} // namespace
60+
5561
/// Try to vectorize `convOp` as a convolution.
5662
static FailureOr<Operation *>
5763
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
@@ -1991,14 +1997,17 @@ static LogicalResult vectorizeLinalgOpPrecondition(
19911997
// features. But we will still need stride/dilation attributes that will be
19921998
// annoying to reverse-engineer...
19931999
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
1994-
// Check if it is 2d+ convolution. If it is, return failure because we don't
2000+
// Create a dummy rewriter first, a rewriter is not required for
2001+
// validation
2002+
IRRewriter dummyBuilder(linalgOp.getContext());
2003+
// Check if we can successfully construct a 1d convolution generator.
2004+
// For example, if it is 2d+ convolution, return failure because we don't
19952005
// support it. To use this pass on a 2d+ convolution, it should have already
19962006
// been decomposed to 1d convolution via
19972007
// DecomposeConvolutionToLowerDimOpsPass.
1998-
if (linalgOp.getNumParallelLoops() >= 4) {
1999-
LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
2008+
if (!validateConv1DGenerator(dummyBuilder, linalgOp))
20002009
return failure();
2001-
}
2010+
20022011
return success();
20032012
}
20042013

@@ -3197,6 +3206,8 @@ struct Conv1DGenerator
31973206
valid = true;
31983207
}
31993208

3209+
bool isValid() { return valid; }
3210+
32003211
/// Generate a vector implementation for:
32013212
/// ```
32023213
/// Op def: ( w, kw )
@@ -3907,6 +3918,13 @@ struct Conv1DGenerator
39073918
}
39083919
}
39093920
};
3921+
3922+
// Helper function to construct Conv1DGenerator
3923+
bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
3924+
Conv1DGenerator conv1dGen(rewriter, linalgOp);
3925+
return conv1dGen.isValid();
3926+
}
3927+
39103928
} // namespace
39113929

39123930
/// Helper function to vectorize a LinalgOp with convolution semantics.

0 commit comments

Comments
 (0)