Skip to content

Commit bfda0e7

Browse files
authored
[mlir][EmitC] Expand the MemRefToEmitC pass - Lowering CopyOp (#151206)
This patch lowers `memref.copy` to `emitc.call_opaque "memcpy"`. From: ``` func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32> return } ``` To: ```cpp #include <cstring> void copying(float v1[9][4][5][7], float v2[9][4][5][7]) { size_t v3 = 0; float* v4 = &v2[v3][v3][v3][v3]; float* v5 = &v1[v3][v3][v3][v3]; size_t v6 = sizeof(float); size_t v7 = 1260; size_t v8 = v6 * v7; memcpy(v5, v4, v8); return; } ```
1 parent 6d08a39 commit bfda0e7

File tree

6 files changed

+201
-35
lines changed

6 files changed

+201
-35
lines changed

mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111
constexpr const char *alignedAllocFunctionName = "aligned_alloc";
1212
constexpr const char *mallocFunctionName = "malloc";
13+
constexpr const char *memcpyFunctionName = "memcpy";
1314
constexpr const char *cppStandardLibraryHeader = "cstdlib";
1415
constexpr const char *cStandardLibraryHeader = "stdlib.h";
16+
constexpr const char *cppStringLibraryHeader = "cstring";
17+
constexpr const char *cStringLibraryHeader = "string.h";
1518

1619
namespace mlir {
1720
class DialectRegistry;

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/IR/Builders.h"
1919
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/IR/Diagnostics.h"
2021
#include "mlir/IR/PatternMatch.h"
2122
#include "mlir/IR/TypeRange.h"
2223
#include "mlir/IR/Value.h"
2324
#include "mlir/Transforms/DialectConversion.h"
2425
#include <cstdint>
26+
#include <numeric>
2527

2628
using namespace mlir;
2729

@@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
9799
return resultTy;
98100
}
99101

102+
static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
103+
OpBuilder &builder) {
104+
assert(isMemRefTypeLegalForEmitC(memrefType) &&
105+
"incompatible memref type for EmitC conversion");
106+
emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
107+
builder, loc, emitc::SizeTType::get(builder.getContext()),
108+
builder.getStringAttr("sizeof"), ValueRange{},
109+
ArrayAttr::get(builder.getContext(),
110+
{TypeAttr::get(memrefType.getElementType())}));
111+
112+
IndexType indexType = builder.getIndexType();
113+
int64_t numElements = std::accumulate(memrefType.getShape().begin(),
114+
memrefType.getShape().end(), int64_t{1},
115+
std::multiplies<int64_t>());
116+
emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
117+
builder, loc, indexType, builder.getIndexAttr(numElements));
118+
119+
Type sizeTType = emitc::SizeTType::get(builder.getContext());
120+
emitc::MulOp totalSizeBytes = emitc::MulOp::create(
121+
builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
122+
123+
return totalSizeBytes.getResult();
124+
}
125+
126+
static emitc::ApplyOp
127+
createPointerFromEmitcArray(Location loc, OpBuilder &builder,
128+
TypedValue<emitc::ArrayType> arrayValue) {
129+
130+
emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
131+
builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
132+
133+
emitc::ArrayType arrayType = arrayValue.getType();
134+
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
135+
emitc::SubscriptOp subPtr =
136+
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
137+
emitc::ApplyOp ptr = emitc::ApplyOp::create(
138+
builder, loc, emitc::PointerType::get(arrayType.getElementType()),
139+
builder.getStringAttr("&"), subPtr);
140+
141+
return ptr;
142+
}
143+
100144
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
101145
using OpConversionPattern::OpConversionPattern;
102146
LogicalResult
@@ -159,6 +203,47 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
159203
}
160204
};
161205

206+
struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
207+
using OpConversionPattern::OpConversionPattern;
208+
209+
LogicalResult
210+
matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
211+
ConversionPatternRewriter &rewriter) const override {
212+
Location loc = copyOp.getLoc();
213+
MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
214+
MemRefType targetMemrefType =
215+
cast<MemRefType>(copyOp.getTarget().getType());
216+
217+
if (!isMemRefTypeLegalForEmitC(srcMemrefType))
218+
return rewriter.notifyMatchFailure(
219+
loc, "incompatible source memref type for EmitC conversion");
220+
221+
if (!isMemRefTypeLegalForEmitC(targetMemrefType))
222+
return rewriter.notifyMatchFailure(
223+
loc, "incompatible target memref type for EmitC conversion");
224+
225+
auto srcArrayValue =
226+
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
227+
emitc::ApplyOp srcPtr =
228+
createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
229+
230+
auto targetArrayValue =
231+
cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
232+
emitc::ApplyOp targetPtr =
233+
createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
234+
235+
emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
236+
rewriter, loc, TypeRange{}, "memcpy",
237+
ValueRange{
238+
targetPtr.getResult(), srcPtr.getResult(),
239+
calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
240+
241+
rewriter.replaceOp(copyOp, memCpyCall.getResults());
242+
243+
return success();
244+
}
245+
};
246+
162247
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
163248
using OpConversionPattern::OpConversionPattern;
164249

@@ -320,6 +405,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
320405

321406
void mlir::populateMemRefToEmitCConversionPatterns(
322407
RewritePatternSet &patterns, const TypeConverter &converter) {
323-
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
324-
ConvertLoad, ConvertStore>(converter, patterns.getContext());
408+
patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
409+
ConvertGetGlobal, ConvertLoad, ConvertStore>(
410+
converter, patterns.getContext());
325411
}

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "mlir/IR/Attributes.h"
1919
#include "mlir/Pass/Pass.h"
2020
#include "mlir/Transforms/DialectConversion.h"
21+
#include "llvm/ADT/SmallSet.h"
22+
#include "llvm/ADT/StringRef.h"
2123

