1616#include " mlir/Dialect/EmitC/IR/EmitC.h"
1717#include " mlir/Dialect/MemRef/IR/MemRef.h"
1818#include " mlir/IR/Builders.h"
19+ #include " mlir/IR/BuiltinTypes.h"
1920#include " mlir/IR/PatternMatch.h"
21+ #include " mlir/IR/TypeRange.h"
2022#include " mlir/Transforms/DialectConversion.h"
2123
2224using namespace mlir ;
@@ -77,23 +79,13 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
7779 }
7880};
7981
80- struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
81- using OpConversionPattern::OpConversionPattern;
82-
83- LogicalResult
84- matchAndRewrite (memref::CopyOp op, OpAdaptor operands,
85- ConversionPatternRewriter &rewriter) const override {
86- return failure ();
87- }
88- };
89-
9082struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
9183 using OpConversionPattern::OpConversionPattern;
9284
9385 LogicalResult
9486 matchAndRewrite (memref::GlobalOp op, OpAdaptor operands,
9587 ConversionPatternRewriter &rewriter) const override {
96- auto type = op.getType ();
88+ MemRefType type = op.getType ();
9789 if (!op.getType ().hasStaticShape ()) {
9890 return rewriter.notifyMatchFailure (
9991 op.getLoc (), " cannot transform global with dynamic shape" );
@@ -105,22 +97,12 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
10597 op.getLoc (), " global variable with alignment requirement is "
10698 " currently not supported" );
10799 }
108- // auto resultTy =
109- // convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
100+
110101 Type resultTy;
111- Type elementType = getTypeConverter ()->convertType (type.getElementType ());
112- auto shape = type.getShape ();
113-
114- if (shape.empty ()) {
115- if (emitc::isSupportedFloatType (elementType)) {
116- resultTy = rewriter.getF32Type ();
117- }
118- if (emitc::isSupportedIntegerType (elementType)) {
119- resultTy = rewriter.getIntegerType (elementType.getIntOrFloatBitWidth ());
120- }
121- } else {
122- resultTy = emitc::ArrayType::get (shape, elementType);
123- }
102+ if (type.getRank () == 0 )
103+ resultTy = getTypeConverter ()->convertType (type.getElementType ());
104+ else
105+ resultTy = getTypeConverter ()->convertType (type);
124106
125107 if (!resultTy) {
126108 return rewriter.notifyMatchFailure (op.getLoc (),
@@ -140,7 +122,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
140122 bool externSpecifier = !staticSpecifier;
141123
142124 Attribute initialValue = operands.getInitialValueAttr ();
143- if (op. getType () .getRank () == 0 ) {
125+ if (type .getRank () == 0 ) {
144126 auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue ());
145127 initialValue = elementsAttr.getSplatValue <Attribute>();
146128 }
@@ -162,7 +144,17 @@ struct ConvertGetGlobal final
162144 matchAndRewrite (memref::GetGlobalOp op, OpAdaptor operands,
163145 ConversionPatternRewriter &rewriter) const override {
164146
165- auto resultTy = getTypeConverter ()->convertType (op.getType ());
147+ MemRefType type = op.getType ();
148+ Type resultTy;
149+ if (type.getRank () == 0 )
150+ resultTy = emitc::LValueType::get (
151+ getTypeConverter ()->convertType (type.getElementType ()));
152+ else
153+ resultTy = getTypeConverter ()->convertType (type);
154+
155+ if (!resultTy)
156+ return rewriter.notifyMatchFailure (op.getLoc (), " cannot convert type" );
157+
166158 if (!resultTy) {
167159 return rewriter.notifyMatchFailure (op.getLoc (),
168160 " cannot convert result type" );
0 commit comments