Skip to content

Commit 1f44f8f

Browse files
committed
*Added logic and tests to verify the size of supplied indexing_map attribute.
1 parent 0a16982 commit 1f44f8f

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3542,6 +3542,10 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
35423542
SmallVector<AffineMap, 3> defaultIndexingMaps =
35433543
batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
35443544

3545+
if (opIndexingMaps.size() != 3)
3546+
return batchMatmulOp->emitOpError()
3547+
<< "Indexing_map attribute must have 3 affine maps.";
3548+
35453549
auto opIndexingMap = opIndexingMaps[opIndex];
35463550
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
35473551

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,33 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
12611261

12621262
// -----
12631263

1264+
func.func @indexing_map_size_mismatch_batch_matmul(%arg0: memref<?x?x?xf32>,
1265+
%arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1266+
// expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
1267+
linalg.batch_matmul indexing_maps = [
1268+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1269+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
1270+
]
1271+
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
1272+
outs(%arg2: memref<?x?x?xf32>)
1273+
return
1274+
}
1275+
1276+
// -----
1277+
1278+
func.func @indexing_map_size_one_batch_matmul(%arg0: memref<?x?x?xf32>,
1279+
%arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1280+
// expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
1281+
linalg.batch_matmul indexing_maps = [
1282+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1283+
]
1284+
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
1285+
outs(%arg2: memref<?x?x?xf32>)
1286+
return
1287+
}
1288+
1289+
// -----
1290+
12641291
func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
12651292
// expected-error @+1 {{expected attribute value}}
12661293
linalg.batch_matmul indexing_maps = [

0 commit comments

Comments
 (0)