-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][linalg] raise generic to named ops. #110421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |||||
| #include "llvm/ADT/SmallBitVector.h" | ||||||
| #include "llvm/ADT/SmallVector.h" | ||||||
| #include <algorithm> | ||||||
| #include <numeric> | ||||||
|
|
||||||
| using namespace mlir; | ||||||
| using namespace mlir::linalg; | ||||||
|
|
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl( | |||||
| return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); | ||||||
| } | ||||||
|
|
||||||
| // Returns true if all loops of the linalgOp are parallel | ||||||
| static bool isAllParallel(LinalgOp op) { | ||||||
| return op.getNumParallelLoops() == op.getNumLoops(); | ||||||
| } | ||||||
|
|
||||||
| // Returns true if and only if linalgOp takes one input and one init. | ||||||
| static bool isSingleInputOutput(LinalgOp op) { | ||||||
| return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1; | ||||||
| } | ||||||
| // Returns true if genericOp body is just a yieldOp that yields | ||||||
| // input operand as result. | ||||||
| static bool isSingleYieldOp(GenericOp op) { | ||||||
| if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1) | ||||||
| return false; | ||||||
|
|
||||||
| Block *body = op.getBody(); | ||||||
| if (body->getOperations().size() != 1) | ||||||
| return false; | ||||||
|
|
||||||
| auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); | ||||||
| if (!yieldOp || yieldOp.getNumOperands() != 1 || | ||||||
| yieldOp->getOperand(0) != body->getArgument(0)) | ||||||
| return false; | ||||||
| return true; | ||||||
| } | ||||||
|
|
||||||
| //===----------------------------------------------------------------------===// | ||||||
| // CopyOpInterface implementation | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| bool linalg::isaCopyOpInterface(LinalgOp linalgOp) { | ||||||
| // Structural. | ||||||
| if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) | ||||||
| // Structural and operands | ||||||
|
||||||
| if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp)) | ||||||
| return false; | ||||||
|
|
||||||
| // Operands and maps. | ||||||
| if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) | ||||||
| return false; | ||||||
| auto mapRange = linalgOp.getIndexingMapsArray(); | ||||||
| if (mapRange.size() != 2 || !mapRange.front().isIdentity() || | ||||||
| !mapRange.back().isIdentity()) { | ||||||
|
|
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) { | |||||
| //===----------------------------------------------------------------------===// | ||||||
| std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) { | ||||||
| // Structural. | ||||||
| if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() || | ||||||
| genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) | ||||||
| if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) || | ||||||
| !isSingleYieldOp(genericOp)) | ||||||
| return std::nullopt; | ||||||
|
|
||||||
| // Input should be referenced and init should not. | ||||||
|
|
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) { | |||||
| OpOperand *value = genericOp.getDpsInputOperand(0); | ||||||
| if (!genericOp.isScalar(value)) | ||||||
| return std::nullopt; | ||||||
| return value->get(); | ||||||
| } | ||||||
|
|
||||||
| Block *body = genericOp.getBody(); | ||||||
| if (body->getOperations().size() != 1) | ||||||
| //===----------------------------------------------------------------------===// | ||||||
| // BroadcastOpInterface implementation | ||||||
| //===----------------------------------------------------------------------===// | ||||||
| std::optional<SmallVector<int64_t>> | ||||||
| linalg::isaBroadcastOpInterface(GenericOp genericOp) { | ||||||
| // Structural. | ||||||
| if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) || | ||||||
| !isSingleYieldOp(genericOp)) | ||||||
| return std::nullopt; | ||||||
|
|
||||||
| auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); | ||||||
| if (!yieldOp || yieldOp.getNumOperands() != 1 || | ||||||
| yieldOp->getOperand(0) != body->getArgument(0)) | ||||||
| auto t0 = genericOp.getDpsInputOperand(0)->get().getType(); | ||||||
rengolin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| auto t1 = genericOp.getDpsInitOperand(0)->get().getType(); | ||||||
| if (!isa<MemRefType, RankedTensorType>(t0) || | ||||||
| !isa<MemRefType, RankedTensorType>(t1)) | ||||||
|
||||||
| return std::nullopt; | ||||||
| return value->get(); | ||||||
|
|
||||||
| // Check output is identity map. Injective function could also be | ||||||
| // a permutation of indices and expressible in linalg.generic but | ||||||
| // is not expressible for named broadcast op. | ||||||
| auto dstMap = genericOp.getIndexingMapsArray()[1]; | ||||||
| if (!dstMap.isIdentity()) | ||||||
| return std::nullopt; | ||||||
|
|
||||||
| SmallVector<int64_t> position; | ||||||
| auto srcMap = genericOp.getIndexingMapsArray()[0]; | ||||||
rengolin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
|
||||||
| // Check input map is monotonically increasing DimIds. | ||||||
| for (unsigned i = 0; i < srcMap.getNumResults(); ++i) { | ||||||
| auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]); | ||||||
| if (!expr) | ||||||
| return std::nullopt; | ||||||
| int64_t pos = expr.getPosition(); | ||||||
| if (i > 0 && pos <= position[i - 1]) | ||||||
| return std::nullopt; | ||||||
| position.push_back(expr.getPosition()); | ||||||
| } | ||||||
|
|
||||||
| SmallVector<int64_t> broadcastedDims; | ||||||
| auto numDims = srcMap.getNumDims(); | ||||||
| for (auto dim : llvm::seq<int64_t>(0, numDims)) { | ||||||
rengolin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| if (!llvm::is_contained(position, dim)) | ||||||
| broadcastedDims.push_back(dim); | ||||||
| } | ||||||
| return broadcastedDims; | ||||||
| } | ||||||
|
|
||||||
| //===----------------------------------------------------------------------===// | ||||||
| // TranposeOpInterface implementation | ||||||
| //===----------------------------------------------------------------------===// | ||||||
| std::optional<SmallVector<int64_t>> | ||||||
| linalg::isaTransposeOpInterface(GenericOp genericOp) { | ||||||
| // Structural. | ||||||
| if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) || | ||||||
| !isSingleYieldOp(genericOp)) | ||||||
| return std::nullopt; | ||||||
|
|
||||||
| // mapping checks. | ||||||
|
||||||
| // mapping checks. | |
| // Check the maps. |
rengolin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why a new file instead of re-using roundtrip.mlir? Note that this file is called "roundtrip-broadcast.mlir", but it test both broadcasts and transposes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I doubled checked transpose and broadcast are in separate file e.g. the linalg.transpose are in roundtrip-transpose.mlir. It may be that in the browser it is appearing mashed up across comments. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: broadcast_first_dimension | ||
| // CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>) | ||
| // CHECK-NOT: linalg.generic | ||
| // CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0] | ||
| // | ||
| func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { | ||
| %res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0] | ||
| return %res : tensor<?x?x?xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: broadcast_mid_dimension | ||
| // CHECK-SAME: %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>) | ||
| // CHECK-NOT: linalg.generic | ||
| // CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1] | ||
| // | ||
| func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { | ||
| %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1] | ||
| return %res : tensor<3x4x5xf32> | ||
| } | ||
|
|
||
|
|
||
| // CHECK-LABEL: broadcast_multiple_dimensions | ||
| // CHECK-SAME: %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>) | ||
| // CHECK-NOT: linalg.generic | ||
| // CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6] | ||
| // | ||
| func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> { | ||
| %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6] | ||
| return %res : tensor<3x4x5x6x7x8x9xf32> | ||
| } |
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps add a 3d, 1d cases? And identity?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aded 3D test. Thanks for the suggestion. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: linalg_transpose | ||
| // CHECK-SAME: %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32> | ||
| // CHECK-NOT: linalg.generic | ||
| // CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0] | ||
| // | ||
| func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> { | ||
|
||
| %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0] | ||
| return %res : tensor<64x16xf32> | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these belong here? IIUC, the comment above ("Interface utility functions") refers to ODS/TableGen "interfaces" (i.e. none of these is a
InterfaceMethod).Having said that, why not add them to the interface?