diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index b44af2defc3e4..6344861c53ac5 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: BZp --- !LinalgOpConfig -metadata: !LinalgOpMetadata - name: batch_reduce_matmul - cpp_class_name: BatchReduceMatmulOp - doc: |- - Performs a batch-reduce matrix multiplication of two 3D inputs. - The partial multiplication results are reduced into a 2D output. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - implements: - - LinalgContractionOpInterface -structured_op: !LinalgStructuredOpConfig - args: - - !LinalgOperandDefConfig - name: A - kind: input_tensor - type_var: T1 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - - !LinalgOperandDefConfig - name: B - kind: input_tensor - type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - - !LinalgOperandDefConfig - name: C - kind: output_tensor - type_var: U - shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)> - indexing_maps: !LinalgIndexingMapsConfig - static_indexing_maps: - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)> - iterator_types: - - reduction - - parallel - - parallel - - reduction - assignments: - - !ScalarAssign - arg: C - value: !ScalarExpression - scalar_fn: - kind: binary - fn_name: add - operands: - - !ScalarExpression - scalar_arg: C - - !ScalarExpression - scalar_fn: - kind: binary - fn_name: mul - operands: - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: A - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: B ---- !LinalgOpConfig metadata: !LinalgOpMetadata name: matvec cpp_class_name: MatvecOp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index f3dbeb274deda..61783812920bc 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -690,34 +690,32 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ Example Transpose: ```mlir - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>) - outs(%arg2: memref<3x7xf32>) + linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>] + ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>) + outs(%arg2: memref<3x7xf32>) ``` Example Broadcast: - ```mlir - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, // broadcast - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) - outs(%arg2: memref<3x7xf32>) + ```mlir + linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>] + ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) + outs(%arg2: memref<3x7xf32>) ``` Example Broadcast and transpose: ```mlir - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose - affine_map<(d0, d1, d2) -> (d2)>, // broadcast - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) + linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose + affine_map<(m, n, k) -> (k)>, // broadcast + affine_map<(m, n, k) -> (m, n)>] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) + outs(%arg2: memref<3x7xf32>) ``` }]; @@ -775,7 +773,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs); - /// Returns a list of AffineMap with the typical matmul indexing charactristic. + /// Returns a list of AffineMap with the default matmul indexing charactristic. static SmallVector getDefaultIndexingMaps(MLIRContext *context); /// Returns true if the given broadcast map \p bcastMap is valid for this op. @@ -954,35 +952,32 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz Example Transpose: ```mlir - linalg.batch_matmul indexing_maps = [ - affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose - affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, - affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - ] - ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) - outs(%arg2: memref<2x3x7xf32>) + linalg.batch_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (batch, m, n)>] + ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) + outs(%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ```mlir - linalg.batch_matmul indexing_maps = [ - affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast - affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, - affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - ] - ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) - outs(%arg2: memref<2x3x7xf32>) + linalg.batch_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (batch, m, n)>] + ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) + outs(%arg2: memref<2x3x7xf32>) ``` Example Broadcast and Transpose: ```mlir - linalg.batch_matmul indexing_maps = [ - affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast - affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose - affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) - outs(%arg2: memref<2x3x7xf32>) + linalg.batch_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast + affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose + affine_map<(batch, m, n, k) -> (batch, m, n)>] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) + outs(%arg2: memref<2x3x7xf32>) ``` }]; @@ -1065,6 +1060,134 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz } +//===----------------------------------------------------------------------===// +// Op definition for BatchReduceMatmulOp +//===----------------------------------------------------------------------===// + +def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [ + AttrSizedOperandSegments, + LinalgContractionOpInterface]> { + + let summary = [{Performs a batch-reduce matrix multiplication on two inputs. + The partial multiplication results are reduced into a 2D output.}]; + let description = [{ + Numeric casting is performed on the operands to the inner multiply, + promoting them to the same data type as the accumulator/output. + + Broadcast and Transpose semantics can be applied by specifying the explicit attribute + 'indexing_maps' as shown below. This is a list attribute, so must include maps for all + arguments if specified. + + Example Transpose: + ```mlir + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) + outs(%arg2: memref<3x7xf32>) + ``` + + Example Broadcast: + ```mlir + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) + outs(%arg2: memref<3x7xf32>) + ``` + + Example Broadcast and Transpose: + ```mlir + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast + affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) + outs(%arg2: memref<3x7xf32>) + ``` + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + DefaultValuedOptionalAttr< + AffineMapArrayAttr, + "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())" + >:$indexing_maps, + DefaultValuedOptionalAttr:$cast + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs, + attributes, BatchReduceMatmulOp::getRegionBuilder(), + BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, + inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(), + BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, + "Attribute":$cast, CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addAttribute("cast", cast); + buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs, + attributes, BatchReduceMatmulOp::getRegionBuilder(), + BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())); + }]> + + ]; + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + SmallVector getIteratorTypesArray(); + + /// Implements the block region builder. + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + + /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic. + static SmallVector getDefaultIndexingMaps(MLIRContext *context); + + /// Returns true if the given broadcast map \p bcastMap is valid for this op. + bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true); + + static std::function)> + getRegionBuilder() { + return regionBuilder; + } + + ::mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputsMutable(); + } + + // Generic methods. + static unsigned getNumRegionArgs(); + std::string getLibraryCallName(); + bool hasDynamicIndexingMaps() { return true; }; + /// Returns true if the user defined indexing maps are not equal to default maps. + bool hasUserDefinedMaps(); + }]; +} + //===----------------------------------------------------------------------===// // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 089ccc6680e48..cee51730bd743 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -220,6 +220,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, attributes, regionBuilder); } +static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef attributes, + RegionBuilderFn regionBuilder, + ArrayRef indexingMaps) { + // Initialize indexingMaps attribute, for BatchReduceMatmulOp. + SmallVector indexingMapsAttrVal; + indexingMapsAttrVal = + llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }); + state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); + return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, + attributes, regionBuilder); +} + /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult @@ -3485,19 +3502,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, return success(); } -// Check general validity of input indexing map. -static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, +// Check general validity of input indexing map of +// BatchMatmulOp/BatchReduceMatmulOp. +template +static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS) { + assert((isa(batchVariantMatmulOp) || + isa(batchVariantMatmulOp)) && + "Expected BatchMatmulOp or BatchReduceMatmulOp"); // Check the result dims are valid. if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap)) - return batchMatmulOp->emitOpError() + return batchVariantMatmulOp->emitOpError() << "Unexpected result dim expression (outside the set of default " "result dims)."; // Check for valid number of result dims of input maps. if (opIndexingMap.getNumResults() > 3) - return batchMatmulOp->emitOpError() + return batchVariantMatmulOp->emitOpError() << "no. of result dim expressions exceeds 3."; auto hasValidBatchDim = [](AffineMap map) { @@ -3507,60 +3529,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, // Check if the requested broadcast is valid. if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { - if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS)) - return batchMatmulOp->emitOpError() << "Invalid broadcast requested."; + if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS)) + return batchVariantMatmulOp->emitOpError() + << "Invalid broadcast requested."; } else if (!hasValidBatchDim(opIndexingMap)) { - return batchMatmulOp->emitOpError() + return batchVariantMatmulOp->emitOpError() << "Invalid batch dimension expression."; } return success(); } /// This function checks if the given AffineMap for the output of a -/// BatchMatmulOp has exactly 3 result dimensions and if the output map result -/// dimensions are valid. -static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, +/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result +/// dimensions and if the output map result dimensions are valid. +template +static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap) { - if (opIndexingMap.getNumResults() != 3) - return batchMatmulOp->emitOpError() + assert((isa(batchVariantMatmulOp) || + isa(batchVariantMatmulOp)) && + "Expected BatchMatmulOp or BatchReduceMatmulOp"); + if (isa(batchVariantMatmulOp) && + opIndexingMap.getNumResults() != 3) { + + return batchVariantMatmulOp->emitOpError() << "expects 3 dims, but got (" << opIndexingMap.getNumResults() << ")."; + } + if (isa(batchVariantMatmulOp) && + opIndexingMap.getNumResults() != 2) { + return batchVariantMatmulOp->emitOpError() + << "expects 2 dims, but got (" << opIndexingMap.getNumResults() + << ")."; + } - auto areValidOutputResultDim = [](AffineMap outputMap) { - return outputMap.getResult(0).isFunctionOfDim(0) && - outputMap.getResult(1).isFunctionOfDim(1) && - outputMap.getResult(2).isFunctionOfDim(2); + auto areValidOutputResultDim = [&](AffineMap outputMap) { + return isa(batchVariantMatmulOp) + ? outputMap.getResult(0).isFunctionOfDim(0) && + outputMap.getResult(1).isFunctionOfDim(1) && + outputMap.getResult(2).isFunctionOfDim(2) + : outputMap.getResult(0).isFunctionOfDim(1) && + outputMap.getResult(1).isFunctionOfDim(2); }; - if (!areValidOutputResultDim(opIndexingMap)) - return batchMatmulOp->emitOpError() + if (!areValidOutputResultDim(opIndexingMap)) { + return batchVariantMatmulOp->emitOpError() << "Invalid output map result dimension."; + } return success(); } /// Verifies the broadcast and transpose semantic specified by the explicit -/// indexing map for the BatchMatmulOp op for each operand specified by opIndex. +/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand +/// specified by opIndex. +template static LogicalResult -verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, - unsigned opIndex) { +verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, + unsigned opIndex) { SmallVector opIndexingMaps = - batchMatmulOp.getIndexingMapsArray(); + batchVariantMatmulOp.getIndexingMapsArray(); SmallVector defaultIndexingMaps = - batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext()); + batchVariantMatmulOp.getDefaultIndexingMaps( + batchVariantMatmulOp->getContext()); if (opIndexingMaps.size() != 3) - return batchMatmulOp->emitOpError() + return batchVariantMatmulOp->emitOpError() << "Indexing_map attribute must have 3 affine maps."; auto opIndexingMap = opIndexingMaps[opIndex]; auto defaultIndexingMap = defaultIndexingMaps[opIndex]; - if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap))) + if (opIndex == 2 && + failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap))) return failure(); - if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap, - opIndex == 0))) + if (opIndex != 2 && + failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap, + defaultIndexingMap, opIndex == 0))) return failure(); return success(); @@ -3636,12 +3681,18 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, helper.yieldOutputs(yields); } -/// Returns true if the given broadcast map \p bcastMap is valid for this op. +/// Returns true if the given bcastMap map is a valid broadcast map. A valid +/// broadcast map must include K dimension. +/// TODO: Strict inclusion of K dimension in the broadcast map is not +/// necessary for both input matrices simultaneously. We can relax this +/// condition to have K dimension for one input matrix map and infer the K +/// dimension for other input matrix map from the one already having K +/// dimension. bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr."); - AffineExpr exp = bcastMap.getResult(0); + AffineExpr expr = bcastMap.getResult(0); // Invalid map if the common dimension of matmul not found. - return exp.isFunctionOfDim(bcastMap.getNumDims() - 1); + return expr.isFunctionOfDim(bcastMap.getNumDims() - 1); } FailureOr parseIndexingMapsAttr(OpAsmParser &parser) { @@ -3939,21 +3990,31 @@ bool BatchMatmulOp::hasUserDefinedMaps() { return defaultMaps != explicitMaps; } -/// Returns true if the given broadcast map bcastMap is valid for this op. +/// Returns true if the given bcastMap map is a valid broadcast map. A valid +/// broadcast map must include K dimension. +/// TODO: Strict inclusion of K dimension in the broadcast map is not +/// necessary for both input matrices simultaneously. We can relax this +/// condition to have K dimension for one input matrix map and infer the K +/// dimension for other input matrix map from the one already having K +/// dimension. bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) { assert(bcastMap.getNumResults() < 3 && "Expected less than 3 result dim expr."); bool isValid = false; enum Indices { batchPos, mPos, nPos, kPos }; if (bcastMap.getNumResults() == 1) { - AffineExpr exp = bcastMap.getResult(0); - isValid = exp.isFunctionOfDim(kPos); + AffineExpr expr = bcastMap.getResult(0); + isValid = expr.isFunctionOfDim(kPos); } else if (bcastMap.getNumResults() == 2) { - AffineExpr exp0 = bcastMap.getResult(0); - AffineExpr exp1 = bcastMap.getResult(1); - isValid = isLHS - ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos)) - : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)); + AffineExpr expr0 = bcastMap.getResult(0); + AffineExpr expr1 = bcastMap.getResult(1); + isValid = + isLHS ? ((expr0.isFunctionOfDim(batchPos) || + expr0.isFunctionOfDim(mPos)) && + expr1.isFunctionOfDim(kPos)) + : ((expr0.isFunctionOfDim(batchPos) && + expr1.isFunctionOfDim(kPos)) || + (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos))); } return isValid; } @@ -4045,7 +4106,7 @@ LogicalResult BatchMatmulOp::verify() { return success(); for (unsigned opIndex = 0; opIndex < 3; opIndex++) { - if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex))) + if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex))) return failure(); } return success(); @@ -5366,6 +5427,176 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// BatchReduceMatmulOp +//===----------------------------------------------------------------------===// +SmallVector BatchReduceMatmulOp::getIteratorTypesArray() { + return SmallVector{ + utils::IteratorType::reduction, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction}; +} + +SmallVector +BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { + AffineExpr d0, d1, d2, d3; + SmallVector indexingMaps; + bindDims(context, d0, d1, d2, d3); + indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context)); + indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context)); + indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context)); + return indexingMaps; +} + +unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; } + +std::string BatchReduceMatmulOp::getLibraryCallName() { + return generateLibraryCallName(getOperation()); +} + +/// Check if the op has broadcast and/or transpose semantic. Returns true if +/// the user defined indexing maps are not equal to default map. +bool BatchReduceMatmulOp::hasUserDefinedMaps() { + SmallVector defaultMaps = + getDefaultIndexingMaps(this->getContext()); + SmallVector explicitMaps = getIndexingMapsArray(); + return defaultMaps != explicitMaps; +} + +/// Returns true if the given bcastMap map is a valid broadcast map. A valid +/// broadcast map must include K dimension. +/// TODO: Strict inclusion of K dimension in the broadcast map is not +/// necessary for both input matrices simultaneously. We can relax this +/// condition to have K dimension for one input matrix map and infer the K +/// dimension for other input matrix map from the one already having K +/// dimension. +bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, + bool isLHS) { + assert(bcastMap.getNumResults() < 3 && + "Expected less than 3 result dim expr."); + bool isValid = false; + enum Indices { batchPos, mPos, nPos, kPos }; + if (bcastMap.getNumResults() == 1) { + AffineExpr expr = bcastMap.getResult(0); + isValid = expr.isFunctionOfDim(kPos); + } else if (bcastMap.getNumResults() == 2) { + AffineExpr expr0 = bcastMap.getResult(0); + AffineExpr expr1 = bcastMap.getResult(1); + isValid = + isLHS ? ((expr0.isFunctionOfDim(batchPos) || + expr0.isFunctionOfDim(mPos)) && + expr1.isFunctionOfDim(kPos)) + : ((expr0.isFunctionOfDim(batchPos) && + expr1.isFunctionOfDim(kPos)) || + (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos))); + } + return isValid; +} + +void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + assert(block.getNumArguments() == 3 && + "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(b, block); + SmallVector yields; + + auto toType = block.getArgument(2).getType(); + Value castValA = + helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0)); + Value castValB = + helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1)); + Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB); + Value addVal = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal); + yields.push_back(addVal); + helper.yieldOutputs(yields); +} + +ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector indexingMapsAttr; + Attribute mapAttr; + if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { + if (parser.parseEqual()) + return failure(); + if (parser.parseLSquare()) + return failure(); + + do { + if (parser.parseAttribute(mapAttr)) + return failure(); + if (!isa(mapAttr)) { + return parser.emitError(parser.getCurrentLocation(), + "expected affine map attribute"); + } + indexingMapsAttr.push_back(mapAttr); + + if (parser.parseOptionalComma()) + break; + } while (true); + + if (parser.parseRSquare()) + return failure(); + } + // Initialize indexingMaps, if not supplied explicitly. + if (indexingMapsAttr.empty()) { + indexingMapsAttr = llvm::map_to_vector( + BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + } + result.addAttribute("indexing_maps", + parser.getBuilder().getArrayAttr(indexingMapsAttr)); + return ::parseNamedStructuredOp(parser, result, + BatchReduceMatmulOp::getNumRegionArgs(), + BatchReduceMatmulOp::getRegionBuilder()); +} + +void BatchReduceMatmulOp::print(OpAsmPrinter &p) { + SmallVector indexingMaps = llvm::map_to_vector( + BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + + if (!llvm::equal(getIndexingMaps(), indexingMaps)) { + p << " indexing_maps = ["; + llvm::interleaveComma(getIndexingMaps(), p, + [&](Attribute attr) { p.printAttribute(attr); }); + p << "]"; + } + + SmallVector elidedAttrs = { + "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); +} + +/// Verify the user defined indexing maps. +LogicalResult BatchReduceMatmulOp::verify() { + // Verification of pure batch_reduce_matmul is handled by + // verifyStructuredOpInterface(). + if (!hasUserDefinedMaps()) + return success(); + + for (unsigned opIndex = 0; opIndex < 3; opIndex++) { + if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex))) + return failure(); + } + return success(); +} +LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor, + SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} +void BatchReduceMatmulOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, cast(getOperation())); +} + +Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() { + return getGenericSpeculatabilityImpl(cast(getOperation())); +} + } // namespace linalg } // namespace mlir diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 63586a5bb8bbb..9a8a7b40e25e4 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -203,6 +203,19 @@ def batch_matmul( ) +def batch_reduce_matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchReduceMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) + + def contract( *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index 0ec71c35497b1..ae07b1b82228c 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -1024,6 +1024,34 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg // ----- +// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: func.func @batch_reduce_matmul( +// CHECK-SAME: %[[A:.*]]: tensor<2x3x5xf32>, +// CHECK-SAME: %[[B:.*]]: tensor<2x5x7xf32>, +// CHECK-SAME: %[[C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> { +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]], +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]} +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield + +func.func @batch_reduce_matmul(%A: tensor<2x3x5xf32>, %B: tensor<2x5x7xf32>, %C: tensor<3x7xf32>) -> tensor<3x7xf32> { + %0 = linalg.batch_reduce_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%A, %B: tensor<2x3x5xf32>, tensor<2x5x7xf32>) + outs(%C: tensor<3x7xf32>) -> tensor<3x7xf32> + return %0 : tensor<3x7xf32> +} + +// ----- + // CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> // CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 90ceadebbc1fa..04c59777d9d7a 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1364,10 +1364,10 @@ func.func @invalid_bcast_batch_matmul_a(%arg0: memref, %arg1: memref, %arg1: memref, %arg2: memref) { +func.func @invalid_single_dim_bcast_expr_batch_matmul_a(%arg0: memref, %arg1: memref, %arg2: memref) { // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} linalg.batch_matmul indexing_maps = [ - affine_map<(d0, d1, d2, d3) -> (d0, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] @@ -1377,14 +1377,14 @@ func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref, % // ----- -func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref, %arg1: memref, %arg2: memref) { +func.func @invalid_single_dim_bcast_expr_batch_matmul_B(%A: memref, %B: memref, %C: memref) { // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}} linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] - ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + ins(%A, %B : memref, memref) outs(%C: memref) return } @@ -1484,6 +1484,205 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref, %arg1 } +// ----- + +//===----------------------------------------------------------------------===// +// linalg.batch_reduce_matmul +//===----------------------------------------------------------------------===// + +func.func @missing_one_indexing_map(%arg0: memref, + %arg1: memref, %arg2: memref) { + // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, + affine_map<(batch, m, n, k) -> (batch, n, k)>] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2: memref) + return +} + +// ----- + +func.func @missing_two_indexing_map(%arg0: memref, + %arg1: memref, %arg2: memref) { + // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2: memref) + return + +} + +// ----- + +func.func @missing_indexing_map(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{expected attribute value}} + linalg.batch_reduce_matmul indexing_maps = [ + , + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 :memref) + return +} + +// ----- + +func.func @invalid_dim_expr_A(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, n, k)>, + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B : memref, memref) + outs(%C :memref) + return +} + +// ----- + +func.func @invalid_dim_expr_B(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, + affine_map<(batch, m, n, k) -> (batch, k, m)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B : memref, memref) + outs(%C :memref) + return +} + +// ----- + +func.func @invalid_bcast_A(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Invalid broadcast requested}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch)>, + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B : memref, memref) + outs(%C: memref) + return +} + +// ----- + +func.func @invalid_multi_dim_bcast_expr_A(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Invalid broadcast requested}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (k, batch)>, + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B : memref, memref) + outs(%C: memref) + return +} + +// ----- + +func.func @invalid_multi_dim_bcast_expr_B(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Invalid broadcast requested}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, + affine_map<(batch, m, n, k) -> (k, batch)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B : memref, memref) + outs(%C: memref) + return +} + +// ----- + +func.func @invalid_bcast_B(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Invalid broadcast requested}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, + affine_map<(batch, m, n, k) -> (n)>, + affine_map<(batch, m, n, k) -> (batch, m, n)>] + ins(%A, %B : memref, memref) + outs(%C: memref) + return +} + +// ----- + +func.func @invalid_batch_dim_A(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Invalid batch dimension expression}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (m, batch, k)>, + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B : memref, memref) + outs(%C :memref) + return +} + +// ----- + +func.func @invalid_batch_dim_B(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Invalid batch dimension expression}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, + affine_map<(batch, m, n, k) -> (n, k, batch)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B : memref, memref) + outs(%C :memref) + return +} + +// ----- + +func.func @invalid_A_map_result_num(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{no. of result dim expressions exceeds 3.}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k, k)>, + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B: memref, memref) + outs(%C: memref) + return +} + +// ----- + +func.func @invalid_B_map_result_num(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{no. of result dim expressions exceeds 3.}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, + affine_map<(batch, m, n, k) -> (batch, k, n, k)>, + affine_map<(batch, m, n, k) -> (m, n)>] + ins(%A, %B: memref, memref) + outs(%C: memref) + return +} + +// ----- + +func.func @invalid_C_map_result_num(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{expects 2 dims, but got (1).}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m)>] + ins(%A, %B: memref, memref) + outs(%C: memref) + return +} + +// ----- + +func.func @invalid_C_map_result_dim(%A: memref, %B: memref, %C: memref) { + // expected-error @+1 {{Invalid output map result dimension.}} + linalg.batch_reduce_matmul + indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>, + affine_map<(batch, m, n, k) -> (batch, k, n)>, + affine_map<(batch, m, n, k) -> (m, k)>] + ins(%A, %B: memref, memref) + outs(%C: memref) + return +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 1bd9c8825b05e..470bc1c78640c 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1637,6 +1637,175 @@ func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memre // ----- +//===----------------------------------------------------------------------===// +// linalg.batch_reduce_matmul +//===----------------------------------------------------------------------===// + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: func.func @bcast_k_to_fill_missing_dims_A( +// CHECK-SAME: %[[A:.*]]: memref<5xf32>, +// CHECK-SAME: %[[B:.*]]: memref<2x5x7xf32>, +// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { +// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>) +// CHECK: return +// CHECK: } + +func.func @bcast_k_to_fill_missing_dims_A(%A: memref<5xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) { + linalg.batch_reduce_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%A, %B : memref<5xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: func.func @bcast_batch_dim_A( +// CHECK-SAME: %[[A:.*]]: memref<3x5xf32>, +// CHECK-SAME: %[[B:.*]]: memref<2x5x7xf32>, +// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { +// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>) +// CHECK: return +// CHECK: } + +func.func @bcast_batch_dim_A(%A: memref<3x5xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) { + linalg.batch_reduce_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%A, %B : memref<3x5xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: func.func @bcast_batch_and_n_dim_B( +// CHECK-SAME: %[[A:.*]]: memref<2x3x5xf32>, +// CHECK-SAME: %[[B:.*]]: memref<5xf32>, +// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { +// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[C]] : memref<3x7xf32>) +// CHECK: return +// CHECK: } + +func.func @bcast_batch_and_n_dim_B(%A: memref<2x3x5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) { + linalg.batch_reduce_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%A, %B : memref<2x3x5xf32>, memref<5xf32>) outs(%C: memref<3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: func.func @bcast_batch_dim_B( +// CHECK-SAME: %[[A:.*]]: memref<2x3x5xf32>, +// CHECK-SAME: %[[B:.*]]: memref<5x7xf32>, +// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { +// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[C]] : memref<3x7xf32>) +// CHECK: return +// CHECK: } + +func.func @bcast_batch_dim_B(%A: memref<2x3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) { + linalg.batch_reduce_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%A, %B : memref<2x3x5xf32>, memref<5x7xf32>) outs(%C: memref<3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: func.func @explicit_transpose_A( +// CHECK-SAME: %[[A:.*]]: memref<2x5x3xf32>, +// CHECK-SAME: %[[B:.*]]: memref<2x5x7xf32>, +// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { +// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>) +// CHECK: return +// CHECK: } +func.func @explicit_transpose_A(%A: memref<2x5x3xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) { + linalg.batch_reduce_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%A, %B : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: func.func @explicit_transpose_B( +// CHECK-SAME: %[[A:.*]]: memref<2x3x5xf32>, +// CHECK-SAME: %[[B:.*]]: memref<2x7x5xf32>, +// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { +// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[C]] : memref<3x7xf32>) +// CHECK: return +// CHECK: } +func.func @explicit_transpose_B(%A: memref<2x3x5xf32>, %B: memref<2x7x5xf32>, %C: memref<3x7xf32>) { + linalg.batch_reduce_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%A, %B : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%C: memref<3x7xf32>) + return +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-LABEL: func.func @bcast_A_transpose_B( +// CHECK-SAME: %[[A:.*]]: memref<3x5xf32>, +// CHECK-SAME: %[[B:.*]]: memref<2x7x5xf32>, +// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) { +// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[C]] : memref<3x7xf32>) +// CHECK: return +// CHECK: } +func.func @bcast_A_transpose_B(%A: memref<3x5xf32>, %B: memref<2x7x5xf32>, %C: memref<3x7xf32>) { + linalg.batch_reduce_matmul indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2)> + ] + ins(%A, %B : memref<3x5xf32>, memref<2x7x5xf32>) outs(%C: memref<3x7xf32>) + return +} + +// ----- + // CHECK-LABEL: func @batchmatmul_transpose_a // CHECK: linalg.batch_matmul_transpose_a // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index e32a911b24b11..79bb576ff6738 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem): print(module) +# CHECK-LABEL: TEST: testBatchReduceMatmulOp +@run +def testBatchReduceMatmulOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + a_shape = (5, 4, 8) + b_shape = (5, 8, 12) + b_transposed_shape = (5, 12, 8) + c_shape = (4, 12) + + dimBatch = ir.AffineDimExpr.get(0) + dimM = ir.AffineDimExpr.get(1) + dimN = ir.AffineDimExpr.get(2) + dimK = ir.AffineDimExpr.get(3) + + # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> + # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> + # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK]) + b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK]) + c_map = ir.AffineMap.get(4, 0, [dimM, dimN]) + + # CHECK: func.func @batch_reduce_matmul_op( + @func.FuncOp.from_py_func( + # CHECK-SAME: %[[A:.*]]: tensor<5x4x8xf32>, + RankedTensorType.get(a_shape, f32), + # CHECK-SAME: %[[Amem:.*]]: memref<5x4x8xf32>, + MemRefType.get(a_shape, f32), + # CHECK-SAME: %[[B:.*]]: tensor<5x8x12xf32>, + RankedTensorType.get(b_shape, f32), + # CHECK-SAME: %[[Bmem:.*]]: memref<5x8x12xf32>, + MemRefType.get(b_shape, f32), + # CHECK-SAME: %[[BTrans:.*]]: tensor<5x12x8xf32>, + RankedTensorType.get(b_transposed_shape, f32), + # CHECK-SAME: %[[BTransmem:.*]]: memref<5x12x8xf32>, + MemRefType.get(b_transposed_shape, f32), + # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>, + RankedTensorType.get(c_shape, f32), + # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>) + MemRefType.get(c_shape, f32), + ) + def batch_reduce_matmul_op( + A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem + ): + # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>) + res = linalg.BatchReduceMatmulOp( + result_tensors=(C.type,), + inputs=(A, B), + outputs=(C,), + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>) + res = linalg.batch_reduce_matmul(A, B, outs=(C,)) + + # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>) + res = linalg.BatchReduceMatmulOp( + result_tensors=(C.type,), + inputs=(A, Btransposed), + outputs=(C,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>) + res = linalg.batch_reduce_matmul( + A, + Btransposed, + outs=(C,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + + # CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + res = linalg.BatchReduceMatmulOp( + result_tensors=[], + inputs=(Amem, Bmem), + outputs=(Cmem,), + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + linalg.batch_reduce_matmul(Amem, Bmem, outs=(Cmem,)) + + # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + res = linalg.BatchReduceMatmulOp( + result_tensors=[], + inputs=(Amem, Btransposedmem), + outputs=(Cmem,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>) + linalg.batch_reduce_matmul( + Amem, + Btransposedmem, + outs=(Cmem,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + + print(module) + + # CHECK-LABEL: TEST: testPackUnPackOp @run def testPackUnPackOp():