diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt index ec433284e17ad..75f65f39f2371 100644 --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgDialect LinalgOps.cpp LinalgDialect.cpp ValueBoundsOpInterfaceImpl.cpp + RegionBuilderHelper.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cbc565b0c8cbd..f029613a280e1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -9,7 +9,6 @@ // This file implements the Linalg operations. // //===----------------------------------------------------------------------===// - #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/AsmParser/AsmParser.h" @@ -50,6 +49,8 @@ #include #include +#include "RegionBuilderHelper.h" + using namespace mlir; using namespace mlir::linalg; @@ -411,296 +412,6 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, // Region is elided. } -//===----------------------------------------------------------------------===// -// Region builder helper. -// TODO: Move this to a utility library. -// The public methods on this class are referenced directly from generated code. -// Helper build the unary, binary, and type conversion functions defined by the -// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this -// class. -// -// Implementations of the math functions must be polymorphic over numeric types, -// internally performing necessary casts. If the function application makes no -// sense, then the only recourse is to assert and return nullptr. This can be -// extended later if it becomes possible to fail construction of the region. The -// invariant should be enforced at a higher level. -// -// TODO: These helpers are currently type polymorphic over the class of integer -// and floating point types, but they will not internally cast within bit -// widths of a class (mixed precision such as i8->i32) or across classes -// (i.e. mixed float and integer). Many such combinations are ambiguous or need -// to be handled with care and work is being considered to extend the op -// language to make such cases explicit. In the mean-time, violating this will -// fail verification, which is deemed acceptable. -//===----------------------------------------------------------------------===// - -namespace { - -class RegionBuilderHelper { -public: - RegionBuilderHelper(OpBuilder &builder, Block &block) - : builder(builder), block(block) {} - - // Build the unary functions defined by OpDSL. - Value buildUnaryFn(UnaryFn unaryFn, Value arg, - function_ref emitError = {}) { - if (!isFloatingPoint(arg)) { - if (emitError) { - emitError() << "unsupported non numeric type"; - return nullptr; - } - llvm_unreachable("unsupported non numeric type"); - } - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointToEnd(&block); - switch (unaryFn) { - case UnaryFn::exp: - return math::ExpOp::create(builder, arg.getLoc(), arg); - case UnaryFn::log: - return math::LogOp::create(builder, arg.getLoc(), arg); - case UnaryFn::abs: - return math::AbsFOp::create(builder, arg.getLoc(), arg); - case UnaryFn::ceil: - return math::CeilOp::create(builder, arg.getLoc(), arg); - case UnaryFn::floor: - return math::FloorOp::create(builder, arg.getLoc(), arg); - case UnaryFn::negf: - return arith::NegFOp::create(builder, arg.getLoc(), arg); - case UnaryFn::reciprocal: { - Attribute oneAttr = builder.getOneAttr(arg.getType()); - auto one = arith::ConstantOp::create(builder, arg.getLoc(), - ::cast(oneAttr)); - return arith::DivFOp::create(builder, arg.getLoc(), one, arg); - } - case UnaryFn::round: - return math::RoundOp::create(builder, arg.getLoc(), arg); - case UnaryFn::sqrt: - return math::SqrtOp::create(builder, arg.getLoc(), arg); - case UnaryFn::rsqrt: - return math::RsqrtOp::create(builder, arg.getLoc(), arg); - case UnaryFn::square: - return arith::MulFOp::create(builder, arg.getLoc(), arg, arg); - case UnaryFn::tanh: - return math::TanhOp::create(builder, arg.getLoc(), arg); - case UnaryFn::erf: - return math::ErfOp::create(builder, arg.getLoc(), arg); - } - if (emitError) { - emitError() << "unsupported unary function"; - return nullptr; - } - llvm_unreachable("unsupported unary function"); - } - - // Build the binary functions defined by OpDSL. - // If emitError is provided, an error will be emitted if the operation is not - // supported and a nullptr will be returned, otherwise an assertion will be - // raised. - Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1, - function_ref emitError = {}) { - bool allComplex = isComplex(arg0) && isComplex(arg1); - bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); - bool allInteger = isInteger(arg0) && isInteger(arg1); - bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && - arg1.getType().getIntOrFloatBitWidth() == 1; - if (!allComplex && !allFloatingPoint && !allInteger) { - if (emitError) { - emitError() - << "Cannot build binary Linalg operation: expects allComplex, " - "allFloatingPoint, or allInteger, got " - << arg0.getType() << " and " << arg1.getType(); - return nullptr; - } - llvm_unreachable("unsupported non numeric type"); - } - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointToEnd(&block); - switch (binaryFn) { - case BinaryFn::add: - if (allComplex) - return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1); - if (allFloatingPoint) - return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1); - if (allBool) - return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1); - return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::sub: - if (allComplex) - return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1); - if (allFloatingPoint) - return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1); - if (allBool) { - if (emitError) { - emitError() << "unsupported operation: sub with bools"; - return nullptr; - } - llvm_unreachable("unsupported operation: sub with bools"); - } - return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::mul: - if (allComplex) - return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1); - if (allFloatingPoint) - return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1); - if (allBool) - return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1); - return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::div: - if (allComplex) - return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1); - if (allFloatingPoint) - return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1); - if (allBool) { - if (emitError) { - emitError() << "unsupported operation: div with bools"; - return nullptr; - } - llvm_unreachable("unsupported operation: div with bools"); - } - return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::div_unsigned: - if (!allInteger || allBool) { - if (emitError) { - emitError() << "unsupported operation: unsigned div not on uint"; - return nullptr; - } - llvm_unreachable("unsupported operation: unsigned div not on uint"); - } - return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::max_signed: - assert(!allComplex); - if (allFloatingPoint) - return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); - return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::min_signed: - assert(!allComplex); - if (allFloatingPoint) - return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); - return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::max_unsigned: - assert(!allComplex); - if (allFloatingPoint) - return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); - return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::min_unsigned: - assert(!allComplex); - if (allFloatingPoint) - return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); - return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1); - case BinaryFn::powf: - assert(allFloatingPoint); - return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1); - } - if (emitError) { - emitError() << "unsupported binary function"; - return nullptr; - } - llvm_unreachable("unsupported binary function"); - } - - // Build the ternary functions defined by OpDSL. - Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2, - function_ref emitError = {}) { - bool headBool = - isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1; - bool tailFloatingPoint = - isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2); - bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2); - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointToEnd(&block); - switch (ternaryFn) { - case TernaryFn::select: - if (!headBool && !(tailFloatingPoint || tailInteger)) - llvm_unreachable("unsupported non numeric type"); - return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2); - } - if (emitError) { - emitError() << "unsupported ternary function"; - return nullptr; - } - llvm_unreachable("unsupported ternary function"); - } - - // Build the type functions defined by OpDSL. - Value buildTypeFn(TypeFn typeFn, Type toType, Value operand, - function_ref emitError = {}) { - switch (typeFn) { - case TypeFn::cast_signed: - return cast(toType, operand, false); - case TypeFn::cast_unsigned: - return cast(toType, operand, true); - } - if (emitError) { - emitError() << "unsupported type conversion function"; - return nullptr; - } - llvm_unreachable("unsupported type conversion function"); - } - - void yieldOutputs(ValueRange values) { - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointToEnd(&block); - Location loc = builder.getUnknownLoc(); - YieldOp::create(builder, loc, values); - } - - Value constant(const std::string &value) { - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointToEnd(&block); - Location loc = builder.getUnknownLoc(); - Attribute valueAttr = parseAttribute(value, builder.getContext()); - return arith::ConstantOp::create(builder, loc, - ::cast(valueAttr)); - } - - Value index(int64_t dim) { - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointToEnd(&block); - return IndexOp::create(builder, builder.getUnknownLoc(), dim); - } - - Type getIntegerType(unsigned width) { - return IntegerType::get(builder.getContext(), width); - } - - Type getFloat32Type() { return Float32Type::get(builder.getContext()); } - Type getFloat64Type() { return Float64Type::get(builder.getContext()); } - -private: - // Generates operations to cast the given operand to a specified type. - // If the cast cannot be performed, a warning will be issued and the - // operand returned as-is (which will presumably yield a verification - // issue downstream). - Value cast(Type toType, Value operand, bool isUnsignedCast) { - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointToEnd(&block); - auto loc = operand.getLoc(); - if (isa(loc)) { - if (operand.getDefiningOp()) - loc = operand.getDefiningOp()->getLoc(); - else if (operand.getParentBlock() && - operand.getParentBlock()->getParentOp()) - loc = operand.getParentBlock()->getParentOp()->getLoc(); - } - return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); - } - - bool isComplex(Value value) { - return llvm::isa(value.getType()); - } - bool isFloatingPoint(Value value) { - return llvm::isa(value.getType()); - } - bool isInteger(Value value) { - return llvm::isa(value.getType()); - } - - OpBuilder &builder; - Block █ -}; - -} // namespace - //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp new file mode 100644 index 0000000000000..8e4b7e98eb5ce --- /dev/null +++ b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.cpp @@ -0,0 +1,237 @@ +//===- RegionBuilderHelper.cpp - Region Builder Helper class -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of RegionBuilderHelper class. +// +//===----------------------------------------------------------------------===// + +#include "RegionBuilderHelper.h" + +namespace mlir { +namespace linalg { + +Value RegionBuilderHelper::buildUnaryFn( + UnaryFn unaryFn, Value arg, function_ref emitError) { + if (!isFloatingPoint(arg)) { + if (emitError) { + emitError() << "unsupported non numeric type"; + return nullptr; + } + llvm_unreachable("unsupported non numeric type"); + } + + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); + switch (unaryFn) { + case UnaryFn::exp: + return math::ExpOp::create(builder, arg.getLoc(), arg); + case UnaryFn::log: + return math::LogOp::create(builder, arg.getLoc(), arg); + case UnaryFn::abs: + return math::AbsFOp::create(builder, arg.getLoc(), arg); + case UnaryFn::ceil: + return math::CeilOp::create(builder, arg.getLoc(), arg); + case UnaryFn::floor: + return math::FloorOp::create(builder, arg.getLoc(), arg); + case UnaryFn::negf: + return arith::NegFOp::create(builder, arg.getLoc(), arg); + case UnaryFn::reciprocal: { + Attribute oneAttr = builder.getOneAttr(arg.getType()); + auto one = arith::ConstantOp::create(builder, arg.getLoc(), + llvm::cast(oneAttr)); + return arith::DivFOp::create(builder, arg.getLoc(), one, arg); + } + case UnaryFn::round: + return math::RoundOp::create(builder, arg.getLoc(), arg); + case UnaryFn::sqrt: + return math::SqrtOp::create(builder, arg.getLoc(), arg); + case UnaryFn::rsqrt: + return math::RsqrtOp::create(builder, arg.getLoc(), arg); + case UnaryFn::square: + return arith::MulFOp::create(builder, arg.getLoc(), arg, arg); + case UnaryFn::tanh: + return math::TanhOp::create(builder, arg.getLoc(), arg); + case UnaryFn::erf: + return math::ErfOp::create(builder, arg.getLoc(), arg); + } + + if (emitError) { + emitError() << "unsupported unary function"; + return nullptr; + } + llvm_unreachable("unsupported unary function"); +} + +Value RegionBuilderHelper::buildBinaryFn( + BinaryFn binaryFn, Value arg0, Value arg1, + function_ref emitError) { + + bool allComplex = isComplex(arg0) && isComplex(arg1); + bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); + bool allInteger = isInteger(arg0) && isInteger(arg1); + bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && + arg1.getType().getIntOrFloatBitWidth() == 1; + + if (!allComplex && !allFloatingPoint && !allInteger) { + if (emitError) { + emitError() + << "Cannot build binary Linalg operation: expects allComplex, " + "allFloatingPoint, or allInteger, got " + << arg0.getType() << " and " << arg1.getType(); + return nullptr; + } + llvm_unreachable("unsupported non numeric type"); + } + + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); + switch (binaryFn) { + case BinaryFn::add: + if (allComplex) + return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1); + if (allFloatingPoint) + return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1); + if (allBool) + return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::sub: + if (allComplex) + return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1); + if (allFloatingPoint) + return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1); + if (allBool) { + if (emitError) { + emitError() << "unsupported operation: sub with bools"; + return nullptr; + } + llvm_unreachable("unsupported operation: sub with bools"); + } + return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::mul: + if (allComplex) + return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1); + if (allFloatingPoint) + return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1); + if (allBool) + return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::div: + if (allComplex) + return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1); + if (allFloatingPoint) + return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1); + if (allBool) { + if (emitError) { + emitError() << "unsupported operation: div with bools"; + return nullptr; + } + llvm_unreachable("unsupported operation: div with bools"); + } + return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::div_unsigned: + if (!allInteger || allBool) { + if (emitError) { + emitError() << "unsupported operation: unsigned div not on uint"; + return nullptr; + } + llvm_unreachable("unsupported operation: unsigned div not on uint"); + } + return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::max_signed: + assert(!allComplex); + if (allFloatingPoint) + return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::min_signed: + assert(!allComplex); + if (allFloatingPoint) + return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::max_unsigned: + assert(!allComplex); + if (allFloatingPoint) + return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::min_unsigned: + assert(!allComplex); + if (allFloatingPoint) + return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1); + case BinaryFn::powf: + assert(allFloatingPoint); + return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1); + } + + if (emitError) { + emitError() << "unsupported binary function"; + return nullptr; + } + llvm_unreachable("unsupported binary function"); +} + +Value RegionBuilderHelper::buildTernaryFn( + TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2, + function_ref emitError) { + bool headBool = + isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1; + bool tailFloatingPoint = + isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2); + bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); + + switch (ternaryFn) { + case TernaryFn::select: + if (!headBool && !(tailFloatingPoint || tailInteger)) + llvm_unreachable("unsupported non numeric type"); + return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2); + } + + if (emitError) { + emitError() << "unsupported ternary function"; + return nullptr; + } + llvm_unreachable("unsupported ternary function"); +} + +Value RegionBuilderHelper::buildTypeFn( + TypeFn typeFn, Type toType, Value operand, + function_ref emitError) { + switch (typeFn) { + case TypeFn::cast_signed: + return cast(toType, operand, false); + case TypeFn::cast_unsigned: + return cast(toType, operand, true); + } + + if (emitError) { + emitError() << "unsupported type conversion function"; + return nullptr; + } + llvm_unreachable("unsupported type conversion function"); +} + +void RegionBuilderHelper::yieldOutputs(ValueRange values) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); + Location loc = builder.getUnknownLoc(); + YieldOp::create(builder, loc, values); +} + +Value RegionBuilderHelper::constant(const std::string &value) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); + Location loc = builder.getUnknownLoc(); + Attribute valueAttr = parseAttribute(value, builder.getContext()); + return arith::ConstantOp::create(builder, loc, + llvm::cast(valueAttr)); +} + +} // namespace linalg + +} // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h new file mode 100644 index 0000000000000..d5a59be1d526d --- /dev/null +++ b/mlir/lib/Dialect/Linalg/IR/RegionBuilderHelper.h @@ -0,0 +1,148 @@ +//===- RegionBuilderHelper.h - Region-Builder-Helper class declaration ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper builds the unary, binary, and type conversion functions defined by the +// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this +// class. +// +// Implementations of the math functions must be polymorphic over numeric types, +// internally performing necessary casts. If the function application makes no +// sense, then the only recourse is to assert and return nullptr. This can be +// extended later if it becomes possible to fail construction of the region. The +// invariant should be enforced at a higher level. +// +// TODO: These helpers are currently type polymorphic over the class of integer +// and floating point types, but they will not internally cast within bit +// widths of a class (mixed precision such as i8->i32) or across classes +// (i.e. mixed float and integer). Many such combinations are ambiguous or need +// to be handled with care and work is being considered to extend the op +// language to make such cases explicit. In the mean-time, violating this will +// fail verification, which is deemed acceptable. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LINALG_REGION_BUILDER_HELPER_H +#define MLIR_LINALG_REGION_BUILDER_HELPER_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InterleavedRange.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +namespace mlir { +namespace linalg { + +class RegionBuilderHelper { +public: + RegionBuilderHelper(OpBuilder &builder, Block &block) + : builder(builder), block(block) {} + + // Build the unary functions. + Value buildUnaryFn(UnaryFn unaryFn, Value arg, + function_ref emitError = {}); + + // Build the binary functions. + // If emitError is provided, an error will be emitted if the operation is not + // supported and a nullptr will be returned, otherwise an assertion is raised. + Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1, + function_ref emitError = {}); + + // Build the ternary functions defined by OpDSL. + Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2, + function_ref emitError = {}); + + // Build the type functions defined by OpDSL. + Value buildTypeFn(TypeFn typeFn, Type toType, Value operand, + function_ref emitError = {}); + + // Create a `yieldOp` to yield `values` passed in as arg. + void yieldOutputs(ValueRange values); + + // Create a constant op with value parsed from string `value`. + Value constant(const std::string &value); + + // Create an `index` op to extract iteration index `dim`. + Value index(int64_t dim) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); + return IndexOp::create(builder, builder.getUnknownLoc(), dim); + } + + // Create an integer of size `width`. + Type getIntegerType(unsigned width) { + return IntegerType::get(builder.getContext(), width); + } + + Type getFloat32Type() { return Float32Type::get(builder.getContext()); } + Type getFloat64Type() { return Float64Type::get(builder.getContext()); } + +private: + // Generates operations to cast the given operand to a specified type. + // If the cast cannot be performed, a warning will be issued and the + // operand returned as-is (which will presumably yield a verification + // issue downstream). + Value cast(Type toType, Value operand, bool isUnsignedCast) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&block); + auto loc = operand.getLoc(); + if (isa(loc)) { + if (operand.getDefiningOp()) + loc = operand.getDefiningOp()->getLoc(); + else if (operand.getParentBlock() && + operand.getParentBlock()->getParentOp()) + loc = operand.getParentBlock()->getParentOp()->getLoc(); + } + return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast); + } + + bool isComplex(Value value) { + return llvm::isa(value.getType()); + } + bool isFloatingPoint(Value value) { + return llvm::isa(value.getType()); + } + bool isInteger(Value value) { + return llvm::isa(value.getType()); + } + + OpBuilder &builder; + Block █ +}; + +} // namespace linalg +} // namespace mlir + +#endif // MLIR_LINALG_REGION_BUILDER_HELPER_H