@@ -79,13 +79,23 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
7979 }
8080};
8181
82+ Type convertMemRefType (MemRefType opTy, const TypeConverter *typeConverter) {
83+ Type resultTy;
84+ if (opTy.getRank () == 0 ) {
85+ resultTy = typeConverter->convertType (mlir::getElementTypeOrSelf (opTy));
86+ } else {
87+ resultTy = typeConverter->convertType (opTy);
88+ }
89+ return resultTy;
90+ }
91+
8292struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
8393 using OpConversionPattern::OpConversionPattern;
8494
8595 LogicalResult
8696 matchAndRewrite (memref::GlobalOp op, OpAdaptor operands,
8797 ConversionPatternRewriter &rewriter) const override {
88- MemRefType type = op.getType ();
98+ MemRefType opTy = op.getType ();
8999 if (!op.getType ().hasStaticShape ()) {
90100 return rewriter.notifyMatchFailure (
91101 op.getLoc (), " cannot transform global with dynamic shape" );
@@ -98,11 +108,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
98108 " currently not supported" );
99109 }
100110
101- Type resultTy;
102- if (type.getRank () == 0 )
103- resultTy = getTypeConverter ()->convertType (type.getElementType ());
104- else
105- resultTy = getTypeConverter ()->convertType (type);
111+ Type resultTy = convertMemRefType (opTy, getTypeConverter ());
106112
107113 if (!resultTy) {
108114 return rewriter.notifyMatchFailure (op.getLoc (),
@@ -122,7 +128,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
122128 bool externSpecifier = !staticSpecifier;
123129
124130 Attribute initialValue = operands.getInitialValueAttr ();
125- if (type .getRank () == 0 ) {
131+ if (opTy .getRank () == 0 ) {
126132 auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue ());
127133 initialValue = elementsAttr.getSplatValue <Attribute>();
128134 }
@@ -144,16 +150,10 @@ struct ConvertGetGlobal final
144150 matchAndRewrite (memref::GetGlobalOp op, OpAdaptor operands,
145151 ConversionPatternRewriter &rewriter) const override {
146152
147- MemRefType type = op.getType ();
148- Type resultTy;
149- if (type.getRank () == 0 )
150- resultTy = emitc::LValueType::get (
151- getTypeConverter ()->convertType (type.getElementType ()));
152- else
153- resultTy = getTypeConverter ()->convertType (type);
154-
155- if (!resultTy)
156- return rewriter.notifyMatchFailure (op.getLoc (), " cannot convert type" );
153+ MemRefType opTy = op.getType ();
154+ Type resultTy = convertMemRefType (opTy, getTypeConverter ());
155+ if (opTy.getRank () == 0 )
156+ resultTy = emitc::LValueType::get (resultTy);
157157
158158 if (!resultTy) {
159159 return rewriter.notifyMatchFailure (op.getLoc (),
0 commit comments