1919#include " mlir/IR/BuiltinTypes.h"
2020#include " mlir/IR/PatternMatch.h"
2121#include " mlir/IR/TypeRange.h"
22+ #include " mlir/IR/Value.h"
2223#include " mlir/Transforms/DialectConversion.h"
24+ #include < cstdint>
2325
2426using namespace mlir ;
2527
28+ static bool isMemRefTypeLegalForEmitC (MemRefType memRefType) {
29+ return memRefType.hasStaticShape () && memRefType.getLayout ().isIdentity () &&
30+ memRefType.getRank () != 0 &&
31+ !llvm::is_contained (memRefType.getShape (), 0 );
32+ }
33+
2634namespace {
2735// / Implement the interface to convert MemRef to EmitC.
2836struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
8997 return resultTy;
9098}
9199
100+ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
101+ using OpConversionPattern::OpConversionPattern;
102+ LogicalResult
103+ matchAndRewrite (memref::AllocOp allocOp, OpAdaptor operands,
104+ ConversionPatternRewriter &rewriter) const override {
105+ Location loc = allocOp.getLoc ();
106+ MemRefType memrefType = allocOp.getType ();
107+ if (!isMemRefTypeLegalForEmitC (memrefType)) {
108+ return rewriter.notifyMatchFailure (
109+ loc, " incompatible memref type for EmitC conversion" );
110+ }
111+
112+ Type sizeTType = emitc::SizeTType::get (rewriter.getContext ());
113+ Type elementType = memrefType.getElementType ();
114+ IndexType indexType = rewriter.getIndexType ();
115+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create <emitc::CallOpaqueOp>(
116+ loc, sizeTType, rewriter.getStringAttr (" sizeof" ), ValueRange{},
117+ ArrayAttr::get (rewriter.getContext (), {TypeAttr::get (elementType)}));
118+
119+ int64_t numElements = 1 ;
120+ for (int64_t dimSize : memrefType.getShape ()) {
121+ numElements *= dimSize;
122+ }
123+ Value numElementsValue = rewriter.create <emitc::ConstantOp>(
124+ loc, indexType, rewriter.getIndexAttr (numElements));
125+
126+ Value totalSizeBytes = rewriter.create <emitc::MulOp>(
127+ loc, sizeTType, sizeofElementOp.getResult (0 ), numElementsValue);
128+
129+ emitc::CallOpaqueOp allocCall;
130+ StringAttr allocFunctionName;
131+ Value alignmentValue;
132+ SmallVector<Value, 2 > argsVec;
133+ if (allocOp.getAlignment ()) {
134+ allocFunctionName = rewriter.getStringAttr (alignedAllocFunctionName);
135+ alignmentValue = rewriter.create <emitc::ConstantOp>(
136+ loc, sizeTType,
137+ rewriter.getIntegerAttr (indexType,
138+ allocOp.getAlignment ().value_or (0 )));
139+ argsVec.push_back (alignmentValue);
140+ } else {
141+ allocFunctionName = rewriter.getStringAttr (mallocFunctionName);
142+ }
143+
144+ argsVec.push_back (totalSizeBytes);
145+ ValueRange args (argsVec);
146+
147+ allocCall = rewriter.create <emitc::CallOpaqueOp>(
148+ loc,
149+ emitc::PointerType::get (
150+ emitc::OpaqueType::get (rewriter.getContext (), " void" )),
151+ allocFunctionName, args);
152+
153+ emitc::PointerType targetPointerType = emitc::PointerType::get (elementType);
154+ emitc::CastOp castOp = rewriter.create <emitc::CastOp>(
155+ loc, targetPointerType, allocCall.getResult (0 ));
156+
157+ rewriter.replaceOp (allocOp, castOp);
158+ return success ();
159+ }
160+ };
161+
92162struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
93163 using OpConversionPattern::OpConversionPattern;
94164
@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
223293void mlir::populateMemRefToEmitCTypeConversion (TypeConverter &typeConverter) {
224294 typeConverter.addConversion (
225295 [&](MemRefType memRefType) -> std::optional<Type> {
226- if (!memRefType.hasStaticShape () ||
227- !memRefType.getLayout ().isIdentity () || memRefType.getRank () == 0 ||
228- llvm::is_contained (memRefType.getShape (), 0 )) {
296+ if (!isMemRefTypeLegalForEmitC (memRefType)) {
229297 return {};
230298 }
231299 Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
252320
253321void mlir::populateMemRefToEmitCConversionPatterns (
254322 RewritePatternSet &patterns, const TypeConverter &converter) {
255- patterns.add <ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad ,
256- ConvertStore>(converter, patterns.getContext ());
323+ patterns.add <ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal ,
324+ ConvertLoad, ConvertStore>(converter, patterns.getContext ());
257325}
0 commit comments