Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ def ApplyFoldAddIntoDestPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyHoistVectorTransferPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.hoist_vector_transfer",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Finds pattern to hoist the possible vector transfer reads/writes outside the reduction and k-loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: double space typo

for a batch reduce matmul operation.
}];

let assemblyFormat = "attr-dict";
}

def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.pad_vectorization",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,10 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
/// suffices for achieving the sum.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);

/// Pattern to hoists the vector transfer reads/writes outside the reduction and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "hoists the" -> "hoist"

/// k-loop for batch reduce matmul operation if licm fails.
void populateHoistVectorTransferPatterns(RewritePatternSet &patterns);

/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
linalg::populateFoldAddIntoDestPatterns(patterns);
}

void transform::ApplyHoistVectorTransferPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populateHoistVectorTransferPatterns(patterns);
}

void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populatePadOpVectorizationPatterns(patterns);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
DecomposeGenericByUnfoldingPermutation.cpp
Vectorization.cpp
WinogradConv2D.cpp
HoistVectorTransfers.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
Expand Down
267 changes: 267 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
//===- HoistVectorTransfers.cpp ---------------------------------------*- C++-*-===//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: right justification appears off.

//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

// Function to retrives vector transfer read operations (Acc, Lhs, and Rhs) from contraction operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I believe comments should also keep to the 80-char limit. Also holds for comments further down.

Does clang-format leave this comment line (and the others) as is?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit1: "retrieves" and "(RHS, LHS, Acc)"
nit2: we probably want to write out op names as "vector.transfer_read".

static FailureOr<SmallVector<vector::TransferReadOp>>
getContractOperands(vector::ContractionOp contractOp) {
SmallVector<vector::TransferReadOp> list;
for (int i = 0; i < 3; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: https://llvm.org/docs/CodingStandards.html#use-range-based-for-loops-wherever-possible

That is, for (OpOperand operand : contractOp.getOperands()) { etc.

auto vectorReadOp =
contractOp.getOperand(i).getDefiningOp<vector::TransferReadOp>();
if (!vectorReadOp)
return failure();
list.push_back(vectorReadOp);
}
return list;
}

// Function to retrive subview from vector transfer read operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "retrive" -> "retrieve" (also in other places)

static FailureOr<SmallVector<memref::SubViewOp>>
getReadOperands(SmallVector<vector::TransferReadOp> readOps) {
SmallVector<memref::SubViewOp> list;
for (vector::TransferReadOp readOp : readOps) {
auto subViewOp = readOp.getOperand(0).getDefiningOp<memref::SubViewOp>();
if (!subViewOp)
return failure();
list.push_back(subViewOp);
}
return list;
}

// Function to retrive the tiled nested loop structure (m->n->reduction->k) for the contract operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: missing full stop.

static FailureOr<SmallVector<scf::ForOp>>
getNestedLoop(vector::ContractionOp contractOp) {
SmallVector<scf::ForOp> list;
Operation *current = contractOp;
for (int i = 0; i < 4; i++) {
Operation *parent = current->getParentOfType<scf::ForOp>();
if (!parent)
return failure();
list.push_back(dyn_cast<scf::ForOp>(parent));
current = parent;
}
return list;
}

// Function to check iv of nested loops matches with the subview
static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
SmallVector<memref::SubViewOp> subviews) {
auto subviewOpLhsOffsets = subviews[0].getOffsets();
auto subviewOpRhsOffsets = subviews[1].getOffsets();
auto subviewOpAccOffsets = subviews[2].getOffsets();

Value ivK = loops[0].getInductionVar();
if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
return failure();

Value ivReduction = loops[1].getInductionVar();
if (ivReduction != subviewOpLhsOffsets[0] ||
ivReduction != subviewOpRhsOffsets[0])
return failure();

Value ivN = loops[2].getInductionVar();
if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
return failure();

Value ivM = loops[3].getInductionVar();
if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
return failure();

return success();
}

/// Hoist vector transfer read and write operations for the tiled batch reduce matmul operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you are matching concrete ops, better to write out the op names, i.e. "vector.transfer_read" and "vector.transfer_write".

/// outside the reduction and k-loop.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include the matching conditions in the docstring, .e.g. loop nest of four scf.for + ...?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, are the m- and n-loop really needed for this hoisting to be valid? In light of that, can we generalize the pattern?

///
/// As an example, the following pseudo-code will be rewritten
/// scf.for %arg3 = %c0 to %c32 step %c4 // m-loop
/// scf.for %arg4 = %c0 to %c64 step %c64 // n-loop
/// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
/// scf.for %arg5 = %c0 to %c24 step %c1 // reduction-loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one too many indents

