19
19
#include " mlir/IR/BuiltinTypes.h"
20
20
#include " mlir/IR/PatternMatch.h"
21
21
#include " mlir/IR/TypeRange.h"
22
+ #include " mlir/IR/Value.h"
22
23
#include " mlir/Transforms/DialectConversion.h"
24
+ #include < cstdint>
23
25
24
26
using namespace mlir ;
25
27
28
+ static bool isMemRefTypeLegalForEmitC (MemRefType memRefType) {
29
+ return memRefType.hasStaticShape () && memRefType.getLayout ().isIdentity () &&
30
+ memRefType.getRank () != 0 &&
31
+ !llvm::is_contained (memRefType.getShape (), 0 );
32
+ }
33
+
26
34
namespace {
27
35
// / Implement the interface to convert MemRef to EmitC.
28
36
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
89
97
return resultTy;
90
98
}
91
99
100
+ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
101
+ using OpConversionPattern::OpConversionPattern;
102
+ LogicalResult
103
+ matchAndRewrite (memref::AllocOp allocOp, OpAdaptor operands,
104
+ ConversionPatternRewriter &rewriter) const override {
105
+ Location loc = allocOp.getLoc ();
106
+ MemRefType memrefType = allocOp.getType ();
107
+ if (!isMemRefTypeLegalForEmitC (memrefType)) {
108
+ return rewriter.notifyMatchFailure (
109
+ loc, " incompatible memref type for EmitC conversion" );
110
+ }
111
+
112
+ Type sizeTType = emitc::SizeTType::get (rewriter.getContext ());
113
+ Type elementType = memrefType.getElementType ();
114
+ IndexType indexType = rewriter.getIndexType ();
115
+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create <emitc::CallOpaqueOp>(
116
+ loc, sizeTType, rewriter.getStringAttr (" sizeof" ), ValueRange{},
117
+ ArrayAttr::get (rewriter.getContext (), {TypeAttr::get (elementType)}));
118
+
119
+ int64_t numElements = 1 ;
120
+ for (int64_t dimSize : memrefType.getShape ()) {
121
+ numElements *= dimSize;
122
+ }
123
+ Value numElementsValue = rewriter.create <emitc::ConstantOp>(
124
+ loc, indexType, rewriter.getIndexAttr (numElements));
125
+
126
+ Value totalSizeBytes = rewriter.create <emitc::MulOp>(
127
+ loc, sizeTType, sizeofElementOp.getResult (0 ), numElementsValue);
128
+
129
+ emitc::CallOpaqueOp allocCall;
130
+ StringAttr allocFunctionName;
131
+ Value alignmentValue;
132
+ SmallVector<Value, 2 > argsVec;
133
+ if (allocOp.getAlignment ()) {
134
+ allocFunctionName = rewriter.getStringAttr (alignedAllocFunctionName);
135
+ alignmentValue = rewriter.create <emitc::ConstantOp>(
136
+ loc, sizeTType,
137
+ rewriter.getIntegerAttr (indexType,
138
+ allocOp.getAlignment ().value_or (0 )));
139
+ argsVec.push_back (alignmentValue);
140
+ } else {
141
+ allocFunctionName = rewriter.getStringAttr (mallocFunctionName);
142
+ }
143
+
144
+ argsVec.push_back (totalSizeBytes);
145
+ ValueRange args (argsVec);
146
+
147
+ allocCall = rewriter.create <emitc::CallOpaqueOp>(
148
+ loc,
149
+ emitc::PointerType::get (
150
+ emitc::OpaqueType::get (rewriter.getContext (), " void" )),
151
+ allocFunctionName, args);
152
+
153
+ emitc::PointerType targetPointerType = emitc::PointerType::get (elementType);
154
+ emitc::CastOp castOp = rewriter.create <emitc::CastOp>(
155
+ loc, targetPointerType, allocCall.getResult (0 ));
156
+
157
+ rewriter.replaceOp (allocOp, castOp);
158
+ return success ();
159
+ }
160
+ };
161
+
92
162
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
93
163
using OpConversionPattern::OpConversionPattern;
94
164
@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
223
293
void mlir::populateMemRefToEmitCTypeConversion (TypeConverter &typeConverter) {
224
294
typeConverter.addConversion (
225
295
[&](MemRefType memRefType) -> std::optional<Type> {
226
- if (!memRefType.hasStaticShape () ||
227
- !memRefType.getLayout ().isIdentity () || memRefType.getRank () == 0 ||
228
- llvm::is_contained (memRefType.getShape (), 0 )) {
296
+ if (!isMemRefTypeLegalForEmitC (memRefType)) {
229
297
return {};
230
298
}
231
299
Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
252
320
253
321
void mlir::populateMemRefToEmitCConversionPatterns (
254
322
RewritePatternSet &patterns, const TypeConverter &converter) {
255
- patterns.add <ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad ,
256
- ConvertStore>(converter, patterns.getContext ());
323
+ patterns.add <ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal ,
324
+ ConvertLoad, ConvertStore>(converter, patterns.getContext ());
257
325
}
0 commit comments