Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ namespace LLVM {
/// Generate IR that prints the given string to stdout.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef symbolName, StringRef string,
const LLVMTypeConverter &typeConverter,
bool addNewline = true,
std::optional<StringRef> runtimeFunctionName = {});
LogicalResult createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter,
bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
} // namespace LLVM

} // namespace mlir
Expand Down
58 changes: 31 additions & 27 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include <optional>

namespace mlir {
class Location;
Expand All @@ -29,42 +28,47 @@ class ValueRange;
namespace LLVM {
class LLVMFuncOp;

/// Helper functions to lookup or create the declaration for commonly used
/// Helper functions to look up or create the declaration for commonly used
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Failure if an unexpected version of function is found.
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
LLVM::LLVMFuncOp
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintStringFn(Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);

/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {}, bool isVarArg = false);
/// Return a failure if the FuncOp found has unexpected signature.
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {}, Type resultType = {},
bool isVarArg = false, bool isReserved = false);

} // namespace LLVM
} // namespace mlir
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
if (failed(allocFuncOp))
return failure();
auto coroAlloc = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange{coroAlign, coroSize});
loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});

// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
Expand Down Expand Up @@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
// Free the memory.
auto freeFuncOp =
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
if (failed(freeFuncOp))
return failure();
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
ValueRange(coroMem.getResult()));

return success();
Expand Down
10 changes: 7 additions & 3 deletions mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {

// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
*getTypeConverter(), /*addNewLine=*/false,
/*runtimeFunctionName=*/"puts");
if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg",
op.getMsg(), *getTypeConverter(),
/*addNewLine=*/false,
/*runtimeFunctionName=*/"puts")
.failed()) {
return failure();
}
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
Expand Down
17 changes: 12 additions & 5 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(

// Find the malloc and free, or declare them if necessary.
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
LLVM::LLVMFuncOp freeFunc, mallocFunc;
if (toDynamic)
FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
if (toDynamic) {
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
if (!toDynamic)
if (failed(mallocFunc))
return failure();
}
if (!toDynamic) {
freeFunc = LLVM::lookupOrCreateFreeFn(module);
if (failed(freeFunc))
return failure();
}

unsigned unrankedMemrefPos = 0;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
Expand All @@ -293,7 +299,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary.
Value memory =
toDynamic
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
? builder
.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
.getResult()
: builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
IntegerType::get(getContext(), 8),
Expand All @@ -302,7 +309,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
Value source = desc.memRefDescPtr(builder, loc);
builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
if (!toDynamic)
builder.create<LLVM::CallOp>(loc, freeFunc, source);
builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);

// Create a new descriptor. The same descriptor can be returned multiple
// times, attempting to modify its pointer can lead to memory leaks
Expand Down
15 changes: 10 additions & 5 deletions mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
return uniqueName;
}

void mlir::LLVM::createPrintStrCall(
LogicalResult mlir::LLVM::createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
std::optional<StringRef> runtimeFunctionName) {
Expand Down Expand Up @@ -59,8 +59,13 @@ void mlir::LLVM::createPrintStrCall(
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep =
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
Operation *printer =
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
if (auto printer =
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
succeeded(printer)) {
builder.create<LLVM::CallOp>(loc, TypeRange(),
SymbolRefAttr::get(printer.value()), gep);
} else {
return failure();
}
return success();
}
23 changes: 15 additions & 8 deletions mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
using namespace mlir;

namespace {
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
Operation *module, Type indexType) {
FailureOr<LLVM::LLVMFuncOp>
getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);

return LLVM::lookupOrCreateMallocFn(module, indexType);
}

LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
Operation *module, Type indexType) {
FailureOr<LLVM::LLVMFuncOp>
getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;

if (useGenericFn)
Expand Down Expand Up @@ -80,10 +82,13 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
<< " to integer address space "
"failed. Consider adding memory space conversions.";
}
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
if (failed(allocFuncOp))
return std::make_tuple(Value(), Value());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this result in a crash later on? If so, this would also require a FailureOr wrapping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned pair of value is populated and finally verified here: AllocLikeConvertion.cpp

  // Allocate the underlying buffer.
  auto [allocatedPtr, alignedPtr] =
      this->allocateBuffer(rewriter, loc, size, op);

  if (!allocatedPtr || !alignedPtr)
    return rewriter.notifyMatchFailure(loc,
                                       "underlying buffer allocation failed");

