@@ -67,7 +67,7 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
6767 LogicalResult
6868 matchAndRewrite (memref::AllocaOp op, OpAdaptor operands,
6969 ConversionPatternRewriter &rewriter) const override {
70-
70+ auto memRefType = op. getType ();
7171 if (!op.getType ().hasStaticShape ()) {
7272 return rewriter.notifyMatchFailure (
7373 op.getLoc (), " cannot transform alloca with dynamic shape" );
@@ -80,12 +80,48 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
8080 op.getLoc (), " cannot transform alloca with alignment requirement" );
8181 }
8282
83- auto resultTy = getTypeConverter ()->convertType (op.getType ());
84- if (!resultTy) {
85- return rewriter.notifyMatchFailure (op.getLoc (), " cannot convert type" );
83+ if (op.getType ().getRank () == 0 ||
84+ llvm::is_contained (memRefType.getShape (), 0 )) {
85+ return rewriter.notifyMatchFailure (
86+ op.getLoc (), " cannot transform alloca with rank 0 or zero-sized dim" );
8687 }
88+
89+ auto convertedTy = getTypeConverter ()->convertType (memRefType);
90+ if (!convertedTy) {
91+ return rewriter.notifyMatchFailure (op.getLoc (),
92+ " cannot convert memref type" );
93+ }
94+
95+ auto arrayTy = emitc::ArrayType::get (memRefType.getShape (),
96+ memRefType.getElementType ());
97+ auto elemTy = memRefType.getElementType ();
98+
8799 auto noInit = emitc::OpaqueAttr::get (getContext (), " " );
88- rewriter.replaceOpWithNewOp <emitc::VariableOp>(op, resultTy, noInit);
100+ auto arrayVar =
101+ rewriter.create <emitc::VariableOp>(op.getLoc (), arrayTy, noInit);
102+
103+ // Build zero indices for the base subscript.
104+ SmallVector<Value> indices;
105+ for (unsigned i = 0 ; i < memRefType.getRank (); ++i) {
106+ auto zero = rewriter.create <emitc::ConstantOp>(
107+ op.getLoc (), rewriter.getIndexType (), rewriter.getIndexAttr (0 ));
108+ indices.push_back (zero);
109+ }
110+
111+ auto current = rewriter.create <emitc::SubscriptOp>(
112+ op.getLoc (), emitc::LValueType::get (elemTy), arrayVar.getResult (),
113+ indices);
114+
115+ auto ptrElemTy = emitc::PointerType::get (elemTy);
116+ auto addrOf = rewriter.create <emitc::ApplyOp>(op.getLoc (), ptrElemTy,
117+ rewriter.getStringAttr (" &" ),
118+ current.getResult ());
119+
120+ auto ptrArrayTy = emitc::PointerType::get (arrayTy);
121+ auto casted = rewriter.create <emitc::CastOp>(op.getLoc (), ptrArrayTy,
122+ addrOf.getResult ());
123+
124+ rewriter.replaceOp (op, casted.getResult ());
89125 return success ();
90126 }
91127};
@@ -122,24 +158,6 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
122158 return totalSizeBytes.getResult ();
123159}
124160
125- static emitc::ApplyOp
126- createPointerFromEmitcArray (Location loc, OpBuilder &builder,
127- TypedValue<emitc::ArrayType> arrayValue) {
128-
129- emitc::ConstantOp zeroIndex = emitc::ConstantOp::create (
130- builder, loc, builder.getIndexType (), builder.getIndexAttr (0 ));
131-
132- emitc::ArrayType arrayType = arrayValue.getType ();
133- llvm::SmallVector<mlir::Value> indices (arrayType.getRank (), zeroIndex);
134- emitc::SubscriptOp subPtr =
135- emitc::SubscriptOp::create (builder, loc, arrayValue, ValueRange (indices));
136- emitc::ApplyOp ptr = emitc::ApplyOp::create (
137- builder, loc, emitc::PointerType::get (arrayType.getElementType ()),
138- builder.getStringAttr (" &" ), subPtr);
139-
140- return ptr;
141- }
142-
143161struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
144162 using OpConversionPattern::OpConversionPattern;
145163 LogicalResult
@@ -224,20 +242,10 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
224242 return rewriter.notifyMatchFailure (
225243 loc, " incompatible target memref type for EmitC conversion" );
226244
227- auto srcArrayValue =
228- cast<TypedValue<emitc::ArrayType>>(operands.getSource ());
229- emitc::ApplyOp srcPtr =
230- createPointerFromEmitcArray (loc, rewriter, srcArrayValue);
231-
232- auto targetArrayValue =
233- cast<TypedValue<emitc::ArrayType>>(operands.getTarget ());
234- emitc::ApplyOp targetPtr =
235- createPointerFromEmitcArray (loc, rewriter, targetArrayValue);
236-
237245 emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create (
238246 rewriter, loc, TypeRange{}, " memcpy" ,
239247 ValueRange{
240- targetPtr. getResult (), srcPtr. getResult (),
248+ operands. getTarget (), operands. getSource (),
241249 calculateMemrefTotalSizeBytes (loc, srcMemrefType, rewriter)});
242250
243251 rewriter.replaceOp (copyOp, memCpyCall.getResults ());
@@ -265,11 +273,14 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
265273 " currently not supported" );
266274 }
267275
268- Type resultTy = convertMemRefType (opTy, getTypeConverter ());
269-
270- if (!resultTy) {
271- return rewriter.notifyMatchFailure (op.getLoc (),
272- " cannot convert result type" );
276+ Type elemTy = getTypeConverter ()->convertType (opTy.getElementType ());
277+ Type globalType;
278+ if (opTy.getRank () == 0 ) {
279+ globalType = elemTy;
280+ } else {
281+ SmallVector<int64_t > shape (opTy.getShape ().begin (),
282+ opTy.getShape ().end ());
283+ globalType = emitc::ArrayType::get (shape, elemTy);
273284 }
274285
275286 SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility (op);
@@ -293,7 +304,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
293304 initialValue = {};
294305
295306 rewriter.replaceOpWithNewOp <emitc::GlobalOp>(
296- op, operands.getSymName (), resultTy , initialValue, externSpecifier,
307+ op, operands.getSymName (), globalType , initialValue, externSpecifier,
297308 staticSpecifier, operands.getConstant ());
298309 return success ();
299310 }
@@ -308,24 +319,64 @@ struct ConvertGetGlobal final
308319 ConversionPatternRewriter &rewriter) const override {
309320
310321 MemRefType opTy = op.getType ();
322+ Location loc = op.getLoc ();
323+
324+ Type elemTy = getTypeConverter ()->convertType (opTy.getElementType ());
325+ if (!elemTy)
326+ return rewriter.notifyMatchFailure (loc, " cannot convert element type" );
327+
311328 Type resultTy = convertMemRefType (opTy, getTypeConverter ());
312329
313330 if (!resultTy) {
314331 return rewriter.notifyMatchFailure (op.getLoc (),
315332 " cannot convert result type" );
316- }
333+
334+ Type globalType;
335+ if (opTy.getRank () == 0 ) {
336+ globalType = elemTy;
337+ } else {
338+ SmallVector<int64_t > shape (opTy.getShape ().begin (),
339+ opTy.getShape ().end ());
340+ globalType = emitc::ArrayType::get (shape, elemTy);
341+ }
317342
318343 if (opTy.getRank () == 0 ) {
319- emitc::LValueType lvalueType = emitc::LValueType::get (resultTy );
344+ emitc::LValueType lvalueType = emitc::LValueType::get (globalType );
320345 emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create (
321346 rewriter, op.getLoc (), lvalueType, operands.getNameAttr ());
322- emitc::PointerType pointerType = emitc::PointerType::get (resultTy);
323- rewriter.replaceOpWithNewOp <emitc::ApplyOp>(
324- op, pointerType, rewriter.getStringAttr (" &" ), globalLValue);
347+ emitc::PointerType pointerType = emitc::PointerType::get (globalType);
348+ auto addrOf = rewriter.create <emitc::ApplyOp>(
349+ loc, ptrElemTy, rewriter.getStringAttr (" &" ), globalLVal.getResult ());
350+
351+ auto arrayTy = emitc::ArrayType::get ({1 }, globalType);
352+ auto ptrArrayTy = emitc::PointerType::get (arrayTy);
353+ auto casted =
354+ rewriter.create <emitc::CastOp>(loc, ptrArrayTy, addrOf.getResult ());
355+ rewriter.replaceOp (op, casted.getResult ());
325356 return success ();
326357 }
327- rewriter.replaceOpWithNewOp <emitc::GetGlobalOp>(op, resultTy,
328- operands.getNameAttr ());
358+
359+ auto getGlobal = rewriter.create <emitc::GetGlobalOp>(
360+ loc, globalType, operands.getNameAttr ());
361+
362+ SmallVector<Value> indices;
363+ for (unsigned i = 0 ; i < opTy.getRank (); ++i) {
364+ auto zero = rewriter.create <emitc::ConstantOp>(
365+ loc, rewriter.getIndexType (), rewriter.getIndexAttr (0 ));
366+ indices.push_back (zero);
367+ }
368+
369+ auto current = rewriter.create <emitc::SubscriptOp>(
370+ loc, emitc::LValueType::get (elemTy), getGlobal.getResult (), indices);
371+
372+ auto ptrElemTy = emitc::PointerType::get (opTy.getElementType ());
373+ auto addrOf = rewriter.create <emitc::ApplyOp>(
374+ loc, ptrElemTy, rewriter.getStringAttr (" &" ), current.getResult ());
375+
376+ auto casted =
377+ rewriter.create <emitc::CastOp>(loc, resultTy, addrOf.getResult ());
378+
379+ rewriter.replaceOp (op, casted.getResult ());
329380 return success ();
330381 }
331382};
0 commit comments