@@ -80,29 +80,10 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
8080
8181 LogicalResult matchAndRewrite (mlir::emitc::ClassOp classOp,
8282 PatternRewriter &rewriter) const override {
83- mlir::MLIRContext *context = rewriter.getContext ();
84- emitc::OpaqueType stringViewType =
85- mlir::emitc::OpaqueType::get (rewriter.getContext (), " std::string_view" );
86- emitc::OpaqueType charType =
87- mlir::emitc::OpaqueType::get (rewriter.getContext (), " char" );
88- emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get (
89- rewriter.getContext (), " const std::map<std::string, char*>" );
90-
91- FunctionType funcType =
92- rewriter.getFunctionType ({stringViewType}, {charType});
93- emitc::FuncOp executeFunc =
94- classOp.lookupSymbol <mlir::emitc::FuncOp>(" execute" );
95- if (executeFunc)
96- rewriter.setInsertionPoint (executeFunc);
97- else
98- classOp.emitError () << " ClassOp must contain a function named 'execute' "
99- " to add reflection map" ;
83+ MLIRContext *context = rewriter.getContext ();
10084
101- emitc::FuncOp getBufferFunc = rewriter.create <mlir::emitc::FuncOp>(
102- classOp.getLoc (), " getBufferForName" , funcType);
103-
104- Block *funcBody = getBufferFunc.addEntryBlock ();
105- rewriter.setInsertionPointToStart (funcBody);
85+ emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get (
86+ context, " const std::map<std::string, char*>" );
10687
10788 // Collect all field names
10889 std::vector<std::pair<std::string, std::string>> fieldNames;
@@ -129,44 +110,45 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
129110 }
130111 });
131112
113+ // Construct the map initializer string
132114 std::string mapInitializer = " { " ;
133115 for (size_t i = 0 ; i < fieldNames.size (); ++i) {
134116 mapInitializer += " { \" " + fieldNames[i].first + " \" , " +
135117 " reinterpret_cast<char*>(&" + fieldNames[i].second +
136- " )" ,
137- mapInitializer += " }" ;
118+ " )" ;
119+ mapInitializer += " }" ;
138120 if (i < fieldNames.size () - 1 )
139121 mapInitializer += " , " ;
140122 }
141123 mapInitializer += " }" ;
142124
143- emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get (
144- context, " std::map<std::string, char*>::const_iterator" );
125+ emitc::OpaqueType returnType = mlir::emitc::OpaqueType::get (
126+ context, " const std::map<std::string, char*>" );
127+
128+ emitc::FuncOp executeFunc =
129+ classOp.lookupSymbol <mlir::emitc::FuncOp>(" execute" );
130+ if (executeFunc)
131+ rewriter.setInsertionPoint (executeFunc);
132+ else
133+ classOp.emitError () << " ClassOp must contain a function named 'execute' "
134+ " to add reflection map" ;
135+
136+ // Create the getFeatures function
137+ emitc::FuncOp getFeaturesFunc = rewriter.create <mlir::emitc::FuncOp>(
138+ classOp.getLoc (), " getFeatures" ,
139+ rewriter.getFunctionType ({}, {returnType}));
140+
141+ // Add the body of the getFeatures function
142+ Block *funcBody = getFeaturesFunc.addEntryBlock ();
143+ rewriter.setInsertionPointToStart (funcBody);
145144
145+ // Create the constant map
146146 emitc::ConstantOp bufferMap = rewriter.create <emitc::ConstantOp>(
147147 classOp.getLoc (), mapType,
148148 emitc::OpaqueAttr::get (context, mapInitializer));
149149
150- mlir::Value nameArg = getBufferFunc.getArgument (0 );
151- emitc::CallOpaqueOp it = rewriter.create <emitc::CallOpaqueOp>(
152- classOp.getLoc (), iteratorType, rewriter.getStringAttr (" find" ),
153- mlir::ValueRange{bufferMap.getResult (), nameArg});
154- emitc::CallOpaqueOp endIt = rewriter.create <emitc::CallOpaqueOp>(
155- classOp.getLoc (), iteratorType, rewriter.getStringAttr (" end" ),
156- bufferMap.getResult ());
157- emitc::CallOpaqueOp isEnd = rewriter.create <emitc::CallOpaqueOp>(
158- classOp.getLoc (), rewriter.getI1Type (),
159- " operator==" , mlir::ValueRange{it.getResult (0 ), endIt.getResult (0 )});
160- emitc::ConstantOp nullPtr = rewriter.create <emitc::ConstantOp>(
161- classOp.getLoc (), charType, emitc::OpaqueAttr::get (context, " nullptr" ));
162- emitc::CallOpaqueOp second = rewriter.create <emitc::CallOpaqueOp>(
163- classOp.getLoc (), charType, " second" , it.getResult (0 ));
164-
165- emitc::ConditionalOp result = rewriter.create <emitc::ConditionalOp>(
166- classOp.getLoc (), charType, isEnd.getResult (0 ), nullPtr.getResult (),
167- second.getResult (0 ));
168-
169- rewriter.create <emitc::ReturnOp>(classOp.getLoc (), result.getResult ());
150+ rewriter.create <mlir::emitc::ReturnOp>(classOp.getLoc (),
151+ bufferMap.getResult ());
170152
171153 return success ();
172154 }
0 commit comments