/// scf.for %arg6 = %c0 to %c64 step %c1 // k-loop
/// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1]
/// %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1]
/// %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
/// %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
/// %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]}
/// %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3
/// vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]}
/// to:
/// scf.for %arg3 = %c0 to %c32 step %c4
/// scf.for %arg4 = %c0 to %c64 step %c64
/// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1]
/// %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]}
/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) {
/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) {
/// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1]
/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1]
/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]}
/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8
/// scf.yield %6 : !type
/// }
/// scf.yield %3 : !type
/// }
/// vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]}
///
struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {

// Check the vector contract operation satisfies the required pattern.
// Check the Acc, Lhs, and Rhs of contract operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing full stop (in general comments should aim to be (full) sentences).

auto operands = getContractOperands(contractOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either expand the type, or document in a comment the check you are performing on the operands (i.e. you need them to be transfer_reads)

if (failed(operands))
return rewriter.notifyMatchFailure(contractOp,
"Invalid operands for contract op");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about "not all contract op operands are transfer_reads"? Or something else a bit more descriptive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and elsewhere: match failure messages are not capitalized, i.e. "Invalid" -> "invalid"


auto readOps = *operands;
auto vectorReadOpAcc = readOps[2];
auto vectorReadOpLhs = readOps[0];
auto vectorReadOpRhs = readOps[1];

// Check whether the operand of vector transfer read is a subview
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "the" -> "each" and missing full stop

auto subviews = getReadOperands(readOps);
if (failed(subviews))
return rewriter.notifyMatchFailure(
contractOp, "Vector read op operands are not a subview");

// Check the operation type MatMul, B-MatMul, or BR-MatMul
SmallVector<vector::IteratorType> contractIteratorTypes =
contractOp.getIteratorTypesArray();
int reductionCount =
std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(),
vector::IteratorType::reduction);

auto vectorReadOpLhsType = cast<ShapedType>(vectorReadOpLhs.getType());
auto vectorReadOpRhsRank =
(cast<ShapedType>(vectorReadOpRhs.getType())).getRank();

if (reductionCount == 2 &&
(vectorReadOpLhsType.getRank() != 3 || vectorReadOpRhsRank != 3))
return rewriter.notifyMatchFailure(
contractOp, "Invalid rank for batch reduce operation");

if (reductionCount == 1)
return rewriter.notifyMatchFailure(
contractOp, "Batch matmul operation not supported yet");

if (reductionCount > 2)
return rewriter.notifyMatchFailure(
contractOp, "The vector contract operation is not a gemm");

// Check the K-dim to be 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not a full sentence

int64_t K =
vectorReadOpLhsType.getDimSize(vectorReadOpLhsType.getRank() - 1);
if (K != 1)
return rewriter.notifyMatchFailure(contractOp, "K dim is not 1");

// Check whether the BR-matmul tiling + vector contract pattern matches for the
// 4-nested loop structure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: full stop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also below a couple times.

auto loops = getNestedLoop(contractOp);
if (failed(loops))
return rewriter.notifyMatchFailure(
contractOp, "Invalid loop nest in contract pattern");

auto checkLoops = checkNestedLoop(*loops, *subviews);
if (failed(checkLoops))
return rewriter.notifyMatchFailure(
contractOp, "Loops doesn't match the iv in subviews");

auto nestedLoops = *loops;
auto kForOp = nestedLoops[0];
auto reductionForOp = nestedLoops[1];

// Move the vector transfer read before the reduction and k loop
rewriter.setInsertionPoint(reductionForOp);
auto *cloneVectorReadOp = rewriter.clone(*vectorReadOpAcc);

// Code to re-create the reduction and k loop with iter args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: leave out "code to"

auto vectorReadOpValue = cloneVectorReadOp->getResult(0);
auto newReductionForOp = rewriter.create<scf::ForOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would guess LoopLikeOpInterface::replaceWithAdditionalYields is well suited to what you are trying to do here. Please have a look if that would help simplify this section.

reductionForOp.getLoc(), reductionForOp.getLowerBound(),
reductionForOp.getUpperBound(), reductionForOp.getStep(),
ValueRange{vectorReadOpValue},
[&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of such nested lambdas, can we please pull out these lambdas into a name-bound lambda or, equivalently, into a method? This is to reduce the rightward drift.

Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) {
auto newKForOp = rewriter.create<scf::ForOp>(
kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
kForOp.getStep(), iterArgsNewReductionForOp,
[&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
IRMapping mapper;
mapper.map(reductionForOp.getInductionVar(),
ivNewReductionForOp);
mapper.map(kForOp.getInductionVar(), ivNewKForOp);

for (auto &op : kForOp.getBody()->without_terminator()) {
rewriterNewKForOp.clone(op, mapper);
}
rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp,
iterArgsNewKForOp);
});
rewriterNewReductionForOp.create<scf::YieldOp>(
locNewReductionForOp, newKForOp.getResult(0));
});

// Code to hoist vector transfer write after reduction loop and also to
// update the yield of k loop
auto newKForOp =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well: I am not sure why this section appears a bit involved. Per the example in the docs, you are just hoisting one transfer_read and "threading" that vector through the iter_args of two loops, right? Doesn't two applications of LoopLikeOpInterface::replaceWithAdditionalYields suffice for this?

llvm::dyn_cast<scf::ForOp>(newReductionForOp.getBody()->front());
Value newcontractOpValue;
vector::TransferWriteOp vectorWriteOperation;
Block *bodyBlock = newKForOp.getBody();
for (auto &op : bodyBlock->getOperations()) {
if (auto vectorContractOp = llvm::dyn_cast<vector::ContractionOp>(op)) {
vectorContractOp.setOperand(vectorContractOp.getNumOperands() - 1,
newKForOp.getRegionIterArgs()[0]);
newcontractOpValue = vectorContractOp.getResult();
}
if (auto yieldOp = llvm::dyn_cast<scf::YieldOp>(op)) {
yieldOp.setOperand(0, newcontractOpValue);
}
if (auto vectorWriteOp = llvm::dyn_cast<vector::TransferWriteOp>(op)) {
vectorWriteOperation = vectorWriteOp;
}
}

vectorWriteOperation.setOperand(0, newReductionForOp.getResult(0));
vectorWriteOperation->moveBefore(reductionForOp);

// Erase the old vector contract operation
for (auto result : contractOp->getResults()) {
for (auto *userOp : result.getUsers()) {
userOp->erase();
}
}
contractOp.erase();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to make use of the replaceAllUsesWith API. Alternatively, modify the relevant operand once the region containing the contract has been transplanted to the loops with the new iter_args (though this approach is not always safe so probably best avoided as well).


return success();
}
};

void linalg::populateHoistVectorTransferPatterns(RewritePatternSet &patterns) {
patterns.add<HoistVectorTransferOp>(patterns.getContext());
}
Loading