@@ -87,28 +87,13 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
8787 }
8888};
8989
90- static Type convertGlobalMemrefTypeToEmitc (MemRefType type,
91- const TypeConverter &typeConverter) {
92- Type elementType = typeConverter.convertType (type.getElementType ());
93- Type arrayTy = elementType;
94- // Shape has the outermost dim at index 0, so need to walk it backwards
95- auto shape = type.getShape ();
96- if (shape.empty ()) {
97- arrayTy = emitc::ArrayType::get ({1 }, arrayTy);
98- } else {
99- // For non-zero dimensions, use the original shape
100- arrayTy = emitc::ArrayType::get (shape, arrayTy);
101- }
102- return arrayTy;
103- }
104-
10590struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
10691 using OpConversionPattern::OpConversionPattern;
10792
10893 LogicalResult
10994 matchAndRewrite (memref::GlobalOp op, OpAdaptor operands,
11095 ConversionPatternRewriter &rewriter) const override {
111-
96+ auto type = op. getType ();
11297 if (!op.getType ().hasStaticShape ()) {
11398 return rewriter.notifyMatchFailure (
11499 op.getLoc (), " cannot transform global with dynamic shape" );
@@ -120,8 +105,23 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
120105 op.getLoc (), " global variable with alignment requirement is "
121106 " currently not supported" );
122107 }
123- auto resultTy =
124- convertGlobalMemrefTypeToEmitc (op.getType (), *getTypeConverter ());
108+ // auto resultTy =
109+ // convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
110+ Type resultTy;
111+ Type elementType = getTypeConverter ()->convertType (type.getElementType ());
112+ auto shape = type.getShape ();
113+
114+ if (shape.empty ()) {
115+ if (emitc::isSupportedFloatType (elementType)) {
116+ resultTy = rewriter.getF32Type ();
117+ }
118+ if (emitc::isSupportedIntegerType (elementType)) {
119+ resultTy = rewriter.getIntegerType (elementType.getIntOrFloatBitWidth ());
120+ }
121+ } else {
122+ resultTy = emitc::ArrayType::get (shape, elementType);
123+ }
124+
125125 if (!resultTy) {
126126 return rewriter.notifyMatchFailure (op.getLoc (),
127127 " cannot convert result type" );
@@ -142,12 +142,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
142142 Attribute initialValue = operands.getInitialValueAttr ();
143143 if (op.getType ().getRank () == 0 ) {
144144 auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue ());
145- auto scalarValue = elementsAttr.getSplatValue <Attribute>();
146-
147- // Convert scalar value to single-element array
148- initialValue = DenseElementsAttr::get (
149- RankedTensorType::get ({1 }, elementsAttr.getElementType ()),
150- {scalarValue});
145+ initialValue = elementsAttr.getSplatValue <Attribute>();
151146 }
152147 if (isa_and_present<UnitAttr>(initialValue))
153148 initialValue = {};
0 commit comments