1616#include " mlir/Dialect/EmitC/IR/EmitC.h"
1717#include " mlir/Dialect/MemRef/IR/MemRef.h"
1818#include " mlir/IR/Builders.h"
19+ #include " mlir/IR/BuiltinTypes.h"
1920#include " mlir/IR/PatternMatch.h"
21+ #include " mlir/IR/TypeRange.h"
2022#include " mlir/Transforms/DialectConversion.h"
2123
2224using namespace mlir ;
@@ -77,13 +79,23 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
7779 }
7880};
7981
82+ Type convertMemRefType (MemRefType opTy, const TypeConverter *typeConverter) {
83+ Type resultTy;
84+ if (opTy.getRank () == 0 ) {
85+ resultTy = typeConverter->convertType (mlir::getElementTypeOrSelf (opTy));
86+ } else {
87+ resultTy = typeConverter->convertType (opTy);
88+ }
89+ return resultTy;
90+ }
91+
8092struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
8193 using OpConversionPattern::OpConversionPattern;
8294
8395 LogicalResult
8496 matchAndRewrite (memref::GlobalOp op, OpAdaptor operands,
8597 ConversionPatternRewriter &rewriter) const override {
86-
98+ MemRefType opTy = op. getType ();
8799 if (!op.getType ().hasStaticShape ()) {
88100 return rewriter.notifyMatchFailure (
89101 op.getLoc (), " cannot transform global with dynamic shape" );
@@ -95,7 +107,9 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
95107 op.getLoc (), " global variable with alignment requirement is "
96108 " currently not supported" );
97109 }
98- auto resultTy = getTypeConverter ()->convertType (op.getType ());
110+
111+ Type resultTy = convertMemRefType (opTy, getTypeConverter ());
112+
99113 if (!resultTy) {
100114 return rewriter.notifyMatchFailure (op.getLoc (),
101115 " cannot convert result type" );
@@ -114,6 +128,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
114128 bool externSpecifier = !staticSpecifier;
115129
116130 Attribute initialValue = operands.getInitialValueAttr ();
131+ if (opTy.getRank () == 0 ) {
132+ auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue ());
133+ initialValue = elementsAttr.getSplatValue <Attribute>();
134+ }
117135 if (isa_and_present<UnitAttr>(initialValue))
118136 initialValue = {};
119137
@@ -132,11 +150,23 @@ struct ConvertGetGlobal final
132150 matchAndRewrite (memref::GetGlobalOp op, OpAdaptor operands,
133151 ConversionPatternRewriter &rewriter) const override {
134152
135- auto resultTy = getTypeConverter ()->convertType (op.getType ());
153+ MemRefType opTy = op.getType ();
154+ Type resultTy = convertMemRefType (opTy, getTypeConverter ());
155+
136156 if (!resultTy) {
137157 return rewriter.notifyMatchFailure (op.getLoc (),
138158 " cannot convert result type" );
139159 }
160+
161+ if (opTy.getRank () == 0 ) {
162+ emitc::LValueType lvalueType = emitc::LValueType::get (resultTy);
163+ emitc::GetGlobalOp globalLValue = rewriter.create <emitc::GetGlobalOp>(
164+ op.getLoc (), lvalueType, operands.getNameAttr ());
165+ emitc::PointerType pointerType = emitc::PointerType::get (resultTy);
166+ rewriter.replaceOpWithNewOp <emitc::ApplyOp>(
167+ op, pointerType, rewriter.getStringAttr (" &" ), globalLValue);
168+ return success ();
169+ }
140170 rewriter.replaceOpWithNewOp <emitc::GetGlobalOp>(op, resultTy,
141171 operands.getNameAttr ());
142172 return success ();
0 commit comments