-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[MLIR][Linalg] Transform pass to hoist vector transfer operations for a batch reduce matmul operation #121748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
84130f6
83ed5c4
45a7e76
fc4c016
8aa56e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1824,6 +1824,10 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, | |
| /// suffices for achieving the sum. | ||
| void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); | ||
|
|
||
| /// Pattern to hoists the vector transfer reads/writes outside the reduction and | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "hoists the" -> "hoist" |
||
| /// k-loop for batch reduce matmul operation if licm fails. | ||
| void populateHoistVectorTransferPatterns(RewritePatternSet &patterns); | ||
|
|
||
| /// Pattern to fuse a `tensor.pad` operation with the producer of its source, | ||
| /// if the producer is a `linalg` operation with all parallel iterator types. | ||
| void populateFuseTensorPadWithProducerLinalgOpPatterns( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,267 @@ | ||
| //===- HoistVectorTransfers.cpp ---------------------------------------*- C++-*-===// | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: right justification appears off. |
||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
| #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
| #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" | ||
| #include "mlir/IR/BuiltinAttributes.h" | ||
| #include "mlir/IR/IRMapping.h" | ||
| #include "mlir/Pass/Pass.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
|
||
| using namespace mlir; | ||
|
|
||
| // Function to retrives vector transfer read operations (Acc, Lhs, and Rhs) from contraction operation. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I believe comments should also keep to the 80-char limit. Also holds for comments further down. Does
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit1: "retrieves" and "(RHS, LHS, Acc)" |
||
| static FailureOr<SmallVector<vector::TransferReadOp>> | ||
| getContractOperands(vector::ContractionOp contractOp) { | ||
| SmallVector<vector::TransferReadOp> list; | ||
| for (int i = 0; i < 3; i++) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: https://llvm.org/docs/CodingStandards.html#use-range-based-for-loops-wherever-possible That is, |
||
| auto vectorReadOp = | ||
| contractOp.getOperand(i).getDefiningOp<vector::TransferReadOp>(); | ||
| if (!vectorReadOp) | ||
| return failure(); | ||
| list.push_back(vectorReadOp); | ||
| } | ||
| return list; | ||
| } | ||
|
|
||
| // Function to retrive subview from vector transfer read operation. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "retrive" -> "retrieve" (also in other places) |
||
| static FailureOr<SmallVector<memref::SubViewOp>> | ||
| getReadOperands(SmallVector<vector::TransferReadOp> readOps) { | ||
| SmallVector<memref::SubViewOp> list; | ||
| for (vector::TransferReadOp readOp : readOps) { | ||
| auto subViewOp = readOp.getOperand(0).getDefiningOp<memref::SubViewOp>(); | ||
| if (!subViewOp) | ||
| return failure(); | ||
| list.push_back(subViewOp); | ||
| } | ||
| return list; | ||
| } | ||
|
|
||
| // Function to retrive the tiled nested loop structure (m->n->reduction->k) for the contract operation | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: missing full stop. |
||
| static FailureOr<SmallVector<scf::ForOp>> | ||
| getNestedLoop(vector::ContractionOp contractOp) { | ||
| SmallVector<scf::ForOp> list; | ||
| Operation *current = contractOp; | ||
| for (int i = 0; i < 4; i++) { | ||
| Operation *parent = current->getParentOfType<scf::ForOp>(); | ||
| if (!parent) | ||
| return failure(); | ||
| list.push_back(dyn_cast<scf::ForOp>(parent)); | ||
| current = parent; | ||
| } | ||
| return list; | ||
| } | ||
|
|
||
| // Function to check iv of nested loops matches with the subview | ||
| static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops, | ||
| SmallVector<memref::SubViewOp> subviews) { | ||
| auto subviewOpLhsOffsets = subviews[0].getOffsets(); | ||
| auto subviewOpRhsOffsets = subviews[1].getOffsets(); | ||
| auto subviewOpAccOffsets = subviews[2].getOffsets(); | ||
|
|
||
| Value ivK = loops[0].getInductionVar(); | ||
| if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1]) | ||
| return failure(); | ||
|
|
||
| Value ivReduction = loops[1].getInductionVar(); | ||
| if (ivReduction != subviewOpLhsOffsets[0] || | ||
| ivReduction != subviewOpRhsOffsets[0]) | ||
| return failure(); | ||
|
|
||
| Value ivN = loops[2].getInductionVar(); | ||
| if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2]) | ||
| return failure(); | ||
|
|
||
| Value ivM = loops[3].getInductionVar(); | ||
| if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0]) | ||
| return failure(); | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| /// Hoist vector transfer read and write operations for the tiled batch reduce matmul operation | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As you are matching concrete ops, better to write out the op names, i.e. "vector.transfer_read" and "vector.transfer_write". |
||
| /// outside the reduction and k-loop. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you include the matching conditions in the docstring, .e.g. loop nest of four scf.for + ...?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, are the m- and n-loop really needed for this hoisting to be valid? In light of that, can we generalize the pattern? |
||
| /// | ||
| /// As an example, the following pseudo-code will be rewritten | ||
| /// scf.for %arg3 = %c0 to %c32 step %c4 // m-loop | ||
| /// scf.for %arg4 = %c0 to %c64 step %c64 // n-loop | ||
| /// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] | ||
| /// scf.for %arg5 = %c0 to %c24 step %c1 // reduction-loop | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one too many indents |
||
| /// scf.for %arg6 = %c0 to %c64 step %c1 // k-loop | ||
| /// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] | ||
| /// %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] | ||
| /// %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} | ||
| /// %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} | ||
| /// %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} | ||
| /// %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 | ||
| /// vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} | ||
| /// to: | ||
| /// scf.for %arg3 = %c0 to %c32 step %c4 | ||
| /// scf.for %arg4 = %c0 to %c64 step %c64 | ||
| /// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] | ||
| /// %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} | ||
| /// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) { | ||
| /// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) { | ||
| /// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] | ||
| /// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] | ||
| /// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} | ||
| /// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} | ||
| /// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 | ||
| /// scf.yield %6 : !type | ||
| /// } | ||
| /// scf.yield %3 : !type | ||
| /// } | ||
| /// vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} | ||
| /// | ||
| struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> { | ||
| using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(vector::ContractionOp contractOp, | ||
| PatternRewriter &rewriter) const override { | ||
|
|
||
| // Check the vector contract operation satisfies the required pattern. | ||
| // Check the Acc, Lhs, and Rhs of contract operation | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: missing full stop (in general comments should aim to be (full) sentences). |
||
| auto operands = getContractOperands(contractOp); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either expand the type, or document in a comment the check you are performing on the operands (i.e. you need them to be transfer_reads) |
||
| if (failed(operands)) | ||
| return rewriter.notifyMatchFailure(contractOp, | ||
| "Invalid operands for contract op"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about "not all contract op operands are transfer_reads"? Or something else a bit more descriptive.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here and elsewhere: match failure messages are not capitalized, i.e. "Invalid" -> "invalid" |
||
|
|
||
| auto readOps = *operands; | ||
| auto vectorReadOpAcc = readOps[2]; | ||
| auto vectorReadOpLhs = readOps[0]; | ||
| auto vectorReadOpRhs = readOps[1]; | ||
|
|
||
| // Check whether the operand of vector transfer read is a subview | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "the" -> "each" and missing full stop |
||
| auto subviews = getReadOperands(readOps); | ||
| if (failed(subviews)) | ||
| return rewriter.notifyMatchFailure( | ||
| contractOp, "Vector read op operands are not a subview"); | ||
|
|
||
| // Check the operation type MatMul, B-MatMul, or BR-MatMul | ||
| SmallVector<vector::IteratorType> contractIteratorTypes = | ||
| contractOp.getIteratorTypesArray(); | ||
| int reductionCount = | ||
| std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(), | ||
| vector::IteratorType::reduction); | ||
|
|
||
| auto vectorReadOpLhsType = cast<ShapedType>(vectorReadOpLhs.getType()); | ||
| auto vectorReadOpRhsRank = | ||
| (cast<ShapedType>(vectorReadOpRhs.getType())).getRank(); | ||
|
|
||
| if (reductionCount == 2 && | ||
| (vectorReadOpLhsType.getRank() != 3 || vectorReadOpRhsRank != 3)) | ||
| return rewriter.notifyMatchFailure( | ||
| contractOp, "Invalid rank for batch reduce operation"); | ||
|
|
||
| if (reductionCount == 1) | ||
| return rewriter.notifyMatchFailure( | ||
| contractOp, "Batch matmul operation not supported yet"); | ||
|
|
||
| if (reductionCount > 2) | ||
| return rewriter.notifyMatchFailure( | ||
| contractOp, "The vector contract operation is not a gemm"); | ||
|
|
||
| // Check the K-dim to be 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: not a full sentence |
||
| int64_t K = | ||
| vectorReadOpLhsType.getDimSize(vectorReadOpLhsType.getRank() - 1); | ||
| if (K != 1) | ||
| return rewriter.notifyMatchFailure(contractOp, "K dim is not 1"); | ||
|
|
||
| // Check whether the BR-matmul tiling + vector contract pattern matches for the | ||
| // 4-nested loop structure | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: full stop.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also below a couple times. |
||
| auto loops = getNestedLoop(contractOp); | ||
| if (failed(loops)) | ||
| return rewriter.notifyMatchFailure( | ||
| contractOp, "Invalid loop nest in contract pattern"); | ||
|
|
||
| auto checkLoops = checkNestedLoop(*loops, *subviews); | ||
| if (failed(checkLoops)) | ||
| return rewriter.notifyMatchFailure( | ||
| contractOp, "Loops doesn't match the iv in subviews"); | ||
|
|
||
| auto nestedLoops = *loops; | ||
| auto kForOp = nestedLoops[0]; | ||
| auto reductionForOp = nestedLoops[1]; | ||
|
|
||
| // Move the vector transfer read before the reduction and k loop | ||
| rewriter.setInsertionPoint(reductionForOp); | ||
| auto *cloneVectorReadOp = rewriter.clone(*vectorReadOpAcc); | ||
|
|
||
| // Code to re-create the reduction and k loop with iter args | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: leave out "code to" |
||
| auto vectorReadOpValue = cloneVectorReadOp->getResult(0); | ||
| auto newReductionForOp = rewriter.create<scf::ForOp>( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would guess |
||
| reductionForOp.getLoc(), reductionForOp.getLowerBound(), | ||
| reductionForOp.getUpperBound(), reductionForOp.getStep(), | ||
| ValueRange{vectorReadOpValue}, | ||
| [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case of such nested lambdas, can we please pull out these lambdas into a name-bound lambda or, equivalently, into a method? This is to reduce the rightward drift. |
||
| Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) { | ||
| auto newKForOp = rewriter.create<scf::ForOp>( | ||
| kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), | ||
| kForOp.getStep(), iterArgsNewReductionForOp, | ||
| [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, | ||
| Value ivNewKForOp, ValueRange iterArgsNewKForOp) { | ||
| IRMapping mapper; | ||
| mapper.map(reductionForOp.getInductionVar(), | ||
| ivNewReductionForOp); | ||
| mapper.map(kForOp.getInductionVar(), ivNewKForOp); | ||
|
|
||
| for (auto &op : kForOp.getBody()->without_terminator()) { | ||
| rewriterNewKForOp.clone(op, mapper); | ||
| } | ||
| rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp, | ||
| iterArgsNewKForOp); | ||
| }); | ||
| rewriterNewReductionForOp.create<scf::YieldOp>( | ||
| locNewReductionForOp, newKForOp.getResult(0)); | ||
| }); | ||
|
|
||
| // Code to hoist vector transfer write after reduction loop and also to | ||
| // update the yield of k loop | ||
| auto newKForOp = | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here as well: I am not sure why this section appears a bit involved. Per the example in the docs, you are just hoisting one transfer_read and "threading" that vector through the iter_args of two loops, right? Doesn't two applications of |
||
| llvm::dyn_cast<scf::ForOp>(newReductionForOp.getBody()->front()); | ||
| Value newcontractOpValue; | ||
| vector::TransferWriteOp vectorWriteOperation; | ||
| Block *bodyBlock = newKForOp.getBody(); | ||
| for (auto &op : bodyBlock->getOperations()) { | ||
| if (auto vectorContractOp = llvm::dyn_cast<vector::ContractionOp>(op)) { | ||
| vectorContractOp.setOperand(vectorContractOp.getNumOperands() - 1, | ||
| newKForOp.getRegionIterArgs()[0]); | ||
| newcontractOpValue = vectorContractOp.getResult(); | ||
| } | ||
| if (auto yieldOp = llvm::dyn_cast<scf::YieldOp>(op)) { | ||
| yieldOp.setOperand(0, newcontractOpValue); | ||
| } | ||
| if (auto vectorWriteOp = llvm::dyn_cast<vector::TransferWriteOp>(op)) { | ||
| vectorWriteOperation = vectorWriteOp; | ||
| } | ||
| } | ||
|
|
||
| vectorWriteOperation.setOperand(0, newReductionForOp.getResult(0)); | ||
| vectorWriteOperation->moveBefore(reductionForOp); | ||
|
|
||
| // Erase the old vector contract operation | ||
| for (auto result : contractOp->getResults()) { | ||
| for (auto *userOp : result.getUsers()) { | ||
| userOp->erase(); | ||
| } | ||
| } | ||
| contractOp.erase(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try to make use of the |
||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| void linalg::populateHoistVectorTransferPatterns(RewritePatternSet &patterns) { | ||
| patterns.add<HoistVectorTransferOp>(patterns.getContext()); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: double space typo