Here empty value is assigned as the invalid state and properly handled. I prefer to leave it as is, and a separate PR might be more appropriate to modernize the error handling style if we do have a preference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me 🙂

auto results =
rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);

Value allocatedPtr =
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
Expand Down Expand Up @@ -146,11 +151,13 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);

Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
if (failed(allocFuncOp))
return Value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled at MemRefToLLVM.cpp

    Value ptr = allocateBufferAutoAlign(
        rewriter, loc, sizeBytes, op, &defaultLayout,
        alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
                                      &defaultLayout));
    if (!ptr)
      return std::make_tuple(Value(), Value());
    return std::make_tuple(ptr, ptr);

in the same style.

auto results = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));

return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
elementPtrType, *getTypeConverter());
Expand Down
15 changes: 10 additions & 5 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) {
return !ShapedType::isDynamic(strideOrOffset);
}

LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
ModuleOp module) {
FailureOr<LLVM::LLVMFuncOp> getFreeFn(const LLVMTypeConverter *typeConverter,
ModuleOp module) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;

if (useGenericFn)
Expand Down Expand Up @@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
LLVM::LLVMFuncOp freeFunc =
auto freeFunc =
getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
if (failed(freeFunc))
return failure();
Value allocatedPtr;
if (auto unrankedTy =
llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
Expand All @@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
allocatedPtr = MemRefDescriptor(adaptor.getMemref())
.allocatedPtr(rewriter, op.getLoc());
}
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
allocatedPtr);
return success();
}
};
Expand Down Expand Up @@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
rewriter.create<LLVM::CallOp>(loc, copyFn,
if (failed(copyFn))
return failure();
rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
ValueRange{elemSize, sourcePtr, targetPtr});

// Restore stack used for descriptors
Expand Down
48 changes: 29 additions & 19 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,24 +1546,32 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {

auto punct = printOp.getPunctuation();
if (auto stringLiteral = printOp.getStringLiteral()) {
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter(),
/*addNewline=*/false);
if (LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter(),
/*addNewline=*/false)
.failed()) {
return failure();
}
} else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
return LLVM::lookupOrCreatePrintCloseFn(parent);
case PrintPunctuation::Open:
return LLVM::lookupOrCreatePrintOpenFn(parent);
case PrintPunctuation::Comma:
return LLVM::lookupOrCreatePrintCommaFn(parent);
case PrintPunctuation::NewLine:
return LLVM::lookupOrCreatePrintNewlineFn(parent);
default:
llvm_unreachable("unexpected punctuation");
}
}());
if (auto op = [&] -> FailureOr<LLVM::LLVMFuncOp> {
switch (punct) {
case PrintPunctuation::Close:
return LLVM::lookupOrCreatePrintCloseFn(parent);
case PrintPunctuation::Open:
return LLVM::lookupOrCreatePrintOpenFn(parent);
case PrintPunctuation::Comma:
return LLVM::lookupOrCreatePrintCommaFn(parent);
case PrintPunctuation::NewLine:
return LLVM::lookupOrCreatePrintNewlineFn(parent);
default:
llvm_unreachable("unexpected punctuation");
}
}();
succeeded(op))
emitCall(rewriter, printOp->getLoc(), op.value());
else {
return failure();
}
}

rewriter.eraseOp(printOp);
Expand All @@ -1588,7 +1596,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {

// Make sure element type has runtime support.
PrintConversion conversion = PrintConversion::None;
Operation *printer;
FailureOr<Operation *> printer;
if (printType.isF32()) {
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
} else if (printType.isF64()) {
Expand Down Expand Up @@ -1631,6 +1639,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
} else {
return failure();
}
if (failed(printer))
return failure();

switch (conversion) {
case PrintConversion::ZeroExt64:
Expand All @@ -1648,7 +1658,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
case PrintConversion::None:
break;
}
emitCall(rewriter, loc, printer, value);
emitCall(rewriter, loc, printer.value(), value);
return success();
}

Expand Down
Loading
Loading