From fbe1abb62c9d502fcb8db1978359cd3a88a2a906 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 15 Nov 2024 12:41:45 +0000 Subject: [PATCH 1/2] Revert "[MLIR][Linalg] Fix unclosed code block which broke generated docs - NFC (#115763)" This reverts commit c02b8a01b7caf2e4ffe17a123f1bcf59192e4b39. --- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e578f4b956ef5..2b47414ff5e92 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -561,7 +561,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ AttrSizedOperandSegments, LinalgContractionOpInterface]> { - + let summary = [{ Performs a matrix multiplication of two 2D inputs without broadcast or transpose. }]; @@ -593,17 +593,16 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ ] ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) - ``` + ``` - Example Broadcast and transpose: - ``` - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose - affine_map<(d0, d1, d2) -> (d2)>, // broadcast - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) - ``` + Example Broadcast and transpose: + ``` + linalg.matmul indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose + affine_map<(d0, d1, d2) -> (d2)>, // broadcast + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) }]; let arguments = (ins From 828fcb06329ea8ff36ac1be08a4ef5ad0ae8e30e Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 15 Nov 2024 12:42:08 +0000 Subject: [PATCH 2/2] Revert "[MLIR][Linalg] Re-land linalg.matmul move to ODS. + Remove/update failing obsolete OpDSL tests. (#115319)" This reverts commit 3ad0148020ca91cc288bffd8ad36e25f7555a3bb. --- .../Dialect/Linalg/IR/LinalgInterfaces.td | 10 - .../Linalg/IR/LinalgNamedStructuredOps.yaml | 72 +++++ .../Dialect/Linalg/IR/LinalgStructuredOps.td | 134 --------- .../Dialect/Linalg/IR/LinalgInterfaces.cpp | 17 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 263 +----------------- .../Linalg/Transforms/TransposeMatmul.cpp | 7 - .../Linalg/Transforms/Vectorization.cpp | 5 - .../NVGPU/TransformOps/NVGPUTransformOps.cpp | 6 - .../linalg/opdsl/ops/core_named_ops.py | 17 ++ .../Dialect/Linalg/generalize-named-ops.mlir | 111 -------- mlir/test/Dialect/Linalg/invalid.mlir | 159 ----------- mlir/test/Dialect/Linalg/named-ops.mlir | 243 ---------------- mlir/test/python/dialects/linalg/ops.py | 75 +++++ .../integration/dialects/linalg/opsrun.py | 115 ++++++++ .../python/integration/dialects/transform.py | 28 +- .../mlir-linalg-ods-yaml-gen.cpp | 6 +- 16 files changed, 309 insertions(+), 959 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index c0eff99c85075..b81a4c9c8760c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -708,16 +708,6 @@ def LinalgStructuredInterface return; }] >, - InterfaceMethod< - /*desc=*/[{ - Return true if the user has supplied an explicit indexing maps for this op. - }], - /*retTy=*/"bool", - /*methodName=*/"hasUserDefinedMaps", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ return false; }] - >, //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index b0ea1f7695581..8cd63bc927075 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1065,6 +1065,78 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: rhs --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: matmul + cpp_class_name: MatmulOp + doc: |- + Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + - !LinalgOperandDefConfig + name: B + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> + - !LinalgOperandDefConfig + name: C + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + - !LinalgOperandDefConfig + name: cast + kind: type_fn_attr + default_fn: cast_signed + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + iterator_types: + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: B +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_matmul cpp_class_name: QuantizedMatmulOp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 2b47414ff5e92..c2fee8ea55c96 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -554,140 +554,6 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ let hasCanonicalizer = 1; } -//===----------------------------------------------------------------------===// -// Op definition for MatmulOp -//===----------------------------------------------------------------------===// - -def MatmulOp : LinalgStructuredBase_Op<"matmul", [ - AttrSizedOperandSegments, - LinalgContractionOpInterface]> { - - let summary = [{ - Performs a matrix multiplication of two 2D inputs without broadcast or transpose. - }]; - let description = [{ - Numeric casting is performed on the operands to the inner multiply, - promoting them to the same data type as the accumulator/output. - - Broadcast and Transpose semantics can be appiled by specifying the explicit attribute - 'indexing_maps' as shown below.This is a list attribute, so the list must include all - the maps if specified. - - Example Transpose: - ``` - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>) - outs(%arg2: memref<3x7xf32>) - ``` - - Example Broadcast: - ``` - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, // broadcast - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) - outs(%arg2: memref<3x7xf32>) - ``` - - Example Broadcast and transpose: - ``` - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose - affine_map<(d0, d1, d2) -> (d2)>, // broadcast - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) - }]; - - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs, - DefaultValuedOptionalAttr:$indexing_maps, - DefaultValuedOptionalAttr:$cast - ); - let results = (outs Variadic:$result_tensors); - let regions = (region AnyRegion:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder< - (ins "ValueRange":$inputs, "ValueRange":$outputs, - CArg<"ArrayRef", "{}">:$attributes), - [{ - buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, - attributes, MatmulOp::getRegionBuilder()); - }]>, - OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs, - CArg<"ArrayRef", "{}">:$attributes), - [{ - buildStructuredOp($_builder, $_state, resultTensorTypes, - inputs, outputs, attributes, MatmulOp::getRegionBuilder()); - }]>, - OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), - [{ - $_state.addOperands(operands); - $_state.addAttributes(attributes); - $_state.addTypes(resultTensorTypes); - (void)$_state.addRegion(); - }]>, - OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs, - "Attribute":$cast, CArg<"ArrayRef", "{}">:$attributes), - [{ - $_state.addAttribute("cast", cast); - buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, - attributes, MatmulOp::getRegionBuilder()); - }]> - - ]; - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; - let hasVerifier = 1; - - let extraClassDeclaration = structuredOpsBaseDecls # [{ - SmallVector getIteratorTypesArray(); - - /// Implements the block region builder. - static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); - - /// Returns a list of AffineMap with the typical matmul indexing charactristic. - SmallVector getDefaultIndexingMaps(); - - /// Returns true if the given broadcast map \p bcastMap is valid for this op. - bool isValidLhsRhsBroadcastMap(AffineMap bcastMap); - - static std::function)> - getRegionBuilder() { - return regionBuilder; - } - - ::mlir::MutableOperandRange getDpsInitsMutable() { - return getOutputsMutable(); - } - - // Generic methods. - static unsigned getNumRegionArgs(); - std::string getLibraryCallName(); - bool hasDynamicIndexingMaps(); - /// Check if the op has broadcast and/or transpose semantic. Returns true if the - /// user defined indexing maps are not equal to default map. - bool hasUserDefinedMaps(); - }]; -} - //===----------------------------------------------------------------------===// // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 0cffadf8fb64a..bd77965194b27 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -15,21 +15,14 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/IR/TypeUtilities.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" #include #include -#include using namespace mlir; using namespace mlir::linalg; @@ -1218,6 +1211,7 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); + // Mixed tensor/buffer operands are not allowed. if (!linalgOp.hasPureTensorSemantics() && !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) @@ -1237,8 +1231,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { << ") to be equal to the number of input/output operands (" << linalgOp->getNumOperands() << ")"; - // Set this flag if this op has user defined maps. This is required to guard - // the below error condition which assume default indexing maps. for (OpOperand &opOperand : linalgOp->getOpOperands()) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); @@ -1255,13 +1247,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { << " dim(s) to match the number of loops"; int64_t rank = linalgOp.getRank(&opOperand); - if (indexingMap.getNumResults() != rank) return op->emitOpError("expected operand rank (") << rank << ") to match the result rank of indexing_map #" << opOperand.getOperandNumber() << " (" << indexingMap.getNumResults() << ")"; } + SmallVector redDims; linalgOp.getReductionDims(redDims); @@ -1271,8 +1263,9 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { // Check if given shapes match to inferred shapes. SmallVector endLoopRangeValues = linalgOp.getStaticLoopRanges(); SmallVector startLoopRangeValues(endLoopRangeValues.size(), 0); - // Verify only static cases since we can't get exact dimension sizes and - // loop ranges for dynamic cases in this stage. + + // Verify only static cases since we can't get exact dimension sizes and loop + // ranges for dynamic cases in this stage. if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { for (int64_t &range : endLoopRangeValues) range -= 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index c909d13e4314b..730c478c2883e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -27,7 +27,6 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Matchers.h" @@ -38,17 +37,12 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include #include using namespace mlir; @@ -155,36 +149,15 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, // iterator_types is an auto-generated method. } -/// Helper to create a typical indexing map for MatmulOp. Returns a list of -/// AffineMap. -static SmallVector -getDefaultIndexingMapsForMatmul(MLIRContext *context) { - AffineExpr d0, d1, d2; - SmallVector indexingMaps; - bindDims(context, d0, d1, d2); - indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context)); - indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context)); - indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context)); - return indexingMaps; -} - -/// Wrapper to return the typical indexing map array attribute for MatmulOp. -static SmallVector getDefaultIndexingMapAttr(MLIRContext *context) { - return llvm::map_to_vector( - getDefaultIndexingMapsForMatmul(context), - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); -} - /// Creates a structured operation given `inputs`, `outputs`, and `attributes`. /// The result types are derived automatically if `resultTensorTypes` is none. /// The body of the operation is filled using `regionBuilder`. All ods-gen /// created structured operations use the method to implement their builders. -static void buildStructuredOp( - OpBuilder &b, OperationState &state, - std::optional resultTensorTypes, ValueRange inputs, - ValueRange outputs, ArrayRef attributes, - RegionBuilderFn regionBuilder, - std::optional> indexingMaps = std::nullopt) { +static void buildStructuredOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef attributes, + RegionBuilderFn regionBuilder) { // Derive the result types if needed. SmallVector derivedResultTypes = resultTensorTypes.value_or(TypeRange()); @@ -195,20 +168,6 @@ static void buildStructuredOp( state.addOperands(inputs); state.addOperands(outputs); state.addTypes(derivedResultTypes); - - // Initialize indexingMaps, for MatmulOp. - SmallVector indexingMapsAttrVal; - if (indexingMaps.has_value()) { - for (mlir::AffineMap map : *indexingMaps) { - // Convert each AffineMap to an AffineMapAttr - indexingMapsAttrVal.push_back(AffineMapAttr::get(map)); - } - state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); - } else { - indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext()); - state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); - } - state.addAttributes(attributes); state.addAttribute( "operandSegmentSizes", @@ -340,48 +299,11 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder) { - - SmallVector indexingMapsAttr; - Attribute mapAttr; - if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { - if (parser.parseEqual()) - return failure(); - - if (parser.parseLSquare()) - return failure(); - - do { - if (parser.parseAttribute(mapAttr)) - return failure(); - if (!isa(mapAttr)) { - return parser.emitError(parser.getCurrentLocation(), - "expected affine map attribute"); - } - indexingMapsAttr.push_back(mapAttr); - - if (parser.parseOptionalComma()) - break; - } while (true); - - if (parser.parseRSquare()) - return failure(); - } - // Initialize indexingMaps, if not supplied explicitly. - if (indexingMapsAttr.empty()) { - indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext()); - } - result.addAttribute("indexing_maps", - parser.getBuilder().getArrayAttr(indexingMapsAttr)); - // TODO: Enable when ods-gen supports captures. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); - // Parse optional attributes. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - // TODO: consider merging results parsing into region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; @@ -407,9 +329,13 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p, } static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, - ValueRange inputs, ValueRange outputs, - ArrayRef elidedAttrs = {}) { - p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); + ValueRange inputs, ValueRange outputs) { + p.printOptionalAttrDict( + op->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", + // See generated code in + // LinalgNamedStructuredOps.yamlgen.cpp.inc + "linalg.memoized_indexing_maps"}); // Printing is shared with generic ops, except for the region and // attributes. @@ -3456,168 +3382,3 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder, Location loc) { return arith::ConstantOp::materialize(builder, value, type, loc); } - -/// Returns true if the result AffineExpr of the \p explicitMap is same as \p -/// defaultMap. -static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) { - auto explicitRange = explictMap.getResults(); - auto defaultRange = defaultMap.getResults(); - DenseSet explicitSet(explicitRange.begin(), explicitRange.end()); - DenseSet defaultSet(defaultRange.begin(), defaultRange.end()); - llvm::set_union(explicitSet, defaultSet); - return explicitSet == defaultSet; -} - -/// Returns true if the \p explictMap is broadcasted with respect to the -/// \p defaultMap. -static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) { - return explictMap.getNumResults() < defaultMap.getNumResults(); -} - -/// Verifies the broadcast and transpose semantic sepecified by the explicit -/// indexing map for the MatmulOp \p op for each operand specified by \p -/// opIndex. -static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, - unsigned opIndex) { - SmallVector opIndexingMaps = matmulOp.getIndexingMapsArray(); - SmallVector defaultIndexingMaps = - matmulOp.getDefaultIndexingMaps(); - - auto opIndexingMap = opIndexingMaps[opIndex]; - auto defaultIndexingMap = defaultIndexingMaps[opIndex]; - // Check general validity of indexing map results. - if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap)) - return matmulOp->emitOpError() - << "Unexpected dim expression in map result."; - - // Check if the requested broadcast is valid. - if (isBroadcasted(opIndexingMap, defaultIndexingMap)) { - if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) { - return matmulOp->emitOpError() - << "Invalid broadcast requested, should be (d2)."; - } - return success(); - } - return success(); -} - -namespace mlir { -namespace linalg { -//===----------------------------------------------------------------------===// -// MatMulOp -//===----------------------------------------------------------------------===// -SmallVector MatmulOp::getIteratorTypesArray() { - return SmallVector{utils::IteratorType::parallel, - utils::IteratorType::parallel, - utils::IteratorType::reduction}; -} - -unsigned MatmulOp::getNumRegionArgs() { return 3; } - -std::string MatmulOp::getLibraryCallName() { - return generateLibraryCallName(getOperation()); -} - -bool MatmulOp::hasDynamicIndexingMaps() { return true; } - -/// Check if the op has broadcast and/or transpose semantic. Returns true if the -/// user defined indexing maps are not equal to default map. -bool MatmulOp::hasUserDefinedMaps() { - SmallVector defaultMaps = getDefaultIndexingMaps(); - SmallVector explicitMaps = getIndexingMapsArray(); - return defaultMaps != explicitMaps; -} - -/// Implements the block region builder for the MatmulOp. This is called by -/// 'fillStructuredOpRegion'. -void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { - assert(3 > 0 && block.getNumArguments() == 3 && - "MatmulOp regionBuilder expects 3 (>=0) args"); - RegionBuilderHelper helper(b, block); - SmallVector yields; - - TypeFn castVal = TypeFn::cast_signed; - auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { - return attr.getName() == "cast"; - }); - if (castIter != attrs.end()) { - if (auto attr = llvm::dyn_cast(castIter->getValue())) - castVal = attr.getValue(); - } - - Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), - block.getArgument(0)); - Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), - block.getArgument(1)); - Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); - Value value4 = - helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); - yields.push_back(value4); - helper.yieldOutputs(yields); -} - -/// Returns a list of AffineMap with the typical matmul indexing charactristic. -SmallVector MatmulOp::getDefaultIndexingMaps() { - MLIRContext *context = this->getContext(); - return getDefaultIndexingMapsForMatmul(context); -} - -/// Returns true if the given broadcast map \p bcastMap is valid for this op. -bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { - assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr."); - AffineExpr exp = bcastMap.getResult(0); - // Invalid map if the common dimension of matmul not found. - return exp.isFunctionOfDim(bcastMap.getNumDims() - 1); -} - -ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) { - return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(), - MatmulOp::getRegionBuilder()); -} -void MatmulOp::print(OpAsmPrinter &p) { - SmallVector elidedAttrs = { - "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; - printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), - elidedAttrs); - - SmallVector indexingMaps = - getDefaultIndexingMapAttr(getContext()); - if (!llvm::equal(getIndexingMaps(), indexingMaps)) { - p << " indexing_maps = ["; - llvm::interleaveComma(getIndexingMaps(), p, - [&](Attribute attr) { p.printAttribute(attr); }); - p << "]"; - } -} - -/// Verify the user defined indexing maps. -LogicalResult MatmulOp::verify() { - // Verification of pure matmul is handled by verifyStructuredOpInterface(). - if (!hasUserDefinedMaps()) - return success(); - - for (unsigned opIndex = 0; opIndex < 2; opIndex++) { - if (failed(verifyExtendedMatmulSemantic(*this, opIndex))) - return failure(); - } - return success(); -} - -LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} -void MatmulOp::getEffects( - SmallVectorImpl> - &effects) { - if (hasPureTensorSemantics()) - return; - getGenericEffectsImpl(effects, cast(getOperation())); -} - -Speculation::Speculatability MatmulOp::getSpeculatability() { - return getGenericSpeculatabilityImpl(cast(getOperation())); -} - -} // namespace linalg -} // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index 6b934f7e8157d..aa0052ce47fa7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -31,13 +31,6 @@ using namespace mlir::linalg; FailureOr mlir::linalg::transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp matmulOp, bool transposeLHS) { - // Check to not let go the matmul with extended semantic, through this - // transform. - if (matmulOp.hasUserDefinedMaps()) { - return rewriter.notifyMatchFailure( - matmulOp, "only matmul ops with non-extended semantics are supported"); - } - if (!bufferization::hasTensorSemantics(matmulOp)) return rewriter.notifyMatchFailure( matmulOp, "only matmul ops with tensors are supported"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 23b46a2ee55f8..7f562a3a99b3b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2089,11 +2089,6 @@ vectorizeScalableVectorPrecondition(Operation *op, return failure(); } - // Check to not let go the matmul with extended semantic, through this - // transform. - if (linalgOp.hasUserDefinedMaps()) - return failure(); - // Cond 4: Only the following ops are supported in the // presence of scalable vectors return success(isElementwise(linalgOp) || isa(op) || diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 3c508ed6e324b..0c2275bbc4b22 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -821,12 +821,6 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( bool fail = true; // TODO: more robust detection of matmulOp, with transposes etc. if (isa_and_nonnull(linalgOp.getOperation())) { - // Check to not let go the matmul with extended semantic, through this - // transform. - if (linalgOp.hasUserDefinedMaps()) { - return emitSilenceableError() - << "only matmul ops with non-extended semantics are supported"; - } Location loc = linalgOp.getLoc(); // TODO: more robust computation of laneId, for now assume a single warp. Value laneId = rewriter.create( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index c95cd5eecfffc..89895760cad74 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -383,6 +383,23 @@ def select( O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None]) +@linalg_structured_op +def matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + + @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index aba26c35931fd..1e8f1435ca0fa 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -29,34 +29,6 @@ func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, // ----- -func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-LABEL: func.func @matmul_bcast_a( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) { -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32): -// CHECK: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32 -// CHECK: linalg.yield %[[VAL_7]] : f32 -// CHECK: } -// CHECK: return -// CHECK: } - -// ----- - func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> @@ -919,86 +891,3 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor, tensor, tensor> } - -// ----- - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-LABEL: func.func @matmul_transpose_a_explicit( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { - -// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} -// CHECK: arith.mulf -// CHECK: arith.addf - -func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) - outs(%arg2: memref<3x7xf32>) - - return -} - -// ----- - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-LABEL: func.func @matmul_transpose_b_explicit( -// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { - -// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} -// CHECK: arith.mulf -// CHECK: arith.addf - -func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) - outs(%arg2: memref<3x7xf32>) - - return -} - -// ----- - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { - -// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} -// CHECK: arith.mulf -// CHECK: arith.addf - -func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) - outs(%arg2: memref<3x7xf32>) - - return -} - -// ----- - diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index a59472377a732..4b5a66f8fb5b9 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -370,165 +370,6 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, // ----- -func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) { - // expected-error @+1 {{expected attribute value}} - linalg.matmul indexing_maps = [ - , - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>) - outs(%arg2 :memref<2x4xf32>) - return -} - -// ----- - -func.func @invalid_matmul_dim_a(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) { - // expected-error @+1 {{Unexpected dim expression in map result}} - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>) - return -} - -// ----- - -func.func @invalid_matmul_dim_b(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) { - // expected-error @+1 {{Unexpected dim expression in map result}} - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2, d0)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>) - return -} - -// ----- - -func.func @invalid_transpose_a_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> { - // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 1}} - %0 = linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) - outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32> - return %0: tensor<4x64xf32> -} - -// ----- - -func.func @invalid_transpose_b_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> { - // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #1 to be 1, but found 64}} - %0 = linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) - outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32> - return %0: tensor<4x64xf32> -} - -// ----- - -func.func @invalid_bcast_a(%arg0: memref<3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { - // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}} - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// ----- - -func.func @invalid_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) { - // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}} - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// ----- - -func.func @invalid_bcast_a_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { - // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #0 (1)}} - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// ----- - -func.func @invalid_bcast_b_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { - // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #1 (1)}} - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// ----- - -func.func @invalid_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) { - // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}} - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// ----- - -func.func @invalid_matmul_bcast_b_transpose_a_wrong_dim(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { - // expected-error @+1 {{'linalg.matmul' op Unexpected dim expression in map result.}} - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// ----- - -func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) { - // expected-error @+2 {{custom op 'indexing_maps' is unknown (tried 'func.indexing_maps' as well)}} - linalg.matmul ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) outs(%init : tensor<4x64xf32>) - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - return -} - -// ----- - func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) { // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}} linalg.conv_2d_nhwc_hwcf diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 68aa5a85b5e0e..bc0cfd52e8b51 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1233,249 +1233,6 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a // ----- -// CHECK-LABEL: func @matmul_transpose_a_explicit -// CHECK: linalg.matmul -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>) -// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) -func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) - outs(%arg2: memref<3x7xf32>) - - return -} - -// ----- - -func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) - outs(%arg2: memref<3x7xf32>) - - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-LABEL: func.func @matmul_transpose_b_explicit( -// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] -// CHECK: return -// CHECK: } - -// ----- - -func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) - outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] -// CHECK: return -// CHECK: } - -// ----- - -func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-LABEL: func @matmul_bcast_a -// CHECK: linalg.matmul -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>) -// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) - -// ----- - -func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-LABEL: func @matmul_bcast_a_dim1 -// CHECK: linalg.matmul -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>) -// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) - -// ----- - -func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-LABEL: func @matmul_bcast_b -// CHECK: linalg.matmul -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>) -// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) - -// ----- - -func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-LABEL: func.func @matmul_bcast_a_b( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]] -// CHECK: return -// CHECK: } - -// ----- - -func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-LABEL: func @matmul_bcast_b_dim1 -// CHECK: linalg.matmul -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>) -// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) - -// ----- - -func.func @dynamic_matmul_bcast_a(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-LABEL: func.func @dynamic_matmul_bcast_a( -// CHECK-SAME: %[[VAL_0:.*]]: memref, -// CHECK-SAME: %[[VAL_1:.*]]: memref, -// CHECK-SAME: %[[VAL_2:.*]]: memref) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref, memref) outs(%[[VAL_2]] : memref) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] -// CHECK: return -// CHECK: } - -// ----- - -func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-LABEL: func.func @matmul_bcast_a_transpose_b( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] -// CHECK: return -// CHECK: } - -// ----- - -func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) { - linalg.matmul indexing_maps = [ - affine_map<(d0, d1, d2) -> (d2, d0)>, - affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ] - ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>) - return -} - -// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> -// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK-LABEL: func.func @matmul_bcast_b_transpose_a( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) { -// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] -// CHECK: return -// CHECK: } - -// ----- - // CHECK-LABEL: func @matmul_transpose_b // CHECK: linalg.matmul_transpose_b // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 72045a07b2da8..3bfbcf7d7f7c8 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -84,6 +84,81 @@ def named_form(lhs, rhs): print(module) + +# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm +@run +def testNamedStructuredOpGenericForm(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) + ) + def named_form(lhs, rhs): + init_result = tensor.empty([4, 8], f32) + # CHECK: "linalg.matmul"(%{{.*}}) + # CHECK-SAME: cast = #linalg.type_fn + # CHECK-SAME: operandSegmentSizes = array + # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): + # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 + # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 + # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () + # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + return linalg.matmul(lhs, rhs, outs=[init_result]) + + module.operation.print(print_generic_op_form=True) + + +# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp +@run +def testNamedStructuredAsGenericOp(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) + ) + def generic_form(lhs, rhs): + init_result = tensor.EmptyOp([4, 8], f32) + # CHECK: linalg.generic + return linalg.matmul( + lhs, rhs, outs=[init_result.result], emit_generic=True + ) + + print(module) + + +# CHECK-LABEL: TEST: testOpResultFromOtherOp +@run +def testOpResultFromOtherOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) + ) + def pass_an_op_directly(arg0, arg1): + one = arith.ConstantOp(F32Type.get(), 1.0) + # CHECK: %[[LHS:.*]] = linalg.fill + lhs = linalg.fill(one, outs=[arg0]) + # CHECK: %[[RHS:.*]] = linalg.fill + rhs = linalg.fill(one, outs=[arg1]) + # CHECK: %[[INIT:.*]] = tensor.empty + init = tensor.EmptyOp([4, 8], f32) + # CHECK: linalg.matmul + # CHECK: ins(%[[LHS]], %[[RHS]] + # CHECK: outs(%[[INIT]] + return linalg.matmul(lhs, rhs, outs=init) + + print(module) + + # CHECK-LABEL: TEST: testIdentityRegionOps @run def testIdentityRegionOps(): diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index f77900bc27773..f6519fb17a6b9 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -50,6 +50,37 @@ def log(*args): } """ +matmul_boiler = """ +func.func @main() -> f32 attributes {llvm.emit_c_interface} { + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant -1 : i8 + %v2 = arith.constant 2.0 : f32 + + %A = memref.alloc() : memref<4x16xi8> + %B = memref.alloc() : memref<16x8xf32> + %C0 = memref.alloc() : memref<4x8xf32> + %C1 = memref.alloc() : memref<4x8xf32> + linalg.fill ins(%v1 : i8) outs(%A : memref<4x16xi8>) + linalg.fill ins(%v2 : f32) outs(%B : memref<16x8xf32>) + linalg.fill ins(%v0 : f32) outs(%C0 : memref<4x8xf32>) + linalg.fill ins(%v0 : f32) outs(%C1 : memref<4x8xf32>) + + call @matmul_signed_on_buffers(%A, %B, %C0) : + (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () + call @matmul_unsigned_on_buffers(%A, %B, %C1) : + (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> () + + %c0 = arith.constant 0 : index + %res0 = memref.load %C0[%c0, %c0] : memref<4x8xf32> + %res1 = memref.load %C1[%c0, %c0] : memref<4x8xf32> + + %0 = arith.addf %res0, %res1 : f32 + + // TODO: FFI-based solution to allow testing and printing with python code. + return %0 : f32 +} +""" + fill_boiler = """ func.func @main() -> i32 attributes {llvm.emit_c_interface} { %O0 = memref.alloc() : memref @@ -265,6 +296,90 @@ def elemwise_log_mul_on_buffers(lhs, rhs, out): test_elemwise_generic() +def test_matmul_builtin(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32), + ) + def matmul_signed_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out]) + + @func.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32), + ) + def matmul_unsigned_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) + + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.0) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32 + # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160 + # CHECK: RESULT: 8128 + + +test_matmul_builtin() + + +def test_matmul_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32), + ) + def matmul_signed_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) + + @func.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32), + ) + def matmul_unsigned_on_buffers(lhs, rhs, out): + linalg.matmul( + lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True + ) + + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.0) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32 + # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160 + # CHECK: RESULT: 8128 + + +test_matmul_generic() + + def test_fill_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() diff --git a/mlir/test/python/integration/dialects/transform.py b/mlir/test/python/integration/dialects/transform.py index 303274a8f8828..bc88a61314d0d 100644 --- a/mlir/test/python/integration/dialects/transform.py +++ b/mlir/test/python/integration/dialects/transform.py @@ -99,28 +99,26 @@ def basic(target: any_op_t()): # CHECK-LABEL: TEST: test_apply_patterns @construct_and_print_in_module def test_apply_patterns(module_): - b, M, N, K = 1, 3, 5, 3 + M, N, K = 3, 5, 3 - # CHECK-LABEL: func.func @batch_reduce_matmul( - # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>, - # CHECK-SAME: %[[VAL_1:.*]]: tensor<1x5x3xf32>, - # CHECK-SAME: %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { + # CHECK-LABEL: func.func @matmul( + # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32 - # CHECK: %[[VAL_5:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> + # CHECK: %[[VAL_5:.*]] = linalg.matmul {cast = #linalg.type_fn} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> # CHECK: return %[[VAL_5]] : tensor<3x3xf32> # CHECK: } @func.func( - T.tensor(b, M, N, T.f32()), T.tensor(b, N, K, T.f32()), T.tensor(M, K, T.f32()) + T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32()) ) - def batch_reduce_matmul(A, B, C): + def matmul(A, B, C): i = arith.constant(T.i32(), 1) v = arith.addi(i, i) - return linalg.batch_reduce_matmul(A, B, outs=[C]) + return linalg.matmul(A, B, outs=[C]) # CHECK-LABEL: module attributes {transform.with_named_sequence} { # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) { - # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op + # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation # CHECK: transform.apply_patterns to %[[VAL_2]] { # CHECK: transform.apply_patterns.canonicalization @@ -134,9 +132,7 @@ def batch_reduce_matmul(A, B, C): def mod(): @named_sequence("__transform_main", [any_op_t()], []) def basic(variant_op: any_op_t()): - matmul = structured_match( - any_op_t(), variant_op, ops=["linalg.batch_reduce_matmul"] - ) + matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"]) top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func") @apply_patterns(top_func) @@ -151,9 +147,9 @@ def pats(): pm = PassManager.parse("builtin.module(transform-interpreter)") pm.run(module_.operation) - # CHECK-LABEL: func.func @batch_reduce_matmul( - # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>, %[[VAL_1:.*]]: tensor<1x5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { - # CHECK: %[[VAL_3:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> + # CHECK-LABEL: func.func @matmul( + # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { + # CHECK: %[[VAL_3:.*]] = linalg.matmul {cast = #linalg.type_fn} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> # CHECK: return %[[VAL_3]] : tensor<3x3xf32> # CHECK: } print(module_) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 6be7d4320c656..5f86c0cd74707 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -678,11 +678,7 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ {0}::getNumRegionArgs(), {0}::getRegionBuilder()); } void {0}::print(OpAsmPrinter &p) {{ - SmallVector elidedAttrs = {{"operandSegmentSizes", - "linalg.memoized_indexing_maps", - "indexing_maps"}; - ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), - elidedAttrs); + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); } )FMT";