Skip to content

Commit 80ea697

Browse files
committed
-Added output map verification and corresponding tests.
-Replaced assert for the count of number of dim expression with proper error reporting and new test case. -Fixed typos.
1 parent 4968fa7 commit 80ea697

File tree

2 files changed

+101
-18
lines changed

2 files changed

+101
-18
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3479,7 +3479,18 @@ static bool isValidBatchDim(AffineMap bcastMap) {
34793479
return exp.isFunctionOfDim(0);
34803480
}
34813481

3482-
/// Verifies the broadcast and transpose semantic sepecified by the explicit
3482+
/// Checks if the given AffineMap's result dimensions are valid output result
3483+
/// dimensions.
3484+
static bool isValidOutputResultDim(AffineMap outputMap) {
3485+
enum Indices { batchPos, mPos, nPos };
3486+
AffineExpr exp0 = outputMap.getResult(batchPos);
3487+
AffineExpr exp1 = outputMap.getResult(mPos);
3488+
AffineExpr exp2 = outputMap.getResult(nPos);
3489+
return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
3490+
exp2.isFunctionOfDim(nPos);
3491+
}
3492+
3493+
/// Verifies the broadcast and transpose semantic specified by the explicit
34833494
/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
34843495
/// opIndex.
34853496
static LogicalResult
@@ -3493,19 +3504,35 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
34933504
auto opIndexingMap = opIndexingMaps[opIndex];
34943505
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
34953506
// Check general validity of indexing map results.
3496-
if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3497-
return batchMatmulOp->emitOpError()
3498-
<< "Unexpected dim expression in map result.";
3499-
// Check if the requested broadcast is valid.
3500-
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3501-
if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
3502-
return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3507+
if (opIndex < 2) {
3508+
if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3509+
return batchMatmulOp->emitOpError()
3510+
<< "Unexpected dim expression in map result.";
3511+
// Check if the requested broadcast is valid.
3512+
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3513+
if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap,
3514+
opIndex == 0)) {
3515+
return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3516+
}
3517+
} else {
3518+
// Check for valid number of result dims of input maps.
3519+
if (opIndexingMap.getNumResults() != 3)
3520+
return batchMatmulOp->emitOpError()
3521+
<< "no. of result dim expression cannot exceed 3.";
3522+
3523+
if (!isValidBatchDim(opIndexingMap))
3524+
return batchMatmulOp->emitOpError()
3525+
<< "Invalid batch dimension expression.";
35033526
}
35043527
} else {
3505-
if (!isValidBatchDim(opIndexingMap)) {
3528+
// Check for valid number of result dims of output map.
3529+
if (opIndexingMap.getNumResults() != 3)
35063530
return batchMatmulOp->emitOpError()
3507-
<< "Invalid batch dimension expression.";
3508-
}
3531+
<< "no. of result dim expression cannot exceed 3.";
3532+
3533+
if (!isValidOutputResultDim(opIndexingMap))
3534+
return batchMatmulOp->emitOpError()
3535+
<< "Invalid output map result dimension.";
35093536
}
35103537
return success();
35113538
}
@@ -3910,7 +3937,7 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
39103937

39113938
void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
39123939
ArrayRef<NamedAttribute> attrs) {
3913-
assert(3 > 0 && block.getNumArguments() == 3 &&
3940+
assert(block.getNumArguments() == 3 &&
39143941
"BatchMatmulOp regionBuilder expects 3 (>=0) args");
39153942
RegionBuilderHelper helper(b, block);
39163943
SmallVector<Value> yields;
@@ -3992,7 +4019,7 @@ LogicalResult BatchMatmulOp::verify() {
39924019
if (!hasUserDefinedMaps())
39934020
return success();
39944021

3995-
for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
4022+
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
39964023
if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
39974024
return failure();
39984025
}

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,7 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
12611261

12621262
// -----
12631263

1264-
func.func @missing_indexing_map_batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
1264+
func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
12651265
// expected-error @+1 {{expected attribute value}}
12661266
linalg.batch_matmul indexing_maps = [
12671267
,
@@ -1275,27 +1275,27 @@ func.func @missing_indexing_map_batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: te
12751275

12761276
// -----
12771277

1278-
func.func @invalid_dim_expr_batch_matmul_a(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
1278+
func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
12791279
// expected-error @+1 {{Unexpected dim expression in map result}}
12801280
linalg.batch_matmul indexing_maps = [
12811281
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
12821282
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
12831283
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
12841284
]
1285-
ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
1285+
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
12861286
return
12871287
}
12881288

12891289
// -----
12901290

1291-
func.func @invalid_dim_expr_batch_matmul_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
1291+
func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
12921292
// expected-error @+1 {{Unexpected dim expression in map result}}
12931293
linalg.batch_matmul indexing_maps = [
12941294
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
12951295
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
12961296
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
12971297
]
1298-
ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
1298+
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
12991299
return
13001300
}
13011301

@@ -1376,3 +1376,59 @@ func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: mem
13761376
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
13771377
return
13781378
}
1379+
1380+
// -----
1381+
1382+
func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
1383+
// expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
1384+
linalg.batch_matmul indexing_maps = [
1385+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
1386+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1387+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1388+
]
1389+
ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
1390+
outs(%arg2: memref<?x?xf32>)
1391+
return
1392+
}
1393+
1394+
// -----
1395+
1396+
func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
1397+
// expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
1398+
linalg.batch_matmul indexing_maps = [
1399+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1400+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
1401+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1402+
]
1403+
ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
1404+
outs(%arg2: memref<?x?xf32>)
1405+
return
1406+
}
1407+
1408+
// -----
1409+
1410+
func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
1411+
// expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
1412+
linalg.batch_matmul indexing_maps = [
1413+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1414+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1415+
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1416+
]
1417+
ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
1418+
outs(%arg2: memref<?x?xf32>)
1419+
return
1420+
}
1421+
1422+
// -----
1423+
1424+
func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1425+
// expected-error @+1 {{'linalg.batch_matmul' op Invalid output map result dimension.}}
1426+
linalg.batch_matmul indexing_maps = [
1427+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1428+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1429+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1430+
]
1431+
ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
1432+
outs(%arg2: memref<?x?x?xf32>)
1433+
return
1434+
}

0 commit comments

Comments
 (0)