Skip to content

Commit db115ba

Browse files
authored
[mlir][Linalg] Fix non-matmul linalg structured ops (#116412)
3ad0148 broke linalg structured ops other than MatmulOp. The patch: - Changes the printer to hide additional attributes, which weren't hidden before: "indexing_maps". - Changes the build of every linalg structured op to have an indexing map for matmul. These changes combined, hide the problem until you print the operation in it's generic form. Reproducer: ```mlir func.func public @bug(%arg0 : tensor<5x10x20xf32>, %arg1 : tensor<5x20x40xf32>, %arg3 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32> { %out = linalg.batch_matmul ins(%arg0, %arg1 : tensor<5x10x20xf32>, tensor<5x20x40xf32>) outs(%arg3 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32> func.return %out : tensor<5x10x40xf32> } ``` Prints fine, with `mlir-opt <file>`, but if you do `mlir-opt --mlir-print-op-generic <file>`: ``` #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> "builtin.module"() ({ "func.func"() <{function_type = (tensor<5x10x20xf32>, tensor<5x20x40xf32>, tensor<5x10x40xf32>) -> tensor<5x10x40xf32>, sym_name = "bug", sym_visibility = "public"}> ({ ^bb0(%arg0: tensor<5x10x20xf32>, %arg1: tensor<5x20x40xf32>, %arg2: tensor<5x10x40xf32>): %0 = "linalg.batch_matmul"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> ({ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): %1 = "arith.mulf"(%arg3, %arg4) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 %2 = "arith.addf"(%arg5, %1) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 "linalg.yield"(%2) : (f32) -> () }) {indexing_maps = [#map, #map1, #map2], linalg.memoized_indexing_maps = [#map3, #map4, #map5]} : (tensor<5x10x20xf32>, tensor<5x20x40xf32>, tensor<5x10x40xf32>) -> tensor<5x10x40xf32> "func.return"(%0) : (tensor<5x10x40xf32>) -> () }) : () -> () }) : () -> () ``` The batch_matmul operation's builder now always inserts a indexing_map which is unrelated to the operation itself. This was caught when a transformation from one LinalgStructuredOp to another, tried to pass it's attributes to the other ops builder and there were multiple indexing_map attributes in the result. This patch fixes this by specializing the builders for MatmulOp with indexing map information.
1 parent 2906fca commit db115ba

File tree

4 files changed

+79
-65
lines changed

4 files changed

+79
-65
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,15 +621,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
621621
(ins "ValueRange":$inputs, "ValueRange":$outputs,
622622
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
623623
[{
624-
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
624+
buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
625625
attributes, MatmulOp::getRegionBuilder());
626626
}]>,
627627
OpBuilder<
628628
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
629629
"ValueRange":$outputs,
630630
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
631631
[{
632-
buildStructuredOp($_builder, $_state, resultTensorTypes,
632+
buildMatmulOp($_builder, $_state, resultTensorTypes,
633633
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
634634
}]>,
635635
OpBuilder<
@@ -647,7 +647,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
647647
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
648648
[{
649649
$_state.addAttribute("cast", cast);
650-
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
650+
buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
651651
attributes, MatmulOp::getRegionBuilder());
652652
}]>
653653

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 67 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
169169
}
170170

171171
/// Wrapper to return the typical indexing map array attribute for MatmulOp.
172-
static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
172+
static SmallVector<Attribute>
173+
getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
173174
return llvm::map_to_vector(
174175
getDefaultIndexingMapsForMatmul(context),
175176
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -179,12 +180,11 @@ static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
179180
/// The result types are derived automatically if `resultTensorTypes` is none.
180181
/// The body of the operation is filled using `regionBuilder`. All ods-gen
181182
/// created structured operations use the method to implement their builders.
182-
static void buildStructuredOp(
183-
OpBuilder &b, OperationState &state,
184-
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
185-
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
186-
RegionBuilderFn regionBuilder,
187-
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
183+
static void buildStructuredOp(OpBuilder &b, OperationState &state,
184+
std::optional<TypeRange> resultTensorTypes,
185+
ValueRange inputs, ValueRange outputs,
186+
ArrayRef<NamedAttribute> attributes,
187+
RegionBuilderFn regionBuilder) {
188188
// Derive the result types if needed.
189189
SmallVector<Type> derivedResultTypes =
190190
resultTensorTypes.value_or(TypeRange());
@@ -196,6 +196,24 @@ static void buildStructuredOp(
196196
state.addOperands(outputs);
197197
state.addTypes(derivedResultTypes);
198198

199+
state.addAttributes(attributes);
200+
state.addAttribute(
201+
"operandSegmentSizes",
202+
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
203+
static_cast<int32_t>(outputs.size())}));
204+
205+
// Create and fill the region of the structured operation.
206+
Region &region = *state.addRegion();
207+
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
208+
state.attributes.getAttrs(), regionBuilder);
209+
}
210+
211+
static void
212+
buildMatmulOp(OpBuilder &b, OperationState &state,
213+
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
214+
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
215+
RegionBuilderFn regionBuilder,
216+
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
199217
// Initialize indexingMaps, for MatmulOp.
200218
SmallVector<Attribute, 3> indexingMapsAttrVal;
201219
if (indexingMaps.has_value()) {
@@ -205,20 +223,11 @@ static void buildStructuredOp(
205223
}
206224
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
207225
} else {
208-
indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
226+
indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
209227
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
210228
}
211-
212-
state.addAttributes(attributes);
213-
state.addAttribute(
214-
"operandSegmentSizes",
215-
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
216-
static_cast<int32_t>(outputs.size())}));
217-
218-
// Create and fill the region of the structured operation.
219-
Region &region = *state.addRegion();
220-
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
221-
state.attributes.getAttrs(), regionBuilder);
229+
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
230+
attributes, regionBuilder);
222231
}
223232

