Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
LinalgOps.cpp
LinalgDialect.cpp
ValueBoundsOpInterfaceImpl.cpp
RegionBuilderHelper.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
Expand Down
293 changes: 2 additions & 291 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
// This file implements the Linalg operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"

#include "mlir/AsmParser/AsmParser.h"
Expand Down Expand Up @@ -50,6 +49,8 @@
#include <cassert>
#include <optional>

#include "RegionBuilderHelper.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: full include path mlir/lib/...?


using namespace mlir;
using namespace mlir::linalg;

Expand Down Expand Up @@ -411,296 +412,6 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
// Region is elided.
}

//===----------------------------------------------------------------------===//
// Region builder helper.
// TODO: Move this to a utility library.
// The public methods on this class are referenced directly from generated code.
// Helper build the unary, binary, and type conversion functions defined by the
// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
// class.
//
// Implementations of the math functions must be polymorphic over numeric types,
// internally performing necessary casts. If the function application makes no
// sense, then the only recourse is to assert and return nullptr. This can be
// extended later if it becomes possible to fail construction of the region. The
// invariant should be enforced at a higher level.
//
// TODO: These helpers are currently type polymorphic over the class of integer
// and floating point types, but they will not internally cast within bit
// widths of a class (mixed precision such as i8->i32) or across classes
// (i.e. mixed float and integer). Many such combinations are ambiguous or need
// to be handled with care and work is being considered to extend the op
// language to make such cases explicit. In the mean-time, violating this will
// fail verification, which is deemed acceptable.
//===----------------------------------------------------------------------===//

namespace {

class RegionBuilderHelper {
public:
RegionBuilderHelper(OpBuilder &builder, Block &block)
: builder(builder), block(block) {}

// Build the unary functions defined by OpDSL.
Value buildUnaryFn(UnaryFn unaryFn, Value arg,
function_ref<InFlightDiagnostic()> emitError = {}) {
if (!isFloatingPoint(arg)) {
if (emitError) {
emitError() << "unsupported non numeric type";
return nullptr;
}
llvm_unreachable("unsupported non numeric type");
}
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (unaryFn) {
case UnaryFn::exp:
return math::ExpOp::create(builder, arg.getLoc(), arg);
case UnaryFn::log:
return math::LogOp::create(builder, arg.getLoc(), arg);
case UnaryFn::abs:
return math::AbsFOp::create(builder, arg.getLoc(), arg);
case UnaryFn::ceil:
return math::CeilOp::create(builder, arg.getLoc(), arg);
case UnaryFn::floor:
return math::FloorOp::create(builder, arg.getLoc(), arg);
case UnaryFn::negf:
return arith::NegFOp::create(builder, arg.getLoc(), arg);
case UnaryFn::reciprocal: {
Attribute oneAttr = builder.getOneAttr(arg.getType());
auto one = arith::ConstantOp::create(builder, arg.getLoc(),
::cast<TypedAttr>(oneAttr));
return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
}
case UnaryFn::round:
return math::RoundOp::create(builder, arg.getLoc(), arg);
case UnaryFn::sqrt:
return math::SqrtOp::create(builder, arg.getLoc(), arg);
case UnaryFn::rsqrt:
return math::RsqrtOp::create(builder, arg.getLoc(), arg);
case UnaryFn::square:
return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
case UnaryFn::tanh:
return math::TanhOp::create(builder, arg.getLoc(), arg);
case UnaryFn::erf:
return math::ErfOp::create(builder, arg.getLoc(), arg);
}
if (emitError) {
emitError() << "unsupported unary function";
return nullptr;
}
llvm_unreachable("unsupported unary function");
}

// Build the binary functions defined by OpDSL.
// If emitError is provided, an error will be emitted if the operation is not
// supported and a nullptr will be returned, otherwise an assertion will be
// raised.
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
function_ref<InFlightDiagnostic()> emitError = {}) {
bool allComplex = isComplex(arg0) && isComplex(arg1);
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
bool allInteger = isInteger(arg0) && isInteger(arg1);
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
arg1.getType().getIntOrFloatBitWidth() == 1;
if (!allComplex && !allFloatingPoint && !allInteger) {
if (emitError) {
emitError()
<< "Cannot build binary Linalg operation: expects allComplex, "
"allFloatingPoint, or allInteger, got "
<< arg0.getType() << " and " << arg1.getType();
return nullptr;
}
llvm_unreachable("unsupported non numeric type");
}
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (binaryFn) {
case BinaryFn::add:
if (allComplex)
return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allBool)
return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::sub:
if (allComplex)
return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allBool) {
if (emitError) {
emitError() << "unsupported operation: sub with bools";
return nullptr;
}
llvm_unreachable("unsupported operation: sub with bools");
}
return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
if (allComplex)
return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allBool)
return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::div:
if (allComplex)
return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allBool) {
if (emitError) {
emitError() << "unsupported operation: div with bools";
return nullptr;
}
llvm_unreachable("unsupported operation: div with bools");
}
return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::div_unsigned:
if (!allInteger || allBool) {
if (emitError) {
emitError() << "unsupported operation: unsigned div not on uint";
return nullptr;
}
llvm_unreachable("unsupported operation: unsigned div not on uint");
}
return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_signed:
assert(!allComplex);
if (allFloatingPoint)
return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_signed:
assert(!allComplex);
if (allFloatingPoint)
return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
assert(!allComplex);
if (allFloatingPoint)
return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
assert(!allComplex);
if (allFloatingPoint)
return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::powf:
assert(allFloatingPoint);
return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
}
if (emitError) {
emitError() << "unsupported binary function";
return nullptr;
}
llvm_unreachable("unsupported binary function");
}

