diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h index 57bf6305a469d..a0fb0111d6ace 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -16,6 +16,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" @@ -26,6 +27,9 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" + +#include "llvm/ADT/STLFunctionalExtras.h" + #include namespace mlir { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 33601c5d6dad9..a459656b982e6 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -52,7 +52,8 @@ def Linalg_Dialect : Dialect { kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps"; using RegionBuilderFunType = llvm::function_ref< - void(ImplicitLocOpBuilder &b, Block &, ArrayRef)>; + void(ImplicitLocOpBuilder &b, Block &, ArrayRef, + function_ref)>; RegionBuilderFunType getRegionBuilder(StringRef name) { return namedStructuredOpRegionBuilders.lookup(name); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index ca1cba8747bd8..ba73cfbbed845 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -720,7 +720,7 @@ def LinalgStructuredInterface Returns a null function if this named op does not define a region builder. }], - /*retTy=*/"std::function)>", + /*retTy=*/"std::function, function_ref)>", /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 61783812920bc..7bbc56f549c0b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -192,7 +192,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [ } static std::function)> + Block &, ArrayRef, + function_ref)> getRegionBuilder() { return nullptr; } @@ -300,7 +301,8 @@ def MapOp : LinalgStructuredBase_Op<"map", [ } static std::function)> + mlir::ArrayRef, + function_ref)> getRegionBuilder() { return nullptr; } @@ -380,7 +382,8 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ // Implement functions necessary for DestinationStyleOpInterface. static std::function)> + mlir::ArrayRef, + function_ref)> getRegionBuilder() { return nullptr; } @@ -449,13 +452,14 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [ MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, - mlir::ArrayRef) { + mlir::ArrayRef, function_ref emitError) { OpBuilder::InsertionGuard guard(b); b.create(b.getLoc(), block.getArgument(0)); } static std::function)> + mlir::ArrayRef, + function_ref)> getRegionBuilder() { return regionBuilder; } @@ -521,13 +525,15 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, - mlir::ArrayRef) { + mlir::ArrayRef, + function_ref emitError) { OpBuilder::InsertionGuard guard(b); b.create(b.getLoc(), block.getArgument(0)); } static std::function)> + mlir::ArrayRef, + function_ref)> getRegionBuilder() { return regionBuilder; } @@ -631,10 +637,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ /// Implements the block region builder for the elementwiseOp. This is /// called by the 'fillStructuredOpRegion'. static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); + Block &block, ArrayRef attrs, + function_ref emitError); static std::function)> + Block &, ArrayRef, + function_ref)> getRegionBuilder() { return regionBuilder; } @@ -771,7 +779,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ /// Implements the block region builder. static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); + Block &block, ArrayRef attrs, + function_ref emitError); /// Returns a list of AffineMap with the default matmul indexing charactristic. static SmallVector getDefaultIndexingMaps(MLIRContext *context); @@ -780,7 +789,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap); static std::function)> + Block &, ArrayRef, + function_ref)> getRegionBuilder() { return regionBuilder; } @@ -916,10 +926,12 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [ static unsigned getNumRegionArgs(); static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); + Block &block, ArrayRef attrs, + function_ref emitError); static std::function)> + Block &, ArrayRef, + function_ref)> getRegionBuilder() { return regionBuilder; } @@ -1033,9 +1045,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz SmallVector getIteratorTypesArray(); static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); + Block &block, ArrayRef attrs, + function_ref emitError); static std::function)> + Block &, ArrayRef, + function_ref)> getRegionBuilder() { return regionBuilder; } @@ -1161,7 +1175,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [ /// Implements the block region builder. static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); + Block &block, ArrayRef attrs, + function_ref emitError); /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic. static SmallVector getDefaultIndexingMaps(MLIRContext *context); @@ -1170,7 +1185,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true); static std::function)> + Block &, ArrayRef, + function_ref)> getRegionBuilder() { return regionBuilder; } diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 0c4f6e88e7078..21db18dfd47ed 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); - fun(b, *body, op->getAttrs()); + fun(b, *body, op->getAttrs(), /*emitError=*/{}); } MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 5dbb2403eddbd..5ab44607d6c4a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -117,8 +117,9 @@ OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source, // Support for named Linalg ops defined in ods-gen. //===----------------------------------------------------------------------===// -using RegionBuilderFn = llvm::function_ref)>; +using RegionBuilderFn = llvm::function_ref, + function_ref)>; /// Fills the region of a structured operation using the provided /// `regionBuilder`. The method is used by both named structured ops created by @@ -128,6 +129,7 @@ using RegionBuilderFn = llvm::function_ref attrs, + function_ref emitError, RegionBuilderFn regionBuilder) { SmallVector argTypes; SmallVector argLocs; @@ -148,7 +150,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, opBuilder.setInsertionPointToStart(body); ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - regionBuilder(b, *body, attrs); + regionBuilder(b, *body, attrs, emitError); // indexing_maps is an auto-generated method. @@ -184,7 +186,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state, // Create and fill the region of the structured operation. Region ®ion = *state.addRegion(); fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs), - state.attributes.getAttrs(), regionBuilder); + state.attributes.getAttrs(), /*emitError=*/{}, + regionBuilder); } static void buildMatmulOp(OpBuilder &b, OperationState &state, @@ -329,7 +332,7 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, static ParseResult parseNamedStructuredOpRegion( OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef attrs, - RegionBuilderFn regionBuilder) { + RegionBuilderFn regionBuilder, SMLoc loc) { if (numRegionArgs != inputTypes.size() + outputTypes.size()) { return parser.emitError( parser.getCurrentLocation(), @@ -339,9 +342,15 @@ static ParseResult parseNamedStructuredOpRegion( } OpBuilder opBuilder(parser.getContext()); - fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs, - regionBuilder); - return success(); + ParseResult result = success(); + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, attrs, + [&]() { + result = failure(); + return parser.emitError(loc); + }, + regionBuilder); + return result; } static ParseResult @@ -358,6 +367,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser, RegionBuilderFn regionBuilder) { // TODO: Enable when ods-gen supports captures. SmallVector inputTypes, outputTypes; + SMLoc loc = parser.getCurrentLocation(); if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); @@ -375,7 +385,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser, std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes, outputTypes, result.attributes.getAttrs(), - regionBuilder)) + regionBuilder, loc)) return failure(); result.addRegion(std::move(region)); @@ -435,9 +445,15 @@ class RegionBuilderHelper { : builder(builder), block(block) {} // Build the unary functions defined by OpDSL. - Value buildUnaryFn(UnaryFn unaryFn, Value arg) { - if (!isFloatingPoint(arg)) + Value buildUnaryFn(UnaryFn unaryFn, Value arg, + function_ref emitError = {}) { + if (!isFloatingPoint(arg)) { + if (emitError) { + emitError() << "unsupported non numeric type"; + return nullptr; + } llvm_unreachable("unsupported non numeric type"); + } OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); switch (unaryFn) { @@ -472,18 +488,34 @@ class RegionBuilderHelper { case UnaryFn::erf: return builder.create(arg.getLoc(), arg); } + if (emitError) { + emitError() << "unsupported unary function"; + return nullptr; + } llvm_unreachable("unsupported unary function"); } // Build the binary functions defined by OpDSL. - Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) { + // If emitError is provided, an error will be emitted if the operation is not + // supported and a nullptr will be returned, otherwise an assertion will be + // raised. + Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1, + function_ref emitError = {}) { bool allComplex = isComplex(arg0) && isComplex(arg1); bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); bool allInteger = isInteger(arg0) && isInteger(arg1); bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && arg1.getType().getIntOrFloatBitWidth() == 1; - if (!allComplex && !allFloatingPoint && !allInteger) + if (!allComplex && !allFloatingPoint && !allInteger) { + if (emitError) { + emitError() + << "Cannot build binary Linalg operation: expects allComplex, " + "allFloatingPoint, or allInteger, got " + << arg0.getType() << " and " << arg1.getType(); + return nullptr; + } llvm_unreachable("unsupported non numeric type"); + } OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); switch (binaryFn) { @@ -500,8 +532,13 @@ class RegionBuilderHelper { return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); - if (allBool) + if (allBool) { + if (emitError) { + emitError() << "unsupported operation: sub with bools"; + return nullptr; + } llvm_unreachable("unsupported operation: sub with bools"); + } return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::mul: if (allComplex) @@ -516,12 +553,22 @@ class RegionBuilderHelper { return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); - if (allBool) + if (allBool) { + if (emitError) { + emitError() << "unsupported operation: div with bools"; + return nullptr; + } llvm_unreachable("unsupported operation: div with bools"); + } return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::div_unsigned: - if (!allInteger || allBool) + if (!allInteger || allBool) { + if (emitError) { + emitError() << "unsupported operation: unsigned div not on uint"; + return nullptr; + } llvm_unreachable("unsupported operation: unsigned div not on uint"); + } return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: assert(!allComplex); @@ -547,12 +594,16 @@ class RegionBuilderHelper { assert(allFloatingPoint); return builder.create(arg0.getLoc(), arg0, arg1); } + if (emitError) { + emitError() << "unsupported binary function"; + return nullptr; + } llvm_unreachable("unsupported binary function"); } // Build the ternary functions defined by OpDSL. - Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, - Value arg2) { + Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2, + function_ref emitError = {}) { bool headBool = isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1; bool tailFloatingPoint = @@ -566,17 +617,26 @@ class RegionBuilderHelper { llvm_unreachable("unsupported non numeric type"); return builder.create(arg0.getLoc(), arg0, arg1, arg2); } + if (emitError) { + emitError() << "unsupported ternary function"; + return nullptr; + } llvm_unreachable("unsupported ternary function"); } // Build the type functions defined by OpDSL. - Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { + Value buildTypeFn(TypeFn typeFn, Type toType, Value operand, + function_ref emitError = {}) { switch (typeFn) { case TypeFn::cast_signed: return cast(toType, operand, false); case TypeFn::cast_unsigned: return cast(toType, operand, true); } + if (emitError) { + emitError() << "unsupported type conversion function"; + return nullptr; + } llvm_unreachable("unsupported type conversion function"); } @@ -617,6 +677,13 @@ class RegionBuilderHelper { OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); auto loc = operand.getLoc(); + if (isa(loc)) { + if (operand.getDefiningOp()) + loc = operand.getDefiningOp()->getLoc(); + else if (operand.getParentBlock() && + operand.getParentBlock()->getParentOp()) + loc = operand.getParentBlock()->getParentOp()->getLoc(); + } return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); } @@ -3664,9 +3731,15 @@ bool MatmulOp::hasUserDefinedMaps() { /// Implements the block region builder for the MatmulOp. This is called by /// 'fillStructuredOpRegion'. void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { - assert(3 > 0 && block.getNumArguments() == 3 && - "MatmulOp regionBuilder expects 3 (>=0) args"); + ArrayRef attrs, + function_ref emitError) { + if (emitError && block.getNumArguments() != 3) { + emitError() << "MatmulOp regionBuilder expects 3 args, got " + << block.getNumArguments(); + return; + } + assert(block.getNumArguments() == 3 && + "MatmulOp regionBuilder expects 3 args"); RegionBuilderHelper helper(b, block); SmallVector yields; @@ -3683,9 +3756,13 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, block.getArgument(0)); Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), block.getArgument(1)); - Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); - Value value4 = - helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); + Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError); + if (!value3) + return; + Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), + value3, emitError); + if (!value4) + return; yields.push_back(value4); helper.yieldOutputs(yields); } @@ -3813,7 +3890,13 @@ unsigned ContractOp::getNumRegionArgs() { return 3; } /// Implement block region builder, which is called by 'fillStructuredOpRegion'. void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { + ArrayRef attrs, + function_ref emitError) { + if (emitError && block.getNumArguments() != 3) { + emitError() << "ContractOp regionBuilder expects 3 args, got " + << block.getNumArguments(); + return; + } assert(block.getNumArguments() == 3 && "ContractOp regionBuilder expects 3 args"); RegionBuilderHelper helper(b, block); @@ -3833,10 +3916,14 @@ void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, helper.buildTypeFn(castSignedness, outType, block.getArgument(0)); Value rhsAtOutType = helper.buildTypeFn(castSignedness, outType, block.getArgument(1)); - Value productAtOutType = - helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType); + Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, + rhsAtOutType, emitError); + if (!productAtOutType) + return; Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), - productAtOutType); + productAtOutType, emitError); + if (!result) + return; helper.yieldOutputs({result}); } @@ -4028,10 +4115,16 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) { return isValid; } -void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { +void BatchMatmulOp::regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs, + function_ref emitError) { + if (emitError && block.getNumArguments() != 3) { + emitError() << "BatchMatmulOp regionBuilder expects 3 args, got " + << block.getNumArguments(); + return; + } assert(block.getNumArguments() == 3 && - "BatchMatmulOp regionBuilder expects 3 (>=0) args"); + "BatchMatmulOp regionBuilder expects 3 args"); RegionBuilderHelper helper(b, block); SmallVector yields; @@ -4303,8 +4396,9 @@ LogicalResult ElementwiseOp::verify() { /// Implements the block region builder for the ElementwiseOp. This is called by /// 'fillStructuredOpRegion'. -void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { +void ElementwiseOp::regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs, + function_ref emitError) { ElementwiseKind elemwiseKind; for (auto attr : attrs) { if (attr.getName() == b.getStringAttr("kind")) { @@ -4318,6 +4412,13 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind); auto arityGroup = groupAndKind.arityGroup; auto kind = groupAndKind.kind; + if (emitError && block.getNumArguments() != + getArityGroupAsUInt(arityGroup) + 1 /*output*/) { + emitError() << "Elementwise regionBuilder expects " + << (getArityGroupAsUInt(arityGroup) + 1) << " args, got " + << block.getNumArguments(); + return; + } assert(block.getNumArguments() == getArityGroupAsUInt(arityGroup) + 1 /*output*/ && "Elementwise regionBuilder number of block args mismatch"); @@ -5501,10 +5602,16 @@ bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, return isValid; } -void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ArrayRef attrs) { +void BatchReduceMatmulOp::regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs, + function_ref emitError) { + if (emitError && block.getNumArguments() != 3) { + emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got " + << block.getNumArguments(); + return; + } assert(block.getNumArguments() == 3 && - "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args"); + "BatchReduceMatmulOp regionBuilder expects 3 args"); RegionBuilderHelper helper(b, block); SmallVector yields; diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index ca40301f04fa1..7e338283ff7f8 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1868,9 +1868,51 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape( // ----- +//===----------------------------------------------------------------------===// +// linalg.reduce +//===----------------------------------------------------------------------===// + + func.func @reduce_non_operation_name(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor { // expected-error @below {{expected bare identifier or keyword}} %0 = linalg.reduce {@reduce_fusion_elementwise} ins( %arg0: tensor<4xf32>) outs(%arg1: tensor) dimensions = [0] return %0 : tensor } + +// ----- + + +//===----------------------------------------------------------------------===// +// Tests for generic infrastructure for named Ops. The actual Ops used are +// secondary - we merely want to ensure that the diagnostic infra triggers +// correctly. +//===----------------------------------------------------------------------===// + +module { + func.func @add_invalid_mixed_types(%in_f32: memref<3xf32>, %in_i32 : memref< 3xi32>, %out_f32: memref<3xf32>, %arg3: memref<3xf32>) { + // expected-error @below {{Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'f32' and 'i32'}} + linalg.add ins(%in_f32, %in_i32 : memref<3xf32>, memref< 3xi32>) outs(%out_f32 : memref<3xf32>) + return + } +} + +// ----- + +func.func @elemwise_unary_invalid_mixed_types(%arg0 : tensor) -> tensor { + // expected-error @below {{unsupported non numeric type}} + %0 = linalg.elemwise_unary ins(%arg0 : tensor) outs(%arg0 : tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @matmul_invalid_mixed_types(%t: tensor, %f: vector<4xf16>) + -> (tensor, vector<4xf16>) +{ + // expected-warning @unknown {{could not cast operand of type 'f16' to 'vector<4xf16>'}} + // expected-error @below {{Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'vector<4xf16>' and 'f16'}} + %0 = linalg.matmul ins(%t, %t : tensor, tensor) + outs(%f : vector<4xf16>) -> tensor + func.return %0, %f : tensor, vector<4xf16> +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 5fc7c33f4fb2b..1c961d272f192 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2737,12 +2737,14 @@ def TestLinalgConvOp : bool hasIndexSemantics() { return false; } static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, - mlir::ArrayRef attrs) { + mlir::ArrayRef attrs, + llvm::function_ref emitError) { b.create(block.getArguments().back()); } static std::function)> + mlir::ArrayRef, + llvm::function_ref)> getRegionBuilder() { return ®ionBuilder; } @@ -2798,12 +2800,14 @@ def TestLinalgFillOp : bool hasIndexSemantics() { return false; } static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, - mlir::ArrayRef attrs) { + mlir::ArrayRef attrs, + llvm::function_ref emitError) { b.create(block.getArguments().back()); } static std::function)> + mlir::ArrayRef, + llvm::function_ref)> getRegionBuilder() { return ®ionBuilder; } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml index ab7b86125f693..00c70705cbb35 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -87,7 +87,8 @@ structured_op: !LinalgStructuredOpConfig # ODS-NEXT: } # IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, -# IMPL-NEXT: Block &block, ArrayRef attrs) +# IMPL-NEXT: Block &block, ArrayRef attrs, +# IMPL-NEXT: function_ref emitError) # IMPL: TypeFn castVal = TypeFn::cast_signed; # IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { # IMPL-NEXT: return attr.getName() == "cast"; }); @@ -97,10 +98,10 @@ structured_op: !LinalgStructuredOpConfig # IMPL-NEXT: } # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); -# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]); +# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]], emitError); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); -# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]); -# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]); +# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]], emitError); +# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]], emitError); # @linalg_structured_op @@ -186,7 +187,8 @@ structured_op: !LinalgStructuredOpConfig # IMPL: "incorrect element type for index attribute 'strides'" # IMPL: "incorrect shape for index attribute 'strides'" # IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, -# IMPL-NEXT: Block &block, ArrayRef attrs) +# IMPL-NEXT: Block &block, ArrayRef attrs, +# IMPL-NEXT: function_ref emitError) # IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 && # IMPL: yields.push_back(block.getArgument(0)); @@ -315,13 +317,18 @@ structured_op: !LinalgStructuredOpConfig # ODS-NEXT: $_state.addAttribute("binary_fun", binary_fun) # IMPL-LABEL: void Test4Op::regionBuilder(ImplicitLocOpBuilder &b, -# IMPL-NEXT: Block &block, ArrayRef attrs) +# IMPL-NEXT: Block &block, ArrayRef attrs, +# IMPL-NEXT: function_ref emitError) # IMPL: UnaryFn unary_funVal = UnaryFn::exp # IMPL: BinaryFn binary_funVal = BinaryFn::add -# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0)) -# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0)) -# IMPL-NEXT: yields.push_back([[VAL1]]) +# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0), emitError); +# IMPL-NEXT: if (![[VAL0]]) +# IMPL-NEXT: return; +# IMPL: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0), emitError); +# IMPL-NEXT: if (![[VAL1]]) +# IMPL-NEXT: return; +# IMPL: yields.push_back([[VAL1]]) # @linalg_structured_op # def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)): diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 93a300e0b24a2..0a1693cff1d36 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -559,9 +559,10 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs); + Block &block, ArrayRef attrs, + function_ref emitError); static std::function)> + Block &, ArrayRef, function_ref emitError)> getRegionBuilder() {{ return regionBuilder; } @@ -1010,7 +1011,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ // {3}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( void {0}::regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ArrayRef attrs) {{ + Block &block, ArrayRef attrs, + function_ref emitError) {{ assert({1} > 0 && block.getNumArguments() == {1} && "{0} regionBuilder expects {1} (>=0) args"); RegionBuilderHelper helper(b, block); @@ -1137,8 +1139,13 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, // Call the function builder. std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back(llvm::formatv( - "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName, - funcType, interleaveToString(operandCppValues, ", "))); + R"mlir( + Value {0} = helper.build{1}({2}, {3}, emitError); + if (!{0}) + return; + )mlir", + cppIdent, enumName, funcType, + interleaveToString(operandCppValues, ", "))); return cppIdent; } emitError(genContext.getLoc()) << "unknown ScalarExpression type";