Skip to content

Commit 5f37e76

Browse files
committed
Address @banach-space's comments
1 parent a676cf8 commit 5f37e76

File tree

11 files changed

+378
-363
lines changed

11 files changed

+378
-363
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
688688
AttrSizedOperandSegments,
689689
LinalgContractionOpInterface]> {
690690
let summary = [{
691-
Perform a contraction on two inputs, accumulating on top of a third.
691+
Perform a contraction on two inputs, accumulating into the third.
692692
}];
693693
let description = [{
694694
The semantics of contracting inputs `A` and `B` on top of `C` to produce

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3535,22 +3535,16 @@ FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
35353535
if (parser.parseOptionalKeyword("indexing_maps"))
35363536
return {nullptr}; // Success in case indexing_maps was not provided.
35373537

3538-
SmallVector<Attribute> indexingMaps;
3539-
3540-
auto parseIndexingMap = [&]() -> ParseResult {
3541-
AffineMapAttr affineMapAttr;
3542-
if (parser.parseAttribute(affineMapAttr))
3543-
return failure();
3544-
indexingMaps.push_back(affineMapAttr);
3545-
return success();
3546-
};
3547-
3548-
if (parser.parseEqual() ||
3549-
parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
3550-
parseIndexingMap))
3538+
ArrayAttr arrayAttr;
3539+
if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
35513540
return failure();
35523541

3553-
return parser.getBuilder().getArrayAttr(indexingMaps);
3542+
if (llvm::any_of(arrayAttr,
3543+
[](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3544+
return parser.emitError(parser.getCurrentLocation())
3545+
<< "element of indexing_maps array is not an affine_map";
3546+
3547+
return arrayAttr;
35543548
}
35553549

35563550
ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -3680,7 +3674,7 @@ ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
36803674
FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
36813675
if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
36823676
return parser.emitError(parser.getCurrentLocation(),
3683-
"expected 'indexing_map' attribute");
3677+
"expected 'indexing_maps' attribute");
36843678
result.addAttribute("indexing_maps", *indexingMapsAttr);
36853679

36863680
return parseNamedStructuredOp(parser, result, getNumRegionArgs(),

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 127 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -999,193 +999,206 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
999999

10001000
// -----
10011001

1002-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1003-
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1004-
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1005-
1006-
// CHECK-LABEL: func.func @contract_matmul(
1007-
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
1008-
// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
1009-
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1010-
1011-
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
1012-
// CHECK-NEXT: ^{{.+}}(
1002+
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1003+
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
1004+
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1005+
1006+
// CHECK-LABEL: func.func @contract_matmul(
1007+
// CHECK-SAME: %[[A:.*]]: memref<3x5xf32>,
1008+
// CHECK-SAME: %[[B:.*]]: memref<5x7xf32>,
1009+
// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
1010+
1011+
// CHECK: linalg.generic
1012+
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1013+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
1014+
// CHECK-NEXT: ^{{.+}}(
10131015
// CHECK-NEXT: arith.mulf
10141016
// CHECK-NEXT: arith.addf
10151017
// CHECK-NEXT: linalg.yield
10161018

10171019
func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
1018-
linalg.contract indexing_maps = [
1019-
affine_map<(d0, d1, d2) -> (d0, d2)>,
1020-
affine_map<(d0, d1, d2) -> (d2, d1)>,
1021-
affine_map<(d0, d1, d2) -> (d0, d1)>
1022-
]
1023-
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
1024-
outs(%arg2: memref<3x7xf32>)
1020+
linalg.contract
1021+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
1022+
affine_map<(d0, d1, d2) -> (d2, d1)>,
1023+
affine_map<(d0, d1, d2) -> (d0, d1)>]
1024+
ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
1025+
outs(%arg2: memref<3x7xf32>)
10251026

10261027
return
10271028
}
10281029

10291030
// -----
10301031

1031-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1032-
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1033-
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1032+
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1033+
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1034+
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
10341035

1035-
// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
1036-
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
1037-
// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
1038-
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1036+
// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
1037+
// CHECK-SAME: %[[A:.*]]: memref<5x3xf32>,
1038+
// CHECK-SAME: %[[B:.*]]: memref<7x5xf32>,
1039+
// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
10391040

1040-
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
1041-
// CHECK-NEXT: ^{{.+}}(
1041+
// CHECK: linalg.generic
1042+
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1043+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
1044+
// CHECK-NEXT: ^{{.+}}(
10421045
// CHECK-NEXT: arith.mulf
10431046
// CHECK-NEXT: arith.addf
10441047
// CHECK-NEXT: linalg.yield
10451048

10461049
func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
1047-
linalg.contract indexing_maps = [
1048-
affine_map<(d0, d1, d2) -> (d2, d0)>,
1049-
affine_map<(d0, d1, d2) -> (d1, d2)>,
1050-
affine_map<(d0, d1, d2) -> (d0, d1)>
1051-
]
1052-
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
1053-
outs(%arg2: memref<3x7xf32>)
1054-
1050+
linalg.contract
1051+
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
1052+
affine_map<(d0, d1, d2) -> (d1, d2)>,
1053+
affine_map<(d0, d1, d2) -> (d0, d1)>]
1054+
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
1055+
outs(%arg2: memref<3x7xf32>)
10551056
return
10561057
}
10571058

