diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 308ce92e35520..3682a68b0e2c8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1692,6 +1692,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); +/// Adds patterns that reduce the rank of named contraction ops that have +/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`, +/// ``, `expand_shape` (if on tensors). For example a +/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul` +/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`. +void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index c0829397f1f85..36f8696bf1b27 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -833,4 +833,265 @@ struct LinalgFoldUnitExtentDimsPass (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; + +} // namespace + +namespace { + +/// Returns reassociation indices for collapsing/expanding a +/// tensor of rank `rank` at position `pos`. +static SmallVector +getReassociationForReshapeAtDim(int64_t rank, int64_t pos) { + SmallVector reassociation(rank - 1, {0, 1}); + bool lastDim = pos == rank - 1; + if (rank > 2) { + for (int64_t i = 0; i < rank - 1; i++) { + if (i == pos || (lastDim && i == pos - 1)) + reassociation[i] = ReassociationIndices{i, i + 1}; + else if (i < pos) + reassociation[i] = ReassociationIndices{i}; + else + reassociation[i] = ReassociationIndices{i + 1}; + } + } + return reassociation; +} + +/// Returns a collapsed `val` where the collapsing occurs at dim `pos`. +/// If `pos < 0`, then don't collapse. +static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, + int64_t pos) { + if (pos < 0) + return val; + auto valType = cast(val.getType()); + SmallVector collapsedShape(valType.getShape()); + collapsedShape.erase(collapsedShape.begin() + pos); + return collapseValue( + rewriter, val.getLoc(), val, collapsedShape, + getReassociationForReshapeAtDim(valType.getRank(), pos), + ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); +} + +/// Base class for all rank reduction patterns for contraction ops +/// with unit dimensions. All patterns should convert one named op +/// to another named op. Intended to reduce only one iteration space dim +/// at a time. +/// Reducing multiple dims will happen with recusive application of +/// pattern rewrites. +template +struct RankReduceContractionOps : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + /// Collapse all collapsable operands. + SmallVector + collapseOperands(PatternRewriter &rewriter, ArrayRef operands, + ArrayRef operandCollapseDims) const { + assert(operandCollapseDims.size() == 3 && operands.size() == 3 && + "expected 3 operands and dims"); + return llvm::map_to_vector( + llvm::zip(operands, operandCollapseDims), [&](auto pair) { + return collapseSingletonDimAt(rewriter, std::get<0>(pair), + std::get<1>(pair)); + }); + } + + /// Expand result tensor. + Value expandResult(PatternRewriter &rewriter, Value result, + RankedTensorType expandedType, int64_t dim) const { + return rewriter.create( + result.getLoc(), expandedType, result, + getReassociationForReshapeAtDim(expandedType.getRank(), dim)); + } + + LogicalResult matchAndRewrite(FromOpTy contractionOp, + PatternRewriter &rewriter) const override { + + auto loc = contractionOp.getLoc(); + auto inputs = contractionOp.getDpsInputs(); + auto inits = contractionOp.getDpsInits(); + if (inputs.size() != 2 || inits.size() != 1) + return rewriter.notifyMatchFailure(contractionOp, + "expected 2 inputs and 1 init"); + auto lhs = inputs[0]; + auto rhs = inputs[1]; + auto init = inits[0]; + SmallVector operands{lhs, rhs, init}; + + SmallVector operandUnitDims; + if (failed(getOperandUnitDims(contractionOp, operandUnitDims))) + return rewriter.notifyMatchFailure(contractionOp, + "no reducable dims found"); + + SmallVector collapsedOperands = + collapseOperands(rewriter, operands, operandUnitDims); + Value collapsedLhs = collapsedOperands[0]; + Value collapsedRhs = collapsedOperands[1]; + Value collapsedInit = collapsedOperands[2]; + SmallVector collapsedResultTy; + if (isa(collapsedInit.getType())) + collapsedResultTy.push_back(collapsedInit.getType()); + auto collapsedOp = rewriter.create( + loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, + ValueRange{collapsedInit}); + for (auto attr : contractionOp->getAttrs()) { + if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) + continue; + collapsedOp->setAttr(attr.getName(), attr.getValue()); + } + + auto results = contractionOp.getResults(); + assert(results.size() < 2 && "expected at most one result"); + if (results.empty()) { + rewriter.replaceOp(contractionOp, collapsedOp); + } else { + rewriter.replaceOp( + contractionOp, + expandResult(rewriter, collapsedOp.getResultTensors()[0], + cast(results[0].getType()), + operandUnitDims[2])); + } + + return success(); + } + + /// Populate `operandUnitDims` with 3 indices indicating the unit dim + /// for each operand that should be collapsed in this pattern. If an + /// operand shouldn't be collapsed, the index should be negative. + virtual LogicalResult + getOperandUnitDims(LinalgOp op, + SmallVectorImpl &operandUnitDims) const = 0; +}; + +/// Patterns for unbatching batched contraction ops +template +struct RankReduceToUnBatched : RankReduceContractionOps { + using RankReduceContractionOps::RankReduceContractionOps; + + /// Look for unit batch dims to collapse. + LogicalResult + getOperandUnitDims(LinalgOp op, + SmallVectorImpl &operandUnitDims) const override { + FailureOr maybeContractionDims = + inferContractionDims(op); + if (failed(maybeContractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); + return failure(); + } + ContractionDimensions contractionDims = maybeContractionDims.value(); + + if (contractionDims.batch.size() != 1) + return failure(); + auto batchDim = contractionDims.batch[0]; + SmallVector, 3> bOperands; + op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands); + if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) { + return cast(std::get<0>(pair).getType()) + .getShape()[std::get<1>(pair)] != 1; + })) { + LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); + return failure(); + } + + operandUnitDims = SmallVector{std::get<1>(bOperands[0]), + std::get<1>(bOperands[1]), + std::get<1>(bOperands[2])}; + return success(); + } +}; + +/// Patterns for reducing non-batch dimensions +template +struct RankReduceMatmul : RankReduceContractionOps { + using RankReduceContractionOps::RankReduceContractionOps; + + /// Helper for determining whether the lhs/init or rhs/init are reduced. + static bool constexpr reduceLeft = + (std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v); + + /// Look for non-batch spatial dims to collapse. + LogicalResult + getOperandUnitDims(LinalgOp op, + SmallVectorImpl &operandUnitDims) const override { + FailureOr maybeContractionDims = + inferContractionDims(op); + if (failed(maybeContractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); + return failure(); + } + ContractionDimensions contractionDims = maybeContractionDims.value(); + + if constexpr (reduceLeft) { + auto m = contractionDims.m[0]; + SmallVector, 2> mOperands; + op.mapIterationSpaceDimToAllOperandDims(m, mOperands); + if (mOperands.size() != 2) + return failure(); + if (llvm::all_of(mOperands, [](auto pair) { + return cast(std::get<0>(pair).getType()) + .getShape()[std::get<1>(pair)] == 1; + })) { + operandUnitDims = SmallVector{std::get<1>(mOperands[0]), -1, + std::get<1>(mOperands[1])}; + return success(); + } + } else { + auto n = contractionDims.n[0]; + SmallVector, 2> nOperands; + op.mapIterationSpaceDimToAllOperandDims(n, nOperands); + if (nOperands.size() != 2) + return failure(); + if (llvm::all_of(nOperands, [](auto pair) { + return cast(std::get<0>(pair).getType()) + .getShape()[std::get<1>(pair)] == 1; + })) { + operandUnitDims = SmallVector{-1, std::get<1>(nOperands[0]), + std::get<1>(nOperands[1])}; + return success(); + } + } + LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); + return failure(); + } +}; + } // namespace + +void mlir::linalg::populateContractionOpRankReducingPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + // Unbatching patterns for unit batch size + patterns.add>(context); + patterns + .add>( + context); + patterns + .add>( + context); + patterns.add>(context); + patterns.add>(context); + + // Non-batch rank 1 reducing patterns + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + // Batch rank 1 reducing patterns + patterns.add>(context); + patterns.add>(context); + patterns.add>( + context); + patterns.add>( + context); + + // Non-batch rank 0 reducing patterns + patterns.add>(context); + patterns.add>(context); +} diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir new file mode 100644 index 0000000000000..c086d0fd7e633 --- /dev/null +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -0,0 +1,267 @@ +//RUN: mlir-opt -test-linalg-rank-reduce-contraction-ops --canonicalize -split-input-file %s | FileCheck %s + +func.func @singleton_batch_matmul_tensor(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512x256xf32>, %arg2: tensor<1x128x256xf32>) -> tensor<1x128x256xf32> { + // CHECK-LABEL: @singleton_batch_matmul_tensor + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512x256xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128x256xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512x256xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128x256xf32>) + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1], [2]] output_shape [1, 128, 256] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512x256xf32>) + outs(%arg2 : tensor<1x128x256xf32>) -> tensor<1x128x256xf32> + return %1 : tensor<1x128x256xf32> +} + +// ----- + +func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) { + // CHECK-LABEL: @singleton_batch_matmul_memref + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x?x?xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref, memref) outs(%[[COLLAPSED_INIT]] : memref) + // CHECK-NEXT: return + linalg.batch_matmul ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>) + outs(%arg2 : memref<1x?x?xf32>) + return +} + +// ----- + +func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> { + // CHECK-LABEL: @singleton_batch_matvec + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x512xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x128xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>) + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>) + outs(%arg2 : tensor<1x128xf32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> +} + +// ----- + +func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK-LABEL: @singleton_batch_vecmat + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] + // CHECK-NEXT: return %[[RES]] + %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>) + outs(%arg2 : tensor<1x?xf32>) -> tensor<1x?xf32> + return %1 : tensor<1x?xf32> +} + +// ----- + +func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) { + // CHECK-LABEL: @singleton_batchmatmul_transpose_a + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>) + // CHECK-NEXT: return + linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>) + return +} + +// ----- + +func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) { + // CHECK-LABEL: @singleton_batchmatmul_transpose_b + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32> + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] + // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>) + // CHECK-NEXT: return + linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>) + return +} + +// ----- + +func.func @matmul_to_matvec_tensor(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-LABEL: @matmul_to_matvec_tensor + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1] + // CHECK-NEXT: return %[[RES]] + %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @matmul_to_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK-LABEL: @matmul_to_matvec + // CHECK: linalg.matvec + linalg.matmul ins(%arg0, %arg1: memref, memref) outs(%arg2: memref) + return +} + +// ----- + +func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { + // CHECK-LABEL: @matmul_to_vecmat + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32> + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] + // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] + // CHECK-NEXT: return %[[RES]] + %0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32> + return %0 : tensor<1x?xf32> +} + +// ----- + +func.func @batch_matmul_to_vecmat(%arg0: memref<1x1x?xf32>, %arg1: memref<1x?x?xf32>, %arg2: memref<1x1x?xf32>) { + // CHECK-LABEL: @batch_matmul_to_vecmat + // CHECK: linalg.vecmat + linalg.batch_matmul ins(%arg0, %arg1: memref<1x1x?xf32>, memref<1x?x?xf32>) outs(%arg2: memref<1x1x?xf32>) + return +} + +// ----- + +func.func @matvec_to_dot(%arg0: memref<1x?xf32>, %arg1: memref, %arg2: memref<1xf32>) { + // CHECK-LABEL: @matvec_to_dot + // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x?xf32> + // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref + // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1xf32> + // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1]] + // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] [] + // CHECK-NEXT: linalg.dot ins(%[[COLLAPSED_LHS]], %[[RHS]] : memref, memref) outs(%[[COLLAPSED_INIT]] : memref) + linalg.matvec ins(%arg0, %arg1: memref<1x?xf32>, memref) outs(%arg2: memref<1xf32>) + return +} + +// ----- + +func.func @vecmat_to_dot(%arg0: memref, %arg1: memref, %arg2: memref<1xf32>) { + // CHECK-LABEL: @vecmat_to_dot + // CHECK: linalg.dot + linalg.vecmat ins(%arg0, %arg1: memref, memref) outs(%arg2: memref<1xf32>) + return +} + +// ----- + +func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, %arg2: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-LABEL: @matvec_to_dot_tensor + // CHECK: linalg.dot + %0 = linalg.matvec ins(%arg0, %arg1: tensor<1x?xf32>, tensor) outs(%arg2: tensor<1xf32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- + +func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> { + // CHECK-LABEL: @matmul_transpose_a_to_vecmat + // CHECK: collapse_shape {{.*}} into tensor<256xf32> + // CHECK: collapse_shape {{.*}} into tensor<512xf32> + // CHECK: linalg.vecmat + // CHECK: expand_shape {{.*}} into tensor<1x512xf32> + %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32> + return %0 : tensor<1x512xf32> +} + +// ----- + +func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> { + // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat + // CHECK: collapse_shape {{.*}} into tensor<64x256xf32> + // CHECK: collapse_shape {{.*}} into tensor<64x512xf32> + // CHECK: linalg.batch_vecmat + // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32> + %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> + return %0 : tensor<64x1x512xf32> +} + +// ----- + +func.func @matmul_transpose_b_to_matvec(%arg0: memref, %arg1: memref<1x?xf32>, %arg2: memref) { + // CHECK-LABEL: @matmul_transpose_b_to_matvec + // CHECK: linalg.matvec + linalg.matmul_transpose_b ins(%arg0, %arg1: memref, memref<1x?xf32>) outs(%arg2: memref) + return +} + +// ----- + +func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> { + // CHECK: collapse_shape {{.*}} into tensor<64x256xf32> + // CHECK: collapse_shape {{.*}} into tensor<64x128xf32> + // CHECK: linalg.batch_matvec + // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32> + %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> + return %0 : tensor<64x128x1xf32> +} + +// ----- + +func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> { + // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot + // CHECK: linalg.dot + %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> + return %0 : tensor<1x1x1xf32> +} + +// ----- + +func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> { + // CHECK-LABEL: @nonsingleton_batch_matmul + // CHECK-NOT: collapse_shape + // CHECK: linalg.batch_matmul + // CHECK-NOT: expand_shape + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<2x?x?xf32>, tensor<2x?x?xf32>) + outs(%arg2 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + return %1 : tensor<2x?x?xf32> +} + +// ----- + +func.func @nonsingleton_batch_matmul_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2: tensor) -> tensor { + // CHECK-LABEL: @nonsingleton_batch_matmul_dynamic + // CHECK-NOT: collapse_shape + // CHECK: linalg.batch_matmul + // CHECK-NOT: expand_shape + %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %1 : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt index b28f2b3564662..283e426b4e594 100644 --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_library(MLIRLinalgTestPasses TestLinalgDropUnitDims.cpp TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp + TestLinalgRankReduceContractionOps.cpp TestLinalgTransforms.cpp TestPadFusion.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp new file mode 100644 index 0000000000000..8b455d7d68c30 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp @@ -0,0 +1,67 @@ +//===- TestLinalgRankReduceContractionOps.cpp -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing rank reduing patterns for named +// contraction ops with unit dims. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct TestLinalgRankReduceContractionOps + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestLinalgRankReduceContractionOps) + + TestLinalgRankReduceContractionOps() = default; + TestLinalgRankReduceContractionOps( + const TestLinalgRankReduceContractionOps &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { + return "test-linalg-rank-reduce-contraction-ops"; + } + StringRef getDescription() const final { + return "Test Linalg rank reduce contraction ops with unit dims"; + } + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + func::FuncOp funcOp = this->getOperation(); + + RewritePatternSet patterns(context); + linalg::populateContractionOpRankReducingPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(patterns)))) + return signalPassFailure(); + return; + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgRankReduceContractionOps() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 0e8b161d51345..d4ea7a9cae0d2 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -106,6 +106,7 @@ void registerTestLinalgDecomposeOps(); void registerTestLinalgDropUnitDims(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgGreedyFusion(); +void registerTestLinalgRankReduceContractionOps(); void registerTestLinalgTransforms(); void registerTestLivenessAnalysisPass(); void registerTestLivenessPass(); @@ -235,6 +236,7 @@ void registerTestPasses() { mlir::test::registerTestLinalgDropUnitDims(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgGreedyFusion(); + mlir::test::registerTestLinalgRankReduceContractionOps(); mlir::test::registerTestLinalgTransforms(); mlir::test::registerTestLivenessAnalysisPass(); mlir::test::registerTestLivenessPass();