@@ -80,29 +80,11 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
8080
8181 LogicalResult matchAndRewrite (mlir::emitc::ClassOp classOp,
8282 PatternRewriter &rewriter) const override {
83- mlir::MLIRContext *context = rewriter.getContext ();
84- emitc::OpaqueType stringViewType =
85- mlir::emitc::OpaqueType::get (rewriter.getContext (), " std::string_view" );
86- emitc::OpaqueType charType =
87- mlir::emitc::OpaqueType::get (rewriter.getContext (), " char" );
88- emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get (
89- rewriter.getContext (), " const std::map<std::string, char*>" );
90-
91- FunctionType funcType =
92- rewriter.getFunctionType ({stringViewType}, {charType});
93- emitc::FuncOp executeFunc =
94- classOp.lookupSymbol <mlir::emitc::FuncOp>(" execute" );
95- if (executeFunc)
96- rewriter.setInsertionPoint (executeFunc);
97- else
98- classOp.emitError () << " ClassOp must contain a function named 'execute' "
99- " to add reflection map" ;
83+ MLIRContext *context = rewriter.getContext ();
10084
101- emitc::FuncOp getBufferFunc = rewriter.create <mlir::emitc::FuncOp>(
102- classOp.getLoc (), " getBufferForName" , funcType);
103-
104- Block *funcBody = getBufferFunc.addEntryBlock ();
105- rewriter.setInsertionPointToStart (funcBody);
85+ // Define the opaque types
86+ emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get (
87+ context, " const std::map<std::string, char*>" );
10688
10789 // Collect all field names
10890 std::vector<std::pair<std::string, std::string>> fieldNames;
@@ -138,35 +120,35 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
138120 if (i < fieldNames.size () - 1 )
139121 mapInitializer += " , " ;
140122 }
141- mapInitializer += " }} " ;
123+ mapInitializer += " }" ;
142124
143- emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get (
144- context, " std::map<std::string, char*>::const_iterator" );
125+ emitc::FuncOp executeFunc =
126+ classOp.lookupSymbol <mlir::emitc::FuncOp>(" execute" );
127+ if (executeFunc)
128+ rewriter.setInsertionPoint (executeFunc);
129+ else
130+ classOp.emitError () << " ClassOp must contain a function named 'execute' "
131+ " to add reflection map" ;
132+
133+ emitc::OpaqueType returnType = mlir::emitc::OpaqueType::get (
134+ context, " const std::map<std::string, char*>" );
135+
136+ // Create the getFeatures function
137+ emitc::FuncOp getFeaturesFunc = rewriter.create <mlir::emitc::FuncOp>(
138+ classOp.getLoc (), " getFeatures" ,
139+ rewriter.getFunctionType ({}, {returnType}));
140+
141+ // Add the body of the getFeatures function
142+ Block *funcBody = getFeaturesFunc.addEntryBlock ();
143+ rewriter.setInsertionPointToStart (funcBody);
145144
145+ // Create the constant map
146146 emitc::ConstantOp bufferMap = rewriter.create <emitc::ConstantOp>(
147147 classOp.getLoc (), mapType,
148148 emitc::OpaqueAttr::get (context, mapInitializer));
149149
150- mlir::Value nameArg = getBufferFunc.getArgument (0 );
151- emitc::CallOpaqueOp it = rewriter.create <emitc::CallOpaqueOp>(
152- classOp.getLoc (), iteratorType, rewriter.getStringAttr (" find" ),
153- mlir::ValueRange{bufferMap.getResult (), nameArg});
154- emitc::CallOpaqueOp endIt = rewriter.create <emitc::CallOpaqueOp>(
155- classOp.getLoc (), iteratorType, rewriter.getStringAttr (" end" ),
156- bufferMap.getResult ());
157- emitc::CallOpaqueOp isEnd = rewriter.create <emitc::CallOpaqueOp>(
158- classOp.getLoc (), rewriter.getI1Type (),
159- " operator==" , mlir::ValueRange{it.getResult (0 ), endIt.getResult (0 )});
160- emitc::ConstantOp nullPtr = rewriter.create <emitc::ConstantOp>(
161- classOp.getLoc (), charType, emitc::OpaqueAttr::get (context, " nullptr" ));
162- emitc::CallOpaqueOp second = rewriter.create <emitc::CallOpaqueOp>(
163- classOp.getLoc (), charType, " second" , it.getResult (0 ));
164-
165- emitc::ConditionalOp result = rewriter.create <emitc::ConditionalOp>(
166- classOp.getLoc (), charType, isEnd.getResult (0 ), nullPtr.getResult (),
167- second.getResult (0 ));
168-
169- rewriter.create <emitc::ReturnOp>(classOp.getLoc (), result.getResult ());
150+ rewriter.create <mlir::emitc::ReturnOp>(classOp.getLoc (),
151+ bufferMap.getResult ());
170152
171153 return success ();
172154 }
0 commit comments