@@ -23,16 +23,48 @@ namespace emitc {
2323#include " mlir/Dialect/EmitC/Transforms/Passes.h.inc"
2424
2525namespace {
26+ constexpr const char *kMapLibraryHeader = " map" ;
27+ constexpr const char *kStringLibraryHeader = " string" ;
2628class AddReflectionMapPass
2729 : public impl::AddReflectionMapPassBase<AddReflectionMapPass> {
2830 using AddReflectionMapPassBase::AddReflectionMapPassBase;
2931 void runOnOperation () override {
30- Operation *rootOp = getOperation ();
32+ mlir::ModuleOp module = getOperation ();
3133
3234 RewritePatternSet patterns (&getContext ());
3335 populateAddReflectionMapPatterns (patterns, namedAttribute);
3436
35- walkAndApplyPatterns (rootOp, std::move (patterns));
37+ walkAndApplyPatterns (module , std::move (patterns));
38+ bool hasMap = false ;
39+ bool hasString = false ;
40+ for (auto &op : *module .getBody ()) {
41+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
42+ if (!includeOp)
43+ continue ;
44+ if (includeOp.getIsStandardInclude ()) {
45+ if (includeOp.getInclude () == kMapLibraryHeader )
46+ hasMap = true ;
47+ if (includeOp.getInclude () == kStringLibraryHeader )
48+ hasString = true ;
49+ }
50+ }
51+
52+ if (hasMap && hasString)
53+ return ;
54+
55+ mlir::OpBuilder builder (module .getBody (), module .getBody ()->begin ());
56+ if (!hasMap) {
57+ StringAttr includeAttr = builder.getStringAttr (kMapLibraryHeader );
58+ builder.create <mlir::emitc::IncludeOp>(
59+ module .getLoc (), includeAttr,
60+ /* is_standard_include=*/ builder.getUnitAttr ());
61+ }
62+ if (!hasString) {
63+ StringAttr includeAttr = builder.getStringAttr (kStringLibraryHeader );
64+ builder.create <emitc::IncludeOp>(
65+ module .getLoc (), includeAttr,
66+ /* is_standard_include=*/ builder.getUnitAttr ());
67+ }
3668 }
3769};
3870
@@ -50,16 +82,20 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
5082 mlir::MLIRContext *context = rewriter.getContext ();
5183 emitc::OpaqueType stringViewType =
5284 mlir::emitc::OpaqueType::get (rewriter.getContext (), " std::string_view" );
53- emitc::OpaqueType charPtrType =
85+ emitc::OpaqueType charType =
5486 mlir::emitc::OpaqueType::get (rewriter.getContext (), " char" );
5587 emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get (
5688 rewriter.getContext (), " const std::map<std::string, char*>" );
5789
5890 FunctionType funcType =
59- rewriter.getFunctionType ({stringViewType}, {charPtrType });
91+ rewriter.getFunctionType ({stringViewType}, {charType });
6092 emitc::FuncOp executeFunc =
6193 classOp.lookupSymbol <mlir::emitc::FuncOp>(" execute" );
62- rewriter.setInsertionPoint (executeFunc);
94+ if (executeFunc)
95+ rewriter.setInsertionPoint (executeFunc);
96+ else
97+ classOp.emitError () << " ClassOp must contain a function named 'execute' "
98+ " to add reflection map" ;
6399
64100 emitc::FuncOp getBufferFunc = rewriter.create <mlir::emitc::FuncOp>(
65101 classOp.getLoc (), " getBufferForName" , funcType);
@@ -74,9 +110,8 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
74110 fieldOp->getAttrDictionary ().get (" attrs" )) {
75111 if (DictionaryAttr innerDictAttr =
76112 dyn_cast<mlir::DictionaryAttr>(attrsAttr)) {
77- auto indexPathAttr = innerDictAttr.getNamed (attributeName);
78- ArrayAttr arrayAttr =
79- dyn_cast<mlir::ArrayAttr>(indexPathAttr->getValue ());
113+ ArrayAttr arrayAttr = dyn_cast<mlir::ArrayAttr>(
114+ innerDictAttr.getNamed (attributeName)->getValue ());
80115 if (!arrayAttr.empty ()) {
81116 StringAttr stringAttr = dyn_cast<mlir::StringAttr>(arrayAttr[0 ]);
82117 std::string indexPath = stringAttr.getValue ().str ();
@@ -122,13 +157,12 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
122157 classOp.getLoc (), rewriter.getI1Type (),
123158 " operator==" , mlir::ValueRange{it.getResult (0 ), endIt.getResult (0 )});
124159 emitc::ConstantOp nullPtr = rewriter.create <emitc::ConstantOp>(
125- classOp.getLoc (), charPtrType,
126- emitc::OpaqueAttr::get (context, " nullptr" ));
160+ classOp.getLoc (), charType, emitc::OpaqueAttr::get (context, " nullptr" ));
127161 emitc::CallOpaqueOp second = rewriter.create <emitc::CallOpaqueOp>(
128- classOp.getLoc (), charPtrType , " second" , it.getResult (0 ));
162+ classOp.getLoc (), charType , " second" , it.getResult (0 ));
129163
130164 emitc::ConditionalOp result = rewriter.create <emitc::ConditionalOp>(
131- classOp.getLoc (), charPtrType , isEnd.getResult (0 ), nullPtr.getResult (),
165+ classOp.getLoc (), charType , isEnd.getResult (0 ), nullPtr.getResult (),
132166 second.getResult (0 ));
133167
134168 rewriter.create <emitc::ReturnOp>(classOp.getLoc (), result.getResult ());
0 commit comments