@@ -690,34 +690,32 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
690690
691691 Example Transpose:
692692 ```mlir
693- linalg.matmul indexing_maps = [
694- affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
695- affine_map<(d0, d1, d2) -> (d2, d1)>,
696- affine_map<(d0, d1, d2) -> (d0, d1)>
697- ]
698- ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
699- outs(%arg2: memref<3x7xf32>)
693+ linalg.matmul
694+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
695+ affine_map<(m, n, k) -> (k, n)>,
696+ affine_map<(m, n, k) -> (m, n)>]
697+ ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
698+ outs(%arg2: memref<3x7xf32>)
700699 ```
701700
702701 Example Broadcast:
703- ```mlir
704- linalg.matmul indexing_maps = [
705- affine_map<(d0, d1, d2) -> (d2)>, // broadcast
706- affine_map<(d0, d1, d2) -> (d2, d1)>,
707- affine_map<(d0, d1, d2) -> (d0, d1)>
708- ]
709- ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
710- outs(%arg2: memref<3x7xf32>)
702+ ```mlir
703+ linalg.matmul
704+ indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
705+ affine_map<(m, n, k) -> (k, n)>,
706+ affine_map<(m, n, k) -> (m, n)>]
707+ ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
708+ outs(%arg2: memref<3x7xf32>)
711709 ```
712710
713711 Example Broadcast and transpose:
714712 ```mlir
715- linalg.matmul indexing_maps = [
716- affine_map<(d0, d1, d2 ) -> (d2, d0 )>, // transpose
717- affine_map<(d0, d1, d2 ) -> (d2 )>, // broadcast
718- affine_map<(d0, d1, d2 ) -> (d0, d1)>
719- ]
720- ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
713+ linalg.matmul
714+ indexing_maps = [ affine_map<(m, n, k ) -> (k, m )>, // transpose
715+ affine_map<(m, n, k ) -> (k )>, // broadcast
716+ affine_map<(m, n, k ) -> (m, n)>]
717+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>)
718+ outs(%arg2: memref<3x7xf32>)
721719 ```
722720 }];
723721
@@ -775,7 +773,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
775773 static void regionBuilder(ImplicitLocOpBuilder &b,
776774 Block &block, ArrayRef<NamedAttribute> attrs);
777775
778- /// Returns a list of AffineMap with the typical matmul indexing charactristic.
776+ /// Returns a list of AffineMap with the default matmul indexing charactristic.
779777 static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
780778
781779 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
@@ -954,35 +952,32 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
954952
955953 Example Transpose:
956954 ```mlir
957- linalg.batch_matmul indexing_maps = [
958- affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
959- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
960- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
961- ]
962- ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
963- outs(%arg2: memref<2x3x7xf32>)
955+ linalg.batch_matmul
956+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
957+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
958+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
959+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
960+ outs(%arg2: memref<2x3x7xf32>)
964961 ```
965962
966963 Example Broadcast:
967964 ```mlir
968- linalg.batch_matmul indexing_maps = [
969- affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
970- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
971- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
972- ]
973- ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
974- outs(%arg2: memref<2x3x7xf32>)
965+ linalg.batch_matmul
966+ indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
967+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
968+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
969+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
970+ outs(%arg2: memref<2x3x7xf32>)
975971 ```
976972
977973 Example Broadcast and Transpose:
978974 ```mlir
979- linalg.batch_matmul indexing_maps = [
980- affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
981- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
982- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
983- ]
984- ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
985- outs(%arg2: memref<2x3x7xf32>)
975+ linalg.batch_matmul
976+ indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
977+ affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
978+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
979+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
980+ outs(%arg2: memref<2x3x7xf32>)
986981 ```
987982}];
988983
@@ -1065,6 +1060,134 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
10651060}
10661061
10671062
1063+ //===----------------------------------------------------------------------===//
1064+ // Op definition for BatchReduceMatmulOp
1065+ //===----------------------------------------------------------------------===//
1066+
1067+ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
1068+ AttrSizedOperandSegments,
1069+ LinalgContractionOpInterface]> {
1070+
1071+ let summary = [{Performs a batch-reduce matrix multiplication on two inputs.
1072+ The partial multiplication results are reduced into a 2D output.}];
1073+ let description = [{
1074+ Numeric casting is performed on the operands to the inner multiply,
1075+ promoting them to the same data type as the accumulator/output.
1076+
1077+ Broadcast and Transpose semantics can be applied by specifying the explicit attribute
1078+ 'indexing_maps' as shown below. This is a list attribute, so must include maps for all
1079+ arguments if specified.
1080+
1081+ Example Transpose:
1082+ ```mlir
1083+ linalg.batch_reduce_matmul
1084+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
1085+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
1086+ affine_map<(batch, m, n, k) -> (m, n)>]
1087+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
1088+ outs(%arg2: memref<3x7xf32>)
1089+ ```
1090+
1091+ Example Broadcast:
1092+ ```mlir
1093+ linalg.batch_reduce_matmul
1094+ indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
1095+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
1096+ affine_map<(batch, m, n, k) -> (m, n)>]
1097+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
1098+ outs(%arg2: memref<3x7xf32>)
1099+ ```
1100+
1101+ Example Broadcast and Transpose:
1102+ ```mlir
1103+ linalg.batch_reduce_matmul
1104+ indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
1105+ affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
1106+ affine_map<(batch, m, n, k) -> (m, n)>]
1107+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
1108+ outs(%arg2: memref<3x7xf32>)
1109+ ```
1110+ }];
1111+
1112+ let arguments = (ins
1113+ Variadic<AnyType>:$inputs,
1114+ Variadic<AnyShaped>:$outputs,
1115+ DefaultValuedOptionalAttr<
1116+ AffineMapArrayAttr,
1117+ "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
1118+ >:$indexing_maps,
1119+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
1120+ );
1121+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1122+ let regions = (region AnyRegion:$region);
1123+
1124+ let skipDefaultBuilders = 1;
1125+ let builders = [
1126+ OpBuilder<
1127+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
1128+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1129+ [{
1130+ buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
1131+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
1132+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1133+ }]>,
1134+ OpBuilder<
1135+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1136+ "ValueRange":$outputs,
1137+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1138+ [{
1139+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
1140+ inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
1141+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1142+ }]>,
1143+ OpBuilder<
1144+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1145+ "ValueRange":$outputs,
1146+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1147+ [{
1148+ $_state.addAttribute("cast", cast);
1149+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
1150+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
1151+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1152+ }]>
1153+
1154+ ];
1155+ let hasCustomAssemblyFormat = 1;
1156+ let hasFolder = 1;
1157+ let hasVerifier = 1;
1158+
1159+ let extraClassDeclaration = structuredOpsBaseDecls # [{
1160+ SmallVector<utils::IteratorType> getIteratorTypesArray();
1161+
1162+ /// Implements the block region builder.
1163+ static void regionBuilder(ImplicitLocOpBuilder &b,
1164+ Block &block, ArrayRef<NamedAttribute> attrs);
1165+
1166+ /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
1167+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
1168+
1169+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
1170+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
1171+
1172+ static std::function<void(ImplicitLocOpBuilder &,
1173+ Block &, ArrayRef<NamedAttribute>)>
1174+ getRegionBuilder() {
1175+ return regionBuilder;
1176+ }
1177+
1178+ ::mlir::MutableOperandRange getDpsInitsMutable() {
1179+ return getOutputsMutable();
1180+ }
1181+
1182+ // Generic methods.
1183+ static unsigned getNumRegionArgs();
1184+ std::string getLibraryCallName();
1185+ bool hasDynamicIndexingMaps() { return true; };
1186+ /// Returns true if the user defined indexing maps are not equal to default maps.
1187+ bool hasUserDefinedMaps();
1188+ }];
1189+ }
1190+
10681191//===----------------------------------------------------------------------===//
10691192// Named Linalg ops, implemented as a declarative configurations of generic ops.
10701193//===----------------------------------------------------------------------===//
0 commit comments