Skip to content

Commit 0a16982

Browse files
committed
*Added checks for extended semantics and exit gracefully in user passes.
*Added and udated test cases. *Refactored verification logic.
1 parent 80ea697 commit 0a16982

File tree

6 files changed

+85
-39
lines changed

6 files changed

+85
-39
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3474,7 +3474,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34743474
/// It checks if the first result dimension is a function of the first
34753475
/// dimension.
34763476
static bool isValidBatchDim(AffineMap bcastMap) {
3477-
assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
34783477
AffineExpr exp = bcastMap.getResult(0);
34793478
return exp.isFunctionOfDim(0);
34803479
}
@@ -3490,6 +3489,48 @@ static bool isValidOutputResultDim(AffineMap outputMap) {
34903489
exp2.isFunctionOfDim(nPos);
34913490
}
34923491

3492+
// Check general validity of input indexing map.
3493+
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
3494+
AffineMap opIndexingMap,
3495+
AffineMap defaultIndexingMap, bool isLHS) {
3496+
// Check the result dims are valid.
3497+
if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3498+
return batchMatmulOp->emitOpError()
3499+
<< "Unexpected dim expression in map result.";
3500+
3501+
// Check for valid number of result dims of input maps.
3502+
if (opIndexingMap.getNumResults() > 3)
3503+
return batchMatmulOp->emitOpError()
3504+
<< "no. of result dim expression cannot exceed 3.";
3505+
3506+
// Check if the requested broadcast is valid.
3507+
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3508+
if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3509+
return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3510+
} else if (!isValidBatchDim(opIndexingMap)) {
3511+
return batchMatmulOp->emitOpError()
3512+
<< "Invalid batch dimension expression.";
3513+
}
3514+
return success();
3515+
}
3516+
3517+
/// This function checks if the given AffineMap for the output of a
3518+
/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
3519+
/// dimensions are valid.
3520+
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
3521+
AffineMap opIndexingMap) {
3522+
if (opIndexingMap.getNumResults() != 3)
3523+
return batchMatmulOp->emitOpError()
3524+
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3525+
<< ").";
3526+
3527+
if (!isValidOutputResultDim(opIndexingMap))
3528+
return batchMatmulOp->emitOpError()
3529+
<< "Invalid output map result dimension.";
3530+
3531+
return success();
3532+
}
3533+
34933534
/// Verifies the broadcast and transpose semantic specified by the explicit
34943535
/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
34953536
/// opIndex.
@@ -3503,37 +3544,14 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
35033544

35043545
auto opIndexingMap = opIndexingMaps[opIndex];
35053546
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3506-
// Check general validity of indexing map results.
3507-
if (opIndex < 2) {
3508-
if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3509-
return batchMatmulOp->emitOpError()
3510-
<< "Unexpected dim expression in map result.";
3511-
// Check if the requested broadcast is valid.
3512-
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3513-
if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap,
3514-
opIndex == 0)) {
3515-
return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3516-
}
3517-
} else {
3518-
// Check for valid number of result dims of input maps.
3519-
if (opIndexingMap.getNumResults() != 3)
3520-
return batchMatmulOp->emitOpError()
3521-
<< "no. of result dim expression cannot exceed 3.";
3522-
3523-
if (!isValidBatchDim(opIndexingMap))
3524-
return batchMatmulOp->emitOpError()
3525-
<< "Invalid batch dimension expression.";
3526-
}
3527-
} else {
3528-
// Check for valid number of result dims of output map.
3529-
if (opIndexingMap.getNumResults() != 3)
3530-
return batchMatmulOp->emitOpError()
3531-
<< "no. of result dim expression cannot exceed 3.";
35323547

3533-
if (!isValidOutputResultDim(opIndexingMap))
3534-
return batchMatmulOp->emitOpError()
3535-
<< "Invalid output map result dimension.";
3536-
}
3548+
if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
3549+
return failure();
3550+
3551+
if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3552+
opIndex == 0)))
3553+
return failure();
3554+
35373555
return success();
35383556
}
35393557

mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
138138
FailureOr<PackResult>
139139
linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
140140
const ControlBlockPackMatmulFn &controlPackMatmul) {
141+
// Check to not let go the batch_matmul with extended semantic, through this
142+
// transform.
143+
if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
144+
if (batchMatmulOp->hasUserDefinedMaps()) {
145+
return rewriter.notifyMatchFailure(
146+
*batchMatmulOp,
147+
"only batch_matmul ops with non-extended semantics are supported");
148+
}
149+
}
150+
141151
if (linalgOp.hasPureBufferSemantics())
142152
return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
143153

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,15 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
906906

907907
LogicalResult matchAndRewrite(FromOpTy contractionOp,
908908
PatternRewriter &rewriter) const override {
909+
// Check to not let go the batch_matmul with extended semantic, through this
910+
// transform.
911+
if (std::is_same<FromOpTy, BatchMatmulOp>::value) {
912+
if (contractionOp.hasUserDefinedMaps()) {
913+
return rewriter.notifyMatchFailure(
914+
contractionOp,
915+
"only batch_matmul ops with non-extended semantics are supported");
916+
}
917+
}
909918

910919
auto loc = contractionOp.getLoc();
911920
auto inputs = contractionOp.getDpsInputs();

mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ FailureOr<Operation *>
8888
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
8989
linalg::BatchMatmulOp batchMatmulOp,
9090
bool transposeLHS) {
91+
// Check to not let go the batch_matmul with extended semantic, through this
92+
// transform.
93+
if (batchMatmulOp.hasUserDefinedMaps()) {
94+
return rewriter.notifyMatchFailure(
95+
batchMatmulOp,
96+
"only batch_matmul ops with non-extended semantics are supported");
97+
}
98+
9199
if (!bufferization::hasTensorSemantics(batchMatmulOp))
92100
return rewriter.notifyMatchFailure(
93101
batchMatmulOp, "only matmul ops with tensors are supported");

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,21 +1004,22 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
10041004
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
10051005

10061006
// CHECK-LABEL: func.func @batch_matmul(
1007-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>, %[[VAL_1:.*]]: tensor<?x?x?xf32>,
1008-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
1009-
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_2]] : tensor<?x?x?xf32>) {
1007+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x5xf32>,
1008+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x5x7xf32>,
1009+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> {
1010+
// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<2x3x7xf32>) {
10101011
// CHECK: arith.mulf
10111012
// CHECK: arith.addf
10121013

1013-
func.func @batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
1014+
func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> {
10141015
%0 = linalg.batch_matmul indexing_maps = [
10151016
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
10161017
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
10171018
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
10181019
]
1019-
ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1020-
outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1021-
return %0 : tensor<?x?x?xf32>
1020+
ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
1021+
outs(%arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32>
1022+
return %0 : tensor<2x3x7xf32>
10221023
}
10231024

10241025
// -----

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,7 @@ func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
14081408
// -----
14091409

14101410
func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
1411-
// expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
1411+
// expected-error @+1 {{'linalg.batch_matmul' op expects 3 dims, but got (2).}}
14121412
linalg.batch_matmul indexing_maps = [
14131413
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
14141414
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,

0 commit comments

Comments
 (0)