Skip to content

Commit abfacb4

Browse files
committed
-Fixed invalid broadcast test.
-Added consistent variable and function naming in test cases. -Improved ops indexing_maps description.
1 parent 7acedcf commit abfacb4

File tree

5 files changed

+158
-144
lines changed

5 files changed

+158
-144
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -691,29 +691,29 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
691691
Example Transpose:
692692
```
693693
linalg.matmul
694-
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
695-
affine_map<(d0, d1, d2) -> (d2, d1)>,
696-
affine_map<(d0, d1, d2) -> (d0, d1)>]
694+
indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
695+
affine_map<(m, n, k) -> (k, n)>,
696+
affine_map<(m, n, k) -> (m, n)>]
697697
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
698698
outs(%arg2: memref<3x7xf32>)
699699
```
700700

701701
Example Broadcast:
702702
```
703703
linalg.matmul
704-
indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, // broadcast
705-
affine_map<(d0, d1, d2) -> (d2, d1)>,
706-
affine_map<(d0, d1, d2) -> (d0, d1)>]
704+
indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
705+
affine_map<(m, n, k) -> (k, n)>,
706+
affine_map<(m, n, k) -> (m, n)>]
707707
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
708708
outs(%arg2: memref<3x7xf32>)
709709
```
710710

711711
Example Broadcast and transpose:
712712
```
713713
linalg.matmul
714-
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
715-
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
716-
affine_map<(d0, d1, d2) -> (d0, d1)>]
714+
indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
715+
affine_map<(m, n, k) -> (k)>, // broadcast
716+
affine_map<(m, n, k) -> (m, n)>]
717717
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>)
718718
outs(%arg2: memref<3x7xf32>)
719719
```
@@ -773,7 +773,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
773773
static void regionBuilder(ImplicitLocOpBuilder &b,
774774
Block &block, ArrayRef<NamedAttribute> attrs);
775775

776-
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
776+
/// Returns a list of AffineMap with the default matmul indexing charactristic.
777777
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
778778

779779
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
@@ -953,29 +953,29 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
953953
Example Transpose:
954954
```
955955
linalg.batch_matmul
956-
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
957-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
958-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
956+
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
957+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
958+
affine_map<(batch, m, n, k) -> (batch, m, n)>]
959959
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
960960
outs(%arg2: memref<2x3x7xf32>)
961961
```
962962

963963
Example Broadcast:
964964
```
965965
linalg.batch_matmul
966-
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
967-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
968-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
966+
indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
967+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
968+
affine_map<(batch, m, n, k) -> (batch, m, n)>]
969969
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
970970
outs(%arg2: memref<2x3x7xf32>)
971971
```
972972

973973
Example Broadcast and Transpose:
974974
```
975975
linalg.batch_matmul
976-
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
977-
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
978-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
976+
indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
977+
affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
978+
affine_map<(batch, m, n, k) -> (batch, m, n)>]
979979
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
980980
outs(%arg2: memref<2x3x7xf32>)
981981
```
@@ -1081,29 +1081,29 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
10811081
Example Transpose:
10821082
```
10831083
linalg.batch_reduce_matmul
1084-
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
1085-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1086-
affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
1084+
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
1085+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
1086+
affine_map<(batch, m, n, k) -> (m, n)>]
10871087
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
10881088
outs(%arg2: memref<3x7xf32>)
10891089
```
10901090

10911091
Example Broadcast:
10921092
```
10931093
linalg.batch_reduce_matmul
1094-
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
1095-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1096-
affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
1094+
indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
1095+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
1096+
affine_map<(batch, m, n, k) -> (m, n)>]
10971097
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
10981098
outs(%arg2: memref<3x7xf32>)
10991099
```
11001100

11011101
Example Broadcast and Transpose:
11021102
```
11031103
linalg.batch_reduce_matmul
1104-
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
1105-
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
1106-
affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
1104+
indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
1105+
affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
1106+
affine_map<(batch, m, n, k) -> (m, n)>]
11071107
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
11081108
outs(%arg2: memref<3x7xf32>)
11091109
```
@@ -1163,7 +1163,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
11631163
static void regionBuilder(ImplicitLocOpBuilder &b,
11641164
Block &block, ArrayRef<NamedAttribute> attrs);
11651165

1166-
/// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
1166+
/// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
11671167
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
11681168

11691169
/// Returns true if the given broadcast map \p bcastMap is valid for this op.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3996,9 +3996,12 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
39963996
} else if (bcastMap.getNumResults() == 2) {
39973997
AffineExpr exp0 = bcastMap.getResult(0);
39983998
AffineExpr exp1 = bcastMap.getResult(1);
3999-
isValid = isLHS
4000-
? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
4001-
: (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
3999+
isValid =
4000+
isLHS
4001+
? ((exp0.isFunctionOfDim(batchPos) || exp0.isFunctionOfDim(mPos)) &&
4002+
exp1.isFunctionOfDim(kPos))
4003+
: ((exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(kPos)) ||
4004+
(exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)));
40024005
}
40034006
return isValid;
40044007
}
@@ -5459,9 +5462,12 @@ bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
54595462
} else if (bcastMap.getNumResults() == 2) {
54605463
AffineExpr exp0 = bcastMap.getResult(0);
54615464
AffineExpr exp1 = bcastMap.getResult(1);
5462-
isValid = isLHS
5463-
? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
5464-
: (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
5465+
isValid =
5466+
isLHS
5467+
? ((exp0.isFunctionOfDim(batchPos) || exp0.isFunctionOfDim(mPos)) &&
5468+
exp1.isFunctionOfDim(kPos))
5469+
: ((exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(kPos)) ||
5470+
(exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)));
54655471
}
54665472
return isValid;
54675473
}

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,24 +1029,24 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
10291029
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
10301030

10311031
// CHECK-LABEL: func.func @batch_reduce_matmul(
1032-
// CHECK-SAME: %[[ARG_A:.*]]: tensor<2x3x5xf32>,
1033-
// CHECK-SAME: %[[ARG_B:.*]]: tensor<2x5x7xf32>,
1034-
// CHECK-SAME: %[[ARG_C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
1032+
// CHECK-SAME: %[[A:.*]]: tensor<2x3x5xf32>,
1033+
// CHECK-SAME: %[[B:.*]]: tensor<2x5x7xf32>,
1034+
// CHECK-SAME: %[[C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
10351035
// CHECK: linalg.generic
10361036
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]],
10371037
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
10381038
// CHECK: arith.mulf
10391039
// CHECK: arith.addf
10401040
// CHECK: linalg.yield
10411041

1042-
func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
1042+
func.func @batch_reduce_matmul(%A: tensor<2x3x5xf32>, %B: tensor<2x5x7xf32>, %C: tensor<3x7xf32>) -> tensor<3x7xf32> {
10431043
%0 = linalg.batch_reduce_matmul indexing_maps = [
10441044
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
10451045
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
10461046
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
10471047
]
1048-
ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
1049-
outs(%arg2: tensor<3x7xf32>) -> tensor<3x7xf32>
1048+
ins(%A, %B: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
1049+
outs(%C: tensor<3x7xf32>) -> tensor<3x7xf32>
10501050
return %0 : tensor<3x7xf32>
10511051
}
10521052

0 commit comments

Comments
 (0)