diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index a90777c82bf63..37eec6e07963b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -622,7 +622,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ CArg<"ArrayRef", "{}">:$attributes), [{ buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs, - attributes, MatmulOp::getRegionBuilder()); + attributes, MatmulOp::getRegionBuilder(), + MatmulOp::getDefaultIndexingMaps($_builder.getContext())); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -630,16 +631,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ CArg<"ArrayRef", "{}">:$attributes), [{ buildMatmulOp($_builder, $_state, resultTensorTypes, - inputs, outputs, attributes, MatmulOp::getRegionBuilder()); - }]>, - OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), - [{ - $_state.addOperands(operands); - $_state.addAttributes(attributes); - $_state.addTypes(resultTensorTypes); - (void)$_state.addRegion(); + inputs, outputs, attributes, MatmulOp::getRegionBuilder(), + MatmulOp::getDefaultIndexingMaps($_builder.getContext())); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -648,7 +641,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ [{ $_state.addAttribute("cast", cast); buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs, - attributes, MatmulOp::getRegionBuilder()); + attributes, MatmulOp::getRegionBuilder(), + MatmulOp::getDefaultIndexingMaps($_builder.getContext())); }]> ]; @@ -664,7 +658,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ Block &block, ArrayRef attrs); /// Returns a list of AffineMap with the typical matmul indexing charactristic. - SmallVector getDefaultIndexingMaps(); + static SmallVector getDefaultIndexingMaps(MLIRContext *context); /// Returns true if the given broadcast map \p bcastMap is valid for this op. bool isValidLhsRhsBroadcastMap(AffineMap bcastMap); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index dee8a4e27e6b2..cce54695c0e6c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -155,27 +155,6 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, // iterator_types is an auto-generated method. } -/// Helper to create a typical indexing map for MatmulOp. Returns a list of -/// AffineMap. -static SmallVector -getDefaultIndexingMapsForMatmul(MLIRContext *context) { - AffineExpr d0, d1, d2; - SmallVector indexingMaps; - bindDims(context, d0, d1, d2); - indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context)); - indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context)); - indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context)); - return indexingMaps; -} - -/// Wrapper to return the typical indexing map array attribute for MatmulOp. -static SmallVector -getDefaultMatmulIndexingMapAttr(MLIRContext *context) { - return llvm::map_to_vector( - getDefaultIndexingMapsForMatmul(context), - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); -} - /// Creates a structured operation given `inputs`, `outputs`, and `attributes`. /// The result types are derived automatically if `resultTensorTypes` is none. /// The body of the operation is filled using `regionBuilder`. All ods-gen @@ -208,24 +187,18 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state, state.attributes.getAttrs(), regionBuilder); } -static void -buildMatmulOp(OpBuilder &b, OperationState &state, - std::optional resultTensorTypes, ValueRange inputs, - ValueRange outputs, ArrayRef attributes, - RegionBuilderFn regionBuilder, - std::optional> indexingMaps = std::nullopt) { - // Initialize indexingMaps, for MatmulOp. +static void buildMatmulOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef attributes, + RegionBuilderFn regionBuilder, + ArrayRef indexingMaps) { + // Initialize indexingMaps attribute, for MatmulOp. SmallVector indexingMapsAttrVal; - if (indexingMaps.has_value()) { - for (mlir::AffineMap map : *indexingMaps) { - // Convert each AffineMap to an AffineMapAttr - indexingMapsAttrVal.push_back(AffineMapAttr::get(map)); - } - state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); - } else { - indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext()); - state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); - } + indexingMapsAttrVal = llvm::map_to_vector( + MatmulOp::getDefaultIndexingMaps(b.getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, attributes, regionBuilder); } @@ -3457,7 +3430,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex) { SmallVector opIndexingMaps = matmulOp.getIndexingMapsArray(); SmallVector defaultIndexingMaps = - matmulOp.getDefaultIndexingMaps(); + matmulOp.getDefaultIndexingMaps(matmulOp->getContext()); auto opIndexingMap = opIndexingMaps[opIndex]; auto defaultIndexingMap = defaultIndexingMaps[opIndex]; @@ -3484,6 +3457,17 @@ namespace linalg { // MatMulOp //===----------------------------------------------------------------------===// +/// Returns a list of AffineMap with the typical matmul indexing charactristic. +SmallVector MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { + AffineExpr d0, d1, d2; + SmallVector indexingMaps; + bindDims(context, d0, d1, d2); + indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context)); + indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context)); + indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context)); + return indexingMaps; +} + SmallVector MatmulOp::getIteratorTypesArray() { return SmallVector{utils::IteratorType::parallel, utils::IteratorType::parallel, @@ -3501,7 +3485,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; } /// Check if the op has broadcast and/or transpose semantic. Returns true if /// the user defined indexing maps are not equal to default map. bool MatmulOp::hasUserDefinedMaps() { - SmallVector defaultMaps = getDefaultIndexingMaps(); + SmallVector defaultMaps = + getDefaultIndexingMaps(this->getContext()); SmallVector explicitMaps = getIndexingMapsArray(); return defaultMaps != explicitMaps; } @@ -3535,13 +3520,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, helper.yieldOutputs(yields); } -/// Returns a list of AffineMap with the typical matmul indexing -/// charactristic. -SmallVector MatmulOp::getDefaultIndexingMaps() { - MLIRContext *context = this->getContext(); - return getDefaultIndexingMapsForMatmul(context); -} - /// Returns true if the given broadcast map \p bcastMap is valid for this op. bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr."); @@ -3578,7 +3556,9 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) { } // Initialize indexingMaps, if not supplied explicitly. if (indexingMapsAttr.empty()) { - indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext()); + indexingMapsAttr = llvm::map_to_vector( + MatmulOp::getDefaultIndexingMaps(parser.getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); } result.addAttribute("indexing_maps", parser.getBuilder().getArrayAttr(indexingMapsAttr)); @@ -3592,8 +3572,9 @@ void MatmulOp::print(OpAsmPrinter &p) { printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), elidedAttrs); - SmallVector indexingMaps = - getDefaultMatmulIndexingMapAttr(getContext()); + SmallVector indexingMaps = llvm::map_to_vector( + MatmulOp::getDefaultIndexingMaps(getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); if (!llvm::equal(getIndexingMaps(), indexingMaps)) { p << " indexing_maps = ["; llvm::interleaveComma(getIndexingMaps(), p,