Skip to content

Commit 5870e6b

Browse files
committed
ensure both headers are added
1 parent 12e967e commit 5870e6b

File tree

3 files changed

+70
-65
lines changed

3 files changed

+70
-65
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,6 @@ emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
3737
/*is_standard_include=*/builder.getUnitAttr());
3838
}
3939

40-
bool isExpectedStandardInclude(ConvertMemRefToEmitCOptions options,
41-
emitc::IncludeOp includeOp) {
42-
return ((options.lowerToCpp &&
43-
(includeOp.getInclude() == cppStandardLibraryHeader ||
44-
includeOp.getInclude() == cppStringLibraryHeader)) ||
45-
(!options.lowerToCpp &&
46-
(includeOp.getInclude() == cStandardLibraryHeader ||
47-
includeOp.getInclude() == cStringLibraryHeader)));
48-
}
49-
5040
struct ConvertMemRefToEmitCPass
5141
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
5242
using Base::Base;
@@ -75,34 +65,33 @@ struct ConvertMemRefToEmitCPass
7565
return signalPassFailure();
7666

7767
mlir::ModuleOp module = getOperation();
68+
llvm::SmallVector<StringRef> requiredHeaders;
7869
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
79-
if (callOp.getCallee() != alignedAllocFunctionName &&
80-
callOp.getCallee() != mallocFunctionName &&
81-
callOp.getCallee() != memcpyFunctionName)
82-
return mlir::WalkResult::advance();
83-
84-
for (auto &op : *module.getBody()) {
85-
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
86-
if (!includeOp)
87-
continue;
88-
89-
if (includeOp.getIsStandardInclude() &&
90-
isExpectedStandardInclude(options, includeOp))
91-
return mlir::WalkResult::interrupt();
92-
}
93-
94-
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
95-
StringRef headerName;
96-
if (callOp.getCallee() == memcpyFunctionName)
97-
headerName =
70+
StringRef expectedHeader;
71+
if (callOp.getCallee() == alignedAllocFunctionName ||
72+
callOp.getCallee() == mallocFunctionName)
73+
expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader
74+
: cStandardLibraryHeader;
75+
else if (callOp.getCallee() == memcpyFunctionName)
76+
expectedHeader =
9877
options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
9978
else
100-
headerName = options.lowerToCpp ? cppStandardLibraryHeader
101-
: cStandardLibraryHeader;
102-
103-
addStandardHeader(builder, module, headerName);
104-
return mlir::WalkResult::interrupt();
79+
return mlir::WalkResult::advance();
80+
requiredHeaders.push_back(expectedHeader);
81+
return mlir::WalkResult::advance();
10582
});
83+
for (StringRef expectedHeader : requiredHeaders) {
84+
bool headerFound = llvm::any_of(*module.getBody(), [&](Operation &op) {
85+
auto includeOp = dyn_cast<mlir::emitc::IncludeOp>(op);
86+
return includeOp && includeOp.getIsStandardInclude() &&
87+
(includeOp.getInclude() == expectedHeader);
88+
});
89+
90+
if (!headerFound) {
91+
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
92+
addStandardHeader(builder, module, expectedHeader);
93+
}
94+
}
10695
}
10796
};
10897
} // namespace
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
2+
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
3+
4+
5+
func.func @alloc_copy(%arg0: memref<999xi32>) {
6+
%alloc = memref.alloc() : memref<999xi32>
7+
memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32>
8+
return
9+
}
10+
11+
// NOCPP: module {
12+
// NOCPP-NEXT: emitc.include <"string.h">
13+
// NOCPP-NEXT: emitc.include <"stdlib.h">
14+
15+
// CPP: module {
16+
// CPP-NEXT: emitc.include <"cstring">
17+
// CHECK-NEXT: emitc.include <"cstdlib">
18+
// CHECK-LABEL: alloc_copy
19+
// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
20+
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
21+
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
22+
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
23+
// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t
24+
// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
25+
// CHECK-NEXT: emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
26+
// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32>
27+
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
28+
// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
29+
// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
30+

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,24 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
88

99
// NOCPP: module {
1010
// NOCPP-NEXT: emitc.include <"string.h">
11-
// NOCPP-LABEL: copying
12-
// NOCPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
13-
// NOCPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
14-
// NOCPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
15-
// NOCPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
16-
// NOCPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
17-
// NOCPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
18-
// NOCPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
19-
// NOCPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
20-
// NOCPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
21-
// NOCPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
22-
// NOCPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
23-
// NOCPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
24-
// NOCPP-NEXT: return
25-
// NOCPP-NEXT: }
26-
// NOCPP-NEXT:}
11+
2712

2813
// CPP: module {
2914
// CPP-NEXT: emitc.include <"cstring">
3015
// CPP-LABEL: copying
31-
// CPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
32-
// CPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
33-
// CPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
34-
// CPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
35-
// CPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
36-
// CPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
37-
// CPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
38-
// CPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
39-
// CPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
40-
// CPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
41-
// CPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
42-
// CPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
43-
// CPP-NEXT: return
44-
// CPP-NEXT: }
45-
// CPP-NEXT:}
16+
// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
17+
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
18+
// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
19+
// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
20+
// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
21+
// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
22+
// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
23+
// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
24+
// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
25+
// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
26+
// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
27+
// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
28+
// CHECK-NEXT: return
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT:}
31+

0 commit comments

Comments
 (0)