10581059
// -----
10591060

1060-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1061-
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1062-
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1061+
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1062+
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1063+
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
10631064

1064-
// CHECK-LABEL: func.func @contract_batch_matmul(
1065-
// CHECK-SAME: %[[VAL_0:.*]]: memref<9x3x5xf32>,
1066-
// CHECK-SAME: %[[VAL_1:.*]]: memref<9x5x7xf32>,
1067-
// CHECK-SAME: %[[VAL_2:.*]]: memref<9x3x7xf32>) {
1065+
// CHECK-LABEL: func.func @contract_batch_matmul(
1066+
// CHECK-SAME: %[[A:.*]]: memref<9x3x5xf32>,
1067+
// CHECK-SAME: %[[B:.*]]: memref<9x5x7xf32>,
1068+
// CHECK-SAME: %[[C:.*]]: memref<9x3x7xf32>) {
10681069

1069-
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
1070-
// CHECK-NEXT: ^{{.+}}(
1070+
// CHECK: linalg.generic
1071+
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1072+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
1073+
// CHECK-NEXT: ^{{.+}}(
10711074
// CHECK-NEXT: arith.mulf
10721075
// CHECK-NEXT: arith.addf
10731076
// CHECK-NEXT: linalg.yield
10741077

10751078
func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
1076-
linalg.contract indexing_maps = [
1077-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1078-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1079-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1080-
]
1081-
ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
1082-
outs(%arg2: memref<9x3x7xf32>)
1083-
1079+
linalg.contract
1080+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1081+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1082+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
1083+
ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
1084+
outs(%arg2: memref<9x3x7xf32>)
10841085
return
10851086
}
10861087

10871088
// -----
10881089

1089-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1090-
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1091-
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1090+
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1091+
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1092+
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
10921093

1093-
// CHECK-LABEL: func.func @contract_batch_reduce_matmul(
1094-
// CHECK-SAME: %[[VAL_0:.*]]: memref<9x3x5xf32>,
1095-
// CHECK-SAME: %[[VAL_1:.*]]: memref<9x5x7xf32>,
1096-
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1094+
// CHECK-LABEL: func.func @contract_batch_reduce_matmul(
1095+
// CHECK-SAME: %[[A:.*]]: memref<9x3x5xf32>,
1096+
// CHECK-SAME: %[[B:.*]]: memref<9x5x7xf32>,
1097+
// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
10971098

1098-
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
1099-
// CHECK-NEXT: ^{{.+}}(
1099+
// CHECK: linalg.generic
1100+
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1101+
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]
1102+
// CHECK-NEXT: ^{{.+}}(
11001103
// CHECK-NEXT: arith.mulf
11011104
// CHECK-NEXT: arith.addf
11021105
// CHECK-NEXT: linalg.yield
11031106

1104-
func.func @contract_batch_reduce_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<3x7xf32>) {
1105-
linalg.contract indexing_maps = [
1106-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
1107-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
1108-
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1109-
]
1110-
ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
1111-
outs(%arg2: memref<3x7xf32>)
1107+
#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
1108+
#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
1109+
#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1110+
func.func @contract_batch_reduce_matmul(
1111+
%A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<3x7xf32>) {
1112+
linalg.contract
1113+
indexing_maps = [#accessA, #accessB, #accessC]
1114+
ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
1115+
outs(%C: memref<3x7xf32>)
11121116
return
11131117
}
11141118

11151119
// -----
11161120

1117-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
1118-
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
1119-
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1121+
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
1122+
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
1123+
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
11201124

1121-
// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
1122-
// CHECK-SAME: %[[VAL_0:.*]]: memref<9x5x3xf32>,
1123-
// CHECK-SAME: %[[VAL_1:.*]]: memref<9x7x5xf32>,
1124-
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1125+
// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
1126+
// CHECK-SAME: %[[A:.*]]: memref<9x5x3xf32>,
1127+
// CHECK-SAME: %[[B:.*]]: memref<9x7x5xf32>,
1128+
// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
11251129

1126-
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
1127-
// CHECK-NEXT: ^{{.+}}(
1130+
// CHECK: linalg.generic
1131+
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
1132+
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]
1133+
// CHECK-NEXT: ^{{.+}}(
11281134
// CHECK-NEXT: arith.mulf
11291135
// CHECK-NEXT: arith.addf
11301136
// CHECK-NEXT: linalg.yield
11311137

1132-
func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(%arg0: memref<9x5x3xf32>, %arg1: memref<9x7x5xf32>, %arg2: memref<3x7xf32>) {
1133-
linalg.contract indexing_maps = [
1134-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
1135-
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
1136-
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1137-
]
1138-
ins(%arg0, %arg1 : memref<9x5x3xf32>, memref<9x7x5xf32>)
1139-
outs(%arg2: memref<3x7xf32>)
1138+
#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
1139+
#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
1140+
#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
1141+
func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
1142+
%A: memref<9x5x3xf32>, %B: memref<9x7x5xf32>, %C: memref<3x7xf32>) {
1143+
linalg.contract
1144+
indexing_maps = [#accessA, #accessB, #accessC]
1145+
ins(%A, %B : memref<9x5x3xf32>, memref<9x7x5xf32>)
1146+
outs(%C: memref<3x7xf32>)
11401147
return
11411148
}
11421149

11431150
// -----
11441151

1145-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)>
1146-
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> ()>
1152+
// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0) -> (d0)>
1153+
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0) -> ()>
11471154