// Build the ternary functions defined by OpDSL.
Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
function_ref<InFlightDiagnostic()> emitError = {}) {
bool headBool =
isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
bool tailFloatingPoint =
isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (ternaryFn) {
case TernaryFn::select:
if (!headBool && !(tailFloatingPoint || tailInteger))
llvm_unreachable("unsupported non numeric type");
return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
}
if (emitError) {
emitError() << "unsupported ternary function";
return nullptr;
}
llvm_unreachable("unsupported ternary function");
}

// Build the type functions defined by OpDSL.
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
function_ref<InFlightDiagnostic()> emitError = {}) {
switch (typeFn) {
case TypeFn::cast_signed:
return cast(toType, operand, false);
case TypeFn::cast_unsigned:
return cast(toType, operand, true);
}
if (emitError) {
emitError() << "unsupported type conversion function";
return nullptr;
}
llvm_unreachable("unsupported type conversion function");
}

void yieldOutputs(ValueRange values) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
YieldOp::create(builder, loc, values);
}

Value constant(const std::string &value) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return arith::ConstantOp::create(builder, loc,
::cast<TypedAttr>(valueAttr));
}

Value index(int64_t dim) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
return IndexOp::create(builder, builder.getUnknownLoc(), dim);
}

Type getIntegerType(unsigned width) {
return IntegerType::get(builder.getContext(), width);
}

Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
Type getFloat64Type() { return Float64Type::get(builder.getContext()); }

private:
// Generates operations to cast the given operand to a specified type.
// If the cast cannot be performed, a warning will be issued and the
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand, bool isUnsignedCast) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
auto loc = operand.getLoc();
if (isa<UnknownLoc>(loc)) {
if (operand.getDefiningOp())
loc = operand.getDefiningOp()->getLoc();
else if (operand.getParentBlock() &&
operand.getParentBlock()->getParentOp())
loc = operand.getParentBlock()->getParentOp()->getLoc();
}
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
}

bool isComplex(Value value) {
return llvm::isa<ComplexType>(value.getType());
}
bool isFloatingPoint(Value value) {
return llvm::isa<FloatType>(value.getType());
}
bool isInteger(Value value) {
return llvm::isa<IntegerType>(value.getType());
}

OpBuilder &builder;
Block &block;
};

} // namespace

//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading