diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e578f4b956ef5..a90777c82bf63 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -621,7 +621,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ (ins "ValueRange":$inputs, "ValueRange":$outputs, CArg<"ArrayRef", "{}">:$attributes), [{ - buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, + buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs, attributes, MatmulOp::getRegionBuilder()); }]>, OpBuilder< @@ -629,7 +629,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ "ValueRange":$outputs, CArg<"ArrayRef", "{}">:$attributes), [{ - buildStructuredOp($_builder, $_state, resultTensorTypes, + buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs, attributes, MatmulOp::getRegionBuilder()); }]>, OpBuilder< @@ -647,7 +647,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ "Attribute":$cast, CArg<"ArrayRef", "{}">:$attributes), [{ $_state.addAttribute("cast", cast); - buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, + buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs, attributes, MatmulOp::getRegionBuilder()); }]> diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index c909d13e4314b..dee8a4e27e6b2 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -169,7 +169,8 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) { } /// Wrapper to return the typical indexing map array attribute for MatmulOp. -static SmallVector getDefaultIndexingMapAttr(MLIRContext *context) { +static SmallVector +getDefaultMatmulIndexingMapAttr(MLIRContext *context) { return llvm::map_to_vector( getDefaultIndexingMapsForMatmul(context), [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); @@ -179,12 +180,11 @@ static SmallVector getDefaultIndexingMapAttr(MLIRContext *context) { /// The result types are derived automatically if `resultTensorTypes` is none. /// The body of the operation is filled using `regionBuilder`. All ods-gen /// created structured operations use the method to implement their builders. -static void buildStructuredOp( - OpBuilder &b, OperationState &state, - std::optional resultTensorTypes, ValueRange inputs, - ValueRange outputs, ArrayRef attributes, - RegionBuilderFn regionBuilder, - std::optional> indexingMaps = std::nullopt) { +static void buildStructuredOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef attributes, + RegionBuilderFn regionBuilder) { // Derive the result types if needed. SmallVector derivedResultTypes = resultTensorTypes.value_or(TypeRange()); @@ -196,6 +196,24 @@ static void buildStructuredOp( state.addOperands(outputs); state.addTypes(derivedResultTypes); + state.addAttributes(attributes); + state.addAttribute( + "operandSegmentSizes", + b.getDenseI32ArrayAttr({static_cast(inputs.size()), + static_cast(outputs.size())})); + + // Create and fill the region of the structured operation. + Region ®ion = *state.addRegion(); + fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs), + state.attributes.getAttrs(), regionBuilder); +} + +static void +buildMatmulOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, ValueRange inputs, + ValueRange outputs, ArrayRef attributes, + RegionBuilderFn regionBuilder, + std::optional> indexingMaps = std::nullopt) { // Initialize indexingMaps, for MatmulOp. SmallVector indexingMapsAttrVal; if (indexingMaps.has_value()) { @@ -205,20 +223,11 @@ static void buildStructuredOp( } state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); } else { - indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext()); + indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext()); state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); } - - state.addAttributes(attributes); - state.addAttribute( - "operandSegmentSizes", - b.getDenseI32ArrayAttr({static_cast(inputs.size()), - static_cast(outputs.size())})); - - // Create and fill the region of the structured operation. - Region ®ion = *state.addRegion(); - fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs), - state.attributes.getAttrs(), regionBuilder); + return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, + attributes, regionBuilder); } /// Common parsing used for both named structured ops created by ods-gen and by @@ -340,39 +349,6 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder) { - - SmallVector indexingMapsAttr; - Attribute mapAttr; - if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { - if (parser.parseEqual()) - return failure(); - - if (parser.parseLSquare()) - return failure(); - - do { - if (parser.parseAttribute(mapAttr)) - return failure(); - if (!isa(mapAttr)) { - return parser.emitError(parser.getCurrentLocation(), - "expected affine map attribute"); - } - indexingMapsAttr.push_back(mapAttr); - - if (parser.parseOptionalComma()) - break; - } while (true); - - if (parser.parseRSquare()) - return failure(); - } - // Initialize indexingMaps, if not supplied explicitly. - if (indexingMapsAttr.empty()) { - indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext()); - } - result.addAttribute("indexing_maps", - parser.getBuilder().getArrayAttr(indexingMapsAttr)); - // TODO: Enable when ods-gen supports captures. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) @@ -3503,9 +3479,11 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, namespace mlir { namespace linalg { + //===----------------------------------------------------------------------===// // MatMulOp //===----------------------------------------------------------------------===// + SmallVector MatmulOp::getIteratorTypesArray() { return SmallVector{utils::IteratorType::parallel, utils::IteratorType::parallel, @@ -3520,8 +3498,8 @@ std::string MatmulOp::getLibraryCallName() { bool MatmulOp::hasDynamicIndexingMaps() { return true; } -/// Check if the op has broadcast and/or transpose semantic. Returns true if the -/// user defined indexing maps are not equal to default map. +/// Check if the op has broadcast and/or transpose semantic. Returns true if +/// the user defined indexing maps are not equal to default map. bool MatmulOp::hasUserDefinedMaps() { SmallVector defaultMaps = getDefaultIndexingMaps(); SmallVector explicitMaps = getIndexingMapsArray(); @@ -3557,7 +3535,8 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, helper.yieldOutputs(yields); } -/// Returns a list of AffineMap with the typical matmul indexing charactristic. +/// Returns a list of AffineMap with the typical matmul indexing +/// charactristic. SmallVector MatmulOp::getDefaultIndexingMaps() { MLIRContext *context = this->getContext(); return getDefaultIndexingMapsForMatmul(context); @@ -3572,6 +3551,38 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { } ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector indexingMapsAttr; + Attribute mapAttr; + if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { + if (parser.parseEqual()) + return failure(); + + if (parser.parseLSquare()) + return failure(); + + do { + if (parser.parseAttribute(mapAttr)) + return failure(); + if (!isa(mapAttr)) { + return parser.emitError(parser.getCurrentLocation(), + "expected affine map attribute"); + } + indexingMapsAttr.push_back(mapAttr); + + if (parser.parseOptionalComma()) + break; + } while (true); + + if (parser.parseRSquare()) + return failure(); + } + // Initialize indexingMaps, if not supplied explicitly. + if (indexingMapsAttr.empty()) { + indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext()); + } + result.addAttribute("indexing_maps", + parser.getBuilder().getArrayAttr(indexingMapsAttr)); + return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(), MatmulOp::getRegionBuilder()); } @@ -3582,7 +3593,7 @@ void MatmulOp::print(OpAsmPrinter &p) { elidedAttrs); SmallVector indexingMaps = - getDefaultIndexingMapAttr(getContext()); + getDefaultMatmulIndexingMapAttr(getContext()); if (!llvm::equal(getIndexingMaps(), indexingMaps)) { p << " indexing_maps = ["; llvm::interleaveComma(getIndexingMaps(), p, diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir index c086d0fd7e633..ebdbe70ff46eb 100644 --- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir +++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir @@ -43,7 +43,8 @@ func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor< // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]] // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]] // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>) + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec + // CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>) // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128] // CHECK-NEXT: return %[[RES]] %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>) @@ -62,7 +63,8 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]] // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]] // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]] - // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat + // CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] // CHECK-NEXT: return %[[RES]] @@ -113,7 +115,8 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor, %arg1: tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec + // CHECK-SAME: ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]] // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1] // CHECK-NEXT: return %[[RES]] @@ -140,7 +143,8 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) + // CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat + // CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor, tensor) outs(%[[COLLAPSED_INIT]] : tensor) // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]] // CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]] // CHECK-NEXT: return %[[RES]] diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 6be7d4320c656..80d979864921d 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -679,8 +679,7 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ } void {0}::print(OpAsmPrinter &p) {{ SmallVector elidedAttrs = {{"operandSegmentSizes", - "linalg.memoized_indexing_maps", - "indexing_maps"}; + "linalg.memoized_indexing_maps"}; ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), elidedAttrs); }