diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 0fcaa96ade403..6f1c243cc4396 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp, /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`. bool isaCopyOpInterface(LinalgOp linalgOp); +/// Checks whether `genericOp` is semantically equivalent to a +/// `linalg.broadcast`. Returns broadcast dimensions if true. +std::optional> +isaBroadcastOpInterface(GenericOp genericOp); + +/// Checks whether `genericOp` is semantically equivalent to a +/// `linalg.transpose`. Returns permuted dimensions if true. +std::optional> +isaTransposeOpInterface(GenericOp genericOp); + /// Checks whether a given `genericOp` is semantically equivalent to a single /// linalgelementwise unary op. e.g. linalg.exp. /// A linalg.generic body could be a series of unary elementwise ops e.g. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index fbf3f19cde0e9..0a404194569c2 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -243,6 +243,18 @@ def LinalgStructuredInterface utils::IteratorType::parallel); }] >, + InterfaceMethod< + /*desc=*/[{ + Return true if all loops are parallel. + }], + /*retTy=*/"bool", + /*methodName=*/"isAllParallelLoops", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getNumParallelLoops() == getNumParallelLoops(); + }] + >, InterfaceMethod< /*desc=*/[{ Return the dims that are parallel loops. @@ -327,6 +339,18 @@ def LinalgStructuredInterface return !getBlock()->getArgument(bbArgNumber).use_empty(); }] >, + InterfaceMethod< + /*desc=*/[{ + Returns true only if linalgOp takes one input and produces one result. + }], + /*retTy=*/"bool", + /*methodName=*/"isSingleInputOutput", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getNumDpsInputs() == 1 && $_op.getNumDpsInits() == 1; + }] + >, InterfaceMethod< /*desc=*/[{ Return true if `opOperand` is an init tensor. This is true when it is diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 31f2913924726..a27c666a2aba4 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -210,6 +210,24 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [ } MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } + + // Return true only if GenericOp has a single input and single + // output, and the body is a single yieldOp that yields the input. + // This check is useful when trying to determine if the op is + // essentially a transpose, broadcast, copy or something like that. + bool isSingleYieldOp() { + if (!isSingleInputOutput()) + return false; + Block *body = getBody(); + if (body->getOperations().size() != 1) + return false; + + auto yieldOp = dyn_cast(body->back()); + if (!yieldOp || yieldOp.getNumOperands() != 1 || + yieldOp->getOperand(0) != body->getArgument(0)) + return false; + return true; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 40795879c3026..bd77965194b27 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include +#include using namespace mlir; using namespace mlir::linalg; @@ -53,112 +54,180 @@ bool linalg::detail::canOpOperandsBeDroppedImpl( // CopyOpInterface implementation //===----------------------------------------------------------------------===// -bool linalg::isaCopyOpInterface(LinalgOp linalgOp) { - // Structural. - if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) +bool linalg::isaCopyOpInterface(LinalgOp op) { + // Check all loops are parallel and linalgOp is single input and output. + if (!op.isAllParallelLoops() || !op.isSingleInputOutput()) return false; - // Operands and maps. - if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) - return false; - auto mapRange = linalgOp.getIndexingMapsArray(); + auto mapRange = op.getIndexingMapsArray(); if (mapRange.size() != 2 || !mapRange.front().isIdentity() || !mapRange.back().isIdentity()) { return false; } // Region. - return llvm::hasSingleElement(linalgOp.getBlock()->getOperations()); + return llvm::hasSingleElement(op.getBlock()->getOperations()); } //===----------------------------------------------------------------------===// // FillOpInterface implementation //===----------------------------------------------------------------------===// -std::optional linalg::isaFillOpInterface(GenericOp genericOp) { +std::optional linalg::isaFillOpInterface(GenericOp op) { // Structural. - if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() || - genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) + if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || + !op.isSingleYieldOp()) return std::nullopt; // Input should be referenced and init should not. - if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) || - genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0))) + if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) || + op.payloadUsesValueFromOperand(op.getDpsInitOperand(0))) return std::nullopt; - OpOperand *value = genericOp.getDpsInputOperand(0); - if (!genericOp.isScalar(value)) + OpOperand *value = op.getDpsInputOperand(0); + if (!op.isScalar(value)) return std::nullopt; + return value->get(); +} - Block *body = genericOp.getBody(); - if (body->getOperations().size() != 1) +//===----------------------------------------------------------------------===// +// BroadcastOpInterface implementation +//===----------------------------------------------------------------------===// +std::optional> +linalg::isaBroadcastOpInterface(GenericOp op) { + // Structural. + if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || + !op.isSingleYieldOp()) return std::nullopt; - auto yieldOp = dyn_cast(body->back()); - if (!yieldOp || yieldOp.getNumOperands() != 1 || - yieldOp->getOperand(0) != body->getArgument(0)) + auto srcTy = op.getDpsInputOperand(0)->get().getType(); + auto dstTy = op.getDpsInitOperand(0)->get().getType(); + if (!isa(srcTy) || + !isa(dstTy)) return std::nullopt; - return value->get(); + + // Check output is identity map. Broadcast could additionally be + // employing permutation of indices and that would be expressible + // in linalg.generic but is not expressible for named broadcast op. + auto dstMap = op.getIndexingMapsArray()[1]; + if (!dstMap.isIdentity()) + return std::nullopt; + + SmallVector position; + auto srcMap = op.getIndexingMapsArray()[0]; + + if (srcMap.getResults().size() >= dstMap.getResults().size()) + return std::nullopt; + + // Check input map is monotonically increasing DimIds. + for (unsigned i = 0; i < srcMap.getNumResults(); ++i) { + auto expr = llvm::dyn_cast(srcMap.getResults()[i]); + if (!expr) + return std::nullopt; + int64_t pos = expr.getPosition(); + if (i > 0 && pos <= position[i - 1]) + return std::nullopt; + position.push_back(expr.getPosition()); + } + + SmallVector broadcastedDims; + auto numDims = srcMap.getNumDims(); + // This is quadratic but number of items is generally small. + for (auto dim : llvm::seq(0, numDims)) { + if (!llvm::is_contained(position, dim)) + broadcastedDims.push_back(dim); + } + return broadcastedDims; +} + +//===----------------------------------------------------------------------===// +// TranposeOpInterface implementation +//===----------------------------------------------------------------------===// +std::optional> +linalg::isaTransposeOpInterface(GenericOp op) { + // To specialize as a transpose op, the genericOp must be + // all parallel loops, single input, single output, and its body + // should be just a yield op, yielding input as output as is (no compute). + if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || + !op.isSingleYieldOp()) + return std::nullopt; + + auto mapRange = op.getIndexingMapsArray(); + if (mapRange.size() != 2) + return std::nullopt; + + auto mapOfInput = mapRange.front(); + auto mapOfResult = mapRange.back(); + + // linalg.transpose permutes the dimensions of input using this + // rule: dim(result, i) = dim(input, permutation[i]) + if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation()) + return std::nullopt; + + SmallVector permutation(mapOfInput.getNumDims()); + for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) { + auto expr = llvm::cast(mapOfInput.getResults()[i]); + permutation[expr.getPosition()] = i; + } + return permutation; } //===----------------------------------------------------------------------===// // Elementwise Single Unary/Binary-OpInterface implementation //===----------------------------------------------------------------------===// -static bool -isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp, - unsigned arity) { +static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, + unsigned arity) { // Check all loops are parallel. - if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() || - genericOp.getNumLoops() < 1) + if (!op.isAllParallelLoops() || op.getNumLoops() < 1) return false; // Check there are arity-inputs, 1-output and all are identity-maps. - if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 || - !llvm::all_of(genericOp.getIndexingMapsArray(), + if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 || + !llvm::all_of(op.getIndexingMapsArray(), [](AffineMap map) { return map.isIdentity(); })) return false; // Init should not be referenced for elementwise operations. - if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0))) + if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0))) return false; // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such // as resulting from producer-consumer fusion. Here, we restrict to two ops in // the body, where the first is the elementwise single op and the second a // yield. - Block *body = genericOp.getBody(); + Block *body = op.getBody(); if (body->getOperations().size() != 2) return false; - Operation *op = &body->front(); - if (op->getNumOperands() != arity || op->getNumResults() != 1) + Operation *oper = &body->front(); + if (oper->getNumOperands() != arity || oper->getNumResults() != 1) return false; auto yieldOp = dyn_cast(body->back()); if (!yieldOp || yieldOp.getNumOperands() != 1 || - yieldOp->getOperand(0).getDefiningOp() != op) + yieldOp->getOperand(0).getDefiningOp() != oper) return false; return true; } -bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) { +bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) { // All basic elemwise checks. - if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1)) + if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1)) return false; // Check input is actully used. - if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0))) + if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0))) return false; return true; } -bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) { - if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2)) +bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) { + if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2)) return false; // Check both inputs are used (elementwise). - OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0); - OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1); - if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) || - !genericOp.payloadUsesValueFromOperand(inputOpOperand1)) + OpOperand *inputOpOperand0 = op.getDpsInputOperand(0); + OpOperand *inputOpOperand1 = op.getDpsInputOperand(1); + if (!op.payloadUsesValueFromOperand(inputOpOperand0) || + !op.payloadUsesValueFromOperand(inputOpOperand1)) return false; return true; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 4d7b748d7200e..dfafffce9d9b6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -259,18 +259,43 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, //===----------------------------------------------------------------------===// FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp) { + // Copy if (isaCopyOpInterface(genericOp)) { LinalgOp namedOp = rewriter.replaceOpWithNewOp( genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]); return namedOp; } + // Fill if (isaFillOpInterface(genericOp)) { LinalgOp namedOp = rewriter.replaceOpWithNewOp( genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]); return namedOp; } + // Broadcast + std::optional> equivalentToBroadcast = + isaBroadcastOpInterface(genericOp); + if (equivalentToBroadcast) { + auto dims = *equivalentToBroadcast; + LinalgOp namedOp = rewriter.replaceOpWithNewOp( + genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], + dims); + return namedOp; + } + + // Transpose + std::optional> equivalentToTranspose = + isaTransposeOpInterface(genericOp); + if (equivalentToTranspose) { + auto permutation = *equivalentToTranspose; + LinalgOp namedOp = rewriter.replaceOpWithNewOp( + genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], + permutation); + return namedOp; + } + + // Elementwise Unary if (isaElemwiseSingleUnaryOpInterface(genericOp)) { Operation *op = &genericOp.getBody()->front(); if (isa(op)) { @@ -279,6 +304,7 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, } } + // Elementwise Binary if (isaElemwiseSingleBinaryOpInterface(genericOp)) { bool swap = areBinOpsSwapped(genericOp); Operation *op = &genericOp.getBody()->front(); @@ -300,6 +326,7 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, } } + // Contraction - e.g. matmul if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } diff --git a/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir new file mode 100644 index 0000000000000..d6915ec8fbbf6 --- /dev/null +++ b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s + +// CHECK-LABEL: broadcast_first_dimension +// CHECK-SAME: %[[A:.+]]: tensor, %[[Out:.+]]: tensor) +// CHECK-NOT: linalg.generic +// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor) outs(%[[Out]] : tensor) dimensions = [0] +// +func.func @broadcast_first_dimension(%A: tensor, %Out: tensor) -> tensor { + %res = linalg.broadcast ins(%A: tensor) outs(%Out: tensor) dimensions = [0] + return %res : tensor +} + +// CHECK-LABEL: broadcast_mid_dimension +// CHECK-SAME: %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>) +// CHECK-NOT: linalg.generic +// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1] +// +func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { + %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1] + return %res : tensor<3x4x5xf32> +} + + +// CHECK-LABEL: broadcast_multiple_dimensions +// CHECK-SAME: %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>) +// CHECK-NOT: linalg.generic +// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6] +// +func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> { + %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6] + return %res : tensor<3x4x5x6x7x8x9xf32> +} diff --git a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir new file mode 100644 index 0000000000000..21b7b348f1c7f --- /dev/null +++ b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s + +// CHECK-LABEL: transpose2D +// CHECK-SAME: %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32> +// CHECK-NOT: linalg.generic +// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0] +// +func.func @transpose2D(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> { + %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0] + return %res : tensor<64x16xf32> +} + + +// CHECK-LABEL: transpose3D +// CHECK-SAME: %[[A:.+]]: tensor<7x8x9xf32>, %[[Out:.+]]: tensor<9x7x8xf32> +// CHECK-NOT: linalg.generic +// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<7x8x9xf32>) outs(%[[Out]] : tensor<9x7x8xf32>) permutation = [2, 0, 1] +// +func.func @transpose3D(%arg0: tensor<7x8x9xf32>, %arg1: tensor<9x7x8xf32>) -> tensor<9x7x8xf32> { + %transposed = linalg.transpose ins(%arg0 : tensor<7x8x9xf32>) outs(%arg1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] + return %transposed : tensor<9x7x8xf32> +} diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir new file mode 100644 index 0000000000000..542a7ed4a198b --- /dev/null +++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s + +#map = affine_map<(d0, d1, d2) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// This test checks that linalg.generic does not get incorrectly specialized to transform or broadcast. +// CHECK-LABEL: @transpose_and_broadcast +// CHECK: linalg.generic +func.func @transpose_and_broadcast(%arg0: tensor<7x8xf32>, %arg1: tensor<8x7x9xf32>) -> tensor<8x7x9xf32> { + %0 = linalg.generic + {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<7x8xf32>) outs(%arg1 : tensor<8x7x9xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<8x7x9xf32> + return %0 : tensor<8x7x9xf32> +} diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir index 35679db7412f3..31f2f6b1ab513 100644 --- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir @@ -4,18 +4,6 @@ #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0, d1) -> (d1, d0)> -func.func @broadcast_copy_expect_no_match(%arg0: memref, %arg1: memref) { - // expected-note @below {{when applied to this op}} - linalg.generic { - indexing_maps = [#map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref) outs(%arg1 : memref) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - return -} - func.func @not_a_copy_expect_no_match(%arg0: memref, %arg1: memref) { // expected-note @below {{when applied to this op}} linalg.generic {