Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

constexpr const char *alignedAllocFunctionName = "aligned_alloc";
constexpr const char *mallocFunctionName = "malloc";
constexpr const char *memcpyFunctionName = "memcpy";
constexpr const char *cppStandardLibraryHeader = "cstdlib";
constexpr const char *cStandardLibraryHeader = "stdlib.h";
constexpr const char *cppStringLibraryHeader = "cstring";
constexpr const char *cStringLibraryHeader = "string.h";

namespace mlir {
class DialectRegistry;
Expand Down
90 changes: 88 additions & 2 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
#include <numeric>

using namespace mlir;

Expand Down Expand Up @@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}

static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
OpBuilder &builder) {
assert(isMemRefTypeLegalForEmitC(memrefType) &&
"incompatible memref type for EmitC conversion");
emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
builder, loc, emitc::SizeTType::get(builder.getContext()),
builder.getStringAttr("sizeof"), ValueRange{},
ArrayAttr::get(builder.getContext(),
{TypeAttr::get(memrefType.getElementType())}));

IndexType indexType = builder.getIndexType();
int64_t numElements = std::accumulate(memrefType.getShape().begin(),
memrefType.getShape().end(), int64_t{1},
std::multiplies<int64_t>());
emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
builder, loc, indexType, builder.getIndexAttr(numElements));

Type sizeTType = emitc::SizeTType::get(builder.getContext());
emitc::MulOp totalSizeBytes = emitc::MulOp::create(
builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);

return totalSizeBytes.getResult();
}

static emitc::ApplyOp
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm making this a function instead a lambda so that we can use it later when lowering other memref ops like extract_strided_metadata and reinterpret_cast that need a pointer to the first element of the array.

createPointerFromEmitcArray(Location loc, OpBuilder &builder,
TypedValue<emitc::ArrayType> arrayValue) {

emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
builder, loc, builder.getIndexType(), builder.getIndexAttr(0));

emitc::ArrayType arrayType = arrayValue.getType();
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
emitc::ApplyOp ptr = emitc::ApplyOp::create(
builder, loc, emitc::PointerType::get(arrayType.getElementType()),
builder.getStringAttr("&"), subPtr);

return ptr;
}

struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -159,6 +203,47 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
}
};

struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = copyOp.getLoc();
MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
MemRefType targetMemrefType =
cast<MemRefType>(copyOp.getTarget().getType());

if (!isMemRefTypeLegalForEmitC(srcMemrefType))
return rewriter.notifyMatchFailure(
loc, "incompatible source memref type for EmitC conversion");

if (!isMemRefTypeLegalForEmitC(targetMemrefType))
return rewriter.notifyMatchFailure(
loc, "incompatible target memref type for EmitC conversion");

auto srcArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
emitc::ApplyOp srcPtr =
createPointerFromEmitcArray(loc, rewriter, srcArrayValue);

auto targetArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
emitc::ApplyOp targetPtr =
createPointerFromEmitcArray(loc, rewriter, targetArrayValue);

emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
rewriter, loc, TypeRange{}, "memcpy",
ValueRange{
targetPtr.getResult(), srcPtr.getResult(),
calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});

rewriter.replaceOp(copyOp, memCpyCall.getResults());

return success();
}
};

struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -320,6 +405,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {

void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
ConvertLoad, ConvertStore>(converter, patterns.getContext());
patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
ConvertGetGlobal, ConvertLoad, ConvertStore>(
converter, patterns.getContext());
}
56 changes: 31 additions & 25 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
Expand All @@ -27,6 +29,15 @@ namespace mlir {
using namespace mlir;

namespace {

emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
StringRef headerName) {
StringAttr includeAttr = builder.getStringAttr(headerName);
return builder.create<emitc::IncludeOp>(
module.getLoc(), includeAttr,
/*is_standard_include=*/builder.getUnitAttr());
}

struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
Expand Down Expand Up @@ -55,34 +66,29 @@ struct ConvertMemRefToEmitCPass
return signalPassFailure();

mlir::ModuleOp module = getOperation();
llvm::SmallSet<StringRef, 4> existingHeaders;
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
module.walk([&](mlir::emitc::IncludeOp includeOp) {
if (includeOp.getIsStandardInclude())
existingHeaders.insert(includeOp.getInclude());
});

module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
if (callOp.getCallee() != alignedAllocFunctionName &&
callOp.getCallee() != mallocFunctionName) {
StringRef expectedHeader;
if (callOp.getCallee() == alignedAllocFunctionName ||
callOp.getCallee() == mallocFunctionName)
expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader
: cStandardLibraryHeader;
else if (callOp.getCallee() == memcpyFunctionName)
expectedHeader =
options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
else
return mlir::WalkResult::advance();
if (!existingHeaders.contains(expectedHeader)) {
addStandardHeader(builder, module, expectedHeader);
existingHeaders.insert(expectedHeader);
}

for (auto &op : *module.getBody()) {
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
if (!includeOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change itself is good (removing braces from single-line blocks) but should be done on a separate PR to avoid cluttering this one with unrelated modifications.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally, yes.
but im already modifying this portion of code and this would be a single line change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally, yes. but im already modifying this portion of code and this would be a single line change.

Practically too. It's not about the number of lines or their proximity to other changes. LLVM's contribution policy requires patches to be minimal. More specifically:

* not contain any unrelated changes
* be an isolated change. Independent changes should be submitted as separate patches as this makes reviewing easier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, thanks for the feedback!

continue;
}
if (includeOp.getIsStandardInclude() &&
((options.lowerToCpp &&
includeOp.getInclude() == cppStandardLibraryHeader) ||
(!options.lowerToCpp &&
includeOp.getInclude() == cStandardLibraryHeader))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.
Also, shouldn't the code here also check for c/cppStringLibraryHeader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i should!
thanks for the pointer!

return mlir::WalkResult::interrupt();
}
}

mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
StringAttr includeAttr =
builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
: cStandardLibraryHeader);
builder.create<mlir::emitc::IncludeOp>(
module.getLoc(), includeAttr,
/*is_standard_include=*/builder.getUnitAttr());
return mlir::WalkResult::interrupt();
return mlir::WalkResult::advance();
});
}
};
Expand Down
50 changes: 50 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP

func.func @alloc_copy(%arg0: memref<999xi32>) {
%alloc = memref.alloc() : memref<999xi32>
memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32>
%alloc_1 = memref.alloc() : memref<999xi32>
memref.copy %arg0, %alloc_1 : memref<999xi32> to memref<999xi32>
return
}

// CHECK: module {
// NOCPP: emitc.include <"stdlib.h">
// NOCPP-NEXT: emitc.include <"string.h">

// CPP: emitc.include <"cstdlib">
// CPP-NEXT: emitc.include <"cstring">

// CHECK-LABEL: alloc_copy
// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
// CHECK-NEXT: builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t
// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// CHECK-NEXT: emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32>
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
// CHECK-NEXT: emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t
// CHECK-NEXT: emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
// CHECK-NEXT: emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t
// CHECK-NEXT: emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// CHECK-NEXT: emitc.cast %18 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// CHECK-NEXT: builtin.unrealized_conversion_cast %19 : !emitc.ptr<i32> to !emitc.array<999xi32>
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK-NEXT: emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
// CHECK-NEXT: emitc.apply "&"(%22) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK-NEXT: emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
// CHECK-NEXT: emitc.apply "&"(%24) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
// CHECK-NEXT: emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t
// CHECK-NEXT: emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
// CHECK-NEXT: return
29 changes: 29 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP

func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
return
}

// CHECK: module {
// NOCPP: emitc.include <"string.h">
// CPP: emitc.include <"cstring">

// CHECK-LABEL: copying
// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-NEXT:}

Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics

func.func @memref_op(%arg0 : memref<2x4xf32>) {
// expected-error@+1 {{failed to legalize operation 'memref.copy'}}
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
return
}

// -----

func.func @alloca_with_dynamic_shape() {
%0 = index.constant 1
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
Expand Down
Loading