@@ -80,10 +80,29 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
8080
8181 LogicalResult matchAndRewrite (mlir::emitc::ClassOp classOp,
8282 PatternRewriter &rewriter) const override {
83- MLIRContext *context = rewriter.getContext ();
84-
83+ mlir::MLIRContext *context = rewriter.getContext ();
84+ emitc::OpaqueType stringViewType =
85+ mlir::emitc::OpaqueType::get (rewriter.getContext (), " std::string_view" );
86+ emitc::OpaqueType charType =
87+ mlir::emitc::OpaqueType::get (rewriter.getContext (), " char" );
8588 emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get (
86- context, " const std::map<std::string, char*>" );
89+ rewriter.getContext (), " const std::map<std::string, char*>" );
90+
91+ FunctionType funcType =
92+ rewriter.getFunctionType ({stringViewType}, {charType});
93+ emitc::FuncOp executeFunc =
94+ classOp.lookupSymbol <mlir::emitc::FuncOp>(" execute" );
95+ if (executeFunc)
96+ rewriter.setInsertionPoint (executeFunc);
97+ else
98+ classOp.emitError () << " ClassOp must contain a function named 'execute' "
99+ " to add reflection map" ;
100+
101+ emitc::FuncOp getBufferFunc = rewriter.create <mlir::emitc::FuncOp>(
102+ classOp.getLoc (), " getBufferForName" , funcType);
103+
104+ Block *funcBody = getBufferFunc.addEntryBlock ();
105+ rewriter.setInsertionPointToStart (funcBody);
87106
88107 // Collect all field names
89108 std::vector<std::pair<std::string, std::string>> fieldNames;
@@ -110,45 +129,44 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
110129 }
111130 });
112131
113- // Construct the map initializer string
114132 std::string mapInitializer = " { " ;
115133 for (size_t i = 0 ; i < fieldNames.size (); ++i) {
116134 mapInitializer += " { \" " + fieldNames[i].first + " \" , " +
117135 " reinterpret_cast<char*>(&" + fieldNames[i].second +
118- " )" ;
119- mapInitializer += " }" ;
136+ " )" ,
137+ mapInitializer += " }" ;
120138 if (i < fieldNames.size () - 1 )
121139 mapInitializer += " , " ;
122140 }
123141 mapInitializer += " }" ;
124142
125- emitc::OpaqueType returnType = mlir::emitc::OpaqueType::get (
126- context, " const std::map<std::string, char*>" );
127-
128- emitc::FuncOp executeFunc =
129- classOp.lookupSymbol <mlir::emitc::FuncOp>(" execute" );
130- if (executeFunc)
131- rewriter.setInsertionPoint (executeFunc);
132- else
133- classOp.emitError () << " ClassOp must contain a function named 'execute' "
134- " to add reflection map" ;
135-
136- // Create the getFeatures function
137- emitc::FuncOp getFeaturesFunc = rewriter.create <mlir::emitc::FuncOp>(
138- classOp.getLoc (), " getFeatures" ,
139- rewriter.getFunctionType ({}, {returnType}));
140-
141- // Add the body of the getFeatures function
142- Block *funcBody = getFeaturesFunc.addEntryBlock ();
143- rewriter.setInsertionPointToStart (funcBody);
143+ emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get (
144+ context, " std::map<std::string, char*>::const_iterator" );
144145
145- // Create the constant map
146146 emitc::ConstantOp bufferMap = rewriter.create <emitc::ConstantOp>(
147147 classOp.getLoc (), mapType,
148148 emitc::OpaqueAttr::get (context, mapInitializer));
149149
150- rewriter.create <mlir::emitc::ReturnOp>(classOp.getLoc (),
151- bufferMap.getResult ());
150+ mlir::Value nameArg = getBufferFunc.getArgument (0 );
151+ emitc::CallOpaqueOp it = rewriter.create <emitc::CallOpaqueOp>(
152+ classOp.getLoc (), iteratorType, rewriter.getStringAttr (" find" ),
153+ mlir::ValueRange{bufferMap.getResult (), nameArg});
154+ emitc::CallOpaqueOp endIt = rewriter.create <emitc::CallOpaqueOp>(
155+ classOp.getLoc (), iteratorType, rewriter.getStringAttr (" end" ),
156+ bufferMap.getResult ());
157+ emitc::CallOpaqueOp isEnd = rewriter.create <emitc::CallOpaqueOp>(
158+ classOp.getLoc (), rewriter.getI1Type (),
159+ " operator==" , mlir::ValueRange{it.getResult (0 ), endIt.getResult (0 )});
160+ emitc::ConstantOp nullPtr = rewriter.create <emitc::ConstantOp>(
161+ classOp.getLoc (), charType, emitc::OpaqueAttr::get (context, " nullptr" ));
162+ emitc::CallOpaqueOp second = rewriter.create <emitc::CallOpaqueOp>(
163+ classOp.getLoc (), charType, " second" , it.getResult (0 ));
164+
165+ emitc::ConditionalOp result = rewriter.create <emitc::ConditionalOp>(
166+ classOp.getLoc (), charType, isEnd.getResult (0 ), nullPtr.getResult (),
167+ second.getResult (0 ));
168+
169+ rewriter.create <emitc::ReturnOp>(classOp.getLoc (), result.getResult ());
152170
153171 return success ();
154172 }
0 commit comments