-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] Refactor to create vectorization convOp precondition check #130181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
a33b211
80474f7
00c3a33
74a8986
2ef5555
13f5183
2b5d1dc
a58b8da
f153804
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -52,6 +52,12 @@ using namespace mlir::linalg; | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Forward declaration of Conv1DGenerator and its validator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| struct Conv1DGenerator; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Try to vectorize `convOp` as a convolution. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static FailureOr<Operation *> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1990,8 +1996,21 @@ static LogicalResult vectorizeLinalgOpPrecondition( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // TODO: isaConvolutionOpInterface that can also infer from generic | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // features. But we will still need stride/dilation attributes that will be | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // annoying to reverse-engineer... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Create a dummy rewriter first, a rewriter is not required for | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // validation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IRRewriter dummyBuilder(linalgOp.getContext()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Check if we can successfully construct a 1d convolution generator. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // For example, if it is 2d+ convolution, return failure because we don't | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // support it. To use this pass on a 2d+ convolution, it should have already | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // been decomposed to 1d convolution via | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // DecomposeConvolutionToLowerDimOpsPass. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Determine whether `linalgOp` can be generated with this generator | |
| if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) | |
| return; | |
| lhsShaped = linalgOp.getDpsInputOperand(0)->get(); | |
| rhsShaped = linalgOp.getDpsInputOperand(1)->get(); | |
| resShaped = linalgOp.getDpsInitOperand(0)->get(); | |
| lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType()); | |
| rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType()); | |
| resShapedType = dyn_cast<ShapedType>(resShaped.getType()); | |
| if (!lhsShapedType || !rhsShapedType || !resShapedType) | |
| return; | |
| // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR | |
| // (non-channeled convolution -> LHS and RHS both have single dimensions). | |
| if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) && | |
| (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1)) | |
| return; | |
| Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); | |
| if (!reduceOp) | |
| return; | |
| redOp = reduceOp->getName().getIdentifier(); | |
| if (!setOperKind(reduceOp)) | |
| return; | |
| auto maybeKind = getCombinerOpKind(reduceOp); | |
| // Typically convolution will have a `Add` CombiningKind but for i1 type it | |
| // can get strength reduced to `OR` which is also supported. This strength | |
| // reduction logic is in `buildBinaryFn` helper in the Linalg dialect. | |
| if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD && | |
| *maybeKind != vector::CombiningKind::OR) && | |
| (oper != Pool || !isSupportedPoolKind(*maybeKind)))) { | |
| return; | |
| } | |
| reductionKind = maybeKind.value(); | |
| auto rhsRank = rhsShapedType.getRank(); | |
| switch (oper) { | |
| case Conv: | |
| if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) | |
| return; | |
| break; | |
| case Pool: | |
| if (rhsRank != 1) | |
| return; | |
| break; | |
| } | |
| // The op is now known to be valid. | |
| valid = true; |
The valid variable is only used in assertions in few methods, e.g., depthwiseConv and conv. I think it's mainly created for sanity check, while the new codes did not take it into account. Thus, we crashed in the other place.
The code is quite old and the precondition was added later than the conv code. I think to make it in better structure, we can refactor the generator because everything is started from the generator. How about we have a static class method which returns true when the given operation is supported? That said, we move the above logic check to a static method (e.g., vectorizePrecondition) without initializing any variables.
In the construction, I'd suggest doing simple things as much as possible. And we move the assertion out of the constructor. In the context, they are moved to an initializer method. Because I'd prefer avoiding a crash in the constructor, and we can expose the failure handling to external users. (I don't know what the style is in LLVM, but it is quite common in environments where exceptions are disallowed. See https://abseil.io/tips/42 for more details.)
Thus, it can be something like
Conv1DGenerator : : public StructuredGenerator<LinalgOp, utils::IteratorType> {
// constructor only takes the rewriter and linalgop
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {}
// vectorization precond
bool/LogicalResult vectorizePrecondition(LinalgOp linalgOp) { ... }
// The initialization method
LogicalResult init() {
// or do an assertion here.
if (failed(vectorizedPrecondition(...))) {
return failure();
}
// Initial the values for class members.
}
Does it look better structured?
Uh oh!
There was an error while loading. Please reload this page.