diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 081bf9b6d3b23..6b890272bb6b4 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -106,6 +106,17 @@ def ApplyFoldAddIntoDestPatternsOp : Op]> { + let description = [{ + Finds pattern to hoist the possible vector transfer reads/writes outside the reduction and k-loop + for a batch reduce matmul operation. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyPadVectorizationPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 1dc700f22c202..8a06df4fed363 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1824,6 +1824,10 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, /// suffices for achieving the sum. void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); +/// Pattern to hoists the vector transfer reads/writes outside the reduction and +/// k-loop for batch reduce matmul operation if licm fails. +void populateHoistVectorTransferPatterns(RewritePatternSet &patterns); + /// Pattern to fuse a `tensor.pad` operation with the producer of its source, /// if the producer is a `linalg` operation with all parallel iterator types. void populateFuseTensorPadWithProducerLinalgOpPatterns( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index a1d619c8cd19d..61a3db7302d8d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -262,6 +262,11 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns( linalg::populateFoldAddIntoDestPatterns(patterns); } +void transform::ApplyHoistVectorTransferPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateHoistVectorTransferPatterns(patterns); +} + void transform::ApplyPadVectorizationPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::populatePadOpVectorizationPatterns(patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 3594b08413812..63758a654f803 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -41,6 +41,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms DecomposeGenericByUnfoldingPermutation.cpp Vectorization.cpp WinogradConv2D.cpp + HoistVectorTransfers.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp new file mode 100644 index 0000000000000..1e741010c741e --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistVectorTransfers.cpp @@ -0,0 +1,267 @@ +//===- HoistVectorTransfers.cpp ---------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +// Function to retrives vector transfer read operations (Acc, Lhs, and Rhs) from contraction operation. +static FailureOr> +getContractOperands(vector::ContractionOp contractOp) { + SmallVector list; + for (int i = 0; i < 3; i++) { + auto vectorReadOp = + contractOp.getOperand(i).getDefiningOp(); + if (!vectorReadOp) + return failure(); + list.push_back(vectorReadOp); + } + return list; +} + +// Function to retrive subview from vector transfer read operation. +static FailureOr> +getReadOperands(SmallVector readOps) { + SmallVector list; + for (vector::TransferReadOp readOp : readOps) { + auto subViewOp = readOp.getOperand(0).getDefiningOp(); + if (!subViewOp) + return failure(); + list.push_back(subViewOp); + } + return list; +} + +// Function to retrive the tiled nested loop structure (m->n->reduction->k) for the contract operation +static FailureOr> +getNestedLoop(vector::ContractionOp contractOp) { + SmallVector list; + Operation *current = contractOp; + for (int i = 0; i < 4; i++) { + Operation *parent = current->getParentOfType(); + if (!parent) + return failure(); + list.push_back(dyn_cast(parent)); + current = parent; + } + return list; +} + +// Function to check iv of nested loops matches with the subview +static LogicalResult checkNestedLoop(SmallVector loops, + SmallVector subviews) { + auto subviewOpLhsOffsets = subviews[0].getOffsets(); + auto subviewOpRhsOffsets = subviews[1].getOffsets(); + auto subviewOpAccOffsets = subviews[2].getOffsets(); + + Value ivK = loops[0].getInductionVar(); + if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1]) + return failure(); + + Value ivReduction = loops[1].getInductionVar(); + if (ivReduction != subviewOpLhsOffsets[0] || + ivReduction != subviewOpRhsOffsets[0]) + return failure(); + + Value ivN = loops[2].getInductionVar(); + if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2]) + return failure(); + + Value ivM = loops[3].getInductionVar(); + if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0]) + return failure(); + + return success(); +} + +/// Hoist vector transfer read and write operations for the tiled batch reduce matmul operation +/// outside the reduction and k-loop. +/// +/// As an example, the following pseudo-code will be rewritten +/// scf.for %arg3 = %c0 to %c32 step %c4 // m-loop +/// scf.for %arg4 = %c0 to %c64 step %c64 // n-loop +/// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] +/// scf.for %arg5 = %c0 to %c24 step %c1 // reduction-loop +/// scf.for %arg6 = %c0 to %c64 step %c1 // k-loop +/// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] +/// %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] +/// %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} +/// %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 +/// vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} +/// to: +/// scf.for %arg3 = %c0 to %c32 step %c4 +/// scf.for %arg4 = %c0 to %c64 step %c64 +/// %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] +/// %1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} +/// %2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (!type) { +/// %3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (!type) { +/// %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] +/// %subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] +/// %4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} +/// %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 +/// scf.yield %6 : !type +/// } +/// scf.yield %3 : !type +/// } +/// vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} +/// +struct HoistVectorTransferOp : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + // Check the vector contract operation satisfies the required pattern. + // Check the Acc, Lhs, and Rhs of contract operation + auto operands = getContractOperands(contractOp); + if (failed(operands)) + return rewriter.notifyMatchFailure(contractOp, + "Invalid operands for contract op"); + + auto readOps = *operands; + auto vectorReadOpAcc = readOps[2]; + auto vectorReadOpLhs = readOps[0]; + auto vectorReadOpRhs = readOps[1]; + + // Check whether the operand of vector transfer read is a subview + auto subviews = getReadOperands(readOps); + if (failed(subviews)) + return rewriter.notifyMatchFailure( + contractOp, "Vector read op operands are not a subview"); + + // Check the operation type MatMul, B-MatMul, or BR-MatMul + SmallVector contractIteratorTypes = + contractOp.getIteratorTypesArray(); + int reductionCount = + std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(), + vector::IteratorType::reduction); + + auto vectorReadOpLhsType = cast(vectorReadOpLhs.getType()); + auto vectorReadOpRhsRank = + (cast(vectorReadOpRhs.getType())).getRank(); + + if (reductionCount == 2 && + (vectorReadOpLhsType.getRank() != 3 || vectorReadOpRhsRank != 3)) + return rewriter.notifyMatchFailure( + contractOp, "Invalid rank for batch reduce operation"); + + if (reductionCount == 1) + return rewriter.notifyMatchFailure( + contractOp, "Batch matmul operation not supported yet"); + + if (reductionCount > 2) + return rewriter.notifyMatchFailure( + contractOp, "The vector contract operation is not a gemm"); + + // Check the K-dim to be 1 + int64_t K = + vectorReadOpLhsType.getDimSize(vectorReadOpLhsType.getRank() - 1); + if (K != 1) + return rewriter.notifyMatchFailure(contractOp, "K dim is not 1"); + + // Check whether the BR-matmul tiling + vector contract pattern matches for the + // 4-nested loop structure + auto loops = getNestedLoop(contractOp); + if (failed(loops)) + return rewriter.notifyMatchFailure( + contractOp, "Invalid loop nest in contract pattern"); + + auto checkLoops = checkNestedLoop(*loops, *subviews); + if (failed(checkLoops)) + return rewriter.notifyMatchFailure( + contractOp, "Loops doesn't match the iv in subviews"); + + auto nestedLoops = *loops; + auto kForOp = nestedLoops[0]; + auto reductionForOp = nestedLoops[1]; + + // Move the vector transfer read before the reduction and k loop + rewriter.setInsertionPoint(reductionForOp); + auto *cloneVectorReadOp = rewriter.clone(*vectorReadOpAcc); + + // Code to re-create the reduction and k loop with iter args + auto vectorReadOpValue = cloneVectorReadOp->getResult(0); + auto newReductionForOp = rewriter.create( + reductionForOp.getLoc(), reductionForOp.getLowerBound(), + reductionForOp.getUpperBound(), reductionForOp.getStep(), + ValueRange{vectorReadOpValue}, + [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp, + Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) { + auto newKForOp = rewriter.create( + kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), + kForOp.getStep(), iterArgsNewReductionForOp, + [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, + Value ivNewKForOp, ValueRange iterArgsNewKForOp) { + IRMapping mapper; + mapper.map(reductionForOp.getInductionVar(), + ivNewReductionForOp); + mapper.map(kForOp.getInductionVar(), ivNewKForOp); + + for (auto &op : kForOp.getBody()->without_terminator()) { + rewriterNewKForOp.clone(op, mapper); + } + rewriterNewKForOp.create(locNewKForOp, + iterArgsNewKForOp); + }); + rewriterNewReductionForOp.create( + locNewReductionForOp, newKForOp.getResult(0)); + }); + + // Code to hoist vector transfer write after reduction loop and also to + // update the yield of k loop + auto newKForOp = + llvm::dyn_cast(newReductionForOp.getBody()->front()); + Value newcontractOpValue; + vector::TransferWriteOp vectorWriteOperation; + Block *bodyBlock = newKForOp.getBody(); + for (auto &op : bodyBlock->getOperations()) { + if (auto vectorContractOp = llvm::dyn_cast(op)) { + vectorContractOp.setOperand(vectorContractOp.getNumOperands() - 1, + newKForOp.getRegionIterArgs()[0]); + newcontractOpValue = vectorContractOp.getResult(); + } + if (auto yieldOp = llvm::dyn_cast(op)) { + yieldOp.setOperand(0, newcontractOpValue); + } + if (auto vectorWriteOp = llvm::dyn_cast(op)) { + vectorWriteOperation = vectorWriteOp; + } + } + + vectorWriteOperation.setOperand(0, newReductionForOp.getResult(0)); + vectorWriteOperation->moveBefore(reductionForOp); + + // Erase the old vector contract operation + for (auto result : contractOp->getResults()) { + for (auto *userOp : result.getUsers()) { + userOp->erase(); + } + } + contractOp.erase(); + + return success(); + } +}; + +void linalg::populateHoistVectorTransferPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir new file mode 100644 index 0000000000000..b0b164951d4b3 --- /dev/null +++ b/mlir/test/Dialect/Linalg/hoist-vector-transfer.mlir @@ -0,0 +1,171 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} +func.func @tiled_gemm_hoist_vector_transfer_operations(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + scf.for %arg3 = %c0 to %c32 step %c4 { + scf.for %arg4 = %c0 to %c64 step %c64 { + %subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> + scf.for %arg5 = %c0 to %c24 step %c1 { + scf.for %arg6 = %c0 to %c64 step %c1 { + %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> + %subview_4 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> + %1 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> + %2 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> + %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> + vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> + } + } + } + } + } + return %alloc : memref<8x24x32x64xf32> +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} +// CHECK-LABEL: func.func @tiled_gemm_hoist_vector_transfer_operations( +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 24 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> +// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> +// CHECK: scf.forall (%[[VAL_11:.*]], %[[VAL_12:.*]]) in (8, 24) { +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_13]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_11]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_5]] { +// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_15]], %[[VAL_16]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_18:.*]] = vector.transfer_read %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32> +// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (vector<4x64xf32>) { +// CHECK: %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (vector<4x64xf32>) { +// CHECK: %[[VAL_25:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_20]], %[[VAL_15]], %[[VAL_23]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>> +// CHECK: %[[VAL_26:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_23]], %[[VAL_16]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>> +// CHECK: %[[VAL_27:.*]] = vector.transfer_read %[[VAL_25]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32> +// CHECK: %[[VAL_28:.*]] = vector.transfer_read %[[VAL_26]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_1]] {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32> +// CHECK: %[[VAL_29:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32> +// CHECK: scf.yield %[[VAL_29]] : vector<4x64xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_22]] : vector<4x64xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_17]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_10]] : memref<8x24x32x64xf32> +// CHECK: } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64} +func.func @gemm_without_tiling_so_no_hoisting(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32> + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> + scf.forall (%arg1, %arg2) in (8, 24) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> + %1 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>, vector<24x32x64xf32> + %2 = vector.transfer_read %0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<24x64x64xf32>, vector<24x64x64xf32> + %3 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x64xf32, strided<[64, 1], offset: ?>>, vector<32x64xf32> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<24x32x64xf32>, vector<24x64x64xf32> into vector<32x64xf32> + vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>> + } + return %alloc : memref<8x24x32x64xf32> +} + +// CHECK-LABEL: func.func @gemm_without_tiling_so_no_hoisting +// CHECK: memref.subview +// CHECK-NEXT: vector.transfer_write +// CHECK-NEXT: memref.subview +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.contract +// CHECK-NEXT: vector.transfer_write + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @gemm_with_args_so_no_hoisting(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32> + %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32> + %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32> + %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32> + %4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32> + return %4 : tensor<4x64xf32> +} + +// CHECK-LABEL: func.func @gemm_with_args_so_no_hoisting +// CHECK: vector.transfer_read +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.transfer_read +// CHECK-NEXT: vector.contract +// CHECK-NEXT: vector.transfer_write +// CHECK-NEXT: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.hoist_vector_transfer + } : !transform.any_op + transform.yield + } +}