Skip to content

Commit 85426f6

Browse files
committed
Fix matmul's indexing_maps issue
1 parent 2e6bf66 commit 85426f6

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,22 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
555555
// Op definition for MatmulOp
556556
//===----------------------------------------------------------------------===//
557557

558+
559+
// DONOTMERGE(rolfmorel): explain why the below is necessary
560+
def DefaultValuedMatmulIndexingMapsAttr :
561+
Attr<AffineMapArrayAttr.predicate, AffineMapArrayAttr.summary> {
562+
let storageType = AffineMapArrayAttr.storageType;
563+
let returnType = AffineMapArrayAttr.returnType;
564+
let convertFromStorage = AffineMapArrayAttr.convertFromStorage;
565+
let constBuilderCall = "$_builder.getAffineMapArrayAttr($0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0)";
566+
let defaultValue = "SmallVector<AffineMap>()";
567+
let valueType = AffineMapArrayAttr.valueType;
568+
let isOptional = 1;
569+
570+
let baseAttr = AffineMapArrayAttr;
571+
}
572+
573+
558574
def MatmulOp : LinalgStructuredBase_Op<"matmul", [
559575
AttrSizedOperandSegments,
560576
LinalgContractionOpInterface]> {
@@ -606,7 +622,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
606622
let arguments = (ins
607623
Variadic<AnyType>:$inputs,
608624
Variadic<AnyShaped>:$outputs,
609-
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
625+
DefaultValuedMatmulIndexingMapsAttr:$indexing_maps, // DONOTMERGE(rolfmorel): explain why this is necessary
610626
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
611627
);
612628
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ def __init__(
149149
generic = region_op(GenericOp_, terminator=YieldOp)
150150

151151

152+
@register_attribute_builder("DefaultValuedMatmulIndexingMapsAttr")
153+
def _DefaultValuedMatmulIndexingMapsAttr(x, context):
154+
return ArrayAttr.get([AffineMapAttr.get(v) for v in x])
155+
156+
152157
def matmul(
153158
inputs: Sequence[Union[Operation, OpView, Value]],
154159
*,

mlir/test/python/dialects/linalg/ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
337337
)
338338
linalg.fill_builtin_region(res.operation)
339339
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
340-
linalg.matmul(
341-
(Amem, Bmem), outs=(Cmem,)
342-
)
340+
linalg.matmul((Amem, Bmem), outs=(Cmem,))
343341

344342
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
345343
res = linalg.MatmulOp(

0 commit comments

Comments
 (0)