From 84130f65a351a731f6ba153f5b147db59d7ec09c Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Thu, 2 Jan 2025 07:27:29 -0800 Subject: [PATCH 1/3] =?UTF-8?q?initial=20changes=20for=20upstreaming=20hoi?= =?UTF-8?q?st=20vector=20transfers=C2=A0and=20contract=20to=20fma?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Linalg/TransformOps/LinalgTransformOps.td | 24 ++ .../Dialect/Linalg/Transforms/Transforms.h | 10 + .../TransformOps/LinalgTransformOps.cpp | 13 + .../Dialect/Linalg/Transforms/CMakeLists.txt | 2 + .../Transforms/HoistVectorTransfers.cpp | 234 ++++++++++++ .../Linalg/Transforms/VectorContractToFMA.cpp | 356 ++++++++++++++++++ .../Dialect/Linalg/hoist-vector-transfer.mlir | 97 +++++ .../Linalg/vector-contract-to-fma.mlir | 113 ++++++ 8 files changed, 849 insertions(+) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp create mode 100644 mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp create mode 100644 mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir create mode 100644 mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 2e713bca24efc..d614bb4789767 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -106,6 +106,30 @@ def ApplyFoldAddIntoDestPatternsOp : Op]> { + let description = [{ + Hoists the vector transfer reads/writes outside the reduction and k-loop. + }]; + + let assemblyFormat = "attr-dict"; +} + + +def ApplyVectorContractToFMAPatternsOp : Op]> { + let description = [{ + Implements the lowering of vector contraction op for GEMM of size MxN to + sequence of vector FMAs wrapped inside scf.for loop with iterargs to + accumulate the result of FMAs. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyPadVectorizationPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 1dc700f22c202..d35f99826004e 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1824,6 +1824,16 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, /// suffices for achieving the sum. void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); + +/// Pattern to hoists the vector transfer reads/writes outside the reduction and +/// k-loop. +void populateHoistVectorTransferPatterns(RewritePatternSet &patterns); + + +/// Pattern to lower vector contraction op for GEMM of size MxN to +/// sequence of vector FMAs +void populateVectorContractToFMAPatterns(RewritePatternSet &patterns); + /// Pattern to fuse a `tensor.pad` operation with the producer of its source, /// if the producer is a `linalg` operation with all parallel iterator types. void populateFuseTensorPadWithProducerLinalgOpPatterns( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 221ca27b80fdd..849632aed8a13 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -262,6 +262,19 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns( linalg::populateFoldAddIntoDestPatterns(patterns); } + +void transform::ApplyHoistVectorTransferPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateHoistVectorTransferPatterns(patterns); +} + + +void transform::ApplyVectorContractToFMAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateVectorContractToFMAPatterns(patterns); +} + + void transform::ApplyPadVectorizationPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::populatePadOpVectorizationPatterns(patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 3594b08413812..90d926201cd75 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -41,6 +41,8 @@ add_mlir_dialect_library(MLIRLinalgTransforms DecomposeGenericByUnfoldingPermutation.cpp Vectorization.cpp WinogradConv2D.cpp + HoistVectorTransfers.cpp + VectorContractToFMA.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp new file mode 100644 index 0000000000000..5911d36cbafef --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp @@ -0,0 +1,234 @@ +//===-HoistVectorTransfers.cpp -----------------------------------------*- +// C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements tile configuration hoisting on parallel loops. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +static FailureOr> +getContractOperands(vector::ContractionOp contractOp) { + SmallVector list; + for (int i = 0; i < 3; i++) { + auto vectorReadOp = + contractOp.getOperand(i).getDefiningOp(); + if (!vectorReadOp) + return failure(); + list.push_back(vectorReadOp); + } + return list; +} + +static FailureOr> +getReadOperands(SmallVector readOps) { + SmallVector list; + for (vector::TransferReadOp readOp : readOps) { + auto subViewOp = readOp.getOperand(0).getDefiningOp(); + if (!subViewOp) + return failure(); + list.push_back(subViewOp); + } + return list; +} + +static FailureOr> +getNestedLoop(vector::ContractionOp contractOp) { + SmallVector list; + Operation *current = contractOp; + for (int i = 0; i < 4; i++) { + Operation *parent = current->getParentOfType(); + if (!parent) + return failure(); + list.push_back(dyn_cast(parent)); + current = parent; + } + return list; +} + +static LogicalResult checkNestedLoop(SmallVector loops, + SmallVector subviews) { + auto subviewOpLhsOffsets = subviews[0].getOffsets(); + auto subviewOpRhsOffsets = subviews[1].getOffsets(); + auto subviewOpAccOffsets = subviews[2].getOffsets(); + + Value ivK = loops[0].getInductionVar(); + if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1]) + return failure(); + + Value ivReduction = loops[1].getInductionVar(); + if (ivReduction != subviewOpLhsOffsets[0] || + ivReduction != subviewOpRhsOffsets[0]) + return failure(); + + Value ivN = loops[2].getInductionVar(); + if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2]) + return failure(); + + Value ivM = loops[3].getInductionVar(); + if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0]) + return failure(); + + return success(); +} + +struct HoistVectorTransferOp : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + // Check the vector contract operation satisfies the required pattern. + // Check the Acc, Lhs, and Rhs of contract operation + + auto operands = getContractOperands(contractOp); + if (failed(operands)) + return rewriter.notifyMatchFailure(contractOp, + "Invalid operands for contract op"); + + auto readOps = *operands; + auto vectorReadOpAcc = readOps[2]; + auto vectorReadOpLhs = readOps[0]; + auto vectorReadOpRhs = readOps[1]; + + // Check whether the operand of vector transfer read is a subview + auto subviews = getReadOperands(readOps); + if (failed(subviews)) + return rewriter.notifyMatchFailure( + contractOp, "Vector read op operands are not a subview"); + + // Check the operation type MatMul, B-MatMul, or BR-MatMul + SmallVector contractIteratorTypes = + contractOp.getIteratorTypesArray(); + int reductionCount = + std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(), + vector::IteratorType::reduction); + + auto vectorReadOpLhsType = cast(vectorReadOpLhs.getType()); + auto vectorReadOpRhsRank = + (cast(vectorReadOpRhs.getType())).getRank(); + + if (reductionCount == 2 && + (vectorReadOpLhsType.getRank() != 3 || vectorReadOpRhsRank != 3)) + return rewriter.notifyMatchFailure( + contractOp, "Invalid rank for batch reduce operation"); + + if (reductionCount == 1) + return rewriter.notifyMatchFailure( + contractOp, "Batch matmul operation not supported yet"); + + if (reductionCount > 2) + return rewriter.notifyMatchFailure( + contractOp, "The vector contract operation is not a gemm"); + + // Check the K-dim to be 1 + int64_t K = + vectorReadOpLhsType.getDimSize(vectorReadOpLhsType.getRank() - 1); + if (K != 1) + return rewriter.notifyMatchFailure(contractOp, "K dim is not 1"); + + // Check whether the linalg tiling + vector contract pattern matches for the + // 4-nested loop structure + auto loops = getNestedLoop(contractOp); + if (failed(loops)) + return rewriter.notifyMatchFailure( + contractOp, "Invalid loop nest in contract pattern"); + + auto checkLoops = checkNestedLoop(*loops, *subviews); + if (failed(checkLoops)) + return rewriter.notifyMatchFailure( + contractOp, "Loops doesn't match the iv in subviews"); + + auto nestedLoops = *loops; + auto kForOp = nestedLoops[0]; + auto reductionForOp = nestedLoops[1]; + + // Move the vector transfer read before the reduction and k loop + rewriter.setInsertionPoint(reductionForOp); + auto *cloneVectorReadOp = rewriter.clone(*vectorReadOpAcc); + + // Code to re-create the reduction and k loop with iter args + auto vectorReadOpValue = cloneVectorReadOp->getResult(0); + auto newReductionForOp = rewriter.create( + reductionForOp.getLoc(), reductionForOp.getLowerBound(), + reductionForOp.getUpperBound(), reductionForOp.getStep(), + ValueRange{vectorReadOpValue}, + [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp, + Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) { + auto newKForOp = rewriter.create( + kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), + kForOp.getStep(), iterArgsNewReductionForOp, + [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, + Value ivNewKForOp, ValueRange iterArgsNewKForOp) { + IRMapping mapper; + mapper.map(reductionForOp.getInductionVar(), + ivNewReductionForOp); + mapper.map(kForOp.getInductionVar(), ivNewKForOp); + + for (auto &op : kForOp.getBody()->without_terminator()) { + rewriterNewKForOp.clone(op, mapper); + } + rewriterNewKForOp.create(locNewKForOp, + iterArgsNewKForOp); + }); + rewriterNewReductionForOp.create( + locNewReductionForOp, newKForOp.getResult(0)); + }); + + // Code to hoist vector transfer write after reduction loop and also to + // update the yield of k loop + auto newKForOp = + llvm::dyn_cast(newReductionForOp.getBody()->front()); + Value newcontractOpValue; + vector::TransferWriteOp vectorWriteOperation; + Block *bodyBlock = newKForOp.getBody(); + for (auto &op : bodyBlock->getOperations()) { + if (auto vectorContractOp = llvm::dyn_cast(op)) { + vectorContractOp.setOperand(vectorContractOp.getNumOperands() - 1, + newKForOp.getRegionIterArgs()[0]); + newcontractOpValue = vectorContractOp.getResult(); + } + if (auto yieldOp = llvm::dyn_cast(op)) { + yieldOp.setOperand(0, newcontractOpValue); + } + if (auto vectorWriteOp = llvm::dyn_cast(op)) { + vectorWriteOperation = vectorWriteOp; + } + } + + vectorWriteOperation.setOperand(0, newReductionForOp.getResult(0)); + vectorWriteOperation->moveBefore(reductionForOp); + + // Erase the old vector contract operation + for (auto result : contractOp->getResults()) { + for (auto *userOp : result.getUsers()) { + userOp->erase(); + } + } + contractOp.erase(); + + return success(); + } +}; + +void linalg::populateHoistVectorTransferPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp new file mode 100644 index 0000000000000..2a8132b93bdcb --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp @@ -0,0 +1,356 @@ + +//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering of vector contraction to vector fma. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "vector-contract-to-fma" + +using namespace mlir; + +/// Returns true if the \p map is transposed. +static bool isTransposed(AffineMap map) { + auto results = map.getResults(); + // Assert if the map does not have 3 or 4 inputs ([] m, n, k). + assert((map.getNumInputs() == 3 || map.getNumInputs() == 4) && + "3 or 4 input dim expected"); + // Assert if the result is not 2D. + assert(map.getNumResults() == 2 && "Only 2 output dim expected"); + + // Check the last two dimensions for transposition. + auto dimExpr0 = dyn_cast(results[0]); + auto dimExpr1 = dyn_cast(results[1]); + assert((dimExpr0 && dimExpr1) && "Unexpected dim expression"); + + // Exclude output map result. + bool isOutputResultMap = + dimExpr0 == + mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext()) && + dimExpr1 == + mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()); + assert(!isOutputResultMap && "Output result map not expected"); + + // It's transposed if result found as (k, m) or (n, k), else not transposed. + if ((dimExpr0 == + mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()) && + dimExpr1 == + mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext())) || + (dimExpr0 == + mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()) && + dimExpr1 == + mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()))) + return true; + return false; +} + + +// Structure to hold transformation context +struct TransformationContext { + scf::ForOp innerForOp; + scf::ForOp outerForOp; + scf::ForOp outermostLoop; +}; + +enum class MatMulType { Standard, Batch, BatchReduce }; + +struct VectorContractToFMA + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + if (op.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure( + op, "Unsupported combining kind, only supports ADD at the moment)"); + + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return rewriter.notifyMatchFailure(op, "Masked contractOp not supported"); + + SmallVector maps = op.getIndexingMapsArray(); + if (llvm::any_of( + maps, [](AffineMap map) { return !map.isProjectedPermutation(); })) + return rewriter.notifyMatchFailure(op, "Unexpected map"); + + // Check for the variant of matrix multiply. + auto iteratorTypes = op.getIteratorTypesArray(); + MatMulType matmulType; + unsigned outerDimIndex = 0; + if (iteratorTypes.size() > 3) { + outerDimIndex = iteratorTypes.size() - 4; + matmulType = + iteratorTypes[outerDimIndex] == vector::IteratorType::parallel + ? MatMulType::Batch + : MatMulType::BatchReduce; + outerDimIndex++; + } else if (iteratorTypes.size() == 3) { + matmulType = MatMulType::Standard; + } else { + return rewriter.notifyMatchFailure(op, "Not a gemm"); + } + + if (matmulType == MatMulType::Batch) + return rewriter.notifyMatchFailure(op, "Batch matmul not supported"); + if (iteratorTypes[outerDimIndex] != vector::IteratorType::parallel || + iteratorTypes[outerDimIndex + 1] != vector::IteratorType::parallel || + iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure(op, "Not a gemm"); + + SmallVector results; + + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + auto acc = op.getAcc(); + auto lhsDefiningOp = lhs.getDefiningOp(); + auto rhsDefiningOp = rhs.getDefiningOp(); + auto accDefiningOp = acc.getDefiningOp(); + if (!lhsDefiningOp || !rhsDefiningOp) + return failure(); + + // Accumulator can be a TransferReadOp but must be coming from the chain of + // iterargs of nested loop. + if (accDefiningOp) + return failure(); + + // Make sure the inputs being read are whole tensor or subview. + if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) || + !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) { + return failure(); + } + + auto lhsType = cast(lhsDefiningOp.getType()); + auto rhsType = cast(rhsDefiningOp.getType()); + // auto accType = acc.getType(); + // auto accType = cast(accDefiningOp.getType()); + + if (matmulType == MatMulType::BatchReduce && + (lhsType.getRank() != 3 || rhsType.getRank() != 3)) + return failure(); + + if (matmulType == MatMulType::Standard && + (lhsType.getRank() != 2 || rhsType.getRank() != 2)) + return failure(); + + // Check for non-transposed matrices. + auto mapLHS = maps[0]; + auto mapRHS = maps[1]; + if (matmulType == MatMulType::BatchReduce) { + mapLHS = mapLHS.dropResult(0); + mapRHS = mapRHS.dropResult(0); + } + if (isTransposed(mapLHS) || isTransposed(mapRHS)) + return rewriter.notifyMatchFailure( + op, "Transposed matrices are not expected"); + + // Verify that the accumulator is coming through a chain of iterargs of + // nested loop and it is define by 'TransferReadOp'. + // + struct TransformationContext ctx; + + ctx.innerForOp = op->getParentOfType(); + if (!ctx.innerForOp) + return failure(); + ctx.outerForOp = ctx.innerForOp->getParentOfType(); + if (!ctx.outerForOp) + return failure(); + ctx.outermostLoop = ctx.outerForOp->getParentOfType(); + if (!ctx.outermostLoop) + return failure(); + + // Verify original inner loop has only one iterarg. + auto origIterArgs = ctx.innerForOp.getRegionIterArgs(); + if (origIterArgs.size() != 1) + return failure(); + + // Verify chain, accumulator must be inner loop's iterarg. + auto bbArg = dyn_cast(acc); + if (!bbArg) + return failure(); + + // This block arg must be init arg, not induction variable. + if (bbArg.getOwner() != ctx.innerForOp.getBody() || + bbArg.getArgNumber() == 0) { + return failure(); + } + + // This iterarg must be intialized by outer loop's iterarg. + auto innerInitValue = + ctx.innerForOp.getInitArgs()[bbArg.getArgNumber() - 1]; + auto outerBBArg = dyn_cast(innerInitValue); + if (!outerBBArg) + return failure(); + + // This block arg must be init arg, not induction variable. + if (outerBBArg.getOwner() != ctx.outerForOp.getBody() || + outerBBArg.getArgNumber() == 0) { + return failure(); + } + + // Outer loop's iterarg initializer must be a TransferReadOp. + acc = ctx.outerForOp.getInitArgs()[outerBBArg.getArgNumber() - 1]; + + // This must be defined by vector.transfer_read + if (!acc.getDefiningOp()) + return failure(); + + accDefiningOp = acc.getDefiningOp(); + if (!accDefiningOp) + return failure(); + + // Only 2-D output expected. + auto accType = cast(accDefiningOp.getType()); + if (accType.getRank() != 2) + return failure(); + + int64_t M = accType.getDimSize(0); + int64_t N = accType.getDimSize(1); + int64_t K = lhsType.getDimSize(lhsType.getRank() - 1); + + // K must be 1. + if (K != 1) + return failure(); + + auto accSubview = accDefiningOp.getSource(); + Location loc = op.getLoc(); + + // Create M different <1xN> subviews. + auto memrefType = cast(accSubview.getType()); + auto elementType = memrefType.getElementType(); + SmallVector mixedSizes = {rewriter.getIndexAttr(K), + rewriter.getIndexAttr(N)}; + SmallVector mixedStrides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + + rewriter.setInsertionPoint( + ctx.outermostLoop.getBody(), + std::prev(ctx.outermostLoop.getBody()->end(), 1)); + + Value c0 = rewriter.create(loc, 0); + SmallVector subview_2_splits; + for (int i = 0; i < M; i++) { + SmallVector mixedOffsets = { + rewriter.getIndexAttr(i), + rewriter.getIndexAttr(0), + }; + auto split = rewriter.create( + loc, accSubview, mixedOffsets, mixedSizes, mixedStrides); + subview_2_splits.push_back(split); + } + + // Intialize each accumulator with a vector of size N + SmallVector initAccs; + for (auto subview : subview_2_splits) { + auto acc = rewriter.create( + loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0}); + initAccs.push_back(acc); + } + + // Create new outer loop with M different accumulators. + auto newOuterForOp = rewriter.create( + loc, ctx.outerForOp.getLowerBound(), ctx.outerForOp.getUpperBound(), + ctx.outerForOp.getStep(), initAccs, + [&](OpBuilder &nestedBuilder, Location loc, Value iv, + ValueRange iterArgs) { + // Create new inner loop with M accumulators. + auto newInnerForOp = nestedBuilder.create( + loc, ctx.innerForOp.getLowerBound(), + ctx.innerForOp.getUpperBound(), ctx.innerForOp.getStep(), + iterArgs, + [&](OpBuilder &innerBuilder, Location loc, Value innerIv, + ValueRange innerIterArgs) { + IRMapping mapping; + mapping.map( + lhsDefiningOp.getSource().getDefiningOp()->getOperand(1), + iv); + mapping.map( + lhsDefiningOp.getSource().getDefiningOp()->getOperand(3), + innerIv); + auto lhsClone = innerBuilder.clone( + *lhsDefiningOp.getSource().getDefiningOp(), mapping); + + // Load and broadcast individual elements + SmallVector broadcasts; + for (int i = 0; i < M; i++) { + auto elem = innerBuilder.create( + loc, lhsClone->getResult(0), + ValueRange{ + c0, + innerBuilder.create(loc, i), + c0}); + auto bcast = innerBuilder.create( + loc, VectorType::get({N}, elem.getType()), elem); + broadcasts.push_back(bcast); + } + + IRMapping rhsMapping; + rhsMapping.map( + rhsDefiningOp.getSource().getDefiningOp()->getOperand(1), + iv); + rhsMapping.map( + rhsDefiningOp.getSource().getDefiningOp()->getOperand(2), + innerIv); + auto rhsClone = innerBuilder.clone( + *rhsDefiningOp.getSource().getDefiningOp(), rhsMapping); + auto rowVec = innerBuilder.create( + loc, VectorType::get({N}, elementType), + rhsClone->getResult(0), ValueRange{c0, c0, c0}); + + // Create M different FMAs using broadcasts and current + // accumulator values. + for (int i = 0; i < M; i++) { + auto fma = innerBuilder.create( + loc, broadcasts[i], rowVec, innerIterArgs[i]); + results.push_back(fma); + } + + // Yield all M results + innerBuilder.create(loc, results); + }); + + // Yield results from inner loop to outer loop + nestedBuilder.create(loc, newInnerForOp.getResults()); + }); + + Value matResult = ctx.outerForOp.getResult(0); + Operation *writeOp; + for (auto user : matResult.getUsers()) { + writeOp = dyn_cast(user); + if (writeOp) + break; + } + + // Store final results back to original locations. + if (writeOp) { + for (int i = 0; i < M; i++) { + rewriter.create(loc, newOuterForOp.getResult(i), + subview_2_splits[i], + ValueRange{c0, c0}); + } + } + + // Erase original write. + if (writeOp) + rewriter.eraseOp(writeOp); + + return success(); + } + +}; + +void linalg::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir new file mode 100644 index 0000000000000..f1f24e4e53cb6 --- /dev/null +++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir @@ -0,0 +1,97 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c64 step %c64 { + %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> + scf.for %arg5 = %c0 to %c24 step %c1 { + scf.for %arg6 = %c0 to %c64 step %c1 { + %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> + %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> + %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> + %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> + %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> + vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> + } + } + } + } + } + return %alloc : memref<8x24x32x64xf32> + } + + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + +// CHECK-LABEL: func.func @simple_gemm( +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> +// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> +// CHECK: scf.forall (%[[VAL_11:.*]], %[[VAL_12:.*]]) in (8, 24) { +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_13]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_11]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_5]] { +// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_15]], %[[VAL_16]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_18:.*]] = vector.transfer_read %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> +// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (vector<4x64xf32>) { +// CHECK: %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (vector<4x64xf32>) { +// CHECK: %[[VAL_25:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_20]], %[[VAL_15]], %[[VAL_23]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_26:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_23]], %[[VAL_16]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> +// CHECK: %[[VAL_27:.*]] = vector.transfer_read %[[VAL_25]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> +// CHECK: %[[VAL_28:.*]] = vector.transfer_read %[[VAL_26]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> +// CHECK: %[[VAL_29:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> +// CHECK: scf.yield %[[VAL_29]] : vector<4x64xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_22]] : vector<4x64xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_10]] : memref<8x24x32x64xf32> +// CHECK: } + + + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir new file mode 100644 index 0000000000000..ba11074e3c963 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c64 step %c64 { + %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> + %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> + %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) { + %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) { + %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> + %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> + %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> + %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> + %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> + scf.yield %6 : vector<4x64xf32> + } + scf.yield %3 : vector<4x64xf32> + } + vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> + } + } + } + return %alloc : memref<8x24x32x64xf32> + } + +// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + +// CHECK-LABEL: func.func @simple_gemm( +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> +// CHECK: %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> +// CHECK: scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) { +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] { +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] { +// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> +// CHECK: %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> +// CHECK: %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> +// CHECK: %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> +// CHECK: %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) { +// CHECK: %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) { +// CHECK: %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32> +// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32> +// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32> +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32> +// CHECK: %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> +// CHECK: %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32> +// CHECK: %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32> +// CHECK: %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32> +// CHECK: %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32> +// CHECK: %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32> +// CHECK: scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32> +// CHECK: } +// CHECK: vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> +// CHECK: vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> +// CHECK: vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> +// CHECK: vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_11]] : memref<8x24x32x64xf32> +// CHECK: } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.vector.contract_to_fma + } : !transform.any_op + transform.yield + } + } From 83ed5c4db59730cc0b3bd38a0e5c2551904992b1 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Mon, 6 Jan 2025 02:21:34 -0800 Subject: [PATCH 2/3] added few more test-cases and code re-factoring --- .../Linalg/TransformOps/LinalgTransformOps.td | 8 +- .../Dialect/Linalg/Transforms/Transforms.h | 9 +- .../Transforms/HoistVectorTransfers.cpp | 49 ++++++-- .../Linalg/Transforms/VectorContractToFMA.cpp | 56 ++++++++- .../Dialect/Linalg/hoist-vector-transfer.mlir | 93 ++++++++++++++- .../Linalg/vector-contract-to-fma.mlir | 107 ++++-------------- 6 files changed, 216 insertions(+), 106 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index d614bb4789767..9acaf1ba231a0 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -111,7 +111,8 @@ def ApplyHoistVectorTransferPatternsOp : Op]> { let description = [{ - Hoists the vector transfer reads/writes outside the reduction and k-loop. + Finds pattern to hoist the possible vector transfer reads/writes outside the reduction and k-loop + for a batch reduce matmul operation. }]; let assemblyFormat = "attr-dict"; @@ -122,9 +123,8 @@ def ApplyVectorContractToFMAPatternsOp : Op]> { let description = [{ - Implements the lowering of vector contraction op for GEMM of size MxN to - sequence of vector FMAs wrapped inside scf.for loop with iterargs to - accumulate the result of FMAs. + Collects pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to + sequence of vector FMAs. }]; let assemblyFormat = "attr-dict"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index d35f99826004e..6f639b45408d8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1825,13 +1825,12 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); -/// Pattern to hoists the vector transfer reads/writes outside the reduction and -/// k-loop. +/// Pattern to hoists the vector transfer reads/writes outside the reduction and +/// k-loop for batch reduce matmul operation if licm fails. void populateHoistVectorTransferPatterns(RewritePatternSet &patterns); - -/// Pattern to lower vector contraction op for GEMM of size MxN to -/// sequence of vector FMAs +/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to +/// sequence of vector FMAs. void populateVectorContractToFMAPatterns(RewritePatternSet &patterns); /// Pattern to fuse a `tensor.pad` operation with the producer of its source, diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp index 5911d36cbafef..1e741010c741e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp @@ -1,15 +1,11 @@ -//===-HoistVectorTransfers.cpp -----------------------------------------*- -// C++-*-===// +//===- HoistVectorTransfers.cpp ---------------------------------------*- C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This file implements tile configuration hoisting on parallel loops. -// -//===----------------------------------------------------------------------===// + #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -25,6 +21,7 @@ using namespace mlir; +// Function to retrives vector transfer read operations (Acc, Lhs, and Rhs) from contraction operation. static FailureOr> getContractOperands(vector::ContractionOp contractOp) { SmallVector list; @@ -38,6 +35,7 @@ getContractOperands(vector::ContractionOp contractOp) { return list; } +// Function to retrive subview from vector transfer read operation. static FailureOr> getReadOperands(SmallVector readOps) { SmallVector list; @@ -50,6 +48,7 @@ getReadOperands(SmallVector readOps) { return list; } +// Function to retrive the tiled nested loop structure (m->n->reduction->k) for the contract operation static FailureOr> getNestedLoop(vector::ContractionOp contractOp) { SmallVector list; @@ -64,6 +63,7 @@ getNestedLoop(vector::ContractionOp contractOp) { return list; } +// Function to check iv of nested loops matches with the subview static LogicalResult checkNestedLoop(SmallVector loops, SmallVector subviews) { auto subviewOpLhsOffsets = subviews[0].getOffsets(); @@ -90,6 +90,40 @@ static LogicalResult checkNestedLoop(SmallVector loops, return success(); } +/// Hoist vector transfer read and write operations for the tiled batch reduce matmul operation +/// outside the reduction and k-loop. +/// +/// As an example, the following pseudo-code will be rewritten +/// scf.for %arg3 = %c0 to %c32 step %c4 // m-loop +/// scf.for %arg4 = %c0 to %c64 step %c64 // n-loop +/// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] +/// scf.for %arg5 = %c0 to %c24 step %c1 // reduction-loop +/// scf.for %arg6 = %c0 to %c64 step %c1 // k-loop +/// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] +/// %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] +/// %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} +/// %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 +/// vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} +/// to: +/// scf.for %arg3 = %c0 to %c32 step %c4 +/// scf.for %arg4 = %c0 to %c64 step %c64 +/// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] +/// %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} +/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) { +/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) { +/// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] +/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] +/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 +/// scf.yield %6 : !type +/// } +/// scf.yield %3 : !type +/// } +/// vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} +/// struct HoistVectorTransferOp : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -98,7 +132,6 @@ struct HoistVectorTransferOp : OpRewritePattern { // Check the vector contract operation satisfies the required pattern. // Check the Acc, Lhs, and Rhs of contract operation - auto operands = getContractOperands(contractOp); if (failed(operands)) return rewriter.notifyMatchFailure(contractOp, @@ -145,7 +178,7 @@ struct HoistVectorTransferOp : OpRewritePattern { if (K != 1) return rewriter.notifyMatchFailure(contractOp, "K dim is not 1"); - // Check whether the linalg tiling + vector contract pattern matches for the + // Check whether the BR-matmul tiling + vector contract pattern matches for the // 4-nested loop structure auto loops = getNestedLoop(contractOp); if (failed(loops)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp index 2a8132b93bdcb..4d3dac6a2b4d0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp @@ -1,4 +1,3 @@ - //===--------------- VectorContractToFMA.cpp ------------*- C++-*-===// // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -67,6 +66,61 @@ struct TransformationContext { enum class MatMulType { Standard, Batch, BatchReduce }; + +/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to +/// sequence of vector FMAs. +/// +/// As an example, the following pseudo-code will be rewritten +/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] +/// %1 = vector.transfer_read %subview_1[%c0, %c0], %cst {in_bounds = [true, true]} +/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) { +/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) { +/// %subview_3 = memref.subview %subview_0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] +/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] +/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 +/// scf.yield %6 : !type +/// } +/// scf.yield %3 : !type +/// } +/// vector.transfer_write %2, %subview_1[%c0, %c0] {in_bounds = [true, true]} +/// to: +/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] +/// %subview_2 = memref.subview %subview_1[0, 0] [1, 64] [1, 1] +/// %subview_3 = memref.subview %subview_1[1, 0] [1, 64] [1, 1] +/// %subview_4 = memref.subview %subview_1[2, 0] [1, 64] [1, 1] +/// %subview_5 = memref.subview %subview_1[3, 0] [1, 64] [1, 1] +/// %1 = vector.load %subview_2[%c0, %c0] +/// %2 = vector.load %subview_3[%c0, %c0] +/// %3 = vector.load %subview_4[%c0, %c0] +/// %4 = vector.load %subview_5[%c0, %c0] +/// %5:4 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1, %arg7 = %2, %arg8 = %3, %arg9 = %4) -> (!type, !type, !type, !type) { +/// %6:4 = scf.for %arg10 = %c0 to %c64 step %c1 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!type, !type, !type, !type) { +/// %subview_6 = memref.subview %subview_0[%arg5, %arg3, %arg10] [1, 4, 1] [1, 1, 1] +/// %7 = memref.load %subview_6[%c0, %c0, %c0] +/// %8 = vector.broadcast %7 : f32 to !type +/// %9 = memref.load %subview_6[%c0, %c1, %c0] +/// %10 = vector.broadcast %9 : f32 to !type +/// %11 = memref.load %subview_6[%c0, %c2, %c0] +/// %12 = vector.broadcast %11 : f32 to !type +/// %13 = memref.load %subview_6[%c0, %c3, %c0] +/// %14 = vector.broadcast %13 : f32 to !type +/// %subview_7 = memref.subview %0[%arg5, %arg10, %arg4] [1, 1, 64] [1, 1, 1] +/// %15 = vector.load %subview_7[%c0, %c0, %c0] +/// %16 = vector.fma %8, %15, %arg11 : !type +/// %17 = vector.fma %10, %15, %arg12 : !type +/// %18 = vector.fma %12, %15, %arg13 : !type +/// %19 = vector.fma %14, %15, %arg14 : !type +/// scf.yield %16, %17, %18, %19 : !type, !type, !type, !type +/// } +/// scf.yield %6#0, %6#1, %6#2, %6#3 : !type, !type, !type, !type +/// } +/// vector.store %5#0, %subview_2[%c0, %c0] +/// vector.store %5#1, %subview_3[%c0, %c0] +/// vector.store %5#2, %subview_4[%c0, %c0] +/// vector.store %5#3, %subview_5[%c0, %c0]) +/// struct VectorContractToFMA : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir index f1f24e4e53cb6..3b57f159108ea 100644 --- a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir +++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir @@ -4,7 +4,7 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} - func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { %cst = arith.constant 0.000000e+00 : f32 %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> %c1 = arith.constant 1 : index @@ -46,7 +46,7 @@ // CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} -// CHECK-LABEL: func.func @simple_gemm( +// CHECK-LABEL: func.func @tiled_gemm_hoist_vector_transfer_operations( // CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> @@ -95,3 +95,92 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +module { + memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32> + %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32> + %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32> + vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + } + return %alloc : memref<8x24x32x64xf32> + } +} + +// CHECK-LABEL: func.func @gemm_without_tiling_so_no_hoisting +// CHECK: memref.subview +// CHECK-NEXT: vector.transfer_write +// CHECK-NEXT: memref.subview +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.contract +// CHECK-NEXT: vector.transfer_write + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } +} + + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> + +module { + func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32> + %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32> + %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32> + %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32> + %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32> + return %4 : tensor<4x64xf32> + } +} + + +// CHECK-LABEL: func.func @gemm_with_args_so_no_hoisting +// CHECK: vector.transfer_read +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.contract +// CHECK-NEXT: vector.transfer_write +// CHECK-NEXT: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } +} + + diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir index ba11074e3c963..f18c0dcb573d7 100644 --- a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir +++ b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir @@ -1,106 +1,41 @@ // RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} - func.func @simple_gemm(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + + func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) { %cst = arith.constant 0.000000e+00 : f32 - %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> %c1 = arith.constant 1 : index - %c24 = arith.constant 24 : index + %c16 = arith.constant 16 : index %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index %c4 = arith.constant 4 : index %c32 = arith.constant 32 : index %c0 = arith.constant 0 : index - %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> - scf.forall (%arg1, %arg2) in (8, 24) { - %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> - vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> - %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> - scf.for %arg3 = %c0 to %c32 step %c4 { - scf.for %arg4 = %c0 to %c64 step %c64 { - %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> - %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> - %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) { - %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) { - %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> - %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> - %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> - %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> - %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> - scf.yield %6 : vector<4x64xf32> - } + + scf.for %arg5 = %c0 to %c32 step %c4 { + scf.for %arg6 = %c0 to %c128 step %c64 { + %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>> + %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> + %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> { + %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> { + %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>> + %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>> + %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32> + %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32> + %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32> scf.yield %3 : vector<4x64xf32> } - vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> + scf.yield %con1 : vector<4x64xf32> } + vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> } } - return %alloc : memref<8x24x32x64xf32> + return } -// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} - -// CHECK-LABEL: func.func @simple_gemm( -// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> -// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 24 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 64 : index -// CHECK: %[[VAL_7:.*]] = arith.constant 4 : index -// CHECK: %[[VAL_8:.*]] = arith.constant 32 : index -// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> -// CHECK: %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> -// CHECK: scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) { -// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> -// CHECK: vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> -// CHECK: %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> -// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] { -// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] { -// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> -// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>> -// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>> -// CHECK: %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>> -// CHECK: %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>> -// CHECK: %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> -// CHECK: %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> -// CHECK: %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> -// CHECK: %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> -// CHECK: %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) { -// CHECK: %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) { -// CHECK: %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> -// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> -// CHECK: %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32> -// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> -// CHECK: %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32> -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> -// CHECK: %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32> -// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> -// CHECK: %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32> -// CHECK: %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> -// CHECK: %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32> -// CHECK: %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32> -// CHECK: %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32> -// CHECK: %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32> -// CHECK: %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32> -// CHECK: scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32> -// CHECK: } -// CHECK: scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32> -// CHECK: } -// CHECK: vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> -// CHECK: vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> -// CHECK: vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> -// CHECK: vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: return %[[VAL_11]] : memref<8x24x32x64xf32> -// CHECK: } +// CHECK-NOT: vector.fma module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { From 8aa56e403e3a30eaae5d3deba6a73072b6f60df7 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Mon, 6 Jan 2025 21:21:38 -0800 Subject: [PATCH 3/3] created a separate PR for vector.contract to fma --- .../Linalg/TransformOps/LinalgTransformOps.td | 13 - .../Dialect/Linalg/Transforms/Transforms.h | 5 - .../TransformOps/LinalgTransformOps.cpp | 8 - .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 - .../Linalg/Transforms/VectorContractToFMA.cpp | 410 ------------------ .../Dialect/Linalg/hoist-vector-transfer.mlir | 171 ++++---- .../Linalg/vector-contract-to-fma.mlir | 48 -- 7 files changed, 78 insertions(+), 578 deletions(-) delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp delete mode 100644 mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 12b3ddb9c7490..6b890272bb6b4 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -106,7 +106,6 @@ def ApplyFoldAddIntoDestPatternsOp : Op]> { @@ -118,18 +117,6 @@ def ApplyHoistVectorTransferPatternsOp : Op]> { - let description = [{ - Collects pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to - sequence of vector FMAs. - }]; - - let assemblyFormat = "attr-dict"; -} - def ApplyPadVectorizationPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 6f639b45408d8..8a06df4fed363 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1824,15 +1824,10 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, /// suffices for achieving the sum. void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); - /// Pattern to hoists the vector transfer reads/writes outside the reduction and /// k-loop for batch reduce matmul operation if licm fails. void populateHoistVectorTransferPatterns(RewritePatternSet &patterns); -/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to -/// sequence of vector FMAs. -void populateVectorContractToFMAPatterns(RewritePatternSet &patterns); - /// Pattern to fuse a `tensor.pad` operation with the producer of its source, /// if the producer is a `linalg` operation with all parallel iterator types. void populateFuseTensorPadWithProducerLinalgOpPatterns( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ed838f8476486..61a3db7302d8d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -262,19 +262,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns( linalg::populateFoldAddIntoDestPatterns(patterns); } - void transform::ApplyHoistVectorTransferPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::populateHoistVectorTransferPatterns(patterns); } - -void transform::ApplyVectorContractToFMAPatternsOp::populatePatterns( - RewritePatternSet &patterns) { - linalg::populateVectorContractToFMAPatterns(patterns); -} - - void transform::ApplyPadVectorizationPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::populatePadOpVectorizationPatterns(patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 90d926201cd75..63758a654f803 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -42,7 +42,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms Vectorization.cpp WinogradConv2D.cpp HoistVectorTransfers.cpp - VectorContractToFMA.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp deleted file mode 100644 index 4d3dac6a2b4d0..0000000000000 --- a/mlir/lib/Dialect/Linalg/Transforms/VectorContractToFMA.cpp +++ /dev/null @@ -1,410 +0,0 @@ -//===--------------- VectorContractToFMA.cpp ------------*- C++-*-===// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements lowering of vector contraction to vector fma. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#define DEBUG_TYPE "vector-contract-to-fma" - -using namespace mlir; - -/// Returns true if the \p map is transposed. -static bool isTransposed(AffineMap map) { - auto results = map.getResults(); - // Assert if the map does not have 3 or 4 inputs ([] m, n, k). - assert((map.getNumInputs() == 3 || map.getNumInputs() == 4) && - "3 or 4 input dim expected"); - // Assert if the result is not 2D. - assert(map.getNumResults() == 2 && "Only 2 output dim expected"); - - // Check the last two dimensions for transposition. - auto dimExpr0 = dyn_cast(results[0]); - auto dimExpr1 = dyn_cast(results[1]); - assert((dimExpr0 && dimExpr1) && "Unexpected dim expression"); - - // Exclude output map result. - bool isOutputResultMap = - dimExpr0 == - mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext()) && - dimExpr1 == - mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()); - assert(!isOutputResultMap && "Output result map not expected"); - - // It's transposed if result found as (k, m) or (n, k), else not transposed. - if ((dimExpr0 == - mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()) && - dimExpr1 == - mlir::getAffineDimExpr(map.getNumInputs() - 3, map.getContext())) || - (dimExpr0 == - mlir::getAffineDimExpr(map.getNumInputs() - 2, map.getContext()) && - dimExpr1 == - mlir::getAffineDimExpr(map.getNumInputs() - 1, map.getContext()))) - return true; - return false; -} - - -// Structure to hold transformation context -struct TransformationContext { - scf::ForOp innerForOp; - scf::ForOp outerForOp; - scf::ForOp outermostLoop; -}; - -enum class MatMulType { Standard, Batch, BatchReduce }; - - -/// Pattern to lower vector contraction operation (batch reduce matmul) for GEMM of size MxN to -/// sequence of vector FMAs. -/// -/// As an example, the following pseudo-code will be rewritten -/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] -/// %1 = vector.transfer_read %subview_1[%c0, %c0], %cst {in_bounds = [true, true]} -/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) { -/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) { -/// %subview_3 = memref.subview %subview_0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] -/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] -/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} -/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} -/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 -/// scf.yield %6 : !type -/// } -/// scf.yield %3 : !type -/// } -/// vector.transfer_write %2, %subview_1[%c0, %c0] {in_bounds = [true, true]} -/// to: -/// %subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] -/// %subview_2 = memref.subview %subview_1[0, 0] [1, 64] [1, 1] -/// %subview_3 = memref.subview %subview_1[1, 0] [1, 64] [1, 1] -/// %subview_4 = memref.subview %subview_1[2, 0] [1, 64] [1, 1] -/// %subview_5 = memref.subview %subview_1[3, 0] [1, 64] [1, 1] -/// %1 = vector.load %subview_2[%c0, %c0] -/// %2 = vector.load %subview_3[%c0, %c0] -/// %3 = vector.load %subview_4[%c0, %c0] -/// %4 = vector.load %subview_5[%c0, %c0] -/// %5:4 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1, %arg7 = %2, %arg8 = %3, %arg9 = %4) -> (!type, !type, !type, !type) { -/// %6:4 = scf.for %arg10 = %c0 to %c64 step %c1 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!type, !type, !type, !type) { -/// %subview_6 = memref.subview %subview_0[%arg5, %arg3, %arg10] [1, 4, 1] [1, 1, 1] -/// %7 = memref.load %subview_6[%c0, %c0, %c0] -/// %8 = vector.broadcast %7 : f32 to !type -/// %9 = memref.load %subview_6[%c0, %c1, %c0] -/// %10 = vector.broadcast %9 : f32 to !type -/// %11 = memref.load %subview_6[%c0, %c2, %c0] -/// %12 = vector.broadcast %11 : f32 to !type -/// %13 = memref.load %subview_6[%c0, %c3, %c0] -/// %14 = vector.broadcast %13 : f32 to !type -/// %subview_7 = memref.subview %0[%arg5, %arg10, %arg4] [1, 1, 64] [1, 1, 1] -/// %15 = vector.load %subview_7[%c0, %c0, %c0] -/// %16 = vector.fma %8, %15, %arg11 : !type -/// %17 = vector.fma %10, %15, %arg12 : !type -/// %18 = vector.fma %12, %15, %arg13 : !type -/// %19 = vector.fma %14, %15, %arg14 : !type -/// scf.yield %16, %17, %18, %19 : !type, !type, !type, !type -/// } -/// scf.yield %6#0, %6#1, %6#2, %6#3 : !type, !type, !type, !type -/// } -/// vector.store %5#0, %subview_2[%c0, %c0] -/// vector.store %5#1, %subview_3[%c0, %c0] -/// vector.store %5#2, %subview_4[%c0, %c0] -/// vector.store %5#3, %subview_5[%c0, %c0]) -/// -struct VectorContractToFMA - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override { - if (op.getKind() != vector::CombiningKind::ADD) - return rewriter.notifyMatchFailure( - op, "Unsupported combining kind, only supports ADD at the moment)"); - - auto maskableOp = cast(op.getOperation()); - if (maskableOp.isMasked()) - return rewriter.notifyMatchFailure(op, "Masked contractOp not supported"); - - SmallVector maps = op.getIndexingMapsArray(); - if (llvm::any_of( - maps, [](AffineMap map) { return !map.isProjectedPermutation(); })) - return rewriter.notifyMatchFailure(op, "Unexpected map"); - - // Check for the variant of matrix multiply. - auto iteratorTypes = op.getIteratorTypesArray(); - MatMulType matmulType; - unsigned outerDimIndex = 0; - if (iteratorTypes.size() > 3) { - outerDimIndex = iteratorTypes.size() - 4; - matmulType = - iteratorTypes[outerDimIndex] == vector::IteratorType::parallel - ? MatMulType::Batch - : MatMulType::BatchReduce; - outerDimIndex++; - } else if (iteratorTypes.size() == 3) { - matmulType = MatMulType::Standard; - } else { - return rewriter.notifyMatchFailure(op, "Not a gemm"); - } - - if (matmulType == MatMulType::Batch) - return rewriter.notifyMatchFailure(op, "Batch matmul not supported"); - if (iteratorTypes[outerDimIndex] != vector::IteratorType::parallel || - iteratorTypes[outerDimIndex + 1] != vector::IteratorType::parallel || - iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction) - return rewriter.notifyMatchFailure(op, "Not a gemm"); - - SmallVector results; - - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto acc = op.getAcc(); - auto lhsDefiningOp = lhs.getDefiningOp(); - auto rhsDefiningOp = rhs.getDefiningOp(); - auto accDefiningOp = acc.getDefiningOp(); - if (!lhsDefiningOp || !rhsDefiningOp) - return failure(); - - // Accumulator can be a TransferReadOp but must be coming from the chain of - // iterargs of nested loop. - if (accDefiningOp) - return failure(); - - // Make sure the inputs being read are whole tensor or subview. - if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) || - !llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) { - return failure(); - } - - auto lhsType = cast(lhsDefiningOp.getType()); - auto rhsType = cast(rhsDefiningOp.getType()); - // auto accType = acc.getType(); - // auto accType = cast(accDefiningOp.getType()); - - if (matmulType == MatMulType::BatchReduce && - (lhsType.getRank() != 3 || rhsType.getRank() != 3)) - return failure(); - - if (matmulType == MatMulType::Standard && - (lhsType.getRank() != 2 || rhsType.getRank() != 2)) - return failure(); - - // Check for non-transposed matrices. - auto mapLHS = maps[0]; - auto mapRHS = maps[1]; - if (matmulType == MatMulType::BatchReduce) { - mapLHS = mapLHS.dropResult(0); - mapRHS = mapRHS.dropResult(0); - } - if (isTransposed(mapLHS) || isTransposed(mapRHS)) - return rewriter.notifyMatchFailure( - op, "Transposed matrices are not expected"); - - // Verify that the accumulator is coming through a chain of iterargs of - // nested loop and it is define by 'TransferReadOp'. - // - struct TransformationContext ctx; - - ctx.innerForOp = op->getParentOfType(); - if (!ctx.innerForOp) - return failure(); - ctx.outerForOp = ctx.innerForOp->getParentOfType(); - if (!ctx.outerForOp) - return failure(); - ctx.outermostLoop = ctx.outerForOp->getParentOfType(); - if (!ctx.outermostLoop) - return failure(); - - // Verify original inner loop has only one iterarg. - auto origIterArgs = ctx.innerForOp.getRegionIterArgs(); - if (origIterArgs.size() != 1) - return failure(); - - // Verify chain, accumulator must be inner loop's iterarg. - auto bbArg = dyn_cast(acc); - if (!bbArg) - return failure(); - - // This block arg must be init arg, not induction variable. - if (bbArg.getOwner() != ctx.innerForOp.getBody() || - bbArg.getArgNumber() == 0) { - return failure(); - } - - // This iterarg must be intialized by outer loop's iterarg. - auto innerInitValue = - ctx.innerForOp.getInitArgs()[bbArg.getArgNumber() - 1]; - auto outerBBArg = dyn_cast(innerInitValue); - if (!outerBBArg) - return failure(); - - // This block arg must be init arg, not induction variable. - if (outerBBArg.getOwner() != ctx.outerForOp.getBody() || - outerBBArg.getArgNumber() == 0) { - return failure(); - } - - // Outer loop's iterarg initializer must be a TransferReadOp. - acc = ctx.outerForOp.getInitArgs()[outerBBArg.getArgNumber() - 1]; - - // This must be defined by vector.transfer_read - if (!acc.getDefiningOp()) - return failure(); - - accDefiningOp = acc.getDefiningOp(); - if (!accDefiningOp) - return failure(); - - // Only 2-D output expected. - auto accType = cast(accDefiningOp.getType()); - if (accType.getRank() != 2) - return failure(); - - int64_t M = accType.getDimSize(0); - int64_t N = accType.getDimSize(1); - int64_t K = lhsType.getDimSize(lhsType.getRank() - 1); - - // K must be 1. - if (K != 1) - return failure(); - - auto accSubview = accDefiningOp.getSource(); - Location loc = op.getLoc(); - - // Create M different <1xN> subviews. - auto memrefType = cast(accSubview.getType()); - auto elementType = memrefType.getElementType(); - SmallVector mixedSizes = {rewriter.getIndexAttr(K), - rewriter.getIndexAttr(N)}; - SmallVector mixedStrides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - - rewriter.setInsertionPoint( - ctx.outermostLoop.getBody(), - std::prev(ctx.outermostLoop.getBody()->end(), 1)); - - Value c0 = rewriter.create(loc, 0); - SmallVector subview_2_splits; - for (int i = 0; i < M; i++) { - SmallVector mixedOffsets = { - rewriter.getIndexAttr(i), - rewriter.getIndexAttr(0), - }; - auto split = rewriter.create( - loc, accSubview, mixedOffsets, mixedSizes, mixedStrides); - subview_2_splits.push_back(split); - } - - // Intialize each accumulator with a vector of size N - SmallVector initAccs; - for (auto subview : subview_2_splits) { - auto acc = rewriter.create( - loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0}); - initAccs.push_back(acc); - } - - // Create new outer loop with M different accumulators. - auto newOuterForOp = rewriter.create( - loc, ctx.outerForOp.getLowerBound(), ctx.outerForOp.getUpperBound(), - ctx.outerForOp.getStep(), initAccs, - [&](OpBuilder &nestedBuilder, Location loc, Value iv, - ValueRange iterArgs) { - // Create new inner loop with M accumulators. - auto newInnerForOp = nestedBuilder.create( - loc, ctx.innerForOp.getLowerBound(), - ctx.innerForOp.getUpperBound(), ctx.innerForOp.getStep(), - iterArgs, - [&](OpBuilder &innerBuilder, Location loc, Value innerIv, - ValueRange innerIterArgs) { - IRMapping mapping; - mapping.map( - lhsDefiningOp.getSource().getDefiningOp()->getOperand(1), - iv); - mapping.map( - lhsDefiningOp.getSource().getDefiningOp()->getOperand(3), - innerIv); - auto lhsClone = innerBuilder.clone( - *lhsDefiningOp.getSource().getDefiningOp(), mapping); - - // Load and broadcast individual elements - SmallVector broadcasts; - for (int i = 0; i < M; i++) { - auto elem = innerBuilder.create( - loc, lhsClone->getResult(0), - ValueRange{ - c0, - innerBuilder.create(loc, i), - c0}); - auto bcast = innerBuilder.create( - loc, VectorType::get({N}, elem.getType()), elem); - broadcasts.push_back(bcast); - } - - IRMapping rhsMapping; - rhsMapping.map( - rhsDefiningOp.getSource().getDefiningOp()->getOperand(1), - iv); - rhsMapping.map( - rhsDefiningOp.getSource().getDefiningOp()->getOperand(2), - innerIv); - auto rhsClone = innerBuilder.clone( - *rhsDefiningOp.getSource().getDefiningOp(), rhsMapping); - auto rowVec = innerBuilder.create( - loc, VectorType::get({N}, elementType), - rhsClone->getResult(0), ValueRange{c0, c0, c0}); - - // Create M different FMAs using broadcasts and current - // accumulator values. - for (int i = 0; i < M; i++) { - auto fma = innerBuilder.create( - loc, broadcasts[i], rowVec, innerIterArgs[i]); - results.push_back(fma); - } - - // Yield all M results - innerBuilder.create(loc, results); - }); - - // Yield results from inner loop to outer loop - nestedBuilder.create(loc, newInnerForOp.getResults()); - }); - - Value matResult = ctx.outerForOp.getResult(0); - Operation *writeOp; - for (auto user : matResult.getUsers()) { - writeOp = dyn_cast(user); - if (writeOp) - break; - } - - // Store final results back to original locations. - if (writeOp) { - for (int i = 0; i < M; i++) { - rewriter.create(loc, newOuterForOp.getResult(i), - subview_2_splits[i], - ValueRange{c0, c0}); - } - } - - // Erase original write. - if (writeOp) - rewriter.eraseOp(writeOp); - - return success(); - } - -}; - -void linalg::populateVectorContractToFMAPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir index 3b57f159108ea..b0b164951d4b3 100644 --- a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir +++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir @@ -3,51 +3,48 @@ #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} - func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> - %c1 = arith.constant 1 : index - %c24 = arith.constant 24 : index - %c64 = arith.constant 64 : index - %c4 = arith.constant 4 : index - %c32 = arith.constant 32 : index - %c0 = arith.constant 0 : index - %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> - scf.forall (%arg1, %arg2) in (8, 24) { - %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> - vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> - %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> - scf.for %arg3 = %c0 to %c32 step %c4 { - scf.for %arg4 = %c0 to %c64 step %c64 { - %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> - scf.for %arg5 = %c0 to %c24 step %c1 { - scf.for %arg6 = %c0 to %c64 step %c1 { - %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> - %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> - %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> - %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> - %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> - %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> - vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> - } +memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} +func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c64 step %c64 { + %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> + scf.for %arg5 = %c0 to %c24 step %c1 { + scf.for %arg6 = %c0 to %c64 step %c1 { + %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> + %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> + %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> + %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> + %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> + vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> } } } } - return %alloc : memref<8x24x32x64xf32> } - + return %alloc : memref<8x24x32x64xf32> +} // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - // CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} - // CHECK-LABEL: func.func @tiled_gemm_hoist_vector_transfer_operations( -// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> // CHECK: %[[VAL_3:.*]] = arith.constant 1 : index @@ -84,44 +81,39 @@ // CHECK: return %[[VAL_10]] : memref<8x24x32x64xf32> // CHECK: } - - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func { - transform.apply_patterns.vector.hoist_vector_transfer - } : !transform.any_op - transform.yield - } + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } } - // ----- #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> -module { - memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} - func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> - %c0 = arith.constant 0 : index - %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> - scf.forall (%arg1, %arg2) in (8, 24) { - %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> - vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> - %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> - %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32> - %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32> - %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32> - %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32> - vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> - } - return %alloc : memref<8x24x32x64xf32> +memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} +func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32> + %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32> + %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32> + vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> } + return %alloc : memref<8x24x32x64xf32> } // CHECK-LABEL: func.func @gemm_without_tiling_so_no_hoisting @@ -135,36 +127,31 @@ module { // CHECK-NEXT: vector.transfer_write module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func { - transform.apply_patterns.vector.hoist_vector_transfer - } : !transform.any_op - transform.yield - } + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } } - // ----- #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> - -module { - func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32> - %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32> - %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32> - %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32> - %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32> - return %4 : tensor<4x64xf32> - } +func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32> + %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32> + %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32> + %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32> + %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32> + return %4 : tensor<4x64xf32> } - // CHECK-LABEL: func.func @gemm_with_args_so_no_hoisting // CHECK: vector.transfer_read // CHECK-NEXT: vector.transfer_read @@ -174,13 +161,11 @@ module { // CHECK-NEXT: return module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func { - transform.apply_patterns.vector.hoist_vector_transfer - } : !transform.any_op - transform.yield - } + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } } - - diff --git a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir b/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir deleted file mode 100644 index f18c0dcb573d7..0000000000000 --- a/mlir/test/Dialect/Linalg/vector-contract-to-fma.mlir +++ /dev/null @@ -1,48 +0,0 @@ -// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s - -#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - - func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c4 = arith.constant 4 : index - %c32 = arith.constant 32 : index - %c0 = arith.constant 0 : index - - scf.for %arg5 = %c0 to %c32 step %c4 { - scf.for %arg6 = %c0 to %c128 step %c64 { - %subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>> - %2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> - %con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> { - %con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> { - %subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>> - %subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>> - %0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32> - %1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32> - %3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32> - scf.yield %3 : vector<4x64xf32> - } - scf.yield %con1 : vector<4x64xf32> - } - vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> - } - } - return - } - -// CHECK-NOT: vector.fma - - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %0 { - transform.apply_patterns.vector.contract_to_fma - } : !transform.any_op - transform.yield - } - }