1148-
// CHECK-LABEL: func.func @contract_dot
1149-
// CHECK-SAME: (%[[VAL_0:.*]]: memref<9xf32>, %[[VAL_1:.*]]: memref<9xf32>, %[[VAL_2:.*]]: memref<f32>) {
1155+
// CHECK-LABEL: func.func @contract_dot(
1156+
// CHECK-SAME: %[[A:.*]]: memref<9xf32>, %[[B:.*]]: memref<9xf32>,
1157+
// CHECK-SAME: %[[C:.*]]: memref<f32>) {
11501158

1151-
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["reduction"]}
1152-
// CHECK-NEXT: ^{{.+}}(
1159+
// CHECK: linalg.generic
1160+
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
1161+
// CHECK-SAME: iterator_types = ["reduction"]
1162+
// CHECK-NEXT: ^{{.+}}(
11531163
// CHECK-NEXT: arith.mulf
11541164
// CHECK-NEXT: arith.addf
11551165
// CHECK-NEXT: linalg.yield
11561166

1157-
func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memref<f32>) {
1158-
linalg.contract indexing_maps = [
1159-
affine_map<(d0) -> (d0)>,
1160-
affine_map<(d0) -> (d0)>,
1161-
affine_map<(d0) -> ()>
1162-
]
1163-
ins(%arg0, %arg1 : memref<9xf32>, memref<9xf32>)
1164-
outs(%arg2: memref<f32>)
1167+
#accessAB = affine_map<(d0) -> (d0)>
1168+
#accessC = affine_map<(d0) -> ()>
1169+
func.func @contract_dot(
1170+
%A: memref<9xf32>, %B: memref<9xf32>, %C: memref<f32>) {
1171+
linalg.contract
1172+
indexing_maps = [#accessAB, #accessAB, #accessC]
1173+
ins(%A, %B : memref<9xf32>, memref<9xf32>)
1174+
outs(%C: memref<f32>)
11651175
return
11661176
}
11671177

11681178
// -----
11691179

1170-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1171-
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1180+
// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
1181+
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
11721182

1173-
// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
1174-
// CHECK-SAME: (%[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, %[[VAL_2:.*]]: memref<3x7xf32>) {
1183+
// CHECK-LABEL: func.func @contract_matmul_bcast_a_b(
1184+
// CHECK-SAME: %[[A:.*]]: memref<5xf32>, %[[B:.*]]: memref<5xf32>,
1185+
// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
11751186

1176-
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
1177-
// CHECK-NEXT: ^{{.+}}(
1187+
// CHECK: linalg.generic
1188+
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
1189+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
1190+
// CHECK-NEXT: ^{{.+}}(
11781191
// CHECK-NEXT: arith.mulf
11791192
// CHECK-NEXT: arith.addf
11801193
// CHECK-NEXT: linalg.yield
11811194

1182-
func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
1183-
linalg.contract indexing_maps = [
1184-
affine_map<(d0, d1, d2) -> (d2)>,
1185-
affine_map<(d0, d1, d2) -> (d2)>,
1186-
affine_map<(d0, d1, d2) -> (d0, d1)>
1187-
]
1188-
ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>)
1189-
outs(%arg2: memref<3x7xf32>)
1195+
#accessAB = affine_map<(d0, d1, d2) -> (d2)>
1196+
#accessC = affine_map<(d0, d1, d2) -> (d0, d1)>
1197+
func.func @contract_matmul_bcast_a_b(
1198+
%A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
1199+
linalg.contract
1200+
indexing_maps = [#accessAB, #accessAB, #accessC]
1201+
ins(%A, %B : memref<5xf32>, memref<5xf32>)
1202+
outs(%C: memref<3x7xf32>)
11901203
return
11911204
}

0 commit comments

Comments
 (0)