224233
/// Common parsing used for both named structured ops created by ods-gen and by
@@ -340,39 +349,6 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
340349
OperationState &result,
341350
unsigned numRegionArgs,
342351
RegionBuilderFn regionBuilder) {
343-
344-
SmallVector<Attribute, 3> indexingMapsAttr;
345-
Attribute mapAttr;
346-
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
347-
if (parser.parseEqual())
348-
return failure();
349-
350-
if (parser.parseLSquare())
351-
return failure();
352-
353-
do {
354-
if (parser.parseAttribute(mapAttr))
355-
return failure();
356-
if (!isa<AffineMapAttr>(mapAttr)) {
357-
return parser.emitError(parser.getCurrentLocation(),
358-
"expected affine map attribute");
359-
}
360-
indexingMapsAttr.push_back(mapAttr);
361-
362-
if (parser.parseOptionalComma())
363-
break;
364-
} while (true);
365-
366-
if (parser.parseRSquare())
367-
return failure();
368-
}
369-
// Initialize indexingMaps, if not supplied explicitly.
370-
if (indexingMapsAttr.empty()) {
371-
indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
372-
}
373-
result.addAttribute("indexing_maps",
374-
parser.getBuilder().getArrayAttr(indexingMapsAttr));
375-
376352
// TODO: Enable when ods-gen supports captures.
377353
SmallVector<Type, 1> inputTypes, outputTypes;
378354
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
@@ -3503,9 +3479,11 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
35033479

35043480
namespace mlir {
35053481
namespace linalg {
3482+
35063483
//===----------------------------------------------------------------------===//
35073484
// MatMulOp
35083485
//===----------------------------------------------------------------------===//
3486+
35093487
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
35103488
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
35113489
utils::IteratorType::parallel,
@@ -3520,8 +3498,8 @@ std::string MatmulOp::getLibraryCallName() {
35203498

35213499
bool MatmulOp::hasDynamicIndexingMaps() { return true; }
35223500

3523-
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
3524-
/// user defined indexing maps are not equal to default map.
3501+
/// Check if the op has broadcast and/or transpose semantic. Returns true if
3502+
/// the user defined indexing maps are not equal to default map.
35253503
bool MatmulOp::hasUserDefinedMaps() {
35263504
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
35273505
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
@@ -3557,7 +3535,8 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
35573535
helper.yieldOutputs(yields);
35583536
}
35593537

3560-
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3538+
/// Returns a list of AffineMap with the typical matmul indexing
3539+
/// charactristic.
35613540
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
35623541
MLIRContext *context = this->getContext();
35633542
return getDefaultIndexingMapsForMatmul(context);
@@ -3572,6 +3551,38 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
35723551
}
35733552

35743553
ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3554+
SmallVector<Attribute, 3> indexingMapsAttr;
3555+
Attribute mapAttr;
3556+
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
3557+
if (parser.parseEqual())
3558+
return failure();
3559+
3560+
if (parser.parseLSquare())
3561+
return failure();
3562+
3563+
do {
3564+
if (parser.parseAttribute(mapAttr))
3565+
return failure();
3566+
if (!isa<AffineMapAttr>(mapAttr)) {
3567+
return parser.emitError(parser.getCurrentLocation(),
3568+
"expected affine map attribute");
3569+
}
3570+
indexingMapsAttr.push_back(mapAttr);
3571+
3572+
if (parser.parseOptionalComma())
3573+
break;
3574+
} while (true);
3575+
3576+
if (parser.parseRSquare())
3577+
return failure();
3578+
}
3579+
// Initialize indexingMaps, if not supplied explicitly.
3580+
if (indexingMapsAttr.empty()) {
3581+
indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
3582+
}
3583+
result.addAttribute("indexing_maps",
3584+
parser.getBuilder().getArrayAttr(indexingMapsAttr));
3585+
35753586
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
35763587
MatmulOp::getRegionBuilder());
35773588
}
@@ -3582,7 +3593,7 @@ void MatmulOp::print(OpAsmPrinter &p) {
35823593
elidedAttrs);
35833594

35843595
SmallVector<Attribute, 3> indexingMaps =
3585-
getDefaultIndexingMapAttr(getContext());
3596+
getDefaultMatmulIndexingMapAttr(getContext());
35863597
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
35873598
p << " indexing_maps = [";
35883599
llvm::interleaveComma(getIndexingMaps(), p,

mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<
4343
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
4444
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
4545
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
46-
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
46+
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec
47+
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
4748
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128]
4849
// CHECK-NEXT: return %[[RES]]
4950
%1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>)
@@ -62,7 +63,8 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
6263
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
6364
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
6465
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
65-
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
66+
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat
67+
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
6668
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
6769
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
6870
// CHECK-NEXT: return %[[RES]]
@@ -113,7 +115,8 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32
113115
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
114116
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
115117
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
116-
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
118+
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec
119+
// CHECK-SAME: ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
117120
// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
118121
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
119122
// CHECK-NEXT: return %[[RES]]
@@ -140,7 +143,8 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32
140143
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
141144
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
142145
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
143-
// CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
146+
// CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat
147+
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
144148
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
145149
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
146150
// CHECK-NEXT: return %[[RES]]

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,8 +679,7 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
679679
}
680680
void {0}::print(OpAsmPrinter &p) {{
681681
SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
682-
"linalg.memoized_indexing_maps",
683-
"indexing_maps"};
682+
"linalg.memoized_indexing_maps"};
684683
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
685684
elidedAttrs);
686685
}

0 commit comments

Comments
 (0)