1212#include " mlir/Dialect/EmitC/Transforms/Transforms.h"
1313#include " mlir/IR/Attributes.h"
1414#include " mlir/IR/Builders.h"
15+ #include " mlir/IR/BuiltinAttributes.h"
1516#include " mlir/IR/PatternMatch.h"
1617#include " mlir/IR/TypeRange.h"
18+ #include " mlir/IR/Value.h"
1719#include " mlir/Pass/Pass.h"
1820#include " mlir/Transforms/DialectConversion.h"
1921#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
22+ #include " llvm/ADT/StringRef.h"
2023#include " llvm/Support/GraphWriter.h"
2124#include " llvm/Support/LogicalResult.h"
25+ #include < string>
2226
2327namespace mlir {
2428namespace emitc {
@@ -67,7 +71,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
6771 if (funcOp->getParentOfType <emitc::ClassOp>()) {
6872 return failure ();
6973 }
70- auto className = " My " + funcOp.getSymNameAttr ().str () + " Class" ;
74+ auto className = funcOp.getSymNameAttr ().str () + " Class" ;
7175 mlir::emitc::ClassOp newClassOp =
7276 rewriter.create <emitc::ClassOp>(funcOp.getLoc (), className);
7377
@@ -76,25 +80,33 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
7680 rewriter.setInsertionPointToStart (&newClassOp.getBody ().front ());
7781
7882 auto argAttrs = funcOp.getArgAttrs ();
79- if (argAttrs) {
80- for (const auto &[arg, val] :
81- llvm::zip (*argAttrs, funcOp.getArguments ())) {
82- if (auto namedAttr =
83- dyn_cast<mlir::DictionaryAttr>(arg).getNamed (attributeName)) {
84- Attribute nv = namedAttr->getValue ();
85- StringAttr fieldName =
86- cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0 ]);
87- TypeAttr typeAttr = TypeAttr::get (val.getType ());
88- fields.push_back ({fieldName, typeAttr});
89-
90- rewriter.create <emitc::FieldOp>(funcOp.getLoc (), fieldName, typeAttr,
91- /* attributes*/ arg);
83+ size_t idx = 0 ;
84+
85+ for (const BlockArgument &val : funcOp.getArguments ()) {
86+ StringAttr fieldName;
87+ Attribute argAttr = nullptr ;
88+
89+ if (argAttrs && idx < argAttrs->size ()) {
90+ if (DictionaryAttr dictAttr =
91+ dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx])) {
92+ if (auto namedAttr = dictAttr.getNamed (attributeName)) {
93+ Attribute nv = namedAttr->getValue ();
94+ fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0 ]);
95+ argAttr = (*argAttrs)[idx];
96+ }
9297 }
9398 }
94- } else {
95- funcOp->emitOpError (" arguments should have attributes so we can "
96- " initialize class fields." );
97- return failure ();
99+
100+ if (!fieldName) {
101+ fieldName = rewriter.getStringAttr (" fieldName" + std::to_string (idx));
102+ }
103+
104+ TypeAttr typeAttr = TypeAttr::get (val.getType ());
105+ fields.push_back ({fieldName, typeAttr});
106+ rewriter.create <emitc::FieldOp>(funcOp.getLoc (), fieldName, typeAttr,
107+ argAttr);
108+
109+ ++idx;
98110 }
99111
100112 rewriter.setInsertionPointToEnd (&newClassOp.getBody ().front ());
@@ -112,7 +124,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
112124 rewriter.setInsertionPointToStart (&newFuncOp.getBody ().front ());
113125 std::vector<Value> newArguments;
114126 for (auto [fieldName, attr] : fields) {
115- auto arg =
127+ GetFieldOp arg =
116128 rewriter.create <emitc::GetFieldOp>(loc, attr.getValue (), fieldName);
117129 newArguments.push_back (arg);
118130 }
@@ -122,14 +134,13 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
122134 rewriter.replaceAllUsesWith (oldArg, newArg);
123135 }
124136
125- while (!newFuncOp.getArguments ().empty ()) {
126- if (failed (newFuncOp.eraseArgument (0 ))) {
127- break ;
128- }
137+ llvm::BitVector argsToErase (newFuncOp.getNumArguments (), true );
138+ if (failed (newFuncOp.eraseArguments (argsToErase))) {
139+ newFuncOp->emitOpError (" Failed to erase all arguments using BitVector." );
129140 }
130141
131142 rewriter.replaceOp (funcOp, newClassOp);
132- return funcOp-> use_empty () ? success () : failure ();
143+ return success ();
133144 }
134145};
135146
0 commit comments