2224
namespace mlir {
2325
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
@@ -27,6 +29,15 @@ namespace mlir {
2729
using namespace mlir;
2830

2931
namespace {
32+
33+
emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
34+
StringRef headerName) {
35+
StringAttr includeAttr = builder.getStringAttr(headerName);
36+
return builder.create<emitc::IncludeOp>(
37+
module.getLoc(), includeAttr,
38+
/*is_standard_include=*/builder.getUnitAttr());
39+
}
40+
3041
struct ConvertMemRefToEmitCPass
3142
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
3243
using Base::Base;
@@ -55,34 +66,29 @@ struct ConvertMemRefToEmitCPass
5566
return signalPassFailure();
5667

5768
mlir::ModuleOp module = getOperation();
69+
llvm::SmallSet<StringRef, 4> existingHeaders;
70+
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
71+
module.walk([&](mlir::emitc::IncludeOp includeOp) {
72+
if (includeOp.getIsStandardInclude())
73+
existingHeaders.insert(includeOp.getInclude());
74+
});
75+
5876
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
59-
if (callOp.getCallee() != alignedAllocFunctionName &&
60-
callOp.getCallee() != mallocFunctionName) {
77+
StringRef expectedHeader;
78+
if (callOp.getCallee() == alignedAllocFunctionName ||
79+
callOp.getCallee() == mallocFunctionName)
80+
expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader
81+
: cStandardLibraryHeader;
82+
else if (callOp.getCallee() == memcpyFunctionName)
83+
expectedHeader =
84+
options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
85+
else
6186
return mlir::WalkResult::advance();
87+
if (!existingHeaders.contains(expectedHeader)) {
88+
addStandardHeader(builder, module, expectedHeader);
89+
existingHeaders.insert(expectedHeader);
6290
}
63-
64-
for (auto &op : *module.getBody()) {
65-
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
66-
if (!includeOp) {
67-
continue;
68-
}
69-
if (includeOp.getIsStandardInclude() &&
70-
((options.lowerToCpp &&
71-
includeOp.getInclude() == cppStandardLibraryHeader) ||
72-
(!options.lowerToCpp &&
73-
includeOp.getInclude() == cStandardLibraryHeader))) {
74-
return mlir::WalkResult::interrupt();
75-
}
76-
}
77-
78-
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
79-
StringAttr includeAttr =
80-
builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
81-
: cStandardLibraryHeader);
82-
builder.create<mlir::emitc::IncludeOp>(
83-
module.getLoc(), includeAttr,
84-
/*is_standard_include=*/builder.getUnitAttr());
85-
return mlir::WalkResult::interrupt();
91+
return mlir::WalkResult::advance();
8692
});
8793
}
8894
};
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
2+
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
3+
4+
func.func @alloc_copy(%arg0: memref<999xi32>) {
5+
%alloc = memref.alloc() : memref<999xi32>
6+
memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32>
7+
%alloc_1 = memref.alloc() : memref<999xi32>
8+
memref.copy %arg0, %alloc_1 : memref<999xi32> to memref<999xi32>
9+
return
10+
}
11+
12+
// CHECK: module {
13+
// NOCPP: emitc.include <"stdlib.h">
14+
// NOCPP-NEXT: emitc.include <"string.h">
15+
16+
// CPP: emitc.include <"cstdlib">
17+
// CPP-NEXT: emitc.include <"cstring">
18+
19+
// CHECK-LABEL: alloc_copy
20+
// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
21+
// CHECK-NEXT: builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
22+
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
23+
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
24+
// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t
25+
// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
26+
// CHECK-NEXT: emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
27+
// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32>
28+
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
29+
// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
30+
// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
31+
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
32+
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
33+
// CHECK-NEXT: emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t
34+
// CHECK-NEXT: emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
35+
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
36+
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
37+
// CHECK-NEXT: emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t
38+
// CHECK-NEXT: emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
39+
// CHECK-NEXT: emitc.cast %18 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
40+
// CHECK-NEXT: builtin.unrealized_conversion_cast %19 : !emitc.ptr<i32> to !emitc.array<999xi32>
41+
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
42+
// CHECK-NEXT: emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
43+
// CHECK-NEXT: emitc.apply "&"(%22) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
44+
// CHECK-NEXT: emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
45+
// CHECK-NEXT: emitc.apply "&"(%24) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
46+
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
47+
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
48+
// CHECK-NEXT: emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t
49+
// CHECK-NEXT: emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
50+
// CHECK-NEXT: return
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
2+
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
3+
4+
func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
5+
memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
6+
return
7+
}
8+
9+
// CHECK: module {
10+
// NOCPP: emitc.include <"string.h">
11+
// CPP: emitc.include <"cstring">
12+
13+
// CHECK-LABEL: copying
14+
// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
15+
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
16+
// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
17+
// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
18+
// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
19+
// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
20+
// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
21+
// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
22+
// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
23+
// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
24+
// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
25+
// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
26+
// CHECK-NEXT: return
27+
// CHECK-NEXT: }
28+
// CHECK-NEXT:}
29+

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
22

3-
func.func @memref_op(%arg0 : memref<2x4xf32>) {
4-
// expected-error@+1 {{failed to legalize operation 'memref.copy'}}
5-
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
6-
return
7-
}
8-
9-
// -----
10-
113
func.func @alloca_with_dynamic_shape() {
124
%0 = index.constant 1
135
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}

0 commit comments

Comments
 (0)