66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir-c/BuiltinAttributes.h"
910#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1011#include " mlir/Dialect/EmitC/IR/EmitC.h"
1112#include " mlir/Dialect/Func/IR/FuncOps.h"
13+ #include " mlir/IR/Attributes.h"
1214#include " mlir/IR/BuiltinOps.h"
1315#include " mlir/IR/BuiltinTypes.h"
1416#include " mlir/IR/Dialect.h"
@@ -1137,8 +1139,35 @@ static LogicalResult printOperation(CppEmitter &emitter,
11371139 " with multiple blocks needs variables declared at top" );
11381140 }
11391141
1140- CppEmitter::Scope scope (emitter);
1142+ CppEmitter::Scope classScope (emitter);
11411143 raw_indented_ostream &os = emitter.ostream ();
1144+ os << " class MyClass final {\n " ;
1145+ auto argAttrs = functionOp.getArgAttrs ();
1146+
1147+ std::map<std::string, Value> fields;
1148+ if (argAttrs)
1149+ for (auto [a,v] : zip (*argAttrs, functionOp.getArguments ())) {
1150+ if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
1151+ auto name = cast<mlir::StringAttr>(
1152+ da.getNamed (" ml_program.identifier" )->getValue ())
1153+ .str ();
1154+ fields[name] = v;
1155+ os << " " ;
1156+ if (failed (emitter.emitType (functionOp.getLoc (), v.getType ())))
1157+ return failure ();
1158+ os << " " << emitter.getOrCreateName (v) << " ;\n " ;
1159+ }
1160+ }
1161+ os << " std::map<std::string, char*> _buffer_map {" ;
1162+ for (auto &[n,v]:fields)
1163+ os << " { \" " << n << " \" " << " , reinterpret_cast<char*>(" << emitter.getOrCreateName (v) << " ) }," ;
1164+ os << " };\n " ;
1165+ os << " char* getBufferForName(const std::string& name) const {\n " ;
1166+ os << " auto it = _buffer_map.find(name);\n " ;
1167+ os << " return (it == _buffer_map.end()) ? nullptr : it->second;\n " ;
1168+ os << " }\n " ;
1169+ CppEmitter::Scope scope (emitter);
1170+
11421171 if (functionOp.getSpecifiers ()) {
11431172 for (Attribute specifier : functionOp.getSpecifiersAttr ()) {
11441173 os << cast<StringAttr>(specifier).str () << " " ;
@@ -1159,13 +1188,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
11591188 os << " );" ;
11601189 return success ();
11611190 }
1162- if (failed (printFunctionArgs (emitter, operation, functionOp.getArguments ())))
1163- return failure ();
1191+ // if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1192+ // return failure();
11641193 os << " ) {\n " ;
11651194 if (failed (printFunctionBody (emitter, operation, functionOp.getBlocks ())))
11661195 return failure ();
11671196 os << " }\n " ;
1168-
1197+ os << " }; \n " ;
11691198 return success ();
11701199}
11711200
0 commit comments