@@ -89,38 +89,32 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
8989 // Collect all field names
9090 std::vector<std::pair<std::string, std::string>> fieldNames;
9191 classOp.walk ([&](mlir::emitc::FieldOp fieldOp) {
92- if (mlir::Attribute attrsAttr =
93- fieldOp->getAttrDictionary ().get (" attrs" )) {
94- if (DictionaryAttr innerDictAttr =
95- dyn_cast<mlir::DictionaryAttr>(attrsAttr)) {
96- ArrayAttr arrayAttr = dyn_cast<mlir::ArrayAttr>(
97- innerDictAttr.getNamed (attributeName)->getValue ());
98- if (!arrayAttr.empty ()) {
99- StringAttr stringAttr = dyn_cast<mlir::StringAttr>(arrayAttr[0 ]);
100- std::string indexPath = stringAttr.getValue ().str ();
101- fieldNames.emplace_back (indexPath, fieldOp.getName ().str ());
102- }
103- if (arrayAttr.size () > 1 ) {
104- fieldOp.emitError () << attributeName
105- << " attribute must "
106- " contain at most one value, but found "
107- << arrayAttr.size () << " values." ;
108- return ;
109- }
110- }
92+ if (DictionaryAttr innerDictAttr = dyn_cast<mlir::DictionaryAttr>(
93+ fieldOp->getAttrDictionary ().get (" attrs" ))) {
94+ ArrayAttr arrayAttr = cast<mlir::ArrayAttr>(
95+ innerDictAttr.getNamed (attributeName)->getValue ());
96+ StringAttr stringAttr = cast<mlir::StringAttr>(arrayAttr[0 ]);
97+ fieldNames.emplace_back (stringAttr.getValue ().str (),
98+ fieldOp.getName ().str ());
99+
100+ } else {
101+ fieldOp.emitError ()
102+ << " FieldOp must have a dictionary attribute named '"
103+ << attributeName << " '"
104+ << " with an array containing a string attribute" ;
111105 }
112106 });
113107
114- std::string mapInitializer = " { " ;
108+ std::stringstream ss;
109+ ss << " { " ;
115110 for (size_t i = 0 ; i < fieldNames.size (); ++i) {
116- mapInitializer += " { \" " + fieldNames[i].first + " \" , " +
117- " reinterpret_cast<char*>(&" + fieldNames[i].second +
118- " )" ,
119- mapInitializer += " }" ;
120- if (i < fieldNames.size () - 1 )
121- mapInitializer += " , " ;
111+ ss << " { \" " << fieldNames[i].first << " \" , reinterpret_cast<char*>(&"
112+ << fieldNames[i].second << " ) }" ;
113+ if (i < fieldNames.size () - 1 ) {
114+ ss << " , " ;
115+ }
122116 }
123- mapInitializer += " }" ;
117+ ss << " }" ;
124118
125119 emitc::FuncOp executeFunc =
126120 classOp.lookupSymbol <mlir::emitc::FuncOp>(" execute" );
@@ -144,8 +138,7 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
144138
145139 // Create the constant map
146140 emitc::ConstantOp bufferMap = rewriter.create <emitc::ConstantOp>(
147- classOp.getLoc (), mapType,
148- emitc::OpaqueAttr::get (context, mapInitializer));
141+ classOp.getLoc (), mapType, emitc::OpaqueAttr::get (context, ss.str ()));
149142
150143 rewriter.create <mlir::emitc::ReturnOp>(classOp.getLoc (),
151144 bufferMap.getResult ());
